mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .github/workflows/build.yml # Makefile # Package.swift # ci/run.sh # docs/backend/SYCL.md # examples/llama-bench/llama-bench.cpp # examples/server/CMakeLists.txt # examples/server/README.md # ggml/CMakeLists.txt # ggml/src/CMakeLists.txt # grammars/README.md # scripts/sync-ggml-am.sh # scripts/sync-ggml.last # scripts/sync-ggml.sh # tests/run-json-schema-to-grammar.mjs # tests/test-backend-ops.cpp
This commit is contained in:
commit
a244b1ffd2
55 changed files with 31704 additions and 2916 deletions
|
@ -24,6 +24,16 @@ insert_final_newline = unset
|
||||||
[examples/server/public/*]
|
[examples/server/public/*]
|
||||||
indent_size = 2
|
indent_size = 2
|
||||||
|
|
||||||
|
[examples/server/public/deps_*]
|
||||||
|
trim_trailing_whitespace = unset
|
||||||
|
indent_style = unset
|
||||||
|
indent_size = unset
|
||||||
|
|
||||||
|
[examples/server/deps_*]
|
||||||
|
trim_trailing_whitespace = unset
|
||||||
|
indent_style = unset
|
||||||
|
indent_size = unset
|
||||||
|
|
||||||
[examples/llama.swiftui/llama.swiftui.xcodeproj/*]
|
[examples/llama.swiftui/llama.swiftui.xcodeproj/*]
|
||||||
indent_style = tab
|
indent_style = tab
|
||||||
|
|
||||||
|
|
|
@ -1005,6 +1005,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
|
||||||
if (s == "f16") {
|
if (s == "f16") {
|
||||||
return GGML_TYPE_F16;
|
return GGML_TYPE_F16;
|
||||||
}
|
}
|
||||||
|
if (s == "bf16") {
|
||||||
|
return GGML_TYPE_BF16;
|
||||||
|
}
|
||||||
if (s == "q8_0") {
|
if (s == "q8_0") {
|
||||||
return GGML_TYPE_Q8_0;
|
return GGML_TYPE_Q8_0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -3748,10 +3748,7 @@ class JaisModel(Model):
|
||||||
|
|
||||||
# Embeddings scale
|
# Embeddings scale
|
||||||
self.embeddings_scale = 1.0
|
self.embeddings_scale = 1.0
|
||||||
# note: For some JAIS flavors, output is tied to (same as) wte in original model
|
|
||||||
self.output_is_wte = False
|
|
||||||
if 'mup_embeddings_scale' in self.hparams:
|
if 'mup_embeddings_scale' in self.hparams:
|
||||||
self.output_is_wte = True # Hack (?)
|
|
||||||
self.embeddings_scale = self.hparams['mup_embeddings_scale']
|
self.embeddings_scale = self.hparams['mup_embeddings_scale']
|
||||||
elif 'embeddings_scale' in self.hparams:
|
elif 'embeddings_scale' in self.hparams:
|
||||||
self.embeddings_scale = self.hparams['embeddings_scale']
|
self.embeddings_scale = self.hparams['embeddings_scale']
|
||||||
|
@ -3808,10 +3805,7 @@ class JaisModel(Model):
|
||||||
|
|
||||||
if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
|
if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
|
||||||
tensors.append((new_name, data_torch * self.embeddings_scale))
|
tensors.append((new_name, data_torch * self.embeddings_scale))
|
||||||
if self.output_is_wte:
|
|
||||||
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch * self.width_scale))
|
|
||||||
elif new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
|
elif new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
|
||||||
assert not self.output_is_wte
|
|
||||||
tensors.append((new_name, data_torch * self.width_scale))
|
tensors.append((new_name, data_torch * self.width_scale))
|
||||||
else:
|
else:
|
||||||
tensors.append((new_name, data_torch))
|
tensors.append((new_name, data_torch))
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import * as readline from 'node:readline'
|
import * as readline from 'node:readline'
|
||||||
import { stdin, stdout } from 'node:process'
|
import { stdin, stdout } from 'node:process'
|
||||||
import { readFileSync } from 'node:fs'
|
import { readFileSync } from 'node:fs'
|
||||||
import { SchemaConverter } from './public/json-schema-to-grammar.mjs'
|
import { SchemaConverter } from './public_legacy/json-schema-to-grammar.mjs'
|
||||||
|
|
||||||
const args = process.argv.slice(2);
|
const args = process.argv.slice(2);
|
||||||
const grammarJsonSchemaFile = args.find(
|
const grammarJsonSchemaFile = args.find(
|
||||||
|
|
|
@ -6,5 +6,20 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||||
PUBLIC=$DIR/public
|
PUBLIC=$DIR/public
|
||||||
|
|
||||||
echo "download js bundle files"
|
echo "download js bundle files"
|
||||||
curl https://npm.reversehttp.com/@preact/signals-core,@preact/signals,htm/preact,preact,preact/hooks > $PUBLIC/index.js
|
|
||||||
echo >> $PUBLIC/index.js # add newline
|
# Note for contributors: Always pin to a specific version "maj.min.patch" to avoid breaking the CI
|
||||||
|
|
||||||
|
curl -L https://cdn.tailwindcss.com/3.4.14 > $PUBLIC/deps_tailwindcss.js
|
||||||
|
echo >> $PUBLIC/deps_tailwindcss.js # add newline
|
||||||
|
|
||||||
|
curl -L https://cdnjs.cloudflare.com/ajax/libs/daisyui/4.12.14/styled.min.css > $PUBLIC/deps_daisyui.min.css
|
||||||
|
curl -L https://cdnjs.cloudflare.com/ajax/libs/daisyui/4.12.14/themes.min.css >> $PUBLIC/deps_daisyui.min.css
|
||||||
|
echo >> $PUBLIC/deps_daisyui.min.css # add newline
|
||||||
|
|
||||||
|
curl -L https://unpkg.com/vue@3.5.12/dist/vue.esm-browser.js > $PUBLIC/deps_vue.esm-browser.js
|
||||||
|
echo >> $PUBLIC/deps_vue.esm-browser.js # add newline
|
||||||
|
|
||||||
|
curl -L https://cdnjs.cloudflare.com/ajax/libs/markdown-it/13.0.2/markdown-it.js > $PUBLIC/deps_markdown-it.js
|
||||||
|
echo >> $PUBLIC/deps_markdown-it.js # add newline
|
||||||
|
|
||||||
|
ls -lah $PUBLIC
|
||||||
|
|
|
@ -1,12 +1,16 @@
|
||||||
const paramDefaults = {
|
const paramDefaults = {
|
||||||
stream: true,
|
stream: true,
|
||||||
n_predict: 500,
|
|
||||||
temperature: 0.2,
|
temperature: 0.2,
|
||||||
stop: ["</s>"]
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let generation_settings = null;
|
let generation_settings = null;
|
||||||
|
|
||||||
|
export class CompletionError extends Error {
|
||||||
|
constructor(message, name, data) {
|
||||||
|
super(message);
|
||||||
|
this.name = name;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Completes the prompt as a generator. Recommended for most use cases.
|
// Completes the prompt as a generator. Recommended for most use cases.
|
||||||
//
|
//
|
||||||
|
@ -29,7 +33,7 @@ export async function* llama(prompt, params = {}, config = {}) {
|
||||||
|
|
||||||
const completionParams = { ...paramDefaults, ...params, prompt };
|
const completionParams = { ...paramDefaults, ...params, prompt };
|
||||||
|
|
||||||
const response = await fetch(`${api_url}/completion`, {
|
const response = await fetch(`${api_url}${config.endpoint || '/completion'}`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
body: JSON.stringify(completionParams),
|
body: JSON.stringify(completionParams),
|
||||||
headers: {
|
headers: {
|
||||||
|
@ -41,6 +45,18 @@ export async function* llama(prompt, params = {}, config = {}) {
|
||||||
signal: controller.signal,
|
signal: controller.signal,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const status = response.status;
|
||||||
|
if (status !== 200) {
|
||||||
|
try {
|
||||||
|
const body = await response.json();
|
||||||
|
if (body && body.error && body.error.message) {
|
||||||
|
throw new CompletionError(body.error.message, 'ServerError');
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
throw new CompletionError(err.message, 'ServerError');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const reader = response.body.getReader();
|
const reader = response.body.getReader();
|
||||||
const decoder = new TextDecoder();
|
const decoder = new TextDecoder();
|
||||||
|
|
||||||
|
@ -78,7 +94,12 @@ export async function* llama(prompt, params = {}, config = {}) {
|
||||||
for (const line of lines) {
|
for (const line of lines) {
|
||||||
const match = regex.exec(line);
|
const match = regex.exec(line);
|
||||||
if (match) {
|
if (match) {
|
||||||
result[match[1]] = match[2]
|
result[match[1]] = match[2];
|
||||||
|
if (result.data === '[DONE]') {
|
||||||
|
cont = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
// since we know this is llama.cpp, let's just decode the json in data
|
// since we know this is llama.cpp, let's just decode the json in data
|
||||||
if (result.data) {
|
if (result.data) {
|
||||||
result.data = JSON.parse(result.data);
|
result.data = JSON.parse(result.data);
|
||||||
|
|
13
examples/server/public/deps_daisyui.min.css
vendored
Normal file
13
examples/server/public/deps_daisyui.min.css
vendored
Normal file
File diff suppressed because one or more lines are too long
8442
examples/server/public/deps_markdown-it.js
Normal file
8442
examples/server/public/deps_markdown-it.js
Normal file
File diff suppressed because it is too large
Load diff
82
examples/server/public/deps_tailwindcss.js
Normal file
82
examples/server/public/deps_tailwindcss.js
Normal file
File diff suppressed because one or more lines are too long
18160
examples/server/public/deps_vue.esm-browser.js
Normal file
18160
examples/server/public/deps_vue.esm-browser.js
Normal file
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
209
examples/server/public_legacy/completion.js
Normal file
209
examples/server/public_legacy/completion.js
Normal file
|
@ -0,0 +1,209 @@
|
||||||
|
const paramDefaults = {
|
||||||
|
stream: true,
|
||||||
|
n_predict: 500,
|
||||||
|
temperature: 0.2,
|
||||||
|
stop: ["</s>"]
|
||||||
|
};
|
||||||
|
|
||||||
|
let generation_settings = null;
|
||||||
|
|
||||||
|
|
||||||
|
// Completes the prompt as a generator. Recommended for most use cases.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// import { llama } from '/completion.js'
|
||||||
|
//
|
||||||
|
// const request = llama("Tell me a joke", {n_predict: 800})
|
||||||
|
// for await (const chunk of request) {
|
||||||
|
// document.write(chunk.data.content)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
export async function* llama(prompt, params = {}, config = {}) {
|
||||||
|
let controller = config.controller;
|
||||||
|
const api_url = config.api_url?.replace(/\/+$/, '') || "";
|
||||||
|
|
||||||
|
if (!controller) {
|
||||||
|
controller = new AbortController();
|
||||||
|
}
|
||||||
|
|
||||||
|
const completionParams = { ...paramDefaults, ...params, prompt };
|
||||||
|
|
||||||
|
const response = await fetch(`${api_url}${config.endpoint || '/completion'}`, {
|
||||||
|
method: 'POST',
|
||||||
|
body: JSON.stringify(completionParams),
|
||||||
|
headers: {
|
||||||
|
'Connection': 'keep-alive',
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Accept': 'text/event-stream',
|
||||||
|
...(params.api_key ? {'Authorization': `Bearer ${params.api_key}`} : {})
|
||||||
|
},
|
||||||
|
signal: controller.signal,
|
||||||
|
});
|
||||||
|
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
|
||||||
|
let content = "";
|
||||||
|
let leftover = ""; // Buffer for partially read lines
|
||||||
|
|
||||||
|
try {
|
||||||
|
let cont = true;
|
||||||
|
|
||||||
|
while (cont) {
|
||||||
|
const result = await reader.read();
|
||||||
|
if (result.done) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add any leftover data to the current chunk of data
|
||||||
|
const text = leftover + decoder.decode(result.value);
|
||||||
|
|
||||||
|
// Check if the last character is a line break
|
||||||
|
const endsWithLineBreak = text.endsWith('\n');
|
||||||
|
|
||||||
|
// Split the text into lines
|
||||||
|
let lines = text.split('\n');
|
||||||
|
|
||||||
|
// If the text doesn't end with a line break, then the last line is incomplete
|
||||||
|
// Store it in leftover to be added to the next chunk of data
|
||||||
|
if (!endsWithLineBreak) {
|
||||||
|
leftover = lines.pop();
|
||||||
|
} else {
|
||||||
|
leftover = ""; // Reset leftover if we have a line break at the end
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse all sse events and add them to result
|
||||||
|
const regex = /^(\S+):\s(.*)$/gm;
|
||||||
|
for (const line of lines) {
|
||||||
|
const match = regex.exec(line);
|
||||||
|
if (match) {
|
||||||
|
result[match[1]] = match[2];
|
||||||
|
if (result.data === '[DONE]') {
|
||||||
|
cont = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// since we know this is llama.cpp, let's just decode the json in data
|
||||||
|
if (result.data) {
|
||||||
|
result.data = JSON.parse(result.data);
|
||||||
|
content += result.data.content;
|
||||||
|
|
||||||
|
// yield
|
||||||
|
yield result;
|
||||||
|
|
||||||
|
// if we got a stop token from server, we will break here
|
||||||
|
if (result.data.stop) {
|
||||||
|
if (result.data.generation_settings) {
|
||||||
|
generation_settings = result.data.generation_settings;
|
||||||
|
}
|
||||||
|
cont = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (result.error) {
|
||||||
|
try {
|
||||||
|
result.error = JSON.parse(result.error);
|
||||||
|
if (result.error.message.includes('slot unavailable')) {
|
||||||
|
// Throw an error to be caught by upstream callers
|
||||||
|
throw new Error('slot unavailable');
|
||||||
|
} else {
|
||||||
|
console.error(`llama.cpp error [${result.error.code} - ${result.error.type}]: ${result.error.message}`);
|
||||||
|
}
|
||||||
|
} catch(e) {
|
||||||
|
console.error(`llama.cpp error ${result.error}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
if (e.name !== 'AbortError') {
|
||||||
|
console.error("llama error: ", e);
|
||||||
|
}
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
finally {
|
||||||
|
controller.abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
return content;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call llama, return an event target that you can subscribe to
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// import { llamaEventTarget } from '/completion.js'
|
||||||
|
//
|
||||||
|
// const conn = llamaEventTarget(prompt)
|
||||||
|
// conn.addEventListener("message", (chunk) => {
|
||||||
|
// document.write(chunk.detail.content)
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
export const llamaEventTarget = (prompt, params = {}, config = {}) => {
|
||||||
|
const eventTarget = new EventTarget();
|
||||||
|
(async () => {
|
||||||
|
let content = "";
|
||||||
|
for await (const chunk of llama(prompt, params, config)) {
|
||||||
|
if (chunk.data) {
|
||||||
|
content += chunk.data.content;
|
||||||
|
eventTarget.dispatchEvent(new CustomEvent("message", { detail: chunk.data }));
|
||||||
|
}
|
||||||
|
if (chunk.data.generation_settings) {
|
||||||
|
eventTarget.dispatchEvent(new CustomEvent("generation_settings", { detail: chunk.data.generation_settings }));
|
||||||
|
}
|
||||||
|
if (chunk.data.timings) {
|
||||||
|
eventTarget.dispatchEvent(new CustomEvent("timings", { detail: chunk.data.timings }));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
eventTarget.dispatchEvent(new CustomEvent("done", { detail: { content } }));
|
||||||
|
})();
|
||||||
|
return eventTarget;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call llama, return a promise that resolves to the completed text. This does not support streaming
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// llamaPromise(prompt).then((content) => {
|
||||||
|
// document.write(content)
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// or
|
||||||
|
//
|
||||||
|
// const content = await llamaPromise(prompt)
|
||||||
|
// document.write(content)
|
||||||
|
//
|
||||||
|
export const llamaPromise = (prompt, params = {}, config = {}) => {
|
||||||
|
return new Promise(async (resolve, reject) => {
|
||||||
|
let content = "";
|
||||||
|
try {
|
||||||
|
for await (const chunk of llama(prompt, params, config)) {
|
||||||
|
content += chunk.data.content;
|
||||||
|
}
|
||||||
|
resolve(content);
|
||||||
|
} catch (error) {
|
||||||
|
reject(error);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* (deprecated)
|
||||||
|
*/
|
||||||
|
export const llamaComplete = async (params, controller, callback) => {
|
||||||
|
for await (const chunk of llama(params.prompt, params, { controller })) {
|
||||||
|
callback(chunk);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the model info from the server. This is useful for getting the context window and so on.
|
||||||
|
export const llamaModelInfo = async (config = {}) => {
|
||||||
|
if (!generation_settings) {
|
||||||
|
const api_url = config.api_url?.replace(/\/+$/, '') || "";
|
||||||
|
const props = await fetch(`${api_url}/props`).then(r => r.json());
|
||||||
|
generation_settings = props.default_generation_settings;
|
||||||
|
}
|
||||||
|
return generation_settings;
|
||||||
|
}
|
Before Width: | Height: | Size: 4 KiB After Width: | Height: | Size: 4 KiB |
1303
examples/server/public_legacy/index.html
Normal file
1303
examples/server/public_legacy/index.html
Normal file
File diff suppressed because it is too large
Load diff
12
examples/server/public_legacy/loading.html
Normal file
12
examples/server/public_legacy/loading.html
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta http-equiv="refresh" content="5">
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="loading">
|
||||||
|
The model is loading. Please wait.<br/>
|
||||||
|
The user interface will appear soon.
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
|
@ -15,22 +15,13 @@
|
||||||
#define MIMETYPE_JSON "application/json; charset=utf-8"
|
#define MIMETYPE_JSON "application/json; charset=utf-8"
|
||||||
|
|
||||||
// auto generated files (update with ./deps.sh)
|
// auto generated files (update with ./deps.sh)
|
||||||
#include "colorthemes.css.hpp"
|
|
||||||
#include "style.css.hpp"
|
|
||||||
#include "theme-beeninorder.css.hpp"
|
|
||||||
#include "theme-ketivah.css.hpp"
|
|
||||||
#include "theme-mangotango.css.hpp"
|
|
||||||
#include "theme-playground.css.hpp"
|
|
||||||
#include "theme-polarnight.css.hpp"
|
|
||||||
#include "theme-snowstorm.css.hpp"
|
|
||||||
#include "index.html.hpp"
|
#include "index.html.hpp"
|
||||||
#include "index-new.html.hpp"
|
|
||||||
#include "index.js.hpp"
|
|
||||||
#include "completion.js.hpp"
|
#include "completion.js.hpp"
|
||||||
#include "system-prompts.js.hpp"
|
|
||||||
#include "prompt-formats.js.hpp"
|
|
||||||
#include "json-schema-to-grammar.mjs.hpp"
|
|
||||||
#include "loading.html.hpp"
|
#include "loading.html.hpp"
|
||||||
|
#include "deps_daisyui.min.css.hpp"
|
||||||
|
#include "deps_markdown-it.js.hpp"
|
||||||
|
#include "deps_tailwindcss.js.hpp"
|
||||||
|
#include "deps_vue.esm-browser.js.hpp"
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <condition_variable>
|
#include <condition_variable>
|
||||||
|
@ -2286,16 +2277,6 @@ int main(int argc, char ** argv) {
|
||||||
std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
|
std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
|
||||||
|
|
||||||
svr->set_default_headers({{"Server", "llama.cpp"}});
|
svr->set_default_headers({{"Server", "llama.cpp"}});
|
||||||
|
|
||||||
// CORS preflight
|
|
||||||
svr->Options(R"(.*)", [](const httplib::Request &, httplib::Response & res) {
|
|
||||||
// Access-Control-Allow-Origin is already set by middleware
|
|
||||||
res.set_header("Access-Control-Allow-Credentials", "true");
|
|
||||||
res.set_header("Access-Control-Allow-Methods", "POST");
|
|
||||||
res.set_header("Access-Control-Allow-Headers", "*");
|
|
||||||
return res.set_content("", "text/html"); // blank response, no data
|
|
||||||
});
|
|
||||||
|
|
||||||
svr->set_logger(log_server_request);
|
svr->set_logger(log_server_request);
|
||||||
|
|
||||||
auto res_error = [](httplib::Response & res, const json & error_data) {
|
auto res_error = [](httplib::Response & res, const json & error_data) {
|
||||||
|
@ -2408,6 +2389,14 @@ int main(int argc, char ** argv) {
|
||||||
// register server middlewares
|
// register server middlewares
|
||||||
svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) {
|
svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) {
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
|
// If this is OPTIONS request, skip validation because browsers don't include Authorization header
|
||||||
|
if (req.method == "OPTIONS") {
|
||||||
|
res.set_header("Access-Control-Allow-Credentials", "true");
|
||||||
|
res.set_header("Access-Control-Allow-Methods", "GET, POST");
|
||||||
|
res.set_header("Access-Control-Allow-Headers", "*");
|
||||||
|
res.set_content("", "text/html"); // blank response, no data
|
||||||
|
return httplib::Server::HandlerResponse::Handled; // skip further processing
|
||||||
|
}
|
||||||
if (!middleware_server_state(req, res)) {
|
if (!middleware_server_state(req, res)) {
|
||||||
return httplib::Server::HandlerResponse::Handled;
|
return httplib::Server::HandlerResponse::Handled;
|
||||||
}
|
}
|
||||||
|
@ -3117,33 +3106,19 @@ int main(int argc, char ** argv) {
|
||||||
// register static assets routes
|
// register static assets routes
|
||||||
if (!params.public_path.empty()) {
|
if (!params.public_path.empty()) {
|
||||||
// Set the base directory for serving static files
|
// Set the base directory for serving static files
|
||||||
svr->set_base_dir(params.public_path);
|
bool is_found = svr->set_mount_point("/", params.public_path);
|
||||||
}
|
if (!is_found) {
|
||||||
|
LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
|
||||||
if (!params.api_keys.empty()) {
|
return 1;
|
||||||
// for now, if API key is set, web UI is unusable
|
}
|
||||||
svr->Get("/", [&](const httplib::Request &, httplib::Response & res) {
|
|
||||||
return res.set_content("Web UI is disabled because API key is set.", "text/html; charset=utf-8");
|
|
||||||
});
|
|
||||||
} else {
|
} else {
|
||||||
// using embedded static files
|
// using embedded static files
|
||||||
svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
|
svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
|
||||||
svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
|
svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
|
||||||
svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
|
svr->Get("/deps_daisyui.min.css", handle_static_file(deps_daisyui_min_css, deps_daisyui_min_css_len, "text/css; charset=utf-8"));
|
||||||
svr->Get("/json-schema-to-grammar.mjs", handle_static_file(json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8"));
|
svr->Get("/deps_markdown-it.js", handle_static_file(deps_markdown_it_js, deps_markdown_it_js_len, "text/javascript; charset=utf-8"));
|
||||||
|
svr->Get("/deps_tailwindcss.js", handle_static_file(deps_tailwindcss_js, deps_tailwindcss_js_len, "text/javascript; charset=utf-8"));
|
||||||
// add new-ui files
|
svr->Get("/deps_vue.esm-browser.js", handle_static_file(deps_vue_esm_browser_js, deps_vue_esm_browser_js_len, "text/javascript; charset=utf-8"));
|
||||||
svr->Get("/colorthemes.css", handle_static_file(colorthemes_css, colorthemes_css_len, "text/css; charset=utf-8"));
|
|
||||||
svr->Get("/style.css", handle_static_file(style_css, style_css_len, "text/css; charset=utf-8"));
|
|
||||||
svr->Get("/theme-beeninorder.css", handle_static_file(theme_beeninorder_css, theme_beeninorder_css_len, "text/css; charset=utf-8"));
|
|
||||||
svr->Get("/theme-ketivah.css", handle_static_file(theme_ketivah_css, theme_ketivah_css_len, "text/css; charset=utf-8"));
|
|
||||||
svr->Get("/theme-mangotango.css", handle_static_file(theme_mangotango_css, theme_mangotango_css_len, "text/css; charset=utf-8"));
|
|
||||||
svr->Get("/theme-playground.css", handle_static_file(theme_playground_css, theme_playground_css_len, "text/css; charset=utf-8"));
|
|
||||||
svr->Get("/theme-polarnight.css", handle_static_file(theme_polarnight_css, theme_polarnight_css_len, "text/css; charset=utf-8"));
|
|
||||||
svr->Get("/theme-snowstorm.css", handle_static_file(theme_snowstorm_css, theme_snowstorm_css_len, "text/css; charset=utf-8"));
|
|
||||||
svr->Get("/index-new.html", handle_static_file(index_new_html, index_new_html_len, "text/html; charset=utf-8"));
|
|
||||||
svr->Get("/system-prompts.js", handle_static_file(system_prompts_js, system_prompts_js_len, "text/javascript; charset=utf-8"));
|
|
||||||
svr->Get("/prompt-formats.js", handle_static_file(prompt_formats_js, prompt_formats_js_len, "text/javascript; charset=utf-8"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// register API routes
|
// register API routes
|
||||||
|
|
|
@ -64,5 +64,5 @@ Feature: Security
|
||||||
| localhost | Access-Control-Allow-Origin | localhost |
|
| localhost | Access-Control-Allow-Origin | localhost |
|
||||||
| web.mydomain.fr | Access-Control-Allow-Origin | web.mydomain.fr |
|
| web.mydomain.fr | Access-Control-Allow-Origin | web.mydomain.fr |
|
||||||
| origin | Access-Control-Allow-Credentials | true |
|
| origin | Access-Control-Allow-Credentials | true |
|
||||||
| web.mydomain.fr | Access-Control-Allow-Methods | POST |
|
| web.mydomain.fr | Access-Control-Allow-Methods | GET, POST |
|
||||||
| web.mydomain.fr | Access-Control-Allow-Headers | * |
|
| web.mydomain.fr | Access-Control-Allow-Headers | * |
|
||||||
|
|
|
@ -515,7 +515,7 @@ extern "C" {
|
||||||
GGML_OP_WIN_UNPART,
|
GGML_OP_WIN_UNPART,
|
||||||
GGML_OP_GET_REL_POS,
|
GGML_OP_GET_REL_POS,
|
||||||
GGML_OP_ADD_REL_POS,
|
GGML_OP_ADD_REL_POS,
|
||||||
GGML_OP_RWKV_WKV,
|
GGML_OP_RWKV_WKV6,
|
||||||
|
|
||||||
GGML_OP_UNARY,
|
GGML_OP_UNARY,
|
||||||
|
|
||||||
|
@ -1752,6 +1752,9 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
enum ggml_prec prec);
|
enum ggml_prec prec);
|
||||||
|
|
||||||
|
GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
|
||||||
|
const struct ggml_tensor * a);
|
||||||
|
|
||||||
// TODO: needs to be adapted to ggml_flash_attn_ext
|
// TODO: needs to be adapted to ggml_flash_attn_ext
|
||||||
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
@ -1825,7 +1828,7 @@ extern "C" {
|
||||||
struct ggml_tensor * pw,
|
struct ggml_tensor * pw,
|
||||||
struct ggml_tensor * ph);
|
struct ggml_tensor * ph);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_rwkv_wkv(
|
GGML_API struct ggml_tensor * ggml_rwkv_wkv6(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * k,
|
struct ggml_tensor * k,
|
||||||
struct ggml_tensor * v,
|
struct ggml_tensor * v,
|
||||||
|
|
|
@ -412,6 +412,8 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
||||||
.gemm = ggml_gemm_q4_0_4x8_q8_0,
|
.gemm = ggml_gemm_q4_0_4x8_q8_0,
|
||||||
},
|
},
|
||||||
[GGML_TYPE_Q4_0_8_8] = {
|
[GGML_TYPE_Q4_0_8_8] = {
|
||||||
|
.vec_dot = NULL,
|
||||||
|
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
.ncols = 8,
|
.ncols = 8,
|
||||||
.gemv = ggml_gemv_q4_0_8x8_q8_0,
|
.gemv = ggml_gemv_q4_0_8x8_q8_0,
|
||||||
|
@ -11678,24 +11680,30 @@ static void ggml_compute_forward_add_rel_pos(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_rwkv_wkv
|
// ggml_compute_forward_rwkv_wkv6
|
||||||
|
|
||||||
static void ggml_compute_forward_rwkv_wkv_f32(
|
static void ggml_compute_forward_rwkv_wkv6_f32(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
const size_t T = dst->src[1]->ne[3];
|
const int64_t T = dst->src[1]->ne[3];
|
||||||
const size_t C = dst->ne[0];
|
const int64_t C = dst->ne[0];
|
||||||
const size_t H = dst->src[1]->ne[2];
|
const int64_t HEADS = dst->src[1]->ne[2];
|
||||||
const size_t n_seqs = dst->src[5]->ne[1];
|
const int64_t n_seqs = dst->src[5]->ne[1];
|
||||||
|
const int64_t head_size = C / HEADS;
|
||||||
|
|
||||||
float * dst_data = (float *) dst->data;
|
float * dst_data = (float *) dst->data;
|
||||||
float * state = ((float *) dst->data) + C * T;
|
float * state = ((float *) dst->data) + C * T;
|
||||||
|
|
||||||
if (params->ith != 0) {
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
if (ith >= HEADS) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
memset(dst_data, 0, T * C * sizeof(float));
|
const int h_start = (HEADS * ith) / nth;
|
||||||
|
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
||||||
|
(HEADS * (ith + 1)) / nth : HEADS;
|
||||||
|
|
||||||
float * k = (float *) dst->src[0]->data;
|
float * k = (float *) dst->src[0]->data;
|
||||||
float * v = (float *) dst->src[1]->data;
|
float * v = (float *) dst->src[1]->data;
|
||||||
|
@ -11703,54 +11711,160 @@ static void ggml_compute_forward_rwkv_wkv_f32(
|
||||||
float * time_faaaa = (float *) dst->src[3]->data;
|
float * time_faaaa = (float *) dst->src[3]->data;
|
||||||
float * time_decay = (float *) dst->src[4]->data;
|
float * time_decay = (float *) dst->src[4]->data;
|
||||||
|
|
||||||
size_t t_stride = H * (C / H);
|
size_t t_stride = HEADS * head_size; // Same to C
|
||||||
|
|
||||||
size_t h_stride = C / H;
|
size_t h_stride = C / HEADS;
|
||||||
size_t h_stride_2d = (C / H) * (C / H);
|
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
|
||||||
|
size_t h_stride_2d = head_size * head_size;
|
||||||
|
|
||||||
// basically fused operations:
|
if (ith == 0) {
|
||||||
// dst = r @ (time_faaaa * (k @ v) + state),
|
memset(dst_data, 0, T * C * sizeof(float));
|
||||||
// state = time_decay * state + (k @ v),
|
}
|
||||||
// recursive through each token
|
ggml_barrier(params->threadpool);
|
||||||
for (size_t t = 0; t < T; t++) {
|
|
||||||
size_t t_offset = t * t_stride;
|
|
||||||
size_t state_offset = (C / H) * C * (t / (T / n_seqs));
|
|
||||||
float * state_cur = state + state_offset;
|
|
||||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
|
||||||
|
|
||||||
for (size_t h = 0; h < H; h++) {
|
|
||||||
size_t h_offset = h * h_stride;
|
|
||||||
size_t t_h_offset = t_offset + h_offset;
|
|
||||||
size_t h_2d_offset = h * h_stride_2d;
|
|
||||||
|
|
||||||
for (size_t i = 0; i < C / H; i++) {
|
#if defined(__AVX__) && !defined(__AVX512F__)
|
||||||
size_t t_h_i_offset = t_h_offset + i;
|
#define GGML_F32X GGML_F32x8
|
||||||
size_t h_i_offset = h_offset + i;
|
#define GGML_F32X_SET1 GGML_F32x8_SET1
|
||||||
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
#define GGML_F32X_LOAD GGML_F32x8_LOAD
|
||||||
|
#define GGML_F32X_STORE GGML_F32x8_STORE
|
||||||
|
#define GGML_F32X_MUL GGML_F32x8_MUL
|
||||||
|
#define GGML_F32X_FMA GGML_F32x8_FMA
|
||||||
|
#define WKV_VECTOR_SIZE 8
|
||||||
|
#elif defined(__AVX512F__)
|
||||||
|
#define GGML_F32X GGML_F32x16
|
||||||
|
#define GGML_F32X_SET1 GGML_F32x16_SET1
|
||||||
|
#define GGML_F32X_LOAD GGML_F32x16_LOAD
|
||||||
|
#define GGML_F32X_STORE GGML_F32x16_STORE
|
||||||
|
#define GGML_F32X_MUL GGML_F32x16_MUL
|
||||||
|
#define GGML_F32X_FMA GGML_F32x16_FMA
|
||||||
|
#define WKV_VECTOR_SIZE 16
|
||||||
|
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
||||||
|
#define GGML_F32X GGML_F32x4
|
||||||
|
#define GGML_F32X_SET1 GGML_F32x4_SET1
|
||||||
|
#define GGML_F32X_LOAD GGML_F32x4_LOAD
|
||||||
|
#define GGML_F32X_STORE GGML_F32x4_STORE
|
||||||
|
#define GGML_F32X_MUL GGML_F32x4_MUL
|
||||||
|
#define GGML_F32X_FMA GGML_F32x4_FMA
|
||||||
|
#define WKV_VECTOR_SIZE 4
|
||||||
|
#endif
|
||||||
|
|
||||||
float k_val = k[t_h_i_offset];
|
#ifdef WKV_VECTOR_SIZE
|
||||||
float r_val = r[t_h_i_offset];
|
const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
|
||||||
float time_faaaa_val = time_faaaa[h_i_offset];
|
|
||||||
// RWKV v6: different time_decay for each token.
|
|
||||||
float time_decay_val = time_decay[t_h_i_offset];
|
|
||||||
|
|
||||||
for (size_t j = 0; j < C / H; j ++) {
|
for (int64_t t = 0; t < T; t++) {
|
||||||
size_t t_h_j_offset = t_h_offset + j;
|
size_t t_offset = t * t_stride;
|
||||||
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||||
|
float * state_cur = state + state_offset;
|
||||||
|
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
||||||
|
|
||||||
float v_val = v[t_h_j_offset];
|
for (int64_t h = h_start; h < h_end; h++) {
|
||||||
float kv_val = v_val * k_val;
|
size_t h_offset = h * h_stride;
|
||||||
float prev_state_val = state_prev[h_2d_i_j_offset];
|
size_t t_h_offset = t_offset + h_offset;
|
||||||
float temp_val = kv_val * time_faaaa_val + prev_state_val;
|
size_t h_2d_offset = h * h_stride_2d;
|
||||||
dst_data[t_h_j_offset] += temp_val * r_val;
|
|
||||||
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
|
for (int64_t i = 0; i < head_size; i++) {
|
||||||
|
size_t t_h_i_offset = t_h_offset + i;
|
||||||
|
size_t h_i_offset = h_offset + i;
|
||||||
|
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
||||||
|
|
||||||
|
float k_val = k[t_h_i_offset];
|
||||||
|
float r_val = r[t_h_i_offset];
|
||||||
|
float time_faaaa_val = time_faaaa[h_i_offset];
|
||||||
|
float time_decay_val = time_decay[t_h_i_offset];
|
||||||
|
|
||||||
|
// Broadcast scalar values to vectors
|
||||||
|
GGML_F32X k_vec = GGML_F32X_SET1(k_val);
|
||||||
|
GGML_F32X r_vec = GGML_F32X_SET1(r_val);
|
||||||
|
GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
|
||||||
|
GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
|
||||||
|
|
||||||
|
for (int64_t j = 0; j < vec_count; j++) {
|
||||||
|
size_t base_j = j * WKV_VECTOR_SIZE;
|
||||||
|
size_t t_h_j_offset = t_h_offset + base_j;
|
||||||
|
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
|
||||||
|
|
||||||
|
// Load x elements at once
|
||||||
|
GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
|
||||||
|
GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
|
||||||
|
GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
|
||||||
|
|
||||||
|
// Compute kv = v * k
|
||||||
|
GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
|
||||||
|
|
||||||
|
// Compute temp = kv * time_faaaa + prev_state
|
||||||
|
GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
|
||||||
|
|
||||||
|
// Update dst: dst += temp * r
|
||||||
|
dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
|
||||||
|
GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
|
||||||
|
|
||||||
|
// Update state: state = prev_state * time_decay + kv
|
||||||
|
GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
|
||||||
|
GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle remaining elements, this will not be used.
|
||||||
|
for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
|
||||||
|
size_t t_h_j_offset = t_h_offset + j;
|
||||||
|
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||||
|
float v_val = v[t_h_j_offset];
|
||||||
|
float kv_val = v_val * k_val;
|
||||||
|
float prev_state_val = state_prev[h_2d_i_j_offset];
|
||||||
|
float temp_val = kv_val * time_faaaa_val + prev_state_val;
|
||||||
|
dst_data[t_h_j_offset] += temp_val * r_val;
|
||||||
|
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
#else
|
||||||
|
// basically fused operations:
|
||||||
|
// dst = r @ (time_faaaa * (k @ v) + state),
|
||||||
|
// state = time_decay * state + (k @ v),
|
||||||
|
// recursive through each token
|
||||||
|
for (int64_t t = 0; t < T; t++) {
|
||||||
|
size_t t_offset = t * t_stride;
|
||||||
|
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||||
|
float * state_cur = state + state_offset;
|
||||||
|
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
||||||
|
|
||||||
|
for (int64_t h = h_start; h < h_end; h++) {
|
||||||
|
size_t h_offset = h * h_stride;
|
||||||
|
size_t t_h_offset = t_offset + h_offset;
|
||||||
|
size_t h_2d_offset = h * h_stride_2d;
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < head_size; i++) {
|
||||||
|
size_t t_h_i_offset = t_h_offset + i;
|
||||||
|
size_t h_i_offset = h_offset + i;
|
||||||
|
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
||||||
|
|
||||||
|
float k_val = k[t_h_i_offset];
|
||||||
|
float r_val = r[t_h_i_offset];
|
||||||
|
float time_faaaa_val = time_faaaa[h_i_offset];
|
||||||
|
// RWKV v6: different time_decay for each token.
|
||||||
|
float time_decay_val = time_decay[t_h_i_offset];
|
||||||
|
|
||||||
|
for (int64_t j = 0; j < head_size; j++) {
|
||||||
|
size_t t_h_j_offset = t_h_offset + j;
|
||||||
|
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||||
|
|
||||||
|
float v_val = v[t_h_j_offset];
|
||||||
|
float kv_val = v_val * k_val;
|
||||||
|
float prev_state_val = state_prev[h_2d_i_j_offset];
|
||||||
|
float temp_val = kv_val * time_faaaa_val + prev_state_val;
|
||||||
|
dst_data[t_h_j_offset] += temp_val * r_val;
|
||||||
|
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_compute_forward_rwkv_wkv(
|
|
||||||
|
static void ggml_compute_forward_rwkv_wkv6(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
@ -11759,7 +11873,7 @@ static void ggml_compute_forward_rwkv_wkv(
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_rwkv_wkv_f32(params, dst);
|
ggml_compute_forward_rwkv_wkv6_f32(params, dst);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
|
@ -12511,9 +12625,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
{
|
{
|
||||||
ggml_compute_forward_add_rel_pos(params, tensor);
|
ggml_compute_forward_add_rel_pos(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_RWKV_WKV:
|
case GGML_OP_RWKV_WKV6:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_rwkv_wkv(params, tensor);
|
ggml_compute_forward_rwkv_wkv6(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_MAP_UNARY:
|
case GGML_OP_MAP_UNARY:
|
||||||
{
|
{
|
||||||
|
@ -12811,7 +12925,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
case GGML_OP_WIN_PART:
|
case GGML_OP_WIN_PART:
|
||||||
case GGML_OP_WIN_UNPART:
|
case GGML_OP_WIN_UNPART:
|
||||||
case GGML_OP_GET_REL_POS:
|
case GGML_OP_GET_REL_POS:
|
||||||
case GGML_OP_RWKV_WKV:
|
case GGML_OP_RWKV_WKV6:
|
||||||
case GGML_OP_MAP_UNARY:
|
case GGML_OP_MAP_UNARY:
|
||||||
case GGML_OP_MAP_BINARY:
|
case GGML_OP_MAP_BINARY:
|
||||||
case GGML_OP_MAP_CUSTOM1_F32:
|
case GGML_OP_MAP_CUSTOM1_F32:
|
||||||
|
|
|
@ -38,7 +38,7 @@ bool g_mul_mat_q = false;
|
||||||
#include "ggml-cuda/tsembd.cuh"
|
#include "ggml-cuda/tsembd.cuh"
|
||||||
#include "ggml-cuda/unary.cuh"
|
#include "ggml-cuda/unary.cuh"
|
||||||
#include "ggml-cuda/upscale.cuh"
|
#include "ggml-cuda/upscale.cuh"
|
||||||
#include "ggml-cuda/rwkv-wkv.cuh"
|
#include "ggml-cuda/wkv6.cuh"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <array>
|
#include <array>
|
||||||
|
@ -2324,8 +2324,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||||
ggml_cuda_cross_entropy_loss(ctx, dst);
|
ggml_cuda_cross_entropy_loss(ctx, dst);
|
||||||
break;
|
break;
|
||||||
case GGML_OP_RWKV_WKV:
|
case GGML_OP_RWKV_WKV6:
|
||||||
ggml_cuda_op_rwkv_wkv(ctx, dst);
|
ggml_cuda_op_rwkv_wkv6(ctx, dst);
|
||||||
break;
|
break;
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||||
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
||||||
|
@ -3158,12 +3158,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_ARANGE:
|
case GGML_OP_ARANGE:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
case GGML_OP_RWKV_WKV:
|
case GGML_OP_RWKV_WKV6:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_FLASH_ATTN_EXT: {
|
case GGML_OP_FLASH_ATTN_EXT: {
|
||||||
#ifndef FLASH_ATTN_AVAILABLE
|
#ifndef FLASH_ATTN_AVAILABLE
|
||||||
return false;
|
return false;
|
||||||
#endif
|
#endif
|
||||||
|
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
|
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,9 +13,9 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
|
||||||
const ggml_tensor * KQV = dst;
|
const ggml_tensor * KQV = dst;
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
|
||||||
const int32_t precision = KQV->op_params[3];
|
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||||
|
|
||||||
if (precision != GGML_PREC_DEFAULT) {
|
if (prec != GGML_PREC_DEFAULT) {
|
||||||
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
||||||
constexpr int cols_per_block = 16;
|
constexpr int cols_per_block = 16;
|
||||||
switch (Q->ne[0]) {
|
switch (Q->ne[0]) {
|
||||||
|
@ -301,11 +301,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
|
|
||||||
ggml_cuda_set_device(ctx.device);
|
ggml_cuda_set_device(ctx.device);
|
||||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||||
const int32_t precision = KQV->op_params[3];
|
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||||
|
|
||||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||||
if (cc >= CC_OFFSET_AMD) {
|
if (cc >= CC_OFFSET_AMD) {
|
||||||
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||||
} else {
|
} else {
|
||||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||||
|
@ -332,7 +332,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
||||||
if (precision == GGML_PREC_DEFAULT) {
|
if (prec == GGML_PREC_DEFAULT) {
|
||||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||||
return;
|
return;
|
||||||
} else if(Q->ne[0] <= 128) {
|
} else if(Q->ne[0] <= 128) {
|
||||||
|
|
|
@ -1,5 +0,0 @@
|
||||||
#include "common.cuh"
|
|
||||||
|
|
||||||
#define CUDA_WKV_BLOCK_SIZE 64
|
|
||||||
|
|
||||||
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
|
@ -1,5 +1,5 @@
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
#include "rwkv-wkv.cuh"
|
#include "wkv6.cuh"
|
||||||
|
|
||||||
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
|
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
|
@ -64,7 +64,7 @@ static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const float * k_d = (const float *)dst->src[0]->data;
|
const float * k_d = (const float *)dst->src[0]->data;
|
||||||
const float * v_d = (const float *)dst->src[1]->data;
|
const float * v_d = (const float *)dst->src[1]->data;
|
||||||
const float * r_d = (const float *)dst->src[2]->data;
|
const float * r_d = (const float *)dst->src[2]->data;
|
||||||
|
@ -83,7 +83,7 @@ void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
|
||||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(C % H == 0);
|
GGML_ASSERT(C % H == 0);
|
||||||
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE);
|
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64
|
||||||
|
|
||||||
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
||||||
}
|
}
|
5
ggml/src/ggml-cuda/wkv6.cuh
Normal file
5
ggml/src/ggml-cuda/wkv6.cuh
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
#include "common.cuh"
|
||||||
|
|
||||||
|
#define CUDA_WKV_BLOCK_SIZE 64
|
||||||
|
|
||||||
|
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@ -36,16 +36,20 @@ static struct ggml_backend_metal_device_context {
|
||||||
id<MTLDevice> mtl_device;
|
id<MTLDevice> mtl_device;
|
||||||
int mtl_device_ref_count;
|
int mtl_device_ref_count;
|
||||||
|
|
||||||
bool support_simdgroup_reduction;
|
bool has_simdgroup_reduction;
|
||||||
bool support_simdgroup_mm;
|
bool has_simdgroup_mm;
|
||||||
|
bool has_bfloat;
|
||||||
|
bool use_bfloat;
|
||||||
|
|
||||||
char name[128];
|
char name[128];
|
||||||
} g_ggml_ctx_dev_main = {
|
} g_ggml_ctx_dev_main = {
|
||||||
/*.mtl_device =*/ nil,
|
/*.mtl_device =*/ nil,
|
||||||
/*.mtl_device_ref_count =*/ 0,
|
/*.mtl_device_ref_count =*/ 0,
|
||||||
/*.support_simdgroup_reduction =*/ false,
|
/*.has_simdgroup_reduction =*/ false,
|
||||||
/*.support_simdgroup_mm =*/ false,
|
/*.has_simdgroup_mm =*/ false,
|
||||||
/*.name =*/ "",
|
/*.has_bfloat =*/ false,
|
||||||
|
/*.use_bfloat =*/ false,
|
||||||
|
/*.name =*/ "",
|
||||||
};
|
};
|
||||||
|
|
||||||
// acquire
|
// acquire
|
||||||
|
@ -55,10 +59,19 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
||||||
if (ctx->mtl_device == nil) {
|
if (ctx->mtl_device == nil) {
|
||||||
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
||||||
|
|
||||||
ctx->support_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
||||||
ctx->support_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
||||||
|
|
||||||
ctx->support_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
||||||
|
|
||||||
|
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
||||||
|
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
||||||
|
|
||||||
|
#if defined(GGML_METAL_USE_BF16)
|
||||||
|
ctx->use_bfloat = ctx->has_bfloat;
|
||||||
|
#else
|
||||||
|
ctx->use_bfloat = false;
|
||||||
|
#endif
|
||||||
|
|
||||||
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
||||||
}
|
}
|
||||||
|
@ -120,6 +133,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
|
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
|
||||||
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
|
||||||
|
@ -146,10 +160,14 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
||||||
|
@ -170,10 +188,11 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
||||||
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
||||||
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
|
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
|
||||||
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
|
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
|
||||||
|
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
|
||||||
|
@ -195,6 +214,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
|
||||||
|
@ -216,6 +236,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
|
||||||
|
@ -256,6 +277,12 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
||||||
|
@ -287,12 +314,14 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
||||||
|
@ -300,8 +329,11 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
||||||
|
GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
||||||
|
@ -480,7 +512,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
// dictionary of preprocessor macros
|
// dictionary of preprocessor macros
|
||||||
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
||||||
|
|
||||||
MTLCompileOptions* options = [MTLCompileOptions new];
|
if (ctx_dev->use_bfloat) {
|
||||||
|
[prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"];
|
||||||
|
}
|
||||||
|
|
||||||
|
MTLCompileOptions * options = [MTLCompileOptions new];
|
||||||
options.preprocessorMacros = prep;
|
options.preprocessorMacros = prep;
|
||||||
|
|
||||||
//[options setFastMathEnabled:false];
|
//[options setFastMathEnabled:false];
|
||||||
|
@ -530,9 +566,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false");
|
GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false");
|
||||||
GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false");
|
GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false");
|
||||||
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
|
GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
|
||||||
|
GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false");
|
||||||
|
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
|
||||||
|
|
||||||
ctx->capture_next_compute = false;
|
ctx->capture_next_compute = false;
|
||||||
ctx->capture_started = false;
|
ctx->capture_started = false;
|
||||||
|
@ -568,6 +606,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
|
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
|
||||||
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
|
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
|
||||||
kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
|
kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
|
||||||
|
GGML_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
|
||||||
|
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
|
||||||
|
(int) kernel->pipeline.threadExecutionWidth); \
|
||||||
[metal_function release]; \
|
[metal_function release]; \
|
||||||
if (error) { \
|
if (error) { \
|
||||||
GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
||||||
|
@ -578,8 +619,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
|
GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
|
||||||
}
|
}
|
||||||
|
|
||||||
const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
|
const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm;
|
||||||
const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
|
const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
|
||||||
|
const bool use_bfloat = ctx_dev->use_bfloat;
|
||||||
|
|
||||||
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
||||||
|
|
||||||
|
@ -607,14 +649,15 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
||||||
|
@ -635,101 +678,108 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, support_simdgroup_reduction);
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, support_simdgroup_reduction);
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, support_simdgroup_reduction);
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
||||||
|
@ -745,58 +795,69 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
||||||
|
@ -886,15 +947,18 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
|
static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
|
||||||
for (size_t i = 0, n = 3; i < n; ++i) {
|
const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm;
|
||||||
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
|
const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
|
||||||
return false;
|
const bool use_bfloat = ctx_dev->use_bfloat;
|
||||||
|
|
||||||
|
if (!use_bfloat) {
|
||||||
|
for (size_t i = 0, n = 3; i < n; ++i) {
|
||||||
|
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
|
|
||||||
const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
|
|
||||||
|
|
||||||
switch (op->op) {
|
switch (op->op) {
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(op)) {
|
switch (ggml_get_unary_op(op)) {
|
||||||
|
@ -932,7 +996,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
return support_simdgroup_reduction;
|
return has_simdgroup_reduction;
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
return true;
|
return true;
|
||||||
|
@ -952,13 +1016,13 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||||
if (op->src[1]->type != op->src[2]->type) {
|
if (op->src[1]->type != op->src[2]->type) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||||
case GGML_OP_SSM_CONV:
|
case GGML_OP_SSM_CONV:
|
||||||
case GGML_OP_SSM_SCAN:
|
case GGML_OP_SSM_SCAN:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
return support_simdgroup_reduction &&
|
return has_simdgroup_reduction &&
|
||||||
(op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
|
(op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
|
@ -969,6 +1033,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||||
switch (op->type) {
|
switch (op->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
|
@ -981,10 +1046,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||||
}
|
}
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
switch (op->type) {
|
switch (op->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
|
switch (op->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
@ -1070,7 +1143,7 @@ static void ggml_metal_encode_node(
|
||||||
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
||||||
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
||||||
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
||||||
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
|
const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
|
||||||
|
|
||||||
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
||||||
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
||||||
|
@ -1855,6 +1928,7 @@ static void ggml_metal_encode_node(
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
||||||
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||||
|
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||||
default: break;
|
default: break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1863,6 +1937,7 @@ static void ggml_metal_encode_node(
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
|
||||||
|
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
||||||
|
@ -1940,6 +2015,25 @@ static void ggml_metal_encode_node(
|
||||||
nrows = 4;
|
nrows = 4;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
|
{
|
||||||
|
nth0 = 32;
|
||||||
|
nth1 = 1;
|
||||||
|
if (src1t == GGML_TYPE_F32) {
|
||||||
|
if (ne11 * ne12 < 4) {
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
|
||||||
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
|
||||||
|
nrows = ne11;
|
||||||
|
} else {
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
|
||||||
|
nrows = 4;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
|
||||||
|
nrows = 4;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
{
|
{
|
||||||
nth0 = 8;
|
nth0 = 8;
|
||||||
|
@ -2158,12 +2252,12 @@ static void ggml_metal_encode_node(
|
||||||
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
||||||
ne00 % 32 == 0 && ne00 >= 64 &&
|
ne00 % 32 == 0 && ne00 >= 64 &&
|
||||||
dst_rows > dst_rows_min) {
|
dst_rows > dst_rows_min) {
|
||||||
|
|
||||||
// some Metal matrix data types require aligned pointers
|
// some Metal matrix data types require aligned pointers
|
||||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
||||||
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||||
|
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||||
default: break;
|
default: break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2172,6 +2266,7 @@ static void ggml_metal_encode_node(
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
|
||||||
|
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
|
||||||
|
@ -2241,6 +2336,13 @@ static void ggml_metal_encode_node(
|
||||||
nth1 = 1;
|
nth1 = 1;
|
||||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||||
|
nth0 = 32;
|
||||||
|
nth1 = 1;
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
|
||||||
|
} break;
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
{
|
{
|
||||||
nth0 = 8;
|
nth0 = 8;
|
||||||
|
@ -2438,6 +2540,7 @@ static void ggml_metal_encode_node(
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
|
||||||
|
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline; break;
|
||||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
|
||||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
|
||||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
|
||||||
|
@ -2962,6 +3065,23 @@ static void ggml_metal_encode_node(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
|
{
|
||||||
|
switch (ne00) {
|
||||||
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
|
||||||
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
|
||||||
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
|
||||||
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
|
||||||
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
|
||||||
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
|
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||||
|
GGML_ABORT("add template specialization for this size");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
{
|
{
|
||||||
switch (ne00) {
|
switch (ne00) {
|
||||||
|
@ -3062,6 +3182,7 @@ static void ggml_metal_encode_node(
|
||||||
{
|
{
|
||||||
switch (src1->type) {
|
switch (src1->type) {
|
||||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
||||||
|
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break;
|
||||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
|
||||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
|
||||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
|
||||||
|
@ -3079,6 +3200,7 @@ static void ggml_metal_encode_node(
|
||||||
{
|
{
|
||||||
switch (src1->type) {
|
switch (src1->type) {
|
||||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
||||||
|
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break;
|
||||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
|
||||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
|
||||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
|
||||||
|
@ -3123,18 +3245,15 @@ static void ggml_metal_encode_node(
|
||||||
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
||||||
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
||||||
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
||||||
[encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
|
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17];
|
||||||
[encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
|
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18];
|
||||||
[encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
|
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19];
|
||||||
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
|
[encoder setBytes:&scale length:sizeof( float) atIndex:20];
|
||||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
|
[encoder setBytes:&max_bias length:sizeof( float) atIndex:21];
|
||||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
|
[encoder setBytes:&m0 length:sizeof(m0) atIndex:22];
|
||||||
[encoder setBytes:&scale length:sizeof( float) atIndex:23];
|
[encoder setBytes:&m1 length:sizeof(m1) atIndex:23];
|
||||||
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
|
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24];
|
||||||
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
|
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
|
||||||
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
|
|
||||||
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
|
|
||||||
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
|
|
||||||
|
|
||||||
if (!use_vec_kernel) {
|
if (!use_vec_kernel) {
|
||||||
// half8x8 kernel
|
// half8x8 kernel
|
||||||
|
@ -3145,11 +3264,14 @@ static void ggml_metal_encode_node(
|
||||||
GGML_ASSERT(nqptg % 8 == 0);
|
GGML_ASSERT(nqptg % 8 == 0);
|
||||||
GGML_ASSERT(ncpsg % 32 == 0);
|
GGML_ASSERT(ncpsg % 32 == 0);
|
||||||
|
|
||||||
|
// 2*(2*ncpsg + nqptg)*(nsg)
|
||||||
|
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
|
||||||
|
//
|
||||||
// 16*32*(nsg)
|
// 16*32*(nsg)
|
||||||
// the shared memory needed for the simdgroups to load the KV cache
|
// the shared memory needed for the simdgroups to load the KV cache
|
||||||
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
||||||
//
|
//
|
||||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
||||||
|
|
||||||
int64_t nsgmax = 2;
|
int64_t nsgmax = 2;
|
||||||
|
|
||||||
|
@ -3183,12 +3305,12 @@ static void ggml_metal_encode_node(
|
||||||
|
|
||||||
// ne00 + 2*ncpsg*(nsg)
|
// ne00 + 2*ncpsg*(nsg)
|
||||||
// for each query, we load it as f16 in shared memory (ne00)
|
// for each query, we load it as f16 in shared memory (ne00)
|
||||||
// and store the attention scores (nqptg x ncpsg) as f32
|
// and store the soft_max values and the mask
|
||||||
//
|
//
|
||||||
// 2*ne00*(nsg)
|
// ne00*(nsg)
|
||||||
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
// each simdgroup has a full f16 head vector in shared mem to accumulate results
|
||||||
//
|
//
|
||||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))
|
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
|
||||||
|
|
||||||
int64_t nsgmax = 2;
|
int64_t nsgmax = 2;
|
||||||
|
|
||||||
|
@ -3237,6 +3359,7 @@ static void ggml_metal_encode_node(
|
||||||
switch (dstt) {
|
switch (dstt) {
|
||||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
|
||||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
|
||||||
|
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
|
||||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
|
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
|
||||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
|
||||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
||||||
|
@ -3254,6 +3377,14 @@ static void ggml_metal_encode_node(
|
||||||
default: GGML_ABORT("not implemented");
|
default: GGML_ABORT("not implemented");
|
||||||
};
|
};
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
|
{
|
||||||
|
switch (dstt) {
|
||||||
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
|
||||||
|
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
|
||||||
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
|
};
|
||||||
|
} break;
|
||||||
default: GGML_ABORT("not implemented");
|
default: GGML_ABORT("not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -26,5 +26,8 @@
|
||||||
#include "softmax.hpp"
|
#include "softmax.hpp"
|
||||||
#include "tsembd.hpp"
|
#include "tsembd.hpp"
|
||||||
#include "im2col.hpp"
|
#include "im2col.hpp"
|
||||||
|
#include "wkv6.hpp"
|
||||||
|
#include "outprod.hpp"
|
||||||
|
#include "element_wise.hpp"
|
||||||
|
|
||||||
#endif // GGML_SYCL_BACKEND_HPP
|
#endif // GGML_SYCL_BACKEND_HPP
|
||||||
|
|
|
@ -62,3 +62,43 @@ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block
|
||||||
}
|
}
|
||||||
return sycl_down_blk_size;
|
return sycl_down_blk_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||||
|
const ggml_tensor *src1, ggml_tensor *dst,
|
||||||
|
const ggml_sycl_op_flatten_t op) try {
|
||||||
|
const int64_t nrows0 = ggml_nrows(src0);
|
||||||
|
|
||||||
|
const bool use_src1 = src1 != nullptr;
|
||||||
|
const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
|
||||||
|
|
||||||
|
GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
|
||||||
|
GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
|
||||||
|
|
||||||
|
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||||
|
ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
|
||||||
|
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||||
|
|
||||||
|
// dd = data device
|
||||||
|
float * src0_ddf = (float *) src0->data;
|
||||||
|
float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
|
||||||
|
float * dst_ddf = (float *) dst->data;
|
||||||
|
|
||||||
|
ggml_sycl_pool_alloc<float> src0_f(ctx.pool());
|
||||||
|
ggml_sycl_pool_alloc<float> src1_f(ctx.pool());
|
||||||
|
ggml_sycl_pool_alloc<float> dst_f(ctx.pool());
|
||||||
|
|
||||||
|
ggml_sycl_set_device(ctx.device);
|
||||||
|
queue_ptr main_stream = ctx.stream();
|
||||||
|
// GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n",
|
||||||
|
// ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device);
|
||||||
|
|
||||||
|
// do the computation
|
||||||
|
op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
|
||||||
|
// print_ggml_tensor("tensor", dst);
|
||||||
|
}
|
||||||
|
catch (sycl::exception const &exc) {
|
||||||
|
|
||||||
|
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
||||||
|
<< ", line:" << __LINE__ << std::endl;
|
||||||
|
std::exit(1);
|
||||||
|
}
|
||||||
|
|
|
@ -404,4 +404,262 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
|
||||||
|
|
||||||
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
|
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
|
||||||
|
|
||||||
|
typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||||
|
const ggml_tensor *src1,
|
||||||
|
ggml_tensor *dst, const float *src0_dd,
|
||||||
|
const float *src1_dd, float *dst_dd,
|
||||||
|
const queue_ptr &main_stream);
|
||||||
|
|
||||||
|
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
||||||
|
static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
||||||
|
int ne0, int ne1, int ne2, int ne3,
|
||||||
|
int ne10, int ne11, int ne12, int ne13,
|
||||||
|
/*int s0, */ int s1, int s2, int s3,
|
||||||
|
/*int s10,*/ int s11, int s12, int s13,
|
||||||
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
|
item_ct1.get_local_id(2);
|
||||||
|
const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||||
|
item_ct1.get_local_id(1));
|
||||||
|
const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
|
||||||
|
item_ct1.get_local_id(0)) /
|
||||||
|
ne3;
|
||||||
|
const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
|
||||||
|
item_ct1.get_local_id(0)) %
|
||||||
|
ne3;
|
||||||
|
|
||||||
|
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int i11 = i1 % ne11;
|
||||||
|
const int i12 = i2 % ne12;
|
||||||
|
const int i13 = i3 % ne13;
|
||||||
|
|
||||||
|
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
|
||||||
|
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
||||||
|
const size_t i_dst = i_src0;
|
||||||
|
|
||||||
|
const src0_t * src0_row = src0 + i_src0;
|
||||||
|
const src1_t * src1_row = src1 + i_src1;
|
||||||
|
dst_t * dst_row = dst + i_dst;
|
||||||
|
|
||||||
|
for (int i0 = i0s; i0 < ne0;
|
||||||
|
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
|
||||||
|
const int i10 = i0 % ne10;
|
||||||
|
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
||||||
|
static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
||||||
|
int ne0, int ne1, int ne2, int ne3,
|
||||||
|
int ne10, int ne11, int ne12, int ne13,
|
||||||
|
/*int s0, */ int s1, int s2, int s3,
|
||||||
|
/*int s10,*/ int s11, int s12, int s13,
|
||||||
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
|
||||||
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
|
item_ct1.get_local_id(2);
|
||||||
|
|
||||||
|
const int i3 = i/(ne2*ne1*ne0);
|
||||||
|
const int i2 = (i/(ne1*ne0)) % ne2;
|
||||||
|
const int i1 = (i/ne0) % ne1;
|
||||||
|
const int i0 = i % ne0;
|
||||||
|
|
||||||
|
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int i11 = i1 % ne11;
|
||||||
|
const int i12 = i2 % ne12;
|
||||||
|
const int i13 = i3 % ne13;
|
||||||
|
|
||||||
|
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
|
||||||
|
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
||||||
|
const size_t i_dst = i_src0;
|
||||||
|
|
||||||
|
const src0_t * src0_row = src0 + i_src0;
|
||||||
|
const src1_t * src1_row = src1 + i_src1;
|
||||||
|
dst_t * dst_row = dst + i_dst;
|
||||||
|
|
||||||
|
const int i10 = i0 % ne10;
|
||||||
|
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<float (*bin_op)(const float, const float)>
|
||||||
|
struct bin_bcast_sycl {
|
||||||
|
template <typename src0_t, typename src1_t, typename dst_t>
|
||||||
|
void operator()(ggml_backend_sycl_context & ctx,
|
||||||
|
const struct ggml_tensor *src0,
|
||||||
|
const struct ggml_tensor *src1, struct ggml_tensor *dst,
|
||||||
|
const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
|
||||||
|
queue_ptr stream) {
|
||||||
|
|
||||||
|
GGML_TENSOR_BINARY_OP_LOCALS
|
||||||
|
|
||||||
|
int nr0 = ne10/ne0;
|
||||||
|
int nr1 = ne11/ne1;
|
||||||
|
int nr2 = ne12/ne2;
|
||||||
|
int nr3 = ne13/ne3;
|
||||||
|
|
||||||
|
int nr[4] = { nr0, nr1, nr2, nr3 };
|
||||||
|
|
||||||
|
// collapse dimensions until first broadcast dimension
|
||||||
|
int64_t cne0[] = {ne0, ne1, ne2, ne3};
|
||||||
|
int64_t cne1[] = {ne10, ne11, ne12, ne13};
|
||||||
|
size_t cnb0[] = {nb0, nb1, nb2, nb3};
|
||||||
|
size_t cnb1[] = {nb10, nb11, nb12, nb13};
|
||||||
|
auto collapse = [](int64_t cne[]) {
|
||||||
|
cne[0] *= cne[1];
|
||||||
|
cne[1] = cne[2];
|
||||||
|
cne[2] = cne[3];
|
||||||
|
cne[3] = 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
|
||||||
|
cnb[1] *= cne[1];
|
||||||
|
cnb[2] *= cne[2];
|
||||||
|
cnb[3] *= cne[3];
|
||||||
|
};
|
||||||
|
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
if (nr[i] != 1) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (i > 0) {
|
||||||
|
collapse_nb(cnb0, cne0);
|
||||||
|
collapse_nb(cnb1, cne1);
|
||||||
|
collapse(cne0);
|
||||||
|
collapse(cne1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
int64_t ne0 = cne0[0];
|
||||||
|
int64_t ne1 = cne0[1];
|
||||||
|
int64_t ne2 = cne0[2];
|
||||||
|
int64_t ne3 = cne0[3];
|
||||||
|
|
||||||
|
int64_t ne10 = cne1[0];
|
||||||
|
int64_t ne11 = cne1[1];
|
||||||
|
int64_t ne12 = cne1[2];
|
||||||
|
int64_t ne13 = cne1[3];
|
||||||
|
|
||||||
|
size_t nb0 = cnb0[0];
|
||||||
|
size_t nb1 = cnb0[1];
|
||||||
|
size_t nb2 = cnb0[2];
|
||||||
|
size_t nb3 = cnb0[3];
|
||||||
|
|
||||||
|
size_t nb10 = cnb1[0];
|
||||||
|
size_t nb11 = cnb1[1];
|
||||||
|
size_t nb12 = cnb1[2];
|
||||||
|
size_t nb13 = cnb1[3];
|
||||||
|
|
||||||
|
size_t s0 = nb0 / sizeof(dst_t);
|
||||||
|
size_t s1 = nb1 / sizeof(dst_t);
|
||||||
|
size_t s2 = nb2 / sizeof(dst_t);
|
||||||
|
size_t s3 = nb3 / sizeof(dst_t);
|
||||||
|
|
||||||
|
size_t s10 = nb10 / sizeof(src1_t);
|
||||||
|
size_t s11 = nb11 / sizeof(src1_t);
|
||||||
|
size_t s12 = nb12 / sizeof(src1_t);
|
||||||
|
size_t s13 = nb13 / sizeof(src1_t);
|
||||||
|
|
||||||
|
GGML_ASSERT(s0 == 1);
|
||||||
|
GGML_ASSERT(s10 == 1);
|
||||||
|
|
||||||
|
const int block_size = 128;
|
||||||
|
|
||||||
|
int64_t hne0 = std::max(ne0/2LL, 1LL);
|
||||||
|
|
||||||
|
sycl::range<3> block_dims(1, 1, 1);
|
||||||
|
block_dims[2] = std::min<unsigned int>(hne0, block_size);
|
||||||
|
block_dims[1] = std::min<unsigned int>(
|
||||||
|
ne1, block_size / (unsigned int)block_dims[2]);
|
||||||
|
block_dims[0] = std::min(
|
||||||
|
std::min<unsigned int>(
|
||||||
|
ne2 * ne3, block_size / (unsigned int)block_dims[2] /
|
||||||
|
(unsigned int)block_dims[1]),
|
||||||
|
64U);
|
||||||
|
|
||||||
|
sycl::range<3> block_nums(
|
||||||
|
(ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
|
||||||
|
(ne1 + block_dims[1] - 1) / block_dims[1],
|
||||||
|
(hne0 + block_dims[2] - 1) / block_dims[2]);
|
||||||
|
|
||||||
|
if (block_nums[0] > 65535) {
|
||||||
|
// this is the maximum number of blocks in z direction, fallback to 1D grid kernel
|
||||||
|
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
|
||||||
|
sycl::range<3>(1, 1, block_size),
|
||||||
|
sycl::range<3>(1, 1, block_size)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
k_bin_bcast_unravel<bin_op>(
|
||||||
|
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
|
||||||
|
ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
|
||||||
|
s13, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
/*
|
||||||
|
DPCT1049:16: The work-group size passed to the SYCL kernel may
|
||||||
|
exceed the limit. To get the device limit, query
|
||||||
|
info::device::max_work_group_size. Adjust the work-group size if
|
||||||
|
needed.
|
||||||
|
*/
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
|
||||||
|
ne2, ne3, ne10, ne11, ne12, ne13,
|
||||||
|
s1, s2, s3, s11, s12, s13,
|
||||||
|
item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class op>
|
||||||
|
inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||||
|
const ggml_tensor *src1, ggml_tensor *dst,
|
||||||
|
const float *src0_dd, const float *src1_dd,
|
||||||
|
float *dst_dd,
|
||||||
|
const queue_ptr &main_stream) {
|
||||||
|
|
||||||
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
|
op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||||
|
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
|
||||||
|
(sycl::half *)dst_dd, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
||||||
|
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
|
||||||
|
main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
|
||||||
|
op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
|
||||||
|
main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
|
||||||
|
op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
|
||||||
|
main_stream);
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
|
||||||
|
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||||
|
const ggml_tensor *src1, ggml_tensor *dst,
|
||||||
|
const ggml_sycl_op_flatten_t op);
|
||||||
|
|
||||||
#endif // GGML_SYCL_COMMON_HPP
|
#endif // GGML_SYCL_COMMON_HPP
|
||||||
|
|
|
@ -106,6 +106,7 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
|
||||||
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
|
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
|
||||||
});
|
});
|
||||||
break;
|
break;
|
||||||
|
// dim >=2 will be dispatched to the default path
|
||||||
default:
|
default:
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(gridDim *
|
sycl::nd_range<3>(gridDim *
|
||||||
|
|
1011
ggml/src/ggml-sycl/element_wise.cpp
Normal file
1011
ggml/src/ggml-sycl/element_wise.cpp
Normal file
File diff suppressed because it is too large
Load diff
76
ggml/src/ggml-sycl/element_wise.hpp
Normal file
76
ggml/src/ggml-sycl/element_wise.hpp
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
#ifndef GGML_SYCL_ELEMENTWISE_HPP
|
||||||
|
#define GGML_SYCL_ELEMENTWISE_HPP
|
||||||
|
|
||||||
|
#include "common.hpp"
|
||||||
|
|
||||||
|
static __dpct_inline__ float op_repeat(const float a, const float b) {
|
||||||
|
return b;
|
||||||
|
GGML_UNUSED(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __dpct_inline__ float op_add(const float a, const float b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __dpct_inline__ float op_sub(const float a, const float b) {
|
||||||
|
return a - b;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __dpct_inline__ float op_mul(const float a, const float b) {
|
||||||
|
return a * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __dpct_inline__ float op_div(const float a, const float b) {
|
||||||
|
return a / b;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_sin(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_cos(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_exp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_log(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_neg(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_step(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_sub(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
#endif // GGML_SYCL_ELEMENTWISE_HPP
|
55
ggml/src/ggml-sycl/outprod.cpp
Normal file
55
ggml/src/ggml-sycl/outprod.cpp
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
#include <sycl/sycl.hpp>
|
||||||
|
#include "outprod.hpp"
|
||||||
|
|
||||||
|
|
||||||
|
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||||
|
const ggml_tensor* src1, ggml_tensor* dst) {
|
||||||
|
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||||
|
|
||||||
|
GGML_TENSOR_BINARY_OP_LOCALS
|
||||||
|
|
||||||
|
// Get SYCL queue
|
||||||
|
dpct::queue_ptr stream = ctx.stream();
|
||||||
|
|
||||||
|
// Dimension checks
|
||||||
|
GGML_ASSERT(ne01 == ne11); // Inner dimensions must match
|
||||||
|
GGML_ASSERT(ne0 == ne00); // Output rows match src0 rows
|
||||||
|
GGML_ASSERT(ne1 == ne10); // Output cols match src1 cols
|
||||||
|
|
||||||
|
// Get data pointers
|
||||||
|
const float* src0_d = (const float*)src0->data;
|
||||||
|
const float* src1_d = (const float*)src1->data;
|
||||||
|
float* dst_d = (float*)dst->data;
|
||||||
|
|
||||||
|
// GEMM parameters
|
||||||
|
const float alpha = 1.0f;
|
||||||
|
const float beta = 0.0f;
|
||||||
|
|
||||||
|
// Handle transposition of src1
|
||||||
|
const bool src1_T = ggml_is_transposed(src1);
|
||||||
|
const oneapi::mkl::transpose src1_op =
|
||||||
|
src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
|
||||||
|
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Perform matrix multiplication using oneMKL GEMM
|
||||||
|
oneapi::mkl::blas::gemm(*stream,
|
||||||
|
oneapi::mkl::transpose::nontrans, src1_op,
|
||||||
|
ne0, ne1, ne01,
|
||||||
|
alpha,
|
||||||
|
src0_d, ne00,
|
||||||
|
src1_d, ldb,
|
||||||
|
beta,
|
||||||
|
dst_d, ne0);
|
||||||
|
}
|
||||||
|
catch (sycl::exception const& exc) {
|
||||||
|
std::cerr << exc.what() << std::endl;
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
}
|
11
ggml/src/ggml-sycl/outprod.hpp
Normal file
11
ggml/src/ggml-sycl/outprod.hpp
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
#ifndef GGML_SYCL_OUTPROD_HPP
|
||||||
|
#define GGML_SYCL_OUTPROD_HPP
|
||||||
|
|
||||||
|
#include "common.hpp"
|
||||||
|
|
||||||
|
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||||
|
const ggml_tensor* src1, ggml_tensor* dst);
|
||||||
|
|
||||||
|
|
||||||
|
#endif // GGML_SYCL_OUTPROD_HPP
|
||||||
|
|
|
@ -25,6 +25,11 @@
|
||||||
#define SYCL_RELU_BLOCK_SIZE 256
|
#define SYCL_RELU_BLOCK_SIZE 256
|
||||||
#define SYCL_HARDSIGMOID_BLOCK_SIZE 256
|
#define SYCL_HARDSIGMOID_BLOCK_SIZE 256
|
||||||
#define SYCL_HARDSWISH_BLOCK_SIZE 256
|
#define SYCL_HARDSWISH_BLOCK_SIZE 256
|
||||||
|
#define SYCL_EXP_BLOCK_SIZE 256
|
||||||
|
#define SYCL_NEG_BLOCK_SIZE 256
|
||||||
|
#define SYCL_SIGMOID_BLOCK_SIZE 256
|
||||||
|
#define SYCL_SQRT_BLOCK_SIZE 256
|
||||||
|
#define SYCL_SIN_BLOCK_SIZE 256
|
||||||
#define SYCL_SQR_BLOCK_SIZE 256
|
#define SYCL_SQR_BLOCK_SIZE 256
|
||||||
#define SYCL_CPY_BLOCK_SIZE 32
|
#define SYCL_CPY_BLOCK_SIZE 32
|
||||||
#define SYCL_SCALE_BLOCK_SIZE 256
|
#define SYCL_SCALE_BLOCK_SIZE 256
|
||||||
|
@ -41,6 +46,7 @@
|
||||||
#define SYCL_ACC_BLOCK_SIZE 256
|
#define SYCL_ACC_BLOCK_SIZE 256
|
||||||
#define SYCL_IM2COL_BLOCK_SIZE 256
|
#define SYCL_IM2COL_BLOCK_SIZE 256
|
||||||
#define SYCL_POOL2D_BLOCK_SIZE 256
|
#define SYCL_POOL2D_BLOCK_SIZE 256
|
||||||
|
#define SYCL_ARGMAX_BLOCK_SIZE 256
|
||||||
#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
|
#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
|
||||||
#define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
|
#define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
|
||||||
|
|
||||||
|
|
138
ggml/src/ggml-sycl/wkv6.cpp
Normal file
138
ggml/src/ggml-sycl/wkv6.cpp
Normal file
|
@ -0,0 +1,138 @@
|
||||||
|
#include <sycl/sycl.hpp>
|
||||||
|
#include "wkv6.hpp"
|
||||||
|
|
||||||
|
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
|
||||||
|
|
||||||
|
// Helper function for the main kernel
|
||||||
|
static void rwkv_wkv_f32_kernel(
|
||||||
|
const int B, const int T, const int C, const int H,
|
||||||
|
const float* k, const float* v, const float* r,
|
||||||
|
const float* tf, const float* td, const float* s,
|
||||||
|
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
||||||
|
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
const int bid = item_ct1.get_group(2);
|
||||||
|
|
||||||
|
const int head_size = WKV_BLOCK_SIZE;
|
||||||
|
const int batch_i = bid / H;
|
||||||
|
const int head_i = bid % H;
|
||||||
|
const int state_size = C * head_size;
|
||||||
|
const int n_seq_tokens = T / B;
|
||||||
|
|
||||||
|
// Set up shared memory pointers
|
||||||
|
float* _k = shared_mem;
|
||||||
|
float* _r = _k + head_size;
|
||||||
|
float* _tf = _r + head_size;
|
||||||
|
float* _td = _tf + head_size;
|
||||||
|
|
||||||
|
// Local state array
|
||||||
|
float state[WKV_BLOCK_SIZE];
|
||||||
|
|
||||||
|
// Load initial state
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < head_size; i++) {
|
||||||
|
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sync threads before shared memory operations
|
||||||
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
|
|
||||||
|
// Load time-mixing parameters
|
||||||
|
_tf[tid] = tf[head_i * head_size + tid];
|
||||||
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
|
|
||||||
|
// Main sequence processing loop
|
||||||
|
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
||||||
|
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
||||||
|
t += C) {
|
||||||
|
|
||||||
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
|
|
||||||
|
// Load current timestep data to shared memory
|
||||||
|
_k[tid] = k[t];
|
||||||
|
_r[tid] = r[t];
|
||||||
|
_td[tid] = td[t];
|
||||||
|
|
||||||
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
|
|
||||||
|
const float _v = v[t];
|
||||||
|
float y = 0;
|
||||||
|
|
||||||
|
// Process in chunks of 4 for better vectorization
|
||||||
|
sycl::float4 k4, r4, tf4, td4, s4, kv4;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < head_size; j += 4) {
|
||||||
|
// Load data in vec4 chunks
|
||||||
|
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||||
|
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||||
|
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
||||||
|
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
||||||
|
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||||
|
|
||||||
|
// Compute key-value product
|
||||||
|
sycl::float4 kv4 = k4 * _v;
|
||||||
|
|
||||||
|
// Accumulate weighted sum
|
||||||
|
y += sycl::dot(r4, tf4 * kv4 + s4);
|
||||||
|
|
||||||
|
// Update state
|
||||||
|
s4 = s4 * td4 + kv4;
|
||||||
|
|
||||||
|
// Store updated state
|
||||||
|
state[j] = s4.x();
|
||||||
|
state[j+1] = s4.y();
|
||||||
|
state[j+2] = s4.z();
|
||||||
|
state[j+3] = s4.w();
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[t] = y;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save final state
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < head_size; i++) {
|
||||||
|
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||||
|
const ggml_tensor* src1, ggml_tensor* dst) {
|
||||||
|
|
||||||
|
const float* k_d = (const float*)dst->src[0]->data;
|
||||||
|
const float* v_d = (const float*)dst->src[1]->data;
|
||||||
|
const float* r_d = (const float*)dst->src[2]->data;
|
||||||
|
const float* tf_d = (const float*)dst->src[3]->data;
|
||||||
|
const float* td_d = (const float*)dst->src[4]->data;
|
||||||
|
const float* s_d = (const float*)dst->src[5]->data;
|
||||||
|
float* dst_d = (float*)dst->data;
|
||||||
|
|
||||||
|
const int64_t B = dst->src[5]->ne[1];
|
||||||
|
const int64_t T = dst->src[0]->ne[3];
|
||||||
|
const int64_t C = dst->ne[0];
|
||||||
|
const int64_t H = dst->src[0]->ne[2];
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(C % H == 0);
|
||||||
|
GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
|
||||||
|
|
||||||
|
dpct::queue_ptr stream = ctx.stream();
|
||||||
|
|
||||||
|
// Calculate execution configuration
|
||||||
|
const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
|
||||||
|
sycl::range<3> block_dims(1, 1, C / H);
|
||||||
|
sycl::range<3> grid_dims(1, 1, B * H);
|
||||||
|
|
||||||
|
// Submit kernel
|
||||||
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
|
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||||
|
|
||||||
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
rwkv_wkv_f32_kernel(
|
||||||
|
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
||||||
|
item_ct1, shared_mem_acc.get_pointer()
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
10
ggml/src/ggml-sycl/wkv6.hpp
Normal file
10
ggml/src/ggml-sycl/wkv6.hpp
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
#ifndef GGML_SYCL_WKV6_HPP
|
||||||
|
#define GGML_SYCL_WKV6_HPP
|
||||||
|
|
||||||
|
#include "common.hpp"
|
||||||
|
|
||||||
|
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||||
|
const ggml_tensor *src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
|
||||||
|
#endif // GGML_SYCL_WKV6_HPP
|
|
@ -975,7 +975,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"WIN_UNPART",
|
"WIN_UNPART",
|
||||||
"GET_REL_POS",
|
"GET_REL_POS",
|
||||||
"ADD_REL_POS",
|
"ADD_REL_POS",
|
||||||
"RWKV_WKV",
|
"RWKV_WKV6",
|
||||||
|
|
||||||
"UNARY",
|
"UNARY",
|
||||||
|
|
||||||
|
@ -1070,7 +1070,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"win_unpart(x)",
|
"win_unpart(x)",
|
||||||
"get_rel_pos(x)",
|
"get_rel_pos(x)",
|
||||||
"add_rel_pos(x)",
|
"add_rel_pos(x)",
|
||||||
"rwkv_wkv(k, v, r, tf, td, s)",
|
"rwkv_wkv6(k, v, r, tf, td, s)",
|
||||||
|
|
||||||
"unary(x)",
|
"unary(x)",
|
||||||
|
|
||||||
|
@ -4228,6 +4228,15 @@ void ggml_flash_attn_ext_set_prec(
|
||||||
ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
|
ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum ggml_prec ggml_flash_attn_ext_get_prec(
|
||||||
|
const struct ggml_tensor * a) {
|
||||||
|
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
|
||||||
|
|
||||||
|
const int32_t prec_i32 = ggml_get_op_params_i32(a, 3);
|
||||||
|
|
||||||
|
return (enum ggml_prec) prec_i32;
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_flash_attn_back
|
// ggml_flash_attn_back
|
||||||
|
|
||||||
struct ggml_tensor * ggml_flash_attn_back(
|
struct ggml_tensor * ggml_flash_attn_back(
|
||||||
|
@ -4503,9 +4512,9 @@ struct ggml_tensor * ggml_add_rel_pos_inplace(
|
||||||
return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
|
return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_rwkv_wkv
|
// ggml_rwkv_wkv6
|
||||||
|
|
||||||
struct ggml_tensor * ggml_rwkv_wkv(
|
struct ggml_tensor * ggml_rwkv_wkv6(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * k,
|
struct ggml_tensor * k,
|
||||||
struct ggml_tensor * v,
|
struct ggml_tensor * v,
|
||||||
|
@ -4537,7 +4546,7 @@ struct ggml_tensor * ggml_rwkv_wkv(
|
||||||
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
|
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
|
||||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||||
|
|
||||||
result->op = GGML_OP_RWKV_WKV;
|
result->op = GGML_OP_RWKV_WKV6;
|
||||||
result->src[0] = k;
|
result->src[0] = k;
|
||||||
result->src[1] = v;
|
result->src[1] = v;
|
||||||
result->src[2] = r;
|
result->src[2] = r;
|
||||||
|
@ -6084,7 +6093,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_GET_REL_POS:
|
case GGML_OP_GET_REL_POS:
|
||||||
case GGML_OP_ADD_REL_POS:
|
case GGML_OP_ADD_REL_POS:
|
||||||
case GGML_OP_RWKV_WKV:
|
case GGML_OP_RWKV_WKV6:
|
||||||
case GGML_OP_MAP_UNARY:
|
case GGML_OP_MAP_UNARY:
|
||||||
case GGML_OP_MAP_BINARY:
|
case GGML_OP_MAP_BINARY:
|
||||||
case GGML_OP_MAP_CUSTOM1_F32:
|
case GGML_OP_MAP_CUSTOM1_F32:
|
||||||
|
|
10
klite.embd
10
klite.embd
File diff suppressed because one or more lines are too long
|
@ -1876,8 +1876,11 @@ static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
|
||||||
static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) {
|
static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) {
|
||||||
const auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
const auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
||||||
|
|
||||||
// nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying
|
llama_vocab dummy_vocab;
|
||||||
auto * result = llama_sampler_init_dry(nullptr, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
|
|
||||||
|
// dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying
|
||||||
|
auto * result = llama_sampler_init_dry_impl(dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
|
||||||
|
|
||||||
// Copy the state, including the processed breakers
|
// Copy the state, including the processed breakers
|
||||||
{
|
{
|
||||||
auto * result_ctx = (llama_sampler_dry *) result->ctx;
|
auto * result_ctx = (llama_sampler_dry *) result->ctx;
|
||||||
|
|
|
@ -7052,7 +7052,7 @@ static const std::map<llm_tensor, llm_tensor_info> llm_tensor_info_mapping = {
|
||||||
{LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
{LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV}},
|
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
|
||||||
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
{LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
{LLM_TENSOR_ATTN_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_ATTN_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
|
@ -7168,7 +7168,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
|
||||||
ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
|
ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
|
||||||
op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C);
|
op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_RWKV_WKV:
|
case GGML_OP_RWKV_WKV6:
|
||||||
{
|
{
|
||||||
// FIXME
|
// FIXME
|
||||||
const int64_t S = 123;
|
const int64_t S = 123;
|
||||||
|
@ -7181,7 +7181,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
|
||||||
ggml_tensor * tf = w;
|
ggml_tensor * tf = w;
|
||||||
ggml_tensor * td = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
|
ggml_tensor * td = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
|
||||||
ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H);
|
ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H);
|
||||||
op_tensor = ggml_rwkv_wkv(ctx, k, v, r, tf, td, state);
|
op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name);
|
GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name);
|
||||||
|
@ -10145,7 +10145,7 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
|
||||||
v = ggml_transpose(ctx, v);
|
v = ggml_transpose(ctx, v);
|
||||||
r = ggml_transpose(ctx, r);
|
r = ggml_transpose(ctx, r);
|
||||||
|
|
||||||
struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
|
struct ggml_tensor * wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
|
||||||
cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
|
cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
|
||||||
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
|
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue