mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
Merge branch 'upstream' into concedo_experimental
This commit is contained in:
commit
0cddbe1f0b
30 changed files with 2561 additions and 1226 deletions
|
@ -814,13 +814,18 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
).set_env("LLAMA_ARG_FLASH_ATTN"));
|
).set_env("LLAMA_ARG_FLASH_ATTN"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"-p", "--prompt"}, "PROMPT",
|
{"-p", "--prompt"}, "PROMPT",
|
||||||
ex == LLAMA_EXAMPLE_MAIN
|
"prompt to start generation with; for system message, use -sys",
|
||||||
? "prompt to start generation with\nif -cnv is set, this will be used as system prompt"
|
|
||||||
: "prompt to start generation with",
|
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
params.prompt = value;
|
params.prompt = value;
|
||||||
}
|
}
|
||||||
).set_excludes({LLAMA_EXAMPLE_SERVER}));
|
).set_excludes({LLAMA_EXAMPLE_SERVER}));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"-sys", "--system-prompt"}, "PROMPT",
|
||||||
|
"system prompt to use with model (if applicable, depending on chat template)",
|
||||||
|
[](common_params & params, const std::string & value) {
|
||||||
|
params.system_prompt = value;
|
||||||
|
}
|
||||||
|
).set_examples({LLAMA_EXAMPLE_MAIN}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--no-perf"},
|
{"--no-perf"},
|
||||||
string_format("disable internal libllama performance timings (default: %s)", params.no_perf ? "true" : "false"),
|
string_format("disable internal libllama performance timings (default: %s)", params.no_perf ? "true" : "false"),
|
||||||
|
@ -2448,6 +2453,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
params.vocoder.use_guide_tokens = true;
|
params.vocoder.use_guide_tokens = true;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
|
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--tts-speaker-file"}, "FNAME",
|
||||||
|
"speaker file path for audio generation",
|
||||||
|
[](common_params & params, const std::string & value) {
|
||||||
|
params.vocoder.speaker_file = value;
|
||||||
|
}
|
||||||
|
).set_examples({LLAMA_EXAMPLE_TTS}));
|
||||||
|
|
||||||
// model-specific
|
// model-specific
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
|
|
@ -196,6 +196,8 @@ struct common_params_vocoder {
|
||||||
std::string model = ""; // model path // NOLINT
|
std::string model = ""; // model path // NOLINT
|
||||||
std::string model_url = ""; // model url to download // NOLINT
|
std::string model_url = ""; // model url to download // NOLINT
|
||||||
|
|
||||||
|
std::string speaker_file = ""; // speaker file path // NOLINT
|
||||||
|
|
||||||
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
|
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -257,6 +259,7 @@ struct common_params {
|
||||||
std::string hf_repo = ""; // HF repo // NOLINT
|
std::string hf_repo = ""; // HF repo // NOLINT
|
||||||
std::string hf_file = ""; // HF file // NOLINT
|
std::string hf_file = ""; // HF file // NOLINT
|
||||||
std::string prompt = ""; // NOLINT
|
std::string prompt = ""; // NOLINT
|
||||||
|
std::string system_prompt = ""; // NOLINT
|
||||||
std::string prompt_file = ""; // store the external prompt file name // NOLINT
|
std::string prompt_file = ""; // store the external prompt file name // NOLINT
|
||||||
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT
|
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT
|
||||||
std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
|
std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
|
||||||
|
|
|
@ -32,8 +32,6 @@
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static const char * DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant";
|
|
||||||
|
|
||||||
static llama_context ** g_ctx;
|
static llama_context ** g_ctx;
|
||||||
static llama_model ** g_model;
|
static llama_model ** g_model;
|
||||||
static common_sampler ** g_smpl;
|
static common_sampler ** g_smpl;
|
||||||
|
@ -220,6 +218,10 @@ int main(int argc, char ** argv) {
|
||||||
// print chat template example in conversation mode
|
// print chat template example in conversation mode
|
||||||
if (params.conversation_mode) {
|
if (params.conversation_mode) {
|
||||||
if (params.enable_chat_template) {
|
if (params.enable_chat_template) {
|
||||||
|
if (!params.prompt.empty()) {
|
||||||
|
LOG_WRN("*** User-specified prompt in conversation mode will be ignored, did you mean to set --system-prompt (-sys) instead?\n");
|
||||||
|
}
|
||||||
|
|
||||||
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.get(), params.use_jinja).c_str());
|
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.get(), params.use_jinja).c_str());
|
||||||
} else {
|
} else {
|
||||||
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
|
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
|
||||||
|
@ -264,6 +266,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
std::vector<llama_token> embd_inp;
|
std::vector<llama_token> embd_inp;
|
||||||
|
|
||||||
|
bool waiting_for_first_input = params.conversation_mode && params.enable_chat_template && params.system_prompt.empty();
|
||||||
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
|
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
|
||||||
common_chat_msg new_msg;
|
common_chat_msg new_msg;
|
||||||
new_msg.role = role;
|
new_msg.role = role;
|
||||||
|
@ -275,11 +278,20 @@ int main(int argc, char ** argv) {
|
||||||
};
|
};
|
||||||
|
|
||||||
{
|
{
|
||||||
auto prompt = (params.conversation_mode && params.enable_chat_template)
|
std::string prompt;
|
||||||
// format the system prompt in conversation mode (fallback to default if empty)
|
|
||||||
? chat_add_and_format("system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt)
|
if (params.conversation_mode && params.enable_chat_template) {
|
||||||
|
// format the system prompt in conversation mode (will use template default if empty)
|
||||||
|
prompt = params.system_prompt;
|
||||||
|
|
||||||
|
if (!prompt.empty()) {
|
||||||
|
prompt = chat_add_and_format("system", prompt);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
// otherwise use the prompt as is
|
// otherwise use the prompt as is
|
||||||
: params.prompt;
|
prompt = params.prompt;
|
||||||
|
}
|
||||||
|
|
||||||
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
|
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
|
||||||
LOG_DBG("tokenize the prompt\n");
|
LOG_DBG("tokenize the prompt\n");
|
||||||
embd_inp = common_tokenize(ctx, prompt, true, true);
|
embd_inp = common_tokenize(ctx, prompt, true, true);
|
||||||
|
@ -293,7 +305,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should not run without any tokens
|
// Should not run without any tokens
|
||||||
if (embd_inp.empty()) {
|
if (!params.conversation_mode && embd_inp.empty()) {
|
||||||
if (add_bos) {
|
if (add_bos) {
|
||||||
embd_inp.push_back(llama_vocab_bos(vocab));
|
embd_inp.push_back(llama_vocab_bos(vocab));
|
||||||
LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str());
|
LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str());
|
||||||
|
@ -477,8 +489,8 @@ int main(int argc, char ** argv) {
|
||||||
LOG_INF( " - Press Ctrl+C to interject at any time.\n");
|
LOG_INF( " - Press Ctrl+C to interject at any time.\n");
|
||||||
#endif
|
#endif
|
||||||
LOG_INF( "%s", control_message);
|
LOG_INF( "%s", control_message);
|
||||||
if (params.conversation_mode && params.enable_chat_template && params.prompt.empty()) {
|
if (params.conversation_mode && params.enable_chat_template && params.system_prompt.empty()) {
|
||||||
LOG_INF( " - Using default system message. To change it, set a different value via -p PROMPT or -f FILE argument.\n");
|
LOG_INF( " - Not using system message. To change it, set a different value via -sys PROMPT\n");
|
||||||
}
|
}
|
||||||
LOG_INF("\n");
|
LOG_INF("\n");
|
||||||
|
|
||||||
|
@ -774,7 +786,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// deal with end of generation tokens in interactive mode
|
// deal with end of generation tokens in interactive mode
|
||||||
if (llama_vocab_is_eog(vocab, common_sampler_last(smpl))) {
|
if (!waiting_for_first_input && llama_vocab_is_eog(vocab, common_sampler_last(smpl))) {
|
||||||
LOG_DBG("found an EOG token\n");
|
LOG_DBG("found an EOG token\n");
|
||||||
|
|
||||||
if (params.interactive) {
|
if (params.interactive) {
|
||||||
|
@ -794,12 +806,12 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// if current token is not EOG, we add it to current assistant message
|
// if current token is not EOG, we add it to current assistant message
|
||||||
if (params.conversation_mode) {
|
if (params.conversation_mode && !waiting_for_first_input) {
|
||||||
const auto id = common_sampler_last(smpl);
|
const auto id = common_sampler_last(smpl);
|
||||||
assistant_ss << common_token_to_piece(ctx, id, false);
|
assistant_ss << common_token_to_piece(ctx, id, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_past > 0 && is_interacting) {
|
if ((n_past > 0 || waiting_for_first_input) && is_interacting) {
|
||||||
LOG_DBG("waiting for user input\n");
|
LOG_DBG("waiting for user input\n");
|
||||||
|
|
||||||
if (params.conversation_mode) {
|
if (params.conversation_mode) {
|
||||||
|
@ -889,11 +901,12 @@ int main(int argc, char ** argv) {
|
||||||
input_echo = false; // do not echo this again
|
input_echo = false; // do not echo this again
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_past > 0) {
|
if (n_past > 0 || waiting_for_first_input) {
|
||||||
if (is_interacting) {
|
if (is_interacting) {
|
||||||
common_sampler_reset(smpl);
|
common_sampler_reset(smpl);
|
||||||
}
|
}
|
||||||
is_interacting = false;
|
is_interacting = false;
|
||||||
|
waiting_for_first_input = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Binary file not shown.
|
@ -144,6 +144,7 @@ def test_apply_chat_template():
|
||||||
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
|
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
|
||||||
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
|
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
|
||||||
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
|
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
|
||||||
|
({"type": "json_schema", "json_schema": {"schema": {"const": "foooooo"}}}, 10, "\"foooooo\""),
|
||||||
({"type": "json_object"}, 10, "(\\{|John)+"),
|
({"type": "json_object"}, 10, "(\\{|John)+"),
|
||||||
({"type": "sound"}, 0, None),
|
({"type": "sound"}, 0, None),
|
||||||
# invalid response format (expected to fail)
|
# invalid response format (expected to fail)
|
||||||
|
|
|
@ -26,7 +26,10 @@ from re import RegexFlag
|
||||||
import wget
|
import wget
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_HTTP_TIMEOUT = 12 if "LLAMA_SANITIZE" not in os.environ else 30
|
DEFAULT_HTTP_TIMEOUT = 12
|
||||||
|
|
||||||
|
if "LLAMA_SANITIZE" in os.environ or "GITHUB_ACTION" in os.environ:
|
||||||
|
DEFAULT_HTTP_TIMEOUT = 30
|
||||||
|
|
||||||
|
|
||||||
class ServerResponse:
|
class ServerResponse:
|
||||||
|
|
|
@ -590,8 +590,8 @@ static json oaicompat_completion_params_parse(
|
||||||
if (response_type == "json_object") {
|
if (response_type == "json_object") {
|
||||||
json_schema = json_value(response_format, "schema", json::object());
|
json_schema = json_value(response_format, "schema", json::object());
|
||||||
} else if (response_type == "json_schema") {
|
} else if (response_type == "json_schema") {
|
||||||
json json_schema = json_value(response_format, "json_schema", json::object());
|
auto schema_wrapper = json_value(response_format, "json_schema", json::object());
|
||||||
json_schema = json_value(json_schema, "schema", json::object());
|
json_schema = json_value(schema_wrapper, "schema", json::object());
|
||||||
} else if (!response_type.empty() && response_type != "text") {
|
} else if (!response_type.empty() && response_type != "text") {
|
||||||
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
|
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,7 @@ import { useEffect, useMemo, useRef, useState } from 'react';
|
||||||
import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context';
|
import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context';
|
||||||
import ChatMessage from './ChatMessage';
|
import ChatMessage from './ChatMessage';
|
||||||
import { CanvasType, Message, PendingMessage } from '../utils/types';
|
import { CanvasType, Message, PendingMessage } from '../utils/types';
|
||||||
import { classNames, throttle } from '../utils/misc';
|
import { classNames, cleanCurrentUrl, throttle } from '../utils/misc';
|
||||||
import CanvasPyInterpreter from './CanvasPyInterpreter';
|
import CanvasPyInterpreter from './CanvasPyInterpreter';
|
||||||
import StorageUtils from '../utils/storage';
|
import StorageUtils from '../utils/storage';
|
||||||
import { useVSCodeContext } from '../utils/llama-vscode';
|
import { useVSCodeContext } from '../utils/llama-vscode';
|
||||||
|
@ -18,6 +18,24 @@ export interface MessageDisplay {
|
||||||
isPending?: boolean;
|
isPending?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* If the current URL contains "?m=...", prefill the message input with the value.
|
||||||
|
* If the current URL contains "?q=...", prefill and SEND the message.
|
||||||
|
*/
|
||||||
|
const prefilledMsg = {
|
||||||
|
content() {
|
||||||
|
const url = new URL(window.location.href);
|
||||||
|
return url.searchParams.get('m') ?? url.searchParams.get('q') ?? '';
|
||||||
|
},
|
||||||
|
shouldSend() {
|
||||||
|
const url = new URL(window.location.href);
|
||||||
|
return url.searchParams.has('q');
|
||||||
|
},
|
||||||
|
clear() {
|
||||||
|
cleanCurrentUrl(['m', 'q']);
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
function getListMessageDisplay(
|
function getListMessageDisplay(
|
||||||
msgs: Readonly<Message[]>,
|
msgs: Readonly<Message[]>,
|
||||||
leafNodeId: Message['id']
|
leafNodeId: Message['id']
|
||||||
|
@ -81,7 +99,7 @@ export default function ChatScreen() {
|
||||||
canvasData,
|
canvasData,
|
||||||
replaceMessageAndGenerate,
|
replaceMessageAndGenerate,
|
||||||
} = useAppContext();
|
} = useAppContext();
|
||||||
const [inputMsg, setInputMsg] = useState('');
|
const [inputMsg, setInputMsg] = useState(prefilledMsg.content());
|
||||||
const inputRef = useRef<HTMLTextAreaElement>(null);
|
const inputRef = useRef<HTMLTextAreaElement>(null);
|
||||||
|
|
||||||
const { extraContext, clearExtraContext } = useVSCodeContext(
|
const { extraContext, clearExtraContext } = useVSCodeContext(
|
||||||
|
@ -172,6 +190,22 @@ export default function ChatScreen() {
|
||||||
|
|
||||||
const hasCanvas = !!canvasData;
|
const hasCanvas = !!canvasData;
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (prefilledMsg.shouldSend()) {
|
||||||
|
// send the prefilled message if needed
|
||||||
|
sendNewMessage();
|
||||||
|
} else {
|
||||||
|
// otherwise, focus on the input and move the cursor to the end
|
||||||
|
if (inputRef.current) {
|
||||||
|
inputRef.current.focus();
|
||||||
|
inputRef.current.selectionStart = inputRef.current.value.length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
prefilledMsg.clear();
|
||||||
|
// no need to keep track of sendNewMessage
|
||||||
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
|
}, [inputRef]);
|
||||||
|
|
||||||
// due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg)
|
// due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg)
|
||||||
const pendingMsgDisplay: MessageDisplay[] =
|
const pendingMsgDisplay: MessageDisplay[] =
|
||||||
pendingMsg && messages.at(-1)?.msg.id !== pendingMsg.id
|
pendingMsg && messages.at(-1)?.msg.id !== pendingMsg.id
|
||||||
|
|
|
@ -148,13 +148,13 @@ const SETTING_SECTIONS: SettingSection[] = [
|
||||||
fields: [
|
fields: [
|
||||||
{
|
{
|
||||||
type: SettingInputType.CHECKBOX,
|
type: SettingInputType.CHECKBOX,
|
||||||
label: 'Expand though process by default for generating message',
|
label: 'Expand thought process by default when generating messages',
|
||||||
key: 'showThoughtInProgress',
|
key: 'showThoughtInProgress',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: SettingInputType.CHECKBOX,
|
type: SettingInputType.CHECKBOX,
|
||||||
label:
|
label:
|
||||||
'Exclude thought process when sending request to API (Recommended for DeepSeek-R1)',
|
'Exclude thought process when sending requests to API (Recommended for DeepSeek-R1)',
|
||||||
key: 'excludeThoughtOnReq',
|
key: 'excludeThoughtOnReq',
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
@ -247,7 +247,7 @@ const SETTING_SECTIONS: SettingSection[] = [
|
||||||
This feature uses{' '}
|
This feature uses{' '}
|
||||||
<OpenInNewTab href="https://pyodide.org">pyodide</OpenInNewTab>,
|
<OpenInNewTab href="https://pyodide.org">pyodide</OpenInNewTab>,
|
||||||
downloaded from CDN. To use this feature, ask the LLM to generate
|
downloaded from CDN. To use this feature, ask the LLM to generate
|
||||||
python code inside a markdown code block. You will see a "Run"
|
Python code inside a Markdown code block. You will see a "Run"
|
||||||
button on the code block, near the "Copy" button.
|
button on the code block, near the "Copy" button.
|
||||||
</small>
|
</small>
|
||||||
</>
|
</>
|
||||||
|
@ -274,7 +274,7 @@ export default function SettingDialog({
|
||||||
);
|
);
|
||||||
|
|
||||||
const resetConfig = () => {
|
const resetConfig = () => {
|
||||||
if (window.confirm('Are you sure to reset all settings?')) {
|
if (window.confirm('Are you sure you want to reset all settings?')) {
|
||||||
setLocalConfig(CONFIG_DEFAULT);
|
setLocalConfig(CONFIG_DEFAULT);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -296,9 +296,9 @@ export default function SettingDialog({
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} else if (mustBeNumeric) {
|
} else if (mustBeNumeric) {
|
||||||
const trimedValue = value.toString().trim();
|
const trimmedValue = value.toString().trim();
|
||||||
const numVal = Number(trimedValue);
|
const numVal = Number(trimmedValue);
|
||||||
if (isNaN(numVal) || !isNumeric(numVal) || trimedValue.length === 0) {
|
if (isNaN(numVal) || !isNumeric(numVal) || trimmedValue.length === 0) {
|
||||||
alert(`Value for ${key} must be numeric`);
|
alert(`Value for ${key} must be numeric`);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -118,3 +118,11 @@ export const throttle = <T extends unknown[]>(
|
||||||
}, delay);
|
}, delay);
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const cleanCurrentUrl = (removeQueryParams: string[]) => {
|
||||||
|
const url = new URL(window.location.href);
|
||||||
|
removeQueryParams.forEach((param) => {
|
||||||
|
url.searchParams.delete(param);
|
||||||
|
});
|
||||||
|
window.history.replaceState({}, '', url.toString());
|
||||||
|
};
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
#include "json.hpp"
|
||||||
|
|
||||||
#define _USE_MATH_DEFINES // For M_PI on MSVC
|
#define _USE_MATH_DEFINES // For M_PI on MSVC
|
||||||
|
|
||||||
|
@ -16,6 +17,13 @@
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
enum outetts_version {
|
||||||
|
OUTETTS_V0_2,
|
||||||
|
OUTETTS_V0_3,
|
||||||
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
// Terminal utils
|
// Terminal utils
|
||||||
//
|
//
|
||||||
|
@ -371,7 +379,7 @@ static std::string replace_numbers_with_words(const std::string & input_text) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
|
// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
|
||||||
static std::string process_text(const std::string & text) {
|
static std::string process_text(const std::string & text, const outetts_version tts_version = OUTETTS_V0_2) {
|
||||||
|
|
||||||
// For now I skipped text romanization as I am unsure how to handle
|
// For now I skipped text romanization as I am unsure how to handle
|
||||||
// uroman and MeCab implementations in C++
|
// uroman and MeCab implementations in C++
|
||||||
|
@ -401,7 +409,8 @@ static std::string process_text(const std::string & text) {
|
||||||
if (c == ' ') {
|
if (c == ' ') {
|
||||||
prompt_clean += "<|text_sep|>";
|
prompt_clean += "<|text_sep|>";
|
||||||
*/
|
*/
|
||||||
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|text_sep|>");
|
std::string separator = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|text_sep|>";
|
||||||
|
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), separator);
|
||||||
|
|
||||||
return processed_text;
|
return processed_text;
|
||||||
}
|
}
|
||||||
|
@ -425,8 +434,8 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
|
||||||
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
|
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) {
|
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str, const outetts_version tts_version = OUTETTS_V0_2) {
|
||||||
const std::string& delimiter = "<|text_sep|>";
|
const std::string& delimiter = (tts_version == OUTETTS_V0_3 ? "<|space|>" : "<|text_sep|>");
|
||||||
|
|
||||||
std::vector<llama_token> result;
|
std::vector<llama_token> result;
|
||||||
size_t start = 0;
|
size_t start = 0;
|
||||||
|
@ -452,6 +461,78 @@ static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab,
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static json speaker_from_file(const std::string & speaker_file) {
|
||||||
|
std::ifstream file(speaker_file);
|
||||||
|
if (!file) {
|
||||||
|
LOG_ERR("%s: Failed to open file '%s' for reading\n", __func__, speaker_file.c_str());
|
||||||
|
return json();
|
||||||
|
}
|
||||||
|
|
||||||
|
json speaker = json::parse(file);
|
||||||
|
return speaker;
|
||||||
|
}
|
||||||
|
|
||||||
|
static outetts_version get_tts_version(llama_model *model, json speaker = json::object()) {
|
||||||
|
if (speaker.contains("version")) {
|
||||||
|
std::string version = speaker["version"].get<std::string>();
|
||||||
|
if (version == "0.2") {
|
||||||
|
return OUTETTS_V0_2;
|
||||||
|
} else if (version == "0.3") {
|
||||||
|
return OUTETTS_V0_3;
|
||||||
|
} else {
|
||||||
|
LOG_ERR("%s: Unsupported speaker version '%s'\n", __func__, version.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also could get version from model itself
|
||||||
|
const char *chat_template = llama_model_chat_template(model, nullptr);
|
||||||
|
if (chat_template && std::string(chat_template) == "outetts-0.3") {
|
||||||
|
return OUTETTS_V0_3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use 0.2 as the default version
|
||||||
|
return OUTETTS_V0_2;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string audio_text_from_speaker(json speaker, const outetts_version tts_version = OUTETTS_V0_2) {
|
||||||
|
std::string audio_text = "<|text_start|>";
|
||||||
|
|
||||||
|
if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) {
|
||||||
|
std::string separator = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|text_sep|>";
|
||||||
|
for (const auto &word : speaker["words"]) {
|
||||||
|
audio_text += word["word"].get<std::string>() + separator;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return audio_text;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string audio_data_from_speaker(json speaker, const outetts_version tts_version = OUTETTS_V0_2) {
|
||||||
|
std::string audio_data = "<|audio_start|>\n";
|
||||||
|
|
||||||
|
if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) {
|
||||||
|
std::string code_start = (tts_version == OUTETTS_V0_3) ? "" : "<|code_start|>";
|
||||||
|
std::string code_end = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|code_end|>";
|
||||||
|
for (const auto &word : speaker["words"]) {
|
||||||
|
std::string word_text = word["word"].get<std::string>();
|
||||||
|
double duration = word["duration"].get<double>();
|
||||||
|
std::vector<int> codes = word["codes"].get<std::vector<int>>();
|
||||||
|
|
||||||
|
// Create the audio output entry
|
||||||
|
std::ostringstream word_entry;
|
||||||
|
word_entry << word_text << "<|t_" << std::fixed << std::setprecision(2)
|
||||||
|
<< duration << "|>" + code_start;
|
||||||
|
for (const auto &Code : codes) {
|
||||||
|
word_entry << "<|" << Code << "|>";
|
||||||
|
}
|
||||||
|
word_entry << code_end << "\n";
|
||||||
|
audio_data += word_entry.str();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return audio_data;
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
common_params params;
|
common_params params;
|
||||||
|
|
||||||
|
@ -523,34 +604,9 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<llama_token> codes;
|
std::vector<llama_token> codes;
|
||||||
std::vector<llama_token> guide_tokens;
|
std::vector<llama_token> guide_tokens;
|
||||||
|
|
||||||
// process prompt and generate voice codes
|
// the default speaker profile is from: https://github.com/edwko/OuteTTS/blob/main/outetts/version/v1/default_speakers/en_male_1.json
|
||||||
{
|
std::string audio_text = "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>";
|
||||||
LOG_INF("%s: constructing prompt ..\n", __func__);
|
std::string audio_data = R"(<|audio_start|>
|
||||||
|
|
||||||
std::vector<llama_token> prompt_inp;
|
|
||||||
|
|
||||||
prompt_init(prompt_inp, vocab);
|
|
||||||
|
|
||||||
prompt_add(prompt_inp, vocab, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
|
|
||||||
|
|
||||||
// convert the input text into the necessary format expected by OuteTTS
|
|
||||||
{
|
|
||||||
std::string prompt_clean = process_text(params.prompt);
|
|
||||||
if (params.vocoder.use_guide_tokens) {
|
|
||||||
guide_tokens = prepare_guide_tokens(vocab, prompt_clean);
|
|
||||||
}
|
|
||||||
|
|
||||||
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
|
|
||||||
|
|
||||||
prompt_add(prompt_inp, vocab, prompt_clean, false, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt_add(prompt_inp, vocab, "<|text_end|>\n", false, true);
|
|
||||||
|
|
||||||
// disabled to save time on tokenizing each time
|
|
||||||
// TODO: load voices from the json files
|
|
||||||
#if 0
|
|
||||||
const std::string voice_data = R"(<|audio_start|>
|
|
||||||
the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|>
|
the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|>
|
||||||
overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|>
|
overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|>
|
||||||
package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|>
|
package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|>
|
||||||
|
@ -582,117 +638,170 @@ it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><
|
||||||
looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
|
looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
|
||||||
lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)";
|
lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)";
|
||||||
|
|
||||||
auto tmp = common_tokenize(vocab, voice_data, false, true);
|
// audio data for 0.3 version
|
||||||
printf("\n\n");
|
outetts_version tts_version = get_tts_version(model_ttc);
|
||||||
for (int i = 0; i < tmp.size(); ++i) {
|
if (tts_version == OUTETTS_V0_3) {
|
||||||
printf("%d, ", tmp[i]);
|
audio_text = std::regex_replace(audio_text, std::regex(R"(<\|text_sep\|>)"), "<|space|>");
|
||||||
|
audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_start\|>)"), "");
|
||||||
|
audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_end\|>)"), "<|space|>");
|
||||||
|
}
|
||||||
|
|
||||||
|
// load speaker if given
|
||||||
|
if (!params.vocoder.speaker_file.empty()) {
|
||||||
|
LOG_INF("%s: loading speaker ..\n", __func__);
|
||||||
|
json speaker = speaker_from_file(params.vocoder.speaker_file);
|
||||||
|
if (speaker.empty()) {
|
||||||
|
LOG_ERR("%s: Failed to load speaker file '%s'\n", __func__, params.vocoder.speaker_file.c_str());
|
||||||
|
return 1;
|
||||||
}
|
}
|
||||||
printf("\n\n");
|
audio_text = audio_text_from_speaker(speaker, tts_version);
|
||||||
|
audio_data = audio_data_from_speaker(speaker, tts_version);
|
||||||
|
}
|
||||||
|
|
||||||
|
// process prompt and generate voice codes
|
||||||
|
{
|
||||||
|
LOG_INF("%s: constructing prompt ..\n", __func__);
|
||||||
|
|
||||||
|
std::vector<llama_token> prompt_inp;
|
||||||
|
|
||||||
|
prompt_init(prompt_inp, vocab);
|
||||||
|
|
||||||
|
prompt_add(prompt_inp, vocab, audio_text, false, true);
|
||||||
|
|
||||||
|
// convert the input text into the necessary format expected by OuteTTS
|
||||||
|
{
|
||||||
|
std::string prompt_clean = process_text(params.prompt, tts_version);
|
||||||
|
if (params.vocoder.use_guide_tokens) {
|
||||||
|
guide_tokens = prepare_guide_tokens(vocab, prompt_clean, tts_version);
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
|
||||||
|
|
||||||
|
prompt_add(prompt_inp, vocab, prompt_clean, false, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt_add(prompt_inp, vocab, "<|text_end|>\n", false, true);
|
||||||
|
|
||||||
|
if (!params.vocoder.speaker_file.empty()) {
|
||||||
|
prompt_add(prompt_inp, vocab, audio_data, false, true);
|
||||||
|
} else {
|
||||||
|
// disabled to save time on tokenizing each time
|
||||||
|
#if 1
|
||||||
|
const std::string voice_data = audio_data;
|
||||||
|
|
||||||
|
auto tmp = common_tokenize(vocab, voice_data, false, true);
|
||||||
|
printf("\n\n");
|
||||||
|
for (size_t i = 0; i < tmp.size(); ++i) {
|
||||||
|
printf("%d, ", tmp[i]);
|
||||||
|
}
|
||||||
|
printf("\n\n");
|
||||||
|
prompt_add(prompt_inp, tmp);
|
||||||
#else
|
#else
|
||||||
prompt_add(prompt_inp, llama_tokens {
|
prompt_add(prompt_inp, llama_tokens {
|
||||||
151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585,
|
151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585,
|
||||||
152460, 153375, 151670, 198, 74455, 155808, 151669, 151799,
|
152460, 153375, 151670, 198, 74455, 155808, 151669, 151799,
|
||||||
151873, 151863, 152446, 152372, 152204, 152728, 152229, 152470,
|
151873, 151863, 152446, 152372, 152204, 152728, 152229, 152470,
|
||||||
151970, 153413, 152419, 153334, 153289, 153374, 153199, 152040,
|
151970, 153413, 152419, 153334, 153289, 153374, 153199, 152040,
|
||||||
153260, 152721, 152680, 153297, 152419, 153248, 152400, 152691,
|
153260, 152721, 152680, 153297, 152419, 153248, 152400, 152691,
|
||||||
153368, 153437, 151670, 198, 1722, 155828, 151669, 152607,
|
153368, 153437, 151670, 198, 1722, 155828, 151669, 152607,
|
||||||
152256, 152991, 152299, 152688, 153163, 153016, 152789, 153198,
|
152256, 152991, 152299, 152688, 153163, 153016, 152789, 153198,
|
||||||
152712, 151911, 153107, 152623, 152170, 152395, 152852, 152207,
|
152712, 151911, 153107, 152623, 152170, 152395, 152852, 152207,
|
||||||
152461, 153321, 153309, 151750, 152137, 153340, 152573, 152267,
|
152461, 153321, 153309, 151750, 152137, 153340, 152573, 152267,
|
||||||
153347, 151789, 152681, 153339, 151992, 152512, 151751, 152179,
|
153347, 151789, 152681, 153339, 151992, 152512, 151751, 152179,
|
||||||
153434, 153180, 152900, 153440, 152474, 153122, 153129, 151904,
|
153434, 153180, 152900, 153440, 152474, 153122, 153129, 151904,
|
||||||
152311, 151670, 198, 1499, 155791, 151669, 152276, 152454,
|
152311, 151670, 198, 1499, 155791, 151669, 152276, 152454,
|
||||||
153354, 152544, 153204, 153272, 152708, 153433, 152319, 153226,
|
153354, 152544, 153204, 153272, 152708, 153433, 152319, 153226,
|
||||||
153043, 152325, 153267, 152622, 151670, 198, 4250, 155797,
|
153043, 152325, 153267, 152622, 151670, 198, 4250, 155797,
|
||||||
151669, 153454, 153342, 151989, 152458, 153420, 152303, 152271,
|
151669, 153454, 153342, 151989, 152458, 153420, 152303, 152271,
|
||||||
152827, 153036, 153196, 151708, 153263, 152561, 153207, 152213,
|
152827, 153036, 153196, 151708, 153263, 152561, 153207, 152213,
|
||||||
152112, 153204, 151722, 152542, 151670, 198, 19789, 155796,
|
152112, 153204, 151722, 152542, 151670, 198, 19789, 155796,
|
||||||
151669, 153353, 153182, 152345, 152471, 152477, 153014, 152002,
|
151669, 153353, 153182, 152345, 152471, 152477, 153014, 152002,
|
||||||
152191, 151734, 152312, 152810, 152237, 153224, 153169, 153224,
|
152191, 151734, 152312, 152810, 152237, 153224, 153169, 153224,
|
||||||
152244, 153387, 153404, 151670, 198, 16069, 155811, 151669,
|
152244, 153387, 153404, 151670, 198, 16069, 155811, 151669,
|
||||||
152265, 151946, 151808, 152412, 152363, 152305, 153156, 152733,
|
152265, 151946, 151808, 152412, 152363, 152305, 153156, 152733,
|
||||||
152810, 153157, 152016, 152100, 152069, 153234, 152317, 152589,
|
152810, 153157, 152016, 152100, 152069, 153234, 152317, 152589,
|
||||||
152707, 153121, 153341, 152159, 152114, 153156, 153001, 153504,
|
152707, 153121, 153341, 152159, 152114, 153156, 153001, 153504,
|
||||||
153376, 152272, 152433, 152325, 151941, 151670, 198, 285,
|
153376, 152272, 152433, 152325, 151941, 151670, 198, 285,
|
||||||
155788, 151669, 152238, 152255, 153427, 152318, 153009, 152381,
|
155788, 151669, 152238, 152255, 153427, 152318, 153009, 152381,
|
||||||
152474, 152680, 152157, 153255, 152324, 151682, 151670, 198,
|
152474, 152680, 152157, 153255, 152324, 151682, 151670, 198,
|
||||||
32955, 155804, 151669, 153490, 153419, 152364, 152405, 152682,
|
32955, 155804, 151669, 153490, 153419, 152364, 152405, 152682,
|
||||||
152206, 152078, 153369, 152725, 153193, 153027, 152946, 152488,
|
152206, 152078, 153369, 152725, 153193, 153027, 152946, 152488,
|
||||||
153070, 151883, 152890, 152489, 153144, 153375, 152358, 151685,
|
153070, 151883, 152890, 152489, 153144, 153375, 152358, 151685,
|
||||||
152494, 152117, 152740, 151670, 198, 37448, 480, 155840, 151669,
|
152494, 152117, 152740, 151670, 198, 37448, 480, 155840, 151669,
|
||||||
151902, 152720, 153377, 152027, 152378, 152821, 153207, 153459,
|
151902, 152720, 153377, 152027, 152378, 152821, 153207, 153459,
|
||||||
153028, 153068, 152507, 153255, 152158, 152921, 151958, 152609,
|
153028, 153068, 152507, 153255, 152158, 152921, 151958, 152609,
|
||||||
152748, 152822, 152286, 151714, 152730, 152377, 152353, 152470,
|
152748, 152822, 152286, 151714, 152730, 152377, 152353, 152470,
|
||||||
152606, 152162, 152186, 153071, 152244, 153118, 153375, 153018,
|
152606, 152162, 152186, 153071, 152244, 153118, 153375, 153018,
|
||||||
152712, 153098, 152976, 152336, 151843, 153202, 152297, 151736,
|
152712, 153098, 152976, 152336, 151843, 153202, 152297, 151736,
|
||||||
153380, 153502, 152702, 152115, 153181, 152735, 153277, 153457,
|
153380, 153502, 152702, 152115, 153181, 152735, 153277, 153457,
|
||||||
152393, 153112, 152595, 151670, 198, 19098, 155808, 151669,
|
152393, 153112, 152595, 151670, 198, 19098, 155808, 151669,
|
||||||
152464, 153452, 152595, 153312, 151937, 151933, 153197, 152239,
|
152464, 153452, 152595, 153312, 151937, 151933, 153197, 152239,
|
||||||
153163, 152922, 153402, 152034, 152591, 153438, 152215, 151673,
|
153163, 152922, 153402, 152034, 152591, 153438, 152215, 151673,
|
||||||
152005, 151785, 152642, 151924, 153278, 151805, 151974, 153482,
|
152005, 151785, 152642, 151924, 153278, 151805, 151974, 153482,
|
||||||
152718, 152862, 153347, 151670, 198, 72, 155780, 151669, 151795,
|
152718, 152862, 153347, 151670, 198, 72, 155780, 151669, 151795,
|
||||||
152111, 152746, 152377, 153471, 152309, 151670, 198, 19016,
|
152111, 152746, 152377, 153471, 152309, 151670, 198, 19016,
|
||||||
155788, 151669, 153181, 152271, 152190, 152842, 152224, 152701,
|
155788, 151669, 153181, 152271, 152190, 152842, 152224, 152701,
|
||||||
152939, 152536, 152091, 151815, 152733, 151672, 151670, 198,
|
152939, 152536, 152091, 151815, 152733, 151672, 151670, 198,
|
||||||
14689, 155788, 151669, 152291, 152072, 152942, 151734, 153042,
|
14689, 155788, 151669, 152291, 152072, 152942, 151734, 153042,
|
||||||
153504, 152589, 153333, 151839, 151941, 153038, 153180, 151670,
|
153504, 152589, 153333, 151839, 151941, 153038, 153180, 151670,
|
||||||
198, 36996, 8303, 155832, 151669, 152231, 152256, 152835,
|
198, 36996, 8303, 155832, 151669, 152231, 152256, 152835,
|
||||||
152801, 152985, 153400, 152393, 152818, 152765, 152249, 152600,
|
152801, 152985, 153400, 152393, 152818, 152765, 152249, 152600,
|
||||||
151699, 152302, 152752, 153018, 153009, 151992, 153054, 152847,
|
151699, 152302, 152752, 153018, 153009, 151992, 153054, 152847,
|
||||||
153354, 153228, 152662, 153355, 152532, 153393, 151782, 152458,
|
153354, 153228, 152662, 153355, 152532, 153393, 151782, 152458,
|
||||||
152048, 152757, 152428, 153195, 151906, 153006, 153178, 153250,
|
152048, 152757, 152428, 153195, 151906, 153006, 153178, 153250,
|
||||||
152331, 152284, 152780, 153138, 153319, 151980, 153142, 152418,
|
152331, 152284, 152780, 153138, 153319, 151980, 153142, 152418,
|
||||||
152228, 152733, 151670, 198, 9096, 155801, 151669, 151698,
|
152228, 152733, 151670, 198, 9096, 155801, 151669, 151698,
|
||||||
153321, 152217, 153039, 152935, 153400, 152122, 152531, 153106,
|
153321, 152217, 153039, 152935, 153400, 152122, 152531, 153106,
|
||||||
152169, 152892, 152957, 151851, 152427, 152826, 152451, 151851,
|
152169, 152892, 152957, 151851, 152427, 152826, 152451, 151851,
|
||||||
152901, 152885, 152594, 153446, 153080, 151670, 198, 14689,
|
152901, 152885, 152594, 153446, 153080, 151670, 198, 14689,
|
||||||
155795, 151669, 152658, 151700, 153321, 152450, 152530, 153191,
|
155795, 151669, 152658, 151700, 153321, 152450, 152530, 153191,
|
||||||
151673, 151690, 151698, 152714, 152846, 152981, 153171, 153384,
|
151673, 151690, 151698, 152714, 152846, 152981, 153171, 153384,
|
||||||
153364, 153188, 153246, 151670, 198, 1055, 155779, 151669,
|
153364, 153188, 153246, 151670, 198, 1055, 155779, 151669,
|
||||||
151869, 152388, 152711, 153334, 151736, 151670, 198, 1782,
|
151869, 152388, 152711, 153334, 151736, 151670, 198, 1782,
|
||||||
155780, 151669, 153483, 153240, 152241, 152558, 152697, 153046,
|
155780, 151669, 153483, 153240, 152241, 152558, 152697, 153046,
|
||||||
151670, 198, 5804, 1363, 155820, 151669, 152941, 152764, 152605,
|
151670, 198, 5804, 1363, 155820, 151669, 152941, 152764, 152605,
|
||||||
153034, 153434, 153372, 153347, 151887, 152453, 152758, 152133,
|
153034, 153434, 153372, 153347, 151887, 152453, 152758, 152133,
|
||||||
152510, 152694, 152431, 152321, 153088, 152676, 152223, 152581,
|
152510, 152694, 152431, 152321, 153088, 152676, 152223, 152581,
|
||||||
152459, 152015, 152502, 153063, 152712, 153294, 153451, 153032,
|
152459, 152015, 152502, 153063, 152712, 153294, 153451, 153032,
|
||||||
152903, 152859, 152989, 151748, 152669, 152661, 152650, 152409,
|
152903, 152859, 152989, 151748, 152669, 152661, 152650, 152409,
|
||||||
151861, 151670, 198, 300, 7973, 155828, 151669, 153095, 152469,
|
151861, 151670, 198, 300, 7973, 155828, 151669, 153095, 152469,
|
||||||
152988, 152894, 151819, 152391, 153019, 152058, 153062, 153230,
|
152988, 152894, 151819, 152391, 153019, 152058, 153062, 153230,
|
||||||
151826, 152112, 152306, 152264, 152769, 153390, 152384, 152435,
|
151826, 152112, 152306, 152264, 152769, 153390, 152384, 152435,
|
||||||
152790, 153393, 152983, 152540, 152252, 152034, 153107, 152540,
|
152790, 153393, 152983, 152540, 152252, 152034, 153107, 152540,
|
||||||
151919, 151893, 152558, 152817, 152946, 152956, 152129, 152715,
|
151919, 151893, 152558, 152817, 152946, 152956, 152129, 152715,
|
||||||
153131, 153490, 151734, 152271, 152707, 151734, 153321, 152450,
|
153131, 153490, 151734, 152271, 152707, 151734, 153321, 152450,
|
||||||
151670, 198, 8088, 155792, 151669, 152452, 153497, 153353,
|
151670, 198, 8088, 155792, 151669, 152452, 153497, 153353,
|
||||||
152679, 152533, 152382, 152374, 152611, 153341, 153163, 152285,
|
152679, 152533, 152382, 152374, 152611, 153341, 153163, 152285,
|
||||||
153411, 152495, 153141, 152320, 151670, 198, 1199, 155781,
|
153411, 152495, 153141, 152320, 151670, 198, 1199, 155781,
|
||||||
151669, 151764, 152360, 153295, 152634, 153342, 152199, 152271,
|
151669, 151764, 152360, 153295, 152634, 153342, 152199, 152271,
|
||||||
151670, 198, 43366, 155799, 151669, 152308, 151682, 152889,
|
151670, 198, 43366, 155799, 151669, 152308, 151682, 152889,
|
||||||
152016, 152385, 152629, 152495, 151826, 153321, 152958, 152180,
|
152016, 152385, 152629, 152495, 151826, 153321, 152958, 152180,
|
||||||
151886, 153432, 152922, 152128, 153024, 153040, 152593, 152287,
|
151886, 153432, 152922, 152128, 153024, 153040, 152593, 152287,
|
||||||
151677, 151670, 198, 53660, 155808, 151669, 151727, 152092,
|
151677, 151670, 198, 53660, 155808, 151669, 151727, 152092,
|
||||||
152680, 153331, 151699, 152316, 152938, 152289, 152433, 153384,
|
152680, 153331, 151699, 152316, 152938, 152289, 152433, 153384,
|
||||||
151781, 153137, 153259, 152175, 153213, 152291, 151869, 152691,
|
151781, 153137, 153259, 152175, 153213, 152291, 151869, 152691,
|
||||||
152489, 151941, 152049, 152034, 153053, 152179, 153160, 151676,
|
152489, 151941, 152049, 152034, 153053, 152179, 153160, 151676,
|
||||||
153367, 151670, 198, 268, 4123, 480, 155821, 151669, 152350,
|
153367, 151670, 198, 268, 4123, 480, 155821, 151669, 152350,
|
||||||
152173, 152536, 151991, 151960, 153144, 153013, 152358, 152234,
|
152173, 152536, 151991, 151960, 153144, 153013, 152358, 152234,
|
||||||
153135, 152291, 153235, 152143, 152583, 152402, 153483, 152678,
|
153135, 152291, 153235, 152143, 152583, 152402, 153483, 152678,
|
||||||
152192, 152533, 152946, 151797, 153103, 152310, 152293, 151825,
|
152192, 152533, 152946, 151797, 153103, 152310, 152293, 151825,
|
||||||
152548, 153442, 152109, 152659, 153325, 152781, 152570, 152957,
|
152548, 153442, 152109, 152659, 153325, 152781, 152570, 152957,
|
||||||
151752, 152265, 153381, 152515, 151670, 198, 437, 155787,
|
151752, 152265, 153381, 152515, 151670, 198, 437, 155787,
|
||||||
151669, 152957, 152659, 151975, 152709, 152402, 152836, 152174,
|
151669, 152957, 152659, 151975, 152709, 152402, 152836, 152174,
|
||||||
151792, 153409, 153327, 152990, 151670, 198, 275, 155781,
|
151792, 153409, 153327, 152990, 151670, 198, 275, 155781,
|
||||||
151669, 152520, 153038, 152067, 153273, 153185, 152265, 152974,
|
151669, 152520, 153038, 152067, 153273, 153185, 152265, 152974,
|
||||||
151670, 198, 94273, 155799, 151669, 152953, 152938, 153427,
|
151670, 198, 94273, 155799, 151669, 152953, 152938, 153427,
|
||||||
152244, 151920, 153423, 152929, 152367, 153052, 152129, 152331,
|
152244, 151920, 153423, 152929, 152367, 153052, 152129, 152331,
|
||||||
152257, 152987, 152777, 153448, 152408, 151696, 152408, 152326,
|
152257, 152987, 152777, 153448, 152408, 151696, 152408, 152326,
|
||||||
152699, 151670, 198, 385, 16239, 155828, 151669, 152306, 152268,
|
152699, 151670, 198, 385, 16239, 155828, 151669, 152306, 152268,
|
||||||
153438, 153228, 152978, 152957, 153153, 153393, 152795, 152110,
|
153438, 153228, 152978, 152957, 153153, 153393, 152795, 152110,
|
||||||
152918, 152923, 152467, 152331, 153053, 153330, 151889, 153444,
|
152918, 152923, 152467, 152331, 153053, 153330, 151889, 153444,
|
||||||
152234, 152624, 151779, 152801, 152784, 152139, 152222, 152751,
|
152234, 152624, 151779, 152801, 152784, 152139, 152222, 152751,
|
||||||
152512, 153287, 153141, 153052, 151840, 152589, 152508, 153499,
|
152512, 153287, 153141, 153052, 151840, 152589, 152508, 153499,
|
||||||
152109, 152255, 151739, 152267, 152759, 153318, 153165, 153349,
|
152109, 152255, 151739, 152267, 152759, 153318, 153165, 153349,
|
||||||
151670,});
|
151670,});
|
||||||
#endif
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
// print the prompt token-by-token
|
// print the prompt token-by-token
|
||||||
|
|
||||||
|
|
|
@ -2,10 +2,8 @@
|
||||||
#include "ggml-backend.h"
|
#include "ggml-backend.h"
|
||||||
#include "ggml-impl.h"
|
#include "ggml-impl.h"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <codecvt>
|
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <locale>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
@ -72,14 +70,15 @@
|
||||||
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static std::wstring utf8_to_utf16(const std::string & str) {
|
namespace fs = std::filesystem;
|
||||||
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
|
|
||||||
return converter.from_bytes(str);
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string utf16_to_utf8(const std::wstring & str) {
|
static std::string path_str(const fs::path & path) {
|
||||||
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
|
std::string u8path;
|
||||||
return converter.to_bytes(str);
|
try {
|
||||||
|
u8path = path.u8string();
|
||||||
|
} catch (...) {
|
||||||
|
}
|
||||||
|
return u8path;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(__clang__)
|
#if defined(__clang__)
|
||||||
|
@ -96,12 +95,12 @@ struct dl_handle_deleter {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static dl_handle * dl_load_library(const std::wstring & path) {
|
static dl_handle * dl_load_library(const fs::path & path) {
|
||||||
// suppress error dialogs for missing DLLs
|
// suppress error dialogs for missing DLLs
|
||||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||||
|
|
||||||
HMODULE handle = LoadLibraryW(path.c_str());
|
HMODULE handle = LoadLibraryW(path.wstring().c_str());
|
||||||
|
|
||||||
SetErrorMode(old_mode);
|
SetErrorMode(old_mode);
|
||||||
|
|
||||||
|
@ -129,8 +128,8 @@ struct dl_handle_deleter {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static void * dl_load_library(const std::wstring & path) {
|
static void * dl_load_library(const fs::path & path) {
|
||||||
dl_handle * handle = dlopen(utf16_to_utf8(path).c_str(), RTLD_NOW | RTLD_LOCAL);
|
dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
|
||||||
|
|
||||||
return handle;
|
return handle;
|
||||||
}
|
}
|
||||||
|
@ -217,11 +216,11 @@ struct ggml_backend_registry {
|
||||||
devices.push_back(device);
|
devices.push_back(device);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) {
|
ggml_backend_reg_t load_backend(const fs::path & path, bool silent) {
|
||||||
dl_handle_ptr handle { dl_load_library(path) };
|
dl_handle_ptr handle { dl_load_library(path) };
|
||||||
if (!handle) {
|
if (!handle) {
|
||||||
if (!silent) {
|
if (!silent) {
|
||||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(path).c_str());
|
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(path).c_str());
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -229,7 +228,7 @@ struct ggml_backend_registry {
|
||||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||||
if (score_fn && score_fn() == 0) {
|
if (score_fn && score_fn() == 0) {
|
||||||
if (!silent) {
|
if (!silent) {
|
||||||
GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, utf16_to_utf8(path).c_str());
|
GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path_str(path).c_str());
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -237,7 +236,7 @@ struct ggml_backend_registry {
|
||||||
auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init");
|
auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init");
|
||||||
if (!backend_init_fn) {
|
if (!backend_init_fn) {
|
||||||
if (!silent) {
|
if (!silent) {
|
||||||
GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, utf16_to_utf8(path).c_str());
|
GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path_str(path).c_str());
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -246,16 +245,17 @@ struct ggml_backend_registry {
|
||||||
if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) {
|
if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) {
|
||||||
if (!silent) {
|
if (!silent) {
|
||||||
if (!reg) {
|
if (!reg) {
|
||||||
GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, utf16_to_utf8(path).c_str());
|
GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n",
|
||||||
|
__func__, path_str(path).c_str());
|
||||||
} else {
|
} else {
|
||||||
GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n",
|
GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n",
|
||||||
__func__, utf16_to_utf8(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
|
__func__, path_str(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str());
|
GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_str(path).c_str());
|
||||||
|
|
||||||
register_backend(reg, std::move(handle));
|
register_backend(reg, std::move(handle));
|
||||||
|
|
||||||
|
@ -391,14 +391,14 @@ ggml_backend_t ggml_backend_init_best(void) {
|
||||||
|
|
||||||
// Dynamic loading
|
// Dynamic loading
|
||||||
ggml_backend_reg_t ggml_backend_load(const char * path) {
|
ggml_backend_reg_t ggml_backend_load(const char * path) {
|
||||||
return get_reg().load_backend(utf8_to_utf16(path), false);
|
return get_reg().load_backend(path, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_backend_unload(ggml_backend_reg_t reg) {
|
void ggml_backend_unload(ggml_backend_reg_t reg) {
|
||||||
get_reg().unload_backend(reg, true);
|
get_reg().unload_backend(reg, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::wstring get_executable_path() {
|
static fs::path get_executable_path() {
|
||||||
#if defined(__APPLE__)
|
#if defined(__APPLE__)
|
||||||
// get executable path
|
// get executable path
|
||||||
std::vector<char> path;
|
std::vector<char> path;
|
||||||
|
@ -416,7 +416,7 @@ static std::wstring get_executable_path() {
|
||||||
if (last_slash != std::string::npos) {
|
if (last_slash != std::string::npos) {
|
||||||
base_path = base_path.substr(0, last_slash);
|
base_path = base_path.substr(0, last_slash);
|
||||||
}
|
}
|
||||||
return utf8_to_utf16(base_path + "/");
|
return base_path + "/";
|
||||||
#elif defined(__linux__) || defined(__FreeBSD__)
|
#elif defined(__linux__) || defined(__FreeBSD__)
|
||||||
std::string base_path = ".";
|
std::string base_path = ".";
|
||||||
std::vector<char> path(1024);
|
std::vector<char> path(1024);
|
||||||
|
@ -442,7 +442,7 @@ static std::wstring get_executable_path() {
|
||||||
path.resize(path.size() * 2);
|
path.resize(path.size() * 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
return utf8_to_utf16(base_path + "/");
|
return base_path + "/";
|
||||||
#elif defined(_WIN32)
|
#elif defined(_WIN32)
|
||||||
std::vector<wchar_t> path(MAX_PATH);
|
std::vector<wchar_t> path(MAX_PATH);
|
||||||
DWORD len = GetModuleFileNameW(NULL, path.data(), path.size());
|
DWORD len = GetModuleFileNameW(NULL, path.data(), path.size());
|
||||||
|
@ -462,74 +462,69 @@ static std::wstring get_executable_path() {
|
||||||
return L""; //fix for freebsd compile
|
return L""; //fix for freebsd compile
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::wstring backend_filename_prefix() {
|
static fs::path backend_filename_prefix() {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
return L"ggml-";
|
return fs::u8path("ggml-");
|
||||||
#else
|
#else
|
||||||
return L"libggml-";
|
return fs::u8path("libggml-");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::wstring backend_filename_suffix() {
|
static fs::path backend_filename_extension() {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
return L".dll";
|
return fs::u8path(".dll");
|
||||||
#else
|
#else
|
||||||
return L".so";
|
return fs::u8path(".so");
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::wstring path_separator() {
|
|
||||||
#ifdef _WIN32
|
|
||||||
return L"\\";
|
|
||||||
#else
|
|
||||||
return L"/";
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, const char * user_search_path) {
|
static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, const char * user_search_path) {
|
||||||
// enumerate all the files that match [lib]ggml-name-*.[so|dll] in the search paths
|
// enumerate all the files that match [lib]ggml-name-*.[so|dll] in the search paths
|
||||||
// TODO: search system paths
|
const fs::path name_path = fs::u8path(name);
|
||||||
std::wstring file_prefix = backend_filename_prefix() + utf8_to_utf16(name) + L"-";
|
const fs::path file_prefix = backend_filename_prefix().native() + name_path.native() + fs::u8path("-").native();
|
||||||
std::vector<std::wstring> search_paths;
|
const fs::path file_extension = backend_filename_extension();
|
||||||
|
|
||||||
|
std::vector<fs::path> search_paths;
|
||||||
if (user_search_path == nullptr) {
|
if (user_search_path == nullptr) {
|
||||||
search_paths.push_back(L"." + path_separator());
|
// default search paths: executable directory, current directory
|
||||||
search_paths.push_back(get_executable_path());
|
search_paths.push_back(get_executable_path());
|
||||||
|
search_paths.push_back(fs::current_path());
|
||||||
} else {
|
} else {
|
||||||
search_paths.push_back(utf8_to_utf16(user_search_path) + path_separator());
|
search_paths.push_back(user_search_path);
|
||||||
}
|
}
|
||||||
|
|
||||||
int best_score = 0;
|
int best_score = 0;
|
||||||
std::wstring best_path;
|
fs::path best_path;
|
||||||
|
|
||||||
namespace fs = std::filesystem;
|
|
||||||
for (const auto & search_path : search_paths) {
|
for (const auto & search_path : search_paths) {
|
||||||
if (!fs::exists(search_path)) {
|
if (!fs::exists(search_path)) {
|
||||||
|
GGML_LOG_DEBUG("%s: search path %s does not exist\n", __func__, path_str(search_path).c_str());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
||||||
for (const auto & entry : dir_it) {
|
for (const auto & entry : dir_it) {
|
||||||
if (entry.is_regular_file()) {
|
if (entry.is_regular_file()) {
|
||||||
std::wstring filename = entry.path().filename().wstring();
|
auto filename = entry.path().filename().native();
|
||||||
std::wstring ext = entry.path().extension().wstring();
|
auto ext = entry.path().extension().native();
|
||||||
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
if (filename.find(file_prefix) == 0 && ext == file_extension) {
|
||||||
dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
dl_handle_ptr handle { dl_load_library(entry) };
|
||||||
if (!handle && !silent) {
|
if (!handle && !silent) {
|
||||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(entry.path()).c_str());
|
||||||
}
|
}
|
||||||
if (handle) {
|
if (handle) {
|
||||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||||
if (score_fn) {
|
if (score_fn) {
|
||||||
int s = score_fn();
|
int s = score_fn();
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_str(entry.path()).c_str(), s);
|
||||||
#endif
|
#endif
|
||||||
if (s > best_score) {
|
if (s > best_score) {
|
||||||
best_score = s;
|
best_score = s;
|
||||||
best_path = entry.path().wstring();
|
best_path = entry.path();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (!silent) {
|
if (!silent) {
|
||||||
GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, path_str(entry.path()).c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -541,7 +536,8 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||||
if (best_score == 0) {
|
if (best_score == 0) {
|
||||||
// try to load the base backend
|
// try to load the base backend
|
||||||
for (const auto & search_path : search_paths) {
|
for (const auto & search_path : search_paths) {
|
||||||
std::wstring path = search_path + backend_filename_prefix() + utf8_to_utf16(name) + backend_filename_suffix();
|
fs::path filename = backend_filename_prefix().native() + name_path.native() + backend_filename_extension().native();
|
||||||
|
fs::path path = search_path.native() + filename.native();
|
||||||
if (fs::exists(path)) {
|
if (fs::exists(path)) {
|
||||||
return get_reg().load_backend(path, silent);
|
return get_reg().load_backend(path, silent);
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -190,10 +190,11 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
|
||||||
}
|
}
|
||||||
} // namespace ggml::cpu::kleidiai
|
} // namespace ggml::cpu::kleidiai
|
||||||
|
|
||||||
static void ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
GGML_API enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
||||||
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
|
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
|
||||||
|
|
||||||
GGML_UNUSED(buffer);
|
GGML_UNUSED(buffer);
|
||||||
|
return GGML_STATUS_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
|
static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
|
||||||
|
|
|
@ -294,11 +294,13 @@ static void ggml_cuda_op_bin_bcast(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) {
|
const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
GGML_ASSERT(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
||||||
|
|
||||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
|
op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
|
||||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||||
|
op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
|
||||||
op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
|
op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
|
||||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
||||||
op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
|
op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
|
||||||
|
|
|
@ -1,34 +1,45 @@
|
||||||
#include "clamp.cuh"
|
#include "clamp.cuh"
|
||||||
|
|
||||||
static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) {
|
static __device__ __forceinline__ float op_clamp(float x, float min, float max) {
|
||||||
|
return fminf(fmaxf(x, min), max);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
static __global__ void op_clamp_kernel(const T * x, T * dst, const T min, const T max, const int k) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
if (i >= k) {
|
if (i >= k) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
|
dst[i] = (T)op_clamp((float)x[i], (float)min, (float)max);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) {
|
template <class T>
|
||||||
|
static void clamp_cuda(const T * x, T * dst, const T min, const T max, const int k, cudaStream_t stream) {
|
||||||
const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
|
const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
|
||||||
clamp_f32<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
|
op_clamp_kernel<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const float * src0_d = (const float *)src0->data;
|
const void * src0_d = src0->data;
|
||||||
float * dst_d = (float *)dst->data;
|
void * dst_d = dst->data;
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(src0->type == dst->type);
|
||||||
|
|
||||||
float min;
|
float min;
|
||||||
float max;
|
float max;
|
||||||
memcpy(&min, dst->op_params, sizeof(float));
|
memcpy(&min, dst->op_params, sizeof(float));
|
||||||
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
|
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
|
||||||
|
|
||||||
clamp_f32_cuda(src0_d, dst_d, min, max, ggml_nelements(src0), stream);
|
if (src0->type == GGML_TYPE_F16) {
|
||||||
|
clamp_cuda((const half *)src0_d, (half *)dst_d, (half)min, (half)max, ggml_nelements(src0), stream);
|
||||||
|
} else {
|
||||||
|
clamp_cuda((const float *)src0_d, (float *)dst_d, (float)min, (float)max, ggml_nelements(src0), stream);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,6 +62,7 @@
|
||||||
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
||||||
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
||||||
|
|
||||||
|
#define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD)
|
||||||
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
|
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
|
||||||
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
|
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
|
||||||
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
||||||
|
@ -196,6 +197,10 @@ typedef float2 dfloat2;
|
||||||
#define FP16_MMA_AVAILABLE
|
#define FP16_MMA_AVAILABLE
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||||
|
|
||||||
|
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))
|
||||||
|
#define FP16_MMA_AVAILABLE
|
||||||
|
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))
|
||||||
|
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||||
#define NEW_MMA_AVAILABLE
|
#define NEW_MMA_AVAILABLE
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||||
|
@ -223,12 +228,18 @@ static bool fast_fp16_hardware_available(const int cc) {
|
||||||
|
|
||||||
// Any FP16 tensor core instructions are available for ggml code.
|
// Any FP16 tensor core instructions are available for ggml code.
|
||||||
static bool fp16_mma_available(const int cc) {
|
static bool fp16_mma_available(const int cc) {
|
||||||
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
|
||||||
|
return false;
|
||||||
|
#else
|
||||||
|
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ||
|
||||||
|
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
|
||||||
|
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
|
||||||
}
|
}
|
||||||
|
|
||||||
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
||||||
static bool fp16_mma_hardware_available(const int cc) {
|
static bool fp16_mma_hardware_available(const int cc) {
|
||||||
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
|
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA ||
|
||||||
|
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
||||||
|
|
|
@ -57,12 +57,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
||||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||||
|
|
||||||
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
|
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
|
||||||
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||||
GGML_UNUSED(Q_v);
|
GGML_UNUSED(Q_v);
|
||||||
|
|
||||||
T sum = 0.0f;
|
T sum = 0.0f;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
|
||||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||||
|
|
||||||
const int ib = k_KQ / QI8_1;
|
const int ib = k_KQ / QI8_1;
|
||||||
|
@ -70,7 +71,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
||||||
const int shift = k_KQ & (QI8_1/2);
|
const int shift = k_KQ & (QI8_1/2);
|
||||||
|
|
||||||
const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
||||||
const int u = Q_q8[k_KQ_0/WARP_SIZE];
|
const int u = Q_q8[k_KQ_0/warp_size];
|
||||||
|
|
||||||
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
||||||
|
|
||||||
|
@ -78,14 +79,14 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
||||||
if (std::is_same<T, half>::value) {
|
if (std::is_same<T, half>::value) {
|
||||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||||
|
|
||||||
const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
|
const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/warp_size];
|
||||||
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
|
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
|
||||||
} else
|
} else
|
||||||
#endif // FP16_AVAILABLE
|
#endif // FP16_AVAILABLE
|
||||||
{
|
{
|
||||||
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
||||||
|
|
||||||
sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
|
sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (8/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,12 +98,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
||||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||||
|
|
||||||
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
|
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
|
||||||
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||||
GGML_UNUSED(Q_v);
|
GGML_UNUSED(Q_v);
|
||||||
|
|
||||||
T sum = 0.0f;
|
T sum = 0.0f;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
|
||||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||||
|
|
||||||
const int ib = k_KQ / QI8_1;
|
const int ib = k_KQ / QI8_1;
|
||||||
|
@ -110,7 +112,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
||||||
const int shift = k_KQ & (QI8_1/2);
|
const int shift = k_KQ & (QI8_1/2);
|
||||||
|
|
||||||
const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
||||||
const int u = Q_q8[k_KQ_0/WARP_SIZE];
|
const int u = Q_q8[k_KQ_0/warp_size];
|
||||||
|
|
||||||
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
||||||
|
|
||||||
|
@ -118,7 +120,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
||||||
if (std::is_same<T, half>::value) {
|
if (std::is_same<T, half>::value) {
|
||||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||||
|
|
||||||
const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
|
const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/warp_size];
|
||||||
const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
|
const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
|
||||||
sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
|
sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
|
||||||
} else
|
} else
|
||||||
|
@ -126,8 +128,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
||||||
{
|
{
|
||||||
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
||||||
|
|
||||||
const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
|
const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
|
||||||
const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
|
const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
|
||||||
|
|
||||||
sum += (T) (sumid4d8 + m4s8scaled);
|
sum += (T) (sumid4d8 + m4s8scaled);
|
||||||
}
|
}
|
||||||
|
@ -141,12 +143,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
||||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||||
|
|
||||||
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
|
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
|
||||||
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||||
GGML_UNUSED(Q_v);
|
GGML_UNUSED(Q_v);
|
||||||
|
|
||||||
T sum = 0.0f;
|
T sum = 0.0f;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
|
||||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||||
|
|
||||||
const int ib = k_KQ / QI8_1;
|
const int ib = k_KQ / QI8_1;
|
||||||
|
@ -161,7 +164,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
||||||
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
||||||
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
||||||
|
|
||||||
const int u = Q_q8[k_KQ_0/WARP_SIZE];
|
const int u = Q_q8[k_KQ_0/warp_size];
|
||||||
|
|
||||||
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
||||||
|
|
||||||
|
@ -169,14 +172,14 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
||||||
if (std::is_same<T, half>::value) {
|
if (std::is_same<T, half>::value) {
|
||||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||||
|
|
||||||
const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
|
const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/warp_size];
|
||||||
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
|
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
|
||||||
} else
|
} else
|
||||||
#endif // FP16_AVAILABLE
|
#endif // FP16_AVAILABLE
|
||||||
{
|
{
|
||||||
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
||||||
|
|
||||||
sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
|
sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (16/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -188,12 +191,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
||||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||||
|
|
||||||
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
|
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
|
||||||
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||||
GGML_UNUSED(Q_v);
|
GGML_UNUSED(Q_v);
|
||||||
|
|
||||||
T sum = 0.0f;
|
T sum = 0.0f;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
|
||||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||||
|
|
||||||
const int ib = k_KQ / QI8_1;
|
const int ib = k_KQ / QI8_1;
|
||||||
|
@ -208,7 +212,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
||||||
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
||||||
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
||||||
|
|
||||||
const int u = Q_q8[k_KQ_0/WARP_SIZE];
|
const int u = Q_q8[k_KQ_0/warp_size];
|
||||||
|
|
||||||
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
||||||
|
|
||||||
|
@ -216,7 +220,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
||||||
if (std::is_same<T, half>::value) {
|
if (std::is_same<T, half>::value) {
|
||||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||||
|
|
||||||
const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
|
const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/warp_size];
|
||||||
const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
|
const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
|
||||||
sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
|
sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
|
||||||
} else
|
} else
|
||||||
|
@ -224,8 +228,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
||||||
{
|
{
|
||||||
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
||||||
|
|
||||||
const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
|
const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
|
||||||
const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
|
const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
|
||||||
|
|
||||||
sum += (T) (sumid5d8 + m5s8scaled);
|
sum += (T) (sumid5d8 + m5s8scaled);
|
||||||
}
|
}
|
||||||
|
@ -239,12 +243,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
||||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||||
|
|
||||||
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
|
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
|
||||||
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||||
GGML_UNUSED(Q_v);
|
GGML_UNUSED(Q_v);
|
||||||
|
|
||||||
T sum = 0.0f;
|
T sum = 0.0f;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
|
||||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||||
|
|
||||||
const int ib = k_KQ / QI8_0;
|
const int ib = k_KQ / QI8_0;
|
||||||
|
@ -255,13 +260,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
||||||
T Q_d;
|
T Q_d;
|
||||||
if (std::is_same<T, half>::value) {
|
if (std::is_same<T, half>::value) {
|
||||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||||
Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]);
|
Q_d = __low2half(Q_ds[k_KQ_0/warp_size]);
|
||||||
} else {
|
} else {
|
||||||
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
||||||
Q_d = Q_ds[k_KQ_0/WARP_SIZE].x;
|
Q_d = Q_ds[k_KQ_0/warp_size].x;
|
||||||
}
|
}
|
||||||
|
|
||||||
sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d);
|
sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d);
|
||||||
}
|
}
|
||||||
|
|
||||||
return sum;
|
return sum;
|
||||||
|
@ -272,6 +277,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
|
||||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
|
||||||
|
|
||||||
const half2 * K_h2 = (const half2 *) K_c;
|
const half2 * K_h2 = (const half2 *) K_c;
|
||||||
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||||
GGML_UNUSED(Q_q8);
|
GGML_UNUSED(Q_q8);
|
||||||
GGML_UNUSED(Q_ds_v);
|
GGML_UNUSED(Q_ds_v);
|
||||||
|
|
||||||
|
@ -282,11 +288,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
|
||||||
half2 sum2 = make_half2(0.0f, 0.0f);
|
half2 sum2 = make_half2(0.0f, 0.0f);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
|
||||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||||
|
|
||||||
const half2 K_ik = K_h2[k_KQ];
|
const half2 K_ik = K_h2[k_KQ];
|
||||||
sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
|
sum2 += K_ik * Q_h2[k_KQ_0/warp_size];
|
||||||
}
|
}
|
||||||
|
|
||||||
return __low2half(sum2) + __high2half(sum2);
|
return __low2half(sum2) + __high2half(sum2);
|
||||||
|
@ -298,12 +304,12 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
|
||||||
float sum = 0.0f;
|
float sum = 0.0f;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
|
||||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||||
|
|
||||||
const half2 K_ik = K_h2[k_KQ];
|
const half2 K_ik = K_h2[k_KQ];
|
||||||
sum += __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x;
|
sum += __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x;
|
||||||
sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y;
|
sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y;
|
||||||
}
|
}
|
||||||
|
|
||||||
return sum;
|
return sum;
|
||||||
|
@ -698,6 +704,8 @@ void launch_fattn(
|
||||||
|
|
||||||
GGML_ASSERT(Q->ne[3] == 1);
|
GGML_ASSERT(Q->ne[3] == 1);
|
||||||
|
|
||||||
|
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
|
||||||
|
|
||||||
ggml_cuda_pool & pool = ctx.pool();
|
ggml_cuda_pool & pool = ctx.pool();
|
||||||
cudaStream_t main_stream = ctx.stream();
|
cudaStream_t main_stream = ctx.stream();
|
||||||
const int id = ggml_cuda_get_device();
|
const int id = ggml_cuda_get_device();
|
||||||
|
@ -750,7 +758,7 @@ void launch_fattn(
|
||||||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||||
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
||||||
|
|
||||||
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
const dim3 block_dim(warp_size, nwarps, 1);
|
||||||
dim3 blocks_num;
|
dim3 blocks_num;
|
||||||
if (parallel_blocks == 0) {
|
if (parallel_blocks == 0) {
|
||||||
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
||||||
|
@ -796,6 +804,8 @@ void launch_fattn(
|
||||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
|
||||||
|
GGML_ASSERT(block_dim.x % warp_size == 0);
|
||||||
|
GGML_ASSERT(!GGML_CUDA_CC_IS_AMD(cc) || block_dim.x * block_dim.y <= 4 * (unsigned int)warp_size);
|
||||||
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
|
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
|
||||||
(const char *) Q->data,
|
(const char *) Q->data,
|
||||||
K_data,
|
K_data,
|
||||||
|
|
|
@ -7,14 +7,19 @@
|
||||||
#include "fattn-wmma-f16.cuh"
|
#include "fattn-wmma-f16.cuh"
|
||||||
|
|
||||||
#ifdef FP16_MMA_AVAILABLE
|
#ifdef FP16_MMA_AVAILABLE
|
||||||
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||||
#include <mma.h>
|
#include <mma.h>
|
||||||
|
namespace wmma = nvcuda::wmma;
|
||||||
|
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
|
||||||
|
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
|
||||||
|
#include <rocwmma/rocwmma.hpp>
|
||||||
|
namespace wmma = rocwmma;
|
||||||
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||||
#endif // FP16_MMA_AVAILABLE
|
#endif // FP16_MMA_AVAILABLE
|
||||||
|
|
||||||
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
||||||
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
|
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
|
||||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
||||||
static __global__ void flash_attn_ext_f16(
|
static __global__ void flash_attn_ext_f16(
|
||||||
const char * __restrict__ Q,
|
const char * __restrict__ Q,
|
||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
|
@ -51,7 +56,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const int ne1,
|
const int ne1,
|
||||||
const int ne2,
|
const int ne2,
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
|
||||||
// Skip unused kernel variants for faster compilation:
|
// Skip unused kernel variants for faster compilation:
|
||||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
|
@ -60,6 +65,8 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||||
|
|
||||||
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
||||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||||
|
|
||||||
|
@ -68,11 +75,11 @@ static __global__ void flash_attn_ext_f16(
|
||||||
constexpr int frag_m = ncols == 8 ? 32 : 16;
|
constexpr int frag_m = ncols == 8 ? 32 : 16;
|
||||||
constexpr int frag_n = ncols == 8 ? 8 : 16;
|
constexpr int frag_n = ncols == 8 ? 8 : 16;
|
||||||
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
|
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
|
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
|
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
|
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
|
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
|
||||||
|
|
||||||
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
|
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
|
||||||
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
|
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
|
||||||
|
@ -132,9 +139,9 @@ static __global__ void flash_attn_ext_f16(
|
||||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
const int j = j0 + threadIdx.y;
|
const int j = j0 + threadIdx.y;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||||
const int i = i0 + threadIdx.x;
|
const int i = i0 + threadIdx.x;
|
||||||
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
if (i0 + warp_size > D/2 && i >= D/2) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
|
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
|
||||||
|
@ -146,9 +153,9 @@ static __global__ void flash_attn_ext_f16(
|
||||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
const int j = j0 + threadIdx.y;
|
const int j = j0 + threadIdx.y;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
for (int i0 = 0; i0 < D; i0 += warp_size) {
|
||||||
const int i = i0 + threadIdx.x;
|
const int i = i0 + threadIdx.x;
|
||||||
if (i0 + WARP_SIZE > D && i >= D) {
|
if (i0 + warp_size > D && i >= D) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
||||||
|
@ -162,7 +169,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
for (int i0 = 0; i0 < D; i0 += 16) {
|
for (int i0 = 0; i0 < D; i0 += 16) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||||
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
|
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -176,20 +183,20 @@ static __global__ void flash_attn_ext_f16(
|
||||||
frag_c_KQ KQ_c[ncols/frag_n];
|
frag_c_KQ KQ_c[ncols/frag_n];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||||
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
|
wmma::fill_fragment(KQ_c[j], static_cast<KQ_acc_t>(0.0f));
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
||||||
frag_a_K K_a;
|
frag_a_K K_a;
|
||||||
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||||
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||||
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
|
wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -202,27 +209,27 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const int j = j0 + threadIdx.y;
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
if (std::is_same<KQ_acc_t, float>::value) {
|
if (std::is_same<KQ_acc_t, float>::value) {
|
||||||
float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
|
float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
|
||||||
const int k = k0 + threadIdx.x;
|
const int k = k0 + threadIdx.x;
|
||||||
|
|
||||||
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
|
KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k];
|
||||||
|
|
||||||
if (use_logit_softcap) {
|
if (use_logit_softcap) {
|
||||||
KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
|
KQ_f_tmp[k0/warp_size] = logit_softcap*tanhf(KQ_f_tmp[k0/warp_size]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
float KQ_max_new = KQ_max_f[j0/nwarps];
|
float KQ_max_new = KQ_max_f[j0/nwarps];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
|
||||||
const int k = k0 + threadIdx.x;
|
const int k = k0 + threadIdx.x;
|
||||||
|
|
||||||
KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
|
KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
|
||||||
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
|
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]);
|
||||||
}
|
}
|
||||||
KQ_max_new = warp_reduce_max(KQ_max_new);
|
KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);
|
||||||
|
|
||||||
const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
|
const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
|
||||||
KQ_max_scale_f[j0/nwarps] = expf(diff);
|
KQ_max_scale_f[j0/nwarps] = expf(diff);
|
||||||
|
@ -233,48 +240,48 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
float KQ_rowsum_add = 0.0f;
|
float KQ_rowsum_add = 0.0f;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
|
||||||
const int k = k0 + threadIdx.x;
|
const int k = k0 + threadIdx.x;
|
||||||
|
|
||||||
const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
|
const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps];
|
||||||
KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
|
KQ_f_tmp[k0/warp_size] = expf(diff);
|
||||||
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
||||||
KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
|
KQ_f_tmp[k0/warp_size] = 0.0f;
|
||||||
}
|
}
|
||||||
KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
|
KQ_rowsum_add += KQ_f_tmp[k0/warp_size];
|
||||||
KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
|
KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/warp_size];
|
||||||
}
|
}
|
||||||
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
|
||||||
|
|
||||||
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
||||||
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
|
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
|
||||||
} else {
|
} else {
|
||||||
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
|
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
|
||||||
const int k = k0 + threadIdx.x;
|
const int k = k0 + threadIdx.x;
|
||||||
|
|
||||||
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
|
KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k];
|
||||||
|
|
||||||
if (use_logit_softcap) {
|
if (use_logit_softcap) {
|
||||||
// There is no dedicated tangens hyperbolicus function for half2.
|
// There is no dedicated tangens hyperbolicus function for half2.
|
||||||
KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
|
KQ2_tmp[k0/warp_size] = h2exp(KQ2_tmp[k0/warp_size]*make_half2(2.0f, 2.0f));
|
||||||
KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
|
KQ2_tmp[k0/warp_size] = (KQ2_tmp[k0/warp_size] - make_half2(1.0f, 1.0f))
|
||||||
/(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
|
/(KQ2_tmp[k0/warp_size] + make_half2(1.0f, 1.0f));
|
||||||
|
|
||||||
KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
|
KQ2_tmp[k0/warp_size] *= logit_softcap_2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
|
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
|
||||||
const int k = k0 + threadIdx.x;
|
const int k = k0 + threadIdx.x;
|
||||||
|
|
||||||
KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
||||||
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
|
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]);
|
||||||
}
|
}
|
||||||
KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
KQ_max_new = __half2half2(warp_reduce_max<warp_size>(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
||||||
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
|
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
|
||||||
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
|
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
|
||||||
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
||||||
|
@ -283,17 +290,17 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
|
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
|
||||||
const int k = k0 + threadIdx.x;
|
const int k = k0 + threadIdx.x;
|
||||||
|
|
||||||
const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
|
const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps];
|
||||||
KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
|
KQ2_tmp[k0/warp_size] = h2exp(diff);
|
||||||
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
||||||
*((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
|
*((uint32_t *) &KQ2_tmp[k0/warp_size]) &= ftz_mask;
|
||||||
KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
|
KQ_rowsum_add += KQ2_tmp[k0/warp_size];
|
||||||
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
|
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/warp_size];
|
||||||
}
|
}
|
||||||
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
|
||||||
|
|
||||||
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
||||||
KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
|
KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
|
||||||
|
@ -308,7 +315,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
||||||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||||
nvcuda::wmma::load_matrix_sync(
|
wmma::load_matrix_sync(
|
||||||
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
|
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
|
||||||
KQ + j0*(kqar*kqs_padded) + k,
|
KQ + j0*(kqar*kqs_padded) + k,
|
||||||
kqar*kqs_padded);
|
kqar*kqs_padded);
|
||||||
|
@ -320,7 +327,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
|
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||||
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
|
wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast<half>(0.0f));
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -328,10 +335,10 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||||
|
|
||||||
frag_a_V v_a;
|
frag_a_V v_a;
|
||||||
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||||
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -343,10 +350,10 @@ static __global__ void flash_attn_ext_f16(
|
||||||
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
|
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||||
nvcuda::wmma::store_matrix_sync(
|
wmma::store_matrix_sync(
|
||||||
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
||||||
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
|
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
|
||||||
D_padded, nvcuda::wmma::mem_col_major);
|
D_padded, wmma::mem_col_major);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -364,9 +371,9 @@ static __global__ void flash_attn_ext_f16(
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||||
const int i = i0 + threadIdx.x;
|
const int i = i0 + threadIdx.x;
|
||||||
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
if (i0 + warp_size > D/2 && i >= D/2) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -398,9 +405,9 @@ static __global__ void flash_attn_ext_f16(
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
for (int i0 = 0; i0 < D; i0 += warp_size) {
|
||||||
const int i = i0 + threadIdx.x;
|
const int i = i0 + threadIdx.x;
|
||||||
if (i0 + WARP_SIZE > D && i >= D) {
|
if (i0 + warp_size > D && i >= D) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
float dst_val = VKQ[j_VKQ*D_padded + i];
|
float dst_val = VKQ[j_VKQ*D_padded + i];
|
||||||
|
@ -425,7 +432,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int get_max_power_of_2(int x) {
|
constexpr int get_max_power_of_2(int x) {
|
||||||
|
@ -515,6 +522,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
|
||||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||||
|
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
|
||||||
|
|
||||||
if (prec != GGML_PREC_DEFAULT) {
|
if (prec != GGML_PREC_DEFAULT) {
|
||||||
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
||||||
|
@ -571,7 +579,8 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
if (Q->ne[1] <= 8 && Q->ne[0] % warp_size == 0) {
|
||||||
constexpr int cols_per_block = 8;
|
constexpr int cols_per_block = 8;
|
||||||
switch (Q->ne[0]) {
|
switch (Q->ne[0]) {
|
||||||
case 64:
|
case 64:
|
||||||
|
@ -592,6 +601,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
|
||||||
if (Q->ne[1] <= 32) {
|
if (Q->ne[1] <= 32) {
|
||||||
constexpr int cols_per_block = 16;
|
constexpr int cols_per_block = 16;
|
||||||
|
|
|
@ -250,10 +250,18 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
|
|
||||||
ggml_cuda_set_device(ctx.device);
|
ggml_cuda_set_device(ctx.device);
|
||||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||||
|
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
||||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||||
|
|
||||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
|
||||||
if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
|
if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
|
||||||
|
#if defined(GGML_HIP_ROCWMMA_FATTN)
|
||||||
|
if (fp16_mma_available(cc)) {
|
||||||
|
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
|
||||||
|
|
||||||
|
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||||
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||||
} else {
|
} else {
|
||||||
|
@ -291,7 +299,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||||
const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
|
const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
|
||||||
K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
|
K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
|
||||||
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0 && !mma_fast_for_bs1) {
|
if (Q->ne[1] == 1 && Q->ne[0] % (2*warp_size) == 0 && !mma_fast_for_bs1) {
|
||||||
if (prec == GGML_PREC_DEFAULT) {
|
if (prec == GGML_PREC_DEFAULT) {
|
||||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -2152,6 +2152,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
break;
|
break;
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(dst)) {
|
switch (ggml_get_unary_op(dst)) {
|
||||||
|
case GGML_UNARY_OP_ABS:
|
||||||
|
ggml_cuda_op_abs(ctx, dst);
|
||||||
|
break;
|
||||||
|
case GGML_UNARY_OP_SGN:
|
||||||
|
ggml_cuda_op_sgn(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_UNARY_OP_NEG:
|
case GGML_UNARY_OP_NEG:
|
||||||
ggml_cuda_op_neg(ctx, dst);
|
ggml_cuda_op_neg(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
@ -2249,6 +2255,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
case GGML_OP_CLAMP:
|
case GGML_OP_CLAMP:
|
||||||
ggml_cuda_op_clamp(ctx, dst);
|
ggml_cuda_op_clamp(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_LOG:
|
||||||
|
ggml_cuda_op_log(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_OP_NONE:
|
case GGML_OP_NONE:
|
||||||
case GGML_OP_RESHAPE:
|
case GGML_OP_RESHAPE:
|
||||||
case GGML_OP_VIEW:
|
case GGML_OP_VIEW:
|
||||||
|
@ -2967,6 +2976,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
switch (op->op) {
|
switch (op->op) {
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(op)) {
|
switch (ggml_get_unary_op(op)) {
|
||||||
|
case GGML_UNARY_OP_ABS:
|
||||||
|
case GGML_UNARY_OP_SGN:
|
||||||
case GGML_UNARY_OP_NEG:
|
case GGML_UNARY_OP_NEG:
|
||||||
case GGML_UNARY_OP_STEP:
|
case GGML_UNARY_OP_STEP:
|
||||||
case GGML_UNARY_OP_GELU:
|
case GGML_UNARY_OP_GELU:
|
||||||
|
@ -3149,7 +3160,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
return false;
|
return false;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SILU_BACK:
|
case GGML_OP_SILU_BACK:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||||
break;
|
break;
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
|
@ -3173,6 +3184,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_SIN:
|
case GGML_OP_SIN:
|
||||||
case GGML_OP_COS:
|
case GGML_OP_COS:
|
||||||
case GGML_OP_CLAMP:
|
case GGML_OP_CLAMP:
|
||||||
|
case GGML_OP_LOG:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
return op->src[0]->type != GGML_TYPE_BF16;
|
return op->src[0]->type != GGML_TYPE_BF16;
|
||||||
|
|
|
@ -1,305 +1,213 @@
|
||||||
#include "unary.cuh"
|
#include "unary.cuh"
|
||||||
|
|
||||||
static __global__ void neg_f32(const float * x, float * dst, const int k) {
|
static __device__ __forceinline__ float op_abs(float x) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
return fabsf(x);
|
||||||
|
|
||||||
if (i >= k) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
dst[i] = -x[i];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void step_f32(const float * x, float * dst, const int k) {
|
static __device__ __forceinline__ float op_sgn(float x) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
return (x > 0.f ? 1.f : ((x < 0.f ? -1.f : 0.f)));
|
||||||
|
|
||||||
if (i >= k) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
dst[i] = x[i] > 0.0f;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void gelu_f32(const float * x, float * dst, const int k) {
|
static __device__ __forceinline__ float op_neg(float x) {
|
||||||
|
return -x;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float op_step(float x) {
|
||||||
|
return x > 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float op_gelu(float x) {
|
||||||
const float GELU_COEF_A = 0.044715f;
|
const float GELU_COEF_A = 0.044715f;
|
||||||
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
||||||
|
|
||||||
if (i >= k) {
|
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
float xi = x[i];
|
|
||||||
dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void gelu_quick_f32(const float * x, float * dst, int k) {
|
static __device__ __forceinline__ float op_gelu_quick(float x) {
|
||||||
const float GELU_QUICK_COEF = -1.702f;
|
const float GELU_QUICK_COEF = -1.702f;
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
||||||
if (i >= k) {
|
return x * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x)));
|
||||||
return;
|
|
||||||
}
|
|
||||||
dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void silu_f32(const float * x, float * dst, const int k) {
|
static __device__ __forceinline__ float op_silu(float x) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
return x / (1.0f + expf(-x));
|
||||||
|
|
||||||
if (i >= k) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
dst[i] = x[i] / (1.0f + expf(-x[i]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void silu_back_f32(
|
static __device__ __forceinline__ float op_tanh(float x) {
|
||||||
const float * grad, const float * xf, float * dst, const int k) {
|
return tanhf(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float op_relu(float x) {
|
||||||
|
return fmaxf(x, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float op_sigmoid(float x) {
|
||||||
|
return 1.0f / (1.0f + expf(-x));
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float op_hardsigmoid(float x) {
|
||||||
|
return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float op_hardswish(float x) {
|
||||||
|
return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float op_exp(float x) {
|
||||||
|
return expf(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float op_sqr(float x) {
|
||||||
|
return x * x;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float op_sqrt(float x) {
|
||||||
|
return sqrtf(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float op_sin(float x) {
|
||||||
|
return sinf(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float op_cos(float x) {
|
||||||
|
return cosf(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float op_log(float x) {
|
||||||
|
return logf(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <float (*op)(float), typename T>
|
||||||
|
static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
if (i >= k) {
|
if (i >= k) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const float xfi = xf[i];
|
dst[i] = (T)op((float)x[i]);
|
||||||
const float s = 1.0f / (1.0f + expf(-xfi));
|
|
||||||
dst[i] = grad[i] * s * (1.0f + xfi * (1.0f - s));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void tanh_f32(const float * x, float * dst, int k) {
|
template <float (*op)(float), typename T>
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
static void unary_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
|
||||||
if (i >= k) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
dst[i] = tanhf(x[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __global__ void relu_f32(const float * x, float * dst, const int k) {
|
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
||||||
|
|
||||||
if (i >= k) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
dst[i] = fmaxf(x[i], 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __global__ void sigmoid_f32(const float * x, float * dst, const int k) {
|
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
||||||
|
|
||||||
if (i >= k) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
dst[i] = 1.0f / (1.0f + expf(-x[i]));
|
|
||||||
}
|
|
||||||
|
|
||||||
static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
|
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
||||||
|
|
||||||
if (i >= k) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
|
|
||||||
}
|
|
||||||
|
|
||||||
static __global__ void hardswish_f32(const float * x, float * dst, const int k) {
|
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
||||||
|
|
||||||
if (i >= k) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
|
|
||||||
}
|
|
||||||
|
|
||||||
static __global__ void exp_f32(const float * x, float * dst, const int k) {
|
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
||||||
|
|
||||||
if (i >= k) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
dst[i] = expf(x[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
|
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
||||||
if (i >= k) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __global__ void sqr_f32(const float * x, float * dst, const int k) {
|
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
||||||
|
|
||||||
if (i >= k) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
dst[i] = x[i] * x[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
|
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
||||||
|
|
||||||
if (i >= k) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
dst[i] = sqrtf(x[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __global__ void sin_f32(const float * x, float * dst, const int k) {
|
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
||||||
|
|
||||||
if (i >= k) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
dst[i] = sinf(x[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __global__ void cos_f32(const float * x, float * dst, const int k) {
|
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
||||||
|
|
||||||
if (i >= k) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
dst[i] = cosf(x[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void neg_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
||||||
const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
|
const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
|
||||||
neg_f32<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
unary_op_kernel<op><<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void step_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
template <float (*op)(float)>
|
||||||
const int num_blocks = (k + CUDA_STEP_BLOCK_SIZE - 1) / CUDA_STEP_BLOCK_SIZE;
|
void ggml_cuda_op_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
step_f32<<<num_blocks, CUDA_STEP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const void * src0_d = src0->data;
|
||||||
|
void * dst_d = dst->data;
|
||||||
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(src0->type == dst->type);
|
||||||
|
|
||||||
|
if (src0->type == GGML_TYPE_F16) {
|
||||||
|
unary_cuda<op>((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
|
||||||
|
} else {
|
||||||
|
unary_cuda<op>((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
|
ggml_cuda_op_unary<op_abs>(ctx, dst);
|
||||||
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
|
ggml_cuda_op_unary<op_sgn>(ctx, dst);
|
||||||
gelu_quick_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
||||||
const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
|
|
||||||
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void silu_back_f32_cuda(const float * grad, const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
||||||
const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
|
|
||||||
silu_back_f32<<<num_blocks, CUDA_SILU_BACK_BLOCK_SIZE, 0, stream>>>(grad, x, dst, k);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
||||||
const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
|
|
||||||
tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
||||||
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
|
||||||
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void sigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
||||||
const int num_blocks = (k + CUDA_SIGMOID_BLOCK_SIZE - 1) / CUDA_SIGMOID_BLOCK_SIZE;
|
|
||||||
sigmoid_f32<<<num_blocks, CUDA_SIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
||||||
const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
|
|
||||||
hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
||||||
const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE;
|
|
||||||
hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void exp_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
||||||
const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE;
|
|
||||||
exp_f32<<<num_blocks, CUDA_EXP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
|
|
||||||
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
|
||||||
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
||||||
const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
|
|
||||||
sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
||||||
const int num_blocks = (k + CUDA_SQRT_BLOCK_SIZE - 1) / CUDA_SQRT_BLOCK_SIZE;
|
|
||||||
sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void sin_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
||||||
const int num_blocks = (k + CUDA_SIN_BLOCK_SIZE - 1) / CUDA_SIN_BLOCK_SIZE;
|
|
||||||
sin_f32<<<num_blocks, CUDA_SIN_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
||||||
const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE;
|
|
||||||
cos_f32<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
ggml_cuda_op_unary<op_neg>(ctx, dst);
|
||||||
const float * src0_d = (const float *)src0->data;
|
|
||||||
float * dst_d = (float *)dst->data;
|
|
||||||
cudaStream_t stream = ctx.stream();
|
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
||||||
|
|
||||||
neg_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
ggml_cuda_op_unary<op_step>(ctx, dst);
|
||||||
const float * src0_d = (const float *)src0->data;
|
|
||||||
float * dst_d = (float *)dst->data;
|
|
||||||
cudaStream_t stream = ctx.stream();
|
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
||||||
|
|
||||||
step_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
ggml_cuda_op_unary<op_gelu>(ctx, dst);
|
||||||
const float * src0_d = (const float *)src0->data;
|
}
|
||||||
float * dst_d = (float *)dst->data;
|
|
||||||
cudaStream_t stream = ctx.stream();
|
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary<op_gelu_quick>(ctx, dst);
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
||||||
|
|
||||||
gelu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
ggml_cuda_op_unary<op_silu>(ctx, dst);
|
||||||
const float * src0_d = (const float *)src0->data;
|
}
|
||||||
float * dst_d = (float *)dst->data;
|
|
||||||
cudaStream_t stream = ctx.stream();
|
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary<op_tanh>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
ggml_cuda_op_unary<op_relu>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary<op_sigmoid>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary<op_hardsigmoid>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary<op_hardswish>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary<op_exp>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary<op_sqr>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary<op_sqrt>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary<op_sin>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary<op_cos>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary<op_log>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* silu_back */
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
|
||||||
|
const float s = 1.0f / (1.0f + expf(-x));
|
||||||
|
return grad * s * (1.0f + x * (1.0f - s));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
static __global__ void silu_back_kernel(const T * grad, const T * xf, T * dst, const int k) {
|
||||||
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (i >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[i] = (T)op_silu_back((float)grad[i], (float)xf[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
static void silu_back_cuda(const T * grad, const T * x, T * dst, const int k, cudaStream_t stream) {
|
||||||
|
const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
|
||||||
|
silu_back_kernel<<<num_blocks, CUDA_SILU_BACK_BLOCK_SIZE, 0, stream>>>(grad, x, dst, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
@ -314,179 +222,58 @@ void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(src0->type == dst->type);
|
||||||
|
|
||||||
silu_back_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(src0), stream);
|
if (src0->type == GGML_TYPE_F16) {
|
||||||
|
silu_back_cuda((const half *)src0_d, (const half *)src1_d, (half *)dst_d, ggml_nelements(src0), stream);
|
||||||
|
} else {
|
||||||
|
silu_back_cuda((const float*)src0_d, (const float*)src1_d, (float *)dst_d, ggml_nelements(src0), stream);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
/* leaky relu */
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
|
||||||
const float * src0_d = (const float *)src0->data;
|
|
||||||
float * dst_d = (float *)dst->data;
|
|
||||||
cudaStream_t stream = ctx.stream();
|
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
static __device__ __forceinline__ float op_leaky_relu(float x, const float negative_slope) {
|
||||||
|
return fmaxf(x, 0) + fminf(x, 0.0f) * negative_slope;
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
||||||
|
|
||||||
gelu_quick_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
template <class T>
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
static __global__ void leaky_relu_kernel(const T * x, T * dst, const int k, const float negative_slope) {
|
||||||
const float * src0_d = (const float *)src0->data;
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
float * dst_d = (float *)dst->data;
|
|
||||||
cudaStream_t stream = ctx.stream();
|
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
if (i >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
dst[i] = (T)op_leaky_relu((float)x[i], negative_slope);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
||||||
|
|
||||||
tanh_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
template <class T>
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
static void leaky_relu_cuda(const T * x, T * dst, const int k, const float negative_slope, cudaStream_t stream) {
|
||||||
const float * src0_d = (const float *)src0->data;
|
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
||||||
float * dst_d = (float *)dst->data;
|
leaky_relu_kernel<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
|
||||||
cudaStream_t stream = ctx.stream();
|
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
||||||
|
|
||||||
relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
|
||||||
const float * src0_d = (const float *)src0->data;
|
|
||||||
float * dst_d = (float *)dst->data;
|
|
||||||
cudaStream_t stream = ctx.stream();
|
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
||||||
|
|
||||||
sigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
|
||||||
const float * src0_d = (const float *)src0->data;
|
|
||||||
float * dst_d = (float *)dst->data;
|
|
||||||
cudaStream_t stream = ctx.stream();
|
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
||||||
|
|
||||||
hardsigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
|
||||||
const float * src0_d = (const float *)src0->data;
|
|
||||||
float * dst_d = (float *)dst->data;
|
|
||||||
cudaStream_t stream = ctx.stream();
|
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
||||||
|
|
||||||
hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
|
||||||
const float * src0_d = (const float *)src0->data;
|
|
||||||
float * dst_d = (float *)dst->data;
|
|
||||||
cudaStream_t stream = ctx.stream();
|
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
||||||
|
|
||||||
exp_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const float * src0_d = (const float *)src0->data;
|
const void * src0_d = src0->data;
|
||||||
float * dst_d = (float *)dst->data;
|
void * dst_d = dst->data;
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(src0->type == dst->type);
|
||||||
|
|
||||||
float negative_slope;
|
float negative_slope;
|
||||||
memcpy(&negative_slope, dst->op_params, sizeof(float));
|
memcpy(&negative_slope, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
leaky_relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), negative_slope, stream);
|
if (src0->type == GGML_TYPE_F16) {
|
||||||
}
|
leaky_relu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), negative_slope, stream);
|
||||||
|
} else {
|
||||||
void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream);
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
}
|
||||||
const float * src0_d = (const float *)src0->data;
|
|
||||||
float * dst_d = (float *)dst->data;
|
|
||||||
cudaStream_t stream = ctx.stream();
|
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
||||||
|
|
||||||
sqr_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
|
||||||
const float * src0_d = (const float *)src0->data;
|
|
||||||
float * dst_d = (float *)dst->data;
|
|
||||||
cudaStream_t stream = ctx.stream();
|
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
||||||
|
|
||||||
sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
|
||||||
const float * src0_d = (const float *)src0->data;
|
|
||||||
float * dst_d = (float *)dst->data;
|
|
||||||
cudaStream_t stream = ctx.stream();
|
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
||||||
|
|
||||||
sin_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
|
||||||
const float * src0_d = (const float *)src0->data;
|
|
||||||
float * dst_d = (float *)dst->data;
|
|
||||||
cudaStream_t stream = ctx.stream();
|
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
||||||
|
|
||||||
cos_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,10 @@
|
||||||
#define CUDA_SIN_BLOCK_SIZE 256
|
#define CUDA_SIN_BLOCK_SIZE 256
|
||||||
#define CUDA_COS_BLOCK_SIZE 256
|
#define CUDA_COS_BLOCK_SIZE 256
|
||||||
|
|
||||||
|
void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
@ -49,3 +53,5 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
|
@ -1200,7 +1200,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||||
case GGML_UNARY_OP_GELU_QUICK:
|
case GGML_UNARY_OP_GELU_QUICK:
|
||||||
case GGML_UNARY_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
case GGML_UNARY_OP_ELU:
|
case GGML_UNARY_OP_ELU:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -1210,21 +1210,26 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||||
case GGML_OP_TRANSPOSE:
|
case GGML_OP_TRANSPOSE:
|
||||||
case GGML_OP_PERMUTE:
|
case GGML_OP_PERMUTE:
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
|
return true;
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
case GGML_OP_SUB:
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_ACC:
|
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
|
return op->src[0]->type == GGML_TYPE_F32;
|
||||||
|
case GGML_OP_ACC:
|
||||||
case GGML_OP_REPEAT:
|
case GGML_OP_REPEAT:
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
case GGML_OP_CLAMP:
|
|
||||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||||
return true;
|
return true;
|
||||||
|
case GGML_OP_CLAMP:
|
||||||
|
return op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_SQR:
|
case GGML_OP_SQR:
|
||||||
case GGML_OP_SQRT:
|
case GGML_OP_SQRT:
|
||||||
case GGML_OP_SIN:
|
case GGML_OP_SIN:
|
||||||
case GGML_OP_COS:
|
case GGML_OP_COS:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||||
|
case GGML_OP_LOG:
|
||||||
|
return false; // TODO: implement
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
|
@ -1254,10 +1259,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
case GGML_OP_PAD_REFLECT_1D:
|
case GGML_OP_PAD_REFLECT_1D:
|
||||||
case GGML_OP_ARANGE:
|
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
|
return op->src[0]->type == GGML_TYPE_F32;
|
||||||
|
case GGML_OP_ARANGE:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
if (op->src[1]->type != op->src[2]->type) {
|
if (op->src[1]->type != op->src[2]->type) {
|
||||||
|
|
|
@ -29,6 +29,7 @@
|
||||||
#include "wkv6.hpp"
|
#include "wkv6.hpp"
|
||||||
#include "outprod.hpp"
|
#include "outprod.hpp"
|
||||||
#include "element_wise.hpp"
|
#include "element_wise.hpp"
|
||||||
|
#include "cpy.hpp"
|
||||||
#include "gla.hpp"
|
#include "gla.hpp"
|
||||||
|
|
||||||
#endif // GGML_SYCL_BACKEND_HPP
|
#endif // GGML_SYCL_BACKEND_HPP
|
||||||
|
|
|
@ -34,6 +34,7 @@
|
||||||
#pragma clang diagnostic ignored "-Wnested-anon-types"
|
#pragma clang diagnostic ignored "-Wnested-anon-types"
|
||||||
#include "ggml-common.h"
|
#include "ggml-common.h"
|
||||||
#pragma clang diagnostic pop
|
#pragma clang diagnostic pop
|
||||||
|
#include "ggml-impl.h"
|
||||||
|
|
||||||
void* ggml_sycl_host_malloc(size_t size);
|
void* ggml_sycl_host_malloc(size_t size);
|
||||||
void ggml_sycl_host_free(void* ptr);
|
void ggml_sycl_host_free(void* ptr);
|
||||||
|
|
701
ggml/src/ggml-sycl/cpy.cpp
Normal file
701
ggml/src/ggml-sycl/cpy.cpp
Normal file
|
@ -0,0 +1,701 @@
|
||||||
|
#include "cpy.hpp"
|
||||||
|
|
||||||
|
#include <float.h>
|
||||||
|
|
||||||
|
#include "dequantize.hpp"
|
||||||
|
|
||||||
|
static __dpct_inline__ int best_index_int8(int n, const int8_t * val, float x) {
|
||||||
|
if (x <= val[0]) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
if (x >= val[n - 1]) {
|
||||||
|
return n - 1;
|
||||||
|
}
|
||||||
|
int ml = 0, mu = n - 1;
|
||||||
|
while (mu - ml > 1) {
|
||||||
|
int mav = (ml + mu) / 2;
|
||||||
|
if (x < val[mav]) {
|
||||||
|
mu = mav;
|
||||||
|
} else {
|
||||||
|
ml = mav;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return x - val[mu - 1] < val[mu] - x ? mu - 1 : mu;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void cpy_1_f32_f32(const char * cxi, char * cdsti) {
|
||||||
|
const float * xi = (const float *) cxi;
|
||||||
|
float * dsti = (float *) cdsti;
|
||||||
|
|
||||||
|
*dsti = *xi;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void cpy_1_f32_f16(const char * cxi, char * cdsti) {
|
||||||
|
const float * xi = (const float *) cxi;
|
||||||
|
sycl::half * dsti = (sycl::half *) cdsti;
|
||||||
|
|
||||||
|
*dsti = sycl::vec<float, 1>(*xi).convert<sycl::half, sycl::rounding_mode::automatic>()[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
static void cpy_1_f16_f16(const char * cxi, char * cdsti) {
|
||||||
|
const sycl::half * xi = (const sycl::half *) cxi;
|
||||||
|
sycl::half * dsti = (sycl::half *) cdsti;
|
||||||
|
|
||||||
|
*dsti = *xi;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void cpy_1_f16_f32(const char * cxi, char * cdsti) {
|
||||||
|
const sycl::half * xi = (const sycl::half *) cxi;
|
||||||
|
float * dsti = (float *) cdsti;
|
||||||
|
|
||||||
|
*dsti = *xi;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void cpy_1_i16_i16(const char * cxi, char * cdsti) {
|
||||||
|
const int16_t * xi = (const int16_t *) cxi;
|
||||||
|
int16_t * dsti = (int16_t *) cdsti;
|
||||||
|
|
||||||
|
*dsti = *xi;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void cpy_1_i32_i32(const char * cxi, char * cdsti) {
|
||||||
|
const int32_t * xi = (const int32_t *) cxi;
|
||||||
|
int32_t * dsti = (int32_t *) cdsti;
|
||||||
|
|
||||||
|
*dsti = *xi;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <cpy_kernel_t cpy_1>
|
||||||
|
static void cpy_f32_f16(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
|
||||||
|
const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
|
||||||
|
const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
|
||||||
|
const sycl::nd_item<3> & item_ct1) {
|
||||||
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
||||||
|
|
||||||
|
if (i >= ne) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
|
||||||
|
// then combine those indices with the corresponding byte offsets to get the total offsets
|
||||||
|
const int i03 = i / (ne00 * ne01 * ne02);
|
||||||
|
const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
|
||||||
|
const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
|
||||||
|
const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;
|
||||||
|
const int x_offset = i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;
|
||||||
|
|
||||||
|
const int i13 = i / (ne10 * ne11 * ne12);
|
||||||
|
const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);
|
||||||
|
const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;
|
||||||
|
const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
|
||||||
|
const int dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
|
||||||
|
|
||||||
|
cpy_1(cx + x_offset, cdst + dst_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
||||||
|
const float * xi = (const float *) cxi;
|
||||||
|
block_q8_0 * dsti = (block_q8_0 *) cdsti;
|
||||||
|
|
||||||
|
float amax = 0.0f; // absolute max
|
||||||
|
|
||||||
|
for (int j = 0; j < QK8_0; j++) {
|
||||||
|
const float v = xi[j];
|
||||||
|
amax = sycl::fmax(amax, sycl::fabs((float) v));
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = amax / ((1 << 7) - 1);
|
||||||
|
const float id = d ? 1.0f / d : 0.0f;
|
||||||
|
|
||||||
|
dsti->d = d;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK8_0; ++j) {
|
||||||
|
const float x0 = xi[j] * id;
|
||||||
|
|
||||||
|
dsti->qs[j] = sycl::round((float) x0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
|
||||||
|
float * cdstf = (float *) (cdsti);
|
||||||
|
|
||||||
|
for (int j = 0; j < QK8_0; j += 2) {
|
||||||
|
dfloat2 dq;
|
||||||
|
dequantize_q8_0(cxi, 0, j, dq);
|
||||||
|
*(cdstf + j) = dq.x();
|
||||||
|
*(cdstf + j + 1) = dq.y();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
|
||||||
|
const float * xi = (const float *) cxi;
|
||||||
|
block_q4_0 * dsti = (block_q4_0 *) cdsti;
|
||||||
|
|
||||||
|
float amax = 0.0f;
|
||||||
|
float vmax = 0.0f;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK4_0; ++j) {
|
||||||
|
const float v = xi[j];
|
||||||
|
if (amax < sycl::fabs((float) v)) {
|
||||||
|
amax = sycl::fabs((float) v);
|
||||||
|
vmax = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = vmax / -8;
|
||||||
|
const float id = d ? 1.0f / d : 0.0f;
|
||||||
|
|
||||||
|
dsti->d = d;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK4_0 / 2; ++j) {
|
||||||
|
const float x0 = xi[0 + j] * id;
|
||||||
|
const float x1 = xi[QK4_0 / 2 + j] * id;
|
||||||
|
|
||||||
|
const uint8_t xi0 = dpct::min(15, (int8_t) (x0 + 8.5f));
|
||||||
|
const uint8_t xi1 = dpct::min(15, (int8_t) (x1 + 8.5f));
|
||||||
|
|
||||||
|
dsti->qs[j] = xi0;
|
||||||
|
dsti->qs[j] |= xi1 << 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
|
||||||
|
const float * xi = (const float *) cxi;
|
||||||
|
block_q4_1 * dsti = (block_q4_1 *) cdsti;
|
||||||
|
|
||||||
|
float vmin = FLT_MAX;
|
||||||
|
float vmax = -FLT_MAX;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK4_1; ++j) {
|
||||||
|
const float v = xi[j];
|
||||||
|
|
||||||
|
if (v < vmin) {
|
||||||
|
vmin = v;
|
||||||
|
}
|
||||||
|
if (v > vmax) {
|
||||||
|
vmax = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = (vmax - vmin) / ((1 << 4) - 1);
|
||||||
|
const float id = d ? 1.0f / d : 0.0f;
|
||||||
|
|
||||||
|
dsti->dm.x() = d;
|
||||||
|
dsti->dm.y() = vmin;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK4_1 / 2; ++j) {
|
||||||
|
const float x0 = (xi[0 + j] - vmin) * id;
|
||||||
|
const float x1 = (xi[QK4_1 / 2 + j] - vmin) * id;
|
||||||
|
|
||||||
|
const uint8_t xi0 = dpct::min(15, (int8_t) (x0 + 0.5f));
|
||||||
|
const uint8_t xi1 = dpct::min(15, (int8_t) (x1 + 0.5f));
|
||||||
|
|
||||||
|
dsti->qs[j] = xi0;
|
||||||
|
dsti->qs[j] |= xi1 << 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
|
||||||
|
const float * xi = (const float *) cxi;
|
||||||
|
block_q5_0 * dsti = (block_q5_0 *) cdsti;
|
||||||
|
|
||||||
|
float amax = 0.0f;
|
||||||
|
float vmax = 0.0f;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK5_0; ++j) {
|
||||||
|
const float v = xi[j];
|
||||||
|
if (amax < sycl::fabs((float) v)) {
|
||||||
|
amax = sycl::fabs((float) v);
|
||||||
|
vmax = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = vmax / -16;
|
||||||
|
const float id = d ? 1.0f / d : 0.0f;
|
||||||
|
|
||||||
|
dsti->d = d;
|
||||||
|
|
||||||
|
uint32_t qh = 0;
|
||||||
|
for (int j = 0; j < QK5_0 / 2; ++j) {
|
||||||
|
const float x0 = xi[0 + j] * id;
|
||||||
|
const float x1 = xi[QK5_0 / 2 + j] * id;
|
||||||
|
|
||||||
|
const uint8_t xi0 = dpct::min(31, (int8_t) (x0 + 16.5f));
|
||||||
|
const uint8_t xi1 = dpct::min(31, (int8_t) (x1 + 16.5f));
|
||||||
|
|
||||||
|
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||||
|
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||||
|
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0 / 2);
|
||||||
|
}
|
||||||
|
memcpy(dsti->qh, &qh, sizeof(qh));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
|
||||||
|
const float * xi = (const float *) cxi;
|
||||||
|
block_q5_1 * dsti = (block_q5_1 *) cdsti;
|
||||||
|
|
||||||
|
float min = xi[0];
|
||||||
|
float max = xi[0];
|
||||||
|
|
||||||
|
for (int j = 1; j < QK5_1; ++j) {
|
||||||
|
const float v = xi[j];
|
||||||
|
min = v < min ? v : min;
|
||||||
|
max = v > max ? v : max;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = (max - min) / 31;
|
||||||
|
const float id = d ? 1.0f / d : 0.0f;
|
||||||
|
|
||||||
|
dsti->dm.x() = d;
|
||||||
|
dsti->dm.y() = min;
|
||||||
|
|
||||||
|
uint32_t qh = 0;
|
||||||
|
for (int j = 0; j < QK5_1 / 2; ++j) {
|
||||||
|
const float x0 = (xi[0 + j] - min) * id;
|
||||||
|
const float x1 = (xi[QK5_1 / 2 + j] - min) * id;
|
||||||
|
|
||||||
|
const uint8_t xi0 = (uint8_t) (x0 + 0.5f);
|
||||||
|
const uint8_t xi1 = (uint8_t) (x1 + 0.5f);
|
||||||
|
|
||||||
|
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||||
|
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||||
|
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1 / 2);
|
||||||
|
}
|
||||||
|
memcpy(dsti->qh, &qh, sizeof(qh));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
|
||||||
|
const float * xi = (const float *) cxi;
|
||||||
|
block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
|
||||||
|
|
||||||
|
float amax = 0.0f;
|
||||||
|
float vmax = 0.0f;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK4_NL; ++j) {
|
||||||
|
const float v = xi[j];
|
||||||
|
if (amax < sycl::fabs((float) v)) {
|
||||||
|
amax = sycl::fabs((float) v);
|
||||||
|
vmax = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float d = vmax / kvalues_iq4nl[0];
|
||||||
|
const float id = d ? 1.0f / d : 0.0f;
|
||||||
|
|
||||||
|
float sumqx = 0, sumq2 = 0;
|
||||||
|
for (int j = 0; j < QK4_NL / 2; ++j) {
|
||||||
|
const float x0 = xi[0 + j] * id;
|
||||||
|
const float x1 = xi[QK4_NL / 2 + j] * id;
|
||||||
|
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
|
||||||
|
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
|
||||||
|
dsti->qs[j] = xi0 | (xi1 << 4);
|
||||||
|
const float v0 = kvalues_iq4nl[xi0];
|
||||||
|
const float v1 = kvalues_iq4nl[xi1];
|
||||||
|
const float w0 = xi[0 + j] * xi[0 + j];
|
||||||
|
const float w1 = xi[QK4_NL / 2 + j] * xi[QK4_NL / 2 + j];
|
||||||
|
sumqx += w0 * v0 * xi[j] + w1 * v1 * xi[QK4_NL / 2 + j];
|
||||||
|
sumq2 += w0 * v0 * v0 + w1 * v1 * v1;
|
||||||
|
}
|
||||||
|
|
||||||
|
dsti->d = sumq2 > 0 ? sumqx / sumq2 : d;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <dequantize_kernel_t dequant, int qk> static void cpy_blck_q_f32(const char * cxi, char * cdsti) {
|
||||||
|
float * cdstf = (float *) (cdsti);
|
||||||
|
|
||||||
|
for (int j = 0; j < qk / 2; j++) {
|
||||||
|
dfloat2 dq;
|
||||||
|
dequant(cxi, 0, j, dq);
|
||||||
|
*(cdstf + j) = dq.x();
|
||||||
|
*(cdstf + j + qk / 2) = dq.y();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <cpy_kernel_t cpy_blck, int qk>
|
||||||
|
static void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
|
||||||
|
const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
|
||||||
|
const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
|
||||||
|
const sycl::nd_item<3> & item_ct1) {
|
||||||
|
const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk;
|
||||||
|
|
||||||
|
if (i >= ne) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int i03 = i / (ne00 * ne01 * ne02);
|
||||||
|
const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
|
||||||
|
const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
|
||||||
|
const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;
|
||||||
|
const int x_offset = i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;
|
||||||
|
|
||||||
|
const int i13 = i / (ne10 * ne11 * ne12);
|
||||||
|
const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);
|
||||||
|
const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;
|
||||||
|
const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
|
||||||
|
const int dst_offset = (i10 / qk) * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
|
||||||
|
|
||||||
|
cpy_blck(cx + x_offset, cdst + dst_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <cpy_kernel_t cpy_blck, int qk>
|
||||||
|
static void cpy_q_f32(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
|
||||||
|
const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
|
||||||
|
const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
|
||||||
|
const sycl::nd_item<3> & item_ct1) {
|
||||||
|
const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk;
|
||||||
|
|
||||||
|
if (i >= ne) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int i03 = i / (ne00 * ne01 * ne02);
|
||||||
|
const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
|
||||||
|
const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
|
||||||
|
const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;
|
||||||
|
const int x_offset = (i00 / qk) * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;
|
||||||
|
|
||||||
|
const int i13 = i / (ne10 * ne11 * ne12);
|
||||||
|
const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);
|
||||||
|
const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;
|
||||||
|
const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
|
||||||
|
const int dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
|
||||||
|
|
||||||
|
cpy_blck(cx + x_offset, cdst + dst_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_f16_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||||
|
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
|
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||||
|
|
||||||
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||||
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
cpy_f32_f16<cpy_1_f16_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
|
||||||
|
nb10, nb11, nb12, nb13, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_f32_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||||
|
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
|
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||||
|
|
||||||
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||||
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
cpy_f32_f16<cpy_1_f32_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
|
||||||
|
nb10, nb11, nb12, nb13, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_f32_f16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||||
|
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
|
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||||
|
|
||||||
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||||
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
cpy_f32_f16<cpy_1_f32_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
|
||||||
|
nb10, nb11, nb12, nb13, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_f32_q8_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||||
|
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
|
GGML_ASSERT(ne % QK8_0 == 0);
|
||||||
|
const int num_blocks = ne / QK8_0;
|
||||||
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||||
|
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||||
|
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
|
const int num_blocks = ne;
|
||||||
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||||
|
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||||
|
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
|
GGML_ASSERT(ne % QK4_0 == 0);
|
||||||
|
const int num_blocks = ne / QK4_0;
|
||||||
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||||
|
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||||
|
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
|
const int num_blocks = ne;
|
||||||
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
|
||||||
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
||||||
|
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
||||||
|
item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_f32_q4_1_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||||
|
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
|
GGML_ASSERT(ne % QK4_1 == 0);
|
||||||
|
const int num_blocks = ne / QK4_1;
|
||||||
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||||
|
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||||
|
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
|
const int num_blocks = ne;
|
||||||
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
|
||||||
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
||||||
|
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
||||||
|
item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_f32_q5_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||||
|
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
|
GGML_ASSERT(ne % QK5_0 == 0);
|
||||||
|
const int num_blocks = ne / QK5_0;
|
||||||
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||||
|
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||||
|
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
|
const int num_blocks = ne;
|
||||||
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
|
||||||
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
||||||
|
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
||||||
|
item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_f32_q5_1_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||||
|
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
|
GGML_ASSERT(ne % QK5_1 == 0);
|
||||||
|
const int num_blocks = ne / QK5_1;
|
||||||
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||||
|
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||||
|
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
|
const int num_blocks = ne;
|
||||||
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
|
||||||
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
||||||
|
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
||||||
|
item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_f32_iq4_nl_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||||
|
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
|
GGML_ASSERT(ne % QK4_NL == 0);
|
||||||
|
const int num_blocks = ne / QK4_NL;
|
||||||
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
|
||||||
|
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
|
||||||
|
ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||||
|
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
|
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||||
|
|
||||||
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||||
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
cpy_f32_f16<cpy_1_f16_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
|
||||||
|
nb10, nb11, nb12, nb13, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_i16_i16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||||
|
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
|
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
|
||||||
|
{
|
||||||
|
// dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
// {sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||||
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
cpy_f32_f16<cpy_1_i16_i16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
|
||||||
|
nb10, nb11, nb12, nb13, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||||
|
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
|
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
|
||||||
|
{
|
||||||
|
// dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
// {sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||||
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
cpy_f32_f16<cpy_1_i32_i32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
|
||||||
|
nb10, nb11, nb12, nb13, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1) try {
|
||||||
|
const int64_t ne = ggml_nelements(src0);
|
||||||
|
GGML_ASSERT(ne == ggml_nelements(src1));
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
|
||||||
|
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
|
||||||
|
|
||||||
|
GGML_TENSOR_BINARY_OP_LOCALS01;
|
||||||
|
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
queue_ptr main_stream = ctx.stream();
|
||||||
|
|
||||||
|
char * src0_ddc = (char *) src0->data;
|
||||||
|
char * src1_ddc = (char *) src1->data;
|
||||||
|
GGML_SYCL_DEBUG("[SYCL] %s: Tensor supplied: %s to %s\n", __func__, ggml_type_name(src0->type),
|
||||||
|
ggml_type_name(src1->type));
|
||||||
|
|
||||||
|
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||||
|
ggml_cpy_f32_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
|
||||||
|
nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||||
|
ggml_cpy_f32_f16_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
|
||||||
|
nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||||
|
ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
|
||||||
|
nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||||
|
ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
|
||||||
|
nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||||
|
ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
|
||||||
|
nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||||
|
ggml_cpy_f16_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
|
||||||
|
nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||||
|
ggml_cpy_f16_f16_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
|
||||||
|
nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) {
|
||||||
|
ggml_cpy_i16_i16_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
|
||||||
|
nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
|
||||||
|
ggml_cpy_i32_i32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
|
||||||
|
nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
||||||
|
ggml_cpy_q4_0_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
|
||||||
|
nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
||||||
|
ggml_cpy_q4_1_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
|
||||||
|
nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
||||||
|
ggml_cpy_q8_0_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
|
||||||
|
nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||||
|
ggml_cpy_f32_q5_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
|
||||||
|
nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
||||||
|
ggml_cpy_q5_0_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
|
||||||
|
nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||||
|
ggml_cpy_f32_q5_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
|
||||||
|
nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||||
|
ggml_cpy_q5_1_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
|
||||||
|
nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||||
|
ggml_cpy_f32_iq4_nl_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
|
||||||
|
nb10, nb11, nb12, nb13, main_stream);
|
||||||
|
} else {
|
||||||
|
GGML_LOG_ERROR("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type),
|
||||||
|
ggml_type_name(src1->type));
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
} catch (const sycl::exception & exc) {
|
||||||
|
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
||||||
|
std::exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||||
|
// TODO: why do we pass dst as src1 here?
|
||||||
|
GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
|
||||||
|
ggml_sycl_cpy(ctx, dst->src[0], dst);
|
||||||
|
GGML_SYCL_DEBUG("[SYCL] call %s done\n", __func__);
|
||||||
|
}
|
11
ggml/src/ggml-sycl/cpy.hpp
Normal file
11
ggml/src/ggml-sycl/cpy.hpp
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
#ifndef GGML_SYCL_CPY_HPP
|
||||||
|
#define GGML_SYCL_CPY_HPP
|
||||||
|
|
||||||
|
#include "common.hpp"
|
||||||
|
|
||||||
|
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
||||||
|
|
||||||
|
void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1);
|
||||||
|
void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
#endif // GGML_SYCL_CPY_HPP
|
|
@ -1285,8 +1285,6 @@ std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(q
|
||||||
// struct ggml_sycl_pool_vmm : public ggml_sycl_pool
|
// struct ggml_sycl_pool_vmm : public ggml_sycl_pool
|
||||||
|
|
||||||
/// kernels
|
/// kernels
|
||||||
|
|
||||||
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
|
||||||
typedef void (*ggml_sycl_op_mul_mat_t)(
|
typedef void (*ggml_sycl_op_mul_mat_t)(
|
||||||
ggml_backend_sycl_context & ctx,
|
ggml_backend_sycl_context & ctx,
|
||||||
const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
|
const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
|
||||||
|
@ -1468,193 +1466,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void cpy_1_f32_f32(const char * cxi, char * cdsti) {
|
|
||||||
const float * xi = (const float *) cxi;
|
|
||||||
float * dsti = (float *) cdsti;
|
|
||||||
|
|
||||||
*dsti = *xi;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void cpy_1_f32_f16(const char * cxi, char * cdsti) {
|
|
||||||
const float * xi = (const float *) cxi;
|
|
||||||
sycl::half *dsti = (sycl::half *)cdsti;
|
|
||||||
|
|
||||||
*dsti = sycl::vec<float, 1>(*xi)
|
|
||||||
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
static void cpy_1_f16_f16(const char * cxi, char * cdsti) {
|
|
||||||
const sycl::half *xi = (const sycl::half *)cxi;
|
|
||||||
sycl::half *dsti = (sycl::half *)cdsti;
|
|
||||||
|
|
||||||
*dsti = *xi;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void cpy_1_f16_f32(const char * cxi, char * cdsti) {
|
|
||||||
const sycl::half *xi = (const sycl::half *)cxi;
|
|
||||||
float * dsti = (float *) cdsti;
|
|
||||||
|
|
||||||
*dsti = *xi;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void cpy_1_i16_i16(const char * cxi, char * cdsti) {
|
|
||||||
const int16_t *xi = (const int16_t *)cxi;
|
|
||||||
int16_t *dsti = (int16_t *)cdsti;
|
|
||||||
|
|
||||||
*dsti = *xi;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void cpy_1_i32_i32(const char * cxi, char * cdsti) {
|
|
||||||
const int32_t *xi = (const int32_t *)cxi;
|
|
||||||
int32_t *dsti = (int32_t *)cdsti;
|
|
||||||
|
|
||||||
*dsti = *xi;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <cpy_kernel_t cpy_1>
|
|
||||||
static void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
||||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
|
||||||
const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
|
|
||||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
||||||
item_ct1.get_local_id(2);
|
|
||||||
|
|
||||||
if (i >= ne) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
|
|
||||||
// then combine those indices with the corresponding byte offsets to get the total offsets
|
|
||||||
const int i03 = i/(ne00 * ne01 * ne02);
|
|
||||||
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
|
||||||
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
|
||||||
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
|
||||||
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
|
||||||
|
|
||||||
const int i13 = i/(ne10 * ne11 * ne12);
|
|
||||||
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
|
||||||
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
|
||||||
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
|
||||||
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
|
|
||||||
|
|
||||||
cpy_1(cx + x_offset, cdst + dst_offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
|
||||||
const float * xi = (const float *) cxi;
|
|
||||||
block_q8_0 * dsti = (block_q8_0 *) cdsti;
|
|
||||||
|
|
||||||
float amax = 0.0f; // absolute max
|
|
||||||
|
|
||||||
for (int j = 0; j < QK8_0; j++) {
|
|
||||||
const float v = xi[j];
|
|
||||||
amax = sycl::fmax(amax, sycl::fabs((float)v));
|
|
||||||
}
|
|
||||||
|
|
||||||
const float d = amax / ((1 << 7) - 1);
|
|
||||||
const float id = d ? 1.0f/d : 0.0f;
|
|
||||||
|
|
||||||
dsti->d = d;
|
|
||||||
|
|
||||||
for (int j = 0; j < QK8_0; ++j) {
|
|
||||||
const float x0 = xi[j]*id;
|
|
||||||
|
|
||||||
dsti->qs[j] = sycl::round((float)x0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
|
|
||||||
const float * xi = (const float *) cxi;
|
|
||||||
block_q4_0 * dsti = (block_q4_0 *) cdsti;
|
|
||||||
|
|
||||||
float amax = 0.0f;
|
|
||||||
float vmax = 0.0f;
|
|
||||||
|
|
||||||
for (int j = 0; j < QK4_0; ++j) {
|
|
||||||
const float v = xi[j];
|
|
||||||
if (amax < sycl::fabs((float)v)) {
|
|
||||||
amax = sycl::fabs((float)v);
|
|
||||||
vmax = v;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const float d = vmax / -8;
|
|
||||||
const float id = d ? 1.0f/d : 0.0f;
|
|
||||||
|
|
||||||
dsti->d = d;
|
|
||||||
|
|
||||||
for (int j = 0; j < QK4_0/2; ++j) {
|
|
||||||
const float x0 = xi[0 + j]*id;
|
|
||||||
const float x1 = xi[QK4_0/2 + j]*id;
|
|
||||||
|
|
||||||
const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 8.5f));
|
|
||||||
const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 8.5f));
|
|
||||||
|
|
||||||
dsti->qs[j] = xi0;
|
|
||||||
dsti->qs[j] |= xi1 << 4;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
|
|
||||||
const float * xi = (const float *) cxi;
|
|
||||||
block_q4_1 * dsti = (block_q4_1 *) cdsti;
|
|
||||||
|
|
||||||
float vmin = FLT_MAX;
|
|
||||||
float vmax = -FLT_MAX;
|
|
||||||
|
|
||||||
for (int j = 0; j < QK4_1; ++j) {
|
|
||||||
const float v = xi[j];
|
|
||||||
|
|
||||||
if (v < vmin) vmin = v;
|
|
||||||
if (v > vmax) vmax = v;
|
|
||||||
}
|
|
||||||
|
|
||||||
const float d = (vmax - vmin) / ((1 << 4) - 1);
|
|
||||||
const float id = d ? 1.0f/d : 0.0f;
|
|
||||||
|
|
||||||
dsti->dm.x() = d;
|
|
||||||
dsti->dm.y() = vmin;
|
|
||||||
|
|
||||||
for (int j = 0; j < QK4_1/2; ++j) {
|
|
||||||
const float x0 = (xi[0 + j] - vmin)*id;
|
|
||||||
const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
|
|
||||||
|
|
||||||
const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 0.5f));
|
|
||||||
const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 0.5f));
|
|
||||||
|
|
||||||
dsti->qs[j] = xi0;
|
|
||||||
dsti->qs[j] |= xi1 << 4;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <cpy_kernel_t cpy_blck, int qk>
|
|
||||||
static void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
||||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
|
||||||
const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
|
|
||||||
const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
||||||
item_ct1.get_local_id(2)) *
|
|
||||||
qk;
|
|
||||||
|
|
||||||
if (i >= ne) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int i03 = i/(ne00 * ne01 * ne02);
|
|
||||||
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
|
||||||
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
|
||||||
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
|
||||||
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
|
||||||
|
|
||||||
const int i13 = i/(ne10 * ne11 * ne12);
|
|
||||||
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
|
||||||
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
|
||||||
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
|
||||||
const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
|
|
||||||
|
|
||||||
cpy_blck(cx + x_offset, cdst + dst_offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
|
static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int row = item_ct1.get_group(1);
|
const int row = item_ct1.get_group(1);
|
||||||
|
@ -1903,231 +1714,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void
|
|
||||||
ggml_cpy_f16_f32_sycl(const char *cx, char *cdst, const int ne, const int ne00,
|
|
||||||
const int ne01, const int ne02, const int nb00,
|
|
||||||
const int nb01, const int nb02, const int nb03,
|
|
||||||
const int ne10, const int ne11, const int ne12,
|
|
||||||
const int nb10, const int nb11, const int nb12,
|
|
||||||
const int nb13, queue_ptr stream) {
|
|
||||||
|
|
||||||
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
|
|
||||||
{
|
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
|
||||||
{sycl::aspect::fp16});
|
|
||||||
|
|
||||||
stream->parallel_for(
|
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
|
||||||
cpy_f32_f16<cpy_1_f16_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00,
|
|
||||||
nb01, nb02, nb03, ne10, ne11, ne12,
|
|
||||||
nb10, nb11, nb12, nb13, item_ct1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_cpy_f32_f32_sycl(const char *cx, char *cdst, const int ne,
|
|
||||||
const int ne00, const int ne01,
|
|
||||||
const int ne02, const int nb00,
|
|
||||||
const int nb01, const int nb02,
|
|
||||||
const int nb03, const int ne10,
|
|
||||||
const int ne11, const int ne12,
|
|
||||||
const int nb10, const int nb11,
|
|
||||||
const int nb12, const int nb13,
|
|
||||||
queue_ptr stream) {
|
|
||||||
|
|
||||||
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
|
|
||||||
{
|
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
|
||||||
{sycl::aspect::fp16});
|
|
||||||
|
|
||||||
stream->parallel_for(
|
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
|
||||||
cpy_f32_f16<cpy_1_f32_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
|
||||||
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
|
||||||
item_ct1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_cpy_f32_f16_sycl(const char *cx, char *cdst, const int ne,
|
|
||||||
const int ne00, const int ne01,
|
|
||||||
const int ne02, const int nb00,
|
|
||||||
const int nb01, const int nb02,
|
|
||||||
const int nb03, const int ne10,
|
|
||||||
const int ne11, const int ne12,
|
|
||||||
const int nb10, const int nb11,
|
|
||||||
const int nb12, const int nb13,
|
|
||||||
queue_ptr stream) {
|
|
||||||
|
|
||||||
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
|
|
||||||
{
|
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
|
||||||
{sycl::aspect::fp16});
|
|
||||||
|
|
||||||
stream->parallel_for(
|
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
|
||||||
cpy_f32_f16<cpy_1_f32_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
|
||||||
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
|
||||||
item_ct1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_cpy_f32_q8_0_sycl(const char *cx, char *cdst, const int ne,
|
|
||||||
const int ne00, const int ne01,
|
|
||||||
const int ne02, const int nb00,
|
|
||||||
const int nb01, const int nb02,
|
|
||||||
const int nb03, const int ne10,
|
|
||||||
const int ne11, const int ne12,
|
|
||||||
const int nb10, const int nb11,
|
|
||||||
const int nb12, const int nb13,
|
|
||||||
queue_ptr stream) {
|
|
||||||
|
|
||||||
GGML_ASSERT(ne % QK8_0 == 0);
|
|
||||||
const int num_blocks = ne / QK8_0;
|
|
||||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
|
|
||||||
sycl::range<3>(1, 1, 1)),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
|
||||||
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(
|
|
||||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
|
||||||
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
|
||||||
item_ct1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_cpy_f32_q4_0_sycl(const char *cx, char *cdst, const int ne,
|
|
||||||
const int ne00, const int ne01,
|
|
||||||
const int ne02, const int nb00,
|
|
||||||
const int nb01, const int nb02,
|
|
||||||
const int nb03, const int ne10,
|
|
||||||
const int ne11, const int ne12,
|
|
||||||
const int nb10, const int nb11,
|
|
||||||
const int nb12, const int nb13,
|
|
||||||
queue_ptr stream) {
|
|
||||||
|
|
||||||
GGML_ASSERT(ne % QK4_0 == 0);
|
|
||||||
const int num_blocks = ne / QK4_0;
|
|
||||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
|
|
||||||
sycl::range<3>(1, 1, 1)),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
|
||||||
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(
|
|
||||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
|
||||||
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
|
||||||
item_ct1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_cpy_f32_q4_1_sycl(const char *cx, char *cdst, const int ne,
|
|
||||||
const int ne00, const int ne01,
|
|
||||||
const int ne02, const int nb00,
|
|
||||||
const int nb01, const int nb02,
|
|
||||||
const int nb03, const int ne10,
|
|
||||||
const int ne11, const int ne12,
|
|
||||||
const int nb10, const int nb11,
|
|
||||||
const int nb12, const int nb13,
|
|
||||||
queue_ptr stream) {
|
|
||||||
|
|
||||||
GGML_ASSERT(ne % QK4_1 == 0);
|
|
||||||
const int num_blocks = ne / QK4_1;
|
|
||||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
|
|
||||||
sycl::range<3>(1, 1, 1)),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
|
||||||
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(
|
|
||||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
|
||||||
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
|
||||||
item_ct1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_cpy_f16_f16_sycl(const char *cx, char *cdst, const int ne,
|
|
||||||
const int ne00, const int ne01,
|
|
||||||
const int ne02, const int nb00,
|
|
||||||
const int nb01, const int nb02,
|
|
||||||
const int nb03, const int ne10,
|
|
||||||
const int ne11, const int ne12,
|
|
||||||
const int nb10, const int nb11,
|
|
||||||
const int nb12, const int nb13,
|
|
||||||
queue_ptr stream) {
|
|
||||||
|
|
||||||
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
|
|
||||||
{
|
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
|
||||||
{sycl::aspect::fp16});
|
|
||||||
|
|
||||||
stream->parallel_for(
|
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
|
||||||
cpy_f32_f16<cpy_1_f16_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
|
||||||
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
|
||||||
item_ct1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_cpy_i16_i16_sycl(const char *cx, char *cdst, const int ne,
|
|
||||||
const int ne00, const int ne01,
|
|
||||||
const int ne02, const int nb00,
|
|
||||||
const int nb01, const int nb02,
|
|
||||||
const int nb03, const int ne10,
|
|
||||||
const int ne11, const int ne12,
|
|
||||||
const int nb10, const int nb11,
|
|
||||||
const int nb12, const int nb13,
|
|
||||||
queue_ptr stream) {
|
|
||||||
|
|
||||||
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
|
|
||||||
{
|
|
||||||
// dpct::has_capability_or_fail(stream->get_device(),
|
|
||||||
// {sycl::aspect::fp16});
|
|
||||||
|
|
||||||
stream->parallel_for(
|
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
|
||||||
cpy_f32_f16<cpy_1_i16_i16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
|
||||||
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
|
||||||
item_ct1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_cpy_i32_i32_sycl(const char *cx, char *cdst, const int ne,
|
|
||||||
const int ne00, const int ne01,
|
|
||||||
const int ne02, const int nb00,
|
|
||||||
const int nb01, const int nb02,
|
|
||||||
const int nb03, const int ne10,
|
|
||||||
const int ne11, const int ne12,
|
|
||||||
const int nb10, const int nb11,
|
|
||||||
const int nb12, const int nb13,
|
|
||||||
queue_ptr stream) {
|
|
||||||
|
|
||||||
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
|
|
||||||
{
|
|
||||||
// dpct::has_capability_or_fail(stream->get_device(),
|
|
||||||
// {sycl::aspect::fp16});
|
|
||||||
|
|
||||||
stream->parallel_for(
|
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
|
||||||
cpy_f32_f16<cpy_1_i32_i32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
|
||||||
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
|
||||||
item_ct1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void scale_f32_sycl(const float *x, float *dst, const float scale,
|
static void scale_f32_sycl(const float *x, float *dst, const float scale,
|
||||||
const int k, queue_ptr stream) {
|
const int k, queue_ptr stream) {
|
||||||
|
@ -3645,58 +3232,6 @@ static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_clamp);
|
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_clamp);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
|
||||||
ggml_tensor *dst) try {
|
|
||||||
const int64_t ne = ggml_nelements(src0);
|
|
||||||
GGML_ASSERT(ne == ggml_nelements(src1));
|
|
||||||
|
|
||||||
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
|
|
||||||
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
|
|
||||||
|
|
||||||
GGML_TENSOR_BINARY_OP_LOCALS01;
|
|
||||||
|
|
||||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
|
||||||
queue_ptr main_stream = ctx.stream();
|
|
||||||
|
|
||||||
char * src0_ddc = (char *) src0->data;
|
|
||||||
char * src1_ddc = (char *) src1->data;
|
|
||||||
|
|
||||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
|
||||||
ggml_cpy_f32_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
|
||||||
ggml_cpy_f32_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
|
||||||
ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
|
||||||
ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
|
||||||
ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
|
||||||
ggml_cpy_f16_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
|
||||||
ggml_cpy_f16_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
||||||
} else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) {
|
|
||||||
ggml_cpy_i16_i16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
||||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
|
|
||||||
ggml_cpy_i32_i32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
||||||
} else {
|
|
||||||
GGML_LOG_ERROR("%s: unsupported type combination (%s to %s)\n", __func__,
|
|
||||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
|
||||||
GGML_ABORT("fatal error");
|
|
||||||
}
|
|
||||||
GGML_UNUSED(dst);
|
|
||||||
}
|
|
||||||
catch (sycl::exception const &exc) {
|
|
||||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
|
||||||
<< ", line:" << __LINE__ << std::endl;
|
|
||||||
std::exit(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
||||||
// TODO: why do we pass dst as src1 here?
|
|
||||||
ggml_sycl_cpy(ctx, dst->src[0], dst, nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
|
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
|
||||||
}
|
}
|
||||||
|
@ -3893,7 +3428,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
||||||
ggml_sycl_clamp(ctx, dst);
|
ggml_sycl_clamp(ctx, dst);
|
||||||
break;
|
break;
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
ggml_sycl_cpy(ctx, dst->src[0], dst->src[1], dst);
|
ggml_sycl_cpy(ctx, dst->src[0], dst->src[1]);
|
||||||
break;
|
break;
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
ggml_sycl_dup(ctx, dst);
|
ggml_sycl_dup(ctx, dst);
|
||||||
|
@ -4407,6 +3942,30 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
return false;
|
return false;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
|
|
|
@ -8460,7 +8460,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_UNARY_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
case GGML_UNARY_OP_TANH:
|
case GGML_UNARY_OP_TANH:
|
||||||
case GGML_UNARY_OP_SIGMOID:
|
case GGML_UNARY_OP_SIGMOID:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -8661,19 +8661,20 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
case GGML_OP_ACC:
|
|
||||||
case GGML_OP_SUB:
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
case GGML_OP_CONCAT:
|
|
||||||
case GGML_OP_SILU_BACK:
|
case GGML_OP_SILU_BACK:
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
case GGML_OP_RMS_NORM_BACK:
|
||||||
case GGML_OP_UPSCALE:
|
|
||||||
case GGML_OP_SCALE:
|
|
||||||
case GGML_OP_SQR:
|
case GGML_OP_SQR:
|
||||||
case GGML_OP_SIN:
|
case GGML_OP_SIN:
|
||||||
case GGML_OP_COS:
|
case GGML_OP_COS:
|
||||||
case GGML_OP_CLAMP:
|
case GGML_OP_CLAMP:
|
||||||
|
return op->src[0]->type == GGML_TYPE_F32;
|
||||||
|
case GGML_OP_ACC:
|
||||||
|
case GGML_OP_CONCAT:
|
||||||
|
case GGML_OP_UPSCALE:
|
||||||
|
case GGML_OP_SCALE:
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue