mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 09:34:37 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # README.md # examples/main/README.md # examples/run/run.cpp
This commit is contained in:
commit
39fad991cc
31 changed files with 1319 additions and 620 deletions
|
@ -675,7 +675,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
));
|
));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--no-context-shift"},
|
{"--no-context-shift"},
|
||||||
string_format("disables context shift on inifinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
|
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
|
||||||
[](common_params & params) {
|
[](common_params & params) {
|
||||||
params.ctx_shift = false;
|
params.ctx_shift = false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -249,16 +249,30 @@ class chat_template {
|
||||||
inputs.add_generation_prompt = false;
|
inputs.add_generation_prompt = false;
|
||||||
full = apply(inputs);
|
full = apply(inputs);
|
||||||
}
|
}
|
||||||
|
auto eos_pos_last = full.rfind(eos_token_);
|
||||||
if (full.find(prefix) != 0) {
|
if (eos_pos_last == prefix.size() - eos_token_.size() ||
|
||||||
if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) {
|
(full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
|
||||||
prefix = prefix.substr(0, prefix.size() - eos_token_.size());
|
full = full.substr(0, eos_pos_last);
|
||||||
}
|
}
|
||||||
|
size_t common_prefix_length = 0;
|
||||||
|
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
|
||||||
|
if (prefix[i] != full[i]) {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
if (full.find(prefix) != 0) {
|
if (prefix[i] == '<') {
|
||||||
|
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
|
||||||
|
// but it removes thinking tags for past messages.
|
||||||
|
// The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
common_prefix_length = i + 1;
|
||||||
|
}
|
||||||
|
auto example = full.substr(common_prefix_length);
|
||||||
|
if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
|
||||||
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
|
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
|
||||||
|
} else {
|
||||||
|
tool_call_example_ = example;
|
||||||
}
|
}
|
||||||
tool_call_example_ = full.substr(prefix.size());
|
|
||||||
}
|
}
|
||||||
} catch (const std::exception & e) {
|
} catch (const std::exception & e) {
|
||||||
fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
|
fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
|
||||||
|
@ -363,7 +377,7 @@ class chat_template {
|
||||||
if (polyfill_tools) {
|
if (polyfill_tools) {
|
||||||
adjusted_messages = add_system(inputs.messages,
|
adjusted_messages = add_system(inputs.messages,
|
||||||
"You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
|
"You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
|
||||||
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_));
|
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"));
|
||||||
} else {
|
} else {
|
||||||
adjusted_messages = inputs.messages;
|
adjusted_messages = inputs.messages;
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
#include "ggml.h" // for ggml_log_level
|
#include "ggml.h" // for ggml_log_level
|
||||||
|
|
||||||
|
#define LOG_CLR_TO_EOL "\033[K\r"
|
||||||
#define LOG_COL_DEFAULT "\033[0m"
|
#define LOG_COL_DEFAULT "\033[0m"
|
||||||
#define LOG_COL_BOLD "\033[1m"
|
#define LOG_COL_BOLD "\033[1m"
|
||||||
#define LOG_COL_RED "\033[31m"
|
#define LOG_COL_RED "\033[31m"
|
||||||
|
|
|
@ -1385,6 +1385,13 @@ static std::string strip(const std::string & s) {
|
||||||
return s.substr(start, end - start + 1);
|
return s.substr(start, end - start + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::string capitalize(const std::string & s) {
|
||||||
|
if (s.empty()) return s;
|
||||||
|
auto result = s;
|
||||||
|
result[0] = std::toupper(result[0]);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
static std::string html_escape(const std::string & s) {
|
static std::string html_escape(const std::string & s) {
|
||||||
std::string result;
|
std::string result;
|
||||||
result.reserve(s.size());
|
result.reserve(s.size());
|
||||||
|
@ -1462,6 +1469,9 @@ public:
|
||||||
if (method->get_name() == "strip") {
|
if (method->get_name() == "strip") {
|
||||||
vargs.expectArgs("strip method", {0, 0}, {0, 0});
|
vargs.expectArgs("strip method", {0, 0}, {0, 0});
|
||||||
return Value(strip(str));
|
return Value(strip(str));
|
||||||
|
} else if (method->get_name() == "capitalize") {
|
||||||
|
vargs.expectArgs("capitalize method", {0, 0}, {0, 0});
|
||||||
|
return Value(capitalize(str));
|
||||||
} else if (method->get_name() == "endswith") {
|
} else if (method->get_name() == "endswith") {
|
||||||
vargs.expectArgs("endswith method", {1, 1}, {0, 0});
|
vargs.expectArgs("endswith method", {1, 1}, {0, 0});
|
||||||
auto suffix = vargs.args[0].get<std::string>();
|
auto suffix = vargs.args[0].get<std::string>();
|
||||||
|
@ -1792,7 +1802,7 @@ private:
|
||||||
auto left = parseStringConcat();
|
auto left = parseStringConcat();
|
||||||
if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression");
|
if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression");
|
||||||
|
|
||||||
static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not[\r\n\s]+in\b)");
|
static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not\s+in\b)");
|
||||||
static std::regex not_tok(R"(not\b)");
|
static std::regex not_tok(R"(not\b)");
|
||||||
std::string op_str;
|
std::string op_str;
|
||||||
while (!(op_str = consumeToken(compare_tok)).empty()) {
|
while (!(op_str = consumeToken(compare_tok)).empty()) {
|
||||||
|
@ -2171,7 +2181,7 @@ private:
|
||||||
using TemplateTokenIterator = TemplateTokenVector::const_iterator;
|
using TemplateTokenIterator = TemplateTokenVector::const_iterator;
|
||||||
|
|
||||||
std::vector<std::string> parseVarNames() {
|
std::vector<std::string> parseVarNames() {
|
||||||
static std::regex varnames_regex(R"(((?:\w+)(?:[\r\n\s]*,[\r\n\s]*(?:\w+))*)[\r\n\s]*)");
|
static std::regex varnames_regex(R"(((?:\w+)(?:\s*,\s*(?:\w+))*)\s*)");
|
||||||
|
|
||||||
std::vector<std::string> group;
|
std::vector<std::string> group;
|
||||||
if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names");
|
if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names");
|
||||||
|
@ -2194,13 +2204,13 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
TemplateTokenVector tokenize() {
|
TemplateTokenVector tokenize() {
|
||||||
static std::regex comment_tok(R"(\{#([-~]?)([\s\S\r\n]*?)([-~]?)#\})");
|
static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})");
|
||||||
static std::regex expr_open_regex(R"(\{\{([-~])?)");
|
static std::regex expr_open_regex(R"(\{\{([-~])?)");
|
||||||
static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)");
|
static std::regex block_open_regex(R"(^\{%([-~])?\s*)");
|
||||||
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
|
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
|
||||||
static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
|
static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
|
||||||
static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})");
|
static std::regex expr_close_regex(R"(\s*([-~])?\}\})");
|
||||||
static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})");
|
static std::regex block_close_regex(R"(\s*([-~])?%\})");
|
||||||
|
|
||||||
TemplateTokenVector tokens;
|
TemplateTokenVector tokens;
|
||||||
std::vector<std::string> group;
|
std::vector<std::string> group;
|
||||||
|
@ -2284,7 +2294,7 @@ private:
|
||||||
auto post_space = parseBlockClose();
|
auto post_space = parseBlockClose();
|
||||||
tokens.push_back(std::make_unique<EndGenerationTemplateToken>(location, pre_space, post_space));
|
tokens.push_back(std::make_unique<EndGenerationTemplateToken>(location, pre_space, post_space));
|
||||||
} else if (keyword == "set") {
|
} else if (keyword == "set") {
|
||||||
static std::regex namespaced_var_regex(R"((\w+)[\s\n\r]*\.[\s\n\r]*(\w+))");
|
static std::regex namespaced_var_regex(R"((\w+)\s*\.\s*(\w+))");
|
||||||
|
|
||||||
std::string ns;
|
std::string ns;
|
||||||
std::vector<std::string> var_names;
|
std::vector<std::string> var_names;
|
||||||
|
@ -2336,6 +2346,11 @@ private:
|
||||||
throw std::runtime_error("Unexpected block: " + keyword);
|
throw std::runtime_error("Unexpected block: " + keyword);
|
||||||
}
|
}
|
||||||
} else if (std::regex_search(it, end, match, non_text_open_regex)) {
|
} else if (std::regex_search(it, end, match, non_text_open_regex)) {
|
||||||
|
if (!match.position()) {
|
||||||
|
if (match[0] != "{#")
|
||||||
|
throw std::runtime_error("Internal error: Expected a comment");
|
||||||
|
throw std::runtime_error("Missing end of comment tag");
|
||||||
|
}
|
||||||
auto text_end = it + match.position();
|
auto text_end = it + match.position();
|
||||||
text = std::string(it, text_end);
|
text = std::string(it, text_end);
|
||||||
it = text_end;
|
it = text_end;
|
||||||
|
@ -2400,7 +2415,7 @@ private:
|
||||||
|
|
||||||
auto text = text_token->text;
|
auto text = text_token->text;
|
||||||
if (post_space == SpaceHandling::Strip) {
|
if (post_space == SpaceHandling::Strip) {
|
||||||
static std::regex trailing_space_regex(R"((\s|\r|\n)+$)");
|
static std::regex trailing_space_regex(R"(\s+$)");
|
||||||
text = std::regex_replace(text, trailing_space_regex, "");
|
text = std::regex_replace(text, trailing_space_regex, "");
|
||||||
} else if (options.lstrip_blocks && it != end) {
|
} else if (options.lstrip_blocks && it != end) {
|
||||||
auto i = text.size();
|
auto i = text.size();
|
||||||
|
@ -2410,7 +2425,7 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (pre_space == SpaceHandling::Strip) {
|
if (pre_space == SpaceHandling::Strip) {
|
||||||
static std::regex leading_space_regex(R"(^(\s|\r|\n)+)");
|
static std::regex leading_space_regex(R"(^\s+)");
|
||||||
text = std::regex_replace(text, leading_space_regex, "");
|
text = std::regex_replace(text, leading_space_regex, "");
|
||||||
} else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast<ExpressionTemplateToken*>((*(it - 2)).get())) {
|
} else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast<ExpressionTemplateToken*>((*(it - 2)).get())) {
|
||||||
if (text.length() > 0 && text[0] == '\n') {
|
if (text.length() > 0 && text[0] == '\n') {
|
||||||
|
|
|
@ -9,7 +9,7 @@ struct common_speculative_params {
|
||||||
int n_draft = 16; // max drafted tokens
|
int n_draft = 16; // max drafted tokens
|
||||||
int n_reuse = 256;
|
int n_reuse = 256;
|
||||||
|
|
||||||
float p_min = 0.9f; // min probabiliy required to accept a token in the draft
|
float p_min = 0.9f; // min probability required to accept a token in the draft
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
|
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
|
||||||
|
|
205
docs/backend/OPENCL.md
Normal file
205
docs/backend/OPENCL.md
Normal file
|
@ -0,0 +1,205 @@
|
||||||
|
# llama.cpp for OpenCL
|
||||||
|
|
||||||
|
- [Background](#background)
|
||||||
|
- [OS](#os)
|
||||||
|
- [Hardware](#hardware)
|
||||||
|
- [DataType Supports](#datatype-supports)
|
||||||
|
- [Model Preparation](#model-preparation)
|
||||||
|
- [CMake Options](#cmake-options)
|
||||||
|
- [Android](#android)
|
||||||
|
- [Windows 11 Arm64](#windows-11-arm64)
|
||||||
|
- [Known Issue](#known-issues)
|
||||||
|
- [TODO](#todo)
|
||||||
|
|
||||||
|
## Background
|
||||||
|
|
||||||
|
OpenCL (Open Computing Language) is an open, royalty-free standard for cross-platform, parallel programming of diverse accelerators found in supercomputers, cloud servers, personal computers, mobile devices and embedded platforms. OpenCL specifies a programming language (based on C99) for programming these devices and application programming interfaces (APIs) to control the platform and execute programs on the compute devices. Similar to CUDA, OpenCL has been widely used to program GPUs and is supported by most GPU vendors.
|
||||||
|
|
||||||
|
### Llama.cpp + OpenCL
|
||||||
|
|
||||||
|
The llama.cpp OpenCL backend is designed to enable llama.cpp on **Qualcomm Adreno GPU** firstly via OpenCL. Thanks to the portabilty of OpenCL, the OpenCL backend can also run on certain Intel GPUs although the performance is not optimal.
|
||||||
|
|
||||||
|
## OS
|
||||||
|
|
||||||
|
| OS | Status | Verified |
|
||||||
|
|---------|---------|------------------------------------------------|
|
||||||
|
| Android | Support | Snapdragon 8 Gen 3, Snapdragon 8 Elite |
|
||||||
|
| Windows | Support | Windows 11 Arm64 with Snapdragon X Elite |
|
||||||
|
| Linux | Support | Ubuntu 22.04 WSL2 with Intel 12700H |
|
||||||
|
|
||||||
|
## Hardware
|
||||||
|
|
||||||
|
### Adreno GPU
|
||||||
|
|
||||||
|
**Verified devices**
|
||||||
|
|
||||||
|
| Adreno GPU | Status |
|
||||||
|
|:------------------------------------:|:-------:|
|
||||||
|
| Adreno 750 (Snapdragon 8 Gen 3) | Support |
|
||||||
|
| Adreno 830 (Snapdragon 8 Elite) | Support |
|
||||||
|
| Adreno X85 (Snapdragon X Elite) | Support |
|
||||||
|
|
||||||
|
## DataType Supports
|
||||||
|
|
||||||
|
| DataType | Status |
|
||||||
|
|:----------------------:|:--------------------------:|
|
||||||
|
| Q4_0 | Support |
|
||||||
|
| Q6_K | Support, but not optimized |
|
||||||
|
|
||||||
|
## Model Preparation
|
||||||
|
|
||||||
|
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model prepration.
|
||||||
|
|
||||||
|
Currently we support `Q4_0` quantization and have optimize for it. To achieve best performance on Adreno GPU, add `--pure` to `llama-quantize`. For example,
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./llama-quantize --pure ggml-model-qwen2.5-3b-f16.gguf ggml-model-qwen-3b-Q4_0.gguf Q4_0
|
||||||
|
```
|
||||||
|
|
||||||
|
Since `Q6_K` is also supported, `Q4_0` quantization without `--pure` will also work. However, the performance will be worse compared to pure `Q4_0` quantization.
|
||||||
|
|
||||||
|
## CMake Options
|
||||||
|
|
||||||
|
The OpenCL backend has the following CMake options that control the behavior of the backend.
|
||||||
|
|
||||||
|
| CMake options | Default value | Description |
|
||||||
|
|:---------------------------------:|:--------------:|:------------------------------------------|
|
||||||
|
| `GGML_OPENCL_EMBED_KERNELS` | `ON` | Embed OpenCL kernels into the executable. |
|
||||||
|
| `GGML_OPENCL_USE_ADRENO_KERNELS` | `ON` | Use kernels optimized for Adreno. |
|
||||||
|
|
||||||
|
## Android
|
||||||
|
|
||||||
|
Ubuntu 22.04 is used for targeting Android. Make sure the following tools are accessible from command line,
|
||||||
|
|
||||||
|
* Git
|
||||||
|
* CMake 3.29
|
||||||
|
* Ninja
|
||||||
|
* Python3
|
||||||
|
|
||||||
|
### I. Setup Environment
|
||||||
|
|
||||||
|
1. **Install NDK**
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cd ~
|
||||||
|
wget https://dl.google.com/android/repository/commandlinetools-linux-8512546_latest.zip && \
|
||||||
|
unzip commandlinetools-linux-8512546_latest.zip && \
|
||||||
|
mkdir -p ~/android-sdk/cmdline-tools && \
|
||||||
|
mv cmdline-tools latest && \
|
||||||
|
mv latest ~/android-sdk/cmdline-tools/ && \
|
||||||
|
rm -rf commandlinetools-linux-8512546_latest.zip
|
||||||
|
|
||||||
|
yes | ~/android-sdk/cmdline-tools/latest/bin/sdkmanager "ndk;26.3.11579264"
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Install OpenCL Headers and Library**
|
||||||
|
|
||||||
|
```sh
|
||||||
|
mkdir -p ~/dev/llm
|
||||||
|
cd ~/dev/llm
|
||||||
|
|
||||||
|
git clone https://github.com/KhronosGroup/OpenCL-Headers && \
|
||||||
|
cd OpenCL-Headers && \
|
||||||
|
cp -r CL ~/android-sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
|
||||||
|
|
||||||
|
cd ~/dev/llm
|
||||||
|
|
||||||
|
git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader && \
|
||||||
|
cd OpenCL-ICD-Loader && \
|
||||||
|
mkdir build_ndk26 && cd build_ndk26 && \
|
||||||
|
cmake .. -G Ninja -DCMAKE_BUILD_TYPE=Release \
|
||||||
|
-DCMAKE_TOOLCHAIN_FILE=$HOME/android-sdk/ndk/26.3.11579264/build/cmake/android.toolchain.cmake \
|
||||||
|
-DOPENCL_ICD_LOADER_HEADERS_DIR=$HOME/android-sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include \
|
||||||
|
-DANDROID_ABI=arm64-v8a \
|
||||||
|
-DANDROID_PLATFORM=24 \
|
||||||
|
-DANDROID_STL=c++_shared && \
|
||||||
|
ninja && \
|
||||||
|
cp libOpenCL.so ~/android-sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android
|
||||||
|
```
|
||||||
|
|
||||||
|
### II. Build llama.cpp
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cd ~/dev/llm
|
||||||
|
|
||||||
|
git clone https://github.com/ggerganov/llama.cpp && \
|
||||||
|
cd llama.cpp && \
|
||||||
|
mkdir build-android && cd build-android
|
||||||
|
|
||||||
|
cmake .. -G Ninja \
|
||||||
|
-DCMAKE_TOOLCHAIN_FILE=$HOME/android-sdk/ndk/26.3.11579264/build/cmake/android.toolchain.cmake \
|
||||||
|
-DANDROID_ABI=arm64-v8a \
|
||||||
|
-DANDROID_PLATFORM=android-28 \
|
||||||
|
-DBUILD_SHARED_LIBS=OFF \
|
||||||
|
-DGGML_OPENCL=ON
|
||||||
|
|
||||||
|
ninja
|
||||||
|
```
|
||||||
|
|
||||||
|
## Windows 11 Arm64
|
||||||
|
|
||||||
|
A Snapdragon X Elite device with Windows 11 Arm64 is used. Make sure the following tools are accessible from command line,
|
||||||
|
|
||||||
|
* Git
|
||||||
|
* CMake 3.29
|
||||||
|
* Clang 19
|
||||||
|
* Ninja
|
||||||
|
* Visual Studio 2022
|
||||||
|
|
||||||
|
Powershell is used for the following instructions.
|
||||||
|
|
||||||
|
### I. Setup Environment
|
||||||
|
|
||||||
|
1. **Install OpenCL Headers and Library**
|
||||||
|
|
||||||
|
```powershell
|
||||||
|
mkdir -p ~/dev/llm
|
||||||
|
|
||||||
|
cd ~/dev/llm
|
||||||
|
git clone https://github.com/KhronosGroup/OpenCL-Headers && cd OpenCL-Headers
|
||||||
|
mkdir build && cd build
|
||||||
|
cmake .. -G Ninja `
|
||||||
|
-DBUILD_TESTING=OFF `
|
||||||
|
-DOPENCL_HEADERS_BUILD_TESTING=OFF `
|
||||||
|
-DOPENCL_HEADERS_BUILD_CXX_TESTS=OFF `
|
||||||
|
-DCMAKE_INSTALL_PREFIX="$HOME/dev/llm/opencl"
|
||||||
|
cmake --build . --target install
|
||||||
|
|
||||||
|
cd ~/dev/llm
|
||||||
|
git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader && cd OpenCL-ICD-Loader
|
||||||
|
mkdir build && cd build
|
||||||
|
cmake .. -G Ninja `
|
||||||
|
-DCMAKE_BUILD_TYPE=Release `
|
||||||
|
-DCMAKE_PREFIX_PATH="$HOME/dev/llm/opencl" `
|
||||||
|
-DCMAKE_INSTALL_PREFIX="$HOME/dev/llm/opencl"
|
||||||
|
cmake --build . --target install
|
||||||
|
```
|
||||||
|
|
||||||
|
### II. Build llama.cpp
|
||||||
|
|
||||||
|
```powershell
|
||||||
|
|
||||||
|
mkdir -p ~/dev/llm
|
||||||
|
cd ~/dev/llm
|
||||||
|
|
||||||
|
git clone https://github.com/ggerganov/llama.cpp && cd llama.cpp
|
||||||
|
mkdir build && cd build
|
||||||
|
|
||||||
|
cmake .. -G Ninja `
|
||||||
|
-DCMAKE_TOOLCHAIN_FILE="$HOME/dev/llm/llama.cpp/cmake/arm64-windows-llvm.cmake" `
|
||||||
|
-DCMAKE_BUILD_TYPE=Release `
|
||||||
|
-DCMAKE_PREFIX_PATH="$HOME/dev/llm/opencl" `
|
||||||
|
-DBUILD_SHARED_LIBS=OFF `
|
||||||
|
-DGGML_OPENCL=ON
|
||||||
|
ninja
|
||||||
|
```
|
||||||
|
|
||||||
|
## Known Issues
|
||||||
|
|
||||||
|
- Qwen2.5 0.5B model produces gibberish output with Adreno kernels.
|
||||||
|
|
||||||
|
## TODO
|
||||||
|
|
||||||
|
- Fix Qwen2.5 0.5B
|
||||||
|
- Optimization for Q6_K
|
||||||
|
- Support and optimization for Q4_K
|
Binary file not shown.
|
@ -1600,6 +1600,10 @@ struct server_queue {
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||||
|
if (!running) {
|
||||||
|
QUE_DBG("%s", "terminate\n");
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (queue_tasks.empty()) {
|
if (queue_tasks.empty()) {
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
break;
|
break;
|
||||||
|
@ -1620,11 +1624,11 @@ struct server_queue {
|
||||||
QUE_DBG("%s", "waiting for new tasks\n");
|
QUE_DBG("%s", "waiting for new tasks\n");
|
||||||
{
|
{
|
||||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||||
if (queue_tasks.empty()) {
|
|
||||||
if (!running) {
|
if (!running) {
|
||||||
QUE_DBG("%s", "terminate\n");
|
QUE_DBG("%s", "terminate\n");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (queue_tasks.empty()) {
|
||||||
condition_tasks.wait(lock, [&]{
|
condition_tasks.wait(lock, [&]{
|
||||||
return (!queue_tasks.empty() || !running);
|
return (!queue_tasks.empty() || !running);
|
||||||
});
|
});
|
||||||
|
@ -2275,7 +2279,7 @@ struct server_context {
|
||||||
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
|
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
|
||||||
result.probs.push_back({
|
result.probs.push_back({
|
||||||
cur_p->data[i].id,
|
cur_p->data[i].id,
|
||||||
common_detokenize(ctx, {cur_p->data[i].id}, special),
|
common_token_to_piece(ctx, cur_p->data[i].id, special),
|
||||||
cur_p->data[i].p
|
cur_p->data[i].p
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -2297,7 +2301,7 @@ struct server_context {
|
||||||
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
|
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
|
||||||
result.probs.push_back({
|
result.probs.push_back({
|
||||||
cur[i].id,
|
cur[i].id,
|
||||||
common_detokenize(ctx, {cur[i].id}, special),
|
common_token_to_piece(ctx, cur[i].id, special),
|
||||||
cur[i].p
|
cur[i].p
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -4430,6 +4434,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// clean up function, to be called before exit
|
// clean up function, to be called before exit
|
||||||
auto clean_up = [&svr]() {
|
auto clean_up = [&svr]() {
|
||||||
|
SRV_INF("%s: cleaning up before exit...\n", __func__);
|
||||||
svr->stop();
|
svr->stop();
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
};
|
};
|
||||||
|
@ -4446,10 +4451,6 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!was_bound) {
|
if (!was_bound) {
|
||||||
//LOG_ERROR("couldn't bind HTTP server socket", {
|
|
||||||
// {"hostname", params.hostname},
|
|
||||||
// {"port", params.port},
|
|
||||||
//});
|
|
||||||
LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port);
|
LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port);
|
||||||
clean_up();
|
clean_up();
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -4466,7 +4467,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
if (!ctx_server.load_model(params)) {
|
if (!ctx_server.load_model(params)) {
|
||||||
clean_up();
|
clean_up();
|
||||||
t.join();
|
// t.join(); // FIXME: see below
|
||||||
LOG_ERR("%s: exiting due to model loading error\n", __func__);
|
LOG_ERR("%s: exiting due to model loading error\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -4490,13 +4491,10 @@ int main(int argc, char ** argv) {
|
||||||
});
|
});
|
||||||
|
|
||||||
shutdown_handler = [&](int) {
|
shutdown_handler = [&](int) {
|
||||||
|
// this will unblock start_loop()
|
||||||
ctx_server.queue_tasks.terminate();
|
ctx_server.queue_tasks.terminate();
|
||||||
};
|
};
|
||||||
|
|
||||||
LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
|
|
||||||
|
|
||||||
ctx_server.queue_tasks.start_loop();
|
|
||||||
|
|
||||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||||
struct sigaction sigint_action;
|
struct sigaction sigint_action;
|
||||||
sigint_action.sa_handler = signal_handler;
|
sigint_action.sa_handler = signal_handler;
|
||||||
|
@ -4511,8 +4509,13 @@ int main(int argc, char ** argv) {
|
||||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
|
||||||
|
|
||||||
|
// this call blocks the main thread until queue_tasks.terminate() is called
|
||||||
|
ctx_server.queue_tasks.start_loop();
|
||||||
|
|
||||||
clean_up();
|
clean_up();
|
||||||
t.join();
|
// t.join(); // FIXME: http thread may stuck if there is an on-going request. we don't need to care about this for now as the HTTP connection will already be closed at this point, but it's better to fix this
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
7
examples/server/webui/package-lock.json
generated
7
examples/server/webui/package-lock.json
generated
|
@ -13,6 +13,7 @@
|
||||||
"@vscode/markdown-it-katex": "^1.1.1",
|
"@vscode/markdown-it-katex": "^1.1.1",
|
||||||
"autoprefixer": "^10.4.20",
|
"autoprefixer": "^10.4.20",
|
||||||
"daisyui": "^4.12.14",
|
"daisyui": "^4.12.14",
|
||||||
|
"dexie": "^4.0.11",
|
||||||
"highlight.js": "^11.10.0",
|
"highlight.js": "^11.10.0",
|
||||||
"katex": "^0.16.15",
|
"katex": "^0.16.15",
|
||||||
"postcss": "^8.4.49",
|
"postcss": "^8.4.49",
|
||||||
|
@ -2338,6 +2339,12 @@
|
||||||
"url": "https://github.com/sponsors/wooorm"
|
"url": "https://github.com/sponsors/wooorm"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/dexie": {
|
||||||
|
"version": "4.0.11",
|
||||||
|
"resolved": "https://registry.npmjs.org/dexie/-/dexie-4.0.11.tgz",
|
||||||
|
"integrity": "sha512-SOKO002EqlvBYYKQSew3iymBoN2EQ4BDw/3yprjh7kAfFzjBYkaMNa/pZvcA7HSWlcKSQb9XhPe3wKyQ0x4A8A==",
|
||||||
|
"license": "Apache-2.0"
|
||||||
|
},
|
||||||
"node_modules/didyoumean": {
|
"node_modules/didyoumean": {
|
||||||
"version": "1.2.2",
|
"version": "1.2.2",
|
||||||
"resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz",
|
"resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz",
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
"@vscode/markdown-it-katex": "^1.1.1",
|
"@vscode/markdown-it-katex": "^1.1.1",
|
||||||
"autoprefixer": "^10.4.20",
|
"autoprefixer": "^10.4.20",
|
||||||
"daisyui": "^4.12.14",
|
"daisyui": "^4.12.14",
|
||||||
|
"dexie": "^4.0.11",
|
||||||
"highlight.js": "^11.10.0",
|
"highlight.js": "^11.10.0",
|
||||||
"katex": "^0.16.15",
|
"katex": "^0.16.15",
|
||||||
"postcss": "^8.4.49",
|
"postcss": "^8.4.49",
|
||||||
|
|
|
@ -3,6 +3,7 @@ import { useAppContext } from '../utils/app.context';
|
||||||
import { Message, PendingMessage } from '../utils/types';
|
import { Message, PendingMessage } from '../utils/types';
|
||||||
import { classNames } from '../utils/misc';
|
import { classNames } from '../utils/misc';
|
||||||
import MarkdownDisplay, { CopyButton } from './MarkdownDisplay';
|
import MarkdownDisplay, { CopyButton } from './MarkdownDisplay';
|
||||||
|
import { ChevronLeftIcon, ChevronRightIcon } from '@heroicons/react/24/outline';
|
||||||
|
|
||||||
interface SplitMessage {
|
interface SplitMessage {
|
||||||
content: PendingMessage['content'];
|
content: PendingMessage['content'];
|
||||||
|
@ -12,17 +13,24 @@ interface SplitMessage {
|
||||||
|
|
||||||
export default function ChatMessage({
|
export default function ChatMessage({
|
||||||
msg,
|
msg,
|
||||||
|
siblingLeafNodeIds,
|
||||||
|
siblingCurrIdx,
|
||||||
id,
|
id,
|
||||||
scrollToBottom,
|
onRegenerateMessage,
|
||||||
|
onEditMessage,
|
||||||
|
onChangeSibling,
|
||||||
isPending,
|
isPending,
|
||||||
}: {
|
}: {
|
||||||
msg: Message | PendingMessage;
|
msg: Message | PendingMessage;
|
||||||
|
siblingLeafNodeIds: Message['id'][];
|
||||||
|
siblingCurrIdx: number;
|
||||||
id?: string;
|
id?: string;
|
||||||
scrollToBottom: (requiresNearBottom: boolean) => void;
|
onRegenerateMessage(msg: Message): void;
|
||||||
|
onEditMessage(msg: Message, content: string): void;
|
||||||
|
onChangeSibling(sibling: Message['id']): void;
|
||||||
isPending?: boolean;
|
isPending?: boolean;
|
||||||
}) {
|
}) {
|
||||||
const { viewingConversation, replaceMessageAndGenerate, config } =
|
const { viewingChat, config } = useAppContext();
|
||||||
useAppContext();
|
|
||||||
const [editingContent, setEditingContent] = useState<string | null>(null);
|
const [editingContent, setEditingContent] = useState<string | null>(null);
|
||||||
const timings = useMemo(
|
const timings = useMemo(
|
||||||
() =>
|
() =>
|
||||||
|
@ -37,6 +45,8 @@ export default function ChatMessage({
|
||||||
: null,
|
: null,
|
||||||
[msg.timings]
|
[msg.timings]
|
||||||
);
|
);
|
||||||
|
const nextSibling = siblingLeafNodeIds[siblingCurrIdx + 1];
|
||||||
|
const prevSibling = siblingLeafNodeIds[siblingCurrIdx - 1];
|
||||||
|
|
||||||
// for reasoning model, we split the message into content and thought
|
// for reasoning model, we split the message into content and thought
|
||||||
// TODO: implement this as remark/rehype plugin in the future
|
// TODO: implement this as remark/rehype plugin in the future
|
||||||
|
@ -64,13 +74,7 @@ export default function ChatMessage({
|
||||||
return { content: actualContent, thought, isThinking };
|
return { content: actualContent, thought, isThinking };
|
||||||
}, [msg]);
|
}, [msg]);
|
||||||
|
|
||||||
if (!viewingConversation) return null;
|
if (!viewingChat) return null;
|
||||||
|
|
||||||
const regenerate = async () => {
|
|
||||||
replaceMessageAndGenerate(viewingConversation.id, msg.id, undefined, () =>
|
|
||||||
scrollToBottom(true)
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="group" id={id}>
|
<div className="group" id={id}>
|
||||||
|
@ -105,13 +109,12 @@ export default function ChatMessage({
|
||||||
</button>
|
</button>
|
||||||
<button
|
<button
|
||||||
className="btn mt-2"
|
className="btn mt-2"
|
||||||
onClick={() =>
|
onClick={() => {
|
||||||
replaceMessageAndGenerate(
|
if (msg.content !== null) {
|
||||||
viewingConversation.id,
|
setEditingContent(null);
|
||||||
msg.id,
|
onEditMessage(msg as Message, editingContent);
|
||||||
editingContent
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
}}
|
||||||
>
|
>
|
||||||
Submit
|
Submit
|
||||||
</button>
|
</button>
|
||||||
|
@ -196,10 +199,35 @@ export default function ChatMessage({
|
||||||
{msg.content !== null && (
|
{msg.content !== null && (
|
||||||
<div
|
<div
|
||||||
className={classNames({
|
className={classNames({
|
||||||
'mx-4 mt-2 mb-2': true,
|
'flex items-center gap-2 mx-4 mt-2 mb-2': true,
|
||||||
'text-right': msg.role === 'user',
|
'flex-row-reverse': msg.role === 'user',
|
||||||
})}
|
})}
|
||||||
>
|
>
|
||||||
|
{siblingLeafNodeIds && siblingLeafNodeIds.length > 1 && (
|
||||||
|
<div className="flex gap-1 items-center opacity-60 text-sm">
|
||||||
|
<button
|
||||||
|
className={classNames({
|
||||||
|
'btn btn-sm btn-ghost p-1': true,
|
||||||
|
'opacity-20': !prevSibling,
|
||||||
|
})}
|
||||||
|
onClick={() => prevSibling && onChangeSibling(prevSibling)}
|
||||||
|
>
|
||||||
|
<ChevronLeftIcon className="h-4 w-4" />
|
||||||
|
</button>
|
||||||
|
<span>
|
||||||
|
{siblingCurrIdx + 1} / {siblingLeafNodeIds.length}
|
||||||
|
</span>
|
||||||
|
<button
|
||||||
|
className={classNames({
|
||||||
|
'btn btn-sm btn-ghost p-1': true,
|
||||||
|
'opacity-20': !nextSibling,
|
||||||
|
})}
|
||||||
|
onClick={() => nextSibling && onChangeSibling(nextSibling)}
|
||||||
|
>
|
||||||
|
<ChevronRightIcon className="h-4 w-4" />
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
{/* user message */}
|
{/* user message */}
|
||||||
{msg.role === 'user' && (
|
{msg.role === 'user' && (
|
||||||
<button
|
<button
|
||||||
|
@ -216,7 +244,11 @@ export default function ChatMessage({
|
||||||
{!isPending && (
|
{!isPending && (
|
||||||
<button
|
<button
|
||||||
className="badge btn-mini show-on-hover mr-2"
|
className="badge btn-mini show-on-hover mr-2"
|
||||||
onClick={regenerate}
|
onClick={() => {
|
||||||
|
if (msg.content !== null) {
|
||||||
|
onRegenerateMessage(msg as Message);
|
||||||
|
}
|
||||||
|
}}
|
||||||
disabled={msg.content === null}
|
disabled={msg.content === null}
|
||||||
>
|
>
|
||||||
🔄 Regenerate
|
🔄 Regenerate
|
||||||
|
|
|
@ -1,28 +1,59 @@
|
||||||
import { useEffect, useState } from 'react';
|
import { useEffect, useMemo, useState } from 'react';
|
||||||
import { useAppContext } from '../utils/app.context';
|
import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context';
|
||||||
import StorageUtils from '../utils/storage';
|
|
||||||
import { useNavigate } from 'react-router';
|
|
||||||
import ChatMessage from './ChatMessage';
|
import ChatMessage from './ChatMessage';
|
||||||
import { CanvasType, PendingMessage } from '../utils/types';
|
import { CanvasType, Message, PendingMessage } from '../utils/types';
|
||||||
import { classNames } from '../utils/misc';
|
import { classNames, throttle } from '../utils/misc';
|
||||||
import CanvasPyInterpreter from './CanvasPyInterpreter';
|
import CanvasPyInterpreter from './CanvasPyInterpreter';
|
||||||
|
import StorageUtils from '../utils/storage';
|
||||||
|
|
||||||
export default function ChatScreen() {
|
/**
|
||||||
const {
|
* A message display is a message node with additional information for rendering.
|
||||||
viewingConversation,
|
* For example, siblings of the message node are stored as their last node (aka leaf node).
|
||||||
sendMessage,
|
*/
|
||||||
isGenerating,
|
export interface MessageDisplay {
|
||||||
stopGenerating,
|
msg: Message | PendingMessage;
|
||||||
pendingMessages,
|
siblingLeafNodeIds: Message['id'][];
|
||||||
canvasData,
|
siblingCurrIdx: number;
|
||||||
} = useAppContext();
|
isPending?: boolean;
|
||||||
const [inputMsg, setInputMsg] = useState('');
|
}
|
||||||
const navigate = useNavigate();
|
|
||||||
|
|
||||||
const currConvId = viewingConversation?.id ?? '';
|
function getListMessageDisplay(
|
||||||
const pendingMsg: PendingMessage | undefined = pendingMessages[currConvId];
|
msgs: Readonly<Message[]>,
|
||||||
|
leafNodeId: Message['id']
|
||||||
|
): MessageDisplay[] {
|
||||||
|
const currNodes = StorageUtils.filterByLeafNodeId(msgs, leafNodeId, true);
|
||||||
|
const res: MessageDisplay[] = [];
|
||||||
|
const nodeMap = new Map<Message['id'], Message>();
|
||||||
|
for (const msg of msgs) {
|
||||||
|
nodeMap.set(msg.id, msg);
|
||||||
|
}
|
||||||
|
// find leaf node from a message node
|
||||||
|
const findLeafNode = (msgId: Message['id']): Message['id'] => {
|
||||||
|
let currNode: Message | undefined = nodeMap.get(msgId);
|
||||||
|
while (currNode) {
|
||||||
|
if (currNode.children.length === 0) break;
|
||||||
|
currNode = nodeMap.get(currNode.children.at(-1) ?? -1);
|
||||||
|
}
|
||||||
|
return currNode?.id ?? -1;
|
||||||
|
};
|
||||||
|
// traverse the current nodes
|
||||||
|
for (const msg of currNodes) {
|
||||||
|
const parentNode = nodeMap.get(msg.parent ?? -1);
|
||||||
|
if (!parentNode) continue;
|
||||||
|
const siblings = parentNode.children;
|
||||||
|
if (msg.type !== 'root') {
|
||||||
|
res.push({
|
||||||
|
msg,
|
||||||
|
siblingLeafNodeIds: siblings.map(findLeafNode),
|
||||||
|
siblingCurrIdx: siblings.indexOf(msg.id),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
const scrollToBottom = (requiresNearBottom: boolean) => {
|
const scrollToBottom = throttle(
|
||||||
|
(requiresNearBottom: boolean, delay: number = 80) => {
|
||||||
const mainScrollElem = document.getElementById('main-scroll');
|
const mainScrollElem = document.getElementById('main-scroll');
|
||||||
if (!mainScrollElem) return;
|
if (!mainScrollElem) return;
|
||||||
const spaceToBottom =
|
const spaceToBottom =
|
||||||
|
@ -32,36 +63,107 @@ export default function ChatScreen() {
|
||||||
if (!requiresNearBottom || spaceToBottom < 50) {
|
if (!requiresNearBottom || spaceToBottom < 50) {
|
||||||
setTimeout(
|
setTimeout(
|
||||||
() => mainScrollElem.scrollTo({ top: mainScrollElem.scrollHeight }),
|
() => mainScrollElem.scrollTo({ top: mainScrollElem.scrollHeight }),
|
||||||
1
|
delay
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
80
|
||||||
|
);
|
||||||
|
|
||||||
|
export default function ChatScreen() {
|
||||||
|
const {
|
||||||
|
viewingChat,
|
||||||
|
sendMessage,
|
||||||
|
isGenerating,
|
||||||
|
stopGenerating,
|
||||||
|
pendingMessages,
|
||||||
|
canvasData,
|
||||||
|
replaceMessageAndGenerate,
|
||||||
|
} = useAppContext();
|
||||||
|
const [inputMsg, setInputMsg] = useState('');
|
||||||
|
|
||||||
|
// keep track of leaf node for rendering
|
||||||
|
const [currNodeId, setCurrNodeId] = useState<number>(-1);
|
||||||
|
const messages: MessageDisplay[] = useMemo(() => {
|
||||||
|
if (!viewingChat) return [];
|
||||||
|
else return getListMessageDisplay(viewingChat.messages, currNodeId);
|
||||||
|
}, [currNodeId, viewingChat]);
|
||||||
|
|
||||||
|
const currConvId = viewingChat?.conv.id ?? null;
|
||||||
|
const pendingMsg: PendingMessage | undefined =
|
||||||
|
pendingMessages[currConvId ?? ''];
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
// reset to latest node when conversation changes
|
||||||
|
setCurrNodeId(-1);
|
||||||
|
// scroll to bottom when conversation changes
|
||||||
|
scrollToBottom(false, 1);
|
||||||
|
}, [currConvId]);
|
||||||
|
|
||||||
|
const onChunk: CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => {
|
||||||
|
if (currLeafNodeId) {
|
||||||
|
setCurrNodeId(currLeafNodeId);
|
||||||
|
}
|
||||||
|
scrollToBottom(true);
|
||||||
};
|
};
|
||||||
|
|
||||||
// scroll to bottom when conversation changes
|
|
||||||
useEffect(() => {
|
|
||||||
scrollToBottom(false);
|
|
||||||
}, [viewingConversation?.id]);
|
|
||||||
|
|
||||||
const sendNewMessage = async () => {
|
const sendNewMessage = async () => {
|
||||||
if (inputMsg.trim().length === 0 || isGenerating(currConvId)) return;
|
if (inputMsg.trim().length === 0 || isGenerating(currConvId ?? '')) return;
|
||||||
const convId = viewingConversation?.id ?? StorageUtils.getNewConvId();
|
|
||||||
const lastInpMsg = inputMsg;
|
const lastInpMsg = inputMsg;
|
||||||
setInputMsg('');
|
setInputMsg('');
|
||||||
if (!viewingConversation) {
|
|
||||||
// if user is creating a new conversation, redirect to the new conversation
|
|
||||||
navigate(`/chat/${convId}`);
|
|
||||||
}
|
|
||||||
scrollToBottom(false);
|
scrollToBottom(false);
|
||||||
// auto scroll as message is being generated
|
setCurrNodeId(-1);
|
||||||
const onChunk = () => scrollToBottom(true);
|
// get the last message node
|
||||||
if (!(await sendMessage(convId, inputMsg, onChunk))) {
|
const lastMsgNodeId = messages.at(-1)?.msg.id ?? null;
|
||||||
|
if (!(await sendMessage(currConvId, lastMsgNodeId, inputMsg, onChunk))) {
|
||||||
// restore the input message if failed
|
// restore the input message if failed
|
||||||
setInputMsg(lastInpMsg);
|
setInputMsg(lastInpMsg);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const handleEditMessage = async (msg: Message, content: string) => {
|
||||||
|
if (!viewingChat) return;
|
||||||
|
setCurrNodeId(msg.id);
|
||||||
|
scrollToBottom(false);
|
||||||
|
await replaceMessageAndGenerate(
|
||||||
|
viewingChat.conv.id,
|
||||||
|
msg.parent,
|
||||||
|
content,
|
||||||
|
onChunk
|
||||||
|
);
|
||||||
|
setCurrNodeId(-1);
|
||||||
|
scrollToBottom(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleRegenerateMessage = async (msg: Message) => {
|
||||||
|
if (!viewingChat) return;
|
||||||
|
setCurrNodeId(msg.parent);
|
||||||
|
scrollToBottom(false);
|
||||||
|
await replaceMessageAndGenerate(
|
||||||
|
viewingChat.conv.id,
|
||||||
|
msg.parent,
|
||||||
|
null,
|
||||||
|
onChunk
|
||||||
|
);
|
||||||
|
setCurrNodeId(-1);
|
||||||
|
scrollToBottom(false);
|
||||||
|
};
|
||||||
|
|
||||||
const hasCanvas = !!canvasData;
|
const hasCanvas = !!canvasData;
|
||||||
|
|
||||||
|
// due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg)
|
||||||
|
const pendingMsgDisplay: MessageDisplay[] =
|
||||||
|
pendingMsg && messages.at(-1)?.msg.id !== pendingMsg.id
|
||||||
|
? [
|
||||||
|
{
|
||||||
|
msg: pendingMsg,
|
||||||
|
siblingLeafNodeIds: [],
|
||||||
|
siblingCurrIdx: 0,
|
||||||
|
isPending: true,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
: [];
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={classNames({
|
className={classNames({
|
||||||
|
@ -81,24 +183,19 @@ export default function ChatScreen() {
|
||||||
<div id="messages-list" className="grow">
|
<div id="messages-list" className="grow">
|
||||||
<div className="mt-auto flex justify-center">
|
<div className="mt-auto flex justify-center">
|
||||||
{/* placeholder to shift the message to the bottom */}
|
{/* placeholder to shift the message to the bottom */}
|
||||||
{viewingConversation ? '' : 'Send a message to start'}
|
{viewingChat ? '' : 'Send a message to start'}
|
||||||
</div>
|
</div>
|
||||||
{viewingConversation?.messages.map((msg) => (
|
{[...messages, ...pendingMsgDisplay].map((msg) => (
|
||||||
<ChatMessage
|
<ChatMessage
|
||||||
key={msg.id}
|
key={msg.msg.id}
|
||||||
msg={msg}
|
msg={msg.msg}
|
||||||
scrollToBottom={scrollToBottom}
|
siblingLeafNodeIds={msg.siblingLeafNodeIds}
|
||||||
|
siblingCurrIdx={msg.siblingCurrIdx}
|
||||||
|
onRegenerateMessage={handleRegenerateMessage}
|
||||||
|
onEditMessage={handleEditMessage}
|
||||||
|
onChangeSibling={setCurrNodeId}
|
||||||
/>
|
/>
|
||||||
))}
|
))}
|
||||||
|
|
||||||
{pendingMsg && (
|
|
||||||
<ChatMessage
|
|
||||||
msg={pendingMsg}
|
|
||||||
scrollToBottom={scrollToBottom}
|
|
||||||
isPending
|
|
||||||
id="pending-msg"
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* chat input */}
|
{/* chat input */}
|
||||||
|
@ -118,10 +215,10 @@ export default function ChatScreen() {
|
||||||
id="msg-input"
|
id="msg-input"
|
||||||
dir="auto"
|
dir="auto"
|
||||||
></textarea>
|
></textarea>
|
||||||
{isGenerating(currConvId) ? (
|
{isGenerating(currConvId ?? '') ? (
|
||||||
<button
|
<button
|
||||||
className="btn btn-neutral ml-2"
|
className="btn btn-neutral ml-2"
|
||||||
onClick={() => stopGenerating(currConvId)}
|
onClick={() => stopGenerating(currConvId ?? '')}
|
||||||
>
|
>
|
||||||
Stop
|
Stop
|
||||||
</button>
|
</button>
|
||||||
|
|
|
@ -25,12 +25,12 @@ export default function Header() {
|
||||||
);
|
);
|
||||||
}, [selectedTheme]);
|
}, [selectedTheme]);
|
||||||
|
|
||||||
const { isGenerating, viewingConversation } = useAppContext();
|
const { isGenerating, viewingChat } = useAppContext();
|
||||||
const isCurrConvGenerating = isGenerating(viewingConversation?.id ?? '');
|
const isCurrConvGenerating = isGenerating(viewingChat?.conv.id ?? '');
|
||||||
|
|
||||||
const removeConversation = () => {
|
const removeConversation = () => {
|
||||||
if (isCurrConvGenerating || !viewingConversation) return;
|
if (isCurrConvGenerating || !viewingChat) return;
|
||||||
const convId = viewingConversation.id;
|
const convId = viewingChat?.conv.id;
|
||||||
if (window.confirm('Are you sure to delete this conversation?')) {
|
if (window.confirm('Are you sure to delete this conversation?')) {
|
||||||
StorageUtils.remove(convId);
|
StorageUtils.remove(convId);
|
||||||
navigate('/');
|
navigate('/');
|
||||||
|
@ -38,9 +38,9 @@ export default function Header() {
|
||||||
};
|
};
|
||||||
|
|
||||||
const downloadConversation = () => {
|
const downloadConversation = () => {
|
||||||
if (isCurrConvGenerating || !viewingConversation) return;
|
if (isCurrConvGenerating || !viewingChat) return;
|
||||||
const convId = viewingConversation.id;
|
const convId = viewingChat?.conv.id;
|
||||||
const conversationJson = JSON.stringify(viewingConversation, null, 2);
|
const conversationJson = JSON.stringify(viewingChat, null, 2);
|
||||||
const blob = new Blob([conversationJson], { type: 'application/json' });
|
const blob = new Blob([conversationJson], { type: 'application/json' });
|
||||||
const url = URL.createObjectURL(blob);
|
const url = URL.createObjectURL(blob);
|
||||||
const a = document.createElement('a');
|
const a = document.createElement('a');
|
||||||
|
@ -75,7 +75,8 @@ export default function Header() {
|
||||||
|
|
||||||
{/* action buttons (top right) */}
|
{/* action buttons (top right) */}
|
||||||
<div className="flex items-center">
|
<div className="flex items-center">
|
||||||
<div v-if="messages.length > 0" className="dropdown dropdown-end">
|
{viewingChat && (
|
||||||
|
<div className="dropdown dropdown-end">
|
||||||
{/* "..." button */}
|
{/* "..." button */}
|
||||||
<button
|
<button
|
||||||
tabIndex={0}
|
tabIndex={0}
|
||||||
|
@ -107,6 +108,8 @@ export default function Header() {
|
||||||
</li>
|
</li>
|
||||||
</ul>
|
</ul>
|
||||||
</div>
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
<div className="tooltip tooltip-bottom" data-tip="Settings">
|
<div className="tooltip tooltip-bottom" data-tip="Settings">
|
||||||
<button className="btn" onClick={() => setShowSettings(true)}>
|
<button className="btn" onClick={() => setShowSettings(true)}>
|
||||||
{/* settings button */}
|
{/* settings button */}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import { useEffect, useMemo, useState } from 'react';
|
import { useEffect, useState } from 'react';
|
||||||
import { classNames } from '../utils/misc';
|
import { classNames } from '../utils/misc';
|
||||||
import { Conversation } from '../utils/types';
|
import { Conversation } from '../utils/types';
|
||||||
import StorageUtils from '../utils/storage';
|
import StorageUtils from '../utils/storage';
|
||||||
|
@ -7,16 +7,17 @@ import { useNavigate, useParams } from 'react-router';
|
||||||
export default function Sidebar() {
|
export default function Sidebar() {
|
||||||
const params = useParams();
|
const params = useParams();
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
const currConv = useMemo(
|
|
||||||
() => StorageUtils.getOneConversation(params.convId ?? ''),
|
|
||||||
[params.convId]
|
|
||||||
);
|
|
||||||
|
|
||||||
const [conversations, setConversations] = useState<Conversation[]>([]);
|
const [conversations, setConversations] = useState<Conversation[]>([]);
|
||||||
|
const [currConv, setCurrConv] = useState<Conversation | null>(null);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const handleConversationChange = () => {
|
StorageUtils.getOneConversation(params.convId ?? '').then(setCurrConv);
|
||||||
setConversations(StorageUtils.getAllConversations());
|
}, [params.convId]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const handleConversationChange = async () => {
|
||||||
|
setConversations(await StorageUtils.getAllConversations());
|
||||||
};
|
};
|
||||||
StorageUtils.onConversationChanged(handleConversationChange);
|
StorageUtils.onConversationChanged(handleConversationChange);
|
||||||
handleConversationChange();
|
handleConversationChange();
|
||||||
|
@ -82,11 +83,11 @@ export default function Sidebar() {
|
||||||
onClick={() => navigate(`/chat/${conv.id}`)}
|
onClick={() => navigate(`/chat/${conv.id}`)}
|
||||||
dir="auto"
|
dir="auto"
|
||||||
>
|
>
|
||||||
<span className="truncate">{conv.messages[0].content}</span>
|
<span className="truncate">{conv.name}</span>
|
||||||
</div>
|
</div>
|
||||||
))}
|
))}
|
||||||
<div className="text-center text-xs opacity-40 mt-auto mx-4">
|
<div className="text-center text-xs opacity-40 mt-auto mx-4">
|
||||||
Conversations are saved to browser's localStorage
|
Conversations are saved to browser's IndexedDB
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
@ -5,6 +5,7 @@ import {
|
||||||
Conversation,
|
Conversation,
|
||||||
Message,
|
Message,
|
||||||
PendingMessage,
|
PendingMessage,
|
||||||
|
ViewingChat,
|
||||||
} from './types';
|
} from './types';
|
||||||
import StorageUtils from './storage';
|
import StorageUtils from './storage';
|
||||||
import {
|
import {
|
||||||
|
@ -13,24 +14,25 @@ import {
|
||||||
getSSEStreamAsync,
|
getSSEStreamAsync,
|
||||||
} from './misc';
|
} from './misc';
|
||||||
import { BASE_URL, CONFIG_DEFAULT, isDev } from '../Config';
|
import { BASE_URL, CONFIG_DEFAULT, isDev } from '../Config';
|
||||||
import { matchPath, useLocation } from 'react-router';
|
import { matchPath, useLocation, useNavigate } from 'react-router';
|
||||||
|
|
||||||
interface AppContextValue {
|
interface AppContextValue {
|
||||||
// conversations and messages
|
// conversations and messages
|
||||||
viewingConversation: Conversation | null;
|
viewingChat: ViewingChat | null;
|
||||||
pendingMessages: Record<Conversation['id'], PendingMessage>;
|
pendingMessages: Record<Conversation['id'], PendingMessage>;
|
||||||
isGenerating: (convId: string) => boolean;
|
isGenerating: (convId: string) => boolean;
|
||||||
sendMessage: (
|
sendMessage: (
|
||||||
convId: string,
|
convId: string | null,
|
||||||
|
leafNodeId: Message['id'] | null,
|
||||||
content: string,
|
content: string,
|
||||||
onChunk?: CallbackGeneratedChunk
|
onChunk: CallbackGeneratedChunk
|
||||||
) => Promise<boolean>;
|
) => Promise<boolean>;
|
||||||
stopGenerating: (convId: string) => void;
|
stopGenerating: (convId: string) => void;
|
||||||
replaceMessageAndGenerate: (
|
replaceMessageAndGenerate: (
|
||||||
convId: string,
|
convId: string,
|
||||||
origMsgId: Message['id'],
|
parentNodeId: Message['id'], // the parent node of the message to be replaced
|
||||||
content?: string,
|
content: string | null,
|
||||||
onChunk?: CallbackGeneratedChunk
|
onChunk: CallbackGeneratedChunk
|
||||||
) => Promise<void>;
|
) => Promise<void>;
|
||||||
|
|
||||||
// canvas
|
// canvas
|
||||||
|
@ -44,23 +46,33 @@ interface AppContextValue {
|
||||||
setShowSettings: (show: boolean) => void;
|
setShowSettings: (show: boolean) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
// for now, this callback is only used for scrolling to the bottom of the chat
|
// this callback is used for scrolling to the bottom of the chat and switching to the last node
|
||||||
type CallbackGeneratedChunk = () => void;
|
export type CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => void;
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
const AppContext = createContext<AppContextValue>({} as any);
|
const AppContext = createContext<AppContextValue>({} as any);
|
||||||
|
|
||||||
|
const getViewingChat = async (convId: string): Promise<ViewingChat | null> => {
|
||||||
|
const conv = await StorageUtils.getOneConversation(convId);
|
||||||
|
if (!conv) return null;
|
||||||
|
return {
|
||||||
|
conv: conv,
|
||||||
|
// all messages from all branches, not filtered by last node
|
||||||
|
messages: await StorageUtils.getMessages(convId),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
export const AppContextProvider = ({
|
export const AppContextProvider = ({
|
||||||
children,
|
children,
|
||||||
}: {
|
}: {
|
||||||
children: React.ReactElement;
|
children: React.ReactElement;
|
||||||
}) => {
|
}) => {
|
||||||
const { pathname } = useLocation();
|
const { pathname } = useLocation();
|
||||||
|
const navigate = useNavigate();
|
||||||
const params = matchPath('/chat/:convId', pathname);
|
const params = matchPath('/chat/:convId', pathname);
|
||||||
const convId = params?.params?.convId;
|
const convId = params?.params?.convId;
|
||||||
|
|
||||||
const [viewingConversation, setViewingConversation] =
|
const [viewingChat, setViewingChat] = useState<ViewingChat | null>(null);
|
||||||
useState<Conversation | null>(null);
|
|
||||||
const [pendingMessages, setPendingMessages] = useState<
|
const [pendingMessages, setPendingMessages] = useState<
|
||||||
Record<Conversation['id'], PendingMessage>
|
Record<Conversation['id'], PendingMessage>
|
||||||
>({});
|
>({});
|
||||||
|
@ -75,12 +87,12 @@ export const AppContextProvider = ({
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
// also reset the canvas data
|
// also reset the canvas data
|
||||||
setCanvasData(null);
|
setCanvasData(null);
|
||||||
const handleConversationChange = (changedConvId: string) => {
|
const handleConversationChange = async (changedConvId: string) => {
|
||||||
if (changedConvId !== convId) return;
|
if (changedConvId !== convId) return;
|
||||||
setViewingConversation(StorageUtils.getOneConversation(convId));
|
setViewingChat(await getViewingChat(changedConvId));
|
||||||
};
|
};
|
||||||
StorageUtils.onConversationChanged(handleConversationChange);
|
StorageUtils.onConversationChanged(handleConversationChange);
|
||||||
setViewingConversation(StorageUtils.getOneConversation(convId ?? ''));
|
getViewingChat(convId ?? '').then(setViewingChat);
|
||||||
return () => {
|
return () => {
|
||||||
StorageUtils.offConversationChanged(handleConversationChange);
|
StorageUtils.offConversationChanged(handleConversationChange);
|
||||||
};
|
};
|
||||||
|
@ -118,23 +130,39 @@ export const AppContextProvider = ({
|
||||||
|
|
||||||
const generateMessage = async (
|
const generateMessage = async (
|
||||||
convId: string,
|
convId: string,
|
||||||
onChunk?: CallbackGeneratedChunk
|
leafNodeId: Message['id'],
|
||||||
|
onChunk: CallbackGeneratedChunk
|
||||||
) => {
|
) => {
|
||||||
if (isGenerating(convId)) return;
|
if (isGenerating(convId)) return;
|
||||||
|
|
||||||
const config = StorageUtils.getConfig();
|
const config = StorageUtils.getConfig();
|
||||||
const currConversation = StorageUtils.getOneConversation(convId);
|
const currConversation = await StorageUtils.getOneConversation(convId);
|
||||||
if (!currConversation) {
|
if (!currConversation) {
|
||||||
throw new Error('Current conversation is not found');
|
throw new Error('Current conversation is not found');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const currMessages = StorageUtils.filterByLeafNodeId(
|
||||||
|
await StorageUtils.getMessages(convId),
|
||||||
|
leafNodeId,
|
||||||
|
false
|
||||||
|
);
|
||||||
const abortController = new AbortController();
|
const abortController = new AbortController();
|
||||||
setAbort(convId, abortController);
|
setAbort(convId, abortController);
|
||||||
|
|
||||||
|
if (!currMessages) {
|
||||||
|
throw new Error('Current messages are not found');
|
||||||
|
}
|
||||||
|
|
||||||
|
const pendingId = Date.now() + 1;
|
||||||
let pendingMsg: PendingMessage = {
|
let pendingMsg: PendingMessage = {
|
||||||
id: Date.now() + 1,
|
id: pendingId,
|
||||||
|
convId,
|
||||||
|
type: 'text',
|
||||||
|
timestamp: pendingId,
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
content: null,
|
content: null,
|
||||||
|
parent: leafNodeId,
|
||||||
|
children: [],
|
||||||
};
|
};
|
||||||
setPending(convId, pendingMsg);
|
setPending(convId, pendingMsg);
|
||||||
|
|
||||||
|
@ -144,7 +172,7 @@ export const AppContextProvider = ({
|
||||||
...(config.systemMessage.length === 0
|
...(config.systemMessage.length === 0
|
||||||
? []
|
? []
|
||||||
: [{ role: 'system', content: config.systemMessage } as APIMessage]),
|
: [{ role: 'system', content: config.systemMessage } as APIMessage]),
|
||||||
...normalizeMsgsForAPI(currConversation?.messages ?? []),
|
...normalizeMsgsForAPI(currMessages),
|
||||||
];
|
];
|
||||||
if (config.excludeThoughtOnReq) {
|
if (config.excludeThoughtOnReq) {
|
||||||
messages = filterThoughtFromMsgs(messages);
|
messages = filterThoughtFromMsgs(messages);
|
||||||
|
@ -205,8 +233,7 @@ export const AppContextProvider = ({
|
||||||
const lastContent = pendingMsg.content || '';
|
const lastContent = pendingMsg.content || '';
|
||||||
if (addedContent) {
|
if (addedContent) {
|
||||||
pendingMsg = {
|
pendingMsg = {
|
||||||
id: pendingMsg.id,
|
...pendingMsg,
|
||||||
role: 'assistant',
|
|
||||||
content: lastContent + addedContent,
|
content: lastContent + addedContent,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -221,7 +248,7 @@ export const AppContextProvider = ({
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
setPending(convId, pendingMsg);
|
setPending(convId, pendingMsg);
|
||||||
onChunk?.();
|
onChunk(); // don't need to switch node for pending message
|
||||||
}
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
setPending(convId, null);
|
setPending(convId, null);
|
||||||
|
@ -236,37 +263,53 @@ export const AppContextProvider = ({
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pendingMsg.content) {
|
if (pendingMsg.content !== null) {
|
||||||
StorageUtils.appendMsg(currConversation.id, {
|
await StorageUtils.appendMsg(pendingMsg as Message, leafNodeId);
|
||||||
id: pendingMsg.id,
|
|
||||||
content: pendingMsg.content,
|
|
||||||
role: pendingMsg.role,
|
|
||||||
timings: pendingMsg.timings,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
setPending(convId, null);
|
setPending(convId, null);
|
||||||
onChunk?.(); // trigger scroll to bottom
|
onChunk(pendingId); // trigger scroll to bottom and switch to the last node
|
||||||
};
|
};
|
||||||
|
|
||||||
const sendMessage = async (
|
const sendMessage = async (
|
||||||
convId: string,
|
convId: string | null,
|
||||||
|
leafNodeId: Message['id'] | null,
|
||||||
content: string,
|
content: string,
|
||||||
onChunk?: CallbackGeneratedChunk
|
onChunk: CallbackGeneratedChunk
|
||||||
): Promise<boolean> => {
|
): Promise<boolean> => {
|
||||||
if (isGenerating(convId) || content.trim().length === 0) return false;
|
if (isGenerating(convId ?? '') || content.trim().length === 0) return false;
|
||||||
|
|
||||||
StorageUtils.appendMsg(convId, {
|
if (convId === null || convId.length === 0 || leafNodeId === null) {
|
||||||
id: Date.now(),
|
const conv = await StorageUtils.createConversation(
|
||||||
|
content.substring(0, 256)
|
||||||
|
);
|
||||||
|
convId = conv.id;
|
||||||
|
leafNodeId = conv.currNode;
|
||||||
|
// if user is creating a new conversation, redirect to the new conversation
|
||||||
|
navigate(`/chat/${convId}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const now = Date.now();
|
||||||
|
const currMsgId = now;
|
||||||
|
StorageUtils.appendMsg(
|
||||||
|
{
|
||||||
|
id: currMsgId,
|
||||||
|
timestamp: now,
|
||||||
|
type: 'text',
|
||||||
|
convId,
|
||||||
role: 'user',
|
role: 'user',
|
||||||
content,
|
content,
|
||||||
});
|
parent: leafNodeId,
|
||||||
|
children: [],
|
||||||
|
},
|
||||||
|
leafNodeId
|
||||||
|
);
|
||||||
|
onChunk(currMsgId);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await generateMessage(convId, onChunk);
|
await generateMessage(convId, currMsgId, onChunk);
|
||||||
return true;
|
return true;
|
||||||
} catch (_) {
|
} catch (_) {
|
||||||
// rollback
|
// TODO: rollback
|
||||||
StorageUtils.popMsg(convId);
|
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
|
@ -279,22 +322,33 @@ export const AppContextProvider = ({
|
||||||
// if content is undefined, we remove last assistant message
|
// if content is undefined, we remove last assistant message
|
||||||
const replaceMessageAndGenerate = async (
|
const replaceMessageAndGenerate = async (
|
||||||
convId: string,
|
convId: string,
|
||||||
origMsgId: Message['id'],
|
parentNodeId: Message['id'], // the parent node of the message to be replaced
|
||||||
content?: string,
|
content: string | null,
|
||||||
onChunk?: CallbackGeneratedChunk
|
onChunk: CallbackGeneratedChunk
|
||||||
) => {
|
) => {
|
||||||
if (isGenerating(convId)) return;
|
if (isGenerating(convId)) return;
|
||||||
|
|
||||||
StorageUtils.filterAndKeepMsgs(convId, (msg) => msg.id < origMsgId);
|
if (content !== null) {
|
||||||
if (content) {
|
const now = Date.now();
|
||||||
StorageUtils.appendMsg(convId, {
|
const currMsgId = now;
|
||||||
id: Date.now(),
|
StorageUtils.appendMsg(
|
||||||
|
{
|
||||||
|
id: currMsgId,
|
||||||
|
timestamp: now,
|
||||||
|
type: 'text',
|
||||||
|
convId,
|
||||||
role: 'user',
|
role: 'user',
|
||||||
content,
|
content,
|
||||||
});
|
parent: parentNodeId,
|
||||||
|
children: [],
|
||||||
|
},
|
||||||
|
parentNodeId
|
||||||
|
);
|
||||||
|
parentNodeId = currMsgId;
|
||||||
}
|
}
|
||||||
|
onChunk(parentNodeId);
|
||||||
|
|
||||||
await generateMessage(convId, onChunk);
|
await generateMessage(convId, parentNodeId, onChunk);
|
||||||
};
|
};
|
||||||
|
|
||||||
const saveConfig = (config: typeof CONFIG_DEFAULT) => {
|
const saveConfig = (config: typeof CONFIG_DEFAULT) => {
|
||||||
|
@ -306,7 +360,7 @@ export const AppContextProvider = ({
|
||||||
<AppContext.Provider
|
<AppContext.Provider
|
||||||
value={{
|
value={{
|
||||||
isGenerating,
|
isGenerating,
|
||||||
viewingConversation,
|
viewingChat,
|
||||||
pendingMessages,
|
pendingMessages,
|
||||||
sendMessage,
|
sendMessage,
|
||||||
stopGenerating,
|
stopGenerating,
|
||||||
|
|
|
@ -4,7 +4,6 @@ import { APIMessage, Message } from './types';
|
||||||
|
|
||||||
// ponyfill for missing ReadableStream asyncIterator on Safari
|
// ponyfill for missing ReadableStream asyncIterator on Safari
|
||||||
import { asyncIterator } from '@sec-ant/readable-stream/ponyfill/asyncIterator';
|
import { asyncIterator } from '@sec-ant/readable-stream/ponyfill/asyncIterator';
|
||||||
import { isDev } from '../Config';
|
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
export const isString = (x: any) => !!x.toLowerCase;
|
export const isString = (x: any) => !!x.toLowerCase;
|
||||||
|
@ -23,7 +22,7 @@ export async function* getSSEStreamAsync(fetchResponse: Response) {
|
||||||
.pipeThrough(new TextLineStream());
|
.pipeThrough(new TextLineStream());
|
||||||
// @ts-expect-error asyncIterator complains about type, but it should work
|
// @ts-expect-error asyncIterator complains about type, but it should work
|
||||||
for await (const line of asyncIterator(lines)) {
|
for await (const line of asyncIterator(lines)) {
|
||||||
if (isDev) console.log({ line });
|
//if (isDev) console.log({ line });
|
||||||
if (line.startsWith('data:') && !line.endsWith('[DONE]')) {
|
if (line.startsWith('data:') && !line.endsWith('[DONE]')) {
|
||||||
const data = JSON.parse(line.slice(5));
|
const data = JSON.parse(line.slice(5));
|
||||||
yield data;
|
yield data;
|
||||||
|
@ -55,7 +54,7 @@ export const copyStr = (textToCopy: string) => {
|
||||||
/**
|
/**
|
||||||
* filter out redundant fields upon sending to API
|
* filter out redundant fields upon sending to API
|
||||||
*/
|
*/
|
||||||
export function normalizeMsgsForAPI(messages: Message[]) {
|
export function normalizeMsgsForAPI(messages: Readonly<Message[]>) {
|
||||||
return messages.map((msg) => {
|
return messages.map((msg) => {
|
||||||
return {
|
return {
|
||||||
role: msg.role,
|
role: msg.role,
|
||||||
|
@ -88,3 +87,23 @@ export function classNames(classes: Record<string, boolean>): string {
|
||||||
|
|
||||||
export const delay = (ms: number) =>
|
export const delay = (ms: number) =>
|
||||||
new Promise((resolve) => setTimeout(resolve, ms));
|
new Promise((resolve) => setTimeout(resolve, ms));
|
||||||
|
|
||||||
|
export const throttle = <T extends unknown[]>(
|
||||||
|
callback: (...args: T) => void,
|
||||||
|
delay: number
|
||||||
|
) => {
|
||||||
|
let isWaiting = false;
|
||||||
|
|
||||||
|
return (...args: T) => {
|
||||||
|
if (isWaiting) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
callback(...args);
|
||||||
|
isWaiting = true;
|
||||||
|
|
||||||
|
setTimeout(() => {
|
||||||
|
isWaiting = false;
|
||||||
|
}, delay);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
|
@ -2,7 +2,8 @@
|
||||||
// format: { [convId]: { id: string, lastModified: number, messages: [...] } }
|
// format: { [convId]: { id: string, lastModified: number, messages: [...] } }
|
||||||
|
|
||||||
import { CONFIG_DEFAULT } from '../Config';
|
import { CONFIG_DEFAULT } from '../Config';
|
||||||
import { Conversation, Message } from './types';
|
import { Conversation, Message, TimingReport } from './types';
|
||||||
|
import Dexie, { Table } from 'dexie';
|
||||||
|
|
||||||
const event = new EventTarget();
|
const event = new EventTarget();
|
||||||
|
|
||||||
|
@ -17,85 +18,154 @@ const dispatchConversationChange = (convId: string) => {
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const db = new Dexie('LlamacppWebui') as Dexie & {
|
||||||
|
conversations: Table<Conversation>;
|
||||||
|
messages: Table<Message>;
|
||||||
|
};
|
||||||
|
|
||||||
|
// https://dexie.org/docs/Version/Version.stores()
|
||||||
|
db.version(1).stores({
|
||||||
|
// Unlike SQL, you don’t need to specify all properties but only the one you wish to index.
|
||||||
|
conversations: '&id, lastModified',
|
||||||
|
messages: '&id, convId, [convId+id], timestamp',
|
||||||
|
});
|
||||||
|
|
||||||
// convId is a string prefixed with 'conv-'
|
// convId is a string prefixed with 'conv-'
|
||||||
const StorageUtils = {
|
const StorageUtils = {
|
||||||
/**
|
/**
|
||||||
* manage conversations
|
* manage conversations
|
||||||
*/
|
*/
|
||||||
getAllConversations(): Conversation[] {
|
async getAllConversations(): Promise<Conversation[]> {
|
||||||
const res = [];
|
await migrationLStoIDB().catch(console.error); // noop if already migrated
|
||||||
for (const key in localStorage) {
|
return (await db.conversations.toArray()).sort(
|
||||||
if (key.startsWith('conv-')) {
|
(a, b) => b.lastModified - a.lastModified
|
||||||
res.push(JSON.parse(localStorage.getItem(key) ?? '{}'));
|
);
|
||||||
}
|
|
||||||
}
|
|
||||||
res.sort((a, b) => b.lastModified - a.lastModified);
|
|
||||||
return res;
|
|
||||||
},
|
},
|
||||||
/**
|
/**
|
||||||
* can return null if convId does not exist
|
* can return null if convId does not exist
|
||||||
*/
|
*/
|
||||||
getOneConversation(convId: string): Conversation | null {
|
async getOneConversation(convId: string): Promise<Conversation | null> {
|
||||||
return JSON.parse(localStorage.getItem(convId) || 'null');
|
return (await db.conversations.where('id').equals(convId).first()) ?? null;
|
||||||
},
|
},
|
||||||
/**
|
/**
|
||||||
* if convId does not exist, create one
|
* get all message nodes in a conversation
|
||||||
*/
|
*/
|
||||||
appendMsg(convId: string, msg: Message): void {
|
async getMessages(convId: string): Promise<Message[]> {
|
||||||
if (msg.content === null) return;
|
return await db.messages.where({ convId }).toArray();
|
||||||
const conv = StorageUtils.getOneConversation(convId) || {
|
},
|
||||||
id: convId,
|
/**
|
||||||
lastModified: Date.now(),
|
* use in conjunction with getMessages to filter messages by leafNodeId
|
||||||
messages: [],
|
* includeRoot: whether to include the root node in the result
|
||||||
|
* if node with leafNodeId does not exist, return the path with the latest timestamp
|
||||||
|
*/
|
||||||
|
filterByLeafNodeId(
|
||||||
|
msgs: Readonly<Message[]>,
|
||||||
|
leafNodeId: Message['id'],
|
||||||
|
includeRoot: boolean
|
||||||
|
): Readonly<Message[]> {
|
||||||
|
const res: Message[] = [];
|
||||||
|
const nodeMap = new Map<Message['id'], Message>();
|
||||||
|
for (const msg of msgs) {
|
||||||
|
nodeMap.set(msg.id, msg);
|
||||||
|
}
|
||||||
|
let startNode: Message | undefined = nodeMap.get(leafNodeId);
|
||||||
|
if (!startNode) {
|
||||||
|
// if not found, we return the path with the latest timestamp
|
||||||
|
let latestTime = -1;
|
||||||
|
for (const msg of msgs) {
|
||||||
|
if (msg.timestamp > latestTime) {
|
||||||
|
startNode = msg;
|
||||||
|
latestTime = msg.timestamp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// traverse the path from leafNodeId to root
|
||||||
|
// startNode can never be undefined here
|
||||||
|
let currNode: Message | undefined = startNode;
|
||||||
|
while (currNode) {
|
||||||
|
if (currNode.type !== 'root' || (currNode.type === 'root' && includeRoot))
|
||||||
|
res.push(currNode);
|
||||||
|
currNode = nodeMap.get(currNode.parent ?? -1);
|
||||||
|
}
|
||||||
|
res.sort((a, b) => a.timestamp - b.timestamp);
|
||||||
|
return res;
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* create a new conversation with a default root node
|
||||||
|
*/
|
||||||
|
async createConversation(name: string): Promise<Conversation> {
|
||||||
|
const now = Date.now();
|
||||||
|
const msgId = now;
|
||||||
|
const conv: Conversation = {
|
||||||
|
id: `conv-${now}`,
|
||||||
|
lastModified: now,
|
||||||
|
currNode: msgId,
|
||||||
|
name,
|
||||||
};
|
};
|
||||||
conv.messages.push(msg);
|
await db.conversations.add(conv);
|
||||||
conv.lastModified = Date.now();
|
// create a root node
|
||||||
localStorage.setItem(convId, JSON.stringify(conv));
|
await db.messages.add({
|
||||||
dispatchConversationChange(convId);
|
id: msgId,
|
||||||
|
convId: conv.id,
|
||||||
|
type: 'root',
|
||||||
|
timestamp: now,
|
||||||
|
role: 'system',
|
||||||
|
content: '',
|
||||||
|
parent: -1,
|
||||||
|
children: [],
|
||||||
|
});
|
||||||
|
return conv;
|
||||||
},
|
},
|
||||||
/**
|
/**
|
||||||
* Get new conversation id
|
* if convId does not exist, throw an error
|
||||||
*/
|
*/
|
||||||
getNewConvId(): string {
|
async appendMsg(
|
||||||
return `conv-${Date.now()}`;
|
msg: Exclude<Message, 'parent' | 'children'>,
|
||||||
|
parentNodeId: Message['id']
|
||||||
|
): Promise<void> {
|
||||||
|
if (msg.content === null) return;
|
||||||
|
const { convId } = msg;
|
||||||
|
await db.transaction('rw', db.conversations, db.messages, async () => {
|
||||||
|
const conv = await StorageUtils.getOneConversation(convId);
|
||||||
|
const parentMsg = await db.messages
|
||||||
|
.where({ convId, id: parentNodeId })
|
||||||
|
.first();
|
||||||
|
// update the currNode of conversation
|
||||||
|
if (!conv) {
|
||||||
|
throw new Error(`Conversation ${convId} does not exist`);
|
||||||
|
}
|
||||||
|
if (!parentMsg) {
|
||||||
|
throw new Error(
|
||||||
|
`Parent message ID ${parentNodeId} does not exist in conversation ${convId}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
await db.conversations.update(convId, {
|
||||||
|
lastModified: Date.now(),
|
||||||
|
currNode: msg.id,
|
||||||
|
});
|
||||||
|
// update parent
|
||||||
|
await db.messages.update(parentNodeId, {
|
||||||
|
children: [...parentMsg.children, msg.id],
|
||||||
|
});
|
||||||
|
// create message
|
||||||
|
await db.messages.add({
|
||||||
|
...msg,
|
||||||
|
parent: parentNodeId,
|
||||||
|
children: [],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
dispatchConversationChange(convId);
|
||||||
},
|
},
|
||||||
/**
|
/**
|
||||||
* remove conversation by id
|
* remove conversation by id
|
||||||
*/
|
*/
|
||||||
remove(convId: string): void {
|
async remove(convId: string): Promise<void> {
|
||||||
localStorage.removeItem(convId);
|
await db.transaction('rw', db.conversations, db.messages, async () => {
|
||||||
|
await db.conversations.delete(convId);
|
||||||
|
await db.messages.where({ convId }).delete();
|
||||||
|
});
|
||||||
dispatchConversationChange(convId);
|
dispatchConversationChange(convId);
|
||||||
},
|
},
|
||||||
/**
|
|
||||||
* remove all conversations
|
|
||||||
*/
|
|
||||||
filterAndKeepMsgs(
|
|
||||||
convId: string,
|
|
||||||
predicate: (msg: Message) => boolean
|
|
||||||
): void {
|
|
||||||
const conv = StorageUtils.getOneConversation(convId);
|
|
||||||
if (!conv) return;
|
|
||||||
conv.messages = conv.messages.filter(predicate);
|
|
||||||
conv.lastModified = Date.now();
|
|
||||||
localStorage.setItem(convId, JSON.stringify(conv));
|
|
||||||
dispatchConversationChange(convId);
|
|
||||||
},
|
|
||||||
/**
|
|
||||||
* remove last message from conversation
|
|
||||||
*/
|
|
||||||
popMsg(convId: string): Message | undefined {
|
|
||||||
const conv = StorageUtils.getOneConversation(convId);
|
|
||||||
if (!conv) return;
|
|
||||||
const msg = conv.messages.pop();
|
|
||||||
conv.lastModified = Date.now();
|
|
||||||
if (conv.messages.length === 0) {
|
|
||||||
StorageUtils.remove(convId);
|
|
||||||
} else {
|
|
||||||
localStorage.setItem(convId, JSON.stringify(conv));
|
|
||||||
}
|
|
||||||
dispatchConversationChange(convId);
|
|
||||||
return msg;
|
|
||||||
},
|
|
||||||
|
|
||||||
// event listeners
|
// event listeners
|
||||||
onConversationChanged(callback: CallbackConversationChanged) {
|
onConversationChanged(callback: CallbackConversationChanged) {
|
||||||
|
@ -136,3 +206,79 @@ const StorageUtils = {
|
||||||
};
|
};
|
||||||
|
|
||||||
export default StorageUtils;
|
export default StorageUtils;
|
||||||
|
|
||||||
|
// Migration from localStorage to IndexedDB
|
||||||
|
|
||||||
|
// these are old types, LS prefix stands for LocalStorage
|
||||||
|
interface LSConversation {
|
||||||
|
id: string; // format: `conv-{timestamp}`
|
||||||
|
lastModified: number; // timestamp from Date.now()
|
||||||
|
messages: LSMessage[];
|
||||||
|
}
|
||||||
|
interface LSMessage {
|
||||||
|
id: number;
|
||||||
|
role: 'user' | 'assistant' | 'system';
|
||||||
|
content: string;
|
||||||
|
timings?: TimingReport;
|
||||||
|
}
|
||||||
|
async function migrationLStoIDB() {
|
||||||
|
if (localStorage.getItem('migratedToIDB')) return;
|
||||||
|
const res: LSConversation[] = [];
|
||||||
|
for (const key in localStorage) {
|
||||||
|
if (key.startsWith('conv-')) {
|
||||||
|
res.push(JSON.parse(localStorage.getItem(key) ?? '{}'));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (res.length === 0) return;
|
||||||
|
await db.transaction('rw', db.conversations, db.messages, async () => {
|
||||||
|
let migratedCount = 0;
|
||||||
|
for (const conv of res) {
|
||||||
|
const { id: convId, lastModified, messages } = conv;
|
||||||
|
const firstMsg = messages[0];
|
||||||
|
const lastMsg = messages.at(-1);
|
||||||
|
if (messages.length < 2 || !firstMsg || !lastMsg) {
|
||||||
|
console.log(
|
||||||
|
`Skipping conversation ${convId} with ${messages.length} messages`
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const name = firstMsg.content ?? '(no messages)';
|
||||||
|
await db.conversations.add({
|
||||||
|
id: convId,
|
||||||
|
lastModified,
|
||||||
|
currNode: lastMsg.id,
|
||||||
|
name,
|
||||||
|
});
|
||||||
|
const rootId = messages[0].id - 2;
|
||||||
|
await db.messages.add({
|
||||||
|
id: rootId,
|
||||||
|
convId: convId,
|
||||||
|
type: 'root',
|
||||||
|
timestamp: rootId,
|
||||||
|
role: 'system',
|
||||||
|
content: '',
|
||||||
|
parent: -1,
|
||||||
|
children: [firstMsg.id],
|
||||||
|
});
|
||||||
|
for (let i = 0; i < messages.length; i++) {
|
||||||
|
const msg = messages[i];
|
||||||
|
await db.messages.add({
|
||||||
|
...msg,
|
||||||
|
type: 'text',
|
||||||
|
convId: convId,
|
||||||
|
timestamp: msg.id,
|
||||||
|
parent: i === 0 ? rootId : messages[i - 1].id,
|
||||||
|
children: i === messages.length - 1 ? [] : [messages[i + 1].id],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
migratedCount++;
|
||||||
|
console.log(
|
||||||
|
`Migrated conversation ${convId} with ${messages.length} messages`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
console.log(
|
||||||
|
`Migrated ${migratedCount} conversations from localStorage to IndexedDB`
|
||||||
|
);
|
||||||
|
localStorage.setItem('migratedToIDB', '1');
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
|
@ -5,11 +5,46 @@ export interface TimingReport {
|
||||||
predicted_ms: number;
|
predicted_ms: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* What is conversation "branching"? It is a feature that allows the user to edit an old message in the history, while still keeping the conversation flow.
|
||||||
|
* Inspired by ChatGPT / Claude / Hugging Chat where you edit a message, a new branch of the conversation is created, and the old message is still visible.
|
||||||
|
*
|
||||||
|
* We use the same node-based structure like other chat UIs, where each message has a parent and children. A "root" message is the first message in a conversation, which will not be displayed in the UI.
|
||||||
|
*
|
||||||
|
* root
|
||||||
|
* ├── message 1
|
||||||
|
* │ └── message 2
|
||||||
|
* │ └── message 3
|
||||||
|
* └── message 4
|
||||||
|
* └── message 5
|
||||||
|
*
|
||||||
|
* In the above example, assuming that user wants to edit message 2, a new branch will be created:
|
||||||
|
*
|
||||||
|
* ├── message 2
|
||||||
|
* │ └── message 3
|
||||||
|
* └── message 6
|
||||||
|
*
|
||||||
|
* Message 2 and 6 are siblings, and message 6 is the new branch.
|
||||||
|
*
|
||||||
|
* We only need to know the last node (aka leaf) to get the current branch. In the above example, message 5 is the leaf of branch containing message 4 and 5.
|
||||||
|
*
|
||||||
|
* For the implementation:
|
||||||
|
* - StorageUtils.getMessages() returns list of all nodes
|
||||||
|
* - StorageUtils.filterByLeafNodeId() filters the list of nodes from a given leaf node
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Note: the term "message" and "node" are used interchangeably in this context
|
||||||
export interface Message {
|
export interface Message {
|
||||||
id: number;
|
id: number;
|
||||||
|
convId: string;
|
||||||
|
type: 'text' | 'root';
|
||||||
|
timestamp: number; // timestamp from Date.now()
|
||||||
role: 'user' | 'assistant' | 'system';
|
role: 'user' | 'assistant' | 'system';
|
||||||
content: string;
|
content: string;
|
||||||
timings?: TimingReport;
|
timings?: TimingReport;
|
||||||
|
// node based system for branching
|
||||||
|
parent: Message['id'];
|
||||||
|
children: Message['id'][];
|
||||||
}
|
}
|
||||||
|
|
||||||
export type APIMessage = Pick<Message, 'role' | 'content'>;
|
export type APIMessage = Pick<Message, 'role' | 'content'>;
|
||||||
|
@ -17,7 +52,13 @@ export type APIMessage = Pick<Message, 'role' | 'content'>;
|
||||||
export interface Conversation {
|
export interface Conversation {
|
||||||
id: string; // format: `conv-{timestamp}`
|
id: string; // format: `conv-{timestamp}`
|
||||||
lastModified: number; // timestamp from Date.now()
|
lastModified: number; // timestamp from Date.now()
|
||||||
messages: Message[];
|
currNode: Message['id']; // the current message node being viewed
|
||||||
|
name: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ViewingChat {
|
||||||
|
conv: Readonly<Conversation>;
|
||||||
|
messages: Readonly<Message[]>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export type PendingMessage = Omit<Message, 'content'> & {
|
export type PendingMessage = Omit<Message, 'content'> & {
|
||||||
|
|
|
@ -10,8 +10,6 @@ extern "C" {
|
||||||
#define GGML_VK_NAME "Vulkan"
|
#define GGML_VK_NAME "Vulkan"
|
||||||
#define GGML_VK_MAX_DEVICES 16
|
#define GGML_VK_MAX_DEVICES 16
|
||||||
|
|
||||||
GGML_BACKEND_API void ggml_vk_instance_init(void);
|
|
||||||
|
|
||||||
// backend API
|
// backend API
|
||||||
GGML_BACKEND_API ggml_backend_t ggml_backend_vk_init(size_t dev_num);
|
GGML_BACKEND_API ggml_backend_t ggml_backend_vk_init(size_t dev_num);
|
||||||
|
|
||||||
|
|
|
@ -473,7 +473,6 @@ GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128)
|
||||||
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
||||||
GGML_TABLE_END()
|
GGML_TABLE_END()
|
||||||
|
|
||||||
//#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A // lowest compute capability for integer intrinsics
|
|
||||||
GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
|
GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
|
||||||
0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
|
0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
|
||||||
0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
|
0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
|
||||||
|
@ -508,7 +507,6 @@ GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
|
||||||
0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,
|
0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,
|
||||||
0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
|
0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
|
||||||
GGML_TABLE_END()
|
GGML_TABLE_END()
|
||||||
//#endif
|
|
||||||
|
|
||||||
|
|
||||||
GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256)
|
GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256)
|
||||||
|
|
|
@ -284,14 +284,14 @@ struct ggml_backend_cpu_device_context {
|
||||||
&hKey) == ERROR_SUCCESS) {
|
&hKey) == ERROR_SUCCESS) {
|
||||||
DWORD cpu_brand_size = 0;
|
DWORD cpu_brand_size = 0;
|
||||||
if (RegQueryValueExA(hKey,
|
if (RegQueryValueExA(hKey,
|
||||||
TEXT("ProcessorNameString"),
|
"ProcessorNameString",
|
||||||
NULL,
|
NULL,
|
||||||
NULL,
|
NULL,
|
||||||
NULL,
|
NULL,
|
||||||
&cpu_brand_size) == ERROR_SUCCESS) {
|
&cpu_brand_size) == ERROR_SUCCESS) {
|
||||||
description.resize(cpu_brand_size);
|
description.resize(cpu_brand_size);
|
||||||
if (RegQueryValueExA(hKey,
|
if (RegQueryValueExA(hKey,
|
||||||
TEXT("ProcessorNameString"),
|
"ProcessorNameString",
|
||||||
NULL,
|
NULL,
|
||||||
NULL,
|
NULL,
|
||||||
(LPBYTE)&description[0], // NOLINT
|
(LPBYTE)&description[0], // NOLINT
|
||||||
|
@ -534,9 +534,6 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
|
||||||
if (ggml_cpu_has_dotprod()) {
|
if (ggml_cpu_has_dotprod()) {
|
||||||
features.push_back({ "DOTPROD", "1" });
|
features.push_back({ "DOTPROD", "1" });
|
||||||
}
|
}
|
||||||
if (ggml_cpu_has_matmul_int8()) {
|
|
||||||
features.push_back({ "MATMUL_INT8", "1" });
|
|
||||||
}
|
|
||||||
if (ggml_cpu_get_sve_cnt() > 0) {
|
if (ggml_cpu_get_sve_cnt() > 0) {
|
||||||
static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt());
|
static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt());
|
||||||
features.push_back({ "SVE_CNT", sve_cnt.c_str() });
|
features.push_back({ "SVE_CNT", sve_cnt.c_str() });
|
||||||
|
|
|
@ -71,6 +71,47 @@
|
||||||
#define GGML_CUDA_CC_QY1 210
|
#define GGML_CUDA_CC_QY1 210
|
||||||
#define GGML_CUDA_CC_QY2 220
|
#define GGML_CUDA_CC_QY2 220
|
||||||
|
|
||||||
|
#ifdef __CUDA_ARCH_LIST__
|
||||||
|
constexpr bool ggml_cuda_has_arch_impl(int) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class ... Archs>
|
||||||
|
constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) {
|
||||||
|
return arch == first || ggml_cuda_has_arch_impl(arch, rest...);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr bool ggml_cuda_has_arch(const int arch) {
|
||||||
|
return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) {
|
||||||
|
if (cur == 0) {
|
||||||
|
GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch);
|
||||||
|
}
|
||||||
|
return cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class ... Archs>
|
||||||
|
constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) {
|
||||||
|
if (first <= arch && first > cur) {
|
||||||
|
return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...);
|
||||||
|
} else {
|
||||||
|
return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int ggml_cuda_highest_compiled_arch(const int arch) {
|
||||||
|
return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
static int ggml_cuda_highest_compiled_arch(const int arch) {
|
||||||
|
return arch;
|
||||||
|
}
|
||||||
|
#endif // __CUDA_ARCH_LIST__
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
||||||
|
|
||||||
#if defined(_MSC_VER)
|
#if defined(_MSC_VER)
|
||||||
|
@ -124,11 +165,11 @@ static const char * cu_get_error_str(CUresult err) {
|
||||||
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
|
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
|
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
|
||||||
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
|
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
|
||||||
#else
|
#else
|
||||||
#define GGML_CUDA_ASSUME(x)
|
#define GGML_CUDA_ASSUME(x)
|
||||||
#endif // CUDART_VERSION >= 11100
|
#endif // CUDART_VERSION >= 11010
|
||||||
|
|
||||||
#ifdef GGML_CUDA_F16
|
#ifdef GGML_CUDA_F16
|
||||||
typedef half dfloat; // dequantize float
|
typedef half dfloat; // dequantize float
|
||||||
|
@ -162,18 +203,32 @@ typedef float2 dfloat2;
|
||||||
#define FLASH_ATTN_AVAILABLE
|
#define FLASH_ATTN_AVAILABLE
|
||||||
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||||
|
|
||||||
static constexpr bool fast_fp16_available(const int cc) {
|
static bool fp16_available(const int cc) {
|
||||||
|
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool fast_fp16_available(const int cc) {
|
||||||
|
return fp16_available(cc) && cc != 610;
|
||||||
|
}
|
||||||
|
|
||||||
|
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
||||||
|
static bool fast_fp16_hardware_available(const int cc) {
|
||||||
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
|
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Any FP16 tensor cores are available.
|
// Any FP16 tensor core instructions are available for ggml code.
|
||||||
static constexpr bool fp16_mma_available(const int cc) {
|
static bool fp16_mma_available(const int cc) {
|
||||||
|
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
|
||||||
|
}
|
||||||
|
|
||||||
|
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
||||||
|
static bool fp16_mma_hardware_available(const int cc) {
|
||||||
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
|
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
||||||
static constexpr bool new_mma_available(const int cc) {
|
static bool new_mma_available(const int cc) {
|
||||||
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
|
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
||||||
|
|
|
@ -599,7 +599,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_PASCAL) {
|
if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
|
||||||
return dequantize_block_q8_0_f16_cuda;
|
return dequantize_block_q8_0_f16_cuda;
|
||||||
}
|
}
|
||||||
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
||||||
|
|
|
@ -1868,14 +1868,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||||
|
|
||||||
const int cc = ggml_cuda_info().devices[id].cc;
|
const int cc = ggml_cuda_info().devices[id].cc;
|
||||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
||||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
|
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
||||||
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
|
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const int cc = ggml_cuda_info().devices[ctx.device].cc;
|
const int cc = ggml_cuda_info().devices[ctx.device].cc;
|
||||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
||||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
|
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
||||||
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
|
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
|
||||||
}
|
}
|
||||||
|
|
||||||
// debug helpers
|
// debug helpers
|
||||||
|
@ -2845,7 +2845,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
|
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
|
||||||
cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
|
cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
|
||||||
if (err != cudaSuccess) {
|
if (err != cudaSuccess) {
|
||||||
// clear the error
|
// clear the error
|
||||||
|
@ -2857,8 +2857,10 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
#else
|
#else
|
||||||
|
GGML_UNUSED(buffer);
|
||||||
|
GGML_UNUSED(size);
|
||||||
return false;
|
return false;
|
||||||
#endif
|
#endif // CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
|
void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
|
||||||
|
@ -3210,8 +3212,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
|
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
|
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
|
||||||
return cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
||||||
}
|
}
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||||
|
|
|
@ -18,7 +18,7 @@ void ggml_cuda_op_mul_mat_q(
|
||||||
const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
|
const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
|
||||||
|
|
||||||
int id = ggml_cuda_get_device();
|
int id = ggml_cuda_get_device();
|
||||||
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
const int cc = ggml_cuda_info().devices[id].cc;
|
||||||
|
|
||||||
// the main device has a larger memory buffer to hold the results from all GPUs
|
// the main device has a larger memory buffer to hold the results from all GPUs
|
||||||
// nrows_dst == nrows of the matrix that the kernel writes into
|
// nrows_dst == nrows of the matrix that the kernel writes into
|
||||||
|
@ -27,7 +27,8 @@ void ggml_cuda_op_mul_mat_q(
|
||||||
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
|
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
|
||||||
// Also its fixup needs to allocate a temporary buffer in the memory pool.
|
// Also its fixup needs to allocate a temporary buffer in the memory pool.
|
||||||
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
|
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
|
||||||
const bool use_stream_k = compute_capability >= GGML_CUDA_CC_VOLTA && compute_capability < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
|
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA &&
|
||||||
|
cc < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
|
||||||
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
|
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
|
@ -138,7 +139,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc < GGML_CUDA_CC_DP4A) {
|
if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,7 +148,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||||
#endif //GGML_CUDA_FORCE_MMQ
|
#endif //GGML_CUDA_FORCE_MMQ
|
||||||
|
|
||||||
if (cc < GGML_CUDA_CC_OFFSET_AMD) {
|
if (cc < GGML_CUDA_CC_OFFSET_AMD) {
|
||||||
return cc < GGML_CUDA_CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc) && !GGML_CUDA_CC_IS_GCN(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc) && !GGML_CUDA_CC_IS_GCN(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||||
|
|
|
@ -87,12 +87,13 @@ struct tile_x_sizes {
|
||||||
int sc;
|
int sc;
|
||||||
};
|
};
|
||||||
|
|
||||||
static constexpr int get_mmq_x_max_host(const int cc) {
|
static int get_mmq_x_max_host(const int cc) {
|
||||||
return new_mma_available(cc) ? 128 :
|
return new_mma_available(cc) ? 128 :
|
||||||
|
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ?
|
||||||
#ifdef GGML_CUDA_FORCE_MMQ
|
#ifdef GGML_CUDA_FORCE_MMQ
|
||||||
cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128 : 64;
|
128 : 64;
|
||||||
#else
|
#else
|
||||||
cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64;
|
MMQ_DP4A_MAX_BATCH_SIZE : 64;
|
||||||
#endif // GGML_CUDA_FORCE_MMQ
|
#endif // GGML_CUDA_FORCE_MMQ
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -120,8 +121,9 @@ static constexpr __device__ int get_mmq_x_max_device() {
|
||||||
#endif // NEW_MMA_AVAILABLE
|
#endif // NEW_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr int get_mmq_y_host(const int cc) {
|
static int get_mmq_y_host(const int cc) {
|
||||||
return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) : (cc >= GGML_CUDA_CC_VOLTA ? 128 : 64);
|
return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
|
||||||
|
(ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? 128 : 64);
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr __device__ int get_mmq_y_device() {
|
static constexpr __device__ int get_mmq_y_device() {
|
||||||
|
@ -2829,7 +2831,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
|
||||||
const int mmq_x_max = get_mmq_x_max_host(cc);
|
const int mmq_x_max = get_mmq_x_max_host(cc);
|
||||||
const int mmq_y = get_mmq_y_host(cc);
|
const int mmq_y = get_mmq_y_host(cc);
|
||||||
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
|
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
|
||||||
const bool use_stream_k = cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
|
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
|
||||||
|
|
||||||
int mmq_x_best = 0;
|
int mmq_x_best = 0;
|
||||||
int nparts_best = INT_MAX;
|
int nparts_best = INT_MAX;
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
|
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
|
||||||
#define USE_CUB
|
#define USE_CUB
|
||||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
|
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
|
||||||
|
|
||||||
#ifdef USE_CUB
|
#ifdef USE_CUB
|
||||||
#include <cub/cub.cuh>
|
#include <cub/cub.cuh>
|
||||||
|
|
|
@ -171,6 +171,7 @@ struct vk_device_struct {
|
||||||
uint32_t subgroup_size;
|
uint32_t subgroup_size;
|
||||||
uint32_t shader_core_count;
|
uint32_t shader_core_count;
|
||||||
bool uma;
|
bool uma;
|
||||||
|
bool prefer_host_memory;
|
||||||
bool float_controls_rte_fp16;
|
bool float_controls_rte_fp16;
|
||||||
|
|
||||||
bool subgroup_size_control;
|
bool subgroup_size_control;
|
||||||
|
@ -188,12 +189,12 @@ struct vk_device_struct {
|
||||||
|
|
||||||
size_t idx;
|
size_t idx;
|
||||||
|
|
||||||
bool mul_mat_l;
|
bool mul_mat_l[GGML_TYPE_COUNT];
|
||||||
bool mul_mat_m;
|
bool mul_mat_m[GGML_TYPE_COUNT];
|
||||||
bool mul_mat_s;
|
bool mul_mat_s[GGML_TYPE_COUNT];
|
||||||
bool mul_mat_id_l;
|
bool mul_mat_id_l[GGML_TYPE_COUNT];
|
||||||
bool mul_mat_id_m;
|
bool mul_mat_id_m[GGML_TYPE_COUNT];
|
||||||
bool mul_mat_id_s;
|
bool mul_mat_id_s[GGML_TYPE_COUNT];
|
||||||
|
|
||||||
// set to true to indicate that some shaders need to be compiled after the dryrun
|
// set to true to indicate that some shaders need to be compiled after the dryrun
|
||||||
bool need_compiles {};
|
bool need_compiles {};
|
||||||
|
@ -1298,7 +1299,9 @@ static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk:
|
||||||
static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
|
static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
|
||||||
vk_buffer buf;
|
vk_buffer buf;
|
||||||
try {
|
try {
|
||||||
if (device->uma) {
|
if (device->prefer_host_memory) {
|
||||||
|
buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
||||||
|
} else if (device->uma) {
|
||||||
// Fall back to host memory type
|
// Fall back to host memory type
|
||||||
buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
|
buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
|
||||||
} else {
|
} else {
|
||||||
|
@ -1382,7 +1385,33 @@ static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ
|
||||||
return {64, 64};
|
return {64, 64};
|
||||||
};
|
};
|
||||||
|
|
||||||
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id) {
|
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
|
||||||
|
|
||||||
|
uint32_t lut_size = 0;
|
||||||
|
switch (src0_type) {
|
||||||
|
case GGML_TYPE_IQ2_XXS:
|
||||||
|
lut_size = 8*256;
|
||||||
|
break;
|
||||||
|
case GGML_TYPE_IQ2_XS:
|
||||||
|
lut_size = 8*512;
|
||||||
|
break;
|
||||||
|
case GGML_TYPE_IQ2_S:
|
||||||
|
lut_size = 8*1024;
|
||||||
|
break;
|
||||||
|
case GGML_TYPE_IQ3_XXS:
|
||||||
|
lut_size = 4*256;
|
||||||
|
break;
|
||||||
|
case GGML_TYPE_IQ3_S:
|
||||||
|
lut_size = 4*512;
|
||||||
|
break;
|
||||||
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
case GGML_TYPE_IQ4_XS:
|
||||||
|
lut_size = 4*16;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
// Needs to be kept up to date on shader changes
|
// Needs to be kept up to date on shader changes
|
||||||
const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
|
const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
|
||||||
const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
||||||
|
@ -1392,7 +1421,13 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
||||||
const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0;
|
const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0;
|
||||||
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
|
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
|
||||||
|
|
||||||
return (load_bufs + mmid_row_ids + coopmat_stage) <= device->properties.limits.maxComputeSharedMemorySize;
|
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
|
||||||
|
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||||
|
|
||||||
|
VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
|
||||||
|
"mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", supported=" << supported);
|
||||||
|
|
||||||
|
return supported;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_load_shaders(vk_device& device) {
|
static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
|
@ -1476,62 +1511,32 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
m_align = 64;
|
m_align = 64;
|
||||||
s_align = 32;
|
s_align = 32;
|
||||||
|
|
||||||
// Fallback to smaller sizes if there's not enough shared memory. Given the current shaders
|
for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
|
||||||
// and tile sizes, this should handle 16KB, 32KB, and 48KB+.
|
ggml_type t = (ggml_type)i;
|
||||||
// This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders.
|
|
||||||
// But the numbers happen to work out for 32KB shared memory size that when using the medium
|
|
||||||
// size there's enough room for everything, and we assert for this.
|
|
||||||
uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
|
|
||||||
if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
|
|
||||||
l_warptile = m_warptile;
|
|
||||||
l_wg_denoms = m_wg_denoms;
|
|
||||||
shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
|
|
||||||
GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
|
|
||||||
}
|
|
||||||
if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
|
|
||||||
// assert mul_mat_mat_id shaders will fit.
|
|
||||||
GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
|
|
||||||
}
|
|
||||||
|
|
||||||
shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
|
|
||||||
if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
|
|
||||||
if (device->properties.limits.maxComputeSharedMemorySize == 32768) {
|
|
||||||
l_warptile_mmq = m_warptile_mmq;
|
|
||||||
l_mmq_wg_denoms = m_mmq_wg_denoms;
|
|
||||||
} else {
|
|
||||||
l_warptile_mmq = s_warptile_mmq;
|
|
||||||
l_mmq_wg_denoms = s_mmq_wg_denoms;
|
|
||||||
}
|
|
||||||
shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
|
|
||||||
GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
|
|
||||||
}
|
|
||||||
if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
|
|
||||||
// assert mul_mat_mat_id shaders will fit.
|
|
||||||
GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
|
|
||||||
}
|
|
||||||
// Disable medium and large matrix multiplication if not enough shared memory is available
|
// Disable medium and large matrix multiplication if not enough shared memory is available
|
||||||
// Check mmq warptiles as the largest configuration
|
// Check mmq warptiles as the largest configuration
|
||||||
// Throw an error if not enough for any matrix multiplication is available
|
// Throw an error if not enough for any matrix multiplication is available
|
||||||
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false)) {
|
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t)) {
|
||||||
std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
|
std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
|
||||||
throw std::runtime_error("Shared memory size too small for matrix multiplication.");
|
throw std::runtime_error("Shared memory size too small for matrix multiplication.");
|
||||||
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false)) {
|
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t)) {
|
||||||
device->mul_mat_m = false;
|
device->mul_mat_m[i] = false;
|
||||||
device->mul_mat_l = false;
|
device->mul_mat_l[i] = false;
|
||||||
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false)) {
|
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t)) {
|
||||||
device->mul_mat_l = false;
|
device->mul_mat_l[i] = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disable mul_mat_id if not enough shared memory is available
|
// Disable mul_mat_id if not enough shared memory is available
|
||||||
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true)) {
|
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true, t)) {
|
||||||
device->mul_mat_id_s = false;
|
device->mul_mat_id_s[i] = false;
|
||||||
device->mul_mat_id_m = false;
|
device->mul_mat_id_m[i] = false;
|
||||||
device->mul_mat_id_l = false;
|
device->mul_mat_id_l[i] = false;
|
||||||
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true)) {
|
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true, t)) {
|
||||||
device->mul_mat_id_m = false;
|
device->mul_mat_id_m[i] = false;
|
||||||
device->mul_mat_id_l = false;
|
device->mul_mat_id_l[i] = false;
|
||||||
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true)) {
|
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true, t)) {
|
||||||
device->mul_mat_id_l = false;
|
device->mul_mat_id_l[i] = false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1688,119 +1693,116 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||||
if (device->coopmat_support) {
|
if (device->coopmat_support) {
|
||||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||||
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||||
if (device->mul_mat ## ID ## _l) \
|
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
|
||||||
if (device->mul_mat ## ID ## _m) \
|
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
|
||||||
if (device->mul_mat ## ID ## _s) \
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
|
||||||
if (device->mul_mat ## ID ## _l) \
|
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
|
||||||
if (device->mul_mat ## ID ## _m) \
|
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
|
||||||
if (device->mul_mat ## ID ## _s) \
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
|
||||||
|
|
||||||
// Create 2 variants, {f16,f32} accumulator
|
// Create 2 variants, {f16,f32} accumulator
|
||||||
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||||
if (device->coopmat_acc_f16_support) { \
|
if (device->coopmat_acc_f16_support) { \
|
||||||
CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||||
} \
|
} \
|
||||||
if (device->coopmat_acc_f32_support) { \
|
if (device->coopmat_acc_f32_support) { \
|
||||||
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
|
|
||||||
if (device->coopmat_acc_f16_support) {
|
if (device->coopmat_acc_f16_support) {
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
|
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
} else {
|
} else {
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
|
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
||||||
CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
||||||
|
|
||||||
if (device->coopmat_acc_f16_support) {
|
if (device->coopmat_acc_f16_support) {
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
|
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
} else {
|
} else {
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
|
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
#undef CREATE_MM2
|
#undef CREATE_MM2
|
||||||
#undef CREATE_MM
|
#undef CREATE_MM
|
||||||
|
@ -1808,141 +1810,135 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||||
if (device->fp16) {
|
if (device->fp16) {
|
||||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||||
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||||
if (device->mul_mat ## ID ## _l) \
|
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
||||||
if (device->mul_mat ## ID ## _m) \
|
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
||||||
if (device->mul_mat ## ID ## _s) \
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
||||||
if (device->mul_mat ## ID ## _l) \
|
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
|
||||||
if (device->mul_mat ## ID ## _m) \
|
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
|
||||||
if (device->mul_mat ## ID ## _s) \
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
||||||
|
|
||||||
// Create 2 variants, {f16,f32} accumulator
|
// Create 2 variants, {f16,f32} accumulator
|
||||||
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||||
CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||||
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||||
|
|
||||||
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
|
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
|
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
|
|
||||||
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
||||||
CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
||||||
|
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
|
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
}
|
|
||||||
#undef CREATE_MM2
|
#undef CREATE_MM2
|
||||||
#undef CREATE_MM
|
#undef CREATE_MM
|
||||||
} else {
|
} else {
|
||||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||||
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||||
if (device->mul_mat ## ID ## _l) \
|
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
||||||
if (device->mul_mat ## ID ## _m) \
|
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
||||||
if (device->mul_mat ## ID ## _s) \
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
||||||
if (device->mul_mat ## ID ## _l) \
|
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
|
||||||
if (device->mul_mat ## ID ## _m) \
|
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
|
||||||
if (device->mul_mat ## ID ## _s) \
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
||||||
|
|
||||||
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
|
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
|
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
|
|
||||||
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
||||||
CREATE_MM(pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
||||||
|
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
|
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
}
|
|
||||||
#undef CREATE_MM
|
#undef CREATE_MM
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2210,6 +2206,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
device->physical_device = physical_devices[dev_num];
|
device->physical_device = physical_devices[dev_num];
|
||||||
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
|
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
|
||||||
|
|
||||||
|
const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
|
||||||
|
device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
|
||||||
|
|
||||||
bool fp16_storage = false;
|
bool fp16_storage = false;
|
||||||
bool fp16_compute = false;
|
bool fp16_compute = false;
|
||||||
bool maintenance4_support = false;
|
bool maintenance4_support = false;
|
||||||
|
@ -2631,35 +2630,37 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
|
|
||||||
// Shaders
|
// Shaders
|
||||||
// Disable matmul tile sizes early if performance low or not supported
|
// Disable matmul tile sizes early if performance low or not supported
|
||||||
|
for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
|
||||||
switch (device->vendor_id) {
|
switch (device->vendor_id) {
|
||||||
#ifndef GGML_VULKAN_RUN_TESTS
|
#ifndef GGML_VULKAN_RUN_TESTS
|
||||||
case VK_VENDOR_ID_AMD:
|
case VK_VENDOR_ID_AMD:
|
||||||
case VK_VENDOR_ID_INTEL:
|
case VK_VENDOR_ID_INTEL:
|
||||||
device->mul_mat_l = false;
|
device->mul_mat_l[i] = false;
|
||||||
device->mul_mat_m = true;
|
device->mul_mat_m[i] = true;
|
||||||
device->mul_mat_s = true;
|
device->mul_mat_s[i] = true;
|
||||||
device->mul_mat_id_l = false;
|
device->mul_mat_id_l[i] = false;
|
||||||
device->mul_mat_id_m = true;
|
device->mul_mat_id_m[i] = true;
|
||||||
device->mul_mat_id_s = true;
|
device->mul_mat_id_s[i] = true;
|
||||||
break;
|
break;
|
||||||
case VK_VENDOR_ID_APPLE:
|
case VK_VENDOR_ID_APPLE:
|
||||||
device->mul_mat_l = false;
|
device->mul_mat_l[i] = false;
|
||||||
device->mul_mat_m = true;
|
device->mul_mat_m[i] = true;
|
||||||
device->mul_mat_s = false;
|
device->mul_mat_s[i] = false;
|
||||||
device->mul_mat_id_l = false;
|
device->mul_mat_id_l[i] = false;
|
||||||
device->mul_mat_id_m = true;
|
device->mul_mat_id_m[i] = true;
|
||||||
device->mul_mat_id_s = false;
|
device->mul_mat_id_s[i] = false;
|
||||||
break;
|
break;
|
||||||
#endif
|
#endif
|
||||||
default:
|
default:
|
||||||
device->mul_mat_l = true;
|
device->mul_mat_l[i] = true;
|
||||||
device->mul_mat_m = true;
|
device->mul_mat_m[i] = true;
|
||||||
device->mul_mat_s = true;
|
device->mul_mat_s[i] = true;
|
||||||
device->mul_mat_id_l = true;
|
device->mul_mat_id_l[i] = true;
|
||||||
device->mul_mat_id_m = true;
|
device->mul_mat_id_m[i] = true;
|
||||||
device->mul_mat_id_s = true;
|
device->mul_mat_id_s[i] = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ggml_vk_load_shaders(device);
|
ggml_vk_load_shaders(device);
|
||||||
|
|
||||||
|
@ -2800,14 +2801,12 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||||
static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
|
static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
|
||||||
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
|
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
|
||||||
|
|
||||||
void ggml_vk_instance_init() {
|
static void ggml_vk_instance_init() {
|
||||||
if (vk_instance_initialized) {
|
if (vk_instance_initialized) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
VK_LOG_DEBUG("ggml_vk_instance_init()");
|
VK_LOG_DEBUG("ggml_vk_instance_init()");
|
||||||
|
|
||||||
vk_instance_initialized = true;
|
|
||||||
|
|
||||||
uint32_t api_version = vk::enumerateInstanceVersion();
|
uint32_t api_version = vk::enumerateInstanceVersion();
|
||||||
|
|
||||||
if (api_version < VK_API_VERSION_1_2) {
|
if (api_version < VK_API_VERSION_1_2) {
|
||||||
|
@ -2858,6 +2857,7 @@ void ggml_vk_instance_init() {
|
||||||
GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n");
|
GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n");
|
||||||
}
|
}
|
||||||
vk_instance.instance = vk::createInstance(instance_create_info);
|
vk_instance.instance = vk::createInstance(instance_create_info);
|
||||||
|
vk_instance_initialized = true;
|
||||||
|
|
||||||
size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
|
size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
|
||||||
|
|
||||||
|
@ -2882,7 +2882,7 @@ void ggml_vk_instance_init() {
|
||||||
// Make sure at least one device exists
|
// Make sure at least one device exists
|
||||||
if (devices.empty()) {
|
if (devices.empty()) {
|
||||||
std::cerr << "ggml_vulkan: Error: No devices found." << std::endl;
|
std::cerr << "ggml_vulkan: Error: No devices found." << std::endl;
|
||||||
GGML_ABORT("fatal error");
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default to using all dedicated GPUs
|
// Default to using all dedicated GPUs
|
||||||
|
@ -3764,31 +3764,31 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
|
||||||
return split_k;
|
return split_k;
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
|
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
|
||||||
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
|
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
|
||||||
|
|
||||||
if (ctx->device->coopmat2) {
|
if (ctx->device->coopmat2) {
|
||||||
if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) {
|
if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
|
||||||
return aligned ? mmp->a_l : mmp->l;
|
return aligned ? mmp->a_l : mmp->l;
|
||||||
}
|
}
|
||||||
if ((ctx->device->mul_mat_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s) {
|
if ((ctx->device->mul_mat_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s[src0_type]) {
|
||||||
return aligned ? mmp->a_m : mmp->m;
|
return aligned ? mmp->a_m : mmp->m;
|
||||||
}
|
}
|
||||||
return aligned ? mmp->a_s : mmp->s;
|
return aligned ? mmp->a_s : mmp->s;
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((ctx->device->mul_mat_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_l)) {
|
if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) {
|
||||||
return aligned ? mmp->a_s : mmp->s;
|
return aligned ? mmp->a_s : mmp->s;
|
||||||
}
|
}
|
||||||
if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l) {
|
if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) {
|
||||||
return aligned ? mmp->a_m : mmp->m;
|
return aligned ? mmp->a_m : mmp->m;
|
||||||
}
|
}
|
||||||
return aligned ? mmp->a_l : mmp->l;
|
return aligned ? mmp->a_l : mmp->l;
|
||||||
}
|
}
|
||||||
|
|
||||||
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
|
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
|
||||||
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
|
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
|
||||||
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true)->align;
|
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_matmul(
|
static void ggml_vk_matmul(
|
||||||
|
@ -3815,31 +3815,31 @@ static void ggml_vk_matmul(
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
|
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
|
||||||
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
|
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
|
||||||
|
|
||||||
if (ctx->device->coopmat2) {
|
if (ctx->device->coopmat2) {
|
||||||
if ((ctx->device->mul_mat_id_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_s)) {
|
if ((ctx->device->mul_mat_id_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
|
||||||
return aligned ? mmp->a_l : mmp->l;
|
return aligned ? mmp->a_l : mmp->l;
|
||||||
}
|
}
|
||||||
if ((ctx->device->mul_mat_id_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s) {
|
if ((ctx->device->mul_mat_id_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s[src0_type]) {
|
||||||
return aligned ? mmp->a_m : mmp->m;
|
return aligned ? mmp->a_m : mmp->m;
|
||||||
}
|
}
|
||||||
return aligned ? mmp->a_s : mmp->s;
|
return aligned ? mmp->a_s : mmp->s;
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) {
|
if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) {
|
||||||
return aligned ? mmp->a_s : mmp->s;
|
return aligned ? mmp->a_s : mmp->s;
|
||||||
}
|
}
|
||||||
if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l) {
|
if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) {
|
||||||
return aligned ? mmp->a_m : mmp->m;
|
return aligned ? mmp->a_m : mmp->m;
|
||||||
}
|
}
|
||||||
return aligned ? mmp->a_l : mmp->l;
|
return aligned ? mmp->a_l : mmp->l;
|
||||||
}
|
}
|
||||||
|
|
||||||
static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
|
static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
|
||||||
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
|
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
|
||||||
return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true)->align;
|
return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_matmul_id(
|
static void ggml_vk_matmul_id(
|
||||||
|
@ -4020,10 +4020,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||||
const int y_ne = ne11 * ne10;
|
const int y_ne = ne11 * ne10;
|
||||||
const int d_ne = ne11 * ne01;
|
const int d_ne = ne11 * ne01;
|
||||||
|
|
||||||
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
|
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
|
||||||
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
|
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
|
||||||
|
|
||||||
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
|
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
|
||||||
|
|
||||||
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
|
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
|
||||||
|
|
||||||
|
@ -4602,10 +4602,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||||
const uint64_t y_ne = ne11 * ne10;
|
const uint64_t y_ne = ne11 * ne10;
|
||||||
const uint64_t d_ne = ne21 * ne20;
|
const uint64_t d_ne = ne21 * ne20;
|
||||||
|
|
||||||
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1));
|
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
|
||||||
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
|
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
|
||||||
|
|
||||||
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned);
|
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
|
||||||
|
|
||||||
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
||||||
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
||||||
|
@ -8044,13 +8044,14 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
{
|
{
|
||||||
|
ggml_type src0_type = op->src[0]->type;
|
||||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||||
const vk_device& device = ggml_vk_get_device(ctx->device);
|
const vk_device& device = ggml_vk_get_device(ctx->device);
|
||||||
if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s && !device->mul_mat_id_m && !device->mul_mat_id_l) {
|
if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
|
||||||
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
|
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
switch (op->src[0]->type) {
|
switch (src0_type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
|
@ -8356,8 +8357,13 @@ ggml_backend_reg_t ggml_backend_vk_reg() {
|
||||||
/* .iface = */ ggml_backend_vk_reg_i,
|
/* .iface = */ ggml_backend_vk_reg_i,
|
||||||
/* .context = */ nullptr,
|
/* .context = */ nullptr,
|
||||||
};
|
};
|
||||||
|
try {
|
||||||
|
ggml_vk_instance_init();
|
||||||
return ®
|
return ®
|
||||||
|
} catch (const vk::SystemError& e) {
|
||||||
|
VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: System error: " << e.what());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extension availability
|
// Extension availability
|
||||||
|
|
|
@ -1392,7 +1392,7 @@ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tenso
|
||||||
(t0->nb[3] == t1->nb[3]);
|
(t0->nb[3] == t1->nb[3]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if t1 can be represented as a repeatition of t0
|
// check if t1 can be represented as a repetition of t0
|
||||||
bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
||||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||||
|
|
||||||
|
|
|
@ -116,7 +116,7 @@ struct llama_grammar {
|
||||||
llama_partial_utf8 partial_utf8;
|
llama_partial_utf8 partial_utf8;
|
||||||
|
|
||||||
// lazy grammars wait for trigger words or tokens before constraining the sampling.
|
// lazy grammars wait for trigger words or tokens before constraining the sampling.
|
||||||
// we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
|
// we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
|
||||||
// (useful e.g. for tool_choice=required)
|
// (useful e.g. for tool_choice=required)
|
||||||
bool lazy = false;
|
bool lazy = false;
|
||||||
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
|
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue