diff --git a/common/arg.cpp b/common/arg.cpp index 5f3b6ee1b..b9b02be52 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -814,13 +814,18 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_env("LLAMA_ARG_FLASH_ATTN")); add_opt(common_arg( {"-p", "--prompt"}, "PROMPT", - ex == LLAMA_EXAMPLE_MAIN - ? "prompt to start generation with\nif -cnv is set, this will be used as system prompt" - : "prompt to start generation with", + "prompt to start generation with; for system message, use -sys", [](common_params & params, const std::string & value) { params.prompt = value; } ).set_excludes({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"-sys", "--system-prompt"}, "PROMPT", + "system prompt to use with model (if applicable, depending on chat template)", + [](common_params & params, const std::string & value) { + params.system_prompt = value; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); add_opt(common_arg( {"--no-perf"}, string_format("disable internal libllama performance timings (default: %s)", params.no_perf ? "true" : "false"), @@ -2448,6 +2453,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.vocoder.use_guide_tokens = true; } ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--tts-speaker-file"}, "FNAME", + "speaker file path for audio generation", + [](common_params & params, const std::string & value) { + params.vocoder.speaker_file = value; + } + ).set_examples({LLAMA_EXAMPLE_TTS})); // model-specific add_opt(common_arg( diff --git a/common/common.h b/common/common.h index f4aa70846..adb8310c5 100644 --- a/common/common.h +++ b/common/common.h @@ -196,6 +196,8 @@ struct common_params_vocoder { std::string model = ""; // model path // NOLINT std::string model_url = ""; // model url to download // NOLINT + std::string speaker_file = ""; // speaker file path // NOLINT + bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT }; @@ -257,6 +259,7 @@ struct common_params { std::string hf_repo = ""; // HF repo // NOLINT std::string hf_file = ""; // HF file // NOLINT std::string prompt = ""; // NOLINT + std::string system_prompt = ""; // NOLINT std::string prompt_file = ""; // store the external prompt file name // NOLINT std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT std::string input_prefix = ""; // string to prefix user inputs with // NOLINT diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 30a9df4c3..bf30752bc 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -32,8 +32,6 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif -static const char * DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant"; - static llama_context ** g_ctx; static llama_model ** g_model; static common_sampler ** g_smpl; @@ -220,6 +218,10 @@ int main(int argc, char ** argv) { // print chat template example in conversation mode if (params.conversation_mode) { if (params.enable_chat_template) { + if (!params.prompt.empty()) { + LOG_WRN("*** User-specified prompt in conversation mode will be ignored, did you mean to set --system-prompt (-sys) instead?\n"); + } + LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.get(), params.use_jinja).c_str()); } else { LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); @@ -264,6 +266,7 @@ int main(int argc, char ** argv) { std::vector embd_inp; + bool waiting_for_first_input = params.conversation_mode && params.enable_chat_template && params.system_prompt.empty(); auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { common_chat_msg new_msg; new_msg.role = role; @@ -275,11 +278,20 @@ int main(int argc, char ** argv) { }; { - auto prompt = (params.conversation_mode && params.enable_chat_template) - // format the system prompt in conversation mode (fallback to default if empty) - ? chat_add_and_format("system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt) + std::string prompt; + + if (params.conversation_mode && params.enable_chat_template) { + // format the system prompt in conversation mode (will use template default if empty) + prompt = params.system_prompt; + + if (!prompt.empty()) { + prompt = chat_add_and_format("system", prompt); + } + } else { // otherwise use the prompt as is - : params.prompt; + prompt = params.prompt; + } + if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { LOG_DBG("tokenize the prompt\n"); embd_inp = common_tokenize(ctx, prompt, true, true); @@ -293,7 +305,7 @@ int main(int argc, char ** argv) { } // Should not run without any tokens - if (embd_inp.empty()) { + if (!params.conversation_mode && embd_inp.empty()) { if (add_bos) { embd_inp.push_back(llama_vocab_bos(vocab)); LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str()); @@ -477,8 +489,8 @@ int main(int argc, char ** argv) { LOG_INF( " - Press Ctrl+C to interject at any time.\n"); #endif LOG_INF( "%s", control_message); - if (params.conversation_mode && params.enable_chat_template && params.prompt.empty()) { - LOG_INF( " - Using default system message. To change it, set a different value via -p PROMPT or -f FILE argument.\n"); + if (params.conversation_mode && params.enable_chat_template && params.system_prompt.empty()) { + LOG_INF( " - Not using system message. To change it, set a different value via -sys PROMPT\n"); } LOG_INF("\n"); @@ -774,7 +786,7 @@ int main(int argc, char ** argv) { } // deal with end of generation tokens in interactive mode - if (llama_vocab_is_eog(vocab, common_sampler_last(smpl))) { + if (!waiting_for_first_input && llama_vocab_is_eog(vocab, common_sampler_last(smpl))) { LOG_DBG("found an EOG token\n"); if (params.interactive) { @@ -794,12 +806,12 @@ int main(int argc, char ** argv) { } // if current token is not EOG, we add it to current assistant message - if (params.conversation_mode) { + if (params.conversation_mode && !waiting_for_first_input) { const auto id = common_sampler_last(smpl); assistant_ss << common_token_to_piece(ctx, id, false); } - if (n_past > 0 && is_interacting) { + if ((n_past > 0 || waiting_for_first_input) && is_interacting) { LOG_DBG("waiting for user input\n"); if (params.conversation_mode) { @@ -889,11 +901,12 @@ int main(int argc, char ** argv) { input_echo = false; // do not echo this again } - if (n_past > 0) { + if (n_past > 0 || waiting_for_first_input) { if (is_interacting) { common_sampler_reset(smpl); } is_interacting = false; + waiting_for_first_input = false; } } diff --git a/examples/server/public/index.html.gz b/examples/server/public/index.html.gz index e6a22a4e3..c7a3c426b 100644 Binary files a/examples/server/public/index.html.gz and b/examples/server/public/index.html.gz differ diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index af1dcb5b9..491cb3a5d 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -144,6 +144,7 @@ def test_apply_chat_template(): @pytest.mark.parametrize("response_format,n_predicted,re_content", [ ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""), ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"), + ({"type": "json_schema", "json_schema": {"schema": {"const": "foooooo"}}}, 10, "\"foooooo\""), ({"type": "json_object"}, 10, "(\\{|John)+"), ({"type": "sound"}, 0, None), # invalid response format (expected to fail) diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index a82504235..f32a439f6 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -26,7 +26,10 @@ from re import RegexFlag import wget -DEFAULT_HTTP_TIMEOUT = 12 if "LLAMA_SANITIZE" not in os.environ else 30 +DEFAULT_HTTP_TIMEOUT = 12 + +if "LLAMA_SANITIZE" in os.environ or "GITHUB_ACTION" in os.environ: + DEFAULT_HTTP_TIMEOUT = 30 class ServerResponse: diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 6830c2e1a..144d914c2 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -590,8 +590,8 @@ static json oaicompat_completion_params_parse( if (response_type == "json_object") { json_schema = json_value(response_format, "schema", json::object()); } else if (response_type == "json_schema") { - json json_schema = json_value(response_format, "json_schema", json::object()); - json_schema = json_value(json_schema, "schema", json::object()); + auto schema_wrapper = json_value(response_format, "json_schema", json::object()); + json_schema = json_value(schema_wrapper, "schema", json::object()); } else if (!response_type.empty() && response_type != "text") { throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); } diff --git a/examples/server/webui/src/components/ChatScreen.tsx b/examples/server/webui/src/components/ChatScreen.tsx index d7a246cf6..79de30532 100644 --- a/examples/server/webui/src/components/ChatScreen.tsx +++ b/examples/server/webui/src/components/ChatScreen.tsx @@ -2,7 +2,7 @@ import { useEffect, useMemo, useRef, useState } from 'react'; import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context'; import ChatMessage from './ChatMessage'; import { CanvasType, Message, PendingMessage } from '../utils/types'; -import { classNames, throttle } from '../utils/misc'; +import { classNames, cleanCurrentUrl, throttle } from '../utils/misc'; import CanvasPyInterpreter from './CanvasPyInterpreter'; import StorageUtils from '../utils/storage'; import { useVSCodeContext } from '../utils/llama-vscode'; @@ -18,6 +18,24 @@ export interface MessageDisplay { isPending?: boolean; } +/** + * If the current URL contains "?m=...", prefill the message input with the value. + * If the current URL contains "?q=...", prefill and SEND the message. + */ +const prefilledMsg = { + content() { + const url = new URL(window.location.href); + return url.searchParams.get('m') ?? url.searchParams.get('q') ?? ''; + }, + shouldSend() { + const url = new URL(window.location.href); + return url.searchParams.has('q'); + }, + clear() { + cleanCurrentUrl(['m', 'q']); + }, +}; + function getListMessageDisplay( msgs: Readonly, leafNodeId: Message['id'] @@ -81,7 +99,7 @@ export default function ChatScreen() { canvasData, replaceMessageAndGenerate, } = useAppContext(); - const [inputMsg, setInputMsg] = useState(''); + const [inputMsg, setInputMsg] = useState(prefilledMsg.content()); const inputRef = useRef(null); const { extraContext, clearExtraContext } = useVSCodeContext( @@ -172,6 +190,22 @@ export default function ChatScreen() { const hasCanvas = !!canvasData; + useEffect(() => { + if (prefilledMsg.shouldSend()) { + // send the prefilled message if needed + sendNewMessage(); + } else { + // otherwise, focus on the input and move the cursor to the end + if (inputRef.current) { + inputRef.current.focus(); + inputRef.current.selectionStart = inputRef.current.value.length; + } + } + prefilledMsg.clear(); + // no need to keep track of sendNewMessage + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [inputRef]); + // due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg) const pendingMsgDisplay: MessageDisplay[] = pendingMsg && messages.at(-1)?.msg.id !== pendingMsg.id diff --git a/examples/server/webui/src/components/SettingDialog.tsx b/examples/server/webui/src/components/SettingDialog.tsx index 592b93fa3..b65e73ae1 100644 --- a/examples/server/webui/src/components/SettingDialog.tsx +++ b/examples/server/webui/src/components/SettingDialog.tsx @@ -148,13 +148,13 @@ const SETTING_SECTIONS: SettingSection[] = [ fields: [ { type: SettingInputType.CHECKBOX, - label: 'Expand though process by default for generating message', + label: 'Expand thought process by default when generating messages', key: 'showThoughtInProgress', }, { type: SettingInputType.CHECKBOX, label: - 'Exclude thought process when sending request to API (Recommended for DeepSeek-R1)', + 'Exclude thought process when sending requests to API (Recommended for DeepSeek-R1)', key: 'excludeThoughtOnReq', }, ], @@ -247,7 +247,7 @@ const SETTING_SECTIONS: SettingSection[] = [ This feature uses{' '} pyodide, downloaded from CDN. To use this feature, ask the LLM to generate - python code inside a markdown code block. You will see a "Run" + Python code inside a Markdown code block. You will see a "Run" button on the code block, near the "Copy" button. @@ -274,7 +274,7 @@ export default function SettingDialog({ ); const resetConfig = () => { - if (window.confirm('Are you sure to reset all settings?')) { + if (window.confirm('Are you sure you want to reset all settings?')) { setLocalConfig(CONFIG_DEFAULT); } }; @@ -296,9 +296,9 @@ export default function SettingDialog({ return; } } else if (mustBeNumeric) { - const trimedValue = value.toString().trim(); - const numVal = Number(trimedValue); - if (isNaN(numVal) || !isNumeric(numVal) || trimedValue.length === 0) { + const trimmedValue = value.toString().trim(); + const numVal = Number(trimmedValue); + if (isNaN(numVal) || !isNumeric(numVal) || trimmedValue.length === 0) { alert(`Value for ${key} must be numeric`); return; } diff --git a/examples/server/webui/src/utils/misc.ts b/examples/server/webui/src/utils/misc.ts index d46322862..87f55b2af 100644 --- a/examples/server/webui/src/utils/misc.ts +++ b/examples/server/webui/src/utils/misc.ts @@ -118,3 +118,11 @@ export const throttle = ( }, delay); }; }; + +export const cleanCurrentUrl = (removeQueryParams: string[]) => { + const url = new URL(window.location.href); + removeQueryParams.forEach((param) => { + url.searchParams.delete(param); + }); + window.history.replaceState({}, '', url.toString()); +}; diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index 35ab2bd43..ee09294b6 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -3,6 +3,7 @@ #include "sampling.h" #include "log.h" #include "llama.h" +#include "json.hpp" #define _USE_MATH_DEFINES // For M_PI on MSVC @@ -16,6 +17,13 @@ #include #include +using json = nlohmann::ordered_json; + +enum outetts_version { + OUTETTS_V0_2, + OUTETTS_V0_3, +}; + // // Terminal utils // @@ -371,7 +379,7 @@ static std::string replace_numbers_with_words(const std::string & input_text) { } // Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39 -static std::string process_text(const std::string & text) { +static std::string process_text(const std::string & text, const outetts_version tts_version = OUTETTS_V0_2) { // For now I skipped text romanization as I am unsure how to handle // uroman and MeCab implementations in C++ @@ -401,7 +409,8 @@ static std::string process_text(const std::string & text) { if (c == ' ') { prompt_clean += "<|text_sep|>"; */ - processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|text_sep|>"); + std::string separator = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|text_sep|>"; + processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), separator); return processed_text; } @@ -425,8 +434,8 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) { prompt_add(prompt, vocab, "<|im_start|>\n", true, true); } -static std::vector prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) { - const std::string& delimiter = "<|text_sep|>"; +static std::vector prepare_guide_tokens(const llama_vocab * vocab, const std::string & str, const outetts_version tts_version = OUTETTS_V0_2) { + const std::string& delimiter = (tts_version == OUTETTS_V0_3 ? "<|space|>" : "<|text_sep|>"); std::vector result; size_t start = 0; @@ -452,6 +461,78 @@ static std::vector prepare_guide_tokens(const llama_vocab * vocab, return result; } +static json speaker_from_file(const std::string & speaker_file) { + std::ifstream file(speaker_file); + if (!file) { + LOG_ERR("%s: Failed to open file '%s' for reading\n", __func__, speaker_file.c_str()); + return json(); + } + + json speaker = json::parse(file); + return speaker; +} + +static outetts_version get_tts_version(llama_model *model, json speaker = json::object()) { + if (speaker.contains("version")) { + std::string version = speaker["version"].get(); + if (version == "0.2") { + return OUTETTS_V0_2; + } else if (version == "0.3") { + return OUTETTS_V0_3; + } else { + LOG_ERR("%s: Unsupported speaker version '%s'\n", __func__, version.c_str()); + } + } + + // Also could get version from model itself + const char *chat_template = llama_model_chat_template(model, nullptr); + if (chat_template && std::string(chat_template) == "outetts-0.3") { + return OUTETTS_V0_3; + } + + // Use 0.2 as the default version + return OUTETTS_V0_2; +} + +static std::string audio_text_from_speaker(json speaker, const outetts_version tts_version = OUTETTS_V0_2) { + std::string audio_text = "<|text_start|>"; + + if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) { + std::string separator = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|text_sep|>"; + for (const auto &word : speaker["words"]) { + audio_text += word["word"].get() + separator; + } + } + + return audio_text; +} + +static std::string audio_data_from_speaker(json speaker, const outetts_version tts_version = OUTETTS_V0_2) { + std::string audio_data = "<|audio_start|>\n"; + + if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) { + std::string code_start = (tts_version == OUTETTS_V0_3) ? "" : "<|code_start|>"; + std::string code_end = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|code_end|>"; + for (const auto &word : speaker["words"]) { + std::string word_text = word["word"].get(); + double duration = word["duration"].get(); + std::vector codes = word["codes"].get>(); + + // Create the audio output entry + std::ostringstream word_entry; + word_entry << word_text << "<|t_" << std::fixed << std::setprecision(2) + << duration << "|>" + code_start; + for (const auto &Code : codes) { + word_entry << "<|" << Code << "|>"; + } + word_entry << code_end << "\n"; + audio_data += word_entry.str(); + } + } + + return audio_data; +} + int main(int argc, char ** argv) { common_params params; @@ -523,34 +604,9 @@ int main(int argc, char ** argv) { std::vector codes; std::vector guide_tokens; - // process prompt and generate voice codes - { - LOG_INF("%s: constructing prompt ..\n", __func__); - - std::vector prompt_inp; - - prompt_init(prompt_inp, vocab); - - prompt_add(prompt_inp, vocab, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true); - - // convert the input text into the necessary format expected by OuteTTS - { - std::string prompt_clean = process_text(params.prompt); - if (params.vocoder.use_guide_tokens) { - guide_tokens = prepare_guide_tokens(vocab, prompt_clean); - } - - LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str()); - - prompt_add(prompt_inp, vocab, prompt_clean, false, true); - } - - prompt_add(prompt_inp, vocab, "<|text_end|>\n", false, true); - - // disabled to save time on tokenizing each time - // TODO: load voices from the json files -#if 0 - const std::string voice_data = R"(<|audio_start|> + // the default speaker profile is from: https://github.com/edwko/OuteTTS/blob/main/outetts/version/v1/default_speakers/en_male_1.json + std::string audio_text = "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>"; + std::string audio_data = R"(<|audio_start|> the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|> overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|> package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|> @@ -582,117 +638,170 @@ it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|>< looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|> lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)"; - auto tmp = common_tokenize(vocab, voice_data, false, true); - printf("\n\n"); - for (int i = 0; i < tmp.size(); ++i) { - printf("%d, ", tmp[i]); + // audio data for 0.3 version + outetts_version tts_version = get_tts_version(model_ttc); + if (tts_version == OUTETTS_V0_3) { + audio_text = std::regex_replace(audio_text, std::regex(R"(<\|text_sep\|>)"), "<|space|>"); + audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_start\|>)"), ""); + audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_end\|>)"), "<|space|>"); + } + + // load speaker if given + if (!params.vocoder.speaker_file.empty()) { + LOG_INF("%s: loading speaker ..\n", __func__); + json speaker = speaker_from_file(params.vocoder.speaker_file); + if (speaker.empty()) { + LOG_ERR("%s: Failed to load speaker file '%s'\n", __func__, params.vocoder.speaker_file.c_str()); + return 1; } - printf("\n\n"); + audio_text = audio_text_from_speaker(speaker, tts_version); + audio_data = audio_data_from_speaker(speaker, tts_version); + } + + // process prompt and generate voice codes + { + LOG_INF("%s: constructing prompt ..\n", __func__); + + std::vector prompt_inp; + + prompt_init(prompt_inp, vocab); + + prompt_add(prompt_inp, vocab, audio_text, false, true); + + // convert the input text into the necessary format expected by OuteTTS + { + std::string prompt_clean = process_text(params.prompt, tts_version); + if (params.vocoder.use_guide_tokens) { + guide_tokens = prepare_guide_tokens(vocab, prompt_clean, tts_version); + } + + LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str()); + + prompt_add(prompt_inp, vocab, prompt_clean, false, true); + } + + prompt_add(prompt_inp, vocab, "<|text_end|>\n", false, true); + + if (!params.vocoder.speaker_file.empty()) { + prompt_add(prompt_inp, vocab, audio_data, false, true); + } else { + // disabled to save time on tokenizing each time +#if 1 + const std::string voice_data = audio_data; + + auto tmp = common_tokenize(vocab, voice_data, false, true); + printf("\n\n"); + for (size_t i = 0; i < tmp.size(); ++i) { + printf("%d, ", tmp[i]); + } + printf("\n\n"); + prompt_add(prompt_inp, tmp); #else - prompt_add(prompt_inp, llama_tokens { - 151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585, - 152460, 153375, 151670, 198, 74455, 155808, 151669, 151799, - 151873, 151863, 152446, 152372, 152204, 152728, 152229, 152470, - 151970, 153413, 152419, 153334, 153289, 153374, 153199, 152040, - 153260, 152721, 152680, 153297, 152419, 153248, 152400, 152691, - 153368, 153437, 151670, 198, 1722, 155828, 151669, 152607, - 152256, 152991, 152299, 152688, 153163, 153016, 152789, 153198, - 152712, 151911, 153107, 152623, 152170, 152395, 152852, 152207, - 152461, 153321, 153309, 151750, 152137, 153340, 152573, 152267, - 153347, 151789, 152681, 153339, 151992, 152512, 151751, 152179, - 153434, 153180, 152900, 153440, 152474, 153122, 153129, 151904, - 152311, 151670, 198, 1499, 155791, 151669, 152276, 152454, - 153354, 152544, 153204, 153272, 152708, 153433, 152319, 153226, - 153043, 152325, 153267, 152622, 151670, 198, 4250, 155797, - 151669, 153454, 153342, 151989, 152458, 153420, 152303, 152271, - 152827, 153036, 153196, 151708, 153263, 152561, 153207, 152213, - 152112, 153204, 151722, 152542, 151670, 198, 19789, 155796, - 151669, 153353, 153182, 152345, 152471, 152477, 153014, 152002, - 152191, 151734, 152312, 152810, 152237, 153224, 153169, 153224, - 152244, 153387, 153404, 151670, 198, 16069, 155811, 151669, - 152265, 151946, 151808, 152412, 152363, 152305, 153156, 152733, - 152810, 153157, 152016, 152100, 152069, 153234, 152317, 152589, - 152707, 153121, 153341, 152159, 152114, 153156, 153001, 153504, - 153376, 152272, 152433, 152325, 151941, 151670, 198, 285, - 155788, 151669, 152238, 152255, 153427, 152318, 153009, 152381, - 152474, 152680, 152157, 153255, 152324, 151682, 151670, 198, - 32955, 155804, 151669, 153490, 153419, 152364, 152405, 152682, - 152206, 152078, 153369, 152725, 153193, 153027, 152946, 152488, - 153070, 151883, 152890, 152489, 153144, 153375, 152358, 151685, - 152494, 152117, 152740, 151670, 198, 37448, 480, 155840, 151669, - 151902, 152720, 153377, 152027, 152378, 152821, 153207, 153459, - 153028, 153068, 152507, 153255, 152158, 152921, 151958, 152609, - 152748, 152822, 152286, 151714, 152730, 152377, 152353, 152470, - 152606, 152162, 152186, 153071, 152244, 153118, 153375, 153018, - 152712, 153098, 152976, 152336, 151843, 153202, 152297, 151736, - 153380, 153502, 152702, 152115, 153181, 152735, 153277, 153457, - 152393, 153112, 152595, 151670, 198, 19098, 155808, 151669, - 152464, 153452, 152595, 153312, 151937, 151933, 153197, 152239, - 153163, 152922, 153402, 152034, 152591, 153438, 152215, 151673, - 152005, 151785, 152642, 151924, 153278, 151805, 151974, 153482, - 152718, 152862, 153347, 151670, 198, 72, 155780, 151669, 151795, - 152111, 152746, 152377, 153471, 152309, 151670, 198, 19016, - 155788, 151669, 153181, 152271, 152190, 152842, 152224, 152701, - 152939, 152536, 152091, 151815, 152733, 151672, 151670, 198, - 14689, 155788, 151669, 152291, 152072, 152942, 151734, 153042, - 153504, 152589, 153333, 151839, 151941, 153038, 153180, 151670, - 198, 36996, 8303, 155832, 151669, 152231, 152256, 152835, - 152801, 152985, 153400, 152393, 152818, 152765, 152249, 152600, - 151699, 152302, 152752, 153018, 153009, 151992, 153054, 152847, - 153354, 153228, 152662, 153355, 152532, 153393, 151782, 152458, - 152048, 152757, 152428, 153195, 151906, 153006, 153178, 153250, - 152331, 152284, 152780, 153138, 153319, 151980, 153142, 152418, - 152228, 152733, 151670, 198, 9096, 155801, 151669, 151698, - 153321, 152217, 153039, 152935, 153400, 152122, 152531, 153106, - 152169, 152892, 152957, 151851, 152427, 152826, 152451, 151851, - 152901, 152885, 152594, 153446, 153080, 151670, 198, 14689, - 155795, 151669, 152658, 151700, 153321, 152450, 152530, 153191, - 151673, 151690, 151698, 152714, 152846, 152981, 153171, 153384, - 153364, 153188, 153246, 151670, 198, 1055, 155779, 151669, - 151869, 152388, 152711, 153334, 151736, 151670, 198, 1782, - 155780, 151669, 153483, 153240, 152241, 152558, 152697, 153046, - 151670, 198, 5804, 1363, 155820, 151669, 152941, 152764, 152605, - 153034, 153434, 153372, 153347, 151887, 152453, 152758, 152133, - 152510, 152694, 152431, 152321, 153088, 152676, 152223, 152581, - 152459, 152015, 152502, 153063, 152712, 153294, 153451, 153032, - 152903, 152859, 152989, 151748, 152669, 152661, 152650, 152409, - 151861, 151670, 198, 300, 7973, 155828, 151669, 153095, 152469, - 152988, 152894, 151819, 152391, 153019, 152058, 153062, 153230, - 151826, 152112, 152306, 152264, 152769, 153390, 152384, 152435, - 152790, 153393, 152983, 152540, 152252, 152034, 153107, 152540, - 151919, 151893, 152558, 152817, 152946, 152956, 152129, 152715, - 153131, 153490, 151734, 152271, 152707, 151734, 153321, 152450, - 151670, 198, 8088, 155792, 151669, 152452, 153497, 153353, - 152679, 152533, 152382, 152374, 152611, 153341, 153163, 152285, - 153411, 152495, 153141, 152320, 151670, 198, 1199, 155781, - 151669, 151764, 152360, 153295, 152634, 153342, 152199, 152271, - 151670, 198, 43366, 155799, 151669, 152308, 151682, 152889, - 152016, 152385, 152629, 152495, 151826, 153321, 152958, 152180, - 151886, 153432, 152922, 152128, 153024, 153040, 152593, 152287, - 151677, 151670, 198, 53660, 155808, 151669, 151727, 152092, - 152680, 153331, 151699, 152316, 152938, 152289, 152433, 153384, - 151781, 153137, 153259, 152175, 153213, 152291, 151869, 152691, - 152489, 151941, 152049, 152034, 153053, 152179, 153160, 151676, - 153367, 151670, 198, 268, 4123, 480, 155821, 151669, 152350, - 152173, 152536, 151991, 151960, 153144, 153013, 152358, 152234, - 153135, 152291, 153235, 152143, 152583, 152402, 153483, 152678, - 152192, 152533, 152946, 151797, 153103, 152310, 152293, 151825, - 152548, 153442, 152109, 152659, 153325, 152781, 152570, 152957, - 151752, 152265, 153381, 152515, 151670, 198, 437, 155787, - 151669, 152957, 152659, 151975, 152709, 152402, 152836, 152174, - 151792, 153409, 153327, 152990, 151670, 198, 275, 155781, - 151669, 152520, 153038, 152067, 153273, 153185, 152265, 152974, - 151670, 198, 94273, 155799, 151669, 152953, 152938, 153427, - 152244, 151920, 153423, 152929, 152367, 153052, 152129, 152331, - 152257, 152987, 152777, 153448, 152408, 151696, 152408, 152326, - 152699, 151670, 198, 385, 16239, 155828, 151669, 152306, 152268, - 153438, 153228, 152978, 152957, 153153, 153393, 152795, 152110, - 152918, 152923, 152467, 152331, 153053, 153330, 151889, 153444, - 152234, 152624, 151779, 152801, 152784, 152139, 152222, 152751, - 152512, 153287, 153141, 153052, 151840, 152589, 152508, 153499, - 152109, 152255, 151739, 152267, 152759, 153318, 153165, 153349, - 151670,}); + prompt_add(prompt_inp, llama_tokens { + 151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585, + 152460, 153375, 151670, 198, 74455, 155808, 151669, 151799, + 151873, 151863, 152446, 152372, 152204, 152728, 152229, 152470, + 151970, 153413, 152419, 153334, 153289, 153374, 153199, 152040, + 153260, 152721, 152680, 153297, 152419, 153248, 152400, 152691, + 153368, 153437, 151670, 198, 1722, 155828, 151669, 152607, + 152256, 152991, 152299, 152688, 153163, 153016, 152789, 153198, + 152712, 151911, 153107, 152623, 152170, 152395, 152852, 152207, + 152461, 153321, 153309, 151750, 152137, 153340, 152573, 152267, + 153347, 151789, 152681, 153339, 151992, 152512, 151751, 152179, + 153434, 153180, 152900, 153440, 152474, 153122, 153129, 151904, + 152311, 151670, 198, 1499, 155791, 151669, 152276, 152454, + 153354, 152544, 153204, 153272, 152708, 153433, 152319, 153226, + 153043, 152325, 153267, 152622, 151670, 198, 4250, 155797, + 151669, 153454, 153342, 151989, 152458, 153420, 152303, 152271, + 152827, 153036, 153196, 151708, 153263, 152561, 153207, 152213, + 152112, 153204, 151722, 152542, 151670, 198, 19789, 155796, + 151669, 153353, 153182, 152345, 152471, 152477, 153014, 152002, + 152191, 151734, 152312, 152810, 152237, 153224, 153169, 153224, + 152244, 153387, 153404, 151670, 198, 16069, 155811, 151669, + 152265, 151946, 151808, 152412, 152363, 152305, 153156, 152733, + 152810, 153157, 152016, 152100, 152069, 153234, 152317, 152589, + 152707, 153121, 153341, 152159, 152114, 153156, 153001, 153504, + 153376, 152272, 152433, 152325, 151941, 151670, 198, 285, + 155788, 151669, 152238, 152255, 153427, 152318, 153009, 152381, + 152474, 152680, 152157, 153255, 152324, 151682, 151670, 198, + 32955, 155804, 151669, 153490, 153419, 152364, 152405, 152682, + 152206, 152078, 153369, 152725, 153193, 153027, 152946, 152488, + 153070, 151883, 152890, 152489, 153144, 153375, 152358, 151685, + 152494, 152117, 152740, 151670, 198, 37448, 480, 155840, 151669, + 151902, 152720, 153377, 152027, 152378, 152821, 153207, 153459, + 153028, 153068, 152507, 153255, 152158, 152921, 151958, 152609, + 152748, 152822, 152286, 151714, 152730, 152377, 152353, 152470, + 152606, 152162, 152186, 153071, 152244, 153118, 153375, 153018, + 152712, 153098, 152976, 152336, 151843, 153202, 152297, 151736, + 153380, 153502, 152702, 152115, 153181, 152735, 153277, 153457, + 152393, 153112, 152595, 151670, 198, 19098, 155808, 151669, + 152464, 153452, 152595, 153312, 151937, 151933, 153197, 152239, + 153163, 152922, 153402, 152034, 152591, 153438, 152215, 151673, + 152005, 151785, 152642, 151924, 153278, 151805, 151974, 153482, + 152718, 152862, 153347, 151670, 198, 72, 155780, 151669, 151795, + 152111, 152746, 152377, 153471, 152309, 151670, 198, 19016, + 155788, 151669, 153181, 152271, 152190, 152842, 152224, 152701, + 152939, 152536, 152091, 151815, 152733, 151672, 151670, 198, + 14689, 155788, 151669, 152291, 152072, 152942, 151734, 153042, + 153504, 152589, 153333, 151839, 151941, 153038, 153180, 151670, + 198, 36996, 8303, 155832, 151669, 152231, 152256, 152835, + 152801, 152985, 153400, 152393, 152818, 152765, 152249, 152600, + 151699, 152302, 152752, 153018, 153009, 151992, 153054, 152847, + 153354, 153228, 152662, 153355, 152532, 153393, 151782, 152458, + 152048, 152757, 152428, 153195, 151906, 153006, 153178, 153250, + 152331, 152284, 152780, 153138, 153319, 151980, 153142, 152418, + 152228, 152733, 151670, 198, 9096, 155801, 151669, 151698, + 153321, 152217, 153039, 152935, 153400, 152122, 152531, 153106, + 152169, 152892, 152957, 151851, 152427, 152826, 152451, 151851, + 152901, 152885, 152594, 153446, 153080, 151670, 198, 14689, + 155795, 151669, 152658, 151700, 153321, 152450, 152530, 153191, + 151673, 151690, 151698, 152714, 152846, 152981, 153171, 153384, + 153364, 153188, 153246, 151670, 198, 1055, 155779, 151669, + 151869, 152388, 152711, 153334, 151736, 151670, 198, 1782, + 155780, 151669, 153483, 153240, 152241, 152558, 152697, 153046, + 151670, 198, 5804, 1363, 155820, 151669, 152941, 152764, 152605, + 153034, 153434, 153372, 153347, 151887, 152453, 152758, 152133, + 152510, 152694, 152431, 152321, 153088, 152676, 152223, 152581, + 152459, 152015, 152502, 153063, 152712, 153294, 153451, 153032, + 152903, 152859, 152989, 151748, 152669, 152661, 152650, 152409, + 151861, 151670, 198, 300, 7973, 155828, 151669, 153095, 152469, + 152988, 152894, 151819, 152391, 153019, 152058, 153062, 153230, + 151826, 152112, 152306, 152264, 152769, 153390, 152384, 152435, + 152790, 153393, 152983, 152540, 152252, 152034, 153107, 152540, + 151919, 151893, 152558, 152817, 152946, 152956, 152129, 152715, + 153131, 153490, 151734, 152271, 152707, 151734, 153321, 152450, + 151670, 198, 8088, 155792, 151669, 152452, 153497, 153353, + 152679, 152533, 152382, 152374, 152611, 153341, 153163, 152285, + 153411, 152495, 153141, 152320, 151670, 198, 1199, 155781, + 151669, 151764, 152360, 153295, 152634, 153342, 152199, 152271, + 151670, 198, 43366, 155799, 151669, 152308, 151682, 152889, + 152016, 152385, 152629, 152495, 151826, 153321, 152958, 152180, + 151886, 153432, 152922, 152128, 153024, 153040, 152593, 152287, + 151677, 151670, 198, 53660, 155808, 151669, 151727, 152092, + 152680, 153331, 151699, 152316, 152938, 152289, 152433, 153384, + 151781, 153137, 153259, 152175, 153213, 152291, 151869, 152691, + 152489, 151941, 152049, 152034, 153053, 152179, 153160, 151676, + 153367, 151670, 198, 268, 4123, 480, 155821, 151669, 152350, + 152173, 152536, 151991, 151960, 153144, 153013, 152358, 152234, + 153135, 152291, 153235, 152143, 152583, 152402, 153483, 152678, + 152192, 152533, 152946, 151797, 153103, 152310, 152293, 151825, + 152548, 153442, 152109, 152659, 153325, 152781, 152570, 152957, + 151752, 152265, 153381, 152515, 151670, 198, 437, 155787, + 151669, 152957, 152659, 151975, 152709, 152402, 152836, 152174, + 151792, 153409, 153327, 152990, 151670, 198, 275, 155781, + 151669, 152520, 153038, 152067, 153273, 153185, 152265, 152974, + 151670, 198, 94273, 155799, 151669, 152953, 152938, 153427, + 152244, 151920, 153423, 152929, 152367, 153052, 152129, 152331, + 152257, 152987, 152777, 153448, 152408, 151696, 152408, 152326, + 152699, 151670, 198, 385, 16239, 155828, 151669, 152306, 152268, + 153438, 153228, 152978, 152957, 153153, 153393, 152795, 152110, + 152918, 152923, 152467, 152331, 153053, 153330, 151889, 153444, + 152234, 152624, 151779, 152801, 152784, 152139, 152222, 152751, + 152512, 153287, 153141, 153052, 151840, 152589, 152508, 153499, + 152109, 152255, 151739, 152267, 152759, 153318, 153165, 153349, + 151670,}); #endif + } // print the prompt token-by-token diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 955a38d9f..2f56607ac 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -2,10 +2,8 @@ #include "ggml-backend.h" #include "ggml-impl.h" #include -#include #include #include -#include #include #include #include @@ -72,14 +70,15 @@ # pragma clang diagnostic ignored "-Wdeprecated-declarations" #endif -static std::wstring utf8_to_utf16(const std::string & str) { - std::wstring_convert> converter; - return converter.from_bytes(str); -} +namespace fs = std::filesystem; -static std::string utf16_to_utf8(const std::wstring & str) { - std::wstring_convert> converter; - return converter.to_bytes(str); +static std::string path_str(const fs::path & path) { + std::string u8path; + try { + u8path = path.u8string(); + } catch (...) { + } + return u8path; } #if defined(__clang__) @@ -96,12 +95,12 @@ struct dl_handle_deleter { } }; -static dl_handle * dl_load_library(const std::wstring & path) { +static dl_handle * dl_load_library(const fs::path & path) { // suppress error dialogs for missing DLLs DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); - HMODULE handle = LoadLibraryW(path.c_str()); + HMODULE handle = LoadLibraryW(path.wstring().c_str()); SetErrorMode(old_mode); @@ -129,8 +128,8 @@ struct dl_handle_deleter { } }; -static void * dl_load_library(const std::wstring & path) { - dl_handle * handle = dlopen(utf16_to_utf8(path).c_str(), RTLD_NOW | RTLD_LOCAL); +static void * dl_load_library(const fs::path & path) { + dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL); return handle; } @@ -217,11 +216,11 @@ struct ggml_backend_registry { devices.push_back(device); } - ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) { + ggml_backend_reg_t load_backend(const fs::path & path, bool silent) { dl_handle_ptr handle { dl_load_library(path) }; if (!handle) { if (!silent) { - GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(path).c_str()); + GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(path).c_str()); } return nullptr; } @@ -229,7 +228,7 @@ struct ggml_backend_registry { auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); if (score_fn && score_fn() == 0) { if (!silent) { - GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, utf16_to_utf8(path).c_str()); + GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path_str(path).c_str()); } return nullptr; } @@ -237,7 +236,7 @@ struct ggml_backend_registry { auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init"); if (!backend_init_fn) { if (!silent) { - GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, utf16_to_utf8(path).c_str()); + GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path_str(path).c_str()); } return nullptr; } @@ -246,16 +245,17 @@ struct ggml_backend_registry { if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) { if (!silent) { if (!reg) { - GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, utf16_to_utf8(path).c_str()); + GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", + __func__, path_str(path).c_str()); } else { GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n", - __func__, utf16_to_utf8(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION); + __func__, path_str(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION); } } return nullptr; } - GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str()); + GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_str(path).c_str()); register_backend(reg, std::move(handle)); @@ -391,14 +391,14 @@ ggml_backend_t ggml_backend_init_best(void) { // Dynamic loading ggml_backend_reg_t ggml_backend_load(const char * path) { - return get_reg().load_backend(utf8_to_utf16(path), false); + return get_reg().load_backend(path, false); } void ggml_backend_unload(ggml_backend_reg_t reg) { get_reg().unload_backend(reg, true); } -static std::wstring get_executable_path() { +static fs::path get_executable_path() { #if defined(__APPLE__) // get executable path std::vector path; @@ -416,7 +416,7 @@ static std::wstring get_executable_path() { if (last_slash != std::string::npos) { base_path = base_path.substr(0, last_slash); } - return utf8_to_utf16(base_path + "/"); + return base_path + "/"; #elif defined(__linux__) || defined(__FreeBSD__) std::string base_path = "."; std::vector path(1024); @@ -442,7 +442,7 @@ static std::wstring get_executable_path() { path.resize(path.size() * 2); } - return utf8_to_utf16(base_path + "/"); + return base_path + "/"; #elif defined(_WIN32) std::vector path(MAX_PATH); DWORD len = GetModuleFileNameW(NULL, path.data(), path.size()); @@ -462,74 +462,69 @@ static std::wstring get_executable_path() { return L""; //fix for freebsd compile } -static std::wstring backend_filename_prefix() { +static fs::path backend_filename_prefix() { #ifdef _WIN32 - return L"ggml-"; + return fs::u8path("ggml-"); #else - return L"libggml-"; + return fs::u8path("libggml-"); #endif } -static std::wstring backend_filename_suffix() { +static fs::path backend_filename_extension() { #ifdef _WIN32 - return L".dll"; + return fs::u8path(".dll"); #else - return L".so"; -#endif -} - -static std::wstring path_separator() { -#ifdef _WIN32 - return L"\\"; -#else - return L"/"; + return fs::u8path(".so"); #endif } static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, const char * user_search_path) { // enumerate all the files that match [lib]ggml-name-*.[so|dll] in the search paths - // TODO: search system paths - std::wstring file_prefix = backend_filename_prefix() + utf8_to_utf16(name) + L"-"; - std::vector search_paths; + const fs::path name_path = fs::u8path(name); + const fs::path file_prefix = backend_filename_prefix().native() + name_path.native() + fs::u8path("-").native(); + const fs::path file_extension = backend_filename_extension(); + + std::vector search_paths; if (user_search_path == nullptr) { - search_paths.push_back(L"." + path_separator()); + // default search paths: executable directory, current directory search_paths.push_back(get_executable_path()); + search_paths.push_back(fs::current_path()); } else { - search_paths.push_back(utf8_to_utf16(user_search_path) + path_separator()); + search_paths.push_back(user_search_path); } int best_score = 0; - std::wstring best_path; + fs::path best_path; - namespace fs = std::filesystem; for (const auto & search_path : search_paths) { if (!fs::exists(search_path)) { + GGML_LOG_DEBUG("%s: search path %s does not exist\n", __func__, path_str(search_path).c_str()); continue; } fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied); for (const auto & entry : dir_it) { if (entry.is_regular_file()) { - std::wstring filename = entry.path().filename().wstring(); - std::wstring ext = entry.path().extension().wstring(); - if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { - dl_handle_ptr handle { dl_load_library(entry.path().wstring()) }; + auto filename = entry.path().filename().native(); + auto ext = entry.path().extension().native(); + if (filename.find(file_prefix) == 0 && ext == file_extension) { + dl_handle_ptr handle { dl_load_library(entry) }; if (!handle && !silent) { - GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); + GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(entry.path()).c_str()); } if (handle) { auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); if (score_fn) { int s = score_fn(); #ifndef NDEBUG - GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s); + GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_str(entry.path()).c_str(), s); #endif if (s > best_score) { best_score = s; - best_path = entry.path().wstring(); + best_path = entry.path(); } } else { if (!silent) { - GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); + GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, path_str(entry.path()).c_str()); } } } @@ -541,7 +536,8 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, if (best_score == 0) { // try to load the base backend for (const auto & search_path : search_paths) { - std::wstring path = search_path + backend_filename_prefix() + utf8_to_utf16(name) + backend_filename_suffix(); + fs::path filename = backend_filename_prefix().native() + name_path.native() + backend_filename_extension().native(); + fs::path path = search_path.native() + filename.native(); if (fs::exists(path)) { return get_reg().load_backend(path, silent); } diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 6bb8bb00e..8bf676ca7 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1420,15 +1420,41 @@ inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } +inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) { + for (int i = 0; i < n; ++i) { + z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) + GGML_FP16_TO_FP32(y[i])); + } +} inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; } inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } +inline static void ggml_vec_sub_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) { + for (int i = 0; i < n; ++i) { + z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) - GGML_FP16_TO_FP32(y[i])); + } +} inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } +inline static void ggml_vec_neg_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(-GGML_FP16_TO_FP32(x[i])); + } +} + inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } +inline static void ggml_vec_mul_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) { + for (int i = 0; i < n; ++i) { + z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) * GGML_FP16_TO_FP32(y[i])); + } +} inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } +inline static void ggml_vec_div_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) { + for (int i = 0; i < n; ++i) { + z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) / GGML_FP16_TO_FP32(y[i])); + } +} static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) { assert(nrc == 1); @@ -1815,22 +1841,107 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } +inline static void ggml_vec_sqr_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + float v = GGML_FP16_TO_FP32(x[i]); + y[i] = GGML_FP32_TO_FP16(v*v); + } +} inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } +inline static void ggml_vec_sqrt_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(sqrtf(GGML_FP16_TO_FP32(x[i]))); + } +} inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); } +inline static void ggml_vec_log_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(logf(GGML_FP16_TO_FP32(x[i]))); + } +} inline static void ggml_vec_sin_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]); } +inline static void ggml_vec_sin_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(sinf(GGML_FP16_TO_FP32(x[i]))); + } +} inline static void ggml_vec_cos_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]); } +inline static void ggml_vec_cos_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(cosf(GGML_FP16_TO_FP32(x[i]))); + } +} inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } +inline static void ggml_vec_abs_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(fabsf(GGML_FP16_TO_FP32(x[i]))); + } +} inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } +inline static void ggml_vec_sgn_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + float v = GGML_FP16_TO_FP32(x[i]); + y[i] = GGML_FP32_TO_FP16((v > 0.f) ? 1.f : ((v < 0.f) ? -1.f : 0.f)); + } +} inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } +inline static void ggml_vec_step_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16((GGML_FP16_TO_FP32(x[i]) > 0.f) ? 1.f : 0.f); + } +} inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); } +inline static void ggml_vec_tanh_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(tanhf(GGML_FP16_TO_FP32(x[i]))); + } +} inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); } +inline static void ggml_vec_elu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(expm1f(GGML_FP16_TO_FP32(x[i]))); + } +} inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } +inline static void ggml_vec_relu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + float v = GGML_FP16_TO_FP32(x[i]); + y[i] = GGML_FP32_TO_FP16((v > 0.f) ? v : 0.f); + } +} inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); } +inline static void ggml_vec_leaky_relu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const float ns) { + for (int i = 0; i < n; ++i) { + float v = GGML_FP16_TO_FP32(x[i]); + y[i] = GGML_FP32_TO_FP16(((v > 0.f) ? v : 0.f) + ns * ((v < 0.0f) ? v : 0.f)); + } +} inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); } +inline static void ggml_vec_sigmoid_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(1.f / (1.f + expf(-GGML_FP16_TO_FP32(x[i])))); + } +} // TODO: optimize performance inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } +inline static void ggml_vec_hardswish_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + float v = GGML_FP16_TO_FP32(x[i]); + y[i] = GGML_FP32_TO_FP16(v * fminf(1.0f, fmaxf(0.0f, (v + 3.0f) / 6.0f))); + } +} inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } +inline static void ggml_vec_hardsigmoid_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(fminf(1.0f, fmaxf(0.0f, (GGML_FP16_TO_FP32(x[i]) + 3.0f) / 6.0f))); + } +} inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); } +inline static void ggml_vec_exp_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(expf(GGML_FP16_TO_FP32(x[i]))); + } +} static const float GELU_COEF_A = 0.044715f; static const float GELU_QUICK_COEF = -1.702f; @@ -1898,10 +2009,21 @@ inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * } #endif +inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + float v = GGML_FP16_TO_FP32(x[i]); + y[i] = GGML_FP32_TO_FP16(v*(1.0f/(1.0f+expf(GELU_QUICK_COEF*v)))); + } +} + // Sigmoid Linear Unit (SiLU) function inline static float ggml_silu_f32(float x) { return x/(1.0f + expf(-x)); } +inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) { + float v = GGML_FP16_TO_FP32(x); + return GGML_FP32_TO_FP16(v/(1.0f + expf(-v))); +} #if __FINITE_MATH_ONLY__ #error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix" @@ -2125,6 +2247,12 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) { } } +inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_silu_f16(x[i]); + } +} + static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { int i = 0; ggml_float sum = 0; @@ -2196,12 +2324,24 @@ inline static float ggml_silu_backward_f32(float x, float dy) { return dy*s*(1.0f + x*(1.0f - s)); } +inline static ggml_fp16_t ggml_silu_backward_f16(ggml_fp16_t x, ggml_fp16_t dy) { + const float v = GGML_FP16_TO_FP32(x); + const float s = 1.0f/(1.0f + expf(-v)); + return GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(dy)*s*(1.0f + v*(1.0f - s))); +} + inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) { for (int i = 0; i < n; ++i) { dx[i] = ggml_silu_backward_f32(x[i], dy[i]); } } +inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, const ggml_fp16_t * x, const ggml_fp16_t * dy) { + for (int i = 0; i < n; ++i) { + dx[i] = ggml_silu_backward_f16(x[i], dy[i]); + } +} + inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE ggml_float sum = 0.0; @@ -4397,7 +4537,7 @@ static void ggml_compute_forward_add_f16_f16( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); const int ith = params->ith; const int nth = params->nth; @@ -4422,17 +4562,22 @@ static void ggml_compute_forward_add_f16_f16( if (nb10 == sizeof(ggml_fp16_t)) { for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(src1_ptr[i])); + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + for (int64_t r = 0; r < nr0; ++r) { + ggml_vec_add_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); } } } @@ -5220,6 +5365,62 @@ static void ggml_compute_forward_sub_f32( } } +static void ggml_compute_forward_sub_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F16); + + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(ggml_fp16_t)) { + for (int ir = ir0; ir < ir1; ++ir) { + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + for (int64_t r = 0; r < nr0; ++r) { + ggml_vec_sub_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); + } + } + } else { + // src1 is not contiguous + GGML_ABORT("unimplemented error"); + } +} + static void ggml_compute_forward_sub( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -5231,6 +5432,10 @@ static void ggml_compute_forward_sub( { ggml_compute_forward_sub_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_sub_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -5323,6 +5528,55 @@ static void ggml_compute_forward_mul_f32( } } +static void ggml_compute_forward_mul_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F16); + + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + if (nb10 == sizeof(ggml_fp16_t)) { + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + for (int64_t r = 0 ; r < nr0; ++r) { + ggml_vec_mul_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); + } + } + } else { + // src1 is not contiguous + GGML_ABORT("unimplemented error"); + } +} + static void ggml_compute_forward_mul( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -5330,13 +5584,17 @@ static void ggml_compute_forward_mul( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(src1->type == GGML_TYPE_F32 && "only f32 src1 supported for now"); + GGML_ASSERT((src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16) && "only f32/f16 src1 supported for now"); switch (src0->type) { case GGML_TYPE_F32: { ggml_compute_forward_mul_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_mul_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -5417,6 +5675,55 @@ static void ggml_compute_forward_div_f32( } } +static void ggml_compute_forward_div_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F16); + + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + if (nb10 == sizeof(ggml_fp16_t)) { + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + for (int64_t r = 0; r < nr0; ++r) { + ggml_vec_div_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); + } + } + } else { + // src1 is not contiguous + GGML_ABORT("unimplemented error"); + } +} + static void ggml_compute_forward_div( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -5428,6 +5735,10 @@ static void ggml_compute_forward_div( { ggml_compute_forward_div_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_div_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -5462,6 +5773,31 @@ static void ggml_compute_forward_sqr_f32( } } +static void ggml_compute_forward_sqr_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(ggml_fp16_t)); + assert(src0->nb[0] == sizeof(ggml_fp16_t)); + + for (int i = 0; i < n; i++) { + ggml_vec_sqr_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + static void ggml_compute_forward_sqr( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -5473,6 +5809,10 @@ static void ggml_compute_forward_sqr( { ggml_compute_forward_sqr_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_sqr_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -5507,6 +5847,31 @@ static void ggml_compute_forward_sqrt_f32( } } +static void ggml_compute_forward_sqrt_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(ggml_fp16_t)); + assert(src0->nb[0] == sizeof(ggml_fp16_t)); + + for (int i = 0; i < n; i++) { + ggml_vec_sqrt_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + static void ggml_compute_forward_sqrt( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -5518,6 +5883,10 @@ static void ggml_compute_forward_sqrt( { ggml_compute_forward_sqrt_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_sqrt_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -5552,6 +5921,31 @@ static void ggml_compute_forward_log_f32( } } +static void ggml_compute_forward_log_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + GGML_ASSERT( dst->nb[0] == sizeof(ggml_fp16_t)); + GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t)); + + for (int i = 0; i < n; i++) { + ggml_vec_log_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + static void ggml_compute_forward_log( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -5563,6 +5957,10 @@ static void ggml_compute_forward_log( { ggml_compute_forward_log_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_log_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -5597,6 +5995,31 @@ static void ggml_compute_forward_sin_f32( } } +static void ggml_compute_forward_sin_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + GGML_ASSERT( dst->nb[0] == sizeof(ggml_fp16_t)); + GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t)); + + for (int i = 0; i < n; i++) { + ggml_vec_sin_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + static void ggml_compute_forward_sin( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -5608,6 +6031,10 @@ static void ggml_compute_forward_sin( { ggml_compute_forward_sin_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_sin_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -5642,6 +6069,31 @@ static void ggml_compute_forward_cos_f32( } } +static void ggml_compute_forward_cos_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + GGML_ASSERT( dst->nb[0] == sizeof(ggml_fp16_t)); + GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t)); + + for (int i = 0; i < n; i++) { + ggml_vec_cos_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + static void ggml_compute_forward_cos( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -5653,6 +6105,10 @@ static void ggml_compute_forward_cos( { ggml_compute_forward_cos_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_cos_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -6310,6 +6766,30 @@ static void ggml_compute_forward_abs_f32( } } +static void ggml_compute_forward_abs_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_abs_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + static void ggml_compute_forward_abs( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -6321,6 +6801,10 @@ static void ggml_compute_forward_abs( { ggml_compute_forward_abs_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_abs_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -6354,6 +6838,30 @@ static void ggml_compute_forward_sgn_f32( } } +static void ggml_compute_forward_sgn_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_sgn_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + static void ggml_compute_forward_sgn( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -6365,6 +6873,10 @@ static void ggml_compute_forward_sgn( { ggml_compute_forward_sgn_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_sgn_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -6398,6 +6910,30 @@ static void ggml_compute_forward_neg_f32( } } +static void ggml_compute_forward_neg_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_neg_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + static void ggml_compute_forward_neg( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -6409,6 +6945,10 @@ static void ggml_compute_forward_neg( { ggml_compute_forward_neg_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_neg_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -6442,6 +6982,30 @@ static void ggml_compute_forward_step_f32( } } +static void ggml_compute_forward_step_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_step_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + static void ggml_compute_forward_step( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -6453,6 +7017,10 @@ static void ggml_compute_forward_step( { ggml_compute_forward_step_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_step_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -6486,6 +7054,30 @@ static void ggml_compute_forward_tanh_f32( } } +static void ggml_compute_forward_tanh_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_tanh_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + static void ggml_compute_forward_tanh( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -6497,6 +7089,10 @@ static void ggml_compute_forward_tanh( { ggml_compute_forward_tanh_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_tanh_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -6530,6 +7126,30 @@ static void ggml_compute_forward_elu_f32( } } +static void ggml_compute_forward_elu_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_elu_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + static void ggml_compute_forward_elu( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -6541,6 +7161,10 @@ static void ggml_compute_forward_elu( { ggml_compute_forward_elu_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_elu_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -6574,6 +7198,30 @@ static void ggml_compute_forward_relu_f32( } } +static void ggml_compute_forward_relu_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_relu_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + static void ggml_compute_forward_relu( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -6585,6 +7233,10 @@ static void ggml_compute_forward_relu( { ggml_compute_forward_relu_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_relu_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -6618,6 +7270,30 @@ static void ggml_compute_forward_sigmoid_f32( } } +static void ggml_compute_forward_sigmoid_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_sigmoid_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + static void ggml_compute_forward_sigmoid( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -6629,6 +7305,10 @@ static void ggml_compute_forward_sigmoid( { ggml_compute_forward_sigmoid_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_sigmoid_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -6677,6 +7357,46 @@ static void ggml_compute_forward_gelu_f32( } } +static void ggml_compute_forward_gelu_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_gelu_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + static void ggml_compute_forward_gelu( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -6688,6 +7408,10 @@ static void ggml_compute_forward_gelu( { ggml_compute_forward_gelu_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_gelu_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -6736,6 +7460,46 @@ static void ggml_compute_forward_gelu_quick_f32( } } +static void ggml_compute_forward_gelu_quick_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_gelu_quick_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + static void ggml_compute_forward_gelu_quick( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -6747,6 +7511,10 @@ static void ggml_compute_forward_gelu_quick( { ggml_compute_forward_gelu_quick_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_gelu_quick_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -6795,6 +7563,46 @@ static void ggml_compute_forward_silu_f32( } } +static void ggml_compute_forward_silu_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_silu_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + static void ggml_compute_forward_silu( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -6806,6 +7614,10 @@ static void ggml_compute_forward_silu( { ggml_compute_forward_silu_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_silu_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -6844,6 +7656,36 @@ static void ggml_compute_forward_leaky_relu_f32( } } +static void ggml_compute_forward_leaky_relu_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + float negative_slope; + memcpy(&negative_slope, dst->op_params, sizeof(float)); + + assert(dst->nb[0] == sizeof(ggml_fp16_t)); + assert(src0->nb[0] == sizeof(ggml_fp16_t)); + + for (int i = 0; i < n; i++) { + ggml_vec_leaky_relu_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])), negative_slope); + } +} + static void ggml_compute_forward_leaky_relu( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -6855,6 +7697,10 @@ static void ggml_compute_forward_leaky_relu( { ggml_compute_forward_leaky_relu_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_leaky_relu_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -6907,6 +7753,50 @@ static void ggml_compute_forward_silu_back_f32( } } +static void ggml_compute_forward_silu_back_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * grad = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + assert(ggml_is_contiguous_1(grad)); + assert(ggml_is_contiguous_1(src1)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src1, dst)); + assert(ggml_are_same_shape(src1, grad)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1->ne[0]; + const int nr = ggml_nrows(src1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_silu_backward_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])), + (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1]))); + + #ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } + #endif + } +} + static void ggml_compute_forward_silu_back( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -6918,6 +7808,10 @@ static void ggml_compute_forward_silu_back( { ggml_compute_forward_silu_back_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_silu_back_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -6925,7 +7819,6 @@ static void ggml_compute_forward_silu_back( } } - static void ggml_compute_forward_hardswish_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -6949,6 +7842,31 @@ static void ggml_compute_forward_hardswish_f32( (float *) ((char *) src0->data + i*(src0->nb[1]))); } } + +static void ggml_compute_forward_hardswish_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_hardswish_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + static void ggml_compute_forward_hardswish( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -6960,6 +7878,10 @@ static void ggml_compute_forward_hardswish( { ggml_compute_forward_hardswish_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_hardswish_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -6991,6 +7913,30 @@ static void ggml_compute_forward_hardsigmoid_f32( } } +static void ggml_compute_forward_hardsigmoid_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_hardsigmoid_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + static void ggml_compute_forward_hardsigmoid( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -7002,6 +7948,10 @@ static void ggml_compute_forward_hardsigmoid( { ggml_compute_forward_hardsigmoid_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_hardsigmoid_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -7033,6 +7983,30 @@ static void ggml_compute_forward_exp_f32( } } +static void ggml_compute_forward_exp_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_exp_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + static void ggml_compute_forward_exp( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -7044,6 +8018,10 @@ static void ggml_compute_forward_exp( { ggml_compute_forward_exp_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_exp_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -9337,6 +10315,43 @@ static void ggml_compute_forward_clamp_f32( } } +static void ggml_compute_forward_clamp_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + float min; + float max; + memcpy(&min, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + for (int j = ith; j < n; j += nth) { + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01); + + for (int i = 0; i < nc; i++) { + float v = GGML_FP16_TO_FP32(src0_ptr[i]); + dst_ptr[i] = GGML_FP32_TO_FP16(MAX(MIN(v, max), min)); + } + } +} + static void ggml_compute_forward_clamp( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -9349,6 +10364,9 @@ static void ggml_compute_forward_clamp( ggml_compute_forward_clamp_f32(params, dst); } break; case GGML_TYPE_F16: + { + ggml_compute_forward_clamp_f16(params, dst); + } break; case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 66685fd16..4dff5c67e 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -190,10 +190,11 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc } } // namespace ggml::cpu::kleidiai -static void ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { +GGML_API enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor); GGML_UNUSED(buffer); + return GGML_STATUS_SUCCESS; } static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index ce4b9cfb5..e1fbf0e13 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -294,11 +294,13 @@ static void ggml_cuda_op_bin_bcast( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) { - GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream); - } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream); } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream); diff --git a/ggml/src/ggml-cuda/clamp.cu b/ggml/src/ggml-cuda/clamp.cu index 8009a3e3d..fe415e7f7 100644 --- a/ggml/src/ggml-cuda/clamp.cu +++ b/ggml/src/ggml-cuda/clamp.cu @@ -1,34 +1,45 @@ #include "clamp.cuh" -static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) { +static __device__ __forceinline__ float op_clamp(float x, float min, float max) { + return fminf(fmaxf(x, min), max); +} + +template +static __global__ void op_clamp_kernel(const T * x, T * dst, const T min, const T max, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } - dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); + dst[i] = (T)op_clamp((float)x[i], (float)min, (float)max); } -static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) { +template +static void clamp_cuda(const T * x, T * dst, const T min, const T max, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE; - clamp_f32<<>>(x, dst, min, max, k); + op_clamp_kernel<<>>(x, dst, min, max, k); } void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; + const void * src0_d = src0->data; + void * dst_d = dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); float min; float max; memcpy(&min, dst->op_params, sizeof(float)); memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); - clamp_f32_cuda(src0_d, dst_d, min, max, ggml_nelements(src0), stream); + if (src0->type == GGML_TYPE_F16) { + clamp_cuda((const half *)src0_d, (half *)dst_d, (half)min, (half)max, ggml_nelements(src0), stream); + } else { + clamp_cuda((const float *)src0_d, (float *)dst_d, (float)min, (float)max, ggml_nelements(src0), stream); + } } diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index adf0d3ecb..1832314ec 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -62,6 +62,7 @@ #define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a #define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA +#define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD) #define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1) #define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2) #define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3) @@ -196,6 +197,10 @@ typedef float2 dfloat2; #define FP16_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3)) +#define FP16_MMA_AVAILABLE +#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3)) + #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING #define NEW_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING @@ -223,12 +228,18 @@ static bool fast_fp16_hardware_available(const int cc) { // Any FP16 tensor core instructions are available for ggml code. static bool fp16_mma_available(const int cc) { - return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA; +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) + return false; +#else + return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA || + GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3; +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) } // To be used for feature selection of external libraries, e.g. cuBLAS. static bool fp16_mma_hardware_available(const int cc) { - return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA; + return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA || + GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3; } // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 7b9566fb4..46de14093 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -57,12 +57,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); GGML_UNUSED(Q_v); T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -70,7 +71,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( const int shift = k_KQ & (QI8_1/2); const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; - const int u = Q_q8[k_KQ_0/WARP_SIZE]; + const int u = Q_q8[k_KQ_0/warp_size]; const int sumi = ggml_cuda_dp4a(v, u, 0); @@ -78,14 +79,14 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; - const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE]; + const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/warp_size]; sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */); } else #endif // FP16_AVAILABLE { const float2 * Q_ds = (const float2 *) Q_ds_v; - sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y)); + sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (8/QI8_1)*Q_ds[k_KQ_0/warp_size].y)); } } @@ -97,12 +98,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); GGML_UNUSED(Q_v); T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -110,7 +112,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( const int shift = k_KQ & (QI8_1/2); const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; - const int u = Q_q8[k_KQ_0/WARP_SIZE]; + const int u = Q_q8[k_KQ_0/warp_size]; const int sumi = ggml_cuda_dp4a(v, u, 0); @@ -118,7 +120,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; - const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE]; + const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/warp_size]; const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1); sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled)); } else @@ -126,8 +128,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( { const float2 * Q_ds = (const float2 *) Q_ds_v; - const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi; - const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1; + const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi; + const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1; sum += (T) (sumid4d8 + m4s8scaled); } @@ -141,12 +143,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); GGML_UNUSED(Q_v); T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -161,7 +164,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( v |= (vh << 18) & 0x00100000; // 2 -> 20 v |= (vh << 25) & 0x10000000; // 3 -> 28 - const int u = Q_q8[k_KQ_0/WARP_SIZE]; + const int u = Q_q8[k_KQ_0/warp_size]; const int sumi = ggml_cuda_dp4a(v, u, 0); @@ -169,14 +172,14 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; - const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE]; + const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/warp_size]; sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */; } else #endif // FP16_AVAILABLE { const float2 * Q_ds = (const float2 *) Q_ds_v; - sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y)); + sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (16/QI8_1)*Q_ds[k_KQ_0/warp_size].y)); } } @@ -188,12 +191,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); GGML_UNUSED(Q_v); T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -208,7 +212,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( v |= (vh << 18) & 0x00100000; // 2 -> 20 v |= (vh << 25) & 0x10000000; // 3 -> 28 - const int u = Q_q8[k_KQ_0/WARP_SIZE]; + const int u = Q_q8[k_KQ_0/warp_size]; const int sumi = ggml_cuda_dp4a(v, u, 0); @@ -216,7 +220,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; - const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE]; + const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/warp_size]; const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1); sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled)); } else @@ -224,8 +228,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( { const float2 * Q_ds = (const float2 *) Q_ds_v; - const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi; - const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1; + const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi; + const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1; sum += (T) (sumid5d8 + m5s8scaled); } @@ -239,12 +243,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); GGML_UNUSED(Q_v); T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_0; @@ -255,13 +260,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( T Q_d; if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; - Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]); + Q_d = __low2half(Q_ds[k_KQ_0/warp_size]); } else { const float2 * Q_ds = (const float2 *) Q_ds_v; - Q_d = Q_ds[k_KQ_0/WARP_SIZE].x; + Q_d = Q_ds[k_KQ_0/warp_size].x; } - sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d); + sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d); } return sum; @@ -272,6 +277,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { const half2 * K_h2 = (const half2 *) K_c; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); GGML_UNUSED(Q_q8); GGML_UNUSED(Q_ds_v); @@ -282,11 +288,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( half2 sum2 = make_half2(0.0f, 0.0f); #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const half2 K_ik = K_h2[k_KQ]; - sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE]; + sum2 += K_ik * Q_h2[k_KQ_0/warp_size]; } return __low2half(sum2) + __high2half(sum2); @@ -298,12 +304,12 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( float sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const half2 K_ik = K_h2[k_KQ]; - sum += __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x; - sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y; + sum += __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x; + sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y; } return sum; @@ -698,6 +704,8 @@ void launch_fattn( GGML_ASSERT(Q->ne[3] == 1); + const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size; + ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); const int id = ggml_cuda_get_device(); @@ -750,7 +758,7 @@ void launch_fattn( const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; - const dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 block_dim(warp_size, nwarps, 1); dim3 blocks_num; if (parallel_blocks == 0) { // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. @@ -796,6 +804,8 @@ void launch_fattn( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + GGML_ASSERT(block_dim.x % warp_size == 0); + GGML_ASSERT(!GGML_CUDA_CC_IS_AMD(cc) || block_dim.x * block_dim.y <= 4 * (unsigned int)warp_size); fattn_kernel<<>>( (const char *) Q->data, K_data, diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 8828652fb..622cf2857 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -7,14 +7,19 @@ #include "fattn-wmma-f16.cuh" #ifdef FP16_MMA_AVAILABLE +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) #include +namespace wmma = nvcuda::wmma; +#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE) +#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers +#include +namespace wmma = rocwmma; +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) #endif // FP16_MMA_AVAILABLE // D == head size, VKQ_stride == num VKQ rows calculated in parallel: template -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(nwarps*WARP_SIZE, 1) -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -51,7 +56,7 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { -#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -60,6 +65,8 @@ static __global__ void flash_attn_ext_f16( //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. @@ -68,11 +75,11 @@ static __global__ void flash_attn_ext_f16( constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); - typedef nvcuda::wmma::fragment frag_a_K; - typedef nvcuda::wmma::fragment frag_a_V; - typedef nvcuda::wmma::fragment frag_b; - typedef nvcuda::wmma::fragment frag_c_KQ; - typedef nvcuda::wmma::fragment frag_c_VKQ; + typedef wmma::fragment frag_a_K; + typedef wmma::fragment frag_a_V; + typedef wmma::fragment frag_b; + typedef wmma::fragment frag_c_KQ; + typedef wmma::fragment frag_c_VKQ; constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. @@ -132,9 +139,9 @@ static __global__ void flash_attn_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < D/2; i0 += warp_size) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { + if (i0 + warp_size > D/2 && i >= D/2) { break; } VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); @@ -146,9 +153,9 @@ static __global__ void flash_attn_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < D; i0 += warp_size) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { + if (i0 + warp_size > D && i >= D) { break; } KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; @@ -162,7 +169,7 @@ static __global__ void flash_attn_ext_f16( for (int i0 = 0; i0 < D; i0 += 16) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); } } @@ -176,20 +183,20 @@ static __global__ void flash_attn_ext_f16( frag_c_KQ KQ_c[ncols/frag_n]; #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); + wmma::fill_fragment(KQ_c[j], static_cast(0.0f)); } #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { frag_a_K K_a; - nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); + wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); + wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); } } #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); + wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major); } } @@ -202,27 +209,27 @@ static __global__ void flash_attn_ext_f16( const int j = j0 + threadIdx.y; if (std::is_same::value) { - float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE]; + float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; + KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k]; if (use_logit_softcap) { - KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]); + KQ_f_tmp[k0/warp_size] = logit_softcap*tanhf(KQ_f_tmp[k0/warp_size]); } } float KQ_max_new = KQ_max_f[j0/nwarps]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; - KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); + KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]); } - KQ_max_new = warp_reduce_max(KQ_max_new); + KQ_max_new = warp_reduce_max(KQ_max_new); const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; KQ_max_scale_f[j0/nwarps] = expf(diff); @@ -233,48 +240,48 @@ static __global__ void flash_attn_ext_f16( float KQ_rowsum_add = 0.0f; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { const int k = k0 + threadIdx.x; - const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps]; - KQ_f_tmp[k0/WARP_SIZE] = expf(diff); + const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps]; + KQ_f_tmp[k0/warp_size] = expf(diff); if (diff <= SOFTMAX_FTZ_THRESHOLD) { - KQ_f_tmp[k0/WARP_SIZE] = 0.0f; + KQ_f_tmp[k0/warp_size] = 0.0f; } - KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE]; - KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE]; + KQ_rowsum_add += KQ_f_tmp[k0/warp_size]; + KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/warp_size]; } - KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); // Scale previous KQ_rowsum to account for a potential increase in KQ_max: KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; } else { - half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; + half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k]; if (use_logit_softcap) { // There is no dedicated tangens hyperbolicus function for half2. - KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f)); - KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f)) - /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f)); + KQ2_tmp[k0/warp_size] = h2exp(KQ2_tmp[k0/warp_size]*make_half2(2.0f, 2.0f)); + KQ2_tmp[k0/warp_size] = (KQ2_tmp[k0/warp_size] - make_half2(1.0f, 1.0f)) + /(KQ2_tmp[k0/warp_size] + make_half2(1.0f, 1.0f)); - KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2; + KQ2_tmp[k0/warp_size] *= logit_softcap_2; } } half2 KQ_max_new = KQ_max_h2[j0/nwarps]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); + KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]); } - KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); + KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; KQ_max_scale_h2[j0/nwarps] = h2exp(diff); const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); @@ -283,17 +290,17 @@ static __global__ void flash_attn_ext_f16( half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { const int k = k0 + threadIdx.x; - const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; - KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); + const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps]; + KQ2_tmp[k0/warp_size] = h2exp(diff); const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; - KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; - KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; + *((uint32_t *) &KQ2_tmp[k0/warp_size]) &= ftz_mask; + KQ_rowsum_add += KQ2_tmp[k0/warp_size]; + KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/warp_size]; } - KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); // Scale previous KQ_rowsum to account for a potential increase in KQ_max: KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; @@ -308,7 +315,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { const int k = k0 + (threadIdx.y % VKQ_ratio)*16; - nvcuda::wmma::load_matrix_sync( + wmma::load_matrix_sync( KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], KQ + j0*(kqar*kqs_padded) + k, kqar*kqs_padded); @@ -320,7 +327,7 @@ static __global__ void flash_attn_ext_f16( for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); + wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast(0.0f)); } #pragma unroll @@ -328,10 +335,10 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; - nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); + wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); + wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); } } } @@ -343,10 +350,10 @@ static __global__ void flash_attn_ext_f16( for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync( + wmma::store_matrix_sync( KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], - D_padded, nvcuda::wmma::mem_col_major); + D_padded, wmma::mem_col_major); } } @@ -364,9 +371,9 @@ static __global__ void flash_attn_ext_f16( } #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < D/2; i0 += warp_size) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { + if (i0 + warp_size > D/2 && i >= D/2) { break; } @@ -398,9 +405,9 @@ static __global__ void flash_attn_ext_f16( } #pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < D; i0 += warp_size) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { + if (i0 + warp_size > D && i >= D) { break; } float dst_val = VKQ[j_VKQ*D_padded + i]; @@ -425,7 +432,7 @@ static __global__ void flash_attn_ext_f16( } #else NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) } constexpr int get_max_power_of_2(int x) { @@ -515,6 +522,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten const ggml_tensor * Q = dst->src[0]; const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); + const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size; if (prec != GGML_PREC_DEFAULT) { if (Q->ne[1] <= 32 || Q->ne[0] > 128) { @@ -571,7 +579,8 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten return; } - if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + if (Q->ne[1] <= 8 && Q->ne[0] % warp_size == 0) { constexpr int cols_per_block = 8; switch (Q->ne[0]) { case 64: @@ -592,6 +601,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten } return; } +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) if (Q->ne[1] <= 32) { constexpr int cols_per_block = 16; diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index b1becccb4..24f973056 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -250,10 +250,18 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); - // On AMD the tile kernels perform poorly, use the vec kernel instead: if (cc >= GGML_CUDA_CC_OFFSET_AMD) { +#if defined(GGML_HIP_ROCWMMA_FATTN) + if (fp16_mma_available(cc)) { + ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + return; + } +#endif // defined(GGML_HIP_ROCWMMA_FATTN) + + // On AMD the tile kernels perform poorly, use the vec kernel instead: if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); } else { @@ -291,7 +299,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const int gqa_ratio = Q->ne[2] / K->ne[2]; const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask; - if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0 && !mma_fast_for_bs1) { + if (Q->ne[1] == 1 && Q->ne[0] % (2*warp_size) == 0 && !mma_fast_for_bs1) { if (prec == GGML_PREC_DEFAULT) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); return; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 98afd04cf..bcb5f30ab 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2152,6 +2152,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg break; case GGML_OP_UNARY: switch (ggml_get_unary_op(dst)) { + case GGML_UNARY_OP_ABS: + ggml_cuda_op_abs(ctx, dst); + break; + case GGML_UNARY_OP_SGN: + ggml_cuda_op_sgn(ctx, dst); + break; case GGML_UNARY_OP_NEG: ggml_cuda_op_neg(ctx, dst); break; @@ -2249,6 +2255,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CLAMP: ggml_cuda_op_clamp(ctx, dst); break; + case GGML_OP_LOG: + ggml_cuda_op_log(ctx, dst); + break; case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -2967,6 +2976,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SGN: case GGML_UNARY_OP_NEG: case GGML_UNARY_OP_STEP: case GGML_UNARY_OP_GELU: @@ -3149,7 +3160,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return false; } break; case GGML_OP_SILU_BACK: - return ggml_is_contiguous(op->src[0]); + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; break; case GGML_OP_NORM: case GGML_OP_RMS_NORM: @@ -3173,6 +3184,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_CLAMP: + case GGML_OP_LOG: return true; case GGML_OP_CONT: return op->src[0]->type != GGML_TYPE_BF16; diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 6b21f407d..ec5773e01 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -1,305 +1,213 @@ #include "unary.cuh" -static __global__ void neg_f32(const float * x, float * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - - dst[i] = -x[i]; +static __device__ __forceinline__ float op_abs(float x) { + return fabsf(x); } -static __global__ void step_f32(const float * x, float * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - - dst[i] = x[i] > 0.0f; +static __device__ __forceinline__ float op_sgn(float x) { + return (x > 0.f ? 1.f : ((x < 0.f ? -1.f : 0.f))); } -static __global__ void gelu_f32(const float * x, float * dst, const int k) { +static __device__ __forceinline__ float op_neg(float x) { + return -x; +} + +static __device__ __forceinline__ float op_step(float x) { + return x > 0.0f; +} + +static __device__ __forceinline__ float op_gelu(float x) { const float GELU_COEF_A = 0.044715f; const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - const int i = blockDim.x*blockIdx.x + threadIdx.x; - if (i >= k) { - return; - } - - float xi = x[i]; - dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi))); + return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } -static __global__ void gelu_quick_f32(const float * x, float * dst, int k) { +static __device__ __forceinline__ float op_gelu_quick(float x) { const float GELU_QUICK_COEF = -1.702f; - const int i = blockDim.x*blockIdx.x + threadIdx.x; - if (i >= k) { - return; - } - dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i]))); + + return x * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x))); } -static __global__ void silu_f32(const float * x, float * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - dst[i] = x[i] / (1.0f + expf(-x[i])); +static __device__ __forceinline__ float op_silu(float x) { + return x / (1.0f + expf(-x)); } -static __global__ void silu_back_f32( - const float * grad, const float * xf, float * dst, const int k) { +static __device__ __forceinline__ float op_tanh(float x) { + return tanhf(x); +} + +static __device__ __forceinline__ float op_relu(float x) { + return fmaxf(x, 0); +} + +static __device__ __forceinline__ float op_sigmoid(float x) { + return 1.0f / (1.0f + expf(-x)); +} + +static __device__ __forceinline__ float op_hardsigmoid(float x) { + return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f)); +} + +static __device__ __forceinline__ float op_hardswish(float x) { + return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f)); +} + +static __device__ __forceinline__ float op_exp(float x) { + return expf(x); +} + +static __device__ __forceinline__ float op_sqr(float x) { + return x * x; +} + +static __device__ __forceinline__ float op_sqrt(float x) { + return sqrtf(x); +} + +static __device__ __forceinline__ float op_sin(float x) { + return sinf(x); +} + +static __device__ __forceinline__ float op_cos(float x) { + return cosf(x); +} + +static __device__ __forceinline__ float op_log(float x) { + return logf(x); +} + +template +static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } - const float xfi = xf[i]; - const float s = 1.0f / (1.0f + expf(-xfi)); - dst[i] = grad[i] * s * (1.0f + xfi * (1.0f - s)); + dst[i] = (T)op((float)x[i]); } -static __global__ void tanh_f32(const float * x, float * dst, int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - if (i >= k) { - return; - } - dst[i] = tanhf(x[i]); -} - -static __global__ void relu_f32(const float * x, float * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - dst[i] = fmaxf(x[i], 0); -} - -static __global__ void sigmoid_f32(const float * x, float * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - dst[i] = 1.0f / (1.0f + expf(-x[i])); -} - -static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); -} - -static __global__ void hardswish_f32(const float * x, float * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); -} - -static __global__ void exp_f32(const float * x, float * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - dst[i] = expf(x[i]); -} - -static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - if (i >= k) { - return; - } - dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope; -} - -static __global__ void sqr_f32(const float * x, float * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - dst[i] = x[i] * x[i]; -} - -static __global__ void sqrt_f32(const float * x, float * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - dst[i] = sqrtf(x[i]); -} - -static __global__ void sin_f32(const float * x, float * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - dst[i] = sinf(x[i]); -} - -static __global__ void cos_f32(const float * x, float * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - dst[i] = cosf(x[i]); -} - -static void neg_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { +template +static void unary_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE; - neg_f32<<>>(x, dst, k); + unary_op_kernel<<>>(x, dst, k); } -static void step_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_STEP_BLOCK_SIZE - 1) / CUDA_STEP_BLOCK_SIZE; - step_f32<<>>(x, dst, k); +template +void ggml_cuda_op_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const void * src0_d = src0->data; + void * dst_d = dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); + + if (src0->type == GGML_TYPE_F16) { + unary_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); + } else { + unary_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); + } } -static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; - gelu_f32<<>>(x, dst, k); +void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); } -static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; - gelu_quick_f32<<>>(x, dst, k); -} - -static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; - silu_f32<<>>(x, dst, k); -} - -static void silu_back_f32_cuda(const float * grad, const float * x, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; - silu_back_f32<<>>(grad, x, dst, k); -} - -static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE; - tanh_f32<<>>(x, dst, k); -} - -static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; - relu_f32<<>>(x, dst, k); -} - -static void sigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_SIGMOID_BLOCK_SIZE - 1) / CUDA_SIGMOID_BLOCK_SIZE; - sigmoid_f32<<>>(x, dst, k); -} - -static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE; - hardsigmoid_f32<<>>(x, dst, k); -} - -static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE; - hardswish_f32<<>>(x, dst, k); -} - -static void exp_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE; - exp_f32<<>>(x, dst, k); -} - -static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) { - const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; - leaky_relu_f32<<>>(x, dst, k, negative_slope); -} - -static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE; - sqr_f32<<>>(x, dst, k); -} - -static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_SQRT_BLOCK_SIZE - 1) / CUDA_SQRT_BLOCK_SIZE; - sqrt_f32<<>>(x, dst, k); -} - -static void sin_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_SIN_BLOCK_SIZE - 1) / CUDA_SIN_BLOCK_SIZE; - sin_f32<<>>(x, dst, k); -} - -static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE; - cos_f32<<>>(x, dst, k); +void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); } void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); - - GGML_ASSERT(ggml_is_contiguous(src0)); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - neg_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); + ggml_cuda_op_unary(ctx, dst); } void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); - - GGML_ASSERT(ggml_is_contiguous(src0)); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - step_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); + ggml_cuda_op_unary(ctx, dst); } void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); + ggml_cuda_op_unary(ctx, dst); +} - GGML_ASSERT(ggml_is_contiguous(src0)); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - gelu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); +void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); } void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); + ggml_cuda_op_unary(ctx, dst); +} - GGML_ASSERT(ggml_is_contiguous(src0)); +void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); +void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} - silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); +void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +/* silu_back */ + +static __device__ __forceinline__ float op_silu_back(float grad, float x) { + const float s = 1.0f / (1.0f + expf(-x)); + return grad * s * (1.0f + x * (1.0f - s)); +} + +template +static __global__ void silu_back_kernel(const T * grad, const T * xf, T * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = (T)op_silu_back((float)grad[i], (float)xf[i]); +} + +template +static void silu_back_cuda(const T * grad, const T * x, T * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; + silu_back_kernel<<>>(grad, x, dst, k); } void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -314,179 +222,58 @@ void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); - silu_back_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(src0), stream); + if (src0->type == GGML_TYPE_F16) { + silu_back_cuda((const half *)src0_d, (const half *)src1_d, (half *)dst_d, ggml_nelements(src0), stream); + } else { + silu_back_cuda((const float*)src0_d, (const float*)src1_d, (float *)dst_d, ggml_nelements(src0), stream); + } } -void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); +/* leaky relu */ - GGML_ASSERT(ggml_is_contiguous(src0)); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - gelu_quick_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); +static __device__ __forceinline__ float op_leaky_relu(float x, const float negative_slope) { + return fmaxf(x, 0) + fminf(x, 0.0f) * negative_slope; } -void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); +template +static __global__ void leaky_relu_kernel(const T * x, T * dst, const int k, const float negative_slope) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; - GGML_ASSERT(ggml_is_contiguous(src0)); + if (i >= k) { + return; + } - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - tanh_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); + dst[i] = (T)op_leaky_relu((float)x[i], negative_slope); } -void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); - - GGML_ASSERT(ggml_is_contiguous(src0)); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); -} - -void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); - - GGML_ASSERT(ggml_is_contiguous(src0)); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - sigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); -} - -void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); - - GGML_ASSERT(ggml_is_contiguous(src0)); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - hardsigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); -} - -void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); - - GGML_ASSERT(ggml_is_contiguous(src0)); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); -} - -void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); - - GGML_ASSERT(ggml_is_contiguous(src0)); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - exp_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); +template +static void leaky_relu_cuda(const T * x, T * dst, const int k, const float negative_slope, cudaStream_t stream) { + const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; + leaky_relu_kernel<<>>(x, dst, k, negative_slope); } void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; + const void * src0_d = src0->data; + void * dst_d = dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); float negative_slope; memcpy(&negative_slope, dst->op_params, sizeof(float)); - leaky_relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), negative_slope, stream); -} - -void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); - - GGML_ASSERT(ggml_is_contiguous(src0)); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - sqr_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); -} - -void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); - - GGML_ASSERT(ggml_is_contiguous(src0)); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); -} - -void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); - - GGML_ASSERT(ggml_is_contiguous(src0)); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - sin_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); -} - -void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); - - GGML_ASSERT(ggml_is_contiguous(src0)); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - cos_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); + if (src0->type == GGML_TYPE_F16) { + leaky_relu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), negative_slope, stream); + } else { + leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream); + } } diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index e7f62643a..940a1feed 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -16,6 +16,10 @@ #define CUDA_SIN_BLOCK_SIZE 256 #define CUDA_COS_BLOCK_SIZE 256 +void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst); @@ -49,3 +53,5 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index cf52fa336..d8b9b0fb3 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1200,7 +1200,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_ELU: - return ggml_is_contiguous(op->src[0]); + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; default: return false; } @@ -1210,21 +1210,26 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_TRANSPOSE: case GGML_OP_PERMUTE: case GGML_OP_CONCAT: + return true; case GGML_OP_ADD: case GGML_OP_SUB: - case GGML_OP_ACC: case GGML_OP_MUL: case GGML_OP_DIV: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_ACC: case GGML_OP_REPEAT: case GGML_OP_SCALE: - case GGML_OP_CLAMP: case GGML_OP_CONV_TRANSPOSE_1D: return true; + case GGML_OP_CLAMP: + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: - return ggml_is_contiguous(op->src[0]); + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_LOG: + return false; // TODO: implement case GGML_OP_SUM_ROWS: case GGML_OP_SOFT_MAX: case GGML_OP_GROUP_NORM: @@ -1254,10 +1259,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_UPSCALE: case GGML_OP_PAD: case GGML_OP_PAD_REFLECT_1D: - case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: case GGML_OP_LEAKY_RELU: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_ARANGE: return true; case GGML_OP_FLASH_ATTN_EXT: if (op->src[1]->type != op->src[2]->type) { diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index b1df4e5db..577ff51fd 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -29,6 +29,7 @@ #include "wkv6.hpp" #include "outprod.hpp" #include "element_wise.hpp" +#include "cpy.hpp" #include "gla.hpp" #endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 7c503a1b1..a92988b7d 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -34,6 +34,7 @@ #pragma clang diagnostic ignored "-Wnested-anon-types" #include "ggml-common.h" #pragma clang diagnostic pop +#include "ggml-impl.h" void* ggml_sycl_host_malloc(size_t size); void ggml_sycl_host_free(void* ptr); diff --git a/ggml/src/ggml-sycl/cpy.cpp b/ggml/src/ggml-sycl/cpy.cpp new file mode 100644 index 000000000..5a2314589 --- /dev/null +++ b/ggml/src/ggml-sycl/cpy.cpp @@ -0,0 +1,701 @@ +#include "cpy.hpp" + +#include + +#include "dequantize.hpp" + +static __dpct_inline__ int best_index_int8(int n, const int8_t * val, float x) { + if (x <= val[0]) { + return 0; + } + if (x >= val[n - 1]) { + return n - 1; + } + int ml = 0, mu = n - 1; + while (mu - ml > 1) { + int mav = (ml + mu) / 2; + if (x < val[mav]) { + mu = mav; + } else { + ml = mav; + } + } + return x - val[mu - 1] < val[mu] - x ? mu - 1 : mu; +} + +static void cpy_1_f32_f32(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + float * dsti = (float *) cdsti; + + *dsti = *xi; +} + +static void cpy_1_f32_f16(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + sycl::half * dsti = (sycl::half *) cdsti; + + *dsti = sycl::vec(*xi).convert()[0]; +} + +static void cpy_1_f16_f16(const char * cxi, char * cdsti) { + const sycl::half * xi = (const sycl::half *) cxi; + sycl::half * dsti = (sycl::half *) cdsti; + + *dsti = *xi; +} + +static void cpy_1_f16_f32(const char * cxi, char * cdsti) { + const sycl::half * xi = (const sycl::half *) cxi; + float * dsti = (float *) cdsti; + + *dsti = *xi; +} + +static void cpy_1_i16_i16(const char * cxi, char * cdsti) { + const int16_t * xi = (const int16_t *) cxi; + int16_t * dsti = (int16_t *) cdsti; + + *dsti = *xi; +} + +static void cpy_1_i32_i32(const char * cxi, char * cdsti) { + const int32_t * xi = (const int32_t *) cxi; + int32_t * dsti = (int32_t *) cdsti; + + *dsti = *xi; +} + +template +static void cpy_f32_f16(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, + const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, + const sycl::nd_item<3> & item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); + + if (i >= ne) { + return; + } + + // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor + // then combine those indices with the corresponding byte offsets to get the total offsets + const int i03 = i / (ne00 * ne01 * ne02); + const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); + const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00; + const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00; + const int x_offset = i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03; + + const int i13 = i / (ne10 * ne11 * ne12); + const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11); + const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10; + const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10; + const int dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13; + + cpy_1(cx + x_offset, cdst + dst_offset); +} + +static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q8_0 * dsti = (block_q8_0 *) cdsti; + + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = xi[j]; + amax = sycl::fmax(amax, sycl::fabs((float) v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f / d : 0.0f; + + dsti->d = d; + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = xi[j] * id; + + dsti->qs[j] = sycl::round((float) x0); + } +} + +static void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { + float * cdstf = (float *) (cdsti); + + for (int j = 0; j < QK8_0; j += 2) { + dfloat2 dq; + dequantize_q8_0(cxi, 0, j, dq); + *(cdstf + j) = dq.x(); + *(cdstf + j + 1) = dq.y(); + } +} + +static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q4_0 * dsti = (block_q4_0 *) cdsti; + + float amax = 0.0f; + float vmax = 0.0f; + + for (int j = 0; j < QK4_0; ++j) { + const float v = xi[j]; + if (amax < sycl::fabs((float) v)) { + amax = sycl::fabs((float) v); + vmax = v; + } + } + + const float d = vmax / -8; + const float id = d ? 1.0f / d : 0.0f; + + dsti->d = d; + + for (int j = 0; j < QK4_0 / 2; ++j) { + const float x0 = xi[0 + j] * id; + const float x1 = xi[QK4_0 / 2 + j] * id; + + const uint8_t xi0 = dpct::min(15, (int8_t) (x0 + 8.5f)); + const uint8_t xi1 = dpct::min(15, (int8_t) (x1 + 8.5f)); + + dsti->qs[j] = xi0; + dsti->qs[j] |= xi1 << 4; + } +} + +static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q4_1 * dsti = (block_q4_1 *) cdsti; + + float vmin = FLT_MAX; + float vmax = -FLT_MAX; + + for (int j = 0; j < QK4_1; ++j) { + const float v = xi[j]; + + if (v < vmin) { + vmin = v; + } + if (v > vmax) { + vmax = v; + } + } + + const float d = (vmax - vmin) / ((1 << 4) - 1); + const float id = d ? 1.0f / d : 0.0f; + + dsti->dm.x() = d; + dsti->dm.y() = vmin; + + for (int j = 0; j < QK4_1 / 2; ++j) { + const float x0 = (xi[0 + j] - vmin) * id; + const float x1 = (xi[QK4_1 / 2 + j] - vmin) * id; + + const uint8_t xi0 = dpct::min(15, (int8_t) (x0 + 0.5f)); + const uint8_t xi1 = dpct::min(15, (int8_t) (x1 + 0.5f)); + + dsti->qs[j] = xi0; + dsti->qs[j] |= xi1 << 4; + } +} + +static void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q5_0 * dsti = (block_q5_0 *) cdsti; + + float amax = 0.0f; + float vmax = 0.0f; + + for (int j = 0; j < QK5_0; ++j) { + const float v = xi[j]; + if (amax < sycl::fabs((float) v)) { + amax = sycl::fabs((float) v); + vmax = v; + } + } + + const float d = vmax / -16; + const float id = d ? 1.0f / d : 0.0f; + + dsti->d = d; + + uint32_t qh = 0; + for (int j = 0; j < QK5_0 / 2; ++j) { + const float x0 = xi[0 + j] * id; + const float x1 = xi[QK5_0 / 2 + j] * id; + + const uint8_t xi0 = dpct::min(31, (int8_t) (x0 + 16.5f)); + const uint8_t xi1 = dpct::min(31, (int8_t) (x1 + 16.5f)); + + dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0 / 2); + } + memcpy(dsti->qh, &qh, sizeof(qh)); +} + +static void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q5_1 * dsti = (block_q5_1 *) cdsti; + + float min = xi[0]; + float max = xi[0]; + + for (int j = 1; j < QK5_1; ++j) { + const float v = xi[j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = d ? 1.0f / d : 0.0f; + + dsti->dm.x() = d; + dsti->dm.y() = min; + + uint32_t qh = 0; + for (int j = 0; j < QK5_1 / 2; ++j) { + const float x0 = (xi[0 + j] - min) * id; + const float x1 = (xi[QK5_1 / 2 + j] - min) * id; + + const uint8_t xi0 = (uint8_t) (x0 + 0.5f); + const uint8_t xi1 = (uint8_t) (x1 + 0.5f); + + dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1 / 2); + } + memcpy(dsti->qh, &qh, sizeof(qh)); +} + +static void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_iq4_nl * dsti = (block_iq4_nl *) cdsti; + + float amax = 0.0f; + float vmax = 0.0f; + + for (int j = 0; j < QK4_NL; ++j) { + const float v = xi[j]; + if (amax < sycl::fabs((float) v)) { + amax = sycl::fabs((float) v); + vmax = v; + } + } + + float d = vmax / kvalues_iq4nl[0]; + const float id = d ? 1.0f / d : 0.0f; + + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < QK4_NL / 2; ++j) { + const float x0 = xi[0 + j] * id; + const float x1 = xi[QK4_NL / 2 + j] * id; + const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0); + const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1); + dsti->qs[j] = xi0 | (xi1 << 4); + const float v0 = kvalues_iq4nl[xi0]; + const float v1 = kvalues_iq4nl[xi1]; + const float w0 = xi[0 + j] * xi[0 + j]; + const float w1 = xi[QK4_NL / 2 + j] * xi[QK4_NL / 2 + j]; + sumqx += w0 * v0 * xi[j] + w1 * v1 * xi[QK4_NL / 2 + j]; + sumq2 += w0 * v0 * v0 + w1 * v1 * v1; + } + + dsti->d = sumq2 > 0 ? sumqx / sumq2 : d; +} + +template static void cpy_blck_q_f32(const char * cxi, char * cdsti) { + float * cdstf = (float *) (cdsti); + + for (int j = 0; j < qk / 2; j++) { + dfloat2 dq; + dequant(cxi, 0, j, dq); + *(cdstf + j) = dq.x(); + *(cdstf + j + qk / 2) = dq.y(); + } +} + +template +static void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, + const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, + const sycl::nd_item<3> & item_ct1) { + const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk; + + if (i >= ne) { + return; + } + + const int i03 = i / (ne00 * ne01 * ne02); + const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); + const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00; + const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00; + const int x_offset = i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03; + + const int i13 = i / (ne10 * ne11 * ne12); + const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11); + const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10; + const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10; + const int dst_offset = (i10 / qk) * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13; + + cpy_blck(cx + x_offset, cdst + dst_offset); +} + +template +static void cpy_q_f32(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, + const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, + const sycl::nd_item<3> & item_ct1) { + const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk; + + if (i >= ne) { + return; + } + + const int i03 = i / (ne00 * ne01 * ne02); + const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); + const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00; + const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00; + const int x_offset = (i00 / qk) * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03; + + const int i13 = i / (ne10 * ne11 * ne12); + const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11); + const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10; + const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10; + const int dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13; + + cpy_blck(cx + x_offset, cdst + dst_offset); +} + +static void ggml_cpy_f16_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; + { + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13, item_ct1); + }); + } +} + +static void ggml_cpy_f32_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; + { + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13, item_ct1); + }); + } +} + +static void ggml_cpy_f32_f16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; + { + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13, item_ct1); + }); + } +} + +static void ggml_cpy_f32_q8_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + GGML_ASSERT(ne % QK8_0 == 0); + const int num_blocks = ne / QK8_0; + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + +static void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = ne; + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), + [=](sycl::nd_item<3> item_ct1) { + cpy_q_f32(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + +static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + GGML_ASSERT(ne % QK4_0 == 0); + const int num_blocks = ne / QK4_0; + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + +static void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = ne; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_f32, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, + nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, + item_ct1); + }); +} + +static void ggml_cpy_f32_q4_1_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + GGML_ASSERT(ne % QK4_1 == 0); + const int num_blocks = ne / QK4_1; + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + +static void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = ne; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_f32, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, + nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, + item_ct1); + }); +} + +static void ggml_cpy_f32_q5_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + GGML_ASSERT(ne % QK5_0 == 0); + const int num_blocks = ne / QK5_0; + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + +static void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = ne; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_f32, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, + nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, + item_ct1); + }); +} + +static void ggml_cpy_f32_q5_1_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + GGML_ASSERT(ne % QK5_1 == 0); + const int num_blocks = ne / QK5_1; + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + +static void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = ne; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_f32, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, + nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, + item_ct1); + }); +} + +static void ggml_cpy_f32_iq4_nl_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + GGML_ASSERT(ne % QK4_NL == 0); + const int num_blocks = ne / QK4_NL; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { + cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, + ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + +static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; + { + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13, item_ct1); + }); + } +} + +static void ggml_cpy_i16_i16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; + { + // dpct::has_capability_or_fail(stream->get_device(), + // {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13, item_ct1); + }); + } +} + +static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; + { + // dpct::has_capability_or_fail(stream->get_device(), + // {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13, item_ct1); + }); + } +} + +void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1) try { + const int64_t ne = ggml_nelements(src0); + GGML_ASSERT(ne == ggml_nelements(src1)); + + GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); + GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); + + GGML_TENSOR_BINARY_OP_LOCALS01; + + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + queue_ptr main_stream = ctx.stream(); + + char * src0_ddc = (char *) src0->data; + char * src1_ddc = (char *) src1->data; + GGML_SYCL_DEBUG("[SYCL] %s: Tensor supplied: %s to %s\n", __func__, ggml_type_name(src0->type), + ggml_type_name(src1->type)); + + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + ggml_cpy_f32_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { + ggml_cpy_f32_f16_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { + ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { + ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { + ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { + ggml_cpy_f16_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { + ggml_cpy_f16_f16_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) { + ggml_cpy_i16_i16_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) { + ggml_cpy_i32_i32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q4_0_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q4_1_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q8_0_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { + ggml_cpy_f32_q5_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q5_0_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { + ggml_cpy_f32_q5_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q5_1_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, + nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { + ggml_cpy_f32_iq4_nl_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13, main_stream); + } else { + GGML_LOG_ERROR("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), + ggml_type_name(src1->type)); + GGML_ABORT("fatal error"); + } +} catch (const sycl::exception & exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + // TODO: why do we pass dst as src1 here? + GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__); + ggml_sycl_cpy(ctx, dst->src[0], dst); + GGML_SYCL_DEBUG("[SYCL] call %s done\n", __func__); +} diff --git a/ggml/src/ggml-sycl/cpy.hpp b/ggml/src/ggml-sycl/cpy.hpp new file mode 100644 index 000000000..0a0f561d2 --- /dev/null +++ b/ggml/src/ggml-sycl/cpy.hpp @@ -0,0 +1,11 @@ +#ifndef GGML_SYCL_CPY_HPP +#define GGML_SYCL_CPY_HPP + +#include "common.hpp" + +typedef void (*cpy_kernel_t)(const char * cx, char * cdst); + +void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1); +void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_CPY_HPP diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index d804e6606..e7304947f 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1285,8 +1285,6 @@ std::unique_ptr ggml_backend_sycl_context::new_pool_for_device(q // struct ggml_sycl_pool_vmm : public ggml_sycl_pool /// kernels - -typedef void (*cpy_kernel_t)(const char * cx, char * cdst); typedef void (*ggml_sycl_op_mul_mat_t)( ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, @@ -1468,193 +1466,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous } } -static void cpy_1_f32_f32(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - float * dsti = (float *) cdsti; - - *dsti = *xi; -} - -static void cpy_1_f32_f16(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - sycl::half *dsti = (sycl::half *)cdsti; - - *dsti = sycl::vec(*xi) - .convert()[0]; -} - -static void cpy_1_f16_f16(const char * cxi, char * cdsti) { - const sycl::half *xi = (const sycl::half *)cxi; - sycl::half *dsti = (sycl::half *)cdsti; - - *dsti = *xi; -} - -static void cpy_1_f16_f32(const char * cxi, char * cdsti) { - const sycl::half *xi = (const sycl::half *)cxi; - float * dsti = (float *) cdsti; - - *dsti = *xi; -} - -static void cpy_1_i16_i16(const char * cxi, char * cdsti) { - const int16_t *xi = (const int16_t *)cxi; - int16_t *dsti = (int16_t *)cdsti; - - *dsti = *xi; -} - -static void cpy_1_i32_i32(const char * cxi, char * cdsti) { - const int32_t *xi = (const int32_t *)cxi; - int32_t *dsti = (int32_t *)cdsti; - - *dsti = *xi; -} - -template -static void cpy_f32_f16(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= ne) { - return; - } - - // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor - // then combine those indices with the corresponding byte offsets to get the total offsets - const int i03 = i/(ne00 * ne01 * ne02); - const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); - const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; - const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; - const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; - - const int i13 = i/(ne10 * ne11 * ne12); - const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); - const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; - const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; - const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13; - - cpy_1(cx + x_offset, cdst + dst_offset); -} - -static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_q8_0 * dsti = (block_q8_0 *) cdsti; - - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - const float v = xi[j]; - amax = sycl::fmax(amax, sycl::fabs((float)v)); - } - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - dsti->d = d; - - for (int j = 0; j < QK8_0; ++j) { - const float x0 = xi[j]*id; - - dsti->qs[j] = sycl::round((float)x0); - } -} - -static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_q4_0 * dsti = (block_q4_0 *) cdsti; - - float amax = 0.0f; - float vmax = 0.0f; - - for (int j = 0; j < QK4_0; ++j) { - const float v = xi[j]; - if (amax < sycl::fabs((float)v)) { - amax = sycl::fabs((float)v); - vmax = v; - } - } - - const float d = vmax / -8; - const float id = d ? 1.0f/d : 0.0f; - - dsti->d = d; - - for (int j = 0; j < QK4_0/2; ++j) { - const float x0 = xi[0 + j]*id; - const float x1 = xi[QK4_0/2 + j]*id; - - const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 8.5f)); - const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 8.5f)); - - dsti->qs[j] = xi0; - dsti->qs[j] |= xi1 << 4; - } -} - -static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_q4_1 * dsti = (block_q4_1 *) cdsti; - - float vmin = FLT_MAX; - float vmax = -FLT_MAX; - - for (int j = 0; j < QK4_1; ++j) { - const float v = xi[j]; - - if (v < vmin) vmin = v; - if (v > vmax) vmax = v; - } - - const float d = (vmax - vmin) / ((1 << 4) - 1); - const float id = d ? 1.0f/d : 0.0f; - - dsti->dm.x() = d; - dsti->dm.y() = vmin; - - for (int j = 0; j < QK4_1/2; ++j) { - const float x0 = (xi[0 + j] - vmin)*id; - const float x1 = (xi[QK4_1/2 + j] - vmin)*id; - - const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 0.5f)); - const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 0.5f)); - - dsti->qs[j] = xi0; - dsti->qs[j] |= xi1 << 4; - } -} - -template -static void cpy_f32_q(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) { - const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2)) * - qk; - - if (i >= ne) { - return; - } - - const int i03 = i/(ne00 * ne01 * ne02); - const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); - const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; - const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; - const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; - - const int i13 = i/(ne10 * ne11 * ne12); - const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); - const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; - const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; - const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13; - - cpy_blck(cx + x_offset, cdst + dst_offset); -} - static void k_sum_rows_f32(const float * x, float * dst, const int ncols, const sycl::nd_item<3> &item_ct1) { const int row = item_ct1.get_group(1); @@ -1903,231 +1714,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl( } } -static void -ggml_cpy_f16_f32_sycl(const char *cx, char *cdst, const int ne, const int ne00, - const int ne01, const int ne02, const int nb00, - const int nb01, const int nb02, const int nb03, - const int ne10, const int ne11, const int ne12, - const int nb10, const int nb11, const int nb12, - const int nb13, queue_ptr stream) { - const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, - nb01, nb02, nb03, ne10, ne11, ne12, - nb10, nb11, nb12, nb13, item_ct1); - }); - } -} - -static void ggml_cpy_f32_f32_sycl(const char *cx, char *cdst, const int ne, - const int ne00, const int ne01, - const int ne02, const int nb00, - const int nb01, const int nb02, - const int nb03, const int ne10, - const int ne11, const int ne12, - const int nb10, const int nb11, - const int nb12, const int nb13, - queue_ptr stream) { - - const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, - nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, - item_ct1); - }); - } -} - -static void ggml_cpy_f32_f16_sycl(const char *cx, char *cdst, const int ne, - const int ne00, const int ne01, - const int ne02, const int nb00, - const int nb01, const int nb02, - const int nb03, const int ne10, - const int ne11, const int ne12, - const int nb10, const int nb11, - const int nb12, const int nb13, - queue_ptr stream) { - - const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, - nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, - item_ct1); - }); - } -} - -static void ggml_cpy_f32_q8_0_sycl(const char *cx, char *cdst, const int ne, - const int ne00, const int ne01, - const int ne02, const int nb00, - const int nb01, const int nb02, - const int nb03, const int ne10, - const int ne11, const int ne12, - const int nb10, const int nb11, - const int nb12, const int nb13, - queue_ptr stream) { - - GGML_ASSERT(ne % QK8_0 == 0); - const int num_blocks = ne / QK8_0; - stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), - sycl::range<3>(1, 1, 1)), - [=](sycl::nd_item<3> item_ct1) { - cpy_f32_q( - cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, - nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, - item_ct1); - }); -} - -static void ggml_cpy_f32_q4_0_sycl(const char *cx, char *cdst, const int ne, - const int ne00, const int ne01, - const int ne02, const int nb00, - const int nb01, const int nb02, - const int nb03, const int ne10, - const int ne11, const int ne12, - const int nb10, const int nb11, - const int nb12, const int nb13, - queue_ptr stream) { - - GGML_ASSERT(ne % QK4_0 == 0); - const int num_blocks = ne / QK4_0; - stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), - sycl::range<3>(1, 1, 1)), - [=](sycl::nd_item<3> item_ct1) { - cpy_f32_q( - cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, - nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, - item_ct1); - }); -} - -static void ggml_cpy_f32_q4_1_sycl(const char *cx, char *cdst, const int ne, - const int ne00, const int ne01, - const int ne02, const int nb00, - const int nb01, const int nb02, - const int nb03, const int ne10, - const int ne11, const int ne12, - const int nb10, const int nb11, - const int nb12, const int nb13, - queue_ptr stream) { - - GGML_ASSERT(ne % QK4_1 == 0); - const int num_blocks = ne / QK4_1; - stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), - sycl::range<3>(1, 1, 1)), - [=](sycl::nd_item<3> item_ct1) { - cpy_f32_q( - cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, - nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, - item_ct1); - }); -} - -static void ggml_cpy_f16_f16_sycl(const char *cx, char *cdst, const int ne, - const int ne00, const int ne01, - const int ne02, const int nb00, - const int nb01, const int nb02, - const int nb03, const int ne10, - const int ne11, const int ne12, - const int nb10, const int nb11, - const int nb12, const int nb13, - queue_ptr stream) { - - const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, - nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, - item_ct1); - }); - } -} - -static void ggml_cpy_i16_i16_sycl(const char *cx, char *cdst, const int ne, - const int ne00, const int ne01, - const int ne02, const int nb00, - const int nb01, const int nb02, - const int nb03, const int ne10, - const int ne11, const int ne12, - const int nb10, const int nb11, - const int nb12, const int nb13, - queue_ptr stream) { - - const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; - { - // dpct::has_capability_or_fail(stream->get_device(), - // {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, - nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, - item_ct1); - }); - } -} - -static void ggml_cpy_i32_i32_sycl(const char *cx, char *cdst, const int ne, - const int ne00, const int ne01, - const int ne02, const int nb00, - const int nb01, const int nb02, - const int nb03, const int ne10, - const int ne11, const int ne12, - const int nb10, const int nb11, - const int nb12, const int nb13, - queue_ptr stream) { - - const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; - { - // dpct::has_capability_or_fail(stream->get_device(), - // {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, - nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, - item_ct1); - }); - } -} static void scale_f32_sycl(const float *x, float *dst, const float scale, const int k, queue_ptr stream) { @@ -3645,58 +3232,6 @@ static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_clamp); } -static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst) try { - const int64_t ne = ggml_nelements(src0); - GGML_ASSERT(ne == ggml_nelements(src1)); - - GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); - GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); - - GGML_TENSOR_BINARY_OP_LOCALS01; - - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - queue_ptr main_stream = ctx.stream(); - - char * src0_ddc = (char *) src0->data; - char * src1_ddc = (char *) src1->data; - - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_f32_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - ggml_cpy_f32_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { - ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { - ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { - ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); - } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - ggml_cpy_f16_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); - } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_f16_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); - } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) { - ggml_cpy_i16_i16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); - } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) { - ggml_cpy_i32_i32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); - } else { - GGML_LOG_ERROR("%s: unsupported type combination (%s to %s)\n", __func__, - ggml_type_name(src0->type), ggml_type_name(src1->type)); - GGML_ABORT("fatal error"); - } - GGML_UNUSED(dst); -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - // TODO: why do we pass dst as src1 here? - ggml_sycl_cpy(ctx, dst->src[0], dst, nullptr); -} - static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf); } @@ -3893,7 +3428,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens ggml_sycl_clamp(ctx, dst); break; case GGML_OP_CPY: - ggml_sycl_cpy(ctx, dst->src[0], dst->src[1], dst); + ggml_sycl_cpy(ctx, dst->src[0], dst->src[1]); break; case GGML_OP_CONT: ggml_sycl_dup(ctx, dst); @@ -4407,6 +3942,30 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { return true; } + if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) { + return true; + } + if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) { + return true; + } + if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) { + return true; + } return false; } break; case GGML_OP_CONCAT: diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 0bcb2fe4b..102edb834 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -8460,7 +8460,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_SIGMOID: - return ggml_is_contiguous(op->src[0]); + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; default: return false; } @@ -8661,19 +8661,20 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_RMS_NORM: return ggml_is_contiguous(op->src[0]); case GGML_OP_ADD: - case GGML_OP_ACC: case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: - case GGML_OP_CONCAT: case GGML_OP_SILU_BACK: case GGML_OP_RMS_NORM_BACK: - case GGML_OP_UPSCALE: - case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_CLAMP: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_ACC: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SCALE: case GGML_OP_PAD: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: