merge checkpoint 2 - functional merge without q4_0_4_4 (need regen shaders)

This commit is contained in:
Concedo 2024-12-13 17:04:19 +08:00
commit de64b9198c
218 changed files with 175736 additions and 49778 deletions

View file

@ -434,11 +434,11 @@ add_library(ggml
ggml/src/ggml-quants.h ggml/src/ggml-quants.h
ggml/src/ggml-cpu/llamafile/sgemm.cpp ggml/src/ggml-cpu/llamafile/sgemm.cpp
ggml/src/ggml-cpu/llamafile/sgemm.h ggml/src/ggml-cpu/llamafile/sgemm.h
ggml/src/ggml-aarch64.c ggml/src/ggml-cpu/ggml-cpu-traits.cpp
ggml/src/ggml-aarch64.h ggml/src/ggml-cpu/ggml-cpu-traits.h
ggml/src/ggml-threading.cpp ggml/src/ggml-threading.cpp
ggml/src/ggml-cpu/ggml-cpu.cpp ggml/src/ggml-cpu/ggml-cpu.cpp
ggml/src/ggml-cpu/ggml-cpu-aarch64.c ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
ggml/src/ggml-cpu/ggml-cpu-aarch64.h ggml/src/ggml-cpu/ggml-cpu-aarch64.h
ggml/src/ggml-cpu/ggml-cpu-quants.c ggml/src/ggml-cpu/ggml-cpu-quants.c
ggml/src/ggml-cpu/ggml-cpu-quants.h ggml/src/ggml-cpu/ggml-cpu-quants.h

3
CODEOWNERS Normal file
View file

@ -0,0 +1,3 @@
# collaborators can optionally add themselves here to indicate their availability for reviewing related PRs
ci/ @ggerganov

View file

@ -92,9 +92,9 @@ endif
CUBLASLD_FLAGS = CUBLASLD_FLAGS =
CUBLAS_OBJS = CUBLAS_OBJS =
OBJS_FULL += ggml-alloc.o ggml-aarch64.o ggml-quants.o ggml-cpu-quants.o ggml-cpu-aarch64.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o sgemm.o common.o sampling.o OBJS_FULL += ggml-alloc.o ggml-cpu-traits.o ggml-quants.o ggml-cpu-quants.o ggml-cpu-aarch64.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o sgemm.o common.o sampling.o
OBJS_SIMPLE += ggml-alloc.o ggml-aarch64.o ggml-quants_noavx2.o ggml-cpu-quants_noavx2.o ggml-cpu-aarch64_noavx2.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o sgemm_noavx2.o common.o sampling.o OBJS_SIMPLE += ggml-alloc.o ggml-cpu-traits.o ggml-quants_noavx2.o ggml-cpu-quants_noavx2.o ggml-cpu-aarch64_noavx2.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o sgemm_noavx2.o common.o sampling.o
OBJS_FAILSAFE += ggml-alloc.o ggml-aarch64.o ggml-quants_failsafe.o ggml-cpu-quants_failsafe.o ggml-cpu-aarch64_failsafe.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o sgemm_failsafe.o common.o sampling.o OBJS_FAILSAFE += ggml-alloc.o ggml-cpu-traits.o ggml-quants_failsafe.o ggml-cpu-quants_failsafe.o ggml-cpu-aarch64_failsafe.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o sgemm_failsafe.o common.o sampling.o
# OS specific # OS specific
# TODO: support Windows # TODO: support Windows
@ -488,12 +488,12 @@ ggml-cpu-quants_failsafe.o: ggml/src/ggml-cpu/ggml-cpu-quants.c ggml/include/ggm
$(CC) $(CFLAGS) $(NONECFLAGS) -c $< -o $@ $(CC) $(CFLAGS) $(NONECFLAGS) -c $< -o $@
#aarch64 #aarch64
ggml-cpu-aarch64.o: ggml/src/ggml-cpu/ggml-cpu-aarch64.c ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-aarch64.h ggml-cpu-aarch64.o: ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-aarch64.h
$(CC) $(CFLAGS) $(FULLCFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) $(FULLCFLAGS) -c $< -o $@
ggml-cpu-aarch64_noavx2.o: ggml/src/ggml-cpu/ggml-cpu-aarch64.c ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-aarch64.h ggml-cpu-aarch64_noavx2.o: ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-aarch64.h
$(CC) $(CFLAGS) $(SIMPLECFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) $(SIMPLECFLAGS) -c $< -o $@
ggml-cpu-aarch64_failsafe.o: ggml/src/ggml-cpu/ggml-cpu-aarch64.c ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-aarch64.h ggml-cpu-aarch64_failsafe.o: ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-aarch64.h
$(CC) $(CFLAGS) $(NONECFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) $(NONECFLAGS) -c $< -o $@
#sgemm #sgemm
sgemm.o: ggml/src/ggml-cpu/llamafile/sgemm.cpp ggml/src/ggml-cpu/llamafile/sgemm.h ggml/include/ggml.h sgemm.o: ggml/src/ggml-cpu/llamafile/sgemm.cpp ggml/src/ggml-cpu/llamafile/sgemm.h ggml/include/ggml.h
@ -512,8 +512,8 @@ unicode.o: src/unicode.cpp src/unicode.h
$(CXX) $(CXXFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) -c $< -o $@
unicode-data.o: src/unicode-data.cpp src/unicode-data.h unicode-data.o: src/unicode-data.cpp src/unicode-data.h
$(CXX) $(CXXFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) -c $< -o $@
ggml-aarch64.o: ggml/src/ggml-aarch64.c ggml/include/ggml.h ggml/src/ggml-aarch64.h ggml/src/ggml-common.h ggml-cpu-traits.o: ggml/src/ggml-cpu/ggml-cpu-traits.cpp ggml/src/ggml-cpu/ggml-cpu-traits.h ggml/include/ggml.h
$(CC) $(CFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) -c $< -o $@
ggml-threading.o: ggml/src/ggml-threading.cpp ggml/include/ggml.h ggml-threading.o: ggml/src/ggml-threading.cpp ggml/include/ggml.h
$(CXX) $(CXXFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) -c $< -o $@
ggml-cpu-cpp.o: ggml/src/ggml-cpu/ggml-cpu.cpp ggml/include/ggml.h ggml/src/ggml-common.h ggml-cpu-cpp.o: ggml/src/ggml-cpu/ggml-cpu.cpp ggml/include/ggml.h ggml/src/ggml-common.h

4
Sources/llama/llama.h Normal file
View file

@ -0,0 +1,4 @@
#pragma once
#include <llama.h>

View file

@ -0,0 +1,5 @@
module llama [system] {
header "llama.h"
link "llama"
export *
}

View file

@ -0,0 +1,11 @@
set( CMAKE_SYSTEM_NAME Windows )
set( CMAKE_SYSTEM_PROCESSOR x86_64 )
set( CMAKE_C_COMPILER clang )
set( CMAKE_CXX_COMPILER clang++ )
set( arch_c_flags "-march=native" )
set( CMAKE_C_FLAGS_INIT "${arch_c_flags}" )
set( CMAKE_CXX_FLAGS_INIT "${arch_c_flags}" )

View file

@ -592,7 +592,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) { [](common_params & params) {
params.ctx_shift = false; params.ctx_shift = false;
} }
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT")); ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
add_opt(common_arg( add_opt(common_arg(
{"--chunks"}, "N", {"--chunks"}, "N",
string_format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks), string_format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
@ -1712,6 +1712,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.public_path = value; params.public_path = value;
} }
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_STATIC_PATH")); ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_STATIC_PATH"));
add_opt(common_arg(
{"--no-webui"},
string_format("Disable the Web UI (default: %s)", params.webui ? "enabled" : "disabled"),
[](common_params & params) {
params.webui = false;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_WEBUI"));
add_opt(common_arg( add_opt(common_arg(
{"--embedding", "--embeddings"}, {"--embedding", "--embeddings"},
string_format("restrict to only support embedding use case; use only with dedicated embedding models (default: %s)", params.embedding ? "enabled" : "disabled"), string_format("restrict to only support embedding use case; use only with dedicated embedding models (default: %s)", params.embedding ? "enabled" : "disabled"),

View file

@ -62,6 +62,10 @@ struct common_speculative * common_speculative_init(
} }
void common_speculative_free(struct common_speculative * spec) { void common_speculative_free(struct common_speculative * spec) {
if (spec == nullptr) {
return;
}
common_sampler_free(spec->smpl); common_sampler_free(spec->smpl);
llama_batch_free(spec->batch); llama_batch_free(spec->batch);

View file

@ -1992,6 +1992,14 @@ class Qwen2Model(Model):
except FileNotFoundError: except FileNotFoundError:
self._set_vocab_gpt2() self._set_vocab_gpt2()
def set_gguf_parameters(self):
super().set_gguf_parameters()
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
if self.hparams["rope_scaling"].get("type") == "yarn":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
@Model.register("Qwen2MoeForCausalLM") @Model.register("Qwen2MoeForCausalLM")
class Qwen2MoeModel(Model): class Qwen2MoeModel(Model):

View file

@ -54,8 +54,6 @@ As the models are currently fully loaded into memory, you will need adequate dis
Several quantization methods are supported. They differ in the resulting model disk size and inference speed. Several quantization methods are supported. They differ in the resulting model disk size and inference speed.
The quantization formats `Q4_0_4_4`, `Q4_0_4_8` and `Q4_0_8_8` are block interleaved variants of the `Q4_0` format, providing a data layout that is better suited for specific implementations of optimized mulmat kernels. Since these formats differ only in data layout, they have the same quantized size as the `Q4_0` format.
*(outdated)* *(outdated)*
| Model | Measure | F16 | Q4_0 | Q4_1 | Q5_0 | Q5_1 | Q8_0 | | Model | Measure | F16 | Q4_0 | Q4_1 | Q5_0 | Q5_1 | Q8_0 |

View file

@ -49,9 +49,6 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 5.33G, +0.0569 ppl @ Llama-3-8B", }, { "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 5.33G, +0.0569 ppl @ Llama-3-8B", },
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 6.14G, +0.0217 ppl @ Llama-3-8B", }, { "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 6.14G, +0.0217 ppl @ Llama-3-8B", },
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 7.96G, +0.0026 ppl @ Llama-3-8B", }, { "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 7.96G, +0.0026 ppl @ Llama-3-8B", },
{ "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
{ "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
{ "Q4_0_8_8", LLAMA_FTYPE_MOSTLY_Q4_0_8_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
{ "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, +0.0020 ppl @ Mistral-7B", }, { "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, +0.0020 ppl @ Mistral-7B", },
{ "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", }, { "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", },
{ "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", }, { "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", },

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,5 @@
import pytest import pytest
import requests
from utils import * from utils import *
server = ServerPreset.tinyllama2() server = ServerPreset.tinyllama2()
@ -22,7 +23,12 @@ def test_server_props():
server.start() server.start()
res = server.make_request("GET", "/props") res = server.make_request("GET", "/props")
assert res.status_code == 200 assert res.status_code == 200
assert ".gguf" in res.body["model_path"]
assert res.body["total_slots"] == server.n_slots assert res.body["total_slots"] == server.n_slots
default_val = res.body["default_generation_settings"]
assert server.n_ctx is not None and server.n_slots is not None
assert default_val["n_ctx"] == server.n_ctx / server.n_slots
assert default_val["params"]["seed"] == server.seed
def test_server_models(): def test_server_models():
@ -33,6 +39,31 @@ def test_server_models():
assert len(res.body["data"]) == 1 assert len(res.body["data"]) == 1
assert res.body["data"][0]["id"] == server.model_alias assert res.body["data"][0]["id"] == server.model_alias
def test_server_slots():
global server
# without slots endpoint enabled, this should return error
server.server_slots = False
server.start()
res = server.make_request("GET", "/slots")
assert res.status_code == 501 # ERROR_TYPE_NOT_SUPPORTED
assert "error" in res.body
server.stop()
# with slots endpoint enabled, this should return slots info
server.server_slots = True
server.n_slots = 2
server.start()
res = server.make_request("GET", "/slots")
assert res.status_code == 200
assert len(res.body) == server.n_slots
assert server.n_ctx is not None and server.n_slots is not None
assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots
assert "params" in res.body[0]
assert res.body[0]["params"]["seed"] == server.seed
def test_load_split_model(): def test_load_split_model():
global server global server
server.model_hf_repo = "ggml-org/models" server.model_hf_repo = "ggml-org/models"
@ -46,3 +77,20 @@ def test_load_split_model():
}) })
assert res.status_code == 200 assert res.status_code == 200
assert match_regex("(little|girl)+", res.body["content"]) assert match_regex("(little|girl)+", res.body["content"])
def test_no_webui():
global server
# default: webui enabled
server.start()
url = f"http://{server.server_host}:{server.server_port}"
res = requests.get(url)
assert res.status_code == 200
assert "<html>" in res.text
server.stop()
# with --no-webui
server.no_webui = True
server.start()
res = requests.get(url)
assert res.status_code == 404

View file

@ -30,6 +30,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
], ],
}) })
assert res.status_code == 200 assert res.status_code == 200
assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
assert res.body["model"] == model if model is not None else server.model_alias assert res.body["model"] == model if model is not None else server.model_alias
assert res.body["usage"]["prompt_tokens"] == n_prompt assert res.body["usage"]["prompt_tokens"] == n_prompt
assert res.body["usage"]["completion_tokens"] == n_predicted assert res.body["usage"]["completion_tokens"] == n_predicted
@ -59,9 +60,13 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
"stream": True, "stream": True,
}) })
content = "" content = ""
last_cmpl_id = None
for data in res: for data in res:
choice = data["choices"][0] choice = data["choices"][0]
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
if last_cmpl_id is None:
last_cmpl_id = data["id"]
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
if choice["finish_reason"] in ["stop", "length"]: if choice["finish_reason"] in ["stop", "length"]:
assert data["usage"]["prompt_tokens"] == n_prompt assert data["usage"]["prompt_tokens"] == n_prompt
assert data["usage"]["completion_tokens"] == n_predicted assert data["usage"]["completion_tokens"] == n_predicted

View file

@ -42,10 +42,16 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
}) })
content = "" content = ""
for data in res: for data in res:
assert "stop" in data and type(data["stop"]) == bool
if data["stop"]: if data["stop"]:
assert data["timings"]["prompt_n"] == n_prompt assert data["timings"]["prompt_n"] == n_prompt
assert data["timings"]["predicted_n"] == n_predicted assert data["timings"]["predicted_n"] == n_predicted
assert data["truncated"] == truncated assert data["truncated"] == truncated
assert data["stop_type"] == "limit"
assert "generation_settings" in data
assert server.n_predict is not None
assert data["generation_settings"]["n_predict"] == min(n_predict, server.n_predict)
assert data["generation_settings"]["seed"] == server.seed
assert match_regex(re_content, content) assert match_regex(re_content, content)
else: else:
content += data["content"] content += data["content"]

View file

@ -13,28 +13,28 @@ def test_infill_without_input_extra():
global server global server
server.start() server.start()
res = server.make_request("POST", "/infill", data={ res = server.make_request("POST", "/infill", data={
"prompt": "Complete this", "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_", "prompt": " int n_threads = llama_",
"input_suffix": "}\n", "input_suffix": "}\n",
}) })
assert res.status_code == 200 assert res.status_code == 200
assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"]) assert match_regex("(Ann|small|shiny)+", res.body["content"])
def test_infill_with_input_extra(): def test_infill_with_input_extra():
global server global server
server.start() server.start()
res = server.make_request("POST", "/infill", data={ res = server.make_request("POST", "/infill", data={
"prompt": "Complete this",
"input_extra": [{ "input_extra": [{
"filename": "llama.h", "filename": "llama.h",
"text": "LLAMA_API int32_t llama_n_threads();\n" "text": "LLAMA_API int32_t llama_n_threads();\n"
}], }],
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_", "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
"prompt": " int n_threads = llama_",
"input_suffix": "}\n", "input_suffix": "}\n",
}) })
assert res.status_code == 200 assert res.status_code == 200
assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"]) assert match_regex("(Dad|excited|park)+", res.body["content"])
@pytest.mark.parametrize("input_extra", [ @pytest.mark.parametrize("input_extra", [
@ -48,10 +48,30 @@ def test_invalid_input_extra_req(input_extra):
global server global server
server.start() server.start()
res = server.make_request("POST", "/infill", data={ res = server.make_request("POST", "/infill", data={
"prompt": "Complete this",
"input_extra": [input_extra], "input_extra": [input_extra],
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_", "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
"prompt": " int n_threads = llama_",
"input_suffix": "}\n", "input_suffix": "}\n",
}) })
assert res.status_code == 400 assert res.status_code == 400
assert "error" in res.body assert "error" in res.body
@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
def test_with_qwen_model():
global server
server.model_file = None
server.model_hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-IQ3_XXS-GGUF"
server.model_hf_file = "qwen2.5-coder-1.5b-iq3_xxs-imat.gguf"
server.start(timeout_seconds=600)
res = server.make_request("POST", "/infill", data={
"input_extra": [{
"filename": "llama.h",
"text": "LLAMA_API int32_t llama_n_threads();\n"
}],
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
"prompt": " int n_threads = llama_",
"input_suffix": "}\n",
})
assert res.status_code == 200
assert res.body["content"] == "n_threads();\n printf(\"Number of threads: %d\\n\", n_threads);\n return 0;\n"

View file

@ -64,6 +64,7 @@ class ServerProcess:
server_embeddings: bool | None = False server_embeddings: bool | None = False
server_reranking: bool | None = False server_reranking: bool | None = False
server_metrics: bool | None = False server_metrics: bool | None = False
server_slots: bool | None = False
draft: int | None = None draft: int | None = None
api_key: str | None = None api_key: str | None = None
response_format: str | None = None response_format: str | None = None
@ -71,6 +72,7 @@ class ServerProcess:
disable_ctx_shift: int | None = False disable_ctx_shift: int | None = False
draft_min: int | None = None draft_min: int | None = None
draft_max: int | None = None draft_max: int | None = None
no_webui: bool | None = None
# session variables # session variables
process: subprocess.Popen | None = None process: subprocess.Popen | None = None
@ -91,7 +93,6 @@ class ServerProcess:
else: else:
server_path = "../../../build/bin/llama-server" server_path = "../../../build/bin/llama-server"
server_args = [ server_args = [
"--slots", # requires to get slot status via /slots endpoint
"--host", "--host",
self.server_host, self.server_host,
"--port", "--port",
@ -129,6 +130,8 @@ class ServerProcess:
server_args.append("--reranking") server_args.append("--reranking")
if self.server_metrics: if self.server_metrics:
server_args.append("--metrics") server_args.append("--metrics")
if self.server_slots:
server_args.append("--slots")
if self.model_alias: if self.model_alias:
server_args.extend(["--alias", self.model_alias]) server_args.extend(["--alias", self.model_alias])
if self.n_ctx: if self.n_ctx:
@ -156,6 +159,8 @@ class ServerProcess:
server_args.extend(["--draft-max", self.draft_max]) server_args.extend(["--draft-max", self.draft_max])
if self.draft_min: if self.draft_min:
server_args.extend(["--draft-min", self.draft_min]) server_args.extend(["--draft-min", self.draft_min])
if self.no_webui:
server_args.append("--no-webui")
args = [str(arg) for arg in [server_path, *server_args]] args = [str(arg) for arg in [server_path, *server_args]]
print(f"bench: starting server with: {' '.join(args)}") print(f"bench: starting server with: {' '.join(args)}")
@ -181,7 +186,7 @@ class ServerProcess:
start_time = time.time() start_time = time.time()
while time.time() - start_time < timeout_seconds: while time.time() - start_time < timeout_seconds:
try: try:
response = self.make_request("GET", "/slots", headers={ response = self.make_request("GET", "/health", headers={
"Authorization": f"Bearer {self.api_key}" if self.api_key else None "Authorization": f"Bearer {self.api_key}" if self.api_key else None
}) })
if response.status_code == 200: if response.status_code == 200:
@ -224,7 +229,7 @@ class ServerProcess:
result.headers = dict(response.headers) result.headers = dict(response.headers)
result.status_code = response.status_code result.status_code = response.status_code
result.body = response.json() if parse_body else None result.body = response.json() if parse_body else None
print("Response from server", result.body) print("Response from server", json.dumps(result.body, indent=2))
return result return result
def make_stream_request( def make_stream_request(
@ -245,7 +250,7 @@ class ServerProcess:
break break
elif line.startswith('data: '): elif line.startswith('data: '):
data = json.loads(line[6:]) data = json.loads(line[6:])
print("Partial response from server", data) print("Partial response from server", json.dumps(data, indent=2))
yield data yield data
@ -369,3 +374,6 @@ def match_regex(regex: str, text: str) -> bool:
).search(text) ).search(text)
is not None is not None
) )
def is_slow_test_allowed():
return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"

View file

@ -164,6 +164,9 @@ static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, con
} else { } else {
throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts"); throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
} }
if (result.empty()) {
throw std::runtime_error("\"prompt\" must not be empty");
}
return result; return result;
} }
@ -327,12 +330,12 @@ static std::string llama_get_chat_template(const struct llama_model * model) {
std::string template_key = "tokenizer.chat_template"; std::string template_key = "tokenizer.chat_template";
// call with NULL buffer to get the total size of the string // call with NULL buffer to get the total size of the string
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0); int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0);
if (res < 0) { if (res < 2) {
return ""; return "";
} else { } else {
std::vector<char> model_template(res, 0); std::vector<char> model_template(res, 0);
llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
return std::string(model_template.data(), model_template.size()); return std::string(model_template.data(), model_template.size() - 1);
} }
} }
@ -496,8 +499,6 @@ static json oaicompat_completion_params_parse(
const std::string & chat_template) { const std::string & chat_template) {
json llama_params; json llama_params;
llama_params["__oaicompat"] = true;
// Apply chat template to the list of messages // Apply chat template to the list of messages
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages")); llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
@ -648,3 +649,18 @@ static json format_detokenized_response(const std::string & content) {
{"content", content} {"content", content}
}; };
} }
static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) {
json data = json::array();
for (const auto & lb : logit_bias) {
data.push_back(json{
{"bias", lb.bias},
{"token", lb.token},
});
}
return data;
}
static std::string safe_json_to_str(json data) {
return data.dump(-1, ' ', false, json::error_handler_t::replace);
}

View file

@ -228,6 +228,7 @@ extern "C" {
GGML_API void ggml_backend_unload(ggml_backend_reg_t reg); GGML_API void ggml_backend_unload(ggml_backend_reg_t reg);
// Load all known backends from dynamic libraries // Load all known backends from dynamic libraries
GGML_API void ggml_backend_load_all(void); GGML_API void ggml_backend_load_all(void);
GGML_API void ggml_backend_load_all_from_path(const char * dir_path);
// //
// Backend scheduler // Backend scheduler

View file

@ -103,24 +103,14 @@ extern "C" {
// Internal types and functions exposed for tests and benchmarks // Internal types and functions exposed for tests and benchmarks
typedef void (*ggml_from_float_to_mat_t)
(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr, int64_t k, int64_t bs);
typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx, typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
const void * GGML_RESTRICT y, size_t by, int nrc); const void * GGML_RESTRICT y, size_t by, int nrc);
typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
const void * GGML_RESTRICT y, int nr, int nc);
typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
const void * GGML_RESTRICT y, int nr, int nc);
struct ggml_type_traits_cpu { struct ggml_type_traits_cpu {
ggml_from_float_t from_float; ggml_from_float_t from_float;
ggml_from_float_to_mat_t from_float_to_mat;
ggml_vec_dot_t vec_dot; ggml_vec_dot_t vec_dot;
enum ggml_type vec_dot_type; enum ggml_type vec_dot_type;
int64_t nrows; // number of rows to process simultaneously int64_t nrows; // number of rows to process simultaneously
int64_t ncols; // number of columns to process simultaneously
ggml_gemv_t gemv;
ggml_gemm_t gemm;
}; };
GGML_BACKEND_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type); GGML_BACKEND_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type);
@ -140,13 +130,6 @@ extern "C" {
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
#ifdef GGML_USE_CPU_HBM
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
#endif
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void);
GGML_BACKEND_API bool ggml_backend_cpu_buft_is_aarch64(ggml_backend_buffer_type_t buft);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View file

@ -390,15 +390,15 @@ extern "C" {
GGML_TYPE_F64 = 28, GGML_TYPE_F64 = 28,
GGML_TYPE_IQ1_M = 29, GGML_TYPE_IQ1_M = 29,
GGML_TYPE_BF16 = 30, GGML_TYPE_BF16 = 30,
GGML_TYPE_Q4_0_4_4 = 31, // GGML_TYPE_Q4_0_4_4 = 31, support has been removed from gguf files
GGML_TYPE_Q4_0_4_8 = 32, // GGML_TYPE_Q4_0_4_8 = 32,
GGML_TYPE_Q4_0_8_8 = 33, // GGML_TYPE_Q4_0_8_8 = 33,
GGML_TYPE_TQ1_0 = 34, GGML_TYPE_TQ1_0 = 34,
GGML_TYPE_TQ2_0 = 35, GGML_TYPE_TQ2_0 = 35,
GGML_TYPE_IQ4_NL_4_4 = 36, // GGML_TYPE_IQ4_NL_4_4 = 36,
// GGML_TYPE_IQ4_NL_4_8 = 37, // GGML_TYPE_IQ4_NL_4_8 = 37,
// GGML_TYPE_IQ4_NL_8_8 = 38, // GGML_TYPE_IQ4_NL_8_8 = 38,
GGML_TYPE_COUNT, GGML_TYPE_COUNT = 39,
}; };
// precision // precision
@ -439,9 +439,6 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors
GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors
GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors
}; };
// available tensor operations: // available tensor operations:
@ -2211,11 +2208,19 @@ extern "C" {
GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx); GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data); GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data);
#ifdef __cplusplus #ifdef __cplusplus
// restrict not standard in C++ // restrict not standard in C++
#define GGML_RESTRICT # if defined(__GNUC__)
# define GGML_RESTRICT __restrict__
# elif defined(__clang__)
# define GGML_RESTRICT __restrict
# elif defined(_MSC_VER)
# define GGML_RESTRICT __restrict
# else
# define GGML_RESTRICT
# endif
#else #else
#define GGML_RESTRICT restrict # define GGML_RESTRICT restrict
#endif #endif
typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);

View file

@ -1,129 +0,0 @@
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "ggml-aarch64.h"
#include "ggml-impl.h"
#include "ggml-quants.h"
#include <assert.h>
#define UNUSED GGML_UNUSED
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
block_q4_0x4 out;
for (int i = 0; i < 4; i++) {
out.d[i] = in[i].d;
}
const int end = QK4_0 * 2 / blck_size_interleave;
if (blck_size_interleave == 8) {
const uint64_t xor_mask = 0x8888888888888888ULL;
for (int i = 0; i < end; ++i) {
int src_id = i % 4;
int src_offset = (i / 4) * blck_size_interleave;
int dst_offset = i * blck_size_interleave;
uint64_t elems;
// Using memcpy to avoid unaligned memory accesses
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
elems ^= xor_mask;
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
}
} else if (blck_size_interleave == 4) {
const uint32_t xor_mask = 0x88888888;
for (int i = 0; i < end; ++i) {
int src_id = i % 4;
int src_offset = (i / 4) * blck_size_interleave;
int dst_offset = i * blck_size_interleave;
uint32_t elems;
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint32_t));
elems ^= xor_mask;
memcpy(&out.qs[dst_offset], &elems, sizeof(uint32_t));
}
} else {
GGML_ASSERT(false);
}
return out;
}
// interleave 8 block_q4_0s in blocks of blck_size_interleave
// returns an interleaved block_q4_0x8
// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) {
block_q4_0x8 out;
for (int i = 0; i < 8; i++) {
out.d[i] = in[i].d;
}
const int end = QK4_0 * 4 / blck_size_interleave;
const uint64_t xor_mask = 0x8888888888888888ULL;
for (int i = 0; i < end; ++i) {
int src_id = i % 8;
int src_offset = (i / 8) * blck_size_interleave;
int dst_offset = i * blck_size_interleave;
uint64_t elems;
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
elems ^= xor_mask;
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
}
return out;
}
static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int blck_size_interleave) {
assert(n_per_row % QK4_0 == 0);
const int nb = n_per_row / QK4_0;
void * out_ptr = NULL;
if (nrows_interleaved == 8) {
out_ptr = (block_q4_0x8 *) dst;
}
else if (nrows_interleaved == 4) {
out_ptr = (block_q4_0x4 *) dst;
}
assert(nrows_interleaved <= 8);
block_q4_0 dst_tmp[8];
for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) {
for (int64_t x = 0; x < nb; x++) {
for (int i = 0; i < nrows_interleaved; i++ ) {
quantize_row_q4_0_ref(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0);
}
if (nrows_interleaved == 8) {
*(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave);
out_ptr = (block_q4_0x8 *) out_ptr + 1;
}
else if (nrows_interleaved == 4) {
*(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave);
out_ptr = (block_q4_0x4 *) out_ptr + 1;
}
}
}
return ((nrow * n_per_row) / QK4_0 * sizeof(block_q4_0));
}
size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
UNUSED(quant_weights);
return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4);
}
size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
UNUSED(quant_weights);
return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8);
}
size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
UNUSED(quant_weights);
return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8);
}

View file

@ -1,19 +0,0 @@
#pragma once
#include "ggml.h"
// GGML internal header
#ifdef __cplusplus
extern "C" {
#endif
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q4_0_4x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
#ifdef __cplusplus
}
#endif

View file

@ -449,14 +449,26 @@ static std::string backend_filename_suffix() {
#endif #endif
} }
static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent) { static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, const char * user_search_path) {
// enumerate all the files that match [lib]ggml-name-*.[so|dll] in the search paths
// TODO: search system paths
//not available as we don't want c++17 //not available as we don't want c++17
printf("\nggml_backend_load_best NOT AVAILABLE!\n"); printf("\nggml_backend_load_best NOT AVAILABLE!\n");
return nullptr; return nullptr;
// // enumerate all the files that match [lib]ggml-name-*.[so|dll] in the search paths
// // TODO: search system paths
// std::vector<std::string> search_paths = { "./", get_executable_path() };
// std::string file_prefix = backend_filename_prefix() + name + "-"; // std::string file_prefix = backend_filename_prefix() + name + "-";
// std::vector<std::string> search_paths;
// if (user_search_path == nullptr) {
// search_paths.push_back("./");
// search_paths.push_back(get_executable_path());
// } else {
// #if defined(_WIN32)
// search_paths.push_back(std::string(user_search_path) + "\\");
// #else
// search_paths.push_back(std::string(user_search_path) + "/");
// #endif
// }
// int best_score = 0; // int best_score = 0;
// std::string best_path; // std::string best_path;
@ -512,21 +524,25 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent)
} }
void ggml_backend_load_all() { void ggml_backend_load_all() {
ggml_backend_load_all_from_path(nullptr);
}
void ggml_backend_load_all_from_path(const char * dir_path) {
#ifdef NDEBUG #ifdef NDEBUG
bool silent = true; bool silent = true;
#else #else
bool silent = false; bool silent = false;
#endif #endif
ggml_backend_load_best("blas", silent); ggml_backend_load_best("blas", silent, dir_path);
ggml_backend_load_best("cann", silent); ggml_backend_load_best("cann", silent, dir_path);
ggml_backend_load_best("cuda", silent); ggml_backend_load_best("cuda", silent, dir_path);
ggml_backend_load_best("hip", silent); ggml_backend_load_best("hip", silent, dir_path);
ggml_backend_load_best("kompute", silent); ggml_backend_load_best("kompute", silent, dir_path);
ggml_backend_load_best("metal", silent); ggml_backend_load_best("metal", silent, dir_path);
ggml_backend_load_best("rpc", silent); ggml_backend_load_best("rpc", silent, dir_path);
ggml_backend_load_best("sycl", silent); ggml_backend_load_best("sycl", silent, dir_path);
ggml_backend_load_best("vulkan", silent); ggml_backend_load_best("vulkan", silent, dir_path);
ggml_backend_load_best("musa", silent); ggml_backend_load_best("musa", silent, dir_path);
ggml_backend_load_best("cpu", silent); ggml_backend_load_best("cpu", silent, dir_path);
} }

View file

@ -2089,7 +2089,7 @@ static void * ggml_backend_cann_reg_get_proc_address(ggml_backend_reg_t reg, con
static const ggml_backend_reg_i ggml_backend_cann_reg_interface = { static const ggml_backend_reg_i ggml_backend_cann_reg_interface = {
/* .get_name = */ ggml_backend_cann_reg_get_name, /* .get_name = */ ggml_backend_cann_reg_get_name,
/* .get_device_count = */ ggml_backend_cann_reg_get_device_count, /* .get_device_count = */ ggml_backend_cann_reg_get_device_count,
/* .get_device_get = */ ggml_backend_cann_reg_get_device, /* .get_device = */ ggml_backend_cann_reg_get_device,
/* .get_proc_address = */ ggml_backend_cann_reg_get_proc_address, /* .get_proc_address = */ ggml_backend_cann_reg_get_proc_address,
}; };

View file

@ -6,7 +6,20 @@
typedef uint16_t ggml_half; typedef uint16_t ggml_half;
typedef uint32_t ggml_half2; typedef uint32_t ggml_half2;
#define GGML_COMMON_AGGR #define GGML_COMMON_AGGR_U
#define GGML_COMMON_AGGR_S
#define GGML_COMMON_DECL
#elif defined(GGML_COMMON_DECL_CPP)
#include <cstdint>
typedef uint16_t ggml_half;
typedef uint32_t ggml_half2;
// std-c++ allow anonymous unions but some compiler warn on it
#define GGML_COMMON_AGGR_U data
// std-c++ do not allow it.
#define GGML_COMMON_AGGR_S data
#define GGML_COMMON_DECL #define GGML_COMMON_DECL
#elif defined(GGML_COMMON_DECL_METAL) #elif defined(GGML_COMMON_DECL_METAL)
@ -15,7 +28,8 @@ typedef uint32_t ggml_half2;
typedef half ggml_half; typedef half ggml_half;
typedef half2 ggml_half2; typedef half2 ggml_half2;
#define GGML_COMMON_AGGR #define GGML_COMMON_AGGR_U
#define GGML_COMMON_AGGR_S
#define GGML_COMMON_DECL #define GGML_COMMON_DECL
#elif defined(GGML_COMMON_DECL_CUDA) #elif defined(GGML_COMMON_DECL_CUDA)
@ -29,7 +43,8 @@ typedef half2 ggml_half2;
typedef half ggml_half; typedef half ggml_half;
typedef half2 ggml_half2; typedef half2 ggml_half2;
#define GGML_COMMON_AGGR data #define GGML_COMMON_AGGR_U
#define GGML_COMMON_AGGR_S data
#define GGML_COMMON_DECL #define GGML_COMMON_DECL
#elif defined(GGML_COMMON_DECL_HIP) #elif defined(GGML_COMMON_DECL_HIP)
@ -39,7 +54,8 @@ typedef half2 ggml_half2;
typedef half ggml_half; typedef half ggml_half;
typedef half2 ggml_half2; typedef half2 ggml_half2;
#define GGML_COMMON_AGGR data #define GGML_COMMON_AGGR_U
#define GGML_COMMON_AGGR_S data
#define GGML_COMMON_DECL #define GGML_COMMON_DECL
#elif defined(GGML_COMMON_DECL_SYCL) #elif defined(GGML_COMMON_DECL_SYCL)
@ -49,7 +65,8 @@ typedef half2 ggml_half2;
typedef sycl::half ggml_half; typedef sycl::half ggml_half;
typedef sycl::half2 ggml_half2; typedef sycl::half2 ggml_half2;
#define GGML_COMMON_AGGR data #define GGML_COMMON_AGGR_U
#define GGML_COMMON_AGGR_S data
#define GGML_COMMON_DECL #define GGML_COMMON_DECL
#endif #endif
@ -154,9 +171,9 @@ typedef struct {
struct { struct {
ggml_half d; // delta ggml_half d; // delta
ggml_half m; // min ggml_half m; // min
} GGML_COMMON_AGGR; } GGML_COMMON_AGGR_S;
ggml_half2 dm; ggml_half2 dm;
}; } GGML_COMMON_AGGR_U;
uint8_t qs[QK4_1 / 2]; // nibbles / quants uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1; } block_q4_1;
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding"); static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
@ -175,9 +192,9 @@ typedef struct {
struct { struct {
ggml_half d; // delta ggml_half d; // delta
ggml_half m; // min ggml_half m; // min
} GGML_COMMON_AGGR; } GGML_COMMON_AGGR_S;
ggml_half2 dm; ggml_half2 dm;
}; } GGML_COMMON_AGGR_U;
uint8_t qh[4]; // 5-th bit of quants uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_1 / 2]; // nibbles / quants uint8_t qs[QK5_1 / 2]; // nibbles / quants
} block_q5_1; } block_q5_1;
@ -196,37 +213,13 @@ typedef struct {
struct { struct {
ggml_half d; // delta ggml_half d; // delta
ggml_half s; // d * sum(qs[i]) ggml_half s; // d * sum(qs[i])
} GGML_COMMON_AGGR; } GGML_COMMON_AGGR_S;
ggml_half2 ds; ggml_half2 ds;
}; } GGML_COMMON_AGGR_U;
int8_t qs[QK8_1]; // quants int8_t qs[QK8_1]; // quants
} block_q8_1; } block_q8_1;
static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, "wrong q8_1 block size/padding"); static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, "wrong q8_1 block size/padding");
typedef struct {
ggml_half d[4]; // deltas for 4 q4_0 blocks
uint8_t qs[QK4_0 * 2]; // nibbles / quants for 4 q4_0 blocks
} block_q4_0x4;
static_assert(sizeof(block_q4_0x4) == 4 * sizeof(ggml_half) + QK4_0 * 2, "wrong q4_0x4 block size/padding");
typedef struct {
ggml_half d[8]; // deltas for 8 q4_0 blocks
uint8_t qs[QK4_0 * 4]; // nibbles / quants for 8 q4_0 blocks
} block_q4_0x8;
static_assert(sizeof(block_q4_0x8) == 8 * sizeof(ggml_half) + QK4_0 * 4, "wrong q4_0x8 block size/padding");
typedef struct {
ggml_half d[4]; // deltas for 4 q8_0 blocks
int8_t qs[QK8_0 * 4]; // quants for 4 q8_0 blocks
} block_q8_0x4;
static_assert(sizeof(block_q8_0x4) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong q8_0x4 block size/padding");
typedef struct {
ggml_half d[8]; // deltas for 8 q8_0 blocks
int8_t qs[QK8_0 * 8]; // quants for 8 q8_0 blocks
} block_q8_0x8;
static_assert(sizeof(block_q8_0x8) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong q8_0x8 block size/padding");
// //
// Ternary quantization // Ternary quantization
// //
@ -261,9 +254,9 @@ typedef struct {
struct { struct {
ggml_half d; // super-block scale for quantized scales ggml_half d; // super-block scale for quantized scales
ggml_half dmin; // super-block scale for quantized mins ggml_half dmin; // super-block scale for quantized mins
} GGML_COMMON_AGGR; } GGML_COMMON_AGGR_S;
ggml_half2 dm; ggml_half2 dm;
}; } GGML_COMMON_AGGR_U;
} block_q2_K; } block_q2_K;
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
@ -288,9 +281,9 @@ typedef struct {
struct { struct {
ggml_half d; // super-block scale for quantized scales ggml_half d; // super-block scale for quantized scales
ggml_half dmin; // super-block scale for quantized mins ggml_half dmin; // super-block scale for quantized mins
} GGML_COMMON_AGGR; } GGML_COMMON_AGGR_S;
ggml_half2 dm; ggml_half2 dm;
}; } GGML_COMMON_AGGR_U;
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K; } block_q4_K;
@ -305,9 +298,9 @@ typedef struct {
struct { struct {
ggml_half d; // super-block scale for quantized scales ggml_half d; // super-block scale for quantized scales
ggml_half dmin; // super-block scale for quantized mins ggml_half dmin; // super-block scale for quantized mins
} GGML_COMMON_AGGR; } GGML_COMMON_AGGR_S;
ggml_half2 dm; ggml_half2 dm;
}; } GGML_COMMON_AGGR_U;
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qh[QK_K/8]; // quants, high bit uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits uint8_t qs[QK_K/2]; // quants, low 4 bits
@ -418,12 +411,6 @@ typedef struct {
} block_iq4_xs; } block_iq4_xs;
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
typedef struct {
ggml_half d[4]; // deltas for 4 iq4_nl blocks
uint8_t qs[QK4_NL * 2];// nibbles / quants for 4 iq4_nl blocks
} block_iq4_nlx4;
static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
#endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL
#endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL
@ -437,6 +424,13 @@ static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wro
#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = { #define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
#define GGML_TABLE_END() }; #define GGML_TABLE_END() };
#define GGML_COMMON_IMPL
#elif defined(GGML_COMMON_IMPL_CPP)
#include <cstdint>
#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
#define GGML_TABLE_END() };
#define GGML_COMMON_IMPL #define GGML_COMMON_IMPL
#elif defined(GGML_COMMON_IMPL_METAL) #elif defined(GGML_COMMON_IMPL_METAL)
#include <metal_stdlib> #include <metal_stdlib>
@ -479,7 +473,7 @@ GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128)
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255, 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
GGML_TABLE_END() GGML_TABLE_END()
//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics //#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A // lowest compute capability for integer intrinsics
GGML_TABLE_BEGIN(uint64_t, ksigns64, 128) GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff, 0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff, 0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,

View file

@ -5,6 +5,7 @@
#include "ggml-backend.h" #include "ggml-backend.h"
#include "ggml-impl.h" #include "ggml-impl.h"
#include "ggml-cpu.h" #include "ggml-cpu.h"
#include "ggml-cpu-traits.h"
#if defined(__gnu_linux__) #if defined(__gnu_linux__)
#include <sys/syscall.h> #include <sys/syscall.h>
@ -17,31 +18,65 @@
#if defined(__AMX_INT8__) && defined(__AVX512VNNI__) #if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
// AMX type_trais
namespace ggml::cpu::amx {
class tensor_traits : public ggml::cpu::tensor_traits {
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
size = ggml_backend_amx_desired_wsize(op);
return true;
}
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
if (op->op == GGML_OP_MUL_MAT) {
ggml_backend_amx_mul_mat(params, op);
return true;
}
return false;
}
};
static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) {
static tensor_traits traits;
return &traits;
}
} // namespace ggml::cpu::amx
// AMX buffer interface // AMX buffer interface
static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) { static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) {
free(buffer->context); free(buffer->context);
} }
static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) { static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) {
return (void *)(buffer->context); return (void *) (buffer->context);
} }
static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { static void ggml_backend_amx_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
memset((char *)tensor->data + offset, value, size); tensor->extra = (void *) ggml::cpu::amx::get_tensor_traits(buffer, tensor);
GGML_UNUSED(buffer); GGML_UNUSED(buffer);
} }
static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
uint8_t value, size_t offset, size_t size) {
memset((char *) tensor->data + offset, value, size);
GGML_UNUSED(buffer);
}
static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
const void * data, size_t offset, size_t size) {
if (qtype_has_amx_kernels(tensor->type)) { if (qtype_has_amx_kernels(tensor->type)) {
GGML_LOG_DEBUG("%s: amx repack tensor %s of type %s\n", __func__, tensor->name, ggml_type_name(tensor->type));
ggml_backend_amx_convert_weight(tensor, data, offset, size); ggml_backend_amx_convert_weight(tensor, data, offset, size);
} else { } else {
memcpy((char *)tensor->data + offset, data, size); memcpy((char *) tensor->data + offset, data, size);
} }
GGML_UNUSED(buffer); GGML_UNUSED(buffer);
} }
/*
// need to figure what we need to do with buffer->extra.
static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
GGML_ASSERT(!qtype_has_amx_kernels(tensor->type)); GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
memcpy(data, (const char *)tensor->data + offset, size); memcpy(data, (const char *)tensor->data + offset, size);
@ -62,6 +97,7 @@ static bool ggml_backend_amx_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
GGML_UNUSED(buffer); GGML_UNUSED(buffer);
} }
*/
static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
memset(buffer->context, value, buffer->size); memset(buffer->context, value, buffer->size);
@ -70,13 +106,13 @@ static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t
static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = { static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
/* .free_buffer = */ ggml_backend_amx_buffer_free_buffer, /* .free_buffer = */ ggml_backend_amx_buffer_free_buffer,
/* .get_base = */ ggml_backend_amx_buffer_get_base, /* .get_base = */ ggml_backend_amx_buffer_get_base,
/* .init_tensor = */ NULL, // no initialization required /* .init_tensor = */ ggml_backend_amx_buffer_init_tensor,
/* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor, /* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_amx_buffer_set_tensor, /* .set_tensor = */ ggml_backend_amx_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_amx_buffer_get_tensor, /* .get_tensor = */ nullptr,
/* .cpy_tensor = */ ggml_backend_amx_buffer_cpy_tensor, /* .cpy_tensor = */ nullptr,
/* .clear = */ ggml_backend_amx_buffer_clear, /* .clear = */ ggml_backend_amx_buffer_clear,
/* .reset = */ NULL, /* .reset = */ nullptr,
}; };
static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) { static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
@ -101,18 +137,48 @@ static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_typ
GGML_UNUSED(buft); GGML_UNUSED(buft);
} }
static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) { namespace ggml::cpu::amx {
class extra_buffer_type : ggml::cpu::extra_buffer_type {
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
// handle only 2d gemm for now
auto is_contiguous_2d = [](const struct ggml_tensor * t) {
return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
};
if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous
is_contiguous_2d(op->src[1]) && // src1 must be contiguous
op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() &&
op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x
(qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) {
// src1 must be host buffer
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}
// src1 must be float32
if (op->src[1]->type == GGML_TYPE_F32) {
return true;
}
}
return false;
}
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
if (op->op == GGML_OP_MUL_MAT && op->src[0]->buffer &&
op->src[0]->buffer->buft == ggml_backend_amx_buffer_type()) {
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
}
return nullptr;
}
};
} // namespace ggml::cpu::amx
static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
return ggml_backend_amx_get_alloc_size(tensor); return ggml_backend_amx_get_alloc_size(tensor);
GGML_UNUSED(buft); GGML_UNUSED(buft);
} }
static bool ggml_backend_amx_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
return false;
GGML_UNUSED(buft);
}
#define ARCH_GET_XCOMP_PERM 0x1022 #define ARCH_GET_XCOMP_PERM 0x1022
#define ARCH_REQ_XCOMP_PERM 0x1023 #define ARCH_REQ_XCOMP_PERM 0x1023
#define XFEATURE_XTILECFG 17 #define XFEATURE_XTILECFG 17
@ -129,68 +195,26 @@ static bool ggml_amx_init() {
return true; return true;
#endif #endif
} }
ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() { ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() {
static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = { static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = {
/* .iface = */ { /* .iface = */ {
/* .get_name = */ ggml_backend_amx_buffer_type_get_name, /* .get_name = */ ggml_backend_amx_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_amx_buffer_type_alloc_buffer, /* .alloc_buffer = */ ggml_backend_amx_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_amx_buffer_type_get_alignment, /* .get_alignment = */ ggml_backend_amx_buffer_type_get_alignment,
/* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_max_size = */ nullptr, // defaults to SIZE_MAX
/* .get_alloc_size = */ ggml_backend_amx_buffer_type_get_alloc_size, /* .get_alloc_size = */ ggml_backend_amx_buffer_type_get_alloc_size,
/* .is_host = */ ggml_backend_amx_buffer_type_is_host, /* .is_host = */ nullptr,
}, },
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
/* .context = */ NULL, /* .context = */ new ggml::cpu::amx::extra_buffer_type(),
}; };
if (!ggml_amx_init()) { if (!ggml_amx_init()) {
return NULL; return nullptr;
} }
return &ggml_backend_buffer_type_amx; return &ggml_backend_buffer_type_amx;
} }
bool ggml_backend_amx_buft_is_amx(ggml_backend_buffer_type_t buft) { #endif // defined(__AMX_INT8__) && defined(__AVX512VNNI__)
return buft->iface.get_name == ggml_backend_amx_buffer_type_get_name;
}
bool ggml_backend_amx_device_supports_op(const struct ggml_tensor * op) {
// handle only 2d gemm for now
auto is_contiguous_2d = [](const struct ggml_tensor * t) {
return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
};
switch (op->op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
return true;
case GGML_OP_MUL_MAT: {
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * src1 = op->src[1];
const enum ggml_type type = src0->type;
const int64_t ne0 = op->ne[0];
// amx kernels enables for Q4_0, Q4_1, Q8_0, F16
// Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);
bool can_use_amx =
is_contiguous_2d(src0) && // src0 must be contiguous
is_contiguous_2d(src1) && // src1 must be contiguous
src1->type == GGML_TYPE_F32 && // src1 must be float32
has_amx_kernels && // with amx kernel impls
ne0 % (TILE_N * 2) == 0; // out_features is 32x
return can_use_amx;
}
default:
return false;
}
}
#endif // defined(__AMX_INT8__) && defined(__AVX512VNNI__)

View file

@ -1,20 +1,8 @@
#include "ggml-backend.h" #include "ggml-backend.h"
#include "ggml-cpu-impl.h" #include "ggml-cpu-impl.h"
#ifdef __cplusplus // GGML internal header
extern "C" {
#endif
#if defined(__AMX_INT8__) && defined(__AVX512VNNI__) #if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void); ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void);
bool ggml_backend_amx_buft_is_amx(ggml_backend_buffer_type_t buft);
bool ggml_backend_amx_device_supports_op(const struct ggml_tensor * op);
void ggml_backend_amx_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst);
#endif
#ifdef __cplusplus
}
#endif #endif

View file

@ -7,7 +7,7 @@
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#if defined(_OPENMP) #if defined(GGML_USE_OPENMP)
#include <omp.h> #include <omp.h>
#endif #endif
@ -56,11 +56,11 @@ inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
} }
template <typename func_t> template <typename func_t>
inline void parallel_for(int nth, int n, const func_t& f) { inline void parallel_for(int n, const func_t& f) {
#if defined(_OPENMP) #if defined(GGML_USE_OPENMP)
#pragma omp parallel num_threads(nth) #pragma omp parallel
{ {
//int nth = omp_get_num_threads(); int nth = omp_get_num_threads();
int ith = omp_get_thread_num(); int ith = omp_get_thread_num();
int tbegin, tend; int tbegin, tend;
balance211(n, nth, ith, tbegin, tend); balance211(n, nth, ith, tbegin, tend);
@ -68,8 +68,6 @@ inline void parallel_for(int nth, int n, const func_t& f) {
} }
#else #else
f(0, n); f(0, n);
GGML_UNUSED(nth);
#endif #endif
} }
@ -91,10 +89,3 @@ inline bool qtype_has_amx_kernels(const enum ggml_type type) {
(type == GGML_TYPE_Q6_K) || (type == GGML_TYPE_Q6_K) ||
(type == GGML_TYPE_IQ4_XS); (type == GGML_TYPE_IQ4_XS);
} }
// ggml backend context
struct ggml_backend_amx_context {
int n_threads = GGML_DEFAULT_N_THREADS;
std::unique_ptr<char[]> work_data;
size_t work_size = 0;
};

View file

@ -18,10 +18,6 @@
#include <unistd.h> #include <unistd.h>
#endif #endif
#if defined(_OPENMP)
#include <omp.h>
#endif
#if (defined(_WIN32) || defined(_WIN64)) #if (defined(_WIN32) || defined(_WIN64))
#define RESTRICT __restrict #define RESTRICT __restrict
#else #else
@ -1382,13 +1378,13 @@ struct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K>
#define PACKED_INDEX(n, k, KB, tile_size) (n * KB + k) * tile_size #define PACKED_INDEX(n, k, KB, tile_size) (n * KB + k) * tile_size
template<typename TB, int BLOCK_K> template<typename TB, int BLOCK_K>
void convert_B_packed_format(void * RESTRICT packed_B, const TB * RESTRICT B, int N, int K, int n_threads) { void convert_B_packed_format(void * RESTRICT packed_B, const TB * RESTRICT B, int N, int K) {
const int NB = N / TILE_N; const int NB = N / TILE_N;
const int KB = K / BLOCK_K; const int KB = K / BLOCK_K;
const int TILE_SIZE = get_tile_size<TB>(); const int TILE_SIZE = get_tile_size<TB>();
// parallel on NB should be enough // parallel on NB should be enough
parallel_for(n_threads, NB, [&](int begin, int end) { parallel_for(NB, [&](int begin, int end) {
for (int n = begin; n < end; ++n) { for (int n = begin; n < end; ++n) {
for (int k = 0; k < KB; ++k) { for (int k = 0; k < KB; ++k) {
int n0 = n * TILE_N; int n0 = n * TILE_N;
@ -2334,15 +2330,8 @@ void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * d
const int K = tensor->ne[0]; // ne0: in_features const int K = tensor->ne[0]; // ne0: in_features
const int N = tensor->ne[1]; // ne1: out_features const int N = tensor->ne[1]; // ne1: out_features
#if defined(_OPENMP)
// the buffer ctx is not initialized when .set_tensor is called
int n_threads = omp_get_num_threads();
#else
int n_threads = 1;
#endif
GGML_DISPATCH_QTYPES(TYPE, [&] { GGML_DISPATCH_QTYPES(TYPE, [&] {
convert_B_packed_format<type, blck_size>((void *)((char *)tensor->data + offset), (const type *)data, N, K, n_threads); convert_B_packed_format<type, blck_size>((void *)((char *)tensor->data + offset), (const type *)data, N, K);
}); });
} }

View file

@ -1,16 +1,10 @@
#pragma once #pragma once
#include "common.h" #include "common.h"
#ifdef __cplusplus size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst);
extern "C" {
#endif
size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor); size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor);
void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void ggml_backend_amx_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_backend_amx_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
#ifdef __cplusplus
}
#endif

View file

@ -1,20 +1,72 @@
#define GGML_COMMON_IMPL_C #define GGML_COMMON_IMPL_CPP
#define GGML_COMMON_DECL_CPP
#include "ggml-common.h" #include "ggml-common.h"
#include "ggml-backend-impl.h"
#include "ggml-quants.h" #include "ggml-quants.h"
#include "ggml-impl.h" #include "ggml-impl.h"
#include "ggml-cpu.h" #include "ggml-cpu.h"
#include "ggml-cpu/ggml-cpu-impl.h" #include "ggml-cpu-impl.h"
#include "ggml-cpu-traits.h"
#include <math.h> #include <cmath>
#include <string.h> #include <cstring>
#include <assert.h> #include <cassert>
#include <float.h> #include <cfloat>
#include <stdlib.h> // for qsort #include <cstdlib> // for qsort
#include <stdio.h> // for GGML_ASSERT #include <cstdio> // for GGML_ASSERT
#include "ggml-cpu-aarch64.h" #include "ggml-cpu-aarch64.h"
// TODO: move to include file?
// template <int K> constexpr int QK_0() {
// if constexpr (K == 4) {
// return QK4_0;
// }
// if constexpr (K == 8) {
// return QK8_0;
// }
// return -1;
// }
template <int K> struct QK_0_Helper {
static constexpr int value = -1;
};
// Specialization for K == 4
template <> struct QK_0_Helper<4> {
static constexpr int value = QK4_0;
};
// Specialization for K == 8
template <> struct QK_0_Helper<8> {
static constexpr int value = QK8_0;
};
// Access the value using QK_0<K>::value
template <int K> constexpr int QK_0() {
return QK_0_Helper<K>::value;
}
template <int K, int N> struct block {
ggml_half d[N]; // deltas for N qK_0 blocks
int8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks
};
// control size
static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
using block_q4_0x4 = block<4, 4>;
using block_q4_0x8 = block<4, 8>;
using block_q8_0x4 = block<8, 4>;
using block_q8_0x8 = block<8, 8>;
struct block_iq4_nlx4 {
ggml_half d[4]; // deltas for 4 iq4_nl blocks
uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks
};
static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
#if defined(__GNUC__) #if defined(__GNUC__)
#pragma GCC diagnostic ignored "-Woverlength-strings" #pragma GCC diagnostic ignored "-Woverlength-strings"
#elif defined(_MSC_VER) #elif defined(_MSC_VER)
@ -185,12 +237,12 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
static void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) { static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(QK8_0 == 32); assert(QK8_0 == 32);
assert(k % QK8_0 == 0); assert(k % QK8_0 == 0);
const int nb = k / QK8_0; const int nb = k / QK8_0;
block_q8_0x4 * restrict y = (block_q8_0x4 *) vy; block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
#if defined(__ARM_NEON) #if defined(__ARM_NEON)
float32x4_t srcv[4][8]; float32x4_t srcv[4][8];
@ -279,12 +331,12 @@ static void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int6
#endif #endif
} }
static void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) { static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(QK8_0 == 32); assert(QK8_0 == 32);
assert(k % QK8_0 == 0); assert(k % QK8_0 == 0);
const int nb = k / QK8_0; const int nb = k / QK8_0;
block_q8_0x4 * restrict y = (block_q8_0x4 *) vy; block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
#if defined(__ARM_NEON) #if defined(__ARM_NEON)
float32x4_t srcv[4][8]; float32x4_t srcv[4][8];
@ -494,7 +546,7 @@ static void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int6
#endif #endif
} }
void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) { static void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
assert(nrow == 4); assert(nrow == 4);
UNUSED(nrow); UNUSED(nrow);
if (blck_size_interleave == 4) { if (blck_size_interleave == 4) {
@ -506,7 +558,7 @@ void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nro
} }
} }
void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
const int ncols_interleaved = 4; const int ncols_interleaved = 4;
@ -591,7 +643,7 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
} }
} }
void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { static void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
const int ncols_interleaved = 4; const int ncols_interleaved = 4;
@ -701,7 +753,7 @@ void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
} }
} }
void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
const int ncols_interleaved = 8; const int ncols_interleaved = 8;
@ -974,7 +1026,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
} }
} }
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { static void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
const int ncols_interleaved = 4; const int ncols_interleaved = 4;
@ -1070,7 +1122,7 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
} }
} }
void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { static void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
const int ncols_interleaved = 4; const int ncols_interleaved = 4;
@ -1586,7 +1638,7 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
} }
} }
void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { static void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
const int ncols_interleaved = 4; const int ncols_interleaved = 4;
@ -2040,7 +2092,7 @@ void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
} }
} }
void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
const int ncols_interleaved = 8; const int ncols_interleaved = 8;
@ -2560,31 +2612,31 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31) const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
// Shuffle pattern one - right side input // Shuffle pattern one - right side input
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3) const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3) const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11) const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11) const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19) const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19) const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27) const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27) const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
// Shuffle pattern two - right side input // Shuffle pattern two - right side input
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7) const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7) const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15) const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15) const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23) const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23) const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31) const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31) const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
// Scale values - Load the weight scale values of two block_q4_0x8 // Scale values - Load the weight scale values of two block_q4_0x8
const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
@ -2618,31 +2670,31 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
// Shuffle pattern one - left side input // Shuffle pattern one - left side input
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
// Shuffle pattern two - left side input // Shuffle pattern two - left side input
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
// Resembles MMLAs into 2x2 matrices in ARM Version // Resembles MMLAs into 2x2 matrices in ARM Version
@ -2671,10 +2723,10 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
// Straighten out to make 4 row vectors // Straighten out to make 4 row vectors
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78)); __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01); __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78)); __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11); __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68); const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);
@ -2753,31 +2805,31 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31) const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
// Shuffle pattern one - right side input // Shuffle pattern one - right side input
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3) const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3) const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11) const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11) const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19) const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19) const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27) const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27) const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
// Shuffle pattern two - right side input // Shuffle pattern two - right side input
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7) const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7) const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15) const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15) const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23) const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23) const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31) const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31) const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
// Scale values - Load the weight scale values of two block_q4_0x8 // Scale values - Load the weight scale values of two block_q4_0x8
@ -2809,31 +2861,31 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
// Shuffle pattern one - left side input // Shuffle pattern one - left side input
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
// Shuffle pattern two - left side input // Shuffle pattern two - left side input
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
// Resembles MMLAs into 2x2 matrices in ARM Version // Resembles MMLAs into 2x2 matrices in ARM Version
@ -2862,10 +2914,10 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
// Straighten out to make 4 row vectors // Straighten out to make 4 row vectors
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78)); __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01); __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78)); __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11); __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68); const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68);
@ -3460,7 +3512,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
} }
} }
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { static void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
const int ncols_interleaved = 4; const int ncols_interleaved = 4;
@ -3571,7 +3623,6 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
} }
} }
// FIXME: this code is duplicated from ggml-aarch64.c
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) { static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
block_q4_0x4 out; block_q4_0x4 out;
@ -3641,20 +3692,20 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in
return out; return out;
} }
static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) { static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_Q4_0); GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
GGML_ASSERT(interleave_block == 4 || interleave_block == 8); GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
constexpr int nrows_interleaved = 4;
block_q4_0x4 * dst = (block_q4_0x4 *)t->data; block_q4_0x4 * dst = (block_q4_0x4 *)t->data;
const block_q4_0 * src = (const block_q4_0 *)data; const block_q4_0 * src = (const block_q4_0 *)data;
block_q4_0 dst_tmp[4]; block_q4_0 dst_tmp[4];
int nrow = t->ne[1]; // Number of rows int nrow = ggml_nrows(t);
int nrows_interleaved = 4;
int nblocks = t->ne[0] / QK4_0; int nblocks = t->ne[0] / QK4_0;
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
return -1; return -1;
} }
@ -3672,20 +3723,20 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block
GGML_UNUSED(data_size); GGML_UNUSED(data_size);
} }
static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block, const void * restrict data, size_t data_size) { static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_Q4_0); GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
GGML_ASSERT(interleave_block == 8); GGML_ASSERT(interleave_block == 8);
constexpr int nrows_interleaved = 8;
block_q4_0x8 * dst = (block_q4_0x8*)t->data; block_q4_0x8 * dst = (block_q4_0x8*)t->data;
const block_q4_0 * src = (const block_q4_0*) data; const block_q4_0 * src = (const block_q4_0*) data;
block_q4_0 dst_tmp[8]; block_q4_0 dst_tmp[8];
int nrow = t->ne[1]; // Number of rows int nrow = ggml_nrows(t);
int nrows_interleaved = 8;
int nblocks = t->ne[0] / QK4_0; int nblocks = t->ne[0] / QK4_0;
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
return -1; return -1;
} }
@ -3712,16 +3763,18 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s
const int end = QK4_NL * 2 / blck_size_interleave; const int end = QK4_NL * 2 / blck_size_interleave;
if (blck_size_interleave == 8) { // TODO: this branch seems wrong
for (int i = 0; i < end; ++i) { //if (blck_size_interleave == 8) {
int src_id = i % 4; // for (int i = 0; i < end; ++i) {
int src_offset = (i / 4) * blck_size_interleave; // int src_id = i % 4;
int dst_offset = i * blck_size_interleave; // int src_offset = (i / 4) * blck_size_interleave;
// int dst_offset = i * blck_size_interleave;
// Using memcpy to avoid unaligned memory accesses // // Using memcpy to avoid unaligned memory accesses
memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
} // }
} else if (blck_size_interleave == 4) { //} else
if (blck_size_interleave == 4) {
for (int i = 0; i < end; ++i) { for (int i = 0; i < end; ++i) {
int src_id = i % 4; int src_id = i % 4;
int src_offset = (i / 4) * blck_size_interleave; int src_offset = (i / 4) * blck_size_interleave;
@ -3736,20 +3789,21 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s
return out; return out;
} }
static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) { static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
GGML_ASSERT(interleave_block == 4 || interleave_block == 8); //GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
GGML_ASSERT(interleave_block == 4);
block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data; block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data;
const block_iq4_nl * src = (const block_iq4_nl *)data; const block_iq4_nl * src = (const block_iq4_nl *)data;
block_iq4_nl dst_tmp[4]; block_iq4_nl dst_tmp[4];
int nrow = t->ne[1]; // Number of rows int nrow = ggml_nrows(t);
int nrows_interleaved = 4; int nrows_interleaved = 4;
int nblocks = t->ne[0] / QK4_0; int nblocks = t->ne[0] / QK4_0;
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
return -1; return -1;
} }
@ -3767,57 +3821,457 @@ static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_b
GGML_UNUSED(data_size); GGML_UNUSED(data_size);
} }
// Prepare for optimized kernels if applicable namespace ggml::cpu::aarch64 {
void ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * restrict data, size_t data_size) { // repack
if (cur->type == repack_type) { template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
memcpy(cur->data, data, data_size); int repack(struct ggml_tensor *, const void *, size_t);
return;
}
if (cur->type == GGML_TYPE_Q4_0) { // TODO: generalise.
switch (repack_type) { template <> int repack<block_q4_0, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
case GGML_TYPE_Q4_0_8_8: return repack_q4_0_to_q4_0_4_bl(t, 4, data, data_size);
repack_q4_0_to_q4_0_8_bl(cur, 8, data, data_size);
break;
case GGML_TYPE_Q4_0_4_8:
repack_q4_0_to_q4_0_4_bl(cur, 8, data, data_size);
break;
case GGML_TYPE_Q4_0_4_4:
repack_q4_0_to_q4_0_4_bl(cur, 4, data, data_size);
break;
default:
GGML_ABORT("Unsupported type");
}
} else if (cur->type == GGML_TYPE_IQ4_NL) {
switch (repack_type) {
case GGML_TYPE_IQ4_NL_4_4:
repack_iq4_nl_to_iq4_nl_4_bl(cur, 4, data, data_size);
break;
default:
GGML_ABORT("Unsupported type");
}
} else {
GGML_ABORT("Unsupported type");
}
} }
enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur) { template <> int repack<block_q4_0, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_q4_0_to_q4_0_4_bl(t, 8, data, data_size);
}
template <> int repack<block_q4_0, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);
}
template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
}
// TODO: needs to be revisited
//template <> int repack<block_iq4_nl, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
// return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size);
//}
// gemv
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
void gemv(int, float *, size_t, const void *, const void *, int, int);
template <> void gemv<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <>
void gemv<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
}
// gemm
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
void gemm(int, float *, size_t, const void *, const void *, int, int);
template <> void gemm<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <>
void gemm<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
}
class tensor_traits_base : public ggml::cpu::tensor_traits {
public:
virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
};
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
// not realy a GGML_TYPE_Q8_0 but same size.
switch (op->op) {
case GGML_OP_MUL_MAT:
size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1]));
return true;
case GGML_OP_MUL_MAT_ID:
size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1]));
size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
return true;
default:
// GGML_ABORT("fatal error");
break;
}
return false;
}
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
switch (op->op) {
case GGML_OP_MUL_MAT:
forward_mul_mat(params, op);
return true;
case GGML_OP_MUL_MAT_ID:
forward_mul_mat_id(params, op);
return true;
default:
// GGML_ABORT("fatal error");
break;
}
return false;
}
void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
const ggml_tensor * src0 = op->src[0];
const ggml_tensor * src1 = op->src[1];
ggml_tensor * dst = op;
GGML_TENSOR_BINARY_OP_LOCALS
const int ith = params->ith;
const int nth = params->nth;
GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);
// GGML_ASSERT(ggml_n_dims(op->src[1]) == 2);
char * wdata = static_cast<char *>(params->wdata);
const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10);
assert(params->wsize >= nbw1 * ne11);
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float;
int64_t i11_processed = 0;
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
INTER_SIZE);
}
i11_processed = ne11 - ne11 % 4;
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
}
ggml_barrier(params->threadpool);
const void * src1_wdata = params->wdata;
const size_t src1_col_stride = ggml_row_size(GGML_TYPE_Q8_0, ne10);
int64_t src0_start = (ith * ne01) / nth;
int64_t src0_end = ((ith + 1) * ne01) / nth;
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
if (src0_start >= src0_end) {
return;
}
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
if (ne11 > 3) {
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01,
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
}
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01,
(const char *) src1_wdata + (src1_col_stride * iter), 1,
src0_end - src0_start);
}
}
void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) {
const ggml_tensor * src0 = op->src[0];
const ggml_tensor * src1 = op->src[1];
const ggml_tensor * ids = op->src[2];
ggml_tensor * dst = op;
GGML_TENSOR_BINARY_OP_LOCALS
const int ith = params->ith;
const int nth = params->nth;
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float;
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(src0->type));
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(ne03 == 1);
GGML_ASSERT(ne13 == 1);
GGML_ASSERT(ne3 == 1);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
// row groups
const int n_ids = ids->ne[0]; // n_expert_used
const int n_as = ne02; // n_expert
const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10);
const size_t nbw2 = nbw1*ne11;
const size_t nbw3 = nbw2*ne12;
struct mmid_row_mapping {
int32_t i1;
int32_t i2;
};
GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
n_as * ne12 * sizeof(mmid_row_mapping)));
auto wdata = (char *) params->wdata;
auto wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
// src1: float32 => block_q8_0
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
(void *) (wdata + i12 * nbw2 + i11 * nbw1),
ne10);
}
}
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)]
if (ith == 0) {
// initialize matrix_row_counts
memset(matrix_row_counts, 0, n_as * sizeof(int64_t));
// group rows by src0 matrix
for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
for (int32_t id = 0; id < n_ids; ++id) {
const int32_t i02 =
*(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
GGML_ASSERT(i02 >= 0 && i02 < n_as);
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };
matrix_row_counts[i02] += 1;
}
}
}
ggml_barrier(params->threadpool);
// compute each matrix multiplication in sequence
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
const int64_t cne1 = matrix_row_counts[cur_a];
if (cne1 == 0) {
continue;
}
auto src0_cur = (const char *) src0->data + cur_a*nb02;
//const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = cne1; // src1 rows
int64_t src0_cur_start = (ith * ne01) / nth;
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
src0_cur_start =
(src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
if (src0_cur_start >= src0_cur_end) return;
for (int ir1 = 0; ir1 < nr1; ir1++) {
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
const int id = row_mapping.i1; // selected expert index
const int64_t i11 = id % ne11;
const int64_t i12 = row_mapping.i2; // row index in src1
const int64_t i1 = id; // selected expert index
const int64_t i2 = i12; // row
auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(
ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start,
ne01, src0_cur + src0_cur_start * nb01,
src1_col, 1, src0_cur_end - src0_cur_start);
}
}
#undef MMID_MATRIX_ROW
}
int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
(int) NB_COLS, (int) INTER_SIZE);
return ggml::cpu::aarch64::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);
}
};
// instance for Q4
static const tensor_traits<block_q4_0, 4, 4> q4_0_4x4_q8_0;
static const tensor_traits<block_q4_0, 8, 4> q4_0_4x8_q8_0;
static const tensor_traits<block_q4_0, 8, 8> q4_0_8x8_q8_0;
// instance for IQ4
static const tensor_traits<block_iq4_nl, 4, 4> iq4_nl_4x4_q8_0;
} // namespace ggml::cpu::aarch64
static const ggml::cpu::tensor_traits * ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur) {
if (cur->type == GGML_TYPE_Q4_0) { if (cur->type == GGML_TYPE_Q4_0) {
// TODO: enable for AVX2 - currently disabled due to bad gemv performance if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
if (/* ggml_cpu_has_avx2() || */ (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) { if (cur->ne[1] % 8 == 0) {
return GGML_TYPE_Q4_0_8_8; return &ggml::cpu::aarch64::q4_0_8x8_q8_0;
}
} }
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
return GGML_TYPE_Q4_0_4_8; if (cur->ne[1] % 4 == 0) {
return &ggml::cpu::aarch64::q4_0_4x8_q8_0;
}
} }
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
return GGML_TYPE_Q4_0_4_4; if (cur->ne[1] % 4 == 0) {
return &ggml::cpu::aarch64::q4_0_4x4_q8_0;
}
} }
} else if (cur->type == GGML_TYPE_IQ4_NL) { } else if (cur->type == GGML_TYPE_IQ4_NL) {
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
return GGML_TYPE_IQ4_NL_4_4; if (cur->ne[1] % 4 == 0) {
return &ggml::cpu::aarch64::iq4_nl_4x4_q8_0;
}
} }
} }
return cur->type; return nullptr;
}
static void ggml_backend_cpu_aarch64_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
tensor->extra = (void *) const_cast<ggml::cpu::tensor_traits *>(ggml_aarch64_get_optimal_repack_type(tensor));
GGML_UNUSED(buffer);
}
static void ggml_backend_cpu_aarch64_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
const void * data, size_t offset, size_t size) {
GGML_ASSERT(offset == 0);
GGML_ASSERT(size == ggml_nbytes(tensor));
auto tensor_traits = (ggml::cpu::aarch64::tensor_traits_base *) tensor->extra;
auto OK = tensor_traits->repack(tensor, data, size);
GGML_ASSERT(OK == 0);
GGML_UNUSED(buffer);
}
static const char * ggml_backend_cpu_aarch64_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
return "CPU_AARCH64";
GGML_UNUSED(buft);
}
static ggml_backend_buffer_t ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
if (buffer == nullptr) {
return nullptr;
}
buffer->buft = buft;
buffer->iface.init_tensor = ggml_backend_cpu_aarch64_buffer_init_tensor;
buffer->iface.set_tensor = ggml_backend_cpu_aarch64_buffer_set_tensor;
return buffer;
}
static size_t ggml_backend_cpu_aarch64_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
return TENSOR_ALIGNMENT;
GGML_UNUSED(buft);
}
namespace ggml::cpu::aarch64 {
class extra_buffer_type : ggml::cpu::extra_buffer_type {
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
if ( op->op == GGML_OP_MUL_MAT &&
op->src[0]->buffer &&
(ggml_n_dims(op->src[0]) == 2) &&
op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type() &&
ggml_aarch64_get_optimal_repack_type(op->src[0])
) {
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}
if (op->src[1]->type == GGML_TYPE_F32) {
return true;
}
//if (op->src[1]->type == GGML_TYPE_Q8_0) {
// return true;
//}
// may be possible if Q8_0 packed...
} else if (op->op == GGML_OP_MUL_MAT_ID
&& op->src[0]->buffer
&& (ggml_n_dims(op->src[0]) == 3)
&& op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type()
&& ggml_aarch64_get_optimal_repack_type(op->src[0])
) {
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}
if (op->src[1]->type == GGML_TYPE_F32) {
return true;
}
//if (op->src[1]->type == GGML_TYPE_Q8_0) {
// return true;
//}
}
return false;
}
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) {
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type()) {
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
}
}
return nullptr;
}
};
} // namespace ggml::cpu::aarch64
ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void) {
static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_aarch64 = {
/* .iface = */ {
/* .get_name = */ ggml_backend_cpu_aarch64_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_cpu_aarch64_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_cpu_aarch64_buffer_type_get_alignment,
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
/* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
/* .is_host = */ nullptr,
},
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
/* .context = */ new ggml::cpu::aarch64::extra_buffer_type(),
};
return &ggml_backend_cpu_buffer_type_aarch64;
} }

View file

@ -1,32 +1,8 @@
#pragma once #pragma once
#include "ggml-cpu-traits.h"
#include "ggml.h" #include "ggml.h"
// GGML internal header // GGML internal header
#ifdef __cplusplus ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void);
extern "C" {
#endif
// Quantization
void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave);
// GEMV
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
// GEMM
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * data, size_t data_size);
enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur);
#ifdef __cplusplus
}
#endif

View file

@ -0,0 +1,55 @@
#ifdef GGML_USE_CPU_HBM
#include "ggml-backend.h"
#include "ggml-backend-impl.h"
#include "ggml-cpu.h"
#include "ggml-impl.h"
#include "ggml-cpu-hbm.h"
// buffer type HBM
#include <hbwmalloc.h>
static const char * ggml_backend_cpu_hbm_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
return "CPU_HBM";
GGML_UNUSED(buft);
}
static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) {
hbw_free(buffer->context);
}
static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
size_t size) {
void * ptr;
int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size);
if (result != 0) {
GGML_LOG_ERROR("failed to allocate HBM buffer of size %zu\n", size);
return NULL;
}
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
buffer->buft = buft;
buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer;
return buffer;
}
ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) {
static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_hbm = {
/* .iface = */ {
/* .get_name = */ ggml_backend_cpu_hbm_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_cpu_hbm_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment,
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
/* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
/* .is_host = */ ggml_backend_cpu_buffer_type_is_host,
},
/* .context = */ nullptr,
};
return &ggml_backend_cpu_buffer_type_hbm;
}
#endif

View file

@ -0,0 +1,8 @@
#pragma once
#include "ggml-backend.h"
#include "ggml.h"
// GGML CPU internal header
ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);

View file

@ -0,0 +1,36 @@
#include "ggml-cpu-traits.h"
#include "ggml-backend-impl.h"
#include "ggml-backend.h"
namespace ggml::cpu {
tensor_traits::~tensor_traits() {}
extra_buffer_type::~extra_buffer_type() {}
} // namespace ggml::cpu
bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) {
for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
if (extra && extra->context) {
auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context;
auto tensor_traits = buf_extra->get_tensor_traits(op);
if (tensor_traits && tensor_traits->compute_forward(params, op)) {
return true;
}
}
}
return false;
}
bool ggml_cpu_extra_work_size(int n_threads, const struct ggml_tensor * op, size_t * size) {
for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
if (extra && extra->context) {
auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context;
auto tensor_traits = buf_extra->get_tensor_traits(op);
if (tensor_traits && tensor_traits->work_size(n_threads, op, *size)) {
return true;
}
}
}
return false;
}

View file

@ -0,0 +1,38 @@
#pragma once
#include "ggml-backend-impl.h"
#include "ggml-cpu-impl.h"
#include "ggml.h"
#ifdef __cplusplus
# include <vector>
extern "C" {
#endif
// return true if op part of extra "accelerator"
bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op);
bool ggml_cpu_extra_work_size(int n_threads, const struct ggml_tensor * op, size_t * size);
#ifdef __cplusplus
}
namespace ggml::cpu {
// register in tensor->extra
class tensor_traits {
public:
virtual ~tensor_traits();
virtual bool work_size(int n_threads, const struct ggml_tensor * op, size_t & size) = 0;
virtual bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) = 0;
};
class extra_buffer_type {
public:
virtual ~extra_buffer_type();
virtual bool supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) = 0;
virtual tensor_traits * get_tensor_traits(const struct ggml_tensor * op) = 0;
};
} // namespace ggml::cpu
// implemented in ggml-cpu.cpp.
std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffers_type();
#endif

View file

@ -3,7 +3,7 @@
#include "ggml-backend-impl.h" #include "ggml-backend-impl.h"
#include "ggml-backend.h" #include "ggml-backend.h"
#include "ggml-cpu-aarch64.h" #include "ggml-cpu-traits.h"
#include "ggml-cpu-impl.h" #include "ggml-cpu-impl.h"
#include "ggml-cpu.h" #include "ggml-cpu.h"
#include "ggml-impl.h" #include "ggml-impl.h"
@ -227,10 +227,6 @@ typedef void * thread_ret_t;
typedef pthread_t ggml_thread_t; typedef pthread_t ggml_thread_t;
#ifdef GGML_USE_CPU_HBM
#include <hbwmalloc.h>
#endif
#if defined(__APPLE__) #if defined(__APPLE__)
#include <unistd.h> #include <unistd.h>
#include <mach/mach.h> #include <mach/mach.h>
@ -304,7 +300,6 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
}, },
[GGML_TYPE_Q8_0] = { [GGML_TYPE_Q8_0] = {
.from_float = quantize_row_q8_0, .from_float = quantize_row_q8_0,
.from_float_to_mat = quantize_mat_q8_0,
.vec_dot = ggml_vec_dot_q8_0_q8_0, .vec_dot = ggml_vec_dot_q8_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
#if defined (__ARM_FEATURE_MATMUL_INT8) #if defined (__ARM_FEATURE_MATMUL_INT8)
@ -412,33 +407,6 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_BF16, .vec_dot_type = GGML_TYPE_BF16,
.nrows = 1, .nrows = 1,
}, },
[GGML_TYPE_Q4_0_4_4] = {
.from_float = NULL,
.vec_dot = NULL,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
.ncols = 4,
.gemv = ggml_gemv_q4_0_4x4_q8_0,
.gemm = ggml_gemm_q4_0_4x4_q8_0,
},
[GGML_TYPE_Q4_0_4_8] = {
.from_float = NULL,
.vec_dot = NULL,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
.ncols = 4,
.gemv = ggml_gemv_q4_0_4x8_q8_0,
.gemm = ggml_gemm_q4_0_4x8_q8_0,
},
[GGML_TYPE_Q4_0_8_8] = {
.from_float = NULL,
.vec_dot = NULL,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
.ncols = 8,
.gemv = ggml_gemv_q4_0_8x8_q8_0,
.gemm = ggml_gemm_q4_0_8x8_q8_0,
},
[GGML_TYPE_TQ1_0] = { [GGML_TYPE_TQ1_0] = {
.from_float = quantize_row_tq1_0, .from_float = quantize_row_tq1_0,
.vec_dot = ggml_vec_dot_tq1_0_q8_K, .vec_dot = ggml_vec_dot_tq1_0_q8_K,
@ -451,15 +419,6 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
}, },
[GGML_TYPE_IQ4_NL_4_4] = {
.from_float = NULL,
.vec_dot = NULL,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
.ncols = 4,
.gemv = ggml_gemv_iq4_nl_4x4_q8_0,
.gemm = ggml_gemm_iq4_nl_4x4_q8_0,
},
}; };
const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) { const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
@ -4525,9 +4484,6 @@ static void ggml_compute_forward_add(
case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
case GGML_TYPE_Q4_0_8_8:
{ {
ggml_compute_forward_add_q_f32(params, dst); ggml_compute_forward_add_q_f32(params, dst);
} break; } break;
@ -4905,9 +4861,6 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
case GGML_TYPE_Q4_0_8_8:
{ {
ggml_compute_forward_add1_q_f32(params, dst); ggml_compute_forward_add1_q_f32(params, dst);
} break; } break;
@ -5035,9 +4988,6 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
case GGML_TYPE_Q4_0_8_8:
default: default:
{ {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
@ -7465,27 +7415,9 @@ static void ggml_compute_forward_mul_mat(
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
enum ggml_type type = src0->type; enum ggml_type const vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
if (src0->buffer && ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) {
type = (enum ggml_type)(intptr_t)src0->extra;
}
#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
if (src0->buffer && ggml_backend_amx_buft_is_amx(src0->buffer->buft)) {
ggml_backend_amx_mul_mat(params, dst);
return;
}
#endif
enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float; ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
ggml_from_float_to_mat_t const from_float_to_mat = type_traits_cpu[vec_dot_type].from_float_to_mat; int64_t const vec_dot_num_rows = type_traits_cpu[src0->type].nrows;
int64_t const vec_dot_num_rows = type_traits_cpu[type].nrows;
int64_t const matmul_num_cols = type_traits_cpu[type].ncols;
int64_t const blck_size_interleave = ggml_get_type_traits(type)->blck_size_interleave;
ggml_gemv_t const gemv = type_traits_cpu[type].gemv;
ggml_gemm_t const gemm = type_traits_cpu[type].gemm;
GGML_ASSERT(ne0 == ne01); GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11); GGML_ASSERT(ne1 == ne11);
@ -7493,7 +7425,7 @@ static void ggml_compute_forward_mul_mat(
GGML_ASSERT(ne3 == ne13); GGML_ASSERT(ne3 == ne13);
// we don't support permuted src0 or src1 // we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type)); GGML_ASSERT(nb00 == ggml_type_size(src0->type));
GGML_ASSERT(nb10 == ggml_type_size(src1->type)); GGML_ASSERT(nb10 == ggml_type_size(src1->type));
// dst cannot be transposed or permuted // dst cannot be transposed or permuted
@ -7505,6 +7437,7 @@ static void ggml_compute_forward_mul_mat(
// nb01 >= nb00 - src0 is not transposed // nb01 >= nb00 - src0 is not transposed
// compute by src0 rows // compute by src0 rows
// TODO: extract to "extra_op"
#if defined(GGML_USE_CLBLAST) #if defined(GGML_USE_CLBLAST)
if (ggml_cl_can_mul_mat(src0, src1, dst)) { if (ggml_cl_can_mul_mat(src0, src1, dst)) {
if (params->ith == 0) { if (params->ith == 0) {
@ -7513,6 +7446,7 @@ static void ggml_compute_forward_mul_mat(
return; return;
} }
#endif #endif
#if GGML_USE_LLAMAFILE #if GGML_USE_LLAMAFILE
// broadcast factors // broadcast factors
const int64_t r2 = ne12 / ne02; const int64_t r2 = ne12 / ne02;
@ -7523,15 +7457,15 @@ static void ggml_compute_forward_mul_mat(
if (src1_cont) { if (src1_cont) {
for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++) for (int64_t i12 = 0; i12 < ne12; i12++)
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(type), if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
nb01/ggml_type_size(type), nb01/ggml_type_size(src0->type),
(const char *)src1->data + i12*nb12 + i13*nb13, (const char *)src1->data + i12*nb12 + i13*nb13,
nb11/ggml_type_size(src1->type), nb11/ggml_type_size(src1->type),
(char *)dst->data + i12*nb2 + i13*nb3, (char *)dst->data + i12*nb2 + i13*nb3,
nb1/ggml_type_size(dst->type), nb1/ggml_type_size(dst->type),
ith, nth, ith, nth,
type, src0->type,
src1->type, src1->type,
dst->type)) dst->type))
goto UseGgmlGemm1; goto UseGgmlGemm1;
@ -7552,19 +7486,10 @@ UseGgmlGemm1:;
for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i12 = 0; i12 < ne12; ++i12) {
int64_t i11_processed = 0; for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) {
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
4, ne10, blck_size_interleave);
}
i11_processed = ne11 - ne11 % 4;
}
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
ne10); ne10);
} }
} }
} }
@ -7584,15 +7509,15 @@ UseGgmlGemm1:;
for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++) for (int64_t i12 = 0; i12 < ne12; i12++)
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(type), if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
nb01/ggml_type_size(type), nb01/ggml_type_size(src0->type),
(const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
row_size/ggml_type_size(vec_dot_type), row_size/ggml_type_size(vec_dot_type),
(char *)dst->data + i12*nb2 + i13*nb3, (char *)dst->data + i12*nb2 + i13*nb3,
nb1/ggml_type_size(dst->type), nb1/ggml_type_size(dst->type),
ith, nth, ith, nth,
type, src0->type,
vec_dot_type, vec_dot_type,
dst->type)) dst->type))
goto UseGgmlGemm2; goto UseGgmlGemm2;
@ -7634,28 +7559,6 @@ UseGgmlGemm2:;
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
if ((ggml_n_dims(src0) == 2) && gemv) {
const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t src1_col_stride = ggml_is_contiguous(src1) || src1->type != vec_dot_type ? ggml_row_size(vec_dot_type, ne10) : nb11;
int64_t src0_start = (ith * ne01) / nth;
int64_t src0_end = ((ith + 1) * ne01) / nth;
src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start;
src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end;
if (src0_start >= src0_end) return;
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
if (gemm && (ne11 > 3)) {
gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01,
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
}
for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) {
gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
src0_end - src0_start);
}
return;
}
// The first chunk comes from our thread_id, the rest will get auto-assigned. // The first chunk comes from our thread_id, the rest will get auto-assigned.
int current_chunk = ith; int current_chunk = ith;
@ -7678,7 +7581,7 @@ UseGgmlGemm2:;
num_rows_per_vec_dot = 1; num_rows_per_vec_dot = 1;
} }
ggml_compute_forward_mul_mat_one_chunk(params, dst, type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
if (nth >= nchunk0 * nchunk1) { if (nth >= nchunk0 * nchunk1) {
break; break;
@ -7710,8 +7613,6 @@ static void ggml_compute_forward_mul_mat_id(
ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot; ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type; enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float; ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
int64_t const matmul_num_cols = type_traits_cpu[type].ncols;
ggml_gemv_t const gemv = type_traits_cpu[type].gemv;
// we don't support permuted src0 or src1 // we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type)); GGML_ASSERT(nb00 == ggml_type_size(type));
@ -7797,34 +7698,6 @@ static void ggml_compute_forward_mul_mat_id(
const int64_t nr0 = ne01; // src0 rows const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = cne1; // src1 rows const int64_t nr1 = cne1; // src1 rows
if (((ggml_n_dims(src0) - 1) == 2) && gemv) {
int64_t src0_cur_start = (ith * ne01) / nth;
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start;
src0_cur_end = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end;
if (src0_cur_start >= src0_cur_end) return;
for (int ir1 = 0; ir1 < nr1; ir1++) {
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
const int id = row_mapping.i1; // selected expert index
const int64_t i11 = id % ne11;
const int64_t i12 = row_mapping.i2; // row index in src1
const int64_t i1 = id; // selected expert index
const int64_t i2 = i12; // row
const char * src1_col = (const char *) wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12 * ne11) * row_size
: (i11 * nb11 + i12 * nb12));
gemv(ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
(const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start);
}
continue;
}
// distribute the thread work across the inner or outer loop based on which one is larger // distribute the thread work across the inner or outer loop based on which one is larger
const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
@ -8132,9 +8005,6 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
case GGML_TYPE_Q4_0_8_8:
{ {
ggml_compute_forward_out_prod_q_f32(params, dst); ggml_compute_forward_out_prod_q_f32(params, dst);
} break; } break;
@ -8397,9 +8267,6 @@ static void ggml_compute_forward_set(
case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
case GGML_TYPE_Q4_0_8_8:
default: default:
{ {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
@ -8661,9 +8528,6 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
case GGML_TYPE_Q4_0_8_8:
{ {
ggml_compute_forward_get_rows_q(params, dst); ggml_compute_forward_get_rows_q(params, dst);
} break; } break;
@ -9253,10 +9117,6 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q8_K: case GGML_TYPE_Q8_K:
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
case GGML_TYPE_Q4_0_8_8:
case GGML_TYPE_IQ4_NL_4_4:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -12462,6 +12322,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
return; return;
} }
// extra_buffer op?
if (ggml_cpu_extra_compute_forward(params, tensor)) return;
switch (tensor->op) { switch (tensor->op) {
case GGML_OP_DUP: case GGML_OP_DUP:
{ {
@ -13409,151 +13272,148 @@ struct ggml_cplan ggml_graph_plan(
size_t cur = 0; size_t cur = 0;
switch (node->op) { if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) {
case GGML_OP_CPY:
case GGML_OP_DUP: switch (node->op) {
{ case GGML_OP_CPY:
if (ggml_is_quantized(node->type) || case GGML_OP_DUP:
// F16 -> BF16 and BF16 -> F16 copies go through intermediate F32 {
(node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) || if (ggml_is_quantized(node->type) ||
(node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) { // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
} (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
} break; cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
case GGML_OP_ADD: }
case GGML_OP_ADD1: } break;
{ case GGML_OP_ADD:
if (ggml_is_quantized(node->src[0]->type)) { case GGML_OP_ADD1:
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; {
} if (ggml_is_quantized(node->src[0]->type)) {
} break; cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
case GGML_OP_ACC: }
{ } break;
if (ggml_is_quantized(node->src[0]->type)) { case GGML_OP_ACC:
cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks; {
} if (ggml_is_quantized(node->src[0]->type)) {
} break; cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
case GGML_OP_COUNT_EQUAL: }
{ } break;
cur = ggml_type_size(node->type)*n_tasks; case GGML_OP_COUNT_EQUAL:
} break; {
case GGML_OP_MUL_MAT: cur = ggml_type_size(node->type)*n_tasks;
{ } break;
#if defined(__AMX_INT8__) && defined(__AVX512VNNI__) case GGML_OP_MUL_MAT:
if (node->src[0]->buffer && ggml_backend_amx_buft_is_amx(node->src[0]->buffer->buft)) { {
cur = ggml_backend_amx_desired_wsize(node); const enum ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type;
}
#endif
const enum ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type;
#if defined(GGML_USE_CLBLAST) #if defined(GGML_USE_CLBLAST)
if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) { if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) {
cur = ggml_cl_mul_mat_get_wsize(node->src[0], node->src[1], node); cur = ggml_cl_mul_mat_get_wsize(node->src[0], node->src[1], node);
} else } else
#endif #endif
if (node->src[1]->type != vec_dot_type) { if (node->src[1]->type != vec_dot_type) {
size_t cur2 = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1])); size_t cur2 = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));
cur = MAX(cur, cur2); cur = MAX(cur, cur2);
} }
} break; } break;
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
{ {
cur = 0; cur = 0;
const struct ggml_tensor * src0 = node->src[0]; const struct ggml_tensor * src0 = node->src[0];
const struct ggml_tensor * src1 = node->src[1]; const struct ggml_tensor * src1 = node->src[1];
const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type; const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
if (src1->type != vec_dot_type) { if (src1->type != vec_dot_type) {
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)); cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
} }
const int n_as = src0->ne[2]; const int n_as = src0->ne[2];
cur += GGML_PAD(cur, sizeof(int64_t)); // align cur += GGML_PAD(cur, sizeof(int64_t)); // align
cur += n_as * sizeof(int64_t); // matrix_row_counts cur += n_as * sizeof(int64_t); // matrix_row_counts
cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
} break; } break;
case GGML_OP_OUT_PROD: case GGML_OP_OUT_PROD:
{ {
if (ggml_is_quantized(node->src[0]->type)) { if (ggml_is_quantized(node->src[0]->type)) {
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
} }
} break; } break;
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
case GGML_OP_ROPE: case GGML_OP_ROPE:
{ {
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
} break; } break;
case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_1D:
{ {
GGML_ASSERT(node->src[0]->ne[3] == 1); GGML_ASSERT(node->src[0]->ne[3] == 1);
GGML_ASSERT(node->src[1]->ne[2] == 1); GGML_ASSERT(node->src[1]->ne[2] == 1);
GGML_ASSERT(node->src[1]->ne[3] == 1); GGML_ASSERT(node->src[1]->ne[3] == 1);
const int64_t ne00 = node->src[0]->ne[0]; // K const int64_t ne00 = node->src[0]->ne[0]; // K
const int64_t ne01 = node->src[0]->ne[1]; // Cout const int64_t ne01 = node->src[0]->ne[1]; // Cout
const int64_t ne02 = node->src[0]->ne[2]; // Cin const int64_t ne02 = node->src[0]->ne[2]; // Cin
const int64_t ne10 = node->src[1]->ne[0]; // L
const int64_t ne11 = node->src[1]->ne[1]; // Cin
const int64_t ne10 = node->src[1]->ne[0]; // L if ((node->src[0]->type == GGML_TYPE_F16 ||
const int64_t ne11 = node->src[1]->ne[1]; // Cin node->src[0]->type == GGML_TYPE_BF16) &&
node->src[1]->type == GGML_TYPE_F32) {
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
cur += sizeof(ggml_fp16_t)*ne10*ne11;
} else if (node->src[0]->type == GGML_TYPE_F32 &&
node->src[1]->type == GGML_TYPE_F32) {
cur += sizeof(float)*ne00*ne01*ne02;
cur += sizeof(float)*ne10*ne11;
} else {
GGML_ABORT("fatal error");
}
} break;
case GGML_OP_CONV_TRANSPOSE_2D:
{
const int64_t ne00 = node->src[0]->ne[0]; // W
const int64_t ne01 = node->src[0]->ne[1]; // H
const int64_t ne02 = node->src[0]->ne[2]; // Channels Out
const int64_t ne03 = node->src[0]->ne[3]; // Channels In
if ((node->src[0]->type == GGML_TYPE_F16 || const int64_t ne10 = node->src[1]->ne[0]; // W
node->src[0]->type == GGML_TYPE_BF16) && const int64_t ne11 = node->src[1]->ne[1]; // H
node->src[1]->type == GGML_TYPE_F32) { const int64_t ne12 = node->src[1]->ne[2]; // Channels In
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
cur += sizeof(ggml_fp16_t)*ne10*ne11; cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
} else if (node->src[0]->type == GGML_TYPE_F32 && cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
node->src[1]->type == GGML_TYPE_F32) { } break;
cur += sizeof(float)*ne00*ne01*ne02; case GGML_OP_FLASH_ATTN_EXT:
cur += sizeof(float)*ne10*ne11; {
} else { const int64_t ne00 = node->src[0]->ne[0]; // D
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
const int64_t D = node->src[0]->ne[0];
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
if (node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
} else if (node->src[1]->type == GGML_TYPE_F16) {
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
} else if (node->src[1]->type == GGML_TYPE_BF16) {
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
}
} break;
case GGML_OP_CROSS_ENTROPY_LOSS:
{
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
} break;
case GGML_OP_COUNT:
{
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
} break; default:
case GGML_OP_CONV_TRANSPOSE_2D: break;
{ }
const int64_t ne00 = node->src[0]->ne[0]; // W
const int64_t ne01 = node->src[0]->ne[1]; // H
const int64_t ne02 = node->src[0]->ne[2]; // Channels Out
const int64_t ne03 = node->src[0]->ne[3]; // Channels In
const int64_t ne10 = node->src[1]->ne[0]; // W
const int64_t ne11 = node->src[1]->ne[1]; // H
const int64_t ne12 = node->src[1]->ne[2]; // Channels In
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
const int64_t ne00 = node->src[0]->ne[0]; // D
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
const int64_t D = node->src[0]->ne[0];
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
if (node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
} else if (node->src[1]->type == GGML_TYPE_F16) {
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
} else if (node->src[1]->type == GGML_TYPE_BF16) {
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
}
} break;
case GGML_OP_CROSS_ENTROPY_LOSS:
{
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
} break;
case GGML_OP_COUNT:
{
GGML_ABORT("fatal error");
}
default:
break;
} }
work_size = MAX(work_size, cur); work_size = MAX(work_size, cur);

View file

@ -2,12 +2,18 @@
#include "ggml-backend-impl.h" #include "ggml-backend-impl.h"
#include "ggml-cpu.h" #include "ggml-cpu.h"
#include "ggml-cpu-aarch64.h" #include "ggml-cpu-aarch64.h"
#include "ggml-cpu-traits.h"
#include "ggml-impl.h" #include "ggml-impl.h"
#include "amx/amx.h" #include "amx/amx.h"
#include <cctype> #include <cctype>
#include <string> #include <string>
#include <vector> #include <vector>
#ifdef GGML_USE_CPU_HBM
#include "ggml-cpu-hbm.h"
#endif
#if defined(__APPLE__) #if defined(__APPLE__)
#include <sys/types.h> #include <sys/types.h>
#include <sys/sysctl.h> #include <sys/sysctl.h>
@ -23,115 +29,7 @@
// ggml-backend interface // ggml-backend interface
#ifdef GGML_USE_CPU_HBM std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type() {
// buffer type HBM
#include <hbwmalloc.h>
static const char * ggml_backend_cpu_hbm_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
return "CPU_HBM";
GGML_UNUSED(buft);
}
static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) {
hbw_free(buffer->context);
}
static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
void * ptr;
int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size);
if (result != 0) {
GGML_LOG_ERROR("failed to allocate HBM buffer of size %zu\n", size);
return NULL;
}
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
buffer->buft = buft;
buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer;
return buffer;
}
ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) {
static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_hbm = {
/* .iface = */ {
/* .get_name = */ ggml_backend_cpu_hbm_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_cpu_hbm_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment,
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
/* .is_host = */ ggml_backend_cpu_buffer_type_is_host,
},
/* .context = */ NULL,
};
return &ggml_backend_cpu_buffer_type_hbm;
}
#endif
// buffer type AARCH64
static void ggml_backend_cpu_aarch64_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
tensor->extra = (void *)ggml_aarch64_get_optimal_repack_type(tensor); // NOLINT
GGML_UNUSED(buffer);
}
static void ggml_backend_cpu_aarch64_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
GGML_ASSERT(offset == 0);
GGML_ASSERT(size == ggml_nbytes(tensor));
enum ggml_type repack_type = (enum ggml_type)(intptr_t)tensor->extra;
ggml_aarch64_repack_tensor(tensor, repack_type, data, size);
GGML_UNUSED(buffer);
}
static const char * ggml_backend_cpu_aarch64_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
return "CPU_AARCH64";
GGML_UNUSED(buft);
}
static ggml_backend_buffer_t ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
auto * buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
if (buffer == NULL) {
return NULL;
}
buffer->buft = buft;
buffer->iface.init_tensor = ggml_backend_cpu_aarch64_buffer_init_tensor;
buffer->iface.set_tensor = ggml_backend_cpu_aarch64_buffer_set_tensor;
return buffer;
}
ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void) {
static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_aarch64 = {
/* .iface = */ {
/* .get_name = */ ggml_backend_cpu_aarch64_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_cpu_aarch64_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
/* .is_host = */ NULL,
},
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
/* .context = */ NULL,
};
return &ggml_backend_cpu_buffer_type_aarch64;
}
bool ggml_backend_cpu_buft_is_aarch64(ggml_backend_buffer_type_t buft) {
return buft == ggml_backend_cpu_aarch64_buffer_type();
}
static ggml_backend_buffer_type_t * ggml_backend_cpu_get_extra_bufts(ggml_backend_dev_t device) {
static std::vector<ggml_backend_buffer_type_t> bufts = []() { static std::vector<ggml_backend_buffer_type_t> bufts = []() {
std::vector<ggml_backend_buffer_type_t> bufts; std::vector<ggml_backend_buffer_type_t> bufts;
@ -152,11 +50,22 @@ static ggml_backend_buffer_type_t * ggml_backend_cpu_get_extra_bufts(ggml_backen
return bufts; return bufts;
}(); }();
return bufts.data(); return bufts;
}
static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_type(ggml_backend_dev_t device) {
return ggml_backend_cpu_get_extra_buffers_type().data();
GGML_UNUSED(device); GGML_UNUSED(device);
} }
static bool ggml_backend_cpu_is_extra_buffer_type(ggml_backend_buffer_type_t buft) {
for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
if (extra && extra == buft) return true;
}
return false;
}
// CPU backend - backend (stream) // CPU backend - backend (stream)
struct ggml_backend_cpu_context { struct ggml_backend_cpu_context {
@ -465,25 +374,19 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
return true; return true;
} }
if (src0 && src0->buffer && ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) { // extra_buffer_op?
if (op->op != GGML_OP_MUL_MAT || src0->type == ggml_aarch64_get_optimal_repack_type(src0)) { for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
return false; if (extra) {
auto buf_extra = (ggml::cpu::extra_buffer_type*) extra->context;
if (buf_extra && buf_extra->supports_op(dev, op)) {
return true;
}
} }
} }
#if defined(__AMX_INT8__) && defined(__AVX512VNNI__) // the other case need host buffer.
if (src0 && src0->buffer && ggml_backend_amx_buft_is_amx(src0->buffer->buft)) { for (int i = 0; i < GGML_MAX_SRC; i++) {
return ggml_backend_amx_device_supports_op(op); if (op->src[i] && op->src[i]->buffer && !ggml_backend_buft_is_host(op->src[i]->buffer->buft)) {
}
for (int i = 1; i < GGML_MAX_SRC; i++) {
if (op->src[i] && op->src[i]->buffer && ggml_backend_amx_buft_is_amx(op->src[i]->buffer->buft)) {
return false;
}
}
#endif
for (int i = 1; i < GGML_MAX_SRC; i++) {
if (op->src[i] && op->src[i]->buffer && ggml_backend_cpu_buft_is_aarch64(op->src[i]->buffer->buft)) {
return false; return false;
} }
} }
@ -506,19 +409,10 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
default: default:
return true; return true;
} }
GGML_UNUSED(dev);
} }
static bool ggml_backend_cpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { static bool ggml_backend_cpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
bool supported = ggml_backend_buft_is_host(buft) || ggml_backend_cpu_buft_is_aarch64(buft); return ggml_backend_buft_is_host(buft) || ggml_backend_cpu_is_extra_buffer_type(buft);
#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
supported = supported || ggml_backend_amx_buft_is_amx(buft);
#endif
return supported;
GGML_UNUSED(dev); GGML_UNUSED(dev);
} }
@ -666,10 +560,12 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const char * name) { static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const char * name) {
if (strcmp(name, "ggml_backend_set_n_threads") == 0) { if (strcmp(name, "ggml_backend_set_n_threads") == 0) {
return (void *)ggml_backend_cpu_set_n_threads; ggml_backend_set_n_threads_t fct = ggml_backend_cpu_set_n_threads;
return (void *)fct;
} }
if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) { if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) {
return (void *)ggml_backend_cpu_get_extra_bufts; ggml_backend_dev_get_extra_bufts_t fct = ggml_backend_cpu_device_get_extra_buffers_type;
return (void *)fct;
} }
if (strcmp(name, "ggml_backend_get_features") == 0) { if (strcmp(name, "ggml_backend_get_features") == 0) {
return (void *)ggml_backend_cpu_get_features; return (void *)ggml_backend_cpu_get_features;

View file

@ -41,28 +41,28 @@
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons #define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
#define CC_PASCAL 600 #define GGML_CUDA_CC_PASCAL 600
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products #define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define CC_VOLTA 700 #define GGML_CUDA_CC_VOLTA 700
#define CC_TURING 750 #define GGML_CUDA_CC_TURING 750
#define CC_AMPERE 800 #define GGML_CUDA_CC_AMPERE 800
#define CC_OFFSET_AMD 1000000 #define GGML_CUDA_CC_OFFSET_AMD 1000000
// GCN/CNDA, wave size is 64 // GCN/CNDA, wave size is 64
#define CC_GCN4 (CC_OFFSET_AMD + 803) // Tonga, Fiji, Polaris, minimum for fast fp16 #define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 803) // Tonga, Fiji, Polaris, minimum for fast fp16
#define CC_VEGA (CC_OFFSET_AMD + 900) // Vega56/64, minimum for fp16 dual issue #define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 900) // Vega56/64, minimum for fp16 dual issue
#define CC_VEGA20 (CC_OFFSET_AMD + 906) // MI50/Radeon VII, minimum for dp4a #define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 906) // MI50/Radeon VII, minimum for dp4a
#define CC_CDNA (CC_OFFSET_AMD + 908) // MI100, minimum for MFMA, acc registers #define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 908) // MI100, minimum for MFMA, acc registers
#define CC_CDNA2 (CC_OFFSET_AMD + 910) // MI210, minimum acc register renameing #define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 910) // MI210, minimum acc register renameing
#define CC_CDNA3 (CC_OFFSET_AMD + 942) // MI300 #define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 942) // MI300
// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32 // RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
#define CC_RDNA1 (CC_OFFSET_AMD + 1010) // RX 5000 #define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 1010) // RX 5000
#define CC_RDNA2 (CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a #define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a
#define CC_RDNA3 (CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA #define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA
#define CC_QY1 210 #define GGML_CUDA_CC_QY1 210
#define CC_QY2 220 #define GGML_CUDA_CC_QY2 220
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
@ -131,36 +131,36 @@ typedef float dfloat; // dequantize float
typedef float2 dfloat2; typedef float2 dfloat2;
#endif // GGML_CUDA_F16 #endif // GGML_CUDA_F16
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#define FP16_AVAILABLE #define FP16_AVAILABLE
#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL #endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 #if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
#define FAST_FP16_AVAILABLE #define FAST_FP16_AVAILABLE
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 #endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#define FP16_MMA_AVAILABLE #define FP16_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define INT8_MMA_AVAILABLE #define INT8_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1) #if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
#define FLASH_ATTN_AVAILABLE #define FLASH_ATTN_AVAILABLE
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1) #endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
static constexpr bool fast_fp16_available(const int cc) { static constexpr bool fast_fp16_available(const int cc) {
return cc >= CC_PASCAL && cc != 610; return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
} }
static constexpr bool fp16_mma_available(const int cc) { static constexpr bool fp16_mma_available(const int cc) {
return cc < CC_OFFSET_AMD && cc >= CC_VOLTA; return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
} }
static constexpr bool int8_mma_available(const int cc) { static constexpr bool int8_mma_available(const int cc) {
return cc < CC_OFFSET_AMD && cc >= CC_TURING; return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
} }
[[noreturn]] [[noreturn]]
@ -187,7 +187,7 @@ static __device__ void no_device_code(
#endif // __CUDA_ARCH__ #endif // __CUDA_ARCH__
static __device__ __forceinline__ int warp_reduce_sum(int x) { static __device__ __forceinline__ int warp_reduce_sum(int x) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
return __reduce_add_sync(0xffffffff, x); return __reduce_add_sync(0xffffffff, x);
#else #else
#pragma unroll #pragma unroll
@ -195,7 +195,7 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) {
x += __shfl_xor_sync(0xffffffff, x, offset, 32); x += __shfl_xor_sync(0xffffffff, x, offset, 32);
} }
return x; return x;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
} }
static __device__ __forceinline__ float warp_reduce_sum(float x) { static __device__ __forceinline__ float warp_reduce_sum(float x) {
@ -284,7 +284,7 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
} }
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#pragma unroll #pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) { for (int offset = 16; offset > 0; offset >>= 1) {
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32)); x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
@ -293,7 +293,7 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#else #else
GGML_UNUSED(x); GGML_UNUSED(x);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
} }
#if CUDART_VERSION < CUDART_HMASK #if CUDART_VERSION < CUDART_HMASK
@ -333,13 +333,13 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
#if __CUDA_ARCH__ >= MIN_CC_DP4A #if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
return __dp4a(a, b, c); return __dp4a(a, b, c);
#else // __CUDA_ARCH__ >= MIN_CC_DP4A #else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
const int8_t * a8 = (const int8_t *) &a; const int8_t * a8 = (const int8_t *) &a;
const int8_t * b8 = (const int8_t *) &b; const int8_t * b8 = (const int8_t *) &b;
return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3]; return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
} }

View file

@ -26,7 +26,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
template <bool need_check> template <bool need_check>
static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) { static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
#if __CUDA_ARCH__ >= CC_PASCAL #if __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE; constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
const int64_t i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x; const int64_t i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
@ -64,7 +64,7 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
GGML_UNUSED(y); GGML_UNUSED(y);
GGML_UNUSED(k); GGML_UNUSED(k);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ >= CC_PASCAL #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
} }
template<typename dst_t> template<typename dst_t>
@ -599,7 +599,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>; return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) { if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_PASCAL) {
return dequantize_block_q8_0_f16_cuda; return dequantize_block_q8_0_f16_cuda;
} }
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>; return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;

View file

@ -304,7 +304,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
// On AMD the tile kernels perform poorly, use the vec kernel instead: // On AMD the tile kernels perform poorly, use the vec kernel instead:
if (cc >= CC_OFFSET_AMD) { if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
} else { } else {

View file

@ -180,7 +180,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
info.devices[id].smpb = prop.sharedMemPerBlock; info.devices[id].smpb = prop.sharedMemPerBlock;
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
info.devices[id].smpbo = prop.sharedMemPerBlock; info.devices[id].smpbo = prop.sharedMemPerBlock;
info.devices[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD; info.devices[id].cc = 100*prop.major + 10*prop.minor + GGML_CUDA_CC_OFFSET_AMD;
#else #else
info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
info.devices[id].cc = 100*prop.major + 10*prop.minor; info.devices[id].cc = 100*prop.major + 10*prop.minor;
@ -1082,7 +1082,7 @@ static void ggml_cuda_op_mul_mat_cublas(
const int compute_capability = ggml_cuda_info().devices[id].cc; const int compute_capability = ggml_cuda_info().devices[id].cc;
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { if (compute_capability >= GGML_CUDA_CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id)); ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
if (src0->type != GGML_TYPE_F16) { if (src0->type != GGML_TYPE_F16) {
@ -1109,7 +1109,7 @@ static void ggml_cuda_op_mul_mat_cublas(
const half beta_f16 = 0.0f; const half beta_f16 = 0.0f;
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F; cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
if (ggml_cuda_info().devices[ctx.device].cc == CC_CDNA) { if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
cu_compute_type = CUBLAS_COMPUTE_32F; cu_compute_type = CUBLAS_COMPUTE_32F;
} }
@ -1613,7 +1613,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F; cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
cudaDataType_t cu_data_type = CUDA_R_16F; cudaDataType_t cu_data_type = CUDA_R_16F;
if (ggml_cuda_info().devices[ctx.device].cc == CC_CDNA) { if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
cu_compute_type = CUBLAS_COMPUTE_32F; cu_compute_type = CUBLAS_COMPUTE_32F;
} }
@ -2362,7 +2362,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
std::vector<void *> ggml_cuda_cpy_fn_ptrs; std::vector<void *> ggml_cuda_cpy_fn_ptrs;
if (cuda_ctx->cuda_graph->graph == nullptr) { if (cuda_ctx->cuda_graph->graph == nullptr) {
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) { if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
#ifndef NDEBUG #ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
@ -3033,7 +3033,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return true; return true;
} }
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; return cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
} }
case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS:
case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
@ -3216,7 +3216,7 @@ static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, con
static const ggml_backend_reg_i ggml_backend_cuda_reg_interface = { static const ggml_backend_reg_i ggml_backend_cuda_reg_interface = {
/* .get_name = */ ggml_backend_cuda_reg_get_name, /* .get_name = */ ggml_backend_cuda_reg_get_name,
/* .get_device_count = */ ggml_backend_cuda_reg_get_device_count, /* .get_device_count = */ ggml_backend_cuda_reg_get_device_count,
/* .get_device_get = */ ggml_backend_cuda_reg_get_device, /* .get_device = */ ggml_backend_cuda_reg_get_device,
/* .get_proc_address = */ ggml_backend_cuda_reg_get_proc_address, /* .get_proc_address = */ ggml_backend_cuda_reg_get_proc_address,
}; };

View file

@ -171,7 +171,7 @@ struct mma_int_C_I16J8 {
__device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) { __device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
#ifdef INT8_MMA_AVAILABLE #ifdef INT8_MMA_AVAILABLE
#if __CUDA_ARCH__ >= CC_AMPERE #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0])); : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
@ -183,7 +183,7 @@ struct mma_int_C_I16J8 {
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[2]), "+r"(x[3]) : "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[1]), "r"(mma_B.x[0])); : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
#endif // __CUDA_ARCH__ >= CC_AMPERE #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else #else
GGML_UNUSED(mma_A); GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B); GGML_UNUSED(mma_B);
@ -193,7 +193,7 @@ struct mma_int_C_I16J8 {
__device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) { __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
#ifdef INT8_MMA_AVAILABLE #ifdef INT8_MMA_AVAILABLE
#if __CUDA_ARCH__ >= CC_AMPERE #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1])); : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
@ -211,7 +211,7 @@ struct mma_int_C_I16J8 {
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[2]), "+r"(x[3]) : "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[3]), "r"(mma_B.x[1])); : "r"(mma_A.x[3]), "r"(mma_B.x[1]));
#endif // __CUDA_ARCH__ >= CC_AMPERE #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else #else
GGML_UNUSED(mma_A); GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B); GGML_UNUSED(mma_B);

View file

@ -27,7 +27,7 @@ void ggml_cuda_op_mul_mat_q(
// The stream-k decomposition is only faster for recent NVIDIA GPUs. // The stream-k decomposition is only faster for recent NVIDIA GPUs.
// Also its fixup needs to allocate a temporary buffer in the memory pool. // Also its fixup needs to allocate a temporary buffer in the memory pool.
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer. // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
const bool use_stream_k = compute_capability >= CC_VOLTA && compute_capability < CC_OFFSET_AMD && src1_ncols == ne11; const bool use_stream_k = compute_capability >= GGML_CUDA_CC_VOLTA && compute_capability < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k}; const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
switch (src0->type) { switch (src0->type) {
@ -138,13 +138,17 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return true; return true;
} }
if (cc < MIN_CC_DP4A) { if (cc < GGML_CUDA_CC_DP4A) {
return false; return false;
} }
if (cc < CC_OFFSET_AMD) { #ifdef GGML_CUDA_FORCE_MMQ
return cc < CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; return true;
#endif //GGML_CUDA_FORCE_MMQ
if (cc < GGML_CUDA_CC_OFFSET_AMD) {
return cc < GGML_CUDA_CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
} }
return (cc < CC_RDNA3 && cc != CC_CDNA && cc != CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; return (cc < GGML_CUDA_CC_RDNA3 && cc != GGML_CUDA_CC_CDNA && cc != GGML_CUDA_CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
} }

View file

@ -90,9 +90,9 @@ struct tile_x_sizes {
static constexpr int get_mmq_x_max_host(const int cc) { static constexpr int get_mmq_x_max_host(const int cc) {
return int8_mma_available(cc) ? 128 : return int8_mma_available(cc) ? 128 :
#ifdef GGML_CUDA_FORCE_MMQ #ifdef GGML_CUDA_FORCE_MMQ
cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64; cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128 : 64;
#else #else
cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64; cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64;
#endif // GGML_CUDA_FORCE_MMQ #endif // GGML_CUDA_FORCE_MMQ
} }
@ -105,23 +105,23 @@ static constexpr __device__ int get_mmq_x_max_device() {
return 128; return 128;
#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
#if __CUDA_ARCH__ >= CC_VOLTA #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#ifdef GGML_CUDA_FORCE_MMQ #ifdef GGML_CUDA_FORCE_MMQ
return MMQ_DP4A_MAX_BATCH_SIZE; return MMQ_DP4A_MAX_BATCH_SIZE;
#else // GGML_CUDA_FORCE_MMQ #else // GGML_CUDA_FORCE_MMQ
return 128; return 128;
#endif // GGML_CUDA_FORCE_MMQ #endif // GGML_CUDA_FORCE_MMQ
#else // __CUDA_ARCH__ >= CC_VOLTA #else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
return 64; return 64;
#endif // __CUDA_ARCH__ >= CC_VOLTA #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
#endif // INT8_MMA_AVAILABLE #endif // INT8_MMA_AVAILABLE
} }
static constexpr int get_mmq_y_host(const int cc) { static constexpr int get_mmq_y_host(const int cc) {
return cc >= CC_OFFSET_AMD ? (cc == CC_RDNA1 ? 64 : 128) : (cc >= CC_VOLTA ? 128 : 64); return cc >= GGML_CUDA_CC_OFFSET_AMD ? (cc == GGML_CUDA_CC_RDNA1 ? 64 : 128) : (cc >= GGML_CUDA_CC_VOLTA ? 128 : 64);
} }
static constexpr __device__ int get_mmq_y_device() { static constexpr __device__ int get_mmq_y_device() {
@ -132,11 +132,11 @@ static constexpr __device__ int get_mmq_y_device() {
return 128; return 128;
#endif // defined RDNA1 #endif // defined RDNA1
#else #else
#if __CUDA_ARCH__ >= CC_VOLTA #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
return 128; return 128;
#else #else
return 64; return 64;
#endif // __CUDA_ARCH__ >= CC_VOLTA #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
} }
@ -2575,11 +2575,11 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
__launch_bounds__(WARP_SIZE*nwarps, 2) __launch_bounds__(WARP_SIZE*nwarps, 2)
#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) #endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
#else #else
#if __CUDA_ARCH__ >= CC_VOLTA #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
__launch_bounds__(WARP_SIZE*nwarps, 1) __launch_bounds__(WARP_SIZE*nwarps, 1)
#else #else
__launch_bounds__(WARP_SIZE*nwarps, 2) __launch_bounds__(WARP_SIZE*nwarps, 2)
#endif // __CUDA_ARCH__ >= CC_VOLTA #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
static __global__ void mul_mat_q( static __global__ void mul_mat_q(
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup, const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
@ -2595,7 +2595,7 @@ static __global__ void mul_mat_q(
constexpr int mmq_y = get_mmq_y_device(); constexpr int mmq_y = get_mmq_y_device();
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
{ {
constexpr bool fixup = false; constexpr bool fixup = false;
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup> mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
@ -2603,7 +2603,7 @@ static __global__ void mul_mat_q(
blockIdx.x, blockIdx.y, 0, ne00/qk); blockIdx.x, blockIdx.y, 0, ne00/qk);
return; return;
} }
#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA #endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
const int64_t blocks_per_ne00 = ne00 / qk; const int64_t blocks_per_ne00 = ne00 / qk;
constexpr int blocks_per_iter = MMQ_ITER_K / qk; constexpr int blocks_per_iter = MMQ_ITER_K / qk;
@ -2826,7 +2826,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
const int mmq_x_max = get_mmq_x_max_host(cc); const int mmq_x_max = get_mmq_x_max_host(cc);
const int mmq_y = get_mmq_y_host(cc); const int mmq_y = get_mmq_y_host(cc);
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y; const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD; const bool use_stream_k = cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
int mmq_x_best = 0; int mmq_x_best = 0;
int nparts_best = INT_MAX; int nparts_best = INT_MAX;

View file

@ -57,7 +57,7 @@ static __global__ void mul_mat_vec(
if (block_size > WARP_SIZE) { if (block_size > WARP_SIZE) {
buf_iw[tid/WARP_SIZE] = sumf; buf_iw[tid/WARP_SIZE] = sumf;
__syncthreads(); __syncthreads();
if (tid > WARP_SIZE) { if (tid >= WARP_SIZE) {
return; return;
} }
sumf = buf_iw[tid]; sumf = buf_iw[tid];

View file

@ -142,7 +142,7 @@ static void mul_mat_vec_q_cuda(
int64_t nwarps = 1; int64_t nwarps = 1;
int64_t rows_per_cuda_block = 1; int64_t rows_per_cuda_block = 1;
if (ggml_cuda_info().devices[id].cc < CC_CDNA || ggml_cuda_info().devices[id].cc == CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_CDNA || ggml_cuda_info().devices[id].cc == GGML_CUDA_CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA
switch(ncols_y) { switch(ncols_y) {
case 1: case 1:
nwarps = 4; nwarps = 4;

View file

@ -3,8 +3,6 @@
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700 #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
#ifdef USE_CUB #ifdef USE_CUB
// On Windows CUB uses libraries with variables called CC_PASCAL which conflict with the define in common.cuh.
// For this reason CUB must be included BEFORE anything else.
#include <cub/cub.cuh> #include <cub/cub.cuh>
using namespace cub; using namespace cub;
#endif // USE_CUB #endif // USE_CUB

View file

@ -5220,15 +5220,6 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
{ {
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
} break; } break;
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
{
VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x4, data, nbytes / sizeof(block_q4_0x4), 4);
} break;
case GGML_TYPE_Q4_0_8_8:
{
VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x8, data, nbytes / sizeof(block_q4_0x8), 8);
} break;
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:

View file

@ -4630,7 +4630,7 @@ static void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, cons
static const ggml_backend_reg_i ggml_backend_sycl_reg_interface = { static const ggml_backend_reg_i ggml_backend_sycl_reg_interface = {
/* .get_name = */ ggml_backend_sycl_reg_get_name, /* .get_name = */ ggml_backend_sycl_reg_get_name,
/* .get_device_count = */ ggml_backend_sycl_reg_get_device_count, /* .get_device_count = */ ggml_backend_sycl_reg_get_device_count,
/* .get_device_get = */ ggml_backend_sycl_reg_get_device, /* .get_device = */ ggml_backend_sycl_reg_get_device,
/* .get_proc_address = */ ggml_backend_sycl_reg_get_proc_address, /* .get_proc_address = */ ggml_backend_sycl_reg_get_proc_address,
}; };

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,11 @@
#version 450 #version 450
#extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_spirv_intrinsics: enable
#if RTE16
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
#endif
layout (push_constant) uniform parameter layout (push_constant) uniform parameter
{ {

View file

@ -2,8 +2,6 @@
#extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_8bit_storage : require #extension GL_EXT_shader_8bit_storage : require
#define K_QUANTS_PER_ITERATION 2
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
#define EXPERT_COUNT 8 #define EXPERT_COUNT 8
#endif #endif

View file

@ -3,9 +3,11 @@
#include "mul_mat_vec_base.comp" #include "mul_mat_vec_base.comp"
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE tmp[32]; layout (constant_id = 0) const uint BLOCK_SIZE = 32;
shared FLOAT_TYPE tmp[BLOCK_SIZE];
void main() { void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z; const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
@ -20,22 +22,25 @@ void main() {
const uint num_blocks_per_row = p.ncols / QUANT_K; const uint num_blocks_per_row = p.ncols / QUANT_K;
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 // 16 threads are used to process each block
const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...16
const uint ix = tid/16;
const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 const uint step = 8;
const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = tid - step*v_im; // 0...15 or 0...7 const uint v_in = itid - step*v_im; // 0...15 or 0...7
const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15 const uint l0 = 2*v_in; // 0...15
const uint q_offset = 32*v_im + l0; const uint q_offset = 32*v_im + l0;
const uint s_offset = 8*v_im; const uint s_offset = 8*v_im;
const uint y_offset = 128*v_im + l0; const uint y_offset = 128*v_im + l0;
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y_idx = i * QUANT_K + y_offset; const uint y_idx = i * QUANT_K + y_offset;
f16vec2 d = data_a[ib0 + i].d; f16vec2 d = data_a[ib0 + i].d;
@ -71,7 +76,7 @@ void main() {
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { [[unroll]] for (int l = 0; l < 2; ++l) {
sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3), sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3), fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3), fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
@ -96,7 +101,7 @@ void main() {
// sum up partial sums and write back result // sum up partial sums and write back result
barrier(); barrier();
[[unroll]] for (uint s = 16; s > 0; s >>= 1) { [[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
if (tid < s) { if (tid < s) {
tmp[tid] += tmp[tid + s]; tmp[tid] += tmp[tid + s];
} }

View file

@ -3,9 +3,11 @@
#include "mul_mat_vec_base.comp" #include "mul_mat_vec_base.comp"
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE tmp[32]; layout (constant_id = 0) const uint BLOCK_SIZE = 32;
shared FLOAT_TYPE tmp[BLOCK_SIZE];
void main() { void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z; const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
@ -20,17 +22,20 @@ void main() {
const uint num_blocks_per_row = p.ncols / QUANT_K; const uint num_blocks_per_row = p.ncols / QUANT_K;
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 // 16 threads are used to process each block
const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...16
const uint ix = tid/16;
const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 const uint step = 8;
const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = tid - step*v_im; // 0...15 or 0...7 const uint v_in = itid - step*v_im; // 0...15 or 0...7
const uint8_t m = uint8_t(1 << (4 * v_im)); const uint8_t m = uint8_t(1 << (4 * v_im));
const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15 const uint l0 = 2*v_in; // 0...15
const uint q_offset = 32*v_im + l0; const uint q_offset = 32*v_im + l0;
const uint y_offset = 128*v_im + l0; const uint y_offset = 128*v_im + l0;
@ -38,7 +43,7 @@ void main() {
const uint s_shift = 4 * v_im; const uint s_shift = 4 * v_im;
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y_idx = i * QUANT_K + y_offset; const uint y_idx = i * QUANT_K + y_offset;
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
@ -66,7 +71,7 @@ void main() {
u8vec2 s10 = unpack8(s10_16); u8vec2 s10 = unpack8(s10_16);
FLOAT_TYPE sum = FLOAT_TYPE(0.0); FLOAT_TYPE sum = FLOAT_TYPE(0.0);
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { [[unroll]] for (int l = 0; l < 2; ++l) {
sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)), sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)),
fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)),
fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)),
@ -83,7 +88,7 @@ void main() {
// sum up partial sums and write back result // sum up partial sums and write back result
barrier(); barrier();
[[unroll]] for (uint s = 16; s > 0; s >>= 1) { [[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
if (tid < s) { if (tid < s) {
tmp[tid] += tmp[tid + s]; tmp[tid] += tmp[tid + s];
} }

View file

@ -4,11 +4,12 @@
#include "mul_mat_vec_base.comp" #include "mul_mat_vec_base.comp"
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE tmp[32]; layout (constant_id = 0) const uint BLOCK_SIZE = 32;
shared FLOAT_TYPE tmp[BLOCK_SIZE];
// This shader assumes K_QUANTS_PER_ITERATION == 2 for alignment of loads
void main() { void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z; const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
@ -22,14 +23,17 @@ void main() {
const uint num_blocks_per_row = p.ncols / QUANT_K; const uint num_blocks_per_row = p.ncols / QUANT_K;
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 // 16 threads are used to process each block
const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...16
const uint ix = tid/16;
const uint step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 const uint step = 4;
const uint il = tid/step; // 0...3 const uint il = itid/step; // 0...3
const uint ir = tid - step*il; // 0...7 or 0...3 const uint ir = itid - step*il; // 0...7 or 0...3
const uint n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 const uint n = 4;
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const uint v_in = il % 2; const uint v_in = il % 2;
@ -40,7 +44,7 @@ void main() {
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y1_idx = i * QUANT_K + y_offset; const uint y1_idx = i * QUANT_K + y_offset;
const uint y2_idx = y1_idx + 128; const uint y2_idx = y1_idx + 128;
@ -115,7 +119,7 @@ void main() {
// sum up partial sums and write back result // sum up partial sums and write back result
barrier(); barrier();
[[unroll]] for (uint s = 16; s > 0; s >>= 1) { [[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
if (tid < s) { if (tid < s) {
tmp[tid] += tmp[tid + s]; tmp[tid] += tmp[tid + s];
} }

View file

@ -4,9 +4,11 @@
#include "mul_mat_vec_base.comp" #include "mul_mat_vec_base.comp"
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE tmp[32]; layout (constant_id = 0) const uint BLOCK_SIZE = 32;
shared FLOAT_TYPE tmp[BLOCK_SIZE];
void main() { void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z; const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
@ -21,11 +23,14 @@ void main() {
const uint num_blocks_per_row = p.ncols / QUANT_K; const uint num_blocks_per_row = p.ncols / QUANT_K;
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
const uint tid = gl_LocalInvocationID.x/2; // 0...31 or 0...16 // 16 threads are used to process each block
const uint ix = gl_LocalInvocationID.x%2; // 0 or 0, 1 const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...16
const uint ix = tid/16;
const uint il = tid/4; // 0...3 const uint il = itid/4; // 0...3
const uint ir = tid - 4*il; // 0...7 or 0...3 const uint ir = itid - 4*il; // 0...7 or 0...3
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const uint v_in = il % 2; const uint v_in = il % 2;
@ -36,7 +41,7 @@ void main() {
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) { [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y1_idx = i * QUANT_K + y_offset; const uint y1_idx = i * QUANT_K + y_offset;
const uint y2_idx = y1_idx + 128; const uint y2_idx = y1_idx + 128;
@ -143,7 +148,7 @@ void main() {
// sum up partial sums and write back result // sum up partial sums and write back result
barrier(); barrier();
[[unroll]] for (uint s = 16; s > 0; s >>= 1) { [[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
if (tid < s) { if (tid < s) {
tmp[tid] += tmp[tid + s]; tmp[tid] += tmp[tid + s];
} }

View file

@ -7,6 +7,12 @@
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif #endif
#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_shader_subgroup_basic : enable
#endif
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#endif #endif
@ -57,6 +63,7 @@ layout (push_constant) uniform parameter
#endif #endif
} p; } p;
layout (constant_id = 0) const uint BLOCK_SIZE = 64;
layout (constant_id = 1) const uint BM = 64; layout (constant_id = 1) const uint BM = 64;
layout (constant_id = 2) const uint BN = 64; layout (constant_id = 2) const uint BN = 64;
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
@ -65,13 +72,26 @@ layout (constant_id = 5) const uint WN = 32;
layout (constant_id = 6) const uint WMITER = 2; layout (constant_id = 6) const uint WMITER = 2;
layout (constant_id = 7) const uint TM = 4; layout (constant_id = 7) const uint TM = 4;
layout (constant_id = 8) const uint TN = 2; layout (constant_id = 8) const uint TN = 2;
layout (constant_id = 9) const uint WARP = 32; layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
layout (constant_id = 10) const uint WARP = 32;
shared FLOAT_TYPE buf_a[BM * (BK+1)]; #ifdef COOPMAT
shared FLOAT_TYPE buf_b[BN * (BK+1)]; #define SHMEM_STRIDE (BK + 8)
#else
#define SHMEM_STRIDE (BK + 1)
#endif
shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
shared u16vec2 row_ids[3072]; shared u16vec2 row_ids[3072];
#endif // MUL_MAT_ID
#define NUM_WARPS (BLOCK_SIZE / WARP)
#ifdef COOPMAT
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif #endif
void main() { void main() {
@ -98,17 +118,32 @@ void main() {
const uint ik = gl_WorkGroupID.x / blocks_m; const uint ik = gl_WorkGroupID.x / blocks_m;
const uint ic = gl_WorkGroupID.y; const uint ic = gl_WorkGroupID.y;
const uint warp_i = gl_LocalInvocationID.x / WARP;
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
const uint WSUBM = WM / WMITER; const uint WSUBM = WM / WMITER;
const uint WSUBN = WN / WNITER; const uint WSUBN = WN / WNITER;
#ifdef COOPMAT
const uint warp_i = gl_SubgroupID;
const uint tiw = gl_SubgroupInvocationID;
const uint cms_per_row = WM / TM;
const uint cms_per_col = WN / TN;
const uint storestride = WARP / TM;
const uint store_r = tiw % TM;
const uint store_c = tiw / TM;
#else
const uint warp_i = gl_LocalInvocationID.x / WARP;
const uint tiw = gl_LocalInvocationID.x % WARP; const uint tiw = gl_LocalInvocationID.x % WARP;
const uint tiwr = tiw % (WSUBM / TM); const uint tiwr = tiw % (WSUBM / TM);
const uint tiwc = tiw / (WSUBM / TM); const uint tiwc = tiw / (WSUBM / TM);
#endif
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
@ -156,21 +191,31 @@ void main() {
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
#endif #endif
float sums[WMITER * TM * WNITER * TN]; #ifdef COOPMAT
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
}
#else
ACC_TYPE sums[WMITER * TM * WNITER * TN];
FLOAT_TYPE cache_a[WMITER * TM]; FLOAT_TYPE cache_a[WMITER * TM];
FLOAT_TYPE cache_b[WNITER * TN]; FLOAT_TYPE cache_b[WNITER * TN];
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = 0.0f; sums[i] = ACC_TYPE(0.0f);
} }
#endif
[[unroll]] for (uint block = start_k; block < end_k; block += BK) { for (uint block = start_k; block < end_k; block += BK) {
[[unroll]] for (uint l = 0; l < BM; l += loadstride_a) { [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
#if defined(DATA_A_F32) || defined(DATA_A_F16) #if defined(DATA_A_F32) || defined(DATA_A_F16)
#if LOAD_VEC_A == 8 #if LOAD_VEC_A == 8
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x); buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y); buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z); buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
@ -181,21 +226,21 @@ void main() {
buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w); buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
#elif LOAD_VEC_A == 4 #elif LOAD_VEC_A == 4
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x); buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y); buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z); buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w); buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
#else #else
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
} else { } else {
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f); buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
} }
#endif #endif
#elif defined(DATA_A_Q4_0) #elif defined(DATA_A_Q4_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
@ -208,7 +253,7 @@ void main() {
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q4_1) #elif defined(DATA_A_Q4_1)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
@ -222,7 +267,7 @@ void main() {
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q5_0) #elif defined(DATA_A_Q5_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
@ -237,7 +282,7 @@ void main() {
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q5_1) #elif defined(DATA_A_Q5_1)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
@ -253,7 +298,7 @@ void main() {
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q8_0) #elif defined(DATA_A_Q8_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = (idx & 0xF) * 2; const uint iqs = (idx & 0xF) * 2;
@ -265,7 +310,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q2_K) #elif defined(DATA_A_Q2_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127 const uint iqs = idx % 128; // 0..127
@ -284,7 +329,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q3_K) #elif defined(DATA_A_Q3_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127 const uint iqs = idx % 128; // 0..127
@ -308,7 +353,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
#elif defined(DATA_A_Q4_K) #elif defined(DATA_A_Q4_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127 const uint iqs = idx % 128; // 0..127
@ -320,15 +365,20 @@ void main() {
const vec2 loadd = vec2(data_a[ib].d); const vec2 loadd = vec2(data_a[ib].d);
uint8_t sc; const uint scidx0 = (is < 4) ? is : (is + 4);
uint8_t mbyte; const uint scidx1 = (is < 4) ? is : (is - 4);
if (is < 4) { const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
sc = uint8_t(data_a[ib].scales[is ] & 63); const uint scidxshift1 = (is < 4) ? 0 : 2;
mbyte = uint8_t(data_a[ib].scales[is + 4] & 63); const uint mbidx0 = is + 4;
} else { const uint mbidx1 = (is < 4) ? is + 4 : is;
sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4)); const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4)); const uint mbidxshift0 = (is < 4) ? 0 : 4;
} const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint mbidxshift1 = (is < 4) ? 0 : 2;
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const float d = loadd.x * sc; const float d = loadd.x * sc;
const float m = -loadd.y * mbyte; const float m = -loadd.y * mbyte;
@ -336,7 +386,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
#elif defined(DATA_A_Q5_K) #elif defined(DATA_A_Q5_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127 const uint iqs = idx % 128; // 0..127
@ -351,15 +401,20 @@ void main() {
const vec2 loadd = vec2(data_a[ib].d); const vec2 loadd = vec2(data_a[ib].d);
uint8_t sc; const uint scidx0 = (is < 4) ? is : (is + 4);
uint8_t mbyte; const uint scidx1 = (is < 4) ? is : (is - 4);
if (is < 4) { const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
sc = uint8_t(data_a[ib].scales[is ] & 63); const uint scidxshift1 = (is < 4) ? 0 : 2;
mbyte = uint8_t(data_a[ib].scales[is + 4] & 63); const uint mbidx0 = is + 4;
} else { const uint mbidx1 = (is < 4) ? is + 4 : is;
sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4)); const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4)); const uint mbidxshift0 = (is < 4) ? 0 : 4;
} const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint mbidxshift1 = (is < 4) ? 0 : 2;
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const float d = loadd.x * sc; const float d = loadd.x * sc;
const float m = -loadd.y * mbyte; const float m = -loadd.y * mbyte;
@ -367,7 +422,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
#elif defined(DATA_A_Q6_K) #elif defined(DATA_A_Q6_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127 const uint iqs = idx % 128; // 0..127
@ -386,7 +441,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
#elif defined(DATA_A_IQ4_NL) #elif defined(DATA_A_IQ4_NL)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
@ -407,7 +462,7 @@ void main() {
#else #else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
#endif #endif
const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B; const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x); buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y); buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z); buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
@ -423,24 +478,24 @@ void main() {
#else #else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
#endif #endif
const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B; const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x); buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y); buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z); buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w); buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
#elif !MUL_MAT_ID #elif !MUL_MAT_ID
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
} else { } else {
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f); buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
} }
#else #else
const uint row_i = ic * BN + loadc_b + l; const uint row_i = ic * BN + loadc_b + l;
if (row_i < _ne1) { if (row_i < _ne1) {
const u16vec2 row_idx = row_ids[row_i]; const u16vec2 row_idx = row_ids[row_i];
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
} else { } else {
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f); buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
} }
#endif #endif
} }
@ -450,16 +505,30 @@ void main() {
pos_a += BK / LOAD_VEC_A; pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B; pos_b += BK / LOAD_VEC_B;
for (uint i = 0; i < BK; i++) { #ifdef COOPMAT
[[unroll]] for (uint i = 0; i < BK; i += TK) {
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
// Load from shared into cache
coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
}
}
}
#else
[[unroll]] for (uint i = 0; i < BK; i++) {
// Load from shared into cache // Load from shared into cache
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint j = 0; j < TM; j++) { [[unroll]] for (uint j = 0; j < TM; j++) {
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i]; cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
} }
} }
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint j = 0; j < TN; j++) { [[unroll]] for (uint j = 0; j < TN; j++) {
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i]; cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
} }
} }
@ -468,12 +537,13 @@ void main() {
[[unroll]] for (uint cc = 0; cc < TN; cc++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
sums[sums_idx] = fma(float(cache_a[wsir * TM + cr]), float(cache_b[wsic * TN + cc]), sums[sums_idx]); sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]);
} }
} }
} }
} }
} }
#endif
barrier(); barrier();
} }
@ -485,6 +555,54 @@ void main() {
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
#endif #endif
#ifdef COOPMAT
#ifdef MUL_MAT_ID
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
const uint row_i = dc + cm_col * TN + col + store_c;
if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i];
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
}
}
#else
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
if (is_aligned && is_in_bounds) {
// Full coopMat is within bounds and stride_d is aligned with 16B
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
} else if (is_in_bounds) {
// Full coopMat is within bounds, but stride_d is not aligned
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
// Partial coopMat is within bounds
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
}
}
}
}
#endif // MUL_MAT_ID
#else
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
@ -496,7 +614,7 @@ void main() {
if (row_i >= _ne1) break; if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i]; const u16vec2 row_idx = row_ids[row_i];
#endif #endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) {
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
@ -504,9 +622,10 @@ void main() {
if (dr_warp + cr < p.M && dc_warp + cc < p.N) { if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
} }
#endif #endif // MUL_MAT_ID
} }
} }
} }
} }
#endif // COOPMAT
} }

View file

@ -1,6 +1,11 @@
#include "types.comp" #include "types.comp"
#extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_spirv_intrinsics: enable
#if RTE16
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
#endif
layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;

View file

@ -16,6 +16,5 @@ void main() {
if (i >= p.KX) { if (i >= p.KX) {
return; return;
} }
data_d[i] = D_TYPE(1. - 2. / (exp(2.*data_a[i]) + 1.));
data_d[i] = D_TYPE(tanh(data_a[i]));
} }

View file

@ -0,0 +1,7 @@
#version 460
#extension GL_NV_cooperative_matrix2 : require
void main()
{
}

View file

@ -61,6 +61,7 @@ const std::vector<std::string> type_names = {
"iq4_nl" "iq4_nl"
}; };
namespace {
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) { void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
#ifdef _WIN32 #ifdef _WIN32
HANDLE stdout_read, stdout_write; HANDLE stdout_read, stdout_write;
@ -199,17 +200,20 @@ static uint32_t compile_count = 0;
static std::mutex compile_count_mutex; static std::mutex compile_count_mutex;
static std::condition_variable compile_count_cond; static std::condition_variable compile_count_cond;
void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) { void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
std::string out_fname = join_paths(output_dir, name + ".spv"); std::string out_fname = join_paths(output_dir, name + ".spv");
std::string in_path = join_paths(input_dir, in_fname); std::string in_path = join_paths(input_dir, in_fname);
std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2"; std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2";
// disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734
std::string opt_level = coopmat ? "" : "-O";
#ifdef _WIN32 #ifdef _WIN32
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, "-O", "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""}; std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
#else #else
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, "-O", in_path, "-o", out_fname}; std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_fname};
#endif #endif
#ifdef GGML_VULKAN_SHADER_DEBUG_INFO #ifdef GGML_VULKAN_SHADER_DEBUG_INFO
@ -259,7 +263,7 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s
} }
static std::vector<std::future<void>> compiles; static std::vector<std::future<void>> compiles;
void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) { void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
{ {
// wait until fewer than N compiles are in progress. // wait until fewer than N compiles are in progress.
// 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors. // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
@ -270,10 +274,10 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
} }
compile_count++; compile_count++;
} }
compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat2, f16acc)); compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
} }
void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) { void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) {
std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
@ -292,14 +296,20 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) {
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
if (coopmat) {
base_dict["COOPMAT"] = "1";
}
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
// Shaders with f16 B_TYPE // Shaders with f16 B_TYPE
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
for (const auto& tname : type_names) { for (const auto& tname : type_names) {
std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string data_a_key = "DATA_A_" + to_uppercase(tname);
@ -308,12 +318,12 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) {
// For aligned matmul loads // For aligned matmul loads
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2"; std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
if (tname != "f16" && tname != "f32") { if (tname != "f16" && tname != "f32") {
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
} }
} }
} }
@ -323,28 +333,27 @@ void process_shaders() {
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}}; std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
// matmul // matmul
for (const auto& fp16 : {false, true}) { for (const auto& matmul_id : {false, true}) {
for (const auto& matmul_id : {false, true}) { // No coopmats
for (const auto& coopmat2 : {false, true}) { // fp32
for (const auto& f16acc : {false, true}) { matmul_shaders(false, matmul_id, false, false, false);
#if !defined(VK_NV_cooperative_matrix2)
if (coopmat2) { // fp16, fp32acc and fp16acc
continue; matmul_shaders(true, matmul_id, false, false, false);
} matmul_shaders(true, matmul_id, false, false, true);
// Coopmat, fp32acc and fp16acc
matmul_shaders(true, matmul_id, true, false, false);
matmul_shaders(true, matmul_id, true, false, true);
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
// Coopmat2, fp32acc and fp16acc
matmul_shaders(true, matmul_id, false, true, false);
matmul_shaders(true, matmul_id, false, true, true);
#endif #endif
if (coopmat2 && !fp16) {
continue;
}
if (!coopmat2 && f16acc) {
continue;
}
matmul_shaders(fp16, matmul_id, coopmat2, f16acc);
}
}
}
} }
#if defined(VK_NV_cooperative_matrix2) #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
// flash attention // flash attention
for (const auto& f16acc : {false, true}) { for (const auto& f16acc : {false, true}) {
std::string acctype = f16acc ? "float16_t" : "float"; std::string acctype = f16acc ? "float16_t" : "float";
@ -356,11 +365,11 @@ void process_shaders() {
if (tname == "f16") { if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, true, f16acc); merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
} else { } else {
std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, true, f16acc); merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
} }
} }
} }
@ -453,9 +462,11 @@ void process_shaders() {
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
@ -463,6 +474,7 @@ void process_shaders() {
string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
@ -525,6 +537,7 @@ void write_output_files() {
fclose(hdr); fclose(hdr);
fclose(src); fclose(src);
} }
}
int main(int argc, char** argv) { int main(int argc, char** argv) {
std::map<std::string, std::string> args; std::map<std::string, std::string> args;

View file

@ -9,7 +9,10 @@
// FIXME: required here for quantization functions // FIXME: required here for quantization functions
#include "ggml-quants.h" #include "ggml-quants.h"
#include "ggml-aarch64.h"
#ifdef GGML_USE_CPU_HBM
#include <hbwmalloc.h>
#endif
#if defined(_MSC_VER) || defined(__MINGW32__) #if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW #include <malloc.h> // using malloc.h with MSC/MINGW
@ -789,32 +792,23 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
.to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row, .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row,
.from_float_ref = (ggml_from_float_t) ggml_fp32_to_bf16_row_ref, .from_float_ref = (ggml_from_float_t) ggml_fp32_to_bf16_row_ref,
}, },
[GGML_TYPE_Q4_0_4_4] = { [31] = { // GGML_TYPE_Q4_0_4_4
.type_name = "q4_0_4x4", .type_name = "TYPE_Q4_0_4_4 REMOVED, use Q4_0 with runtime repacking",
.blck_size = QK4_0, .blck_size = 0,
.blck_size_interleave = 4, .type_size = 0,
.type_size = sizeof(block_q4_0), .is_quantized = false,
.is_quantized = true,
.to_float = NULL,
.from_float_ref = NULL,
}, },
[GGML_TYPE_Q4_0_4_8] = { [32] = { // GGML_TYPE_Q4_0_4_8
.type_name = "q4_0_4x8", .type_name = "TYPE_Q4_0_4_8 REMOVED, use Q4_0 with runtime repacking",
.blck_size = QK4_0, .blck_size = 0,
.blck_size_interleave = 8, .type_size = 0,
.type_size = sizeof(block_q4_0), .is_quantized = false,
.is_quantized = true,
.to_float = NULL,
.from_float_ref = NULL,
}, },
[GGML_TYPE_Q4_0_8_8] = { [33] = { // GGML_TYPE_Q4_0_8_8
.type_name = "q4_0_8x8", .type_name = "TYPE_Q4_0_8_8 REMOVED, use Q4_0 with runtime repacking",
.blck_size = QK4_0, .blck_size = 0,
.blck_size_interleave = 8, .type_size = 0,
.type_size = sizeof(block_q4_0), .is_quantized = false,
.is_quantized = true,
.to_float = NULL,
.from_float_ref = NULL,
}, },
[GGML_TYPE_TQ1_0] = { [GGML_TYPE_TQ1_0] = {
.type_name = "tq1_0", .type_name = "tq1_0",
@ -832,14 +826,23 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
.to_float = (ggml_to_float_t) dequantize_row_tq2_0, .to_float = (ggml_to_float_t) dequantize_row_tq2_0,
.from_float_ref = (ggml_from_float_t) quantize_row_tq2_0_ref, .from_float_ref = (ggml_from_float_t) quantize_row_tq2_0_ref,
}, },
[GGML_TYPE_IQ4_NL_4_4] = { [36] = { // GGML_TYPE_IQ4_NL_4_4
.type_name = "iq4_nl_4x4", .type_name = "TYPE_IQ4_NL_4_4 REMOVED, use IQ4_NL with runtime repacking",
.blck_size = QK4_NL, .blck_size = 0,
.blck_size_interleave = 4, .type_size = 0,
.type_size = sizeof(block_iq4_nl), .is_quantized = false,
.is_quantized = true, },
.to_float = NULL, [37] = { // GGML_TYPE_IQ4_NL_4_8
.from_float_ref = NULL, .type_name = "TYPE_IQ4_NL_4_8 REMOVED, use IQ4_NL with runtime repacking",
.blck_size = 0,
.type_size = 0,
.is_quantized = false,
},
[38] = { // GGML_TYPE_IQ4_NL_8_8
.type_name = "TYPE_IQ4_NL_8_8 REMOVED, use IQ4_NL with runtime repacking",
.blck_size = 0,
.type_size = 0,
.is_quantized = false,
}, },
}; };
@ -1271,9 +1274,6 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break; case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break; case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break;
case GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = GGML_TYPE_Q4_0_4_4; break;
case GGML_FTYPE_MOSTLY_Q4_0_4_8: wtype = GGML_TYPE_Q4_0_4_8; break;
case GGML_FTYPE_MOSTLY_Q4_0_8_8: wtype = GGML_TYPE_Q4_0_8_8; break;
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
} }
@ -6305,9 +6305,6 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_0_4_4: result = quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_0_4_8: result = quantize_q4_0_4x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_0_8_8: result = quantize_q4_0_8x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_F16: case GGML_TYPE_F16:
{ {
size_t elemsize = sizeof(ggml_fp16_t); size_t elemsize = sizeof(ggml_fp16_t);
@ -6883,7 +6880,16 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
(int64_t) info->ne[2] * (int64_t) info->ne[2] *
(int64_t) info->ne[3]; (int64_t) info->ne[3];
if (ggml_blck_size(info->type) == 0 || ne % ggml_blck_size(info->type) != 0) { if (ggml_blck_size(info->type) == 0 ) {
// this tensor type support have been removed:
fprintf(stderr, "%s: tensor '%s' of type %d: %s\n",
__func__, info->name.data, (int) info->type, ggml_type_name(info->type));
fclose(file);
gguf_free(ctx);
return NULL;
}
if (ne % ggml_blck_size(info->type) != 0) {
fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n", fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n",
__func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type)); __func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type));
fclose(file); fclose(file);

View file

@ -761,6 +761,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT, MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_K,
@ -1432,9 +1433,6 @@ class GGMLQuantizationType(IntEnum):
F64 = 28 F64 = 28
IQ1_M = 29 IQ1_M = 29
BF16 = 30 BF16 = 30
Q4_0_4_4 = 31
Q4_0_4_8 = 32
Q4_0_8_8 = 33
TQ1_0 = 34 TQ1_0 = 34
TQ2_0 = 35 TQ2_0 = 35
@ -1478,9 +1476,9 @@ class LlamaFileType(IntEnum):
MOSTLY_IQ4_XS = 30 # except 1d tensors MOSTLY_IQ4_XS = 30 # except 1d tensors
MOSTLY_IQ1_M = 31 # except 1d tensors MOSTLY_IQ1_M = 31 # except 1d tensors
MOSTLY_BF16 = 32 # except 1d tensors MOSTLY_BF16 = 32 # except 1d tensors
MOSTLY_Q4_0_4_4 = 33 # except 1d tensors # MOSTLY_Q4_0_4_4 = 33 # removed from gguf files, use Q4_0 and runtime repack
MOSTLY_Q4_0_4_8 = 34 # except 1d tensors # MOSTLY_Q4_0_4_8 = 34 # removed from gguf files, use Q4_0 and runtime repack
MOSTLY_Q4_0_8_8 = 35 # except 1d tensors # MOSTLY_Q4_0_8_8 = 35 # removed from gguf files, use Q4_0 and runtime repack
MOSTLY_TQ1_0 = 36 # except 1d tensors MOSTLY_TQ1_0 = 36 # except 1d tensors
MOSTLY_TQ2_0 = 37 # except 1d tensors MOSTLY_TQ2_0 = 37 # except 1d tensors
@ -1556,9 +1554,6 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
GGMLQuantizationType.F64: (1, 8), GGMLQuantizationType.F64: (1, 8),
GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32), GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32),
GGMLQuantizationType.BF16: (1, 2), GGMLQuantizationType.BF16: (1, 2),
GGMLQuantizationType.Q4_0_4_4:(32, 2 + 16),
GGMLQuantizationType.Q4_0_4_8:(32, 2 + 16),
GGMLQuantizationType.Q4_0_8_8:(32, 2 + 16),
GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13), GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13),
GGMLQuantizationType.TQ2_0: (256, 2 + 64), GGMLQuantizationType.TQ2_0: (256, 2 + 64),
} }

View file

@ -172,9 +172,9 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors //LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // removed from gguf files, use Q4_0 and runtime repack
LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors //LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // removed from gguf files, use Q4_0 and runtime repack
LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors //LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // removed from gguf files, use Q4_0 and runtime repack
LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors

View file

@ -68,11 +68,10 @@
#define STDMETHODCALLTYPE #define STDMETHODCALLTYPE
#define STDMETHODIMP_(type) type STDMETHODCALLTYPE #define STDMETHODIMP_(type) type STDMETHODCALLTYPE
#define STDMETHODIMP STDMETHODIMP_(HRESULT) #define STDMETHODIMP STDMETHODIMP_(HRESULT)
#define STDMETHOD_(type,name) virtual STDMETHODIMP_(type) name #define STDMETHOD_(type, name) virtual STDMETHODIMP_(type) name
#define STDMETHOD(name) STDMETHOD_(HRESULT, name) #define STDMETHOD(name) STDMETHOD_(HRESULT, name)
#define EXTERN_C extern "C" #define EXTERN_C extern "C"
#define UNREFERENCED_PARAMETER(P) (void)(P) #define UNREFERENCED_PARAMETER(P) (void)(P)
#define RtlEqualMemory(Destination, Source, Length) \ #define RtlEqualMemory(Destination, Source, Length) \
@ -127,7 +126,7 @@
// Used by HRESULT <--> WIN32 error code conversion // Used by HRESULT <--> WIN32 error code conversion
#define SEVERITY_ERROR 1 #define SEVERITY_ERROR 1
#define FACILITY_WIN32 7 #define FACILITY_WIN32 7
#define HRESULT_CODE(hr) ((hr)&0xFFFF) #define HRESULT_CODE(hr) ((hr) & 0xFFFF)
#define MAKE_HRESULT(severity, facility, code) \ #define MAKE_HRESULT(severity, facility, code) \
((HRESULT)(((unsigned long)(severity) << 31) | \ ((HRESULT)(((unsigned long)(severity) << 31) | \
((unsigned long)(facility) << 16) | ((unsigned long)(code)))) ((unsigned long)(facility) << 16) | ((unsigned long)(code))))
@ -239,7 +238,7 @@
#define HRESULT_FROM_WIN32(x) \ #define HRESULT_FROM_WIN32(x) \
(HRESULT)(x) <= 0 ? (HRESULT)(x) \ (HRESULT)(x) <= 0 ? (HRESULT)(x) \
: (HRESULT)(((x)&0x0000FFFF) | (7 << 16) | 0x80000000) : (HRESULT)(((x) & 0x0000FFFF) | (7 << 16) | 0x80000000)
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// //
@ -251,89 +250,30 @@
#define _In_opt_ #define _In_opt_
#define _In_opt_count_(size) #define _In_opt_count_(size)
#define _In_opt_z_ #define _In_opt_z_
#define _In_reads_(size)
#define _In_reads_bytes_(size)
#define _In_reads_bytes_opt_(size)
#define _In_reads_opt_(size)
#define _In_reads_to_ptr_(ptr)
#define _In_count_(size) #define _In_count_(size)
#define _In_range_(lb, ub)
#define _In_bytecount_(size) #define _In_bytecount_(size)
#define _In_opt_bytecount_(size)
#define _In_NLS_string_(size)
#define __in_bcount(size)
#define _Out_ #define _Out_
#define _Out_bytecap_(nbytes) #define _Out_opt_
#define _Out_writes_to_(a, b)
#define _Out_writes_to_opt_(a, b)
#define _Outptr_ #define _Outptr_
#define _Outptr_opt_ #define _Outptr_opt_
#define _Outptr_opt_result_z_
#define _Out_opt_
#define _Out_writes_(size)
#define _Out_write_bytes_(size)
#define _Out_writes_z_(size)
#define _Out_writes_all_(size)
#define _Out_writes_bytes_(size)
#define _Outref_result_buffer_(size)
#define _Outptr_result_buffer_(size)
#define _Out_cap_(size)
#define _Out_cap_x_(size)
#define _Out_range_(lb, ub)
#define _Outptr_result_z_ #define _Outptr_result_z_
#define _Outptr_result_buffer_maybenull_(ptr) #define _Outptr_opt_result_z_
#define _Outptr_result_maybenull_ #define _Outptr_result_maybenull_
#define _Outptr_result_nullonfailure_ #define _Outptr_result_nullonfailure_
#define _Outptr_result_buffer_maybenull_(ptr)
#define __out_ecount_part(a, b) #define _Outptr_result_buffer_(ptr)
#define _Inout_
#define _Inout_z_
#define _Inout_opt_
#define _Inout_cap_(size)
#define _Inout_count_(size)
#define _Inout_count_c_(size)
#define _Inout_opt_count_c_(size)
#define _Inout_bytecount_c_(size)
#define _Inout_opt_bytecount_c_(size)
#define _Ret_maybenull_
#define _Ret_notnull_
#define _Ret_opt_
#define _Use_decl_annotations_
#define __analysis_assume(expr)
#define _Analysis_assume_(expr)
#define _Analysis_assume_nullterminated_(x)
#define _Success_(expr)
#define __inexpressible_readableTo(size)
#define __inexpressible_writableTo(size)
#define _Printf_format_string_
#define _Null_terminated_
#define _Field_size_(size)
#define _Field_size_full_(size)
#define _Field_size_opt_(size)
#define _Post_writable_byte_size_(size)
#define _Post_readable_byte_size_(size)
#define __drv_allocatesMem(mem)
#define _COM_Outptr_ #define _COM_Outptr_
#define _COM_Outptr_opt_ #define _COM_Outptr_opt_
#define _COM_Outptr_result_maybenull_ #define _COM_Outptr_result_maybenull_
#define _COM_Outptr_opt_result_maybenull_ #define _COM_Outptr_opt_result_maybenull_
#define _Null_
#define _Notnull_
#define _Maybenull_
#define THIS_ #define THIS_
#define THIS #define THIS
#define PURE = 0 #define PURE = 0
#define _Outptr_result_bytebuffer_(size) #define _Maybenull_
#define __debugbreak() #define __debugbreak()
@ -620,17 +560,18 @@ template <typename T> inline void **IID_PPV_ARGS_Helper(T **pp) {
#endif // __EMULATE_UUID #endif // __EMULATE_UUID
// Needed for d3d headers, but fail to create actual interfaces // Needed for d3d headers, but fail to create actual interfaces
#define DEFINE_GUID(name, l, w1, w2, b1, b2, b3, b4, b5, b6, b7, b8) const GUID name = { l, w1, w2, { b1, b2, b3, b4, b5, b6, b7, b8 } } #define DEFINE_GUID(name, l, w1, w2, b1, b2, b3, b4, b5, b6, b7, b8) \
const GUID name = {l, w1, w2, {b1, b2, b3, b4, b5, b6, b7, b8}}
#define DECLSPEC_UUID(x) #define DECLSPEC_UUID(x)
#define MIDL_INTERFACE(x) struct DECLSPEC_UUID(x) #define MIDL_INTERFACE(x) struct DECLSPEC_UUID(x)
#define DECLARE_INTERFACE(iface) struct iface #define DECLARE_INTERFACE(iface) struct iface
#define DECLARE_INTERFACE_(iface, parent) DECLARE_INTERFACE(iface) : parent #define DECLARE_INTERFACE_(iface, parent) DECLARE_INTERFACE(iface) : parent
//===--------------------- COM Interfaces ---------------------------------===// //===--------------------- COM Interfaces ---------------------------------===//
CROSS_PLATFORM_UUIDOF(IUnknown, "00000000-0000-0000-C000-000000000046") CROSS_PLATFORM_UUIDOF(IUnknown, "00000000-0000-0000-C000-000000000046")
struct IUnknown { struct IUnknown {
IUnknown() {}; IUnknown(){};
virtual HRESULT QueryInterface(REFIID riid, void **ppvObject) = 0; virtual HRESULT QueryInterface(REFIID riid, void **ppvObject) = 0;
virtual ULONG AddRef() = 0; virtual ULONG AddRef() = 0;
virtual ULONG Release() = 0; virtual ULONG Release() = 0;
@ -644,10 +585,10 @@ struct INoMarshal : public IUnknown {};
CROSS_PLATFORM_UUIDOF(IMalloc, "00000002-0000-0000-C000-000000000046") CROSS_PLATFORM_UUIDOF(IMalloc, "00000002-0000-0000-C000-000000000046")
struct IMalloc : public IUnknown { struct IMalloc : public IUnknown {
virtual void *Alloc(size_t size) = 0; virtual void *Alloc(SIZE_T size) = 0;
virtual void *Realloc(void *ptr, size_t size) = 0; virtual void *Realloc(void *ptr, SIZE_T size) = 0;
virtual void Free(void *ptr) = 0; virtual void Free(void *ptr) = 0;
virtual size_t GetSize(void *pv) = 0; virtual SIZE_T GetSize(void *pv) = 0;
virtual int DidAlloc(void *pv) = 0; virtual int DidAlloc(void *pv) = 0;
virtual void HeapMinimize(void) = 0; virtual void HeapMinimize(void) = 0;
}; };
@ -684,8 +625,10 @@ struct IStream : public ISequentialStream {
// These don't need stub implementations as they come from the DirectX Headers // These don't need stub implementations as they come from the DirectX Headers
// They still need the __uuidof() though // They still need the __uuidof() though
CROSS_PLATFORM_UUIDOF(ID3D12LibraryReflection, "8E349D19-54DB-4A56-9DC9-119D87BDB804") CROSS_PLATFORM_UUIDOF(ID3D12LibraryReflection,
CROSS_PLATFORM_UUIDOF(ID3D12ShaderReflection, "5A58797D-A72C-478D-8BA2-EFC6B0EFE88E") "8E349D19-54DB-4A56-9DC9-119D87BDB804")
CROSS_PLATFORM_UUIDOF(ID3D12ShaderReflection,
"5A58797D-A72C-478D-8BA2-EFC6B0EFE88E")
//===--------------------- COM Pointer Types ------------------------------===// //===--------------------- COM Pointer Types ------------------------------===//
@ -817,6 +760,14 @@ public:
return *this; return *this;
} }
// NOTE: This conversion constructor is not part of the official CComPtr spec;
// however, it is needed to convert CComPtr<Q> to CComPtr<T> where T derives
// from Q on Clang. MSVC compiles this conversion as first a call to
// CComPtr<Q>::operator T*, followed by CComPtr<T>(T*), but Clang fails to
// compile with error: no viable conversion from 'CComPtr<Q>' to 'CComPtr<T>'.
template <typename Q>
CComPtr(const CComPtr<Q> &lp) throw() : CComPtrBase<T>(lp.p) {}
T *operator=(const CComPtr<T> &lp) throw() { T *operator=(const CComPtr<T> &lp) throw() {
if (*this != lp) { if (*this != lp) {
CComPtr(lp).Swap(*this); CComPtr(lp).Swap(*this);
@ -952,38 +903,49 @@ void SysFreeString(BSTR bstrString);
// Allocate string with length prefix // Allocate string with length prefix
BSTR SysAllocStringLen(const OLECHAR *strIn, UINT ui); BSTR SysAllocStringLen(const OLECHAR *strIn, UINT ui);
//===--------------------------- BSTR Length ------------------------------===//
unsigned int SysStringLen(const BSTR bstrString);
//===--------------------- UTF-8 Related Types ----------------------------===// //===--------------------- UTF-8 Related Types ----------------------------===//
// Code Page // Code Page
#define CP_ACP 0 #define CP_ACP 0
#define CP_UTF8 65001 // UTF-8 translation. #define CP_UTF8 65001 // UTF-8 translation.
// Convert Windows codepage value to locale string // RAII style mechanism for setting/unsetting a locale for the specified Windows
const char *CPToLocale(uint32_t CodePage); // codepage
class ScopedLocale {
const char *m_prevLocale;
public:
explicit ScopedLocale(uint32_t codePage)
: m_prevLocale(setlocale(LC_ALL, nullptr)) {
assert((codePage == CP_UTF8) &&
"Support for Linux only handles UTF8 code pages");
setlocale(LC_ALL, "en_US.UTF-8");
}
~ScopedLocale() {
if (m_prevLocale != nullptr) {
setlocale(LC_ALL, m_prevLocale);
}
}
};
// The t_nBufferLength parameter is part of the published interface, but not // The t_nBufferLength parameter is part of the published interface, but not
// used here. // used here.
template <int t_nBufferLength = 128> class CW2AEX { template <int t_nBufferLength = 128> class CW2AEX {
public: public:
CW2AEX(LPCWSTR psz, UINT nCodePage = CP_UTF8) { CW2AEX(LPCWSTR psz) {
const char *locale = CPToLocale(nCodePage); ScopedLocale locale(CP_UTF8);
if (locale == nullptr) {
// Current Implementation only supports CP_UTF8, and CP_ACP
assert(false && "CW2AEX implementation for Linux only handles "
"UTF8 and ACP code pages");
return;
}
if (!psz) { if (!psz) {
m_psz = NULL; m_psz = NULL;
return; return;
} }
locale = setlocale(LC_ALL, locale);
int len = (wcslen(psz) + 1) * 4; int len = (wcslen(psz) + 1) * 4;
m_psz = new char[len]; m_psz = new char[len];
std::wcstombs(m_psz, psz, len); std::wcstombs(m_psz, psz, len);
setlocale(LC_ALL, locale);
} }
~CW2AEX() { delete[] m_psz; } ~CW2AEX() { delete[] m_psz; }
@ -998,25 +960,17 @@ typedef CW2AEX<> CW2A;
// used here. // used here.
template <int t_nBufferLength = 128> class CA2WEX { template <int t_nBufferLength = 128> class CA2WEX {
public: public:
CA2WEX(LPCSTR psz, UINT nCodePage = CP_UTF8) { CA2WEX(LPCSTR psz) {
const char *locale = CPToLocale(nCodePage); ScopedLocale locale(CP_UTF8);
if (locale == nullptr) {
// Current Implementation only supports CP_UTF8, and CP_ACP
assert(false && "CA2WEX implementation for Linux only handles "
"UTF8 and ACP code pages");
return;
}
if (!psz) { if (!psz) {
m_psz = NULL; m_psz = NULL;
return; return;
} }
locale = setlocale(LC_ALL, locale);
int len = strlen(psz) + 1; int len = strlen(psz) + 1;
m_psz = new wchar_t[len]; m_psz = new wchar_t[len];
std::mbstowcs(m_psz, psz, len); std::mbstowcs(m_psz, psz, len);
setlocale(LC_ALL, locale);
} }
~CA2WEX() { delete[] m_psz; } ~CA2WEX() { delete[] m_psz; }
@ -1040,50 +994,34 @@ private:
HANDLE m_h; HANDLE m_h;
}; };
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// CComBSTR // CComBSTR
class CComBSTR class CComBSTR {
{
public: public:
BSTR m_str; BSTR m_str;
CComBSTR() : m_str(nullptr) {}; CComBSTR() : m_str(nullptr){};
CComBSTR(_In_ int nSize, LPCWSTR sz); CComBSTR(int nSize, LPCWSTR sz);
~CComBSTR() throw() { ~CComBSTR() throw() { SysFreeString(m_str); }
SysFreeString(m_str); unsigned int Length() const throw() { return SysStringLen(m_str); }
} operator BSTR() const throw() { return m_str; }
operator BSTR() const throw() bool operator==(const CComBSTR &bstrSrc) const throw();
{
return m_str;
}
bool operator==(_In_ const CComBSTR& bstrSrc) const throw(); BSTR *operator&() throw() { return &m_str; }
BSTR* operator&() throw() BSTR Detach() throw() {
{ BSTR s = m_str;
return &m_str; m_str = NULL;
} return s;
}
BSTR Detach() throw()
{
BSTR s = m_str;
m_str = NULL;
return s;
}
void Empty() throw() {
SysFreeString(m_str);
m_str = NULL;
}
}; };
#endif // __cplusplus
#endif // _WIN32
#ifdef __cplusplus
#include <string>
#include <vector>
//===--------- Convert argv to wchar ----------------===// //===--------- Convert argv to wchar ----------------===//
class WArgV { class WArgV {
std::vector<std::wstring> WStringVector; std::vector<std::wstring> WStringVector;
@ -1091,9 +1029,11 @@ class WArgV {
public: public:
WArgV(int argc, const char **argv); WArgV(int argc, const char **argv);
WArgV(int argc, const wchar_t **argv); const wchar_t **argv() { return WCharPtrVector.data(); }
const wchar_t **argv() { return WCharPtrVector.data();}
}; };
#endif
#endif // __cplusplus
#endif // _WIN32
#endif // LLVM_SUPPORT_WIN_ADAPTER_H #endif // LLVM_SUPPORT_WIN_ADAPTER_H

File diff suppressed because it is too large Load diff

View file

@ -1,623 +0,0 @@
//
// Copyright (C) 2002-2005 3Dlabs Inc. Ltd.
// Copyright (C) 2012-2013 LunarG, Inc.
// Copyright (C) 2017 ARM Limited.
// Modifications Copyright (C) 2020 Advanced Micro Devices, Inc. All rights reserved.
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//
// Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following
// disclaimer in the documentation and/or other materials provided
// with the distribution.
//
// Neither the name of 3Dlabs Inc. Ltd. nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
// FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
// COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
// BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
// ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
#ifndef _BASICTYPES_INCLUDED_
#define _BASICTYPES_INCLUDED_
namespace glslang {
//
// Basic type. Arrays, vectors, sampler details, etc., are orthogonal to this.
//
enum TBasicType {
EbtVoid,
EbtFloat,
EbtDouble,
EbtFloat16,
EbtInt8,
EbtUint8,
EbtInt16,
EbtUint16,
EbtInt,
EbtUint,
EbtInt64,
EbtUint64,
EbtBool,
EbtAtomicUint,
EbtSampler,
EbtStruct,
EbtBlock,
EbtAccStruct,
EbtReference,
EbtRayQuery,
EbtHitObjectNV,
#ifndef GLSLANG_WEB
// SPIR-V type defined by spirv_type
EbtSpirvType,
#endif
// HLSL types that live only temporarily.
EbtString,
EbtNumTypes
};
//
// Storage qualifiers. Should align with different kinds of storage or
// resource or GLSL storage qualifier. Expansion is deprecated.
//
// N.B.: You probably DON'T want to add anything here, but rather just add it
// to the built-in variables. See the comment above TBuiltInVariable.
//
// A new built-in variable will normally be an existing qualifier, like 'in', 'out', etc.
// DO NOT follow the design pattern of, say EvqInstanceId, etc.
//
enum TStorageQualifier {
EvqTemporary, // For temporaries (within a function), read/write
EvqGlobal, // For globals read/write
EvqConst, // User-defined constant values, will be semantically constant and constant folded
EvqVaryingIn, // pipeline input, read only, also supercategory for all built-ins not included in this enum (see TBuiltInVariable)
EvqVaryingOut, // pipeline output, read/write, also supercategory for all built-ins not included in this enum (see TBuiltInVariable)
EvqUniform, // read only, shared with app
EvqBuffer, // read/write, shared with app
EvqShared, // compute shader's read/write 'shared' qualifier
#ifndef GLSLANG_WEB
EvqSpirvStorageClass, // spirv_storage_class
#endif
EvqPayload,
EvqPayloadIn,
EvqHitAttr,
EvqCallableData,
EvqCallableDataIn,
EvqHitObjectAttrNV,
EvqtaskPayloadSharedEXT,
// parameters
EvqIn, // also, for 'in' in the grammar before we know if it's a pipeline input or an 'in' parameter
EvqOut, // also, for 'out' in the grammar before we know if it's a pipeline output or an 'out' parameter
EvqInOut,
EvqConstReadOnly, // input; also other read-only types having neither a constant value nor constant-value semantics
// built-ins read by vertex shader
EvqVertexId,
EvqInstanceId,
// built-ins written by vertex shader
EvqPosition,
EvqPointSize,
EvqClipVertex,
// built-ins read by fragment shader
EvqFace,
EvqFragCoord,
EvqPointCoord,
// built-ins written by fragment shader
EvqFragColor,
EvqFragDepth,
EvqFragStencil,
EvqTileImageEXT,
// end of list
EvqLast
};
//
// Subcategories of the TStorageQualifier, simply to give a direct mapping
// between built-in variable names and an numerical value (the enum).
//
// For backward compatibility, there is some redundancy between the
// TStorageQualifier and these. Existing members should both be maintained accurately.
// However, any new built-in variable (and any existing non-redundant one)
// must follow the pattern that the specific built-in is here, and only its
// general qualifier is in TStorageQualifier.
//
// Something like gl_Position, which is sometimes 'in' and sometimes 'out'
// shows up as two different built-in variables in a single stage, but
// only has a single enum in TBuiltInVariable, so both the
// TStorageQualifier and the TBuitinVariable are needed to distinguish
// between them.
//
enum TBuiltInVariable {
EbvNone,
EbvNumWorkGroups,
EbvWorkGroupSize,
EbvWorkGroupId,
EbvLocalInvocationId,
EbvGlobalInvocationId,
EbvLocalInvocationIndex,
EbvNumSubgroups,
EbvSubgroupID,
EbvSubGroupSize,
EbvSubGroupInvocation,
EbvSubGroupEqMask,
EbvSubGroupGeMask,
EbvSubGroupGtMask,
EbvSubGroupLeMask,
EbvSubGroupLtMask,
EbvSubgroupSize2,
EbvSubgroupInvocation2,
EbvSubgroupEqMask2,
EbvSubgroupGeMask2,
EbvSubgroupGtMask2,
EbvSubgroupLeMask2,
EbvSubgroupLtMask2,
EbvVertexId,
EbvInstanceId,
EbvVertexIndex,
EbvInstanceIndex,
EbvBaseVertex,
EbvBaseInstance,
EbvDrawId,
EbvPosition,
EbvPointSize,
EbvClipVertex,
EbvClipDistance,
EbvCullDistance,
EbvNormal,
EbvVertex,
EbvMultiTexCoord0,
EbvMultiTexCoord1,
EbvMultiTexCoord2,
EbvMultiTexCoord3,
EbvMultiTexCoord4,
EbvMultiTexCoord5,
EbvMultiTexCoord6,
EbvMultiTexCoord7,
EbvFrontColor,
EbvBackColor,
EbvFrontSecondaryColor,
EbvBackSecondaryColor,
EbvTexCoord,
EbvFogFragCoord,
EbvInvocationId,
EbvPrimitiveId,
EbvLayer,
EbvViewportIndex,
EbvPatchVertices,
EbvTessLevelOuter,
EbvTessLevelInner,
EbvBoundingBox,
EbvTessCoord,
EbvColor,
EbvSecondaryColor,
EbvFace,
EbvFragCoord,
EbvPointCoord,
EbvFragColor,
EbvFragData,
EbvFragDepth,
EbvFragStencilRef,
EbvSampleId,
EbvSamplePosition,
EbvSampleMask,
EbvHelperInvocation,
EbvBaryCoordNoPersp,
EbvBaryCoordNoPerspCentroid,
EbvBaryCoordNoPerspSample,
EbvBaryCoordSmooth,
EbvBaryCoordSmoothCentroid,
EbvBaryCoordSmoothSample,
EbvBaryCoordPullModel,
EbvViewIndex,
EbvDeviceIndex,
EbvShadingRateKHR,
EbvPrimitiveShadingRateKHR,
EbvFragSizeEXT,
EbvFragInvocationCountEXT,
EbvSecondaryFragDataEXT,
EbvSecondaryFragColorEXT,
EbvViewportMaskNV,
EbvSecondaryPositionNV,
EbvSecondaryViewportMaskNV,
EbvPositionPerViewNV,
EbvViewportMaskPerViewNV,
EbvFragFullyCoveredNV,
EbvFragmentSizeNV,
EbvInvocationsPerPixelNV,
// ray tracing
EbvLaunchId,
EbvLaunchSize,
EbvInstanceCustomIndex,
EbvGeometryIndex,
EbvWorldRayOrigin,
EbvWorldRayDirection,
EbvObjectRayOrigin,
EbvObjectRayDirection,
EbvRayTmin,
EbvRayTmax,
EbvCullMask,
EbvHitT,
EbvHitKind,
EbvObjectToWorld,
EbvObjectToWorld3x4,
EbvWorldToObject,
EbvWorldToObject3x4,
EbvIncomingRayFlags,
EbvCurrentRayTimeNV,
// barycentrics
EbvBaryCoordNV,
EbvBaryCoordNoPerspNV,
EbvBaryCoordEXT,
EbvBaryCoordNoPerspEXT,
// mesh shaders
EbvTaskCountNV,
EbvPrimitiveCountNV,
EbvPrimitiveIndicesNV,
EbvClipDistancePerViewNV,
EbvCullDistancePerViewNV,
EbvLayerPerViewNV,
EbvMeshViewCountNV,
EbvMeshViewIndicesNV,
//GL_EXT_mesh_shader
EbvPrimitivePointIndicesEXT,
EbvPrimitiveLineIndicesEXT,
EbvPrimitiveTriangleIndicesEXT,
EbvCullPrimitiveEXT,
// sm builtins
EbvWarpsPerSM,
EbvSMCount,
EbvWarpID,
EbvSMID,
// HLSL built-ins that live only temporarily, until they get remapped
// to one of the above.
EbvFragDepthGreater,
EbvFragDepthLesser,
EbvGsOutputStream,
EbvOutputPatch,
EbvInputPatch,
// structbuffer types
EbvAppendConsume, // no need to differentiate append and consume
EbvRWStructuredBuffer,
EbvStructuredBuffer,
EbvByteAddressBuffer,
EbvRWByteAddressBuffer,
// ARM specific core builtins
EbvCoreCountARM,
EbvCoreIDARM,
EbvCoreMaxIDARM,
EbvWarpIDARM,
EbvWarpMaxIDARM,
EbvPositionFetch,
EbvLast
};
// In this enum, order matters; users can assume higher precision is a bigger value
// and EpqNone is 0.
enum TPrecisionQualifier {
EpqNone = 0,
EpqLow,
EpqMedium,
EpqHigh
};
#ifdef GLSLANG_WEB
__inline const char* GetStorageQualifierString(TStorageQualifier q) { return ""; }
__inline const char* GetPrecisionQualifierString(TPrecisionQualifier p) { return ""; }
#else
// These will show up in error messages
__inline const char* GetStorageQualifierString(TStorageQualifier q)
{
switch (q) {
case EvqTemporary: return "temp"; break;
case EvqGlobal: return "global"; break;
case EvqConst: return "const"; break;
case EvqConstReadOnly: return "const (read only)"; break;
#ifndef GLSLANG_WEB
case EvqSpirvStorageClass: return "spirv_storage_class"; break;
#endif
case EvqVaryingIn: return "in"; break;
case EvqVaryingOut: return "out"; break;
case EvqUniform: return "uniform"; break;
case EvqBuffer: return "buffer"; break;
case EvqShared: return "shared"; break;
case EvqIn: return "in"; break;
case EvqOut: return "out"; break;
case EvqInOut: return "inout"; break;
case EvqVertexId: return "gl_VertexId"; break;
case EvqInstanceId: return "gl_InstanceId"; break;
case EvqPosition: return "gl_Position"; break;
case EvqPointSize: return "gl_PointSize"; break;
case EvqClipVertex: return "gl_ClipVertex"; break;
case EvqFace: return "gl_FrontFacing"; break;
case EvqFragCoord: return "gl_FragCoord"; break;
case EvqPointCoord: return "gl_PointCoord"; break;
case EvqFragColor: return "fragColor"; break;
case EvqFragDepth: return "gl_FragDepth"; break;
case EvqFragStencil: return "gl_FragStencilRefARB"; break;
case EvqPayload: return "rayPayloadNV"; break;
case EvqPayloadIn: return "rayPayloadInNV"; break;
case EvqHitAttr: return "hitAttributeNV"; break;
case EvqCallableData: return "callableDataNV"; break;
case EvqCallableDataIn: return "callableDataInNV"; break;
case EvqtaskPayloadSharedEXT: return "taskPayloadSharedEXT"; break;
case EvqHitObjectAttrNV:return "hitObjectAttributeNV"; break;
default: return "unknown qualifier";
}
}
__inline const char* GetBuiltInVariableString(TBuiltInVariable v)
{
switch (v) {
case EbvNone: return "";
case EbvNumWorkGroups: return "NumWorkGroups";
case EbvWorkGroupSize: return "WorkGroupSize";
case EbvWorkGroupId: return "WorkGroupID";
case EbvLocalInvocationId: return "LocalInvocationID";
case EbvGlobalInvocationId: return "GlobalInvocationID";
case EbvLocalInvocationIndex: return "LocalInvocationIndex";
case EbvNumSubgroups: return "NumSubgroups";
case EbvSubgroupID: return "SubgroupID";
case EbvSubGroupSize: return "SubGroupSize";
case EbvSubGroupInvocation: return "SubGroupInvocation";
case EbvSubGroupEqMask: return "SubGroupEqMask";
case EbvSubGroupGeMask: return "SubGroupGeMask";
case EbvSubGroupGtMask: return "SubGroupGtMask";
case EbvSubGroupLeMask: return "SubGroupLeMask";
case EbvSubGroupLtMask: return "SubGroupLtMask";
case EbvSubgroupSize2: return "SubgroupSize";
case EbvSubgroupInvocation2: return "SubgroupInvocationID";
case EbvSubgroupEqMask2: return "SubgroupEqMask";
case EbvSubgroupGeMask2: return "SubgroupGeMask";
case EbvSubgroupGtMask2: return "SubgroupGtMask";
case EbvSubgroupLeMask2: return "SubgroupLeMask";
case EbvSubgroupLtMask2: return "SubgroupLtMask";
case EbvVertexId: return "VertexId";
case EbvInstanceId: return "InstanceId";
case EbvVertexIndex: return "VertexIndex";
case EbvInstanceIndex: return "InstanceIndex";
case EbvBaseVertex: return "BaseVertex";
case EbvBaseInstance: return "BaseInstance";
case EbvDrawId: return "DrawId";
case EbvPosition: return "Position";
case EbvPointSize: return "PointSize";
case EbvClipVertex: return "ClipVertex";
case EbvClipDistance: return "ClipDistance";
case EbvCullDistance: return "CullDistance";
case EbvNormal: return "Normal";
case EbvVertex: return "Vertex";
case EbvMultiTexCoord0: return "MultiTexCoord0";
case EbvMultiTexCoord1: return "MultiTexCoord1";
case EbvMultiTexCoord2: return "MultiTexCoord2";
case EbvMultiTexCoord3: return "MultiTexCoord3";
case EbvMultiTexCoord4: return "MultiTexCoord4";
case EbvMultiTexCoord5: return "MultiTexCoord5";
case EbvMultiTexCoord6: return "MultiTexCoord6";
case EbvMultiTexCoord7: return "MultiTexCoord7";
case EbvFrontColor: return "FrontColor";
case EbvBackColor: return "BackColor";
case EbvFrontSecondaryColor: return "FrontSecondaryColor";
case EbvBackSecondaryColor: return "BackSecondaryColor";
case EbvTexCoord: return "TexCoord";
case EbvFogFragCoord: return "FogFragCoord";
case EbvInvocationId: return "InvocationID";
case EbvPrimitiveId: return "PrimitiveID";
case EbvLayer: return "Layer";
case EbvViewportIndex: return "ViewportIndex";
case EbvPatchVertices: return "PatchVertices";
case EbvTessLevelOuter: return "TessLevelOuter";
case EbvTessLevelInner: return "TessLevelInner";
case EbvBoundingBox: return "BoundingBox";
case EbvTessCoord: return "TessCoord";
case EbvColor: return "Color";
case EbvSecondaryColor: return "SecondaryColor";
case EbvFace: return "Face";
case EbvFragCoord: return "FragCoord";
case EbvPointCoord: return "PointCoord";
case EbvFragColor: return "FragColor";
case EbvFragData: return "FragData";
case EbvFragDepth: return "FragDepth";
case EbvFragStencilRef: return "FragStencilRef";
case EbvSampleId: return "SampleId";
case EbvSamplePosition: return "SamplePosition";
case EbvSampleMask: return "SampleMaskIn";
case EbvHelperInvocation: return "HelperInvocation";
case EbvBaryCoordNoPersp: return "BaryCoordNoPersp";
case EbvBaryCoordNoPerspCentroid: return "BaryCoordNoPerspCentroid";
case EbvBaryCoordNoPerspSample: return "BaryCoordNoPerspSample";
case EbvBaryCoordSmooth: return "BaryCoordSmooth";
case EbvBaryCoordSmoothCentroid: return "BaryCoordSmoothCentroid";
case EbvBaryCoordSmoothSample: return "BaryCoordSmoothSample";
case EbvBaryCoordPullModel: return "BaryCoordPullModel";
case EbvViewIndex: return "ViewIndex";
case EbvDeviceIndex: return "DeviceIndex";
case EbvFragSizeEXT: return "FragSizeEXT";
case EbvFragInvocationCountEXT: return "FragInvocationCountEXT";
case EbvSecondaryFragDataEXT: return "SecondaryFragDataEXT";
case EbvSecondaryFragColorEXT: return "SecondaryFragColorEXT";
case EbvViewportMaskNV: return "ViewportMaskNV";
case EbvSecondaryPositionNV: return "SecondaryPositionNV";
case EbvSecondaryViewportMaskNV: return "SecondaryViewportMaskNV";
case EbvPositionPerViewNV: return "PositionPerViewNV";
case EbvViewportMaskPerViewNV: return "ViewportMaskPerViewNV";
case EbvFragFullyCoveredNV: return "FragFullyCoveredNV";
case EbvFragmentSizeNV: return "FragmentSizeNV";
case EbvInvocationsPerPixelNV: return "InvocationsPerPixelNV";
case EbvLaunchId: return "LaunchIdNV";
case EbvLaunchSize: return "LaunchSizeNV";
case EbvInstanceCustomIndex: return "InstanceCustomIndexNV";
case EbvGeometryIndex: return "GeometryIndexEXT";
case EbvWorldRayOrigin: return "WorldRayOriginNV";
case EbvWorldRayDirection: return "WorldRayDirectionNV";
case EbvObjectRayOrigin: return "ObjectRayOriginNV";
case EbvObjectRayDirection: return "ObjectRayDirectionNV";
case EbvRayTmin: return "ObjectRayTminNV";
case EbvRayTmax: return "ObjectRayTmaxNV";
case EbvHitT: return "HitTNV";
case EbvHitKind: return "HitKindNV";
case EbvIncomingRayFlags: return "IncomingRayFlagsNV";
case EbvObjectToWorld: return "ObjectToWorldNV";
case EbvWorldToObject: return "WorldToObjectNV";
case EbvCurrentRayTimeNV: return "CurrentRayTimeNV";
case EbvBaryCoordEXT:
case EbvBaryCoordNV: return "BaryCoordKHR";
case EbvBaryCoordNoPerspEXT:
case EbvBaryCoordNoPerspNV: return "BaryCoordNoPerspKHR";
case EbvTaskCountNV: return "TaskCountNV";
case EbvPrimitiveCountNV: return "PrimitiveCountNV";
case EbvPrimitiveIndicesNV: return "PrimitiveIndicesNV";
case EbvClipDistancePerViewNV: return "ClipDistancePerViewNV";
case EbvCullDistancePerViewNV: return "CullDistancePerViewNV";
case EbvLayerPerViewNV: return "LayerPerViewNV";
case EbvMeshViewCountNV: return "MeshViewCountNV";
case EbvMeshViewIndicesNV: return "MeshViewIndicesNV";
// GL_EXT_mesh_shader
case EbvPrimitivePointIndicesEXT: return "PrimitivePointIndicesEXT";
case EbvPrimitiveLineIndicesEXT: return "PrimitiveLineIndicesEXT";
case EbvPrimitiveTriangleIndicesEXT: return "PrimitiveTriangleIndicesEXT";
case EbvCullPrimitiveEXT: return "CullPrimitiveEXT";
case EbvWarpsPerSM: return "WarpsPerSMNV";
case EbvSMCount: return "SMCountNV";
case EbvWarpID: return "WarpIDNV";
case EbvSMID: return "SMIDNV";
case EbvShadingRateKHR: return "ShadingRateKHR";
case EbvPrimitiveShadingRateKHR: return "PrimitiveShadingRateKHR";
default: return "unknown built-in variable";
}
}
__inline const char* GetPrecisionQualifierString(TPrecisionQualifier p)
{
switch (p) {
case EpqNone: return ""; break;
case EpqLow: return "lowp"; break;
case EpqMedium: return "mediump"; break;
case EpqHigh: return "highp"; break;
default: return "unknown precision qualifier";
}
}
#endif
__inline bool isTypeSignedInt(TBasicType type)
{
switch (type) {
case EbtInt8:
case EbtInt16:
case EbtInt:
case EbtInt64:
return true;
default:
return false;
}
}
__inline bool isTypeUnsignedInt(TBasicType type)
{
switch (type) {
case EbtUint8:
case EbtUint16:
case EbtUint:
case EbtUint64:
return true;
default:
return false;
}
}
__inline bool isTypeInt(TBasicType type)
{
return isTypeSignedInt(type) || isTypeUnsignedInt(type);
}
__inline bool isTypeFloat(TBasicType type)
{
switch (type) {
case EbtFloat:
case EbtDouble:
case EbtFloat16:
return true;
default:
return false;
}
}
__inline int getTypeRank(TBasicType type)
{
int res = -1;
switch(type) {
case EbtInt8:
case EbtUint8:
res = 0;
break;
case EbtInt16:
case EbtUint16:
res = 1;
break;
case EbtInt:
case EbtUint:
res = 2;
break;
case EbtInt64:
case EbtUint64:
res = 3;
break;
default:
assert(false);
break;
}
return res;
}
} // end namespace glslang
#endif // _BASICTYPES_INCLUDED_

View file

@ -1,325 +0,0 @@
//
// Copyright (C) 2002-2005 3Dlabs Inc. Ltd.
// Copyright (C) 2012-2013 LunarG, Inc.
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//
// Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following
// disclaimer in the documentation and/or other materials provided
// with the distribution.
//
// Neither the name of 3Dlabs Inc. Ltd. nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
// FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
// COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
// BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
// ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
#ifndef _COMMON_INCLUDED_
#define _COMMON_INCLUDED_
#include <algorithm>
#include <cassert>
#ifdef _MSC_VER
#include <cfloat>
#else
#include <cmath>
#endif
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <list>
#include <map>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#if defined(__ANDROID__)
#include <sstream>
namespace std {
template<typename T>
std::string to_string(const T& val) {
std::ostringstream os;
os << val;
return os.str();
}
}
#endif
#if defined(MINGW_HAS_SECURE_API) && MINGW_HAS_SECURE_API
#include <basetsd.h>
#ifndef snprintf
#define snprintf sprintf_s
#endif
#define safe_vsprintf(buf,max,format,args) vsnprintf_s((buf), (max), (max), (format), (args))
#elif defined (solaris)
#define safe_vsprintf(buf,max,format,args) vsnprintf((buf), (max), (format), (args))
#include <sys/int_types.h>
#define UINT_PTR uintptr_t
#else
#define safe_vsprintf(buf,max,format,args) vsnprintf((buf), (max), (format), (args))
#include <stdint.h>
#define UINT_PTR uintptr_t
#endif
#if defined(_MSC_VER)
#define strdup _strdup
#endif
/* windows only pragma */
#ifdef _MSC_VER
#pragma warning(disable : 4786) // Don't warn about too long identifiers
#pragma warning(disable : 4514) // unused inline method
#pragma warning(disable : 4201) // nameless union
#endif
#include "PoolAlloc.h"
//
// Put POOL_ALLOCATOR_NEW_DELETE in base classes to make them use this scheme.
//
#define POOL_ALLOCATOR_NEW_DELETE(A) \
void* operator new(size_t s) { return (A).allocate(s); } \
void* operator new(size_t, void *_Where) { return (_Where); } \
void operator delete(void*) { } \
void operator delete(void *, void *) { } \
void* operator new[](size_t s) { return (A).allocate(s); } \
void* operator new[](size_t, void *_Where) { return (_Where); } \
void operator delete[](void*) { } \
void operator delete[](void *, void *) { }
namespace glslang {
//
// Pool version of string.
//
typedef pool_allocator<char> TStringAllocator;
typedef std::basic_string <char, std::char_traits<char>, TStringAllocator> TString;
} // end namespace glslang
// Repackage the std::hash for use by unordered map/set with a TString key.
namespace std {
template<> struct hash<glslang::TString> {
std::size_t operator()(const glslang::TString& s) const
{
const unsigned _FNV_offset_basis = 2166136261U;
const unsigned _FNV_prime = 16777619U;
unsigned _Val = _FNV_offset_basis;
size_t _Count = s.size();
const char* _First = s.c_str();
for (size_t _Next = 0; _Next < _Count; ++_Next)
{
_Val ^= (unsigned)_First[_Next];
_Val *= _FNV_prime;
}
return _Val;
}
};
}
namespace glslang {
inline TString* NewPoolTString(const char* s)
{
void* memory = GetThreadPoolAllocator().allocate(sizeof(TString));
return new(memory) TString(s);
}
template<class T> inline T* NewPoolObject(T*)
{
return new(GetThreadPoolAllocator().allocate(sizeof(T))) T;
}
template<class T> inline T* NewPoolObject(T, int instances)
{
return new(GetThreadPoolAllocator().allocate(instances * sizeof(T))) T[instances];
}
//
// Pool allocator versions of vectors, lists, and maps
//
template <class T> class TVector : public std::vector<T, pool_allocator<T> > {
public:
POOL_ALLOCATOR_NEW_DELETE(GetThreadPoolAllocator())
typedef typename std::vector<T, pool_allocator<T> >::size_type size_type;
TVector() : std::vector<T, pool_allocator<T> >() {}
TVector(const pool_allocator<T>& a) : std::vector<T, pool_allocator<T> >(a) {}
TVector(size_type i) : std::vector<T, pool_allocator<T> >(i) {}
TVector(size_type i, const T& val) : std::vector<T, pool_allocator<T> >(i, val) {}
};
template <class T> class TList : public std::list<T, pool_allocator<T> > {
};
template <class K, class D, class CMP = std::less<K> >
class TMap : public std::map<K, D, CMP, pool_allocator<std::pair<K const, D> > > {
};
template <class K, class D, class HASH = std::hash<K>, class PRED = std::equal_to<K> >
class TUnorderedMap : public std::unordered_map<K, D, HASH, PRED, pool_allocator<std::pair<K const, D> > > {
};
template <class K, class CMP = std::less<K> >
class TSet : public std::set<K, CMP, pool_allocator<K> > {
};
//
// Persistent string memory. Should only be used for strings that survive
// across compiles/links.
//
typedef std::basic_string<char> TPersistString;
//
// templatized min and max functions.
//
template <class T> T Min(const T a, const T b) { return a < b ? a : b; }
template <class T> T Max(const T a, const T b) { return a > b ? a : b; }
//
// Create a TString object from an integer.
//
#if defined(_MSC_VER) || (defined(MINGW_HAS_SECURE_API) && MINGW_HAS_SECURE_API)
inline const TString String(const int i, const int base = 10)
{
char text[16]; // 32 bit ints are at most 10 digits in base 10
_itoa_s(i, text, sizeof(text), base);
return text;
}
#else
inline const TString String(const int i, const int /*base*/ = 10)
{
char text[16]; // 32 bit ints are at most 10 digits in base 10
// we assume base 10 for all cases
snprintf(text, sizeof(text), "%d", i);
return text;
}
#endif
struct TSourceLoc {
void init()
{
name = nullptr; string = 0; line = 0; column = 0;
}
void init(int stringNum) { init(); string = stringNum; }
// Returns the name if it exists. Otherwise, returns the string number.
std::string getStringNameOrNum(bool quoteStringName = true) const
{
if (name != nullptr) {
TString qstr = quoteStringName ? ("\"" + *name + "\"") : *name;
std::string ret_str(qstr.c_str());
return ret_str;
}
return std::to_string((long long)string);
}
const char* getFilename() const
{
if (name == nullptr)
return nullptr;
return name->c_str();
}
const char* getFilenameStr() const { return name == nullptr ? "" : name->c_str(); }
TString* name; // descriptive name for this string, when a textual name is available, otherwise nullptr
int string;
int line;
int column;
};
class TPragmaTable : public TMap<TString, TString> {
public:
POOL_ALLOCATOR_NEW_DELETE(GetThreadPoolAllocator())
};
const int MaxTokenLength = 1024;
template <class T> bool IsPow2(T powerOf2)
{
if (powerOf2 <= 0)
return false;
return (powerOf2 & (powerOf2 - 1)) == 0;
}
// Round number up to a multiple of the given powerOf2, which is not
// a power, just a number that must be a power of 2.
template <class T> void RoundToPow2(T& number, int powerOf2)
{
assert(IsPow2(powerOf2));
number = (number + powerOf2 - 1) & ~(powerOf2 - 1);
}
template <class T> bool IsMultipleOfPow2(T number, int powerOf2)
{
assert(IsPow2(powerOf2));
return ! (number & (powerOf2 - 1));
}
// Returns log2 of an integer power of 2.
// T should be integral.
template <class T> int IntLog2(T n)
{
assert(IsPow2(n));
int result = 0;
while ((T(1) << result) != n) {
result++;
}
return result;
}
inline bool IsInfinity(double x) {
#ifdef _MSC_VER
switch (_fpclass(x)) {
case _FPCLASS_NINF:
case _FPCLASS_PINF:
return true;
default:
return false;
}
#else
return std::isinf(x);
#endif
}
inline bool IsNan(double x) {
#ifdef _MSC_VER
switch (_fpclass(x)) {
case _FPCLASS_SNAN:
case _FPCLASS_QNAN:
return true;
default:
return false;
}
#else
return std::isnan(x);
#endif
}
} // end namespace glslang
#endif // _COMMON_INCLUDED_

View file

@ -1,974 +0,0 @@
//
// Copyright (C) 2002-2005 3Dlabs Inc. Ltd.
// Copyright (C) 2013 LunarG, Inc.
// Copyright (C) 2017 ARM Limited.
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//
// Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following
// disclaimer in the documentation and/or other materials provided
// with the distribution.
//
// Neither the name of 3Dlabs Inc. Ltd. nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
// FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
// COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
// BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
// ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
#ifndef _CONSTANT_UNION_INCLUDED_
#define _CONSTANT_UNION_INCLUDED_
#include "../Include/Common.h"
#include "../Include/BaseTypes.h"
namespace glslang {
class TConstUnion {
public:
POOL_ALLOCATOR_NEW_DELETE(GetThreadPoolAllocator())
TConstUnion() : iConst(0), type(EbtInt) { }
void setI8Const(signed char i)
{
i8Const = i;
type = EbtInt8;
}
void setU8Const(unsigned char u)
{
u8Const = u;
type = EbtUint8;
}
void setI16Const(signed short i)
{
i16Const = i;
type = EbtInt16;
}
void setU16Const(unsigned short u)
{
u16Const = u;
type = EbtUint16;
}
void setIConst(int i)
{
iConst = i;
type = EbtInt;
}
void setUConst(unsigned int u)
{
uConst = u;
type = EbtUint;
}
void setI64Const(long long i64)
{
i64Const = i64;
type = EbtInt64;
}
void setU64Const(unsigned long long u64)
{
u64Const = u64;
type = EbtUint64;
}
void setDConst(double d)
{
dConst = d;
type = EbtDouble;
}
void setBConst(bool b)
{
bConst = b;
type = EbtBool;
}
void setSConst(const TString* s)
{
sConst = s;
type = EbtString;
}
signed char getI8Const() const { return i8Const; }
unsigned char getU8Const() const { return u8Const; }
signed short getI16Const() const { return i16Const; }
unsigned short getU16Const() const { return u16Const; }
int getIConst() const { return iConst; }
unsigned int getUConst() const { return uConst; }
long long getI64Const() const { return i64Const; }
unsigned long long getU64Const() const { return u64Const; }
double getDConst() const { return dConst; }
bool getBConst() const { return bConst; }
const TString* getSConst() const { return sConst; }
bool operator==(const signed char i) const
{
if (i == i8Const)
return true;
return false;
}
bool operator==(const unsigned char u) const
{
if (u == u8Const)
return true;
return false;
}
bool operator==(const signed short i) const
{
if (i == i16Const)
return true;
return false;
}
bool operator==(const unsigned short u) const
{
if (u == u16Const)
return true;
return false;
}
bool operator==(const int i) const
{
if (i == iConst)
return true;
return false;
}
bool operator==(const unsigned int u) const
{
if (u == uConst)
return true;
return false;
}
bool operator==(const long long i64) const
{
if (i64 == i64Const)
return true;
return false;
}
bool operator==(const unsigned long long u64) const
{
if (u64 == u64Const)
return true;
return false;
}
bool operator==(const double d) const
{
if (d == dConst)
return true;
return false;
}
bool operator==(const bool b) const
{
if (b == bConst)
return true;
return false;
}
bool operator==(const TConstUnion& constant) const
{
if (constant.type != type)
return false;
switch (type) {
case EbtInt:
if (constant.iConst == iConst)
return true;
break;
case EbtUint:
if (constant.uConst == uConst)
return true;
break;
case EbtBool:
if (constant.bConst == bConst)
return true;
break;
case EbtDouble:
if (constant.dConst == dConst)
return true;
break;
#ifndef GLSLANG_WEB
case EbtInt16:
if (constant.i16Const == i16Const)
return true;
break;
case EbtUint16:
if (constant.u16Const == u16Const)
return true;
break;
case EbtInt8:
if (constant.i8Const == i8Const)
return true;
break;
case EbtUint8:
if (constant.u8Const == u8Const)
return true;
break;
case EbtInt64:
if (constant.i64Const == i64Const)
return true;
break;
case EbtUint64:
if (constant.u64Const == u64Const)
return true;
break;
#endif
default:
assert(false && "Default missing");
}
return false;
}
bool operator!=(const signed char i) const
{
return !operator==(i);
}
bool operator!=(const unsigned char u) const
{
return !operator==(u);
}
bool operator!=(const signed short i) const
{
return !operator==(i);
}
bool operator!=(const unsigned short u) const
{
return !operator==(u);
}
bool operator!=(const int i) const
{
return !operator==(i);
}
bool operator!=(const unsigned int u) const
{
return !operator==(u);
}
bool operator!=(const long long i) const
{
return !operator==(i);
}
bool operator!=(const unsigned long long u) const
{
return !operator==(u);
}
bool operator!=(const float f) const
{
return !operator==(f);
}
bool operator!=(const bool b) const
{
return !operator==(b);
}
bool operator!=(const TConstUnion& constant) const
{
return !operator==(constant);
}
bool operator>(const TConstUnion& constant) const
{
assert(type == constant.type);
switch (type) {
case EbtInt:
if (iConst > constant.iConst)
return true;
return false;
case EbtUint:
if (uConst > constant.uConst)
return true;
return false;
case EbtDouble:
if (dConst > constant.dConst)
return true;
return false;
#ifndef GLSLANG_WEB
case EbtInt8:
if (i8Const > constant.i8Const)
return true;
return false;
case EbtUint8:
if (u8Const > constant.u8Const)
return true;
return false;
case EbtInt16:
if (i16Const > constant.i16Const)
return true;
return false;
case EbtUint16:
if (u16Const > constant.u16Const)
return true;
return false;
case EbtInt64:
if (i64Const > constant.i64Const)
return true;
return false;
case EbtUint64:
if (u64Const > constant.u64Const)
return true;
return false;
#endif
default:
assert(false && "Default missing");
return false;
}
}
bool operator<(const TConstUnion& constant) const
{
assert(type == constant.type);
switch (type) {
#ifndef GLSLANG_WEB
case EbtInt8:
if (i8Const < constant.i8Const)
return true;
return false;
case EbtUint8:
if (u8Const < constant.u8Const)
return true;
return false;
case EbtInt16:
if (i16Const < constant.i16Const)
return true;
return false;
case EbtUint16:
if (u16Const < constant.u16Const)
return true;
return false;
case EbtInt64:
if (i64Const < constant.i64Const)
return true;
return false;
case EbtUint64:
if (u64Const < constant.u64Const)
return true;
return false;
#endif
case EbtDouble:
if (dConst < constant.dConst)
return true;
return false;
case EbtInt:
if (iConst < constant.iConst)
return true;
return false;
case EbtUint:
if (uConst < constant.uConst)
return true;
return false;
default:
assert(false && "Default missing");
return false;
}
}
TConstUnion operator+(const TConstUnion& constant) const
{
TConstUnion returnValue;
assert(type == constant.type);
switch (type) {
case EbtInt: returnValue.setIConst(iConst + constant.iConst); break;
case EbtUint: returnValue.setUConst(uConst + constant.uConst); break;
case EbtDouble: returnValue.setDConst(dConst + constant.dConst); break;
#ifndef GLSLANG_WEB
case EbtInt8: returnValue.setI8Const(i8Const + constant.i8Const); break;
case EbtInt16: returnValue.setI16Const(i16Const + constant.i16Const); break;
case EbtInt64: returnValue.setI64Const(i64Const + constant.i64Const); break;
case EbtUint8: returnValue.setU8Const(u8Const + constant.u8Const); break;
case EbtUint16: returnValue.setU16Const(u16Const + constant.u16Const); break;
case EbtUint64: returnValue.setU64Const(u64Const + constant.u64Const); break;
#endif
default: assert(false && "Default missing");
}
return returnValue;
}
TConstUnion operator-(const TConstUnion& constant) const
{
TConstUnion returnValue;
assert(type == constant.type);
switch (type) {
case EbtInt: returnValue.setIConst(iConst - constant.iConst); break;
case EbtUint: returnValue.setUConst(uConst - constant.uConst); break;
case EbtDouble: returnValue.setDConst(dConst - constant.dConst); break;
#ifndef GLSLANG_WEB
case EbtInt8: returnValue.setI8Const(i8Const - constant.i8Const); break;
case EbtInt16: returnValue.setI16Const(i16Const - constant.i16Const); break;
case EbtInt64: returnValue.setI64Const(i64Const - constant.i64Const); break;
case EbtUint8: returnValue.setU8Const(u8Const - constant.u8Const); break;
case EbtUint16: returnValue.setU16Const(u16Const - constant.u16Const); break;
case EbtUint64: returnValue.setU64Const(u64Const - constant.u64Const); break;
#endif
default: assert(false && "Default missing");
}
return returnValue;
}
TConstUnion operator*(const TConstUnion& constant) const
{
TConstUnion returnValue;
assert(type == constant.type);
switch (type) {
case EbtInt: returnValue.setIConst(iConst * constant.iConst); break;
case EbtUint: returnValue.setUConst(uConst * constant.uConst); break;
case EbtDouble: returnValue.setDConst(dConst * constant.dConst); break;
#ifndef GLSLANG_WEB
case EbtInt8: returnValue.setI8Const(i8Const * constant.i8Const); break;
case EbtInt16: returnValue.setI16Const(i16Const * constant.i16Const); break;
case EbtInt64: returnValue.setI64Const(i64Const * constant.i64Const); break;
case EbtUint8: returnValue.setU8Const(u8Const * constant.u8Const); break;
case EbtUint16: returnValue.setU16Const(u16Const * constant.u16Const); break;
case EbtUint64: returnValue.setU64Const(u64Const * constant.u64Const); break;
#endif
default: assert(false && "Default missing");
}
return returnValue;
}
TConstUnion operator%(const TConstUnion& constant) const
{
TConstUnion returnValue;
assert(type == constant.type);
switch (type) {
case EbtInt: returnValue.setIConst(iConst % constant.iConst); break;
case EbtUint: returnValue.setUConst(uConst % constant.uConst); break;
#ifndef GLSLANG_WEB
case EbtInt8: returnValue.setI8Const(i8Const % constant.i8Const); break;
case EbtInt16: returnValue.setI8Const(i8Const % constant.i16Const); break;
case EbtInt64: returnValue.setI64Const(i64Const % constant.i64Const); break;
case EbtUint8: returnValue.setU8Const(u8Const % constant.u8Const); break;
case EbtUint16: returnValue.setU16Const(u16Const % constant.u16Const); break;
case EbtUint64: returnValue.setU64Const(u64Const % constant.u64Const); break;
#endif
default: assert(false && "Default missing");
}
return returnValue;
}
TConstUnion operator>>(const TConstUnion& constant) const
{
TConstUnion returnValue;
switch (type) {
#ifndef GLSLANG_WEB
case EbtInt8:
switch (constant.type) {
case EbtInt8: returnValue.setI8Const(i8Const >> constant.i8Const); break;
case EbtUint8: returnValue.setI8Const(i8Const >> constant.u8Const); break;
case EbtInt16: returnValue.setI8Const(i8Const >> constant.i16Const); break;
case EbtUint16: returnValue.setI8Const(i8Const >> constant.u16Const); break;
case EbtInt: returnValue.setI8Const(i8Const >> constant.iConst); break;
case EbtUint: returnValue.setI8Const(i8Const >> constant.uConst); break;
case EbtInt64: returnValue.setI8Const(i8Const >> constant.i64Const); break;
case EbtUint64: returnValue.setI8Const(i8Const >> constant.u64Const); break;
default: assert(false && "Default missing");
}
break;
case EbtUint8:
switch (constant.type) {
case EbtInt8: returnValue.setU8Const(u8Const >> constant.i8Const); break;
case EbtUint8: returnValue.setU8Const(u8Const >> constant.u8Const); break;
case EbtInt16: returnValue.setU8Const(u8Const >> constant.i16Const); break;
case EbtUint16: returnValue.setU8Const(u8Const >> constant.u16Const); break;
case EbtInt: returnValue.setU8Const(u8Const >> constant.iConst); break;
case EbtUint: returnValue.setU8Const(u8Const >> constant.uConst); break;
case EbtInt64: returnValue.setU8Const(u8Const >> constant.i64Const); break;
case EbtUint64: returnValue.setU8Const(u8Const >> constant.u64Const); break;
default: assert(false && "Default missing");
}
break;
case EbtInt16:
switch (constant.type) {
case EbtInt8: returnValue.setI16Const(i16Const >> constant.i8Const); break;
case EbtUint8: returnValue.setI16Const(i16Const >> constant.u8Const); break;
case EbtInt16: returnValue.setI16Const(i16Const >> constant.i16Const); break;
case EbtUint16: returnValue.setI16Const(i16Const >> constant.u16Const); break;
case EbtInt: returnValue.setI16Const(i16Const >> constant.iConst); break;
case EbtUint: returnValue.setI16Const(i16Const >> constant.uConst); break;
case EbtInt64: returnValue.setI16Const(i16Const >> constant.i64Const); break;
case EbtUint64: returnValue.setI16Const(i16Const >> constant.u64Const); break;
default: assert(false && "Default missing");
}
break;
case EbtUint16:
switch (constant.type) {
case EbtInt8: returnValue.setU16Const(u16Const >> constant.i8Const); break;
case EbtUint8: returnValue.setU16Const(u16Const >> constant.u8Const); break;
case EbtInt16: returnValue.setU16Const(u16Const >> constant.i16Const); break;
case EbtUint16: returnValue.setU16Const(u16Const >> constant.u16Const); break;
case EbtInt: returnValue.setU16Const(u16Const >> constant.iConst); break;
case EbtUint: returnValue.setU16Const(u16Const >> constant.uConst); break;
case EbtInt64: returnValue.setU16Const(u16Const >> constant.i64Const); break;
case EbtUint64: returnValue.setU16Const(u16Const >> constant.u64Const); break;
default: assert(false && "Default missing");
}
break;
#endif
case EbtInt:
switch (constant.type) {
case EbtInt: returnValue.setIConst(iConst >> constant.iConst); break;
case EbtUint: returnValue.setIConst(iConst >> constant.uConst); break;
#ifndef GLSLANG_WEB
case EbtInt8: returnValue.setIConst(iConst >> constant.i8Const); break;
case EbtUint8: returnValue.setIConst(iConst >> constant.u8Const); break;
case EbtInt16: returnValue.setIConst(iConst >> constant.i16Const); break;
case EbtUint16: returnValue.setIConst(iConst >> constant.u16Const); break;
case EbtInt64: returnValue.setIConst(iConst >> constant.i64Const); break;
case EbtUint64: returnValue.setIConst(iConst >> constant.u64Const); break;
#endif
default: assert(false && "Default missing");
}
break;
case EbtUint:
switch (constant.type) {
case EbtInt: returnValue.setUConst(uConst >> constant.iConst); break;
case EbtUint: returnValue.setUConst(uConst >> constant.uConst); break;
#ifndef GLSLANG_WEB
case EbtInt8: returnValue.setUConst(uConst >> constant.i8Const); break;
case EbtUint8: returnValue.setUConst(uConst >> constant.u8Const); break;
case EbtInt16: returnValue.setUConst(uConst >> constant.i16Const); break;
case EbtUint16: returnValue.setUConst(uConst >> constant.u16Const); break;
case EbtInt64: returnValue.setUConst(uConst >> constant.i64Const); break;
case EbtUint64: returnValue.setUConst(uConst >> constant.u64Const); break;
#endif
default: assert(false && "Default missing");
}
break;
#ifndef GLSLANG_WEB
case EbtInt64:
switch (constant.type) {
case EbtInt8: returnValue.setI64Const(i64Const >> constant.i8Const); break;
case EbtUint8: returnValue.setI64Const(i64Const >> constant.u8Const); break;
case EbtInt16: returnValue.setI64Const(i64Const >> constant.i16Const); break;
case EbtUint16: returnValue.setI64Const(i64Const >> constant.u16Const); break;
case EbtInt: returnValue.setI64Const(i64Const >> constant.iConst); break;
case EbtUint: returnValue.setI64Const(i64Const >> constant.uConst); break;
case EbtInt64: returnValue.setI64Const(i64Const >> constant.i64Const); break;
case EbtUint64: returnValue.setI64Const(i64Const >> constant.u64Const); break;
default: assert(false && "Default missing");
}
break;
case EbtUint64:
switch (constant.type) {
case EbtInt8: returnValue.setU64Const(u64Const >> constant.i8Const); break;
case EbtUint8: returnValue.setU64Const(u64Const >> constant.u8Const); break;
case EbtInt16: returnValue.setU64Const(u64Const >> constant.i16Const); break;
case EbtUint16: returnValue.setU64Const(u64Const >> constant.u16Const); break;
case EbtInt: returnValue.setU64Const(u64Const >> constant.iConst); break;
case EbtUint: returnValue.setU64Const(u64Const >> constant.uConst); break;
case EbtInt64: returnValue.setU64Const(u64Const >> constant.i64Const); break;
case EbtUint64: returnValue.setU64Const(u64Const >> constant.u64Const); break;
default: assert(false && "Default missing");
}
break;
#endif
default: assert(false && "Default missing");
}
return returnValue;
}
TConstUnion operator<<(const TConstUnion& constant) const
{
TConstUnion returnValue;
switch (type) {
#ifndef GLSLANG_WEB
case EbtInt8:
switch (constant.type) {
case EbtInt8: returnValue.setI8Const(i8Const << constant.i8Const); break;
case EbtUint8: returnValue.setI8Const(i8Const << constant.u8Const); break;
case EbtInt16: returnValue.setI8Const(i8Const << constant.i16Const); break;
case EbtUint16: returnValue.setI8Const(i8Const << constant.u16Const); break;
case EbtInt: returnValue.setI8Const(i8Const << constant.iConst); break;
case EbtUint: returnValue.setI8Const(i8Const << constant.uConst); break;
case EbtInt64: returnValue.setI8Const(i8Const << constant.i64Const); break;
case EbtUint64: returnValue.setI8Const(i8Const << constant.u64Const); break;
default: assert(false && "Default missing");
}
break;
case EbtUint8:
switch (constant.type) {
case EbtInt8: returnValue.setU8Const(u8Const << constant.i8Const); break;
case EbtUint8: returnValue.setU8Const(u8Const << constant.u8Const); break;
case EbtInt16: returnValue.setU8Const(u8Const << constant.i16Const); break;
case EbtUint16: returnValue.setU8Const(u8Const << constant.u16Const); break;
case EbtInt: returnValue.setU8Const(u8Const << constant.iConst); break;
case EbtUint: returnValue.setU8Const(u8Const << constant.uConst); break;
case EbtInt64: returnValue.setU8Const(u8Const << constant.i64Const); break;
case EbtUint64: returnValue.setU8Const(u8Const << constant.u64Const); break;
default: assert(false && "Default missing");
}
break;
case EbtInt16:
switch (constant.type) {
case EbtInt8: returnValue.setI16Const(i16Const << constant.i8Const); break;
case EbtUint8: returnValue.setI16Const(i16Const << constant.u8Const); break;
case EbtInt16: returnValue.setI16Const(i16Const << constant.i16Const); break;
case EbtUint16: returnValue.setI16Const(i16Const << constant.u16Const); break;
case EbtInt: returnValue.setI16Const(i16Const << constant.iConst); break;
case EbtUint: returnValue.setI16Const(i16Const << constant.uConst); break;
case EbtInt64: returnValue.setI16Const(i16Const << constant.i64Const); break;
case EbtUint64: returnValue.setI16Const(i16Const << constant.u64Const); break;
default: assert(false && "Default missing");
}
break;
case EbtUint16:
switch (constant.type) {
case EbtInt8: returnValue.setU16Const(u16Const << constant.i8Const); break;
case EbtUint8: returnValue.setU16Const(u16Const << constant.u8Const); break;
case EbtInt16: returnValue.setU16Const(u16Const << constant.i16Const); break;
case EbtUint16: returnValue.setU16Const(u16Const << constant.u16Const); break;
case EbtInt: returnValue.setU16Const(u16Const << constant.iConst); break;
case EbtUint: returnValue.setU16Const(u16Const << constant.uConst); break;
case EbtInt64: returnValue.setU16Const(u16Const << constant.i64Const); break;
case EbtUint64: returnValue.setU16Const(u16Const << constant.u64Const); break;
default: assert(false && "Default missing");
}
break;
case EbtInt64:
switch (constant.type) {
case EbtInt8: returnValue.setI64Const(i64Const << constant.i8Const); break;
case EbtUint8: returnValue.setI64Const(i64Const << constant.u8Const); break;
case EbtInt16: returnValue.setI64Const(i64Const << constant.i16Const); break;
case EbtUint16: returnValue.setI64Const(i64Const << constant.u16Const); break;
case EbtInt: returnValue.setI64Const(i64Const << constant.iConst); break;
case EbtUint: returnValue.setI64Const(i64Const << constant.uConst); break;
case EbtInt64: returnValue.setI64Const(i64Const << constant.i64Const); break;
case EbtUint64: returnValue.setI64Const(i64Const << constant.u64Const); break;
default: assert(false && "Default missing");
}
break;
case EbtUint64:
switch (constant.type) {
case EbtInt8: returnValue.setU64Const(u64Const << constant.i8Const); break;
case EbtUint8: returnValue.setU64Const(u64Const << constant.u8Const); break;
case EbtInt16: returnValue.setU64Const(u64Const << constant.i16Const); break;
case EbtUint16: returnValue.setU64Const(u64Const << constant.u16Const); break;
case EbtInt: returnValue.setU64Const(u64Const << constant.iConst); break;
case EbtUint: returnValue.setU64Const(u64Const << constant.uConst); break;
case EbtInt64: returnValue.setU64Const(u64Const << constant.i64Const); break;
case EbtUint64: returnValue.setU64Const(u64Const << constant.u64Const); break;
default: assert(false && "Default missing");
}
break;
#endif
case EbtInt:
switch (constant.type) {
case EbtInt: returnValue.setIConst(iConst << constant.iConst); break;
case EbtUint: returnValue.setIConst(iConst << constant.uConst); break;
#ifndef GLSLANG_WEB
case EbtInt8: returnValue.setIConst(iConst << constant.i8Const); break;
case EbtUint8: returnValue.setIConst(iConst << constant.u8Const); break;
case EbtInt16: returnValue.setIConst(iConst << constant.i16Const); break;
case EbtUint16: returnValue.setIConst(iConst << constant.u16Const); break;
case EbtInt64: returnValue.setIConst(iConst << constant.i64Const); break;
case EbtUint64: returnValue.setIConst(iConst << constant.u64Const); break;
#endif
default: assert(false && "Default missing");
}
break;
case EbtUint:
switch (constant.type) {
case EbtInt: returnValue.setUConst(uConst << constant.iConst); break;
case EbtUint: returnValue.setUConst(uConst << constant.uConst); break;
#ifndef GLSLANG_WEB
case EbtInt8: returnValue.setUConst(uConst << constant.i8Const); break;
case EbtUint8: returnValue.setUConst(uConst << constant.u8Const); break;
case EbtInt16: returnValue.setUConst(uConst << constant.i16Const); break;
case EbtUint16: returnValue.setUConst(uConst << constant.u16Const); break;
case EbtInt64: returnValue.setUConst(uConst << constant.i64Const); break;
case EbtUint64: returnValue.setUConst(uConst << constant.u64Const); break;
#endif
default: assert(false && "Default missing");
}
break;
default: assert(false && "Default missing");
}
return returnValue;
}
TConstUnion operator&(const TConstUnion& constant) const
{
TConstUnion returnValue;
assert(type == constant.type);
switch (type) {
case EbtInt: returnValue.setIConst(iConst & constant.iConst); break;
case EbtUint: returnValue.setUConst(uConst & constant.uConst); break;
#ifndef GLSLANG_WEB
case EbtInt8: returnValue.setI8Const(i8Const & constant.i8Const); break;
case EbtUint8: returnValue.setU8Const(u8Const & constant.u8Const); break;
case EbtInt16: returnValue.setI16Const(i16Const & constant.i16Const); break;
case EbtUint16: returnValue.setU16Const(u16Const & constant.u16Const); break;
case EbtInt64: returnValue.setI64Const(i64Const & constant.i64Const); break;
case EbtUint64: returnValue.setU64Const(u64Const & constant.u64Const); break;
#endif
default: assert(false && "Default missing");
}
return returnValue;
}
TConstUnion operator|(const TConstUnion& constant) const
{
TConstUnion returnValue;
assert(type == constant.type);
switch (type) {
case EbtInt: returnValue.setIConst(iConst | constant.iConst); break;
case EbtUint: returnValue.setUConst(uConst | constant.uConst); break;
#ifndef GLSLANG_WEB
case EbtInt8: returnValue.setI8Const(i8Const | constant.i8Const); break;
case EbtUint8: returnValue.setU8Const(u8Const | constant.u8Const); break;
case EbtInt16: returnValue.setI16Const(i16Const | constant.i16Const); break;
case EbtUint16: returnValue.setU16Const(u16Const | constant.u16Const); break;
case EbtInt64: returnValue.setI64Const(i64Const | constant.i64Const); break;
case EbtUint64: returnValue.setU64Const(u64Const | constant.u64Const); break;
#endif
default: assert(false && "Default missing");
}
return returnValue;
}
TConstUnion operator^(const TConstUnion& constant) const
{
TConstUnion returnValue;
assert(type == constant.type);
switch (type) {
case EbtInt: returnValue.setIConst(iConst ^ constant.iConst); break;
case EbtUint: returnValue.setUConst(uConst ^ constant.uConst); break;
#ifndef GLSLANG_WEB
case EbtInt8: returnValue.setI8Const(i8Const ^ constant.i8Const); break;
case EbtUint8: returnValue.setU8Const(u8Const ^ constant.u8Const); break;
case EbtInt16: returnValue.setI16Const(i16Const ^ constant.i16Const); break;
case EbtUint16: returnValue.setU16Const(u16Const ^ constant.u16Const); break;
case EbtInt64: returnValue.setI64Const(i64Const ^ constant.i64Const); break;
case EbtUint64: returnValue.setU64Const(u64Const ^ constant.u64Const); break;
#endif
default: assert(false && "Default missing");
}
return returnValue;
}
TConstUnion operator~() const
{
TConstUnion returnValue;
switch (type) {
case EbtInt: returnValue.setIConst(~iConst); break;
case EbtUint: returnValue.setUConst(~uConst); break;
#ifndef GLSLANG_WEB
case EbtInt8: returnValue.setI8Const(~i8Const); break;
case EbtUint8: returnValue.setU8Const(~u8Const); break;
case EbtInt16: returnValue.setI16Const(~i16Const); break;
case EbtUint16: returnValue.setU16Const(~u16Const); break;
case EbtInt64: returnValue.setI64Const(~i64Const); break;
case EbtUint64: returnValue.setU64Const(~u64Const); break;
#endif
default: assert(false && "Default missing");
}
return returnValue;
}
TConstUnion operator&&(const TConstUnion& constant) const
{
TConstUnion returnValue;
assert(type == constant.type);
switch (type) {
case EbtBool: returnValue.setBConst(bConst && constant.bConst); break;
default: assert(false && "Default missing");
}
return returnValue;
}
TConstUnion operator||(const TConstUnion& constant) const
{
TConstUnion returnValue;
assert(type == constant.type);
switch (type) {
case EbtBool: returnValue.setBConst(bConst || constant.bConst); break;
default: assert(false && "Default missing");
}
return returnValue;
}
TBasicType getType() const { return type; }
private:
union {
signed char i8Const; // used for i8vec, scalar int8s
unsigned char u8Const; // used for u8vec, scalar uint8s
signed short i16Const; // used for i16vec, scalar int16s
unsigned short u16Const; // used for u16vec, scalar uint16s
int iConst; // used for ivec, scalar ints
unsigned int uConst; // used for uvec, scalar uints
long long i64Const; // used for i64vec, scalar int64s
unsigned long long u64Const; // used for u64vec, scalar uint64s
bool bConst; // used for bvec, scalar bools
double dConst; // used for vec, dvec, mat, dmat, scalar floats and doubles
const TString* sConst; // string constant
};
TBasicType type;
};
// Encapsulate having a pointer to an array of TConstUnion,
// which only needs to be allocated if its size is going to be
// bigger than 0.
//
// One convenience is being able to use [] to go inside the array, instead
// of C++ assuming it as an array of pointers to vectors.
//
// General usage is that the size is known up front, and it is
// created once with the proper size.
//
class TConstUnionArray {
public:
POOL_ALLOCATOR_NEW_DELETE(GetThreadPoolAllocator())
TConstUnionArray() : unionArray(nullptr) { }
virtual ~TConstUnionArray() { }
explicit TConstUnionArray(int size)
{
if (size == 0)
unionArray = nullptr;
else
unionArray = new TConstUnionVector(size);
}
TConstUnionArray(const TConstUnionArray& a) = default;
TConstUnionArray(const TConstUnionArray& a, int start, int size)
{
unionArray = new TConstUnionVector(size);
for (int i = 0; i < size; ++i)
(*unionArray)[i] = a[start + i];
}
// Use this constructor for a smear operation
TConstUnionArray(int size, const TConstUnion& val)
{
unionArray = new TConstUnionVector(size, val);
}
int size() const { return unionArray ? (int)unionArray->size() : 0; }
TConstUnion& operator[](size_t index) { return (*unionArray)[index]; }
const TConstUnion& operator[](size_t index) const { return (*unionArray)[index]; }
bool operator==(const TConstUnionArray& rhs) const
{
// this includes the case that both are unallocated
if (unionArray == rhs.unionArray)
return true;
if (! unionArray || ! rhs.unionArray)
return false;
return *unionArray == *rhs.unionArray;
}
bool operator!=(const TConstUnionArray& rhs) const { return ! operator==(rhs); }
double dot(const TConstUnionArray& rhs)
{
assert(rhs.unionArray->size() == unionArray->size());
double sum = 0.0;
for (size_t comp = 0; comp < unionArray->size(); ++comp)
sum += (*this)[comp].getDConst() * rhs[comp].getDConst();
return sum;
}
bool empty() const { return unionArray == nullptr; }
protected:
typedef TVector<TConstUnion> TConstUnionVector;
TConstUnionVector* unionArray;
};
} // end namespace glslang
#endif // _CONSTANT_UNION_INCLUDED_

View file

@ -1,318 +0,0 @@
//
// Copyright (C) 2002-2005 3Dlabs Inc. Ltd.
// Copyright (C) 2012-2013 LunarG, Inc.
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//
// Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following
// disclaimer in the documentation and/or other materials provided
// with the distribution.
//
// Neither the name of 3Dlabs Inc. Ltd. nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
// FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
// COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
// BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
// ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
#ifndef _POOLALLOC_INCLUDED_
#define _POOLALLOC_INCLUDED_
#ifndef NDEBUG
# define GUARD_BLOCKS // define to enable guard block sanity checking
#endif
//
// This header defines an allocator that can be used to efficiently
// allocate a large number of small requests for heap memory, with the
// intention that they are not individually deallocated, but rather
// collectively deallocated at one time.
//
// This simultaneously
//
// * Makes each individual allocation much more efficient; the
// typical allocation is trivial.
// * Completely avoids the cost of doing individual deallocation.
// * Saves the trouble of tracking down and plugging a large class of leaks.
//
// Individual classes can use this allocator by supplying their own
// new and delete methods.
//
// STL containers can use this allocator by using the pool_allocator
// class as the allocator (second) template argument.
//
#include <cstddef>
#include <cstring>
#include <vector>
namespace glslang {
// If we are using guard blocks, we must track each individual
// allocation. If we aren't using guard blocks, these
// never get instantiated, so won't have any impact.
//
class TAllocation {
public:
TAllocation(size_t size, unsigned char* mem, TAllocation* prev = nullptr) :
size(size), mem(mem), prevAlloc(prev) {
// Allocations are bracketed:
// [allocationHeader][initialGuardBlock][userData][finalGuardBlock]
// This would be cleaner with if (guardBlockSize)..., but that
// makes the compiler print warnings about 0 length memsets,
// even with the if() protecting them.
# ifdef GUARD_BLOCKS
memset(preGuard(), guardBlockBeginVal, guardBlockSize);
memset(data(), userDataFill, size);
memset(postGuard(), guardBlockEndVal, guardBlockSize);
# endif
}
void check() const {
checkGuardBlock(preGuard(), guardBlockBeginVal, "before");
checkGuardBlock(postGuard(), guardBlockEndVal, "after");
}
void checkAllocList() const;
// Return total size needed to accommodate user buffer of 'size',
// plus our tracking data.
inline static size_t allocationSize(size_t size) {
return size + 2 * guardBlockSize + headerSize();
}
// Offset from surrounding buffer to get to user data buffer.
inline static unsigned char* offsetAllocation(unsigned char* m) {
return m + guardBlockSize + headerSize();
}
private:
void checkGuardBlock(unsigned char* blockMem, unsigned char val, const char* locText) const;
// Find offsets to pre and post guard blocks, and user data buffer
unsigned char* preGuard() const { return mem + headerSize(); }
unsigned char* data() const { return preGuard() + guardBlockSize; }
unsigned char* postGuard() const { return data() + size; }
size_t size; // size of the user data area
unsigned char* mem; // beginning of our allocation (pts to header)
TAllocation* prevAlloc; // prior allocation in the chain
const static unsigned char guardBlockBeginVal;
const static unsigned char guardBlockEndVal;
const static unsigned char userDataFill;
const static size_t guardBlockSize;
# ifdef GUARD_BLOCKS
inline static size_t headerSize() { return sizeof(TAllocation); }
# else
inline static size_t headerSize() { return 0; }
# endif
};
//
// There are several stacks. One is to track the pushing and popping
// of the user, and not yet implemented. The others are simply a
// repositories of free pages or used pages.
//
// Page stacks are linked together with a simple header at the beginning
// of each allocation obtained from the underlying OS. Multi-page allocations
// are returned to the OS. Individual page allocations are kept for future
// re-use.
//
// The "page size" used is not, nor must it match, the underlying OS
// page size. But, having it be about that size or equal to a set of
// pages is likely most optimal.
//
class TPoolAllocator {
public:
TPoolAllocator(int growthIncrement = 8*1024, int allocationAlignment = 16);
//
// Don't call the destructor just to free up the memory, call pop()
//
~TPoolAllocator();
//
// Call push() to establish a new place to pop memory too. Does not
// have to be called to get things started.
//
void push();
//
// Call pop() to free all memory allocated since the last call to push(),
// or if no last call to push, frees all memory since first allocation.
//
void pop();
//
// Call popAll() to free all memory allocated.
//
void popAll();
//
// Call allocate() to actually acquire memory. Returns nullptr if no memory
// available, otherwise a properly aligned pointer to 'numBytes' of memory.
//
void* allocate(size_t numBytes);
//
// There is no deallocate. The point of this class is that
// deallocation can be skipped by the user of it, as the model
// of use is to simultaneously deallocate everything at once
// by calling pop(), and to not have to solve memory leak problems.
//
protected:
friend struct tHeader;
struct tHeader {
tHeader(tHeader* nextPage, size_t pageCount) :
#ifdef GUARD_BLOCKS
lastAllocation(nullptr),
#endif
nextPage(nextPage), pageCount(pageCount) { }
~tHeader() {
#ifdef GUARD_BLOCKS
if (lastAllocation)
lastAllocation->checkAllocList();
#endif
}
#ifdef GUARD_BLOCKS
TAllocation* lastAllocation;
#endif
tHeader* nextPage;
size_t pageCount;
};
struct tAllocState {
size_t offset;
tHeader* page;
};
typedef std::vector<tAllocState> tAllocStack;
// Track allocations if and only if we're using guard blocks
#ifndef GUARD_BLOCKS
void* initializeAllocation(tHeader*, unsigned char* memory, size_t) {
#else
void* initializeAllocation(tHeader* block, unsigned char* memory, size_t numBytes) {
new(memory) TAllocation(numBytes, memory, block->lastAllocation);
block->lastAllocation = reinterpret_cast<TAllocation*>(memory);
#endif
// This is optimized entirely away if GUARD_BLOCKS is not defined.
return TAllocation::offsetAllocation(memory);
}
size_t pageSize; // granularity of allocation from the OS
size_t alignment; // all returned allocations will be aligned at
// this granularity, which will be a power of 2
size_t alignmentMask;
size_t headerSkip; // amount of memory to skip to make room for the
// header (basically, size of header, rounded
// up to make it aligned
size_t currentPageOffset; // next offset in top of inUseList to allocate from
tHeader* freeList; // list of popped memory
tHeader* inUseList; // list of all memory currently being used
tAllocStack stack; // stack of where to allocate from, to partition pool
int numCalls; // just an interesting statistic
size_t totalBytes; // just an interesting statistic
private:
TPoolAllocator& operator=(const TPoolAllocator&); // don't allow assignment operator
TPoolAllocator(const TPoolAllocator&); // don't allow default copy constructor
};
//
// There could potentially be many pools with pops happening at
// different times. But a simple use is to have a global pop
// with everyone using the same global allocator.
//
extern TPoolAllocator& GetThreadPoolAllocator();
void SetThreadPoolAllocator(TPoolAllocator* poolAllocator);
//
// This STL compatible allocator is intended to be used as the allocator
// parameter to templatized STL containers, like vector and map.
//
// It will use the pools for allocation, and not
// do any deallocation, but will still do destruction.
//
template<class T>
class pool_allocator {
public:
typedef size_t size_type;
typedef ptrdiff_t difference_type;
typedef T *pointer;
typedef const T *const_pointer;
typedef T& reference;
typedef const T& const_reference;
typedef T value_type;
template<class Other>
struct rebind {
typedef pool_allocator<Other> other;
};
pointer address(reference x) const { return &x; }
const_pointer address(const_reference x) const { return &x; }
pool_allocator() : allocator(GetThreadPoolAllocator()) { }
pool_allocator(TPoolAllocator& a) : allocator(a) { }
pool_allocator(const pool_allocator<T>& p) : allocator(p.allocator) { }
template<class Other>
pool_allocator(const pool_allocator<Other>& p) : allocator(p.getAllocator()) { }
pointer allocate(size_type n) {
return reinterpret_cast<pointer>(getAllocator().allocate(n * sizeof(T))); }
pointer allocate(size_type n, const void*) {
return reinterpret_cast<pointer>(getAllocator().allocate(n * sizeof(T))); }
void deallocate(void*, size_type) { }
void deallocate(pointer, size_type) { }
pointer _Charalloc(size_t n) {
return reinterpret_cast<pointer>(getAllocator().allocate(n)); }
void construct(pointer p, const T& val) { new ((void *)p) T(val); }
void destroy(pointer p) { p->T::~T(); }
bool operator==(const pool_allocator& rhs) const { return &getAllocator() == &rhs.getAllocator(); }
bool operator!=(const pool_allocator& rhs) const { return &getAllocator() != &rhs.getAllocator(); }
size_type max_size() const { return static_cast<size_type>(-1) / sizeof(T); }
size_type max_size(int size) const { return static_cast<size_type>(-1) / size; }
TPoolAllocator& getAllocator() const { return allocator; }
pool_allocator select_on_container_copy_construction() const { return pool_allocator{}; }
protected:
pool_allocator& operator=(const pool_allocator&) { return *this; }
TPoolAllocator& allocator;
};
} // end namespace glslang
#endif // _POOLALLOC_INCLUDED_

View file

@ -1,128 +0,0 @@
//
// Copyright(C) 2021 Advanced Micro Devices, Inc.
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//
// Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following
// disclaimer in the documentation and/or other materials provided
// with the distribution.
//
// Neither the name of 3Dlabs Inc. Ltd. nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
// FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
// COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
// BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
// ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
#pragma once
#ifndef GLSLANG_WEB
//
// GL_EXT_spirv_intrinsics
//
#include "Common.h"
namespace glslang {
class TIntermTyped;
class TIntermConstantUnion;
class TType;
// SPIR-V requirements
struct TSpirvRequirement {
POOL_ALLOCATOR_NEW_DELETE(GetThreadPoolAllocator())
// capability = [..]
TSet<TString> extensions;
// extension = [..]
TSet<int> capabilities;
};
// SPIR-V execution modes
struct TSpirvExecutionMode {
POOL_ALLOCATOR_NEW_DELETE(GetThreadPoolAllocator())
// spirv_execution_mode
TMap<int, TVector<const TIntermConstantUnion*>> modes;
// spirv_execution_mode_id
TMap<int, TVector<const TIntermTyped*> > modeIds;
};
// SPIR-V decorations
struct TSpirvDecorate {
POOL_ALLOCATOR_NEW_DELETE(GetThreadPoolAllocator())
// spirv_decorate
TMap<int, TVector<const TIntermConstantUnion*> > decorates;
// spirv_decorate_id
TMap<int, TVector<const TIntermTyped*>> decorateIds;
// spirv_decorate_string
TMap<int, TVector<const TIntermConstantUnion*> > decorateStrings;
};
// SPIR-V instruction
struct TSpirvInstruction {
POOL_ALLOCATOR_NEW_DELETE(GetThreadPoolAllocator())
TSpirvInstruction() { set = ""; id = -1; }
bool operator==(const TSpirvInstruction& rhs) const { return set == rhs.set && id == rhs.id; }
bool operator!=(const TSpirvInstruction& rhs) const { return !operator==(rhs); }
// spirv_instruction
TString set;
int id;
};
// SPIR-V type parameter
struct TSpirvTypeParameter {
POOL_ALLOCATOR_NEW_DELETE(GetThreadPoolAllocator())
TSpirvTypeParameter(const TIntermConstantUnion* arg) { constant = arg; }
bool operator==(const TSpirvTypeParameter& rhs) const { return constant == rhs.constant; }
bool operator!=(const TSpirvTypeParameter& rhs) const { return !operator==(rhs); }
const TIntermConstantUnion* constant;
};
typedef TVector<TSpirvTypeParameter> TSpirvTypeParameters;
// SPIR-V type
struct TSpirvType {
POOL_ALLOCATOR_NEW_DELETE(GetThreadPoolAllocator())
bool operator==(const TSpirvType& rhs) const
{
return spirvInst == rhs.spirvInst && typeParams == rhs.typeParams;
}
bool operator!=(const TSpirvType& rhs) const { return !operator==(rhs); }
// spirv_type
TSpirvInstruction spirvInst;
TSpirvTypeParameters typeParams;
};
} // end namespace glslang
#endif // GLSLANG_WEB

File diff suppressed because it is too large Load diff

View file

@ -1,352 +0,0 @@
//
// Copyright (C) 2002-2005 3Dlabs Inc. Ltd.
// Copyright (C) 2012-2013 LunarG, Inc.
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//
// Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following
// disclaimer in the documentation and/or other materials provided
// with the distribution.
//
// Neither the name of 3Dlabs Inc. Ltd. nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
// FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
// COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
// BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
// ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
//
// Implement types for tracking GLSL arrays, arrays of arrays, etc.
//
#ifndef _ARRAYS_INCLUDED
#define _ARRAYS_INCLUDED
#include <algorithm>
namespace glslang {
// This is used to mean there is no size yet (unsized), it is waiting to get a size from somewhere else.
const int UnsizedArraySize = 0;
class TIntermTyped;
extern bool SameSpecializationConstants(TIntermTyped*, TIntermTyped*);
// Specialization constants need both a nominal size and a node that defines
// the specialization constant being used. Array types are the same when their
// size and specialization constant nodes are the same.
struct TArraySize {
unsigned int size;
TIntermTyped* node; // nullptr means no specialization constant node
bool operator==(const TArraySize& rhs) const
{
if (size != rhs.size)
return false;
if (node == nullptr || rhs.node == nullptr)
return node == rhs.node;
return SameSpecializationConstants(node, rhs.node);
}
};
//
// TSmallArrayVector is used as the container for the set of sizes in TArraySizes.
// It has generic-container semantics, while TArraySizes has array-of-array semantics.
// That is, TSmallArrayVector should be more focused on mechanism and TArraySizes on policy.
//
struct TSmallArrayVector {
//
// TODO: memory: TSmallArrayVector is intended to be smaller.
// Almost all arrays could be handled by two sizes each fitting
// in 16 bits, needing a real vector only in the cases where there
// are more than 3 sizes or a size needing more than 16 bits.
//
POOL_ALLOCATOR_NEW_DELETE(GetThreadPoolAllocator())
TSmallArrayVector() : sizes(nullptr) { }
virtual ~TSmallArrayVector() { dealloc(); }
// For breaking into two non-shared copies, independently modifiable.
TSmallArrayVector& operator=(const TSmallArrayVector& from)
{
if (from.sizes == nullptr)
sizes = nullptr;
else {
alloc();
*sizes = *from.sizes;
}
return *this;
}
int size() const
{
if (sizes == nullptr)
return 0;
return (int)sizes->size();
}
unsigned int frontSize() const
{
assert(sizes != nullptr && sizes->size() > 0);
return sizes->front().size;
}
TIntermTyped* frontNode() const
{
assert(sizes != nullptr && sizes->size() > 0);
return sizes->front().node;
}
void changeFront(unsigned int s)
{
assert(sizes != nullptr);
// this should only happen for implicitly sized arrays, not specialization constants
assert(sizes->front().node == nullptr);
sizes->front().size = s;
}
void push_back(unsigned int e, TIntermTyped* n)
{
alloc();
TArraySize pair = { e, n };
sizes->push_back(pair);
}
void push_back(const TSmallArrayVector& newDims)
{
alloc();
sizes->insert(sizes->end(), newDims.sizes->begin(), newDims.sizes->end());
}
void pop_front()
{
assert(sizes != nullptr && sizes->size() > 0);
if (sizes->size() == 1)
dealloc();
else
sizes->erase(sizes->begin());
}
// 'this' should currently not be holding anything, and copyNonFront
// will make it hold a copy of all but the first element of rhs.
// (This would be useful for making a type that is dereferenced by
// one dimension.)
void copyNonFront(const TSmallArrayVector& rhs)
{
assert(sizes == nullptr);
if (rhs.size() > 1) {
alloc();
sizes->insert(sizes->begin(), rhs.sizes->begin() + 1, rhs.sizes->end());
}
}
unsigned int getDimSize(int i) const
{
assert(sizes != nullptr && (int)sizes->size() > i);
return (*sizes)[i].size;
}
void setDimSize(int i, unsigned int size) const
{
assert(sizes != nullptr && (int)sizes->size() > i);
assert((*sizes)[i].node == nullptr);
(*sizes)[i].size = size;
}
TIntermTyped* getDimNode(int i) const
{
assert(sizes != nullptr && (int)sizes->size() > i);
return (*sizes)[i].node;
}
bool operator==(const TSmallArrayVector& rhs) const
{
if (sizes == nullptr && rhs.sizes == nullptr)
return true;
if (sizes == nullptr || rhs.sizes == nullptr)
return false;
return *sizes == *rhs.sizes;
}
bool operator!=(const TSmallArrayVector& rhs) const { return ! operator==(rhs); }
protected:
TSmallArrayVector(const TSmallArrayVector&);
void alloc()
{
if (sizes == nullptr)
sizes = new TVector<TArraySize>;
}
void dealloc()
{
delete sizes;
sizes = nullptr;
}
TVector<TArraySize>* sizes; // will either hold such a pointer, or in the future, hold the two array sizes
};
//
// Represent an array, or array of arrays, to arbitrary depth. This is not
// done through a hierarchy of types in a type tree, rather all contiguous arrayness
// in the type hierarchy is localized into this single cumulative object.
//
// The arrayness in TTtype is a pointer, so that it can be non-allocated and zero
// for the vast majority of types that are non-array types.
//
// Order Policy: these are all identical:
// - left to right order within a contiguous set of ...[..][..][..]... in the source language
// - index order 0, 1, 2, ... within the 'sizes' member below
// - outer-most to inner-most
//
struct TArraySizes {
POOL_ALLOCATOR_NEW_DELETE(GetThreadPoolAllocator())
TArraySizes() : implicitArraySize(0), implicitlySized(true), variablyIndexed(false){ }
// For breaking into two non-shared copies, independently modifiable.
TArraySizes& operator=(const TArraySizes& from)
{
implicitArraySize = from.implicitArraySize;
variablyIndexed = from.variablyIndexed;
sizes = from.sizes;
implicitlySized = from.implicitlySized;
return *this;
}
// translate from array-of-array semantics to container semantics
int getNumDims() const { return sizes.size(); }
int getDimSize(int dim) const { return sizes.getDimSize(dim); }
TIntermTyped* getDimNode(int dim) const { return sizes.getDimNode(dim); }
void setDimSize(int dim, int size) { sizes.setDimSize(dim, size); }
int getOuterSize() const { return sizes.frontSize(); }
TIntermTyped* getOuterNode() const { return sizes.frontNode(); }
int getCumulativeSize() const
{
int size = 1;
for (int d = 0; d < sizes.size(); ++d) {
// this only makes sense in paths that have a known array size
assert(sizes.getDimSize(d) != UnsizedArraySize);
size *= sizes.getDimSize(d);
}
return size;
}
void addInnerSize() { addInnerSize((unsigned)UnsizedArraySize); }
void addInnerSize(int s) { addInnerSize((unsigned)s, nullptr); }
void addInnerSize(int s, TIntermTyped* n) { sizes.push_back((unsigned)s, n); }
void addInnerSize(TArraySize pair) {
sizes.push_back(pair.size, pair.node);
implicitlySized = false;
}
void addInnerSizes(const TArraySizes& s) { sizes.push_back(s.sizes); }
void changeOuterSize(int s) {
sizes.changeFront((unsigned)s);
implicitlySized = false;
}
int getImplicitSize() const { return implicitArraySize > 0 ? implicitArraySize : 1; }
void updateImplicitSize(int s) {
implicitArraySize = (std::max)(implicitArraySize, s);
}
bool isInnerUnsized() const
{
for (int d = 1; d < sizes.size(); ++d) {
if (sizes.getDimSize(d) == (unsigned)UnsizedArraySize)
return true;
}
return false;
}
bool clearInnerUnsized()
{
for (int d = 1; d < sizes.size(); ++d) {
if (sizes.getDimSize(d) == (unsigned)UnsizedArraySize)
setDimSize(d, 1);
}
return false;
}
bool isInnerSpecialization() const
{
for (int d = 1; d < sizes.size(); ++d) {
if (sizes.getDimNode(d) != nullptr)
return true;
}
return false;
}
bool isOuterSpecialization()
{
return sizes.getDimNode(0) != nullptr;
}
bool hasUnsized() const { return getOuterSize() == UnsizedArraySize || isInnerUnsized(); }
bool isSized() const { return getOuterSize() != UnsizedArraySize; }
bool isImplicitlySized() const { return implicitlySized; }
bool isDefaultImplicitlySized() const { return implicitlySized && implicitArraySize == 0; }
void setImplicitlySized(bool isImplicitSizing) { implicitlySized = isImplicitSizing; }
void dereference() { sizes.pop_front(); }
void copyDereferenced(const TArraySizes& rhs)
{
assert(sizes.size() == 0);
if (rhs.sizes.size() > 1)
sizes.copyNonFront(rhs.sizes);
}
bool sameInnerArrayness(const TArraySizes& rhs) const
{
if (sizes.size() != rhs.sizes.size())
return false;
for (int d = 1; d < sizes.size(); ++d) {
if (sizes.getDimSize(d) != rhs.sizes.getDimSize(d) ||
sizes.getDimNode(d) != rhs.sizes.getDimNode(d))
return false;
}
return true;
}
void setVariablyIndexed() { variablyIndexed = true; }
bool isVariablyIndexed() const { return variablyIndexed; }
bool operator==(const TArraySizes& rhs) const { return sizes == rhs.sizes; }
bool operator!=(const TArraySizes& rhs) const { return sizes != rhs.sizes; }
protected:
TSmallArrayVector sizes;
TArraySizes(const TArraySizes&);
// For tracking maximum referenced compile-time constant index.
// Applies only to the outer-most dimension. Potentially becomes
// the implicit size of the array, if not variably indexed and
// otherwise legal.
int implicitArraySize;
bool implicitlySized;
bool variablyIndexed; // true if array is indexed with a non compile-time constant
};
} // end namespace glslang
#endif // _ARRAYS_INCLUDED_

View file

@ -0,0 +1,297 @@
/**
This code is based on the glslang_c_interface implementation by Viktor Latypov
**/
/**
BSD 2-Clause License
Copyright (c) 2019, Viktor Latypov
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
**/
#ifndef GLSLANG_C_IFACE_H_INCLUDED
#define GLSLANG_C_IFACE_H_INCLUDED
#include <stdbool.h>
#include <stdlib.h>
#include "glslang_c_shader_types.h"
#include "visibility.h"
typedef struct glslang_shader_s glslang_shader_t;
typedef struct glslang_program_s glslang_program_t;
typedef struct glslang_mapper_s glslang_mapper_t;
typedef struct glslang_resolver_s glslang_resolver_t;
/* Version counterpart */
typedef struct glslang_version_s {
int major;
int minor;
int patch;
const char* flavor;
} glslang_version_t;
/* TLimits counterpart */
typedef struct glslang_limits_s {
bool non_inductive_for_loops;
bool while_loops;
bool do_while_loops;
bool general_uniform_indexing;
bool general_attribute_matrix_vector_indexing;
bool general_varying_indexing;
bool general_sampler_indexing;
bool general_variable_indexing;
bool general_constant_matrix_vector_indexing;
} glslang_limits_t;
/* TBuiltInResource counterpart */
typedef struct glslang_resource_s {
int max_lights;
int max_clip_planes;
int max_texture_units;
int max_texture_coords;
int max_vertex_attribs;
int max_vertex_uniform_components;
int max_varying_floats;
int max_vertex_texture_image_units;
int max_combined_texture_image_units;
int max_texture_image_units;
int max_fragment_uniform_components;
int max_draw_buffers;
int max_vertex_uniform_vectors;
int max_varying_vectors;
int max_fragment_uniform_vectors;
int max_vertex_output_vectors;
int max_fragment_input_vectors;
int min_program_texel_offset;
int max_program_texel_offset;
int max_clip_distances;
int max_compute_work_group_count_x;
int max_compute_work_group_count_y;
int max_compute_work_group_count_z;
int max_compute_work_group_size_x;
int max_compute_work_group_size_y;
int max_compute_work_group_size_z;
int max_compute_uniform_components;
int max_compute_texture_image_units;
int max_compute_image_uniforms;
int max_compute_atomic_counters;
int max_compute_atomic_counter_buffers;
int max_varying_components;
int max_vertex_output_components;
int max_geometry_input_components;
int max_geometry_output_components;
int max_fragment_input_components;
int max_image_units;
int max_combined_image_units_and_fragment_outputs;
int max_combined_shader_output_resources;
int max_image_samples;
int max_vertex_image_uniforms;
int max_tess_control_image_uniforms;
int max_tess_evaluation_image_uniforms;
int max_geometry_image_uniforms;
int max_fragment_image_uniforms;
int max_combined_image_uniforms;
int max_geometry_texture_image_units;
int max_geometry_output_vertices;
int max_geometry_total_output_components;
int max_geometry_uniform_components;
int max_geometry_varying_components;
int max_tess_control_input_components;
int max_tess_control_output_components;
int max_tess_control_texture_image_units;
int max_tess_control_uniform_components;
int max_tess_control_total_output_components;
int max_tess_evaluation_input_components;
int max_tess_evaluation_output_components;
int max_tess_evaluation_texture_image_units;
int max_tess_evaluation_uniform_components;
int max_tess_patch_components;
int max_patch_vertices;
int max_tess_gen_level;
int max_viewports;
int max_vertex_atomic_counters;
int max_tess_control_atomic_counters;
int max_tess_evaluation_atomic_counters;
int max_geometry_atomic_counters;
int max_fragment_atomic_counters;
int max_combined_atomic_counters;
int max_atomic_counter_bindings;
int max_vertex_atomic_counter_buffers;
int max_tess_control_atomic_counter_buffers;
int max_tess_evaluation_atomic_counter_buffers;
int max_geometry_atomic_counter_buffers;
int max_fragment_atomic_counter_buffers;
int max_combined_atomic_counter_buffers;
int max_atomic_counter_buffer_size;
int max_transform_feedback_buffers;
int max_transform_feedback_interleaved_components;
int max_cull_distances;
int max_combined_clip_and_cull_distances;
int max_samples;
int max_mesh_output_vertices_nv;
int max_mesh_output_primitives_nv;
int max_mesh_work_group_size_x_nv;
int max_mesh_work_group_size_y_nv;
int max_mesh_work_group_size_z_nv;
int max_task_work_group_size_x_nv;
int max_task_work_group_size_y_nv;
int max_task_work_group_size_z_nv;
int max_mesh_view_count_nv;
int max_mesh_output_vertices_ext;
int max_mesh_output_primitives_ext;
int max_mesh_work_group_size_x_ext;
int max_mesh_work_group_size_y_ext;
int max_mesh_work_group_size_z_ext;
int max_task_work_group_size_x_ext;
int max_task_work_group_size_y_ext;
int max_task_work_group_size_z_ext;
int max_mesh_view_count_ext;
union
{
int max_dual_source_draw_buffers_ext;
/* Incorrectly capitalized name retained for backward compatibility */
int maxDualSourceDrawBuffersEXT;
};
glslang_limits_t limits;
} glslang_resource_t;
/* Inclusion result structure allocated by C include_local/include_system callbacks */
typedef struct glsl_include_result_s {
/* Header file name or NULL if inclusion failed */
const char* header_name;
/* Header contents or NULL */
const char* header_data;
size_t header_length;
} glsl_include_result_t;
/* Callback for local file inclusion */
typedef glsl_include_result_t* (*glsl_include_local_func)(void* ctx, const char* header_name, const char* includer_name,
size_t include_depth);
/* Callback for system file inclusion */
typedef glsl_include_result_t* (*glsl_include_system_func)(void* ctx, const char* header_name,
const char* includer_name, size_t include_depth);
/* Callback for include result destruction */
typedef int (*glsl_free_include_result_func)(void* ctx, glsl_include_result_t* result);
/* Collection of callbacks for GLSL preprocessor */
typedef struct glsl_include_callbacks_s {
glsl_include_system_func include_system;
glsl_include_local_func include_local;
glsl_free_include_result_func free_include_result;
} glsl_include_callbacks_t;
typedef struct glslang_input_s {
glslang_source_t language;
glslang_stage_t stage;
glslang_client_t client;
glslang_target_client_version_t client_version;
glslang_target_language_t target_language;
glslang_target_language_version_t target_language_version;
/** Shader source code */
const char* code;
int default_version;
glslang_profile_t default_profile;
int force_default_version_and_profile;
int forward_compatible;
glslang_messages_t messages;
const glslang_resource_t* resource;
glsl_include_callbacks_t callbacks;
void* callbacks_ctx;
} glslang_input_t;
/* SpvOptions counterpart */
typedef struct glslang_spv_options_s {
bool generate_debug_info;
bool strip_debug_info;
bool disable_optimizer;
bool optimize_size;
bool disassemble;
bool validate;
bool emit_nonsemantic_shader_debug_info;
bool emit_nonsemantic_shader_debug_source;
bool compile_only;
bool optimize_allow_expanded_id_bound;
} glslang_spv_options_t;
#ifdef __cplusplus
extern "C" {
#endif
GLSLANG_EXPORT void glslang_get_version(glslang_version_t* version);
GLSLANG_EXPORT int glslang_initialize_process(void);
GLSLANG_EXPORT void glslang_finalize_process(void);
GLSLANG_EXPORT glslang_shader_t* glslang_shader_create(const glslang_input_t* input);
GLSLANG_EXPORT void glslang_shader_delete(glslang_shader_t* shader);
GLSLANG_EXPORT void glslang_shader_set_preamble(glslang_shader_t* shader, const char* s);
GLSLANG_EXPORT void glslang_shader_shift_binding(glslang_shader_t* shader, glslang_resource_type_t res, unsigned int base);
GLSLANG_EXPORT void glslang_shader_shift_binding_for_set(glslang_shader_t* shader, glslang_resource_type_t res, unsigned int base, unsigned int set);
GLSLANG_EXPORT void glslang_shader_set_options(glslang_shader_t* shader, int options); // glslang_shader_options_t
GLSLANG_EXPORT void glslang_shader_set_glsl_version(glslang_shader_t* shader, int version);
GLSLANG_EXPORT void glslang_shader_set_default_uniform_block_set_and_binding(glslang_shader_t* shader, unsigned int set, unsigned int binding);
GLSLANG_EXPORT void glslang_shader_set_default_uniform_block_name(glslang_shader_t* shader, const char *name);
GLSLANG_EXPORT void glslang_shader_set_resource_set_binding(glslang_shader_t* shader, const char *const *bindings, unsigned int num_bindings);
GLSLANG_EXPORT int glslang_shader_preprocess(glslang_shader_t* shader, const glslang_input_t* input);
GLSLANG_EXPORT int glslang_shader_parse(glslang_shader_t* shader, const glslang_input_t* input);
GLSLANG_EXPORT const char* glslang_shader_get_preprocessed_code(glslang_shader_t* shader);
GLSLANG_EXPORT void glslang_shader_set_preprocessed_code(glslang_shader_t* shader, const char* code);
GLSLANG_EXPORT const char* glslang_shader_get_info_log(glslang_shader_t* shader);
GLSLANG_EXPORT const char* glslang_shader_get_info_debug_log(glslang_shader_t* shader);
GLSLANG_EXPORT glslang_program_t* glslang_program_create(void);
GLSLANG_EXPORT void glslang_program_delete(glslang_program_t* program);
GLSLANG_EXPORT void glslang_program_add_shader(glslang_program_t* program, glslang_shader_t* shader);
GLSLANG_EXPORT int glslang_program_link(glslang_program_t* program, int messages); // glslang_messages_t
GLSLANG_EXPORT void glslang_program_add_source_text(glslang_program_t* program, glslang_stage_t stage, const char* text, size_t len);
GLSLANG_EXPORT void glslang_program_set_source_file(glslang_program_t* program, glslang_stage_t stage, const char* file);
GLSLANG_EXPORT int glslang_program_map_io(glslang_program_t* program);
GLSLANG_EXPORT int glslang_program_map_io_with_resolver_and_mapper(glslang_program_t* program, glslang_resolver_t* resolver, glslang_mapper_t* mapper);
GLSLANG_EXPORT void glslang_program_SPIRV_generate(glslang_program_t* program, glslang_stage_t stage);
GLSLANG_EXPORT void glslang_program_SPIRV_generate_with_options(glslang_program_t* program, glslang_stage_t stage, glslang_spv_options_t* spv_options);
GLSLANG_EXPORT size_t glslang_program_SPIRV_get_size(glslang_program_t* program);
GLSLANG_EXPORT void glslang_program_SPIRV_get(glslang_program_t* program, unsigned int*);
GLSLANG_EXPORT unsigned int* glslang_program_SPIRV_get_ptr(glslang_program_t* program);
GLSLANG_EXPORT const char* glslang_program_SPIRV_get_messages(glslang_program_t* program);
GLSLANG_EXPORT const char* glslang_program_get_info_log(glslang_program_t* program);
GLSLANG_EXPORT const char* glslang_program_get_info_debug_log(glslang_program_t* program);
GLSLANG_EXPORT glslang_mapper_t* glslang_glsl_mapper_create();
GLSLANG_EXPORT void glslang_glsl_mapper_delete(glslang_mapper_t* mapper);
GLSLANG_EXPORT glslang_resolver_t* glslang_glsl_resolver_create(glslang_program_t* program, glslang_stage_t stage);
GLSLANG_EXPORT void glslang_glsl_resolver_delete(glslang_resolver_t* resolver);
#ifdef __cplusplus
}
#endif
#endif /* #ifdef GLSLANG_C_IFACE_INCLUDED */

View file

@ -0,0 +1,229 @@
/**
This code is based on the glslang_c_interface implementation by Viktor Latypov
**/
/**
BSD 2-Clause License
Copyright (c) 2019, Viktor Latypov
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
**/
#ifndef C_SHADER_TYPES_H_INCLUDED
#define C_SHADER_TYPES_H_INCLUDED
#define LAST_ELEMENT_MARKER(x) x
/* EShLanguage counterpart */
typedef enum {
GLSLANG_STAGE_VERTEX,
GLSLANG_STAGE_TESSCONTROL,
GLSLANG_STAGE_TESSEVALUATION,
GLSLANG_STAGE_GEOMETRY,
GLSLANG_STAGE_FRAGMENT,
GLSLANG_STAGE_COMPUTE,
GLSLANG_STAGE_RAYGEN,
GLSLANG_STAGE_RAYGEN_NV = GLSLANG_STAGE_RAYGEN,
GLSLANG_STAGE_INTERSECT,
GLSLANG_STAGE_INTERSECT_NV = GLSLANG_STAGE_INTERSECT,
GLSLANG_STAGE_ANYHIT,
GLSLANG_STAGE_ANYHIT_NV = GLSLANG_STAGE_ANYHIT,
GLSLANG_STAGE_CLOSESTHIT,
GLSLANG_STAGE_CLOSESTHIT_NV = GLSLANG_STAGE_CLOSESTHIT,
GLSLANG_STAGE_MISS,
GLSLANG_STAGE_MISS_NV = GLSLANG_STAGE_MISS,
GLSLANG_STAGE_CALLABLE,
GLSLANG_STAGE_CALLABLE_NV = GLSLANG_STAGE_CALLABLE,
GLSLANG_STAGE_TASK,
GLSLANG_STAGE_TASK_NV = GLSLANG_STAGE_TASK,
GLSLANG_STAGE_MESH,
GLSLANG_STAGE_MESH_NV = GLSLANG_STAGE_MESH,
LAST_ELEMENT_MARKER(GLSLANG_STAGE_COUNT),
} glslang_stage_t; // would be better as stage, but this is ancient now
/* EShLanguageMask counterpart */
typedef enum {
GLSLANG_STAGE_VERTEX_MASK = (1 << GLSLANG_STAGE_VERTEX),
GLSLANG_STAGE_TESSCONTROL_MASK = (1 << GLSLANG_STAGE_TESSCONTROL),
GLSLANG_STAGE_TESSEVALUATION_MASK = (1 << GLSLANG_STAGE_TESSEVALUATION),
GLSLANG_STAGE_GEOMETRY_MASK = (1 << GLSLANG_STAGE_GEOMETRY),
GLSLANG_STAGE_FRAGMENT_MASK = (1 << GLSLANG_STAGE_FRAGMENT),
GLSLANG_STAGE_COMPUTE_MASK = (1 << GLSLANG_STAGE_COMPUTE),
GLSLANG_STAGE_RAYGEN_MASK = (1 << GLSLANG_STAGE_RAYGEN),
GLSLANG_STAGE_RAYGEN_NV_MASK = GLSLANG_STAGE_RAYGEN_MASK,
GLSLANG_STAGE_INTERSECT_MASK = (1 << GLSLANG_STAGE_INTERSECT),
GLSLANG_STAGE_INTERSECT_NV_MASK = GLSLANG_STAGE_INTERSECT_MASK,
GLSLANG_STAGE_ANYHIT_MASK = (1 << GLSLANG_STAGE_ANYHIT),
GLSLANG_STAGE_ANYHIT_NV_MASK = GLSLANG_STAGE_ANYHIT_MASK,
GLSLANG_STAGE_CLOSESTHIT_MASK = (1 << GLSLANG_STAGE_CLOSESTHIT),
GLSLANG_STAGE_CLOSESTHIT_NV_MASK = GLSLANG_STAGE_CLOSESTHIT_MASK,
GLSLANG_STAGE_MISS_MASK = (1 << GLSLANG_STAGE_MISS),
GLSLANG_STAGE_MISS_NV_MASK = GLSLANG_STAGE_MISS_MASK,
GLSLANG_STAGE_CALLABLE_MASK = (1 << GLSLANG_STAGE_CALLABLE),
GLSLANG_STAGE_CALLABLE_NV_MASK = GLSLANG_STAGE_CALLABLE_MASK,
GLSLANG_STAGE_TASK_MASK = (1 << GLSLANG_STAGE_TASK),
GLSLANG_STAGE_TASK_NV_MASK = GLSLANG_STAGE_TASK_MASK,
GLSLANG_STAGE_MESH_MASK = (1 << GLSLANG_STAGE_MESH),
GLSLANG_STAGE_MESH_NV_MASK = GLSLANG_STAGE_MESH_MASK,
LAST_ELEMENT_MARKER(GLSLANG_STAGE_MASK_COUNT),
} glslang_stage_mask_t;
/* EShSource counterpart */
typedef enum {
GLSLANG_SOURCE_NONE,
GLSLANG_SOURCE_GLSL,
GLSLANG_SOURCE_HLSL,
LAST_ELEMENT_MARKER(GLSLANG_SOURCE_COUNT),
} glslang_source_t;
/* EShClient counterpart */
typedef enum {
GLSLANG_CLIENT_NONE,
GLSLANG_CLIENT_VULKAN,
GLSLANG_CLIENT_OPENGL,
LAST_ELEMENT_MARKER(GLSLANG_CLIENT_COUNT),
} glslang_client_t;
/* EShTargetLanguage counterpart */
typedef enum {
GLSLANG_TARGET_NONE,
GLSLANG_TARGET_SPV,
LAST_ELEMENT_MARKER(GLSLANG_TARGET_COUNT),
} glslang_target_language_t;
/* SH_TARGET_ClientVersion counterpart */
typedef enum {
GLSLANG_TARGET_VULKAN_1_0 = (1 << 22),
GLSLANG_TARGET_VULKAN_1_1 = (1 << 22) | (1 << 12),
GLSLANG_TARGET_VULKAN_1_2 = (1 << 22) | (2 << 12),
GLSLANG_TARGET_VULKAN_1_3 = (1 << 22) | (3 << 12),
GLSLANG_TARGET_OPENGL_450 = 450,
LAST_ELEMENT_MARKER(GLSLANG_TARGET_CLIENT_VERSION_COUNT = 5),
} glslang_target_client_version_t;
/* SH_TARGET_LanguageVersion counterpart */
typedef enum {
GLSLANG_TARGET_SPV_1_0 = (1 << 16),
GLSLANG_TARGET_SPV_1_1 = (1 << 16) | (1 << 8),
GLSLANG_TARGET_SPV_1_2 = (1 << 16) | (2 << 8),
GLSLANG_TARGET_SPV_1_3 = (1 << 16) | (3 << 8),
GLSLANG_TARGET_SPV_1_4 = (1 << 16) | (4 << 8),
GLSLANG_TARGET_SPV_1_5 = (1 << 16) | (5 << 8),
GLSLANG_TARGET_SPV_1_6 = (1 << 16) | (6 << 8),
LAST_ELEMENT_MARKER(GLSLANG_TARGET_LANGUAGE_VERSION_COUNT = 7),
} glslang_target_language_version_t;
/* EShExecutable counterpart */
typedef enum { GLSLANG_EX_VERTEX_FRAGMENT, GLSLANG_EX_FRAGMENT } glslang_executable_t;
// EShOptimizationLevel counterpart
// This enum is not used in the current C interface, but could be added at a later date.
// GLSLANG_OPT_NONE is the current default.
typedef enum {
GLSLANG_OPT_NO_GENERATION,
GLSLANG_OPT_NONE,
GLSLANG_OPT_SIMPLE,
GLSLANG_OPT_FULL,
LAST_ELEMENT_MARKER(GLSLANG_OPT_LEVEL_COUNT),
} glslang_optimization_level_t;
/* EShTextureSamplerTransformMode counterpart */
typedef enum {
GLSLANG_TEX_SAMP_TRANS_KEEP,
GLSLANG_TEX_SAMP_TRANS_UPGRADE_TEXTURE_REMOVE_SAMPLER,
LAST_ELEMENT_MARKER(GLSLANG_TEX_SAMP_TRANS_COUNT),
} glslang_texture_sampler_transform_mode_t;
/* EShMessages counterpart */
typedef enum {
GLSLANG_MSG_DEFAULT_BIT = 0,
GLSLANG_MSG_RELAXED_ERRORS_BIT = (1 << 0),
GLSLANG_MSG_SUPPRESS_WARNINGS_BIT = (1 << 1),
GLSLANG_MSG_AST_BIT = (1 << 2),
GLSLANG_MSG_SPV_RULES_BIT = (1 << 3),
GLSLANG_MSG_VULKAN_RULES_BIT = (1 << 4),
GLSLANG_MSG_ONLY_PREPROCESSOR_BIT = (1 << 5),
GLSLANG_MSG_READ_HLSL_BIT = (1 << 6),
GLSLANG_MSG_CASCADING_ERRORS_BIT = (1 << 7),
GLSLANG_MSG_KEEP_UNCALLED_BIT = (1 << 8),
GLSLANG_MSG_HLSL_OFFSETS_BIT = (1 << 9),
GLSLANG_MSG_DEBUG_INFO_BIT = (1 << 10),
GLSLANG_MSG_HLSL_ENABLE_16BIT_TYPES_BIT = (1 << 11),
GLSLANG_MSG_HLSL_LEGALIZATION_BIT = (1 << 12),
GLSLANG_MSG_HLSL_DX9_COMPATIBLE_BIT = (1 << 13),
GLSLANG_MSG_BUILTIN_SYMBOL_TABLE_BIT = (1 << 14),
GLSLANG_MSG_ENHANCED = (1 << 15),
GLSLANG_MSG_ABSOLUTE_PATH = (1 << 16),
GLSLANG_MSG_DISPLAY_ERROR_COLUMN = (1 << 17),
LAST_ELEMENT_MARKER(GLSLANG_MSG_COUNT),
} glslang_messages_t;
/* EShReflectionOptions counterpart */
typedef enum {
GLSLANG_REFLECTION_DEFAULT_BIT = 0,
GLSLANG_REFLECTION_STRICT_ARRAY_SUFFIX_BIT = (1 << 0),
GLSLANG_REFLECTION_BASIC_ARRAY_SUFFIX_BIT = (1 << 1),
GLSLANG_REFLECTION_INTERMEDIATE_IOO_BIT = (1 << 2),
GLSLANG_REFLECTION_SEPARATE_BUFFERS_BIT = (1 << 3),
GLSLANG_REFLECTION_ALL_BLOCK_VARIABLES_BIT = (1 << 4),
GLSLANG_REFLECTION_UNWRAP_IO_BLOCKS_BIT = (1 << 5),
GLSLANG_REFLECTION_ALL_IO_VARIABLES_BIT = (1 << 6),
GLSLANG_REFLECTION_SHARED_STD140_SSBO_BIT = (1 << 7),
GLSLANG_REFLECTION_SHARED_STD140_UBO_BIT = (1 << 8),
LAST_ELEMENT_MARKER(GLSLANG_REFLECTION_COUNT),
} glslang_reflection_options_t;
/* EProfile counterpart (from Versions.h) */
typedef enum {
GLSLANG_BAD_PROFILE = 0,
GLSLANG_NO_PROFILE = (1 << 0),
GLSLANG_CORE_PROFILE = (1 << 1),
GLSLANG_COMPATIBILITY_PROFILE = (1 << 2),
GLSLANG_ES_PROFILE = (1 << 3),
LAST_ELEMENT_MARKER(GLSLANG_PROFILE_COUNT),
} glslang_profile_t;
/* Shader options */
typedef enum {
GLSLANG_SHADER_DEFAULT_BIT = 0,
GLSLANG_SHADER_AUTO_MAP_BINDINGS = (1 << 0),
GLSLANG_SHADER_AUTO_MAP_LOCATIONS = (1 << 1),
GLSLANG_SHADER_VULKAN_RULES_RELAXED = (1 << 2),
LAST_ELEMENT_MARKER(GLSLANG_SHADER_COUNT),
} glslang_shader_options_t;
/* TResourceType counterpart */
typedef enum {
GLSLANG_RESOURCE_TYPE_SAMPLER,
GLSLANG_RESOURCE_TYPE_TEXTURE,
GLSLANG_RESOURCE_TYPE_IMAGE,
GLSLANG_RESOURCE_TYPE_UBO,
GLSLANG_RESOURCE_TYPE_SSBO,
GLSLANG_RESOURCE_TYPE_UAV,
LAST_ELEMENT_MARKER(GLSLANG_RESOURCE_TYPE_COUNT),
} glslang_resource_type_t;
#undef LAST_ELEMENT_MARKER
#endif

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,54 @@
//
// Copyright (C) 2023 LunarG, Inc.
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//
// Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following
// disclaimer in the documentation and/or other materials provided
// with the distribution.
//
// Neither the name of 3Dlabs Inc. Ltd. nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
// FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
// COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
// BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
// ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
#ifdef GLSLANG_IS_SHARED_LIBRARY
#ifdef _WIN32
#ifdef GLSLANG_EXPORTING
#define GLSLANG_EXPORT __declspec(dllexport)
#else
#define GLSLANG_EXPORT __declspec(dllimport)
#endif
#elif __GNUC__ >= 4
#define GLSLANG_EXPORT __attribute__((visibility("default")))
#endif
#endif // GLSLANG_IS_SHARED_LIBRARY
#ifndef GLSLANG_EXPORT
#define GLSLANG_EXPORT
#endif
// Symbols marked with this macro are only meant for public use by the test suite
// and do not appear in publicly installed headers. They are not considered to be
// part of the glslang library ABI.
#define GLSLANG_EXPORT_FOR_TESTS GLSLANG_EXPORT

View file

@ -1,7 +1,7 @@
// //
// Copyright (C) 2002-2005 3Dlabs Inc. Ltd. // Copyright (C) 2002-2005 3Dlabs Inc. Ltd.
// Copyright (C) 2012-2013 LunarG, Inc. // Copyright (C) 2012-2013 LunarG, Inc.
// Copyright (C) 2017 ARM Limited. // Copyright (C) 2017, 2022-2024 Arm Limited.
// Copyright (C) 2015-2018 Google, Inc. // Copyright (C) 2015-2018 Google, Inc.
// Modifications Copyright (C) 2020 Advanced Micro Devices, Inc. All rights reserved. // Modifications Copyright (C) 2020 Advanced Micro Devices, Inc. All rights reserved.
// //
@ -171,9 +171,11 @@ const char* const E_GL_KHR_shader_subgroup_arithmetic = "GL_KHR_shader_sub
const char* const E_GL_KHR_shader_subgroup_ballot = "GL_KHR_shader_subgroup_ballot"; const char* const E_GL_KHR_shader_subgroup_ballot = "GL_KHR_shader_subgroup_ballot";
const char* const E_GL_KHR_shader_subgroup_shuffle = "GL_KHR_shader_subgroup_shuffle"; const char* const E_GL_KHR_shader_subgroup_shuffle = "GL_KHR_shader_subgroup_shuffle";
const char* const E_GL_KHR_shader_subgroup_shuffle_relative = "GL_KHR_shader_subgroup_shuffle_relative"; const char* const E_GL_KHR_shader_subgroup_shuffle_relative = "GL_KHR_shader_subgroup_shuffle_relative";
const char* const E_GL_KHR_shader_subgroup_rotate = "GL_KHR_shader_subgroup_rotate";
const char* const E_GL_KHR_shader_subgroup_clustered = "GL_KHR_shader_subgroup_clustered"; const char* const E_GL_KHR_shader_subgroup_clustered = "GL_KHR_shader_subgroup_clustered";
const char* const E_GL_KHR_shader_subgroup_quad = "GL_KHR_shader_subgroup_quad"; const char* const E_GL_KHR_shader_subgroup_quad = "GL_KHR_shader_subgroup_quad";
const char* const E_GL_KHR_memory_scope_semantics = "GL_KHR_memory_scope_semantics"; const char* const E_GL_KHR_memory_scope_semantics = "GL_KHR_memory_scope_semantics";
const char* const E_GL_KHR_cooperative_matrix = "GL_KHR_cooperative_matrix";
const char* const E_GL_EXT_shader_atomic_int64 = "GL_EXT_shader_atomic_int64"; const char* const E_GL_EXT_shader_atomic_int64 = "GL_EXT_shader_atomic_int64";
@ -214,6 +216,13 @@ const char* const E_GL_EXT_spirv_intrinsics = "GL_EXT_spirv_intr
const char* const E_GL_EXT_fragment_shader_barycentric = "GL_EXT_fragment_shader_barycentric"; const char* const E_GL_EXT_fragment_shader_barycentric = "GL_EXT_fragment_shader_barycentric";
const char* const E_GL_EXT_mesh_shader = "GL_EXT_mesh_shader"; const char* const E_GL_EXT_mesh_shader = "GL_EXT_mesh_shader";
const char* const E_GL_EXT_opacity_micromap = "GL_EXT_opacity_micromap"; const char* const E_GL_EXT_opacity_micromap = "GL_EXT_opacity_micromap";
const char* const E_GL_EXT_shader_quad_control = "GL_EXT_shader_quad_control";
const char* const E_GL_EXT_draw_instanced = "GL_EXT_draw_instanced";
const char* const E_GL_EXT_texture_array = "GL_EXT_texture_array";
const char* const E_GL_EXT_maximal_reconvergence = "GL_EXT_maximal_reconvergence";
const char* const E_GL_EXT_expect_assume = "GL_EXT_expect_assume";
const char* const E_GL_EXT_control_flow_attributes2 = "GL_EXT_control_flow_attributes2";
const char* const E_GL_EXT_spec_constant_composites = "GL_EXT_spec_constant_composites";
// Arrays of extensions for the above viewportEXTs duplications // Arrays of extensions for the above viewportEXTs duplications
@ -234,6 +243,7 @@ const int Num_OVR_multiview_EXTs = sizeof(OVR_multiview_EXTs) / sizeof(OVR_multi
// #line and #include // #line and #include
const char* const E_GL_GOOGLE_cpp_style_line_directive = "GL_GOOGLE_cpp_style_line_directive"; const char* const E_GL_GOOGLE_cpp_style_line_directive = "GL_GOOGLE_cpp_style_line_directive";
const char* const E_GL_GOOGLE_include_directive = "GL_GOOGLE_include_directive"; const char* const E_GL_GOOGLE_include_directive = "GL_GOOGLE_include_directive";
const char* const E_GL_ARB_shading_language_include = "GL_ARB_shading_language_include";
const char* const E_GL_AMD_shader_ballot = "GL_AMD_shader_ballot"; const char* const E_GL_AMD_shader_ballot = "GL_AMD_shader_ballot";
const char* const E_GL_AMD_shader_trinary_minmax = "GL_AMD_shader_trinary_minmax"; const char* const E_GL_AMD_shader_trinary_minmax = "GL_AMD_shader_trinary_minmax";
@ -265,7 +275,13 @@ const char* const E_GL_NV_fragment_shader_barycentric = "GL_NV_fragmen
const char* const E_GL_NV_compute_shader_derivatives = "GL_NV_compute_shader_derivatives"; const char* const E_GL_NV_compute_shader_derivatives = "GL_NV_compute_shader_derivatives";
const char* const E_GL_NV_shader_texture_footprint = "GL_NV_shader_texture_footprint"; const char* const E_GL_NV_shader_texture_footprint = "GL_NV_shader_texture_footprint";
const char* const E_GL_NV_mesh_shader = "GL_NV_mesh_shader"; const char* const E_GL_NV_mesh_shader = "GL_NV_mesh_shader";
const char* const E_GL_NV_cooperative_matrix = "GL_NV_cooperative_matrix";
const char* const E_GL_NV_shader_sm_builtins = "GL_NV_shader_sm_builtins";
const char* const E_GL_NV_integer_cooperative_matrix = "GL_NV_integer_cooperative_matrix";
const char* const E_GL_NV_shader_invocation_reorder = "GL_NV_shader_invocation_reorder";
const char* const E_GL_EXT_ray_tracing_position_fetch = "GL_EXT_ray_tracing_position_fetch"; const char* const E_GL_EXT_ray_tracing_position_fetch = "GL_EXT_ray_tracing_position_fetch";
const char* const E_GL_NV_displacement_micromap = "GL_NV_displacement_micromap";
const char* const E_GL_NV_shader_atomic_fp16_vector = "GL_NV_shader_atomic_fp16_vector";
// ARM // ARM
const char* const E_GL_ARM_shader_core_builtins = "GL_ARM_shader_core_builtins"; const char* const E_GL_ARM_shader_core_builtins = "GL_ARM_shader_core_builtins";
@ -275,10 +291,9 @@ const char* const E_GL_ARM_shader_core_builtins = "GL_ARM_shader
const char* const viewportEXTs[] = { E_GL_ARB_shader_viewport_layer_array, E_GL_NV_viewport_array2 }; const char* const viewportEXTs[] = { E_GL_ARB_shader_viewport_layer_array, E_GL_NV_viewport_array2 };
const int Num_viewportEXTs = sizeof(viewportEXTs) / sizeof(viewportEXTs[0]); const int Num_viewportEXTs = sizeof(viewportEXTs) / sizeof(viewportEXTs[0]);
const char* const E_GL_NV_cooperative_matrix = "GL_NV_cooperative_matrix";
const char* const E_GL_NV_shader_sm_builtins = "GL_NV_shader_sm_builtins"; const char* const E_GL_QCOM_image_processing = "GL_QCOM_image_processing";
const char* const E_GL_NV_integer_cooperative_matrix = "GL_NV_integer_cooperative_matrix"; const char* const E_GL_QCOM_image_processing2 = "GL_QCOM_image_processing2";
const char* const E_GL_NV_shader_invocation_reorder = "GL_NV_shader_invocation_reorder";
// AEP // AEP
const char* const E_GL_ANDROID_extension_pack_es31a = "GL_ANDROID_extension_pack_es31a"; const char* const E_GL_ANDROID_extension_pack_es31a = "GL_ANDROID_extension_pack_es31a";
@ -330,6 +345,8 @@ const char* const E_GL_EXT_shader_atomic_float2 = "GL_EXT_shader_atomic_float2";
const char* const E_GL_EXT_shader_tile_image = "GL_EXT_shader_tile_image"; const char* const E_GL_EXT_shader_tile_image = "GL_EXT_shader_tile_image";
const char* const E_GL_EXT_texture_shadow_lod = "GL_EXT_texture_shadow_lod";
// Arrays of extensions for the above AEP duplications // Arrays of extensions for the above AEP duplications
const char* const AEP_geometry_shader[] = { E_GL_EXT_geometry_shader, E_GL_OES_geometry_shader }; const char* const AEP_geometry_shader[] = { E_GL_EXT_geometry_shader, E_GL_OES_geometry_shader };

File diff suppressed because it is too large Load diff

View file

@ -38,20 +38,21 @@
#include <string> #include <string>
#include "../Include/ResourceLimits.h" #include "../Include/ResourceLimits.h"
#include "../Include/visibility.h"
// Return pointer to user-writable Resource to pass through API in // Return pointer to user-writable Resource to pass through API in
// future-proof way. // future-proof way.
extern TBuiltInResource* GetResources(); GLSLANG_EXPORT extern TBuiltInResource* GetResources();
// These are the default resources for TBuiltInResources, used for both // These are the default resources for TBuiltInResources, used for both
// - parsing this string for the case where the user didn't supply one, // - parsing this string for the case where the user didn't supply one,
// - dumping out a template for user construction of a config file. // - dumping out a template for user construction of a config file.
extern const TBuiltInResource* GetDefaultResources(); GLSLANG_EXPORT extern const TBuiltInResource* GetDefaultResources();
// Returns the DefaultTBuiltInResource as a human-readable string. // Returns the DefaultTBuiltInResource as a human-readable string.
std::string GetDefaultTBuiltInResourceString(); GLSLANG_EXPORT std::string GetDefaultTBuiltInResourceString();
// Decodes the resource limits from |config| to |resources|. // Decodes the resource limits from |config| to |resources|.
void DecodeResourceLimits(TBuiltInResource* resources, char* config); GLSLANG_EXPORT void DecodeResourceLimits(TBuiltInResource* resources, char* config);
#endif // _STAND_ALONE_RESOURCE_LIMITS_INCLUDED_ #endif // _STAND_ALONE_RESOURCE_LIMITS_INCLUDED_

View file

@ -38,6 +38,7 @@
#define _COMPILER_INTERFACE_INCLUDED_ #define _COMPILER_INTERFACE_INCLUDED_
#include "../Include/ResourceLimits.h" #include "../Include/ResourceLimits.h"
#include "../Include/visibility.h"
#include "../MachineIndependent/Versions.h" #include "../MachineIndependent/Versions.h"
#include <cstring> #include <cstring>
@ -49,22 +50,6 @@
#define C_DECL #define C_DECL
#endif #endif
#ifdef GLSLANG_IS_SHARED_LIBRARY
#ifdef _WIN32
#ifdef GLSLANG_EXPORTING
#define GLSLANG_EXPORT __declspec(dllexport)
#else
#define GLSLANG_EXPORT __declspec(dllimport)
#endif
#elif __GNUC__ >= 4
#define GLSLANG_EXPORT __attribute__((visibility("default")))
#endif
#endif // GLSLANG_IS_SHARED_LIBRARY
#ifndef GLSLANG_EXPORT
#define GLSLANG_EXPORT
#endif
// //
// This is the platform independent interface between an OGL driver // This is the platform independent interface between an OGL driver
// and the shading language compiler/linker. // and the shading language compiler/linker.
@ -188,6 +173,21 @@ typedef enum {
LAST_ELEMENT_MARKER(EShTargetLanguageVersionCount = 7), LAST_ELEMENT_MARKER(EShTargetLanguageVersionCount = 7),
} EShTargetLanguageVersion; } EShTargetLanguageVersion;
//
// Following are a series of helper enums for managing layouts and qualifiers,
// used for TPublicType, TType, others.
//
enum TLayoutPacking {
ElpNone,
ElpShared, // default, but different than saying nothing
ElpStd140,
ElpStd430,
ElpPacked,
ElpScalar,
ElpCount // If expanding, see bitfield width below
};
struct TInputLanguage { struct TInputLanguage {
EShSource languageFamily; // redundant information with other input, this one overrides when not EShSourceNone EShSource languageFamily; // redundant information with other input, this one overrides when not EShSourceNone
EShLanguage stage; // redundant information with other input, this one overrides when not EShSourceNone EShLanguage stage; // redundant information with other input, this one overrides when not EShSourceNone
@ -252,23 +252,25 @@ typedef enum {
// Message choices for what errors and warnings are given. // Message choices for what errors and warnings are given.
// //
enum EShMessages : unsigned { enum EShMessages : unsigned {
EShMsgDefault = 0, // default is to give all required errors and extra warnings EShMsgDefault = 0, // default is to give all required errors and extra warnings
EShMsgRelaxedErrors = (1 << 0), // be liberal in accepting input EShMsgRelaxedErrors = (1 << 0), // be liberal in accepting input
EShMsgSuppressWarnings = (1 << 1), // suppress all warnings, except those required by the specification EShMsgSuppressWarnings = (1 << 1), // suppress all warnings, except those required by the specification
EShMsgAST = (1 << 2), // print the AST intermediate representation EShMsgAST = (1 << 2), // print the AST intermediate representation
EShMsgSpvRules = (1 << 3), // issue messages for SPIR-V generation EShMsgSpvRules = (1 << 3), // issue messages for SPIR-V generation
EShMsgVulkanRules = (1 << 4), // issue messages for Vulkan-requirements of GLSL for SPIR-V EShMsgVulkanRules = (1 << 4), // issue messages for Vulkan-requirements of GLSL for SPIR-V
EShMsgOnlyPreprocessor = (1 << 5), // only print out errors produced by the preprocessor EShMsgOnlyPreprocessor = (1 << 5), // only print out errors produced by the preprocessor
EShMsgReadHlsl = (1 << 6), // use HLSL parsing rules and semantics EShMsgReadHlsl = (1 << 6), // use HLSL parsing rules and semantics
EShMsgCascadingErrors = (1 << 7), // get cascading errors; risks error-recovery issues, instead of an early exit EShMsgCascadingErrors = (1 << 7), // get cascading errors; risks error-recovery issues, instead of an early exit
EShMsgKeepUncalled = (1 << 8), // for testing, don't eliminate uncalled functions EShMsgKeepUncalled = (1 << 8), // for testing, don't eliminate uncalled functions
EShMsgHlslOffsets = (1 << 9), // allow block offsets to follow HLSL rules instead of GLSL rules EShMsgHlslOffsets = (1 << 9), // allow block offsets to follow HLSL rules instead of GLSL rules
EShMsgDebugInfo = (1 << 10), // save debug information EShMsgDebugInfo = (1 << 10), // save debug information
EShMsgHlslEnable16BitTypes = (1 << 11), // enable use of 16-bit types in SPIR-V for HLSL EShMsgHlslEnable16BitTypes = (1 << 11), // enable use of 16-bit types in SPIR-V for HLSL
EShMsgHlslLegalization = (1 << 12), // enable HLSL Legalization messages EShMsgHlslLegalization = (1 << 12), // enable HLSL Legalization messages
EShMsgHlslDX9Compatible = (1 << 13), // enable HLSL DX9 compatible mode (for samplers and semantics) EShMsgHlslDX9Compatible = (1 << 13), // enable HLSL DX9 compatible mode (for samplers and semantics)
EShMsgBuiltinSymbolTable = (1 << 14), // print the builtin symbol table EShMsgBuiltinSymbolTable = (1 << 14), // print the builtin symbol table
EShMsgEnhanced = (1 << 15), // enhanced message readability EShMsgEnhanced = (1 << 15), // enhanced message readability
EShMsgAbsolutePath = (1 << 16), // Output Absolute path for messages
EShMsgDisplayErrorColumn = (1 << 17), // Display error message column aswell as line
LAST_ELEMENT_MARKER(EShMsgCount), LAST_ELEMENT_MARKER(EShMsgCount),
}; };
@ -318,8 +320,8 @@ typedef void* ShHandle;
// Driver calls these to create and destroy compiler/linker // Driver calls these to create and destroy compiler/linker
// objects. // objects.
// //
GLSLANG_EXPORT ShHandle ShConstructCompiler(const EShLanguage, int debugOptions); // one per shader GLSLANG_EXPORT ShHandle ShConstructCompiler(const EShLanguage, int /*debugOptions unused*/); // one per shader
GLSLANG_EXPORT ShHandle ShConstructLinker(const EShExecutable, int debugOptions); // one per shader pair GLSLANG_EXPORT ShHandle ShConstructLinker(const EShExecutable, int /*debugOptions unused*/); // one per shader pair
GLSLANG_EXPORT ShHandle ShConstructUniformMap(); // one per uniform namespace (currently entire program object) GLSLANG_EXPORT ShHandle ShConstructUniformMap(); // one per uniform namespace (currently entire program object)
GLSLANG_EXPORT void ShDestruct(ShHandle); GLSLANG_EXPORT void ShDestruct(ShHandle);
@ -330,18 +332,14 @@ GLSLANG_EXPORT void ShDestruct(ShHandle);
// The info-log should be written by ShCompile into // The info-log should be written by ShCompile into
// ShHandle, so it can answer future queries. // ShHandle, so it can answer future queries.
// //
GLSLANG_EXPORT int ShCompile( GLSLANG_EXPORT int ShCompile(const ShHandle, const char* const shaderStrings[], const int numStrings,
const ShHandle, const int* lengths, const EShOptimizationLevel, const TBuiltInResource* resources,
const char* const shaderStrings[], int, // debugOptions unused
const int numStrings, int defaultVersion = 110, // use 100 for ES environment, overridden by #version in shader
const int* lengths, bool forwardCompatible = false, // give errors for use of deprecated features
const EShOptimizationLevel, EShMessages messages = EShMsgDefault, // warnings and errors
const TBuiltInResource *resources, const char* fileName = nullptr
int debugOptions, );
int defaultVersion = 110, // use 100 for ES environment, overridden by #version in shader
bool forwardCompatible = false, // give errors for use of deprecated features
EShMessages messages = EShMsgDefault // warnings and errors
);
GLSLANG_EXPORT int ShLinkExt( GLSLANG_EXPORT int ShLinkExt(
const ShHandle, // linker object const ShHandle, // linker object
@ -417,6 +415,7 @@ GLSLANG_EXPORT int GetKhronosToolId();
class TIntermediate; class TIntermediate;
class TProgram; class TProgram;
class TPoolAllocator; class TPoolAllocator;
class TIoMapResolver;
// Call this exactly once per process before using anything else // Call this exactly once per process before using anything else
GLSLANG_EXPORT bool InitializeProcess(); GLSLANG_EXPORT bool InitializeProcess();
@ -512,6 +511,9 @@ public:
GLSLANG_EXPORT void setAtomicCounterBlockSet(unsigned int set); GLSLANG_EXPORT void setAtomicCounterBlockSet(unsigned int set);
GLSLANG_EXPORT void setAtomicCounterBlockBinding(unsigned int binding); GLSLANG_EXPORT void setAtomicCounterBlockBinding(unsigned int binding);
GLSLANG_EXPORT void addSourceText(const char* text, size_t len);
GLSLANG_EXPORT void setSourceFile(const char* file);
// For setting up the environment (cleared to nothingness in the constructor). // For setting up the environment (cleared to nothingness in the constructor).
// These must be called so that parsing is done for the right source language and // These must be called so that parsing is done for the right source language and
// target environment, either indirectly through TranslateEnvironment() based on // target environment, either indirectly through TranslateEnvironment() based on
@ -573,6 +575,9 @@ public:
void setEnvInputVulkanRulesRelaxed() { environment.input.vulkanRulesRelaxed = true; } void setEnvInputVulkanRulesRelaxed() { environment.input.vulkanRulesRelaxed = true; }
bool getEnvInputVulkanRulesRelaxed() const { return environment.input.vulkanRulesRelaxed; } bool getEnvInputVulkanRulesRelaxed() const { return environment.input.vulkanRulesRelaxed; }
void setCompileOnly() { compileOnly = true; }
bool getCompileOnly() const { return compileOnly; }
// Interface to #include handlers. // Interface to #include handlers.
// //
// To support #include, a client of Glslang does the following: // To support #include, a client of Glslang does the following:
@ -722,14 +727,15 @@ protected:
TEnvironment environment; TEnvironment environment;
// Indicates this shader is meant to be used without linking
bool compileOnly = false;
friend class TProgram; friend class TProgram;
private: private:
TShader& operator=(TShader&); TShader& operator=(TShader&);
}; };
#if !defined(GLSLANG_WEB)
// //
// A reflection database and its interface, consistent with the OpenGL API reflection queries. // A reflection database and its interface, consistent with the OpenGL API reflection queries.
// //
@ -744,6 +750,8 @@ public:
GLSLANG_EXPORT void dump() const; GLSLANG_EXPORT void dump() const;
static TObjectReflection badReflection() { return TObjectReflection(); } static TObjectReflection badReflection() { return TObjectReflection(); }
GLSLANG_EXPORT unsigned int layoutLocation() const;
std::string name; std::string name;
int offset; int offset;
int glDefineType; int glDefineType;
@ -846,7 +854,19 @@ public:
virtual void addStage(EShLanguage stage, TIntermediate& stageIntermediate) = 0; virtual void addStage(EShLanguage stage, TIntermediate& stageIntermediate) = 0;
}; };
#endif // !GLSLANG_WEB // I/O mapper
class TIoMapper {
public:
TIoMapper() {}
virtual ~TIoMapper() {}
// grow the reflection stage by stage
bool virtual addStage(EShLanguage, TIntermediate&, TInfoSink&, TIoMapResolver*);
bool virtual doMap(TIoMapResolver*, TInfoSink&) { return true; }
bool virtual setAutoPushConstantBlock(const char*, unsigned int, TLayoutPacking) { return false; }
};
// Get the default GLSL IO mapper
GLSLANG_EXPORT TIoMapper* GetGlslIoMapper();
// Make one TProgram per set of shaders that will get linked together. Add all // Make one TProgram per set of shaders that will get linked together. Add all
// the shaders that are to be linked together. After calling shader.parse() // the shaders that are to be linked together. After calling shader.parse()
@ -867,8 +887,6 @@ public:
TIntermediate* getIntermediate(EShLanguage stage) const { return intermediate[stage]; } TIntermediate* getIntermediate(EShLanguage stage) const { return intermediate[stage]; }
#if !defined(GLSLANG_WEB)
// Reflection Interface // Reflection Interface
// call first, to do liveness analysis, index mapping, etc.; returns false on failure // call first, to do liveness analysis, index mapping, etc.; returns false on failure
@ -957,11 +975,14 @@ public:
const TType *getAttributeTType(int index) const { return getPipeInput(index).getType(); } const TType *getAttributeTType(int index) const { return getPipeInput(index).getType(); }
GLSLANG_EXPORT void dumpReflection(); GLSLANG_EXPORT void dumpReflection();
// Get the IO resolver to use for mapIO
GLSLANG_EXPORT TIoMapResolver* getGlslIoResolver(EShLanguage stage);
// I/O mapping: apply base offsets and map live unbound variables // I/O mapping: apply base offsets and map live unbound variables
// If resolver is not provided it uses the previous approach // If resolver is not provided it uses the previous approach
// and respects auto assignment and offsets. // and respects auto assignment and offsets.
GLSLANG_EXPORT bool mapIO(TIoMapResolver* pResolver = nullptr, TIoMapper* pIoMapper = nullptr); GLSLANG_EXPORT bool mapIO(TIoMapResolver* pResolver = nullptr, TIoMapper* pIoMapper = nullptr);
#endif // !GLSLANG_WEB
protected: protected:
GLSLANG_EXPORT bool linkStage(EShLanguage, EShMessages); GLSLANG_EXPORT bool linkStage(EShLanguage, EShMessages);
@ -972,9 +993,7 @@ protected:
TIntermediate* intermediate[EShLangCount]; TIntermediate* intermediate[EShLangCount];
bool newedIntermediate[EShLangCount]; // track which intermediate were "new" versus reusing a singleton unit in a stage bool newedIntermediate[EShLangCount]; // track which intermediate were "new" versus reusing a singleton unit in a stage
TInfoSink* infoSink; TInfoSink* infoSink;
#if !defined(GLSLANG_WEB)
TReflection* reflection; TReflection* reflection;
#endif
bool linked; bool linked;
private: private:

View file

@ -30,25 +30,26 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define _STAND_ALONE_RESOURCE_LIMITS_C_INCLUDED_ #define _STAND_ALONE_RESOURCE_LIMITS_C_INCLUDED_
#include "../Include/glslang_c_interface.h" #include "../Include/glslang_c_interface.h"
#include "../Include/visibility.h"
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
// Returns a struct that can be use to create custom resource values. // Returns a struct that can be use to create custom resource values.
glslang_resource_t* glslang_resource(void); GLSLANG_EXPORT glslang_resource_t* glslang_resource(void);
// These are the default resources for TBuiltInResources, used for both // These are the default resources for TBuiltInResources, used for both
// - parsing this string for the case where the user didn't supply one, // - parsing this string for the case where the user didn't supply one,
// - dumping out a template for user construction of a config file. // - dumping out a template for user construction of a config file.
const glslang_resource_t* glslang_default_resource(void); GLSLANG_EXPORT const glslang_resource_t* glslang_default_resource(void);
// Returns the DefaultTBuiltInResource as a human-readable string. // Returns the DefaultTBuiltInResource as a human-readable string.
// NOTE: User is responsible for freeing this string. // NOTE: User is responsible for freeing this string.
const char* glslang_default_resource_string(); GLSLANG_EXPORT const char* glslang_default_resource_string();
// Decodes the resource limits from |config| to |resources|. // Decodes the resource limits from |config| to |resources|.
void glslang_decode_resource_limits(glslang_resource_t* resources, char* config); GLSLANG_EXPORT void glslang_decode_resource_limits(glslang_resource_t* resources, char* config);
#ifdef __cplusplus #ifdef __cplusplus
} }

View file

@ -35,27 +35,35 @@
#pragma once #pragma once
#if defined(_MSC_VER) && _MSC_VER >= 1900
#pragma warning(disable : 4464) // relative include path contains '..'
#endif
#include "SpvTools.h"
#include "glslang/Include/intermediate.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "Logger.h" #include "Logger.h"
#include "glslang/Include/visibility.h"
namespace glslang { namespace glslang {
class TIntermediate;
void GetSpirvVersion(std::string&); struct SpvOptions {
int GetSpirvGeneratorVersion(); bool generateDebugInfo {false};
void GlslangToSpv(const glslang::TIntermediate& intermediate, std::vector<unsigned int>& spirv, bool stripDebugInfo {false};
SpvOptions* options = nullptr); bool disableOptimizer {true};
void GlslangToSpv(const glslang::TIntermediate& intermediate, std::vector<unsigned int>& spirv, bool optimizeSize {false};
spv::SpvBuildLogger* logger, SpvOptions* options = nullptr); bool disassemble {false};
void OutputSpvBin(const std::vector<unsigned int>& spirv, const char* baseName); bool validate {false};
void OutputSpvHex(const std::vector<unsigned int>& spirv, const char* baseName, const char* varName); bool emitNonSemanticShaderDebugInfo {false};
bool emitNonSemanticShaderDebugSource{ false };
bool compileOnly{false};
bool optimizerAllowExpandedIDBound{false};
};
GLSLANG_EXPORT void GetSpirvVersion(std::string&);
GLSLANG_EXPORT int GetSpirvGeneratorVersion();
GLSLANG_EXPORT void GlslangToSpv(const glslang::TIntermediate& intermediate, std::vector<unsigned int>& spirv,
SpvOptions* options = nullptr);
GLSLANG_EXPORT void GlslangToSpv(const glslang::TIntermediate& intermediate, std::vector<unsigned int>& spirv,
spv::SpvBuildLogger* logger, SpvOptions* options = nullptr);
GLSLANG_EXPORT bool OutputSpvBin(const std::vector<unsigned int>& spirv, const char* baseName);
GLSLANG_EXPORT bool OutputSpvHex(const std::vector<unsigned int>& spirv, const char* baseName, const char* varName);
} }

View file

@ -37,23 +37,16 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "glslang/Include/visibility.h"
namespace spv { namespace spv {
// A class for holding all SPIR-V build status messages, including // A class for holding all SPIR-V build status messages, including
// missing/TBD functionalities, warnings, and errors. // missing/TBD functionalities, warnings, and errors.
class SpvBuildLogger { class GLSLANG_EXPORT SpvBuildLogger {
public: public:
SpvBuildLogger() {} SpvBuildLogger() {}
#ifdef GLSLANG_WEB
void tbdFunctionality(const std::string& f) { }
void missingFunctionality(const std::string& f) { }
void warning(const std::string& w) { }
void error(const std::string& e) { errors.push_back(e); }
std::string getAllMessages() { return ""; }
#else
// Registers a TBD functionality. // Registers a TBD functionality.
void tbdFunctionality(const std::string& f); void tbdFunctionality(const std::string& f);
// Registers a missing functionality. // Registers a missing functionality.
@ -67,7 +60,6 @@ public:
// Returns all messages accumulated in the order of: // Returns all messages accumulated in the order of:
// TBD functionalities, missing functionalities, warnings, errors. // TBD functionalities, missing functionalities, warnings, errors.
std::string getAllMessages() const; std::string getAllMessages() const;
#endif
private: private:
SpvBuildLogger(const SpvBuildLogger&); SpvBuildLogger(const SpvBuildLogger&);

View file

@ -41,6 +41,21 @@
#include <cstdlib> #include <cstdlib>
#include <exception> #include <exception>
#ifdef GLSLANG_IS_SHARED_LIBRARY
#ifdef _WIN32
#ifdef GLSLANG_EXPORTING
#define GLSLANG_EXPORT __declspec(dllexport)
#else
#define GLSLANG_EXPORT __declspec(dllimport)
#endif
#elif __GNUC__ >= 4
#define GLSLANG_EXPORT __attribute__((visibility("default")))
#endif
#endif // GLSLANG_IS_SHARED_LIBRARY
#ifndef GLSLANG_EXPORT
#define GLSLANG_EXPORT
#endif
namespace spv { namespace spv {
class spirvbin_base_t class spirvbin_base_t
@ -77,12 +92,13 @@ public:
#include <cassert> #include <cassert>
#include "spirv.hpp" #include "spirv.hpp"
#include "spvIR.h"
namespace spv { namespace spv {
static inline constexpr Id NoResult = 0;
// class to hold SPIR-V binary data for remapping, DCE, and debug stripping // class to hold SPIR-V binary data for remapping, DCE, and debug stripping
class spirvbin_t : public spirvbin_base_t class GLSLANG_EXPORT spirvbin_t : public spirvbin_base_t
{ {
public: public:
spirvbin_t(int verbose = 0) : entryPoint(spv::NoResult), largestNewId(0), verbose(verbose), errorLatch(false) spirvbin_t(int verbose = 0) : entryPoint(spv::NoResult), largestNewId(0), verbose(verbose), errorLatch(false)

View file

@ -44,66 +44,63 @@
#if ENABLE_OPT #if ENABLE_OPT
#include <vector> #include <vector>
#include <ostream> #include <ostream>
#include <unordered_set>
#include "spirv-tools/libspirv.h" #include "spirv-tools/libspirv.h"
#endif #endif
#include "glslang/MachineIndependent/localintermediate.h" #include "glslang/MachineIndependent/Versions.h"
#include "glslang/Include/visibility.h"
#include "GlslangToSpv.h"
#include "Logger.h" #include "Logger.h"
namespace glslang { namespace glslang {
struct SpvOptions {
bool generateDebugInfo {false};
bool stripDebugInfo {false};
bool disableOptimizer {true};
bool optimizeSize {false};
bool disassemble {false};
bool validate {false};
bool emitNonSemanticShaderDebugInfo {false};
bool emitNonSemanticShaderDebugSource{ false };
};
#if ENABLE_OPT #if ENABLE_OPT
class TIntermediate;
// Translate glslang's view of target versioning to what SPIRV-Tools uses. // Translate glslang's view of target versioning to what SPIRV-Tools uses.
spv_target_env MapToSpirvToolsEnv(const SpvVersion& spvVersion, spv::SpvBuildLogger* logger); GLSLANG_EXPORT spv_target_env MapToSpirvToolsEnv(const SpvVersion& spvVersion, spv::SpvBuildLogger* logger);
GLSLANG_EXPORT spv_target_env MapToSpirvToolsEnv(const glslang::TIntermediate& intermediate, spv::SpvBuildLogger* logger);
// Use the SPIRV-Tools disassembler to print SPIR-V using a SPV_ENV_UNIVERSAL_1_3 environment. // Use the SPIRV-Tools disassembler to print SPIR-V using a SPV_ENV_UNIVERSAL_1_3 environment.
void SpirvToolsDisassemble(std::ostream& out, const std::vector<unsigned int>& spirv); GLSLANG_EXPORT void SpirvToolsDisassemble(std::ostream& out, const std::vector<unsigned int>& spirv);
// Use the SPIRV-Tools disassembler to print SPIR-V with a provided SPIR-V environment. // Use the SPIRV-Tools disassembler to print SPIR-V with a provided SPIR-V environment.
void SpirvToolsDisassemble(std::ostream& out, const std::vector<unsigned int>& spirv, GLSLANG_EXPORT void SpirvToolsDisassemble(std::ostream& out, const std::vector<unsigned int>& spirv,
spv_target_env requested_context); spv_target_env requested_context);
// Apply the SPIRV-Tools validator to generated SPIR-V. // Apply the SPIRV-Tools validator to generated SPIR-V.
void SpirvToolsValidate(const glslang::TIntermediate& intermediate, std::vector<unsigned int>& spirv, GLSLANG_EXPORT void SpirvToolsValidate(const glslang::TIntermediate& intermediate, std::vector<unsigned int>& spirv,
spv::SpvBuildLogger*, bool prelegalization); spv::SpvBuildLogger*, bool prelegalization);
// Apply the SPIRV-Tools optimizer to generated SPIR-V. HLSL SPIR-V is legalized in the process. // Apply the SPIRV-Tools optimizer to generated SPIR-V. HLSL SPIR-V is legalized in the process.
void SpirvToolsTransform(const glslang::TIntermediate& intermediate, std::vector<unsigned int>& spirv, GLSLANG_EXPORT void SpirvToolsTransform(const glslang::TIntermediate& intermediate, std::vector<unsigned int>& spirv,
spv::SpvBuildLogger*, const SpvOptions*); spv::SpvBuildLogger*, const SpvOptions*);
// Apply the SPIRV-Tools EliminateDeadInputComponents pass to generated SPIR-V. Put result in |spirv|. // Apply the SPIRV-Tools EliminateDeadInputComponents pass to generated SPIR-V. Put result in |spirv|.
void SpirvToolsEliminateDeadInputComponents(spv_target_env target_env, std::vector<unsigned int>& spirv, GLSLANG_EXPORT void SpirvToolsEliminateDeadInputComponents(spv_target_env target_env, std::vector<unsigned int>& spirv,
spv::SpvBuildLogger*); spv::SpvBuildLogger*);
// Apply the SPIRV-Tools AnalyzeDeadOutputStores pass to generated SPIR-V. Put result in |live_locs|. // Apply the SPIRV-Tools AnalyzeDeadOutputStores pass to generated SPIR-V. Put result in |live_locs|.
// Return true if the result is valid. // Return true if the result is valid.
bool SpirvToolsAnalyzeDeadOutputStores(spv_target_env target_env, std::vector<unsigned int>& spirv, GLSLANG_EXPORT bool SpirvToolsAnalyzeDeadOutputStores(spv_target_env target_env, std::vector<unsigned int>& spirv,
std::unordered_set<uint32_t>* live_locs, std::unordered_set<uint32_t>* live_locs,
std::unordered_set<uint32_t>* live_builtins, spv::SpvBuildLogger*); std::unordered_set<uint32_t>* live_builtins,
spv::SpvBuildLogger*);
// Apply the SPIRV-Tools EliminateDeadOutputStores and AggressiveDeadCodeElimination passes to generated SPIR-V using // Apply the SPIRV-Tools EliminateDeadOutputStores and AggressiveDeadCodeElimination passes to generated SPIR-V using
// |live_locs|. Put result in |spirv|. // |live_locs|. Put result in |spirv|.
void SpirvToolsEliminateDeadOutputStores(spv_target_env target_env, std::vector<unsigned int>& spirv, GLSLANG_EXPORT void SpirvToolsEliminateDeadOutputStores(spv_target_env target_env, std::vector<unsigned int>& spirv,
std::unordered_set<uint32_t>* live_locs, std::unordered_set<uint32_t>* live_locs,
std::unordered_set<uint32_t>* live_builtins, spv::SpvBuildLogger*); std::unordered_set<uint32_t>* live_builtins,
spv::SpvBuildLogger*);
// Apply the SPIRV-Tools optimizer to strip debug info from SPIR-V. This is implicitly done by // Apply the SPIRV-Tools optimizer to strip debug info from SPIR-V. This is implicitly done by
// SpirvToolsTransform if spvOptions->stripDebugInfo is set, but can be called separately if // SpirvToolsTransform if spvOptions->stripDebugInfo is set, but can be called separately if
// optimization is disabled. // optimization is disabled.
void SpirvToolsStripDebugInfo(const glslang::TIntermediate& intermediate, GLSLANG_EXPORT void SpirvToolsStripDebugInfo(const glslang::TIntermediate& intermediate,
std::vector<unsigned int>& spirv, spv::SpvBuildLogger*); std::vector<unsigned int>& spirv, spv::SpvBuildLogger*);
#endif #endif

View file

@ -0,0 +1,55 @@
//
// Copyright (C) 2014-2015 LunarG, Inc.
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//
// Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following
// disclaimer in the documentation and/or other materials provided
// with the distribution.
//
// Neither the name of 3Dlabs Inc. Ltd. nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
// FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
// COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
// BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
// ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
// Disassembler for SPIR-V.
//
#pragma once
#ifndef disassembler_H
#define disassembler_H
#include <iostream>
#include <vector>
#include "glslang/Include/visibility.h"
namespace spv {
// disassemble with glslang custom disassembler
GLSLANG_EXPORT void Disassemble(std::ostream& out, const std::vector<unsigned int>&);
} // end namespace spv
#endif // disassembler_H

View file

@ -1,19 +1,19 @@
// Copyright (c) 2014-2020 The Khronos Group Inc. // Copyright (c) 2014-2024 The Khronos Group Inc.
// //
// Permission is hereby granted, free of charge, to any person obtaining a copy // Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and/or associated documentation files (the "Materials"), // of this software and/or associated documentation files (the "Materials"),
// to deal in the Materials without restriction, including without limitation // to deal in the Materials without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense, // the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Materials, and to permit persons to whom the // and/or sell copies of the Materials, and to permit persons to whom the
// Materials are furnished to do so, subject to the following conditions: // Materials are furnished to do so, subject to the following conditions:
// //
// The above copyright notice and this permission notice shall be included in // The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Materials. // all copies or substantial portions of the Materials.
// //
// MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS KHRONOS // MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS KHRONOS
// STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS SPECIFICATIONS AND // STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS SPECIFICATIONS AND
// HEADER INFORMATION ARE LOCATED AT https://www.khronos.org/registry/ // HEADER INFORMATION ARE LOCATED AT https://www.khronos.org/registry/
// //
// THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS // THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
@ -27,7 +27,7 @@
// Enumeration tokens for SPIR-V, in various styles: // Enumeration tokens for SPIR-V, in various styles:
// C, C++, C++11, JSON, Lua, Python, C#, D, Beef // C, C++, C++11, JSON, Lua, Python, C#, D, Beef
// //
// - C will have tokens with a "Spv" prefix, e.g.: SpvSourceLanguageGLSL // - C will have tokens with a "Spv" prefix, e.g.: SpvSourceLanguageGLSL
// - C++ will have tokens in the "spv" name space, e.g.: spv::SourceLanguageGLSL // - C++ will have tokens in the "spv" name space, e.g.: spv::SourceLanguageGLSL
// - C++11 will use enum classes in the spv namespace, e.g.: spv::SourceLanguage::GLSL // - C++11 will use enum classes in the spv namespace, e.g.: spv::SourceLanguage::GLSL
@ -38,7 +38,7 @@
// - D will have tokens under the "spv" module, e.g: spv.SourceLanguage.GLSL // - D will have tokens under the "spv" module, e.g: spv.SourceLanguage.GLSL
// - Beef will use enum classes in the Specification class located in the "Spv" namespace, // - Beef will use enum classes in the Specification class located in the "Spv" namespace,
// e.g.: Spv.Specification.SourceLanguage.GLSL // e.g.: Spv.Specification.SourceLanguage.GLSL
// //
// Some tokens act like mask values, which can be OR'd together, // Some tokens act like mask values, which can be OR'd together,
// while others are mutually exclusive. The mask-like ones have // while others are mutually exclusive. The mask-like ones have
// "Mask" in their name, and a parallel enum that has the shift // "Mask" in their name, and a parallel enum that has the shift
@ -174,6 +174,8 @@ enum ExecutionMode {
ExecutionModeStencilRefUnchangedBackAMD = 5082, ExecutionModeStencilRefUnchangedBackAMD = 5082,
ExecutionModeStencilRefGreaterBackAMD = 5083, ExecutionModeStencilRefGreaterBackAMD = 5083,
ExecutionModeStencilRefLessBackAMD = 5084, ExecutionModeStencilRefLessBackAMD = 5084,
ExecutionModeQuadDerivativesKHR = 5088,
ExecutionModeRequireFullQuadsKHR = 5089,
ExecutionModeOutputLinesEXT = 5269, ExecutionModeOutputLinesEXT = 5269,
ExecutionModeOutputLinesNV = 5269, ExecutionModeOutputLinesNV = 5269,
ExecutionModeOutputPrimitivesEXT = 5270, ExecutionModeOutputPrimitivesEXT = 5270,
@ -198,6 +200,7 @@ enum ExecutionMode {
ExecutionModeNoGlobalOffsetINTEL = 5895, ExecutionModeNoGlobalOffsetINTEL = 5895,
ExecutionModeNumSIMDWorkitemsINTEL = 5896, ExecutionModeNumSIMDWorkitemsINTEL = 5896,
ExecutionModeSchedulerTargetFmaxMhzINTEL = 5903, ExecutionModeSchedulerTargetFmaxMhzINTEL = 5903,
ExecutionModeMaximallyReconvergesKHR = 6023,
ExecutionModeStreamingInterfaceINTEL = 6154, ExecutionModeStreamingInterfaceINTEL = 6154,
ExecutionModeNamedBarrierCountINTEL = 6417, ExecutionModeNamedBarrierCountINTEL = 6417,
ExecutionModeMax = 0x7fffffff, ExecutionModeMax = 0x7fffffff,
@ -382,7 +385,7 @@ enum ImageOperandsShift {
ImageOperandsMax = 0x7fffffff, ImageOperandsMax = 0x7fffffff,
}; };
enum ImageOperandsMask { enum ImageOperandsMask : unsigned {
ImageOperandsMaskNone = 0, ImageOperandsMaskNone = 0,
ImageOperandsBiasMask = 0x00000001, ImageOperandsBiasMask = 0x00000001,
ImageOperandsLodMask = 0x00000002, ImageOperandsLodMask = 0x00000002,
@ -417,7 +420,7 @@ enum FPFastMathModeShift {
FPFastMathModeMax = 0x7fffffff, FPFastMathModeMax = 0x7fffffff,
}; };
enum FPFastMathModeMask { enum FPFastMathModeMask : unsigned {
FPFastMathModeMaskNone = 0, FPFastMathModeMaskNone = 0,
FPFastMathModeNotNaNMask = 0x00000001, FPFastMathModeNotNaNMask = 0x00000001,
FPFastMathModeNotInfMask = 0x00000002, FPFastMathModeNotInfMask = 0x00000002,
@ -513,6 +516,9 @@ enum Decoration {
DecorationMaxByteOffsetId = 47, DecorationMaxByteOffsetId = 47,
DecorationNoSignedWrap = 4469, DecorationNoSignedWrap = 4469,
DecorationNoUnsignedWrap = 4470, DecorationNoUnsignedWrap = 4470,
DecorationWeightTextureQCOM = 4487,
DecorationBlockMatchTextureQCOM = 4488,
DecorationBlockMatchSamplerQCOM = 4499,
DecorationExplicitInterpAMD = 4999, DecorationExplicitInterpAMD = 4999,
DecorationOverrideCoverageNV = 5248, DecorationOverrideCoverageNV = 5248,
DecorationPassthroughNV = 5250, DecorationPassthroughNV = 5250,
@ -718,6 +724,8 @@ enum BuiltIn {
BuiltInHitKindNV = 5333, BuiltInHitKindNV = 5333,
BuiltInCurrentRayTimeNV = 5334, BuiltInCurrentRayTimeNV = 5334,
BuiltInHitTriangleVertexPositionsKHR = 5335, BuiltInHitTriangleVertexPositionsKHR = 5335,
BuiltInHitMicroTriangleVertexPositionsNV = 5337,
BuiltInHitMicroTriangleVertexBarycentricsNV = 5344,
BuiltInIncomingRayFlagsKHR = 5351, BuiltInIncomingRayFlagsKHR = 5351,
BuiltInIncomingRayFlagsNV = 5351, BuiltInIncomingRayFlagsNV = 5351,
BuiltInRayGeometryIndexKHR = 5352, BuiltInRayGeometryIndexKHR = 5352,
@ -725,6 +733,8 @@ enum BuiltIn {
BuiltInSMCountNV = 5375, BuiltInSMCountNV = 5375,
BuiltInWarpIDNV = 5376, BuiltInWarpIDNV = 5376,
BuiltInSMIDNV = 5377, BuiltInSMIDNV = 5377,
BuiltInHitKindFrontFacingMicroTriangleNV = 5405,
BuiltInHitKindBackFacingMicroTriangleNV = 5406,
BuiltInCullMaskKHR = 6021, BuiltInCullMaskKHR = 6021,
BuiltInMax = 0x7fffffff, BuiltInMax = 0x7fffffff,
}; };
@ -735,7 +745,7 @@ enum SelectionControlShift {
SelectionControlMax = 0x7fffffff, SelectionControlMax = 0x7fffffff,
}; };
enum SelectionControlMask { enum SelectionControlMask : unsigned {
SelectionControlMaskNone = 0, SelectionControlMaskNone = 0,
SelectionControlFlattenMask = 0x00000001, SelectionControlFlattenMask = 0x00000001,
SelectionControlDontFlattenMask = 0x00000002, SelectionControlDontFlattenMask = 0x00000002,
@ -764,7 +774,7 @@ enum LoopControlShift {
LoopControlMax = 0x7fffffff, LoopControlMax = 0x7fffffff,
}; };
enum LoopControlMask { enum LoopControlMask : unsigned {
LoopControlMaskNone = 0, LoopControlMaskNone = 0,
LoopControlUnrollMask = 0x00000001, LoopControlUnrollMask = 0x00000001,
LoopControlDontUnrollMask = 0x00000002, LoopControlDontUnrollMask = 0x00000002,
@ -796,7 +806,7 @@ enum FunctionControlShift {
FunctionControlMax = 0x7fffffff, FunctionControlMax = 0x7fffffff,
}; };
enum FunctionControlMask { enum FunctionControlMask : unsigned {
FunctionControlMaskNone = 0, FunctionControlMaskNone = 0,
FunctionControlInlineMask = 0x00000001, FunctionControlInlineMask = 0x00000001,
FunctionControlDontInlineMask = 0x00000002, FunctionControlDontInlineMask = 0x00000002,
@ -826,7 +836,7 @@ enum MemorySemanticsShift {
MemorySemanticsMax = 0x7fffffff, MemorySemanticsMax = 0x7fffffff,
}; };
enum MemorySemanticsMask { enum MemorySemanticsMask : unsigned {
MemorySemanticsMaskNone = 0, MemorySemanticsMaskNone = 0,
MemorySemanticsAcquireMask = 0x00000002, MemorySemanticsAcquireMask = 0x00000002,
MemorySemanticsReleaseMask = 0x00000004, MemorySemanticsReleaseMask = 0x00000004,
@ -862,7 +872,7 @@ enum MemoryAccessShift {
MemoryAccessMax = 0x7fffffff, MemoryAccessMax = 0x7fffffff,
}; };
enum MemoryAccessMask { enum MemoryAccessMask : unsigned {
MemoryAccessMaskNone = 0, MemoryAccessMaskNone = 0,
MemoryAccessVolatileMask = 0x00000001, MemoryAccessVolatileMask = 0x00000001,
MemoryAccessAlignedMask = 0x00000002, MemoryAccessAlignedMask = 0x00000002,
@ -912,7 +922,7 @@ enum KernelProfilingInfoShift {
KernelProfilingInfoMax = 0x7fffffff, KernelProfilingInfoMax = 0x7fffffff,
}; };
enum KernelProfilingInfoMask { enum KernelProfilingInfoMask : unsigned {
KernelProfilingInfoMaskNone = 0, KernelProfilingInfoMaskNone = 0,
KernelProfilingInfoCmdExecTimeMask = 0x00000001, KernelProfilingInfoCmdExecTimeMask = 0x00000001,
}; };
@ -992,6 +1002,7 @@ enum Capability {
CapabilityTileImageColorReadAccessEXT = 4166, CapabilityTileImageColorReadAccessEXT = 4166,
CapabilityTileImageDepthReadAccessEXT = 4167, CapabilityTileImageDepthReadAccessEXT = 4167,
CapabilityTileImageStencilReadAccessEXT = 4168, CapabilityTileImageStencilReadAccessEXT = 4168,
CapabilityCooperativeMatrixLayoutsARM = 4201,
CapabilityFragmentShadingRateKHR = 4422, CapabilityFragmentShadingRateKHR = 4422,
CapabilitySubgroupBallotKHR = 4423, CapabilitySubgroupBallotKHR = 4423,
CapabilityDrawParameters = 4427, CapabilityDrawParameters = 4427,
@ -1023,6 +1034,10 @@ enum Capability {
CapabilityRayQueryKHR = 4472, CapabilityRayQueryKHR = 4472,
CapabilityRayTraversalPrimitiveCullingKHR = 4478, CapabilityRayTraversalPrimitiveCullingKHR = 4478,
CapabilityRayTracingKHR = 4479, CapabilityRayTracingKHR = 4479,
CapabilityTextureSampleWeightedQCOM = 4484,
CapabilityTextureBoxFilterQCOM = 4485,
CapabilityTextureBlockMatchQCOM = 4486,
CapabilityTextureBlockMatch2QCOM = 4498,
CapabilityFloat16ImageAMD = 5008, CapabilityFloat16ImageAMD = 5008,
CapabilityImageGatherBiasLodAMD = 5009, CapabilityImageGatherBiasLodAMD = 5009,
CapabilityFragmentMaskAMD = 5010, CapabilityFragmentMaskAMD = 5010,
@ -1030,6 +1045,7 @@ enum Capability {
CapabilityImageReadWriteLodAMD = 5015, CapabilityImageReadWriteLodAMD = 5015,
CapabilityInt64ImageEXT = 5016, CapabilityInt64ImageEXT = 5016,
CapabilityShaderClockKHR = 5055, CapabilityShaderClockKHR = 5055,
CapabilityQuadControlKHR = 5087,
CapabilitySampleMaskOverrideCoverageNV = 5249, CapabilitySampleMaskOverrideCoverageNV = 5249,
CapabilityGeometryShaderPassthroughNV = 5251, CapabilityGeometryShaderPassthroughNV = 5251,
CapabilityShaderViewportIndexLayerEXT = 5254, CapabilityShaderViewportIndexLayerEXT = 5254,
@ -1089,10 +1105,13 @@ enum Capability {
CapabilityFragmentShaderPixelInterlockEXT = 5378, CapabilityFragmentShaderPixelInterlockEXT = 5378,
CapabilityDemoteToHelperInvocation = 5379, CapabilityDemoteToHelperInvocation = 5379,
CapabilityDemoteToHelperInvocationEXT = 5379, CapabilityDemoteToHelperInvocationEXT = 5379,
CapabilityDisplacementMicromapNV = 5380,
CapabilityRayTracingOpacityMicromapEXT = 5381, CapabilityRayTracingOpacityMicromapEXT = 5381,
CapabilityShaderInvocationReorderNV = 5383, CapabilityShaderInvocationReorderNV = 5383,
CapabilityBindlessTextureNV = 5390, CapabilityBindlessTextureNV = 5390,
CapabilityRayQueryPositionFetchKHR = 5391, CapabilityRayQueryPositionFetchKHR = 5391,
CapabilityAtomicFloat16VectorNV = 5404,
CapabilityRayTracingDisplacementMicromapNV = 5409,
CapabilitySubgroupShuffleINTEL = 5568, CapabilitySubgroupShuffleINTEL = 5568,
CapabilitySubgroupBufferBlockIOINTEL = 5569, CapabilitySubgroupBufferBlockIOINTEL = 5569,
CapabilitySubgroupImageBlockIOINTEL = 5570, CapabilitySubgroupImageBlockIOINTEL = 5570,
@ -1144,6 +1163,8 @@ enum Capability {
CapabilityDotProduct = 6019, CapabilityDotProduct = 6019,
CapabilityDotProductKHR = 6019, CapabilityDotProductKHR = 6019,
CapabilityRayCullMaskKHR = 6020, CapabilityRayCullMaskKHR = 6020,
CapabilityCooperativeMatrixKHR = 6022,
CapabilityReplicatedCompositesEXT = 6024,
CapabilityBitInstructions = 6025, CapabilityBitInstructions = 6025,
CapabilityGroupNonUniformRotateKHR = 6026, CapabilityGroupNonUniformRotateKHR = 6026,
CapabilityAtomicFloat32AddEXT = 6033, CapabilityAtomicFloat32AddEXT = 6033,
@ -1173,7 +1194,7 @@ enum RayFlagsShift {
RayFlagsMax = 0x7fffffff, RayFlagsMax = 0x7fffffff,
}; };
enum RayFlagsMask { enum RayFlagsMask : unsigned {
RayFlagsMaskNone = 0, RayFlagsMaskNone = 0,
RayFlagsOpaqueKHRMask = 0x00000001, RayFlagsOpaqueKHRMask = 0x00000001,
RayFlagsNoOpaqueKHRMask = 0x00000002, RayFlagsNoOpaqueKHRMask = 0x00000002,
@ -1215,7 +1236,7 @@ enum FragmentShadingRateShift {
FragmentShadingRateMax = 0x7fffffff, FragmentShadingRateMax = 0x7fffffff,
}; };
enum FragmentShadingRateMask { enum FragmentShadingRateMask : unsigned {
FragmentShadingRateMaskNone = 0, FragmentShadingRateMaskNone = 0,
FragmentShadingRateVertical2PixelsMask = 0x00000001, FragmentShadingRateVertical2PixelsMask = 0x00000001,
FragmentShadingRateVertical4PixelsMask = 0x00000002, FragmentShadingRateVertical4PixelsMask = 0x00000002,
@ -1261,6 +1282,39 @@ enum PackedVectorFormat {
PackedVectorFormatMax = 0x7fffffff, PackedVectorFormatMax = 0x7fffffff,
}; };
enum CooperativeMatrixOperandsShift {
CooperativeMatrixOperandsMatrixASignedComponentsKHRShift = 0,
CooperativeMatrixOperandsMatrixBSignedComponentsKHRShift = 1,
CooperativeMatrixOperandsMatrixCSignedComponentsKHRShift = 2,
CooperativeMatrixOperandsMatrixResultSignedComponentsKHRShift = 3,
CooperativeMatrixOperandsSaturatingAccumulationKHRShift = 4,
CooperativeMatrixOperandsMax = 0x7fffffff,
};
enum CooperativeMatrixOperandsMask : unsigned {
CooperativeMatrixOperandsMaskNone = 0,
CooperativeMatrixOperandsMatrixASignedComponentsKHRMask = 0x00000001,
CooperativeMatrixOperandsMatrixBSignedComponentsKHRMask = 0x00000002,
CooperativeMatrixOperandsMatrixCSignedComponentsKHRMask = 0x00000004,
CooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask = 0x00000008,
CooperativeMatrixOperandsSaturatingAccumulationKHRMask = 0x00000010,
};
enum CooperativeMatrixLayout {
CooperativeMatrixLayoutRowMajorKHR = 0,
CooperativeMatrixLayoutColumnMajorKHR = 1,
CooperativeMatrixLayoutRowBlockedInterleavedARM = 4202,
CooperativeMatrixLayoutColumnBlockedInterleavedARM = 4203,
CooperativeMatrixLayoutMax = 0x7fffffff,
};
enum CooperativeMatrixUse {
CooperativeMatrixUseMatrixAKHR = 0,
CooperativeMatrixUseMatrixBKHR = 1,
CooperativeMatrixUseMatrixAccumulatorKHR = 2,
CooperativeMatrixUseMax = 0x7fffffff,
};
enum Op { enum Op {
OpNop = 0, OpNop = 0,
OpUndef = 1, OpUndef = 1,
@ -1617,6 +1671,7 @@ enum Op {
OpSubgroupAllEqualKHR = 4430, OpSubgroupAllEqualKHR = 4430,
OpGroupNonUniformRotateKHR = 4431, OpGroupNonUniformRotateKHR = 4431,
OpSubgroupReadInvocationKHR = 4432, OpSubgroupReadInvocationKHR = 4432,
OpExtInstWithForwardRefsKHR = 4433,
OpTraceRayKHR = 4445, OpTraceRayKHR = 4445,
OpExecuteCallableKHR = 4446, OpExecuteCallableKHR = 4446,
OpConvertUToAccelerationStructureKHR = 4447, OpConvertUToAccelerationStructureKHR = 4447,
@ -1634,6 +1689,14 @@ enum Op {
OpUDotAccSatKHR = 4454, OpUDotAccSatKHR = 4454,
OpSUDotAccSat = 4455, OpSUDotAccSat = 4455,
OpSUDotAccSatKHR = 4455, OpSUDotAccSatKHR = 4455,
OpTypeCooperativeMatrixKHR = 4456,
OpCooperativeMatrixLoadKHR = 4457,
OpCooperativeMatrixStoreKHR = 4458,
OpCooperativeMatrixMulAddKHR = 4459,
OpCooperativeMatrixLengthKHR = 4460,
OpConstantCompositeReplicateEXT = 4461,
OpSpecConstantCompositeReplicateEXT = 4462,
OpCompositeConstructReplicateEXT = 4463,
OpTypeRayQueryKHR = 4472, OpTypeRayQueryKHR = 4472,
OpRayQueryInitializeKHR = 4473, OpRayQueryInitializeKHR = 4473,
OpRayQueryTerminateKHR = 4474, OpRayQueryTerminateKHR = 4474,
@ -1641,6 +1704,14 @@ enum Op {
OpRayQueryConfirmIntersectionKHR = 4476, OpRayQueryConfirmIntersectionKHR = 4476,
OpRayQueryProceedKHR = 4477, OpRayQueryProceedKHR = 4477,
OpRayQueryGetIntersectionTypeKHR = 4479, OpRayQueryGetIntersectionTypeKHR = 4479,
OpImageSampleWeightedQCOM = 4480,
OpImageBoxFilterQCOM = 4481,
OpImageBlockMatchSSDQCOM = 4482,
OpImageBlockMatchSADQCOM = 4483,
OpImageBlockMatchWindowSSDQCOM = 4500,
OpImageBlockMatchWindowSADQCOM = 4501,
OpImageBlockMatchGatherSSDQCOM = 4502,
OpImageBlockMatchGatherSADQCOM = 4503,
OpGroupIAddNonUniformAMD = 5000, OpGroupIAddNonUniformAMD = 5000,
OpGroupFAddNonUniformAMD = 5001, OpGroupFAddNonUniformAMD = 5001,
OpGroupFMinNonUniformAMD = 5002, OpGroupFMinNonUniformAMD = 5002,
@ -1652,6 +1723,8 @@ enum Op {
OpFragmentMaskFetchAMD = 5011, OpFragmentMaskFetchAMD = 5011,
OpFragmentFetchAMD = 5012, OpFragmentFetchAMD = 5012,
OpReadClockKHR = 5056, OpReadClockKHR = 5056,
OpGroupNonUniformQuadAllKHR = 5110,
OpGroupNonUniformQuadAnyKHR = 5111,
OpHitObjectRecordHitMotionNV = 5249, OpHitObjectRecordHitMotionNV = 5249,
OpHitObjectRecordHitWithIndexMotionNV = 5250, OpHitObjectRecordHitWithIndexMotionNV = 5250,
OpHitObjectRecordMissMotionNV = 5251, OpHitObjectRecordMissMotionNV = 5251,
@ -1690,6 +1763,8 @@ enum Op {
OpSetMeshOutputsEXT = 5295, OpSetMeshOutputsEXT = 5295,
OpGroupNonUniformPartitionNV = 5296, OpGroupNonUniformPartitionNV = 5296,
OpWritePackedPrimitiveIndices4x8NV = 5299, OpWritePackedPrimitiveIndices4x8NV = 5299,
OpFetchMicroTriangleVertexPositionNV = 5300,
OpFetchMicroTriangleVertexBarycentricNV = 5301,
OpReportIntersectionKHR = 5334, OpReportIntersectionKHR = 5334,
OpReportIntersectionNV = 5334, OpReportIntersectionNV = 5334,
OpIgnoreIntersectionNV = 5335, OpIgnoreIntersectionNV = 5335,
@ -2324,6 +2399,7 @@ inline void HasResultAndType(Op opcode, bool *hasResult, bool *hasResultType) {
case OpPtrEqual: *hasResult = true; *hasResultType = true; break; case OpPtrEqual: *hasResult = true; *hasResultType = true; break;
case OpPtrNotEqual: *hasResult = true; *hasResultType = true; break; case OpPtrNotEqual: *hasResult = true; *hasResultType = true; break;
case OpPtrDiff: *hasResult = true; *hasResultType = true; break; case OpPtrDiff: *hasResult = true; *hasResultType = true; break;
case OpExtInstWithForwardRefsKHR: *hasResult = true; *hasResultType = true; break;
case OpColorAttachmentReadEXT: *hasResult = true; *hasResultType = true; break; case OpColorAttachmentReadEXT: *hasResult = true; *hasResultType = true; break;
case OpDepthAttachmentReadEXT: *hasResult = true; *hasResultType = true; break; case OpDepthAttachmentReadEXT: *hasResult = true; *hasResultType = true; break;
case OpStencilAttachmentReadEXT: *hasResult = true; *hasResultType = true; break; case OpStencilAttachmentReadEXT: *hasResult = true; *hasResultType = true; break;
@ -2346,6 +2422,14 @@ inline void HasResultAndType(Op opcode, bool *hasResult, bool *hasResultType) {
case OpSDotAccSat: *hasResult = true; *hasResultType = true; break; case OpSDotAccSat: *hasResult = true; *hasResultType = true; break;
case OpUDotAccSat: *hasResult = true; *hasResultType = true; break; case OpUDotAccSat: *hasResult = true; *hasResultType = true; break;
case OpSUDotAccSat: *hasResult = true; *hasResultType = true; break; case OpSUDotAccSat: *hasResult = true; *hasResultType = true; break;
case OpTypeCooperativeMatrixKHR: *hasResult = true; *hasResultType = false; break;
case OpCooperativeMatrixLoadKHR: *hasResult = true; *hasResultType = true; break;
case OpCooperativeMatrixStoreKHR: *hasResult = false; *hasResultType = false; break;
case OpCooperativeMatrixMulAddKHR: *hasResult = true; *hasResultType = true; break;
case OpCooperativeMatrixLengthKHR: *hasResult = true; *hasResultType = true; break;
case OpConstantCompositeReplicateEXT: *hasResult = true; *hasResultType = true; break;
case OpSpecConstantCompositeReplicateEXT: *hasResult = true; *hasResultType = true; break;
case OpCompositeConstructReplicateEXT: *hasResult = true; *hasResultType = true; break;
case OpTypeRayQueryKHR: *hasResult = true; *hasResultType = false; break; case OpTypeRayQueryKHR: *hasResult = true; *hasResultType = false; break;
case OpRayQueryInitializeKHR: *hasResult = false; *hasResultType = false; break; case OpRayQueryInitializeKHR: *hasResult = false; *hasResultType = false; break;
case OpRayQueryTerminateKHR: *hasResult = false; *hasResultType = false; break; case OpRayQueryTerminateKHR: *hasResult = false; *hasResultType = false; break;
@ -2353,6 +2437,14 @@ inline void HasResultAndType(Op opcode, bool *hasResult, bool *hasResultType) {
case OpRayQueryConfirmIntersectionKHR: *hasResult = false; *hasResultType = false; break; case OpRayQueryConfirmIntersectionKHR: *hasResult = false; *hasResultType = false; break;
case OpRayQueryProceedKHR: *hasResult = true; *hasResultType = true; break; case OpRayQueryProceedKHR: *hasResult = true; *hasResultType = true; break;
case OpRayQueryGetIntersectionTypeKHR: *hasResult = true; *hasResultType = true; break; case OpRayQueryGetIntersectionTypeKHR: *hasResult = true; *hasResultType = true; break;
case OpImageSampleWeightedQCOM: *hasResult = true; *hasResultType = true; break;
case OpImageBoxFilterQCOM: *hasResult = true; *hasResultType = true; break;
case OpImageBlockMatchSSDQCOM: *hasResult = true; *hasResultType = true; break;
case OpImageBlockMatchSADQCOM: *hasResult = true; *hasResultType = true; break;
case OpImageBlockMatchWindowSSDQCOM: *hasResult = true; *hasResultType = true; break;
case OpImageBlockMatchWindowSADQCOM: *hasResult = true; *hasResultType = true; break;
case OpImageBlockMatchGatherSSDQCOM: *hasResult = true; *hasResultType = true; break;
case OpImageBlockMatchGatherSADQCOM: *hasResult = true; *hasResultType = true; break;
case OpGroupIAddNonUniformAMD: *hasResult = true; *hasResultType = true; break; case OpGroupIAddNonUniformAMD: *hasResult = true; *hasResultType = true; break;
case OpGroupFAddNonUniformAMD: *hasResult = true; *hasResultType = true; break; case OpGroupFAddNonUniformAMD: *hasResult = true; *hasResultType = true; break;
case OpGroupFMinNonUniformAMD: *hasResult = true; *hasResultType = true; break; case OpGroupFMinNonUniformAMD: *hasResult = true; *hasResultType = true; break;
@ -2364,6 +2456,8 @@ inline void HasResultAndType(Op opcode, bool *hasResult, bool *hasResultType) {
case OpFragmentMaskFetchAMD: *hasResult = true; *hasResultType = true; break; case OpFragmentMaskFetchAMD: *hasResult = true; *hasResultType = true; break;
case OpFragmentFetchAMD: *hasResult = true; *hasResultType = true; break; case OpFragmentFetchAMD: *hasResult = true; *hasResultType = true; break;
case OpReadClockKHR: *hasResult = true; *hasResultType = true; break; case OpReadClockKHR: *hasResult = true; *hasResultType = true; break;
case OpGroupNonUniformQuadAllKHR: *hasResult = true; *hasResultType = true; break;
case OpGroupNonUniformQuadAnyKHR: *hasResult = true; *hasResultType = true; break;
case OpHitObjectRecordHitMotionNV: *hasResult = false; *hasResultType = false; break; case OpHitObjectRecordHitMotionNV: *hasResult = false; *hasResultType = false; break;
case OpHitObjectRecordHitWithIndexMotionNV: *hasResult = false; *hasResultType = false; break; case OpHitObjectRecordHitWithIndexMotionNV: *hasResult = false; *hasResultType = false; break;
case OpHitObjectRecordMissMotionNV: *hasResult = false; *hasResultType = false; break; case OpHitObjectRecordMissMotionNV: *hasResult = false; *hasResultType = false; break;
@ -2722,6 +2816,10 @@ inline FragmentShadingRateMask operator|(FragmentShadingRateMask a, FragmentShad
inline FragmentShadingRateMask operator&(FragmentShadingRateMask a, FragmentShadingRateMask b) { return FragmentShadingRateMask(unsigned(a) & unsigned(b)); } inline FragmentShadingRateMask operator&(FragmentShadingRateMask a, FragmentShadingRateMask b) { return FragmentShadingRateMask(unsigned(a) & unsigned(b)); }
inline FragmentShadingRateMask operator^(FragmentShadingRateMask a, FragmentShadingRateMask b) { return FragmentShadingRateMask(unsigned(a) ^ unsigned(b)); } inline FragmentShadingRateMask operator^(FragmentShadingRateMask a, FragmentShadingRateMask b) { return FragmentShadingRateMask(unsigned(a) ^ unsigned(b)); }
inline FragmentShadingRateMask operator~(FragmentShadingRateMask a) { return FragmentShadingRateMask(~unsigned(a)); } inline FragmentShadingRateMask operator~(FragmentShadingRateMask a) { return FragmentShadingRateMask(~unsigned(a)); }
inline CooperativeMatrixOperandsMask operator|(CooperativeMatrixOperandsMask a, CooperativeMatrixOperandsMask b) { return CooperativeMatrixOperandsMask(unsigned(a) | unsigned(b)); }
inline CooperativeMatrixOperandsMask operator&(CooperativeMatrixOperandsMask a, CooperativeMatrixOperandsMask b) { return CooperativeMatrixOperandsMask(unsigned(a) & unsigned(b)); }
inline CooperativeMatrixOperandsMask operator^(CooperativeMatrixOperandsMask a, CooperativeMatrixOperandsMask b) { return CooperativeMatrixOperandsMask(unsigned(a) ^ unsigned(b)); }
inline CooperativeMatrixOperandsMask operator~(CooperativeMatrixOperandsMask a) { return CooperativeMatrixOperandsMask(~unsigned(a)); }
} // end namespace spv } // end namespace spv

View file

@ -1,520 +0,0 @@
//
// Copyright (C) 2014 LunarG, Inc.
// Copyright (C) 2015-2018 Google, Inc.
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//
// Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following
// disclaimer in the documentation and/or other materials provided
// with the distribution.
//
// Neither the name of 3Dlabs Inc. Ltd. nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
// FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
// COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
// BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
// ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
// SPIRV-IR
//
// Simple in-memory representation (IR) of SPIRV. Just for holding
// Each function's CFG of blocks. Has this hierarchy:
// - Module, which is a list of
// - Function, which is a list of
// - Block, which is a list of
// - Instruction
//
#pragma once
#ifndef spvIR_H
#define spvIR_H
#include "spirv.hpp"
#include <algorithm>
#include <cassert>
#include <functional>
#include <iostream>
#include <memory>
#include <vector>
#include <set>
namespace spv {
class Block;
class Function;
class Module;
const Id NoResult = 0;
const Id NoType = 0;
const Decoration NoPrecision = DecorationMax;
#ifdef __GNUC__
# define POTENTIALLY_UNUSED __attribute__((unused))
#else
# define POTENTIALLY_UNUSED
#endif
POTENTIALLY_UNUSED
const MemorySemanticsMask MemorySemanticsAllMemory =
(MemorySemanticsMask)(MemorySemanticsUniformMemoryMask |
MemorySemanticsWorkgroupMemoryMask |
MemorySemanticsAtomicCounterMemoryMask |
MemorySemanticsImageMemoryMask);
struct IdImmediate {
bool isId; // true if word is an Id, false if word is an immediate
unsigned word;
IdImmediate(bool i, unsigned w) : isId(i), word(w) {}
};
//
// SPIR-V IR instruction.
//
class Instruction {
public:
Instruction(Id resultId, Id typeId, Op opCode) : resultId(resultId), typeId(typeId), opCode(opCode), block(nullptr) { }
explicit Instruction(Op opCode) : resultId(NoResult), typeId(NoType), opCode(opCode), block(nullptr) { }
virtual ~Instruction() {}
void addIdOperand(Id id) {
operands.push_back(id);
idOperand.push_back(true);
}
void addImmediateOperand(unsigned int immediate) {
operands.push_back(immediate);
idOperand.push_back(false);
}
void setImmediateOperand(unsigned idx, unsigned int immediate) {
assert(!idOperand[idx]);
operands[idx] = immediate;
}
void addStringOperand(const char* str)
{
unsigned int word = 0;
unsigned int shiftAmount = 0;
char c;
do {
c = *(str++);
word |= ((unsigned int)c) << shiftAmount;
shiftAmount += 8;
if (shiftAmount == 32) {
addImmediateOperand(word);
word = 0;
shiftAmount = 0;
}
} while (c != 0);
// deal with partial last word
if (shiftAmount > 0) {
addImmediateOperand(word);
}
}
bool isIdOperand(int op) const { return idOperand[op]; }
void setBlock(Block* b) { block = b; }
Block* getBlock() const { return block; }
Op getOpCode() const { return opCode; }
int getNumOperands() const
{
assert(operands.size() == idOperand.size());
return (int)operands.size();
}
Id getResultId() const { return resultId; }
Id getTypeId() const { return typeId; }
Id getIdOperand(int op) const {
assert(idOperand[op]);
return operands[op];
}
unsigned int getImmediateOperand(int op) const {
assert(!idOperand[op]);
return operands[op];
}
// Write out the binary form.
void dump(std::vector<unsigned int>& out) const
{
// Compute the wordCount
unsigned int wordCount = 1;
if (typeId)
++wordCount;
if (resultId)
++wordCount;
wordCount += (unsigned int)operands.size();
// Write out the beginning of the instruction
out.push_back(((wordCount) << WordCountShift) | opCode);
if (typeId)
out.push_back(typeId);
if (resultId)
out.push_back(resultId);
// Write out the operands
for (int op = 0; op < (int)operands.size(); ++op)
out.push_back(operands[op]);
}
protected:
Instruction(const Instruction&);
Id resultId;
Id typeId;
Op opCode;
std::vector<Id> operands; // operands, both <id> and immediates (both are unsigned int)
std::vector<bool> idOperand; // true for operands that are <id>, false for immediates
Block* block;
};
//
// SPIR-V IR block.
//
class Block {
public:
Block(Id id, Function& parent);
virtual ~Block()
{
}
Id getId() { return instructions.front()->getResultId(); }
Function& getParent() const { return parent; }
void addInstruction(std::unique_ptr<Instruction> inst);
void addPredecessor(Block* pred) { predecessors.push_back(pred); pred->successors.push_back(this);}
void addLocalVariable(std::unique_ptr<Instruction> inst) { localVariables.push_back(std::move(inst)); }
const std::vector<Block*>& getPredecessors() const { return predecessors; }
const std::vector<Block*>& getSuccessors() const { return successors; }
const std::vector<std::unique_ptr<Instruction> >& getInstructions() const {
return instructions;
}
const std::vector<std::unique_ptr<Instruction> >& getLocalVariables() const { return localVariables; }
void setUnreachable() { unreachable = true; }
bool isUnreachable() const { return unreachable; }
// Returns the block's merge instruction, if one exists (otherwise null).
const Instruction* getMergeInstruction() const {
if (instructions.size() < 2) return nullptr;
const Instruction* nextToLast = (instructions.cend() - 2)->get();
switch (nextToLast->getOpCode()) {
case OpSelectionMerge:
case OpLoopMerge:
return nextToLast;
default:
return nullptr;
}
return nullptr;
}
// Change this block into a canonical dead merge block. Delete instructions
// as necessary. A canonical dead merge block has only an OpLabel and an
// OpUnreachable.
void rewriteAsCanonicalUnreachableMerge() {
assert(localVariables.empty());
// Delete all instructions except for the label.
assert(instructions.size() > 0);
instructions.resize(1);
successors.clear();
addInstruction(std::unique_ptr<Instruction>(new Instruction(OpUnreachable)));
}
// Change this block into a canonical dead continue target branching to the
// given header ID. Delete instructions as necessary. A canonical dead continue
// target has only an OpLabel and an unconditional branch back to the corresponding
// header.
void rewriteAsCanonicalUnreachableContinue(Block* header) {
assert(localVariables.empty());
// Delete all instructions except for the label.
assert(instructions.size() > 0);
instructions.resize(1);
successors.clear();
// Add OpBranch back to the header.
assert(header != nullptr);
Instruction* branch = new Instruction(OpBranch);
branch->addIdOperand(header->getId());
addInstruction(std::unique_ptr<Instruction>(branch));
successors.push_back(header);
}
bool isTerminated() const
{
switch (instructions.back()->getOpCode()) {
case OpBranch:
case OpBranchConditional:
case OpSwitch:
case OpKill:
case OpTerminateInvocation:
case OpReturn:
case OpReturnValue:
case OpUnreachable:
return true;
default:
return false;
}
}
void dump(std::vector<unsigned int>& out) const
{
instructions[0]->dump(out);
for (int i = 0; i < (int)localVariables.size(); ++i)
localVariables[i]->dump(out);
for (int i = 1; i < (int)instructions.size(); ++i)
instructions[i]->dump(out);
}
protected:
Block(const Block&);
Block& operator=(Block&);
// To enforce keeping parent and ownership in sync:
friend Function;
std::vector<std::unique_ptr<Instruction> > instructions;
std::vector<Block*> predecessors, successors;
std::vector<std::unique_ptr<Instruction> > localVariables;
Function& parent;
// track whether this block is known to be uncreachable (not necessarily
// true for all unreachable blocks, but should be set at least
// for the extraneous ones introduced by the builder).
bool unreachable;
};
// The different reasons for reaching a block in the inReadableOrder traversal.
enum ReachReason {
// Reachable from the entry block via transfers of control, i.e. branches.
ReachViaControlFlow = 0,
// A continue target that is not reachable via control flow.
ReachDeadContinue,
// A merge block that is not reachable via control flow.
ReachDeadMerge
};
// Traverses the control-flow graph rooted at root in an order suited for
// readable code generation. Invokes callback at every node in the traversal
// order. The callback arguments are:
// - the block,
// - the reason we reached the block,
// - if the reason was that block is an unreachable continue or unreachable merge block
// then the last parameter is the corresponding header block.
void inReadableOrder(Block* root, std::function<void(Block*, ReachReason, Block* header)> callback);
//
// SPIR-V IR Function.
//
class Function {
public:
Function(Id id, Id resultType, Id functionType, Id firstParam, Module& parent);
virtual ~Function()
{
for (int i = 0; i < (int)parameterInstructions.size(); ++i)
delete parameterInstructions[i];
for (int i = 0; i < (int)blocks.size(); ++i)
delete blocks[i];
}
Id getId() const { return functionInstruction.getResultId(); }
Id getParamId(int p) const { return parameterInstructions[p]->getResultId(); }
Id getParamType(int p) const { return parameterInstructions[p]->getTypeId(); }
void addBlock(Block* block) { blocks.push_back(block); }
void removeBlock(Block* block)
{
auto found = find(blocks.begin(), blocks.end(), block);
assert(found != blocks.end());
blocks.erase(found);
delete block;
}
Module& getParent() const { return parent; }
Block* getEntryBlock() const { return blocks.front(); }
Block* getLastBlock() const { return blocks.back(); }
const std::vector<Block*>& getBlocks() const { return blocks; }
void addLocalVariable(std::unique_ptr<Instruction> inst);
Id getReturnType() const { return functionInstruction.getTypeId(); }
Id getFuncId() const { return functionInstruction.getResultId(); }
void setReturnPrecision(Decoration precision)
{
if (precision == DecorationRelaxedPrecision)
reducedPrecisionReturn = true;
}
Decoration getReturnPrecision() const
{ return reducedPrecisionReturn ? DecorationRelaxedPrecision : NoPrecision; }
void setDebugLineInfo(Id fileName, int line, int column) {
lineInstruction = std::unique_ptr<Instruction>{new Instruction(OpLine)};
lineInstruction->addIdOperand(fileName);
lineInstruction->addImmediateOperand(line);
lineInstruction->addImmediateOperand(column);
}
bool hasDebugLineInfo() const { return lineInstruction != nullptr; }
void setImplicitThis() { implicitThis = true; }
bool hasImplicitThis() const { return implicitThis; }
void addParamPrecision(unsigned param, Decoration precision)
{
if (precision == DecorationRelaxedPrecision)
reducedPrecisionParams.insert(param);
}
Decoration getParamPrecision(unsigned param) const
{
return reducedPrecisionParams.find(param) != reducedPrecisionParams.end() ?
DecorationRelaxedPrecision : NoPrecision;
}
void dump(std::vector<unsigned int>& out) const
{
// OpLine
if (lineInstruction != nullptr) {
lineInstruction->dump(out);
}
// OpFunction
functionInstruction.dump(out);
// OpFunctionParameter
for (int p = 0; p < (int)parameterInstructions.size(); ++p)
parameterInstructions[p]->dump(out);
// Blocks
inReadableOrder(blocks[0], [&out](const Block* b, ReachReason, Block*) { b->dump(out); });
Instruction end(0, 0, OpFunctionEnd);
end.dump(out);
}
protected:
Function(const Function&);
Function& operator=(Function&);
Module& parent;
std::unique_ptr<Instruction> lineInstruction;
Instruction functionInstruction;
std::vector<Instruction*> parameterInstructions;
std::vector<Block*> blocks;
bool implicitThis; // true if this is a member function expecting to be passed a 'this' as the first argument
bool reducedPrecisionReturn;
std::set<int> reducedPrecisionParams; // list of parameter indexes that need a relaxed precision arg
};
//
// SPIR-V IR Module.
//
class Module {
public:
Module() {}
virtual ~Module()
{
// TODO delete things
}
void addFunction(Function *fun) { functions.push_back(fun); }
void mapInstruction(Instruction *instruction)
{
spv::Id resultId = instruction->getResultId();
// map the instruction's result id
if (resultId >= idToInstruction.size())
idToInstruction.resize(resultId + 16);
idToInstruction[resultId] = instruction;
}
Instruction* getInstruction(Id id) const { return idToInstruction[id]; }
const std::vector<Function*>& getFunctions() const { return functions; }
spv::Id getTypeId(Id resultId) const {
return idToInstruction[resultId] == nullptr ? NoType : idToInstruction[resultId]->getTypeId();
}
StorageClass getStorageClass(Id typeId) const
{
assert(idToInstruction[typeId]->getOpCode() == spv::OpTypePointer);
return (StorageClass)idToInstruction[typeId]->getImmediateOperand(0);
}
void dump(std::vector<unsigned int>& out) const
{
for (int f = 0; f < (int)functions.size(); ++f)
functions[f]->dump(out);
}
protected:
Module(const Module&);
std::vector<Function*> functions;
// map from result id to instruction having that result id
std::vector<Instruction*> idToInstruction;
// map from a result id to its type id
};
//
// Implementation (it's here due to circular type definitions).
//
// Add both
// - the OpFunction instruction
// - all the OpFunctionParameter instructions
__inline Function::Function(Id id, Id resultType, Id functionType, Id firstParamId, Module& parent)
: parent(parent), lineInstruction(nullptr),
functionInstruction(id, resultType, OpFunction), implicitThis(false),
reducedPrecisionReturn(false)
{
// OpFunction
functionInstruction.addImmediateOperand(FunctionControlMaskNone);
functionInstruction.addIdOperand(functionType);
parent.mapInstruction(&functionInstruction);
parent.addFunction(this);
// OpFunctionParameter
Instruction* typeInst = parent.getInstruction(functionType);
int numParams = typeInst->getNumOperands() - 1;
for (int p = 0; p < numParams; ++p) {
Instruction* param = new Instruction(firstParamId + p, typeInst->getIdOperand(p + 1), OpFunctionParameter);
parent.mapInstruction(param);
parameterInstructions.push_back(param);
}
}
__inline void Function::addLocalVariable(std::unique_ptr<Instruction> inst)
{
Instruction* raw_instruction = inst.get();
blocks[0]->addLocalVariable(std::move(inst));
parent.mapInstruction(raw_instruction);
}
__inline Block::Block(Id id, Function& parent) : parent(parent), unreachable(false)
{
instructions.push_back(std::unique_ptr<Instruction>(new Instruction(id, NoType, OpLabel)));
instructions.back()->setBlock(this);
parent.getParent().mapInstruction(instructions.back().get());
}
__inline void Block::addInstruction(std::unique_ptr<Instruction> inst)
{
Instruction* raw_instruction = inst.get();
instructions.push_back(std::move(inst));
raw_instruction->setBlock(this);
if (raw_instruction->getResultId())
parent.getParent().mapInstruction(raw_instruction);
}
} // end spv namespace
#endif // spvIR_H

View file

@ -0,0 +1,62 @@
// Copyright (C) 2020 The Khronos Group Inc.
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//
// Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following
// disclaimer in the documentation and/or other materials provided
// with the distribution.
//
// Neither the name of The Khronos Group Inc. nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
// FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
// COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
// BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
// ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
#ifndef GLSLANG_BUILD_INFO
#define GLSLANG_BUILD_INFO
#define GLSLANG_VERSION_MAJOR 15
#define GLSLANG_VERSION_MINOR 0
#define GLSLANG_VERSION_PATCH 0
#define GLSLANG_VERSION_FLAVOR ""
#define GLSLANG_VERSION_GREATER_THAN(major, minor, patch) \
((GLSLANG_VERSION_MAJOR) > (major) || ((major) == GLSLANG_VERSION_MAJOR && \
((GLSLANG_VERSION_MINOR) > (minor) || ((minor) == GLSLANG_VERSION_MINOR && \
(GLSLANG_VERSION_PATCH) > (patch)))))
#define GLSLANG_VERSION_GREATER_OR_EQUAL_TO(major, minor, patch) \
((GLSLANG_VERSION_MAJOR) > (major) || ((major) == GLSLANG_VERSION_MAJOR && \
((GLSLANG_VERSION_MINOR) > (minor) || ((minor) == GLSLANG_VERSION_MINOR && \
(GLSLANG_VERSION_PATCH >= (patch))))))
#define GLSLANG_VERSION_LESS_THAN(major, minor, patch) \
((GLSLANG_VERSION_MAJOR) < (major) || ((major) == GLSLANG_VERSION_MAJOR && \
((GLSLANG_VERSION_MINOR) < (minor) || ((minor) == GLSLANG_VERSION_MINOR && \
(GLSLANG_VERSION_PATCH) < (patch)))))
#define GLSLANG_VERSION_LESS_OR_EQUAL_TO(major, minor, patch) \
((GLSLANG_VERSION_MAJOR) < (major) || ((major) == GLSLANG_VERSION_MAJOR && \
((GLSLANG_VERSION_MINOR) < (minor) || ((minor) == GLSLANG_VERSION_MINOR && \
(GLSLANG_VERSION_PATCH <= (patch))))))
#endif // GLSLANG_BUILD_INFO

View file

@ -489,6 +489,12 @@ SHADERC_EXPORT void shaderc_compile_options_set_hlsl_functionality1(
SHADERC_EXPORT void shaderc_compile_options_set_hlsl_16bit_types( SHADERC_EXPORT void shaderc_compile_options_set_hlsl_16bit_types(
shaderc_compile_options_t options, bool enable); shaderc_compile_options_t options, bool enable);
// Enables or disables relaxed Vulkan rules.
//
// This allows most OpenGL shaders to compile under Vulkan semantics.
SHADERC_EXPORT void shaderc_compile_options_set_vulkan_rules_relaxed(
shaderc_compile_options_t options, bool enable);
// Sets whether the compiler should invert position.Y output in vertex shader. // Sets whether the compiler should invert position.Y output in vertex shader.
SHADERC_EXPORT void shaderc_compile_options_set_invert_y( SHADERC_EXPORT void shaderc_compile_options_set_invert_y(
shaderc_compile_options_t options, bool enable); shaderc_compile_options_t options, bool enable);

View file

@ -353,12 +353,19 @@ class CompileOptions {
shaderc_compile_options_set_hlsl_16bit_types(options_, enable); shaderc_compile_options_set_hlsl_16bit_types(options_, enable);
} }
// Enables or disables relaxed Vulkan rules.
//
// This allows most OpenGL shaders to compile under Vulkan semantics.
void SetVulkanRulesRelaxed(bool enable) {
shaderc_compile_options_set_vulkan_rules_relaxed(options_, enable);
}
// Sets whether the compiler should invert position.Y output in vertex shader. // Sets whether the compiler should invert position.Y output in vertex shader.
void SetInvertY(bool enable) { void SetInvertY(bool enable) {
shaderc_compile_options_set_invert_y(options_, enable); shaderc_compile_options_set_invert_y(options_, enable);
} }
// Sets whether the compiler should generates code for max an min which, // Sets whether the compiler should generate code for max and min which,
// if given a NaN operand, will return the other operand. Similarly, the // if given a NaN operand, will return the other operand. Similarly, the
// clamp builtin will favour the non-NaN operands, as if clamp were // clamp builtin will favour the non-NaN operands, as if clamp were
// implemented as a composition of max and min. // implemented as a composition of max and min.

View file

@ -0,0 +1,134 @@
#ifndef SLANG_COM_HELPER_H
#define SLANG_COM_HELPER_H
/** \file slang-com-helper.h
*/
#include "slang.h"
#include <atomic>
/* !!!!!!!!!!!!!!!!!!!!! Macros to help checking SlangResult !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/
/*! Set SLANG_HANDLE_RESULT_FAIL(x) to code to be executed whenever an error occurs, and is detected by one of the macros */
#ifndef SLANG_HANDLE_RESULT_FAIL
# define SLANG_HANDLE_RESULT_FAIL(x)
#endif
//! Helper macro, that makes it easy to add result checking to calls in functions/methods that themselves return Result.
#define SLANG_RETURN_ON_FAIL(x) { SlangResult _res = (x); if (SLANG_FAILED(_res)) { SLANG_HANDLE_RESULT_FAIL(_res); return _res; } }
//! Helper macro that can be used to test the return value from a call, and will return in a void method/function
#define SLANG_RETURN_VOID_ON_FAIL(x) { SlangResult _res = (x); if (SLANG_FAILED(_res)) { SLANG_HANDLE_RESULT_FAIL(_res); return; } }
//! Helper macro that will return false on failure.
#define SLANG_RETURN_FALSE_ON_FAIL(x) { SlangResult _res = (x); if (SLANG_FAILED(_res)) { SLANG_HANDLE_RESULT_FAIL(_res); return false; } }
//! Helper macro that will return nullptr on failure.
#define SLANG_RETURN_NULL_ON_FAIL(x) { SlangResult _res = (x); if (SLANG_FAILED(_res)) { SLANG_HANDLE_RESULT_FAIL(_res); return nullptr; } }
//! Helper macro that will assert if the return code from a call is failure, also returns the failure.
#define SLANG_ASSERT_ON_FAIL(x) { SlangResult _res = (x); if (SLANG_FAILED(_res)) { assert(false); return _res; } }
//! Helper macro that will assert if the result from a call is a failure, also returns.
#define SLANG_ASSERT_VOID_ON_FAIL(x) { SlangResult _res = (x); if (SLANG_FAILED(_res)) { assert(false); return; } }
/* !!!!!!!!!!!!!!!!!!!!!!! C++ helpers !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/
#if defined(__cplusplus)
namespace Slang {
// Alias SlangResult to Slang::Result
typedef SlangResult Result;
// Alias SlangUUID to Slang::Guid
typedef SlangUUID Guid;
} // namespace Slang
// Operator == and != for Guid/SlangUUID
SLANG_FORCE_INLINE bool operator==(const Slang::Guid& aIn, const Slang::Guid& bIn)
{
using namespace Slang;
// Use the largest type the honors the alignment of Guid
typedef uint32_t CmpType;
union GuidCompare
{
Guid guid;
CmpType data[sizeof(Guid) / sizeof(CmpType)];
};
// Type pun - so compiler can 'see' the pun and not break aliasing rules
const CmpType* a = reinterpret_cast<const GuidCompare&>(aIn).data;
const CmpType* b = reinterpret_cast<const GuidCompare&>(bIn).data;
// Make the guid comparison a single branch, by not using short circuit
return ((a[0] ^ b[0]) | (a[1] ^ b[1]) | (a[2] ^ b[2]) | (a[3] ^ b[3])) == 0;
}
SLANG_FORCE_INLINE bool operator!=(const Slang::Guid& a, const Slang::Guid& b)
{
return !(a == b);
}
/* !!!!!!!! Macros to simplify implementing COM interfaces !!!!!!!!!!!!!!!!!!!!!!!!!!!! */
/* Assumes underlying implementation has a member m_refCount that is initialized to 0 and can have ++ and -- operate on it.
For SLANG_IUNKNOWN_QUERY_INTERFACE to work - must have a method 'getInterface' that returns valid pointers for the Guid, or nullptr
if not found. */
#define SLANG_IUNKNOWN_QUERY_INTERFACE \
SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) SLANG_OVERRIDE \
{ \
ISlangUnknown* intf = getInterface(uuid); \
if (intf) \
{ \
addRef(); \
*outObject = intf; \
return SLANG_OK;\
} \
return SLANG_E_NO_INTERFACE;\
}
#define SLANG_IUNKNOWN_ADD_REF \
SLANG_NO_THROW uint32_t SLANG_MCALL addRef() \
{ \
return ++m_refCount; \
}
#define SLANG_IUNKNOWN_RELEASE \
SLANG_NO_THROW uint32_t SLANG_MCALL release() \
{ \
--m_refCount; \
if (m_refCount == 0) \
{ \
delete this; \
return 0; \
} \
return m_refCount; \
}
#define SLANG_IUNKNOWN_ALL \
SLANG_IUNKNOWN_QUERY_INTERFACE \
SLANG_IUNKNOWN_ADD_REF \
SLANG_IUNKNOWN_RELEASE
// ------------------------ RefObject IUnknown -----------------------------
#define SLANG_REF_OBJECT_IUNKNOWN_QUERY_INTERFACE \
SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) SLANG_OVERRIDE \
{ \
void* intf = getInterface(uuid); \
if (intf) \
{ \
addReference(); \
*outObject = intf; \
return SLANG_OK;\
} \
return SLANG_E_NO_INTERFACE;\
}
#define SLANG_REF_OBJECT_IUNKNOWN_ADD_REF SLANG_NO_THROW uint32_t SLANG_MCALL addRef() SLANG_OVERRIDE { return (uint32_t)addReference(); }
#define SLANG_REF_OBJECT_IUNKNOWN_RELEASE SLANG_NO_THROW uint32_t SLANG_MCALL release() SLANG_OVERRIDE { return (uint32_t)releaseReference(); }
# define SLANG_REF_OBJECT_IUNKNOWN_ALL \
SLANG_REF_OBJECT_IUNKNOWN_QUERY_INTERFACE \
SLANG_REF_OBJECT_IUNKNOWN_ADD_REF \
SLANG_REF_OBJECT_IUNKNOWN_RELEASE
#endif // defined(__cplusplus)
#endif

View file

@ -0,0 +1,160 @@
#ifndef SLANG_COM_PTR_H
#define SLANG_COM_PTR_H
#include "slang-com-helper.h"
#include <assert.h>
#include <cstddef>
namespace Slang {
/*! \brief ComPtr is a simple smart pointer that manages types which implement COM based interfaces.
\details A class that implements a COM, must derive from the IUnknown interface or a type that matches
it's layout exactly (such as ISlangUnknown). Trying to use this template with a class that doesn't follow
these rules, will lead to undefined behavior.
This is a 'strong' pointer type, and will AddRef when a non null pointer is set and Release when the pointer
leaves scope.
Using 'detach' allows a pointer to be removed from the management of the ComPtr.
To set the smart pointer to null, there is the method setNull, or alternatively just assign SLANG_NULL/nullptr.
One edge case using the template is that sometimes you want access as a pointer to a pointer. Sometimes this
is to write into the smart pointer, other times to pass as an array. To handle these different behaviors
there are the methods readRef and writeRef, which are used instead of the & (ref) operator. For example
\code
Void doSomething(ID3D12Resource** resources, IndexT numResources);
// ...
ComPtr<ID3D12Resource> resources[3];
doSomething(resources[0].readRef(), SLANG_COUNT_OF(resource));
\endcode
A more common scenario writing to the pointer
\code
IUnknown* unk = ...;
ComPtr<ID3D12Resource> resource;
Result res = unk->QueryInterface(resource.writeRef());
\endcode
*/
// Enum to force initializing as an attach (without adding a reference)
enum InitAttach
{
INIT_ATTACH
};
template <class T>
class ComPtr
{
public:
typedef T Type;
typedef ComPtr ThisType;
typedef ISlangUnknown* Ptr;
/// Constructors
/// Default Ctor. Sets to nullptr
SLANG_FORCE_INLINE ComPtr() :m_ptr(nullptr) {}
SLANG_FORCE_INLINE ComPtr(std::nullptr_t) : m_ptr(nullptr) {}
/// Sets, and ref counts.
SLANG_FORCE_INLINE explicit ComPtr(T* ptr) :m_ptr(ptr) { if (ptr) ((Ptr)ptr)->addRef(); }
/// The copy ctor
SLANG_FORCE_INLINE ComPtr(const ThisType& rhs) : m_ptr(rhs.m_ptr) { if (m_ptr) ((Ptr)m_ptr)->addRef(); }
/// Ctor without adding to ref count.
SLANG_FORCE_INLINE explicit ComPtr(InitAttach, T* ptr) :m_ptr(ptr) { }
/// Ctor without adding to ref count
SLANG_FORCE_INLINE ComPtr(InitAttach, const ThisType& rhs) : m_ptr(rhs.m_ptr) { }
#ifdef SLANG_HAS_MOVE_SEMANTICS
/// Move Ctor
SLANG_FORCE_INLINE ComPtr(ThisType&& rhs) : m_ptr(rhs.m_ptr) { rhs.m_ptr = nullptr; }
/// Move assign
SLANG_FORCE_INLINE ComPtr& operator=(ThisType&& rhs) { T* swap = m_ptr; m_ptr = rhs.m_ptr; rhs.m_ptr = swap; return *this; }
#endif
/// Destructor releases the pointer, assuming it is set
SLANG_FORCE_INLINE ~ComPtr() { if (m_ptr) ((Ptr)m_ptr)->release(); }
// !!! Operators !!!
/// Returns the dumb pointer
SLANG_FORCE_INLINE operator T *() const { return m_ptr; }
SLANG_FORCE_INLINE T& operator*() { return *m_ptr; }
/// For making method invocations through the smart pointer work through the dumb pointer
SLANG_FORCE_INLINE T* operator->() const { return m_ptr; }
/// Assign
SLANG_FORCE_INLINE const ThisType &operator=(const ThisType& rhs);
/// Assign from dumb ptr
SLANG_FORCE_INLINE T* operator=(T* in);
/// Get the pointer and don't ref
SLANG_FORCE_INLINE T* get() const { return m_ptr; }
/// Release a contained nullptr pointer if set
SLANG_FORCE_INLINE void setNull();
/// Detach
SLANG_FORCE_INLINE T* detach() { T* ptr = m_ptr; m_ptr = nullptr; return ptr; }
/// Set to a pointer without changing the ref count
SLANG_FORCE_INLINE void attach(T* in) { m_ptr = in; }
/// Get ready for writing (nulls contents)
SLANG_FORCE_INLINE T** writeRef() { setNull(); return &m_ptr; }
/// Get for read access
SLANG_FORCE_INLINE T*const* readRef() const { return &m_ptr; }
/// Swap
void swap(ThisType& rhs);
protected:
/// Gets the address of the dumb pointer.
// Disabled: use writeRef and readRef to get a reference based on usage.
#ifndef SLANG_COM_PTR_ENABLE_REF_OPERATOR
SLANG_FORCE_INLINE T** operator&() = delete;
#endif
T* m_ptr;
};
//----------------------------------------------------------------------------
template <typename T>
void ComPtr<T>::setNull()
{
if (m_ptr)
{
((Ptr)m_ptr)->release();
m_ptr = nullptr;
}
}
//----------------------------------------------------------------------------
template <typename T>
const ComPtr<T>& ComPtr<T>::operator=(const ThisType& rhs)
{
if (rhs.m_ptr) ((Ptr)rhs.m_ptr)->addRef();
if (m_ptr) ((Ptr)m_ptr)->release();
m_ptr = rhs.m_ptr;
return *this;
}
//----------------------------------------------------------------------------
template <typename T>
T* ComPtr<T>::operator=(T* ptr)
{
if (ptr) ((Ptr)ptr)->addRef();
if (m_ptr) ((Ptr)m_ptr)->release();
m_ptr = ptr;
return m_ptr;
}
//----------------------------------------------------------------------------
template <typename T>
void ComPtr<T>::swap(ThisType& rhs)
{
T* tmp = m_ptr;
m_ptr = rhs.m_ptr;
rhs.m_ptr = tmp;
}
} // namespace Slang
#endif // SLANG_COM_PTR_H

View file

@ -0,0 +1,55 @@
#ifndef SLANG_CPP_HOST_PRELUDE_H
#define SLANG_CPP_HOST_PRELUDE_H
#include <cstdio>
#include <cmath>
#include <cstring>
#define SLANG_COM_PTR_ENABLE_REF_OPERATOR 1
#include "../source/slang-rt/slang-rt.h"
#include "slang-com-ptr.h"
#include "slang-cpp-types.h"
#ifdef SLANG_LLVM
#include "slang-llvm.h"
#else // SLANG_LLVM
# if SLANG_GCC_FAMILY && __GNUC__ < 6
# include <cmath>
# define SLANG_PRELUDE_STD std::
# else
# include <math.h>
# define SLANG_PRELUDE_STD
# endif
# include <assert.h>
# include <stdlib.h>
# include <string.h>
# include <stdint.h>
#endif // SLANG_LLVM
#if defined(_MSC_VER)
# define SLANG_PRELUDE_SHARED_LIB_EXPORT __declspec(dllexport)
#else
# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__((__visibility__("default")))
//# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__ ((dllexport)) __attribute__((__visibility__("default")))
#endif
#ifdef __cplusplus
# define SLANG_PRELUDE_EXTERN_C extern "C"
# define SLANG_PRELUDE_EXTERN_C_START extern "C" {
# define SLANG_PRELUDE_EXTERN_C_END }
#else
# define SLANG_PRELUDE_EXTERN_C
# define SLANG_PRELUDE_EXTERN_C_START
# define SLANG_PRELUDE_EXTERN_C_END
#endif
#include "slang-cpp-scalar-intrinsics.h"
using namespace Slang;
template<typename TResult, typename... Args>
using Slang_FuncType = TResult(SLANG_MCALL *)(Args...);
#endif

View file

@ -0,0 +1,316 @@
#ifndef SLANG_CPP_PRELUDE_H
#define SLANG_CPP_PRELUDE_H
// Because the signiture of isnan, isfinite, and is isinf changed in C++, we use the macro
// to use the version in the std namespace.
// https://stackoverflow.com/questions/39130040/cmath-hides-isnan-in-math-h-in-c14-c11
#ifdef SLANG_LLVM
#include "slang-llvm.h"
#else // SLANG_LLVM
# if SLANG_GCC_FAMILY && __GNUC__ < 6
# include <cmath>
# define SLANG_PRELUDE_STD std::
# else
# include <math.h>
# define SLANG_PRELUDE_STD
# endif
# include <assert.h>
# include <stdlib.h>
# include <string.h>
# include <stdint.h>
#endif // SLANG_LLVM
#if defined(_MSC_VER)
# define SLANG_PRELUDE_SHARED_LIB_EXPORT __declspec(dllexport)
#else
# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__((__visibility__("default")))
//# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__ ((dllexport)) __attribute__((__visibility__("default")))
#endif
#ifdef __cplusplus
# define SLANG_PRELUDE_EXTERN_C extern "C"
# define SLANG_PRELUDE_EXTERN_C_START extern "C" {
# define SLANG_PRELUDE_EXTERN_C_END }
#else
# define SLANG_PRELUDE_EXTERN_C
# define SLANG_PRELUDE_EXTERN_C_START
# define SLANG_PRELUDE_EXTERN_C_END
#endif
#define SLANG_PRELUDE_EXPORT SLANG_PRELUDE_EXTERN_C SLANG_PRELUDE_SHARED_LIB_EXPORT
#define SLANG_PRELUDE_EXPORT_START SLANG_PRELUDE_EXTERN_C_START SLANG_PRELUDE_SHARED_LIB_EXPORT
#define SLANG_PRELUDE_EXPORT_END SLANG_PRELUDE_EXTERN_C_END
#ifndef INFINITY
// Must overflow for double
# define INFINITY float(1e+300 * 1e+300)
#endif
#ifndef SLANG_INFINITY
# define SLANG_INFINITY INFINITY
#endif
// Detect the compiler type
#ifndef SLANG_COMPILER
# define SLANG_COMPILER
/*
Compiler defines, see http://sourceforge.net/p/predef/wiki/Compilers/
NOTE that SLANG_VC holds the compiler version - not just 1 or 0
*/
# if defined(_MSC_VER)
# if _MSC_VER >= 1900
# define SLANG_VC 14
# elif _MSC_VER >= 1800
# define SLANG_VC 12
# elif _MSC_VER >= 1700
# define SLANG_VC 11
# elif _MSC_VER >= 1600
# define SLANG_VC 10
# elif _MSC_VER >= 1500
# define SLANG_VC 9
# else
# error "unknown version of Visual C++ compiler"
# endif
# elif defined(__clang__)
# define SLANG_CLANG 1
# elif defined(__SNC__)
# define SLANG_SNC 1
# elif defined(__ghs__)
# define SLANG_GHS 1
# elif defined(__GNUC__) /* note: __clang__, __SNC__, or __ghs__ imply __GNUC__ */
# define SLANG_GCC 1
# else
# error "unknown compiler"
# endif
/*
Any compilers not detected by the above logic are now now explicitly zeroed out.
*/
# ifndef SLANG_VC
# define SLANG_VC 0
# endif
# ifndef SLANG_CLANG
# define SLANG_CLANG 0
# endif
# ifndef SLANG_SNC
# define SLANG_SNC 0
# endif
# ifndef SLANG_GHS
# define SLANG_GHS 0
# endif
# ifndef SLANG_GCC
# define SLANG_GCC 0
# endif
#endif /* SLANG_COMPILER */
/*
The following section attempts to detect the target platform being compiled for.
If an application defines `SLANG_PLATFORM` before including this header,
they take responsibility for setting any compiler-dependent macros
used later in the file.
Most applications should not need to touch this section.
*/
#ifndef SLANG_PLATFORM
# define SLANG_PLATFORM
/**
Operating system defines, see http://sourceforge.net/p/predef/wiki/OperatingSystems/
*/
# if defined(WINAPI_FAMILY) && WINAPI_FAMILY == WINAPI_PARTITION_APP
# define SLANG_WINRT 1 /* Windows Runtime, either on Windows RT or Windows 8 */
# elif defined(XBOXONE)
# define SLANG_XBOXONE 1
# elif defined(_WIN64) /* note: XBOXONE implies _WIN64 */
# define SLANG_WIN64 1
# elif defined(_M_PPC)
# define SLANG_X360 1
# elif defined(_WIN32) /* note: _M_PPC implies _WIN32 */
# define SLANG_WIN32 1
# elif defined(__ANDROID__)
# define SLANG_ANDROID 1
# elif defined(__linux__) || defined(__CYGWIN__) /* note: __ANDROID__ implies __linux__ */
# define SLANG_LINUX 1
# elif defined(__APPLE__) && !defined(SLANG_LLVM)
# include "TargetConditionals.h"
# if TARGET_OS_MAC
# define SLANG_OSX 1
# else
# define SLANG_IOS 1
# endif
# elif defined(__APPLE__)
// On `slang-llvm` we can't inclue "TargetConditionals.h" in general, so for now assume its OSX.
# define SLANG_OSX 1
# elif defined(__CELLOS_LV2__)
# define SLANG_PS3 1
# elif defined(__ORBIS__)
# define SLANG_PS4 1
# elif defined(__SNC__) && defined(__arm__)
# define SLANG_PSP2 1
# elif defined(__ghs__)
# define SLANG_WIIU 1
# else
# error "unknown target platform"
# endif
/*
Any platforms not detected by the above logic are now now explicitly zeroed out.
*/
# ifndef SLANG_WINRT
# define SLANG_WINRT 0
# endif
# ifndef SLANG_XBOXONE
# define SLANG_XBOXONE 0
# endif
# ifndef SLANG_WIN64
# define SLANG_WIN64 0
# endif
# ifndef SLANG_X360
# define SLANG_X360 0
# endif
# ifndef SLANG_WIN32
# define SLANG_WIN32 0
# endif
# ifndef SLANG_ANDROID
# define SLANG_ANDROID 0
# endif
# ifndef SLANG_LINUX
# define SLANG_LINUX 0
# endif
# ifndef SLANG_IOS
# define SLANG_IOS 0
# endif
# ifndef SLANG_OSX
# define SLANG_OSX 0
# endif
# ifndef SLANG_PS3
# define SLANG_PS3 0
# endif
# ifndef SLANG_PS4
# define SLANG_PS4 0
# endif
# ifndef SLANG_PSP2
# define SLANG_PSP2 0
# endif
# ifndef SLANG_WIIU
# define SLANG_WIIU 0
# endif
#endif /* SLANG_PLATFORM */
/* Shorthands for "families" of compilers/platforms */
#define SLANG_GCC_FAMILY (SLANG_CLANG || SLANG_SNC || SLANG_GHS || SLANG_GCC)
#define SLANG_WINDOWS_FAMILY (SLANG_WINRT || SLANG_WIN32 || SLANG_WIN64)
#define SLANG_MICROSOFT_FAMILY (SLANG_XBOXONE || SLANG_X360 || SLANG_WINDOWS_FAMILY)
#define SLANG_LINUX_FAMILY (SLANG_LINUX || SLANG_ANDROID)
#define SLANG_APPLE_FAMILY (SLANG_IOS || SLANG_OSX) /* equivalent to #if __APPLE__ */
#define SLANG_UNIX_FAMILY (SLANG_LINUX_FAMILY || SLANG_APPLE_FAMILY) /* shortcut for unix/posix platforms */
// GCC Specific
#if SLANG_GCC_FAMILY
# define SLANG_ALIGN_OF(T) __alignof__(T)
# define SLANG_BREAKPOINT(id) __builtin_trap()
// Use this macro instead of offsetof, because gcc produces warning if offsetof is used on a
// non POD type, even though it produces the correct result
# define SLANG_OFFSET_OF(T, ELEMENT) (size_t(&((T*)1)->ELEMENT) - 1)
#endif // SLANG_GCC_FAMILY
// Microsoft VC specific
#if SLANG_VC
# define SLANG_ALIGN_OF(T) __alignof(T)
# define SLANG_BREAKPOINT(id) __debugbreak();
#endif // SLANG_VC
// Default impls
#ifndef SLANG_OFFSET_OF
# define SLANG_OFFSET_OF(X, Y) offsetof(X, Y)
#endif
#ifndef SLANG_BREAKPOINT
// Make it crash with a write to 0!
# define SLANG_BREAKPOINT(id) (*((int*)0) = int(id));
#endif
// If slang.h has been included we don't need any of these definitions
#ifndef SLANG_H
/* Macro for declaring if a method is no throw. Should be set before the return parameter. */
#ifndef SLANG_NO_THROW
# if SLANG_WINDOWS_FAMILY && !defined(SLANG_DISABLE_EXCEPTIONS)
# define SLANG_NO_THROW __declspec(nothrow)
# endif
#endif
#ifndef SLANG_NO_THROW
# define SLANG_NO_THROW
#endif
/* The `SLANG_STDCALL` and `SLANG_MCALL` defines are used to set the calling
convention for interface methods.
*/
#ifndef SLANG_STDCALL
# if SLANG_MICROSOFT_FAMILY
# define SLANG_STDCALL __stdcall
# else
# define SLANG_STDCALL
# endif
#endif
#ifndef SLANG_MCALL
# define SLANG_MCALL SLANG_STDCALL
#endif
#ifndef SLANG_FORCE_INLINE
# define SLANG_FORCE_INLINE inline
#endif
// TODO(JS): Should these be in slang-cpp-types.h?
// They are more likely to clash with slang.h
struct SlangUUID
{
uint32_t data1;
uint16_t data2;
uint16_t data3;
uint8_t data4[8];
};
typedef int32_t SlangResult;
struct ISlangUnknown
{
virtual SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) = 0;
virtual SLANG_NO_THROW uint32_t SLANG_MCALL addRef() = 0;
virtual SLANG_NO_THROW uint32_t SLANG_MCALL release() = 0;
};
#define SLANG_COM_INTERFACE(a, b, c, d0, d1, d2, d3, d4, d5, d6, d7) \
public: \
SLANG_FORCE_INLINE static const SlangUUID& getTypeGuid() \
{ \
static const SlangUUID guid = { a, b, c, d0, d1, d2, d3, d4, d5, d6, d7 }; \
return guid; \
}
#endif // SLANG_H
// Includes
#include "slang-cpp-scalar-intrinsics.h"
#include "slang-cpp-types.h"
// TODO(JS): Hack! Output C++ code from slang can copy uninitialized variables.
#if defined(_MSC_VER)
# pragma warning(disable : 4700)
#endif
#ifndef SLANG_UNROLL
# define SLANG_UNROLL
#endif
#endif

View file

@ -0,0 +1,498 @@
#ifndef SLANG_PRELUDE_SCALAR_INTRINSICS_H
#define SLANG_PRELUDE_SCALAR_INTRINSICS_H
#if !defined(SLANG_LLVM) && SLANG_PROCESSOR_X86_64 && SLANG_VC
// If we have visual studio and 64 bit processor, we can assume we have popcnt, and can include x86 intrinsics
# include <intrin.h>
#endif
#ifndef SLANG_FORCE_INLINE
# define SLANG_FORCE_INLINE inline
#endif
#ifdef SLANG_PRELUDE_NAMESPACE
namespace SLANG_PRELUDE_NAMESPACE {
#endif
#ifndef SLANG_PRELUDE_PI
# define SLANG_PRELUDE_PI 3.14159265358979323846
#endif
union Union32
{
uint32_t u;
int32_t i;
float f;
};
union Union64
{
uint64_t u;
int64_t i;
double d;
};
// 32 bit cast conversions
SLANG_FORCE_INLINE int32_t _bitCastFloatToInt(float f) { Union32 u; u.f = f; return u.i; }
SLANG_FORCE_INLINE float _bitCastIntToFloat(int32_t i) { Union32 u; u.i = i; return u.f; }
SLANG_FORCE_INLINE uint32_t _bitCastFloatToUInt(float f) { Union32 u; u.f = f; return u.u; }
SLANG_FORCE_INLINE float _bitCastUIntToFloat(uint32_t ui) { Union32 u; u.u = ui; return u.f; }
// ----------------------------- F16 -----------------------------------------
// This impl is based on FloatToHalf that is in Slang codebase
SLANG_FORCE_INLINE uint32_t f32tof16(const float value)
{
const uint32_t inBits = _bitCastFloatToUInt(value);
// bits initially set to just the sign bit
uint32_t bits = (inBits >> 16) & 0x8000;
// Mantissa can't be used as is, as it holds last bit, for rounding.
uint32_t m = (inBits >> 12) & 0x07ff;
uint32_t e = (inBits >> 23) & 0xff;
if (e < 103)
{
// It's zero
return bits;
}
if (e == 0xff)
{
// Could be a NAN or INF. Is INF if *input* mantissa is 0.
// Remove last bit for rounding to make output mantissa.
m >>= 1;
// We *assume* float16/float32 signaling bit and remaining bits
// semantics are the same. (The signalling bit convention is target specific!).
// Non signal bit's usage within mantissa for a NAN are also target specific.
// If the m is 0, it could be because the result is INF, but it could also be because all the
// bits that made NAN were dropped as we have less mantissa bits in f16.
// To fix for this we make non zero if m is 0 and the input mantissa was not.
// This will (typically) produce a signalling NAN.
m += uint32_t(m == 0 && (inBits & 0x007fffffu));
// Combine for output
return (bits | 0x7c00u | m);
}
if (e > 142)
{
// INF.
return bits | 0x7c00u;
}
if (e < 113)
{
m |= 0x0800u;
bits |= (m >> (114 - e)) + ((m >> (113 - e)) & 1);
return bits;
}
bits |= ((e - 112) << 10) | (m >> 1);
bits += m & 1;
return bits;
}
static const float g_f16tof32Magic = _bitCastIntToFloat((127 + (127 - 15)) << 23);
SLANG_FORCE_INLINE float f16tof32(const uint32_t value)
{
const uint32_t sign = (value & 0x8000) << 16;
uint32_t exponent = (value & 0x7c00) >> 10;
uint32_t mantissa = (value & 0x03ff);
if (exponent == 0)
{
// If mantissa is 0 we are done, as output is 0.
// If it's not zero we must have a denormal.
if (mantissa)
{
// We have a denormal so use the magic to do exponent adjust
return _bitCastIntToFloat(sign | ((value & 0x7fff) << 13)) * g_f16tof32Magic;
}
}
else
{
// If the exponent is NAN or INF exponent is 0x1f on input.
// If that's the case, we just need to set the exponent to 0xff on output
// and the mantissa can just stay the same. If its 0 it's INF, else it is NAN and we just copy the bits
//
// Else we need to correct the exponent in the normalized case.
exponent = (exponent == 0x1F) ? 0xff : (exponent + (-15 + 127));
}
return _bitCastUIntToFloat(sign | (exponent << 23) | (mantissa << 13));
}
// ----------------------------- F32 -----------------------------------------
// Helpers
SLANG_FORCE_INLINE float F32_calcSafeRadians(float radians);
#ifdef SLANG_LLVM
SLANG_PRELUDE_EXTERN_C_START
// Unary
float F32_ceil(float f);
float F32_floor(float f);
float F32_round(float f);
float F32_sin(float f);
float F32_cos(float f);
float F32_tan(float f);
float F32_asin(float f);
float F32_acos(float f);
float F32_atan(float f);
float F32_sinh(float f);
float F32_cosh(float f);
float F32_tanh(float f);
float F32_log2(float f);
float F32_log(float f);
float F32_log10(float f);
float F32_exp2(float f);
float F32_exp(float f);
float F32_abs(float f);
float F32_trunc(float f);
float F32_sqrt(float f);
bool F32_isnan(float f);
bool F32_isfinite(float f);
bool F32_isinf(float f);
// Binary
SLANG_FORCE_INLINE float F32_min(float a, float b) { return a < b ? a : b; }
SLANG_FORCE_INLINE float F32_max(float a, float b) { return a > b ? a : b; }
float F32_pow(float a, float b);
float F32_fmod(float a, float b);
float F32_remainder(float a, float b);
float F32_atan2(float a, float b);
float F32_frexp(float x, int* e);
float F32_modf(float x, float* ip);
// Ternary
SLANG_FORCE_INLINE float F32_fma(float a, float b, float c) { return a * b + c; }
SLANG_PRELUDE_EXTERN_C_END
#else
// Unary
SLANG_FORCE_INLINE float F32_ceil(float f) { return ::ceilf(f); }
SLANG_FORCE_INLINE float F32_floor(float f) { return ::floorf(f); }
SLANG_FORCE_INLINE float F32_round(float f) { return ::roundf(f); }
SLANG_FORCE_INLINE float F32_sin(float f) { return ::sinf(f); }
SLANG_FORCE_INLINE float F32_cos(float f) { return ::cosf(f); }
SLANG_FORCE_INLINE float F32_tan(float f) { return ::tanf(f); }
SLANG_FORCE_INLINE float F32_asin(float f) { return ::asinf(f); }
SLANG_FORCE_INLINE float F32_acos(float f) { return ::acosf(f); }
SLANG_FORCE_INLINE float F32_atan(float f) { return ::atanf(f); }
SLANG_FORCE_INLINE float F32_sinh(float f) { return ::sinhf(f); }
SLANG_FORCE_INLINE float F32_cosh(float f) { return ::coshf(f); }
SLANG_FORCE_INLINE float F32_tanh(float f) { return ::tanhf(f); }
SLANG_FORCE_INLINE float F32_log2(float f) { return ::log2f(f); }
SLANG_FORCE_INLINE float F32_log(float f) { return ::logf(f); }
SLANG_FORCE_INLINE float F32_log10(float f) { return ::log10f(f); }
SLANG_FORCE_INLINE float F32_exp2(float f) { return ::exp2f(f); }
SLANG_FORCE_INLINE float F32_exp(float f) { return ::expf(f); }
SLANG_FORCE_INLINE float F32_abs(float f) { return ::fabsf(f); }
SLANG_FORCE_INLINE float F32_trunc(float f) { return ::truncf(f); }
SLANG_FORCE_INLINE float F32_sqrt(float f) { return ::sqrtf(f); }
SLANG_FORCE_INLINE bool F32_isnan(float f) { return SLANG_PRELUDE_STD isnan(f); }
SLANG_FORCE_INLINE bool F32_isfinite(float f) { return SLANG_PRELUDE_STD isfinite(f); }
SLANG_FORCE_INLINE bool F32_isinf(float f) { return SLANG_PRELUDE_STD isinf(f); }
// Binary
SLANG_FORCE_INLINE float F32_min(float a, float b) { return ::fminf(a, b); }
SLANG_FORCE_INLINE float F32_max(float a, float b) { return ::fmaxf(a, b); }
SLANG_FORCE_INLINE float F32_pow(float a, float b) { return ::powf(a, b); }
SLANG_FORCE_INLINE float F32_fmod(float a, float b) { return ::fmodf(a, b); }
SLANG_FORCE_INLINE float F32_remainder(float a, float b) { return ::remainderf(a, b); }
SLANG_FORCE_INLINE float F32_atan2(float a, float b) { return float(::atan2(a, b)); }
SLANG_FORCE_INLINE float F32_frexp(float x, int* e) { return ::frexpf(x, e); }
SLANG_FORCE_INLINE float F32_modf(float x, float* ip)
{
return ::modff(x, ip);
}
// Ternary
SLANG_FORCE_INLINE float F32_fma(float a, float b, float c) { return ::fmaf(a, b, c); }
#endif
SLANG_FORCE_INLINE float F32_calcSafeRadians(float radians)
{
// Put 0 to 2pi cycles to cycle around 0 to 1
float a = radians * (1.0f / float(SLANG_PRELUDE_PI * 2));
// Get truncated fraction, as value in 0 - 1 range
a = a - F32_floor(a);
// Convert back to 0 - 2pi range
return (a * float(SLANG_PRELUDE_PI * 2));
}
SLANG_FORCE_INLINE float F32_rsqrt(float f) { return 1.0f / F32_sqrt(f); }
SLANG_FORCE_INLINE float F32_sign(float f) { return ( f == 0.0f) ? f : (( f < 0.0f) ? -1.0f : 1.0f); }
SLANG_FORCE_INLINE float F32_frac(float f) { return f - F32_floor(f); }
SLANG_FORCE_INLINE uint32_t F32_asuint(float f) { Union32 u; u.f = f; return u.u; }
SLANG_FORCE_INLINE int32_t F32_asint(float f) { Union32 u; u.f = f; return u.i; }
// ----------------------------- F64 -----------------------------------------
SLANG_FORCE_INLINE double F64_calcSafeRadians(double radians);
#ifdef SLANG_LLVM
SLANG_PRELUDE_EXTERN_C_START
// Unary
double F64_ceil(double f);
double F64_floor(double f);
double F64_round(double f);
double F64_sin(double f);
double F64_cos(double f);
double F64_tan(double f);
double F64_asin(double f);
double F64_acos(double f);
double F64_atan(double f);
double F64_sinh(double f);
double F64_cosh(double f);
double F64_tanh(double f);
double F64_log2(double f);
double F64_log(double f);
double F64_log10(double f);
double F64_exp2(double f);
double F64_exp(double f);
double F64_abs(double f);
double F64_trunc(double f);
double F64_sqrt(double f);
bool F64_isnan(double f);
bool F64_isfinite(double f);
bool F64_isinf(double f);
// Binary
SLANG_FORCE_INLINE double F64_min(double a, double b) { return a < b ? a : b; }
SLANG_FORCE_INLINE double F64_max(double a, double b) { return a > b ? a : b; }
double F64_pow(double a, double b);
double F64_fmod(double a, double b);
double F64_remainder(double a, double b);
double F64_atan2(double a, double b);
double F64_frexp(double x, int* e);
double F64_modf(double x, double* ip);
// Ternary
SLANG_FORCE_INLINE double F64_fma(double a, double b, double c) { return a * b + c; }
SLANG_PRELUDE_EXTERN_C_END
#else // SLANG_LLVM
// Unary
SLANG_FORCE_INLINE double F64_ceil(double f) { return ::ceil(f); }
SLANG_FORCE_INLINE double F64_floor(double f) { return ::floor(f); }
SLANG_FORCE_INLINE double F64_round(double f) { return ::round(f); }
SLANG_FORCE_INLINE double F64_sin(double f) { return ::sin(f); }
SLANG_FORCE_INLINE double F64_cos(double f) { return ::cos(f); }
SLANG_FORCE_INLINE double F64_tan(double f) { return ::tan(f); }
SLANG_FORCE_INLINE double F64_asin(double f) { return ::asin(f); }
SLANG_FORCE_INLINE double F64_acos(double f) { return ::acos(f); }
SLANG_FORCE_INLINE double F64_atan(double f) { return ::atan(f); }
SLANG_FORCE_INLINE double F64_sinh(double f) { return ::sinh(f); }
SLANG_FORCE_INLINE double F64_cosh(double f) { return ::cosh(f); }
SLANG_FORCE_INLINE double F64_tanh(double f) { return ::tanh(f); }
SLANG_FORCE_INLINE double F64_log2(double f) { return ::log2(f); }
SLANG_FORCE_INLINE double F64_log(double f) { return ::log(f); }
SLANG_FORCE_INLINE double F64_log10(float f) { return ::log10(f); }
SLANG_FORCE_INLINE double F64_exp2(double f) { return ::exp2(f); }
SLANG_FORCE_INLINE double F64_exp(double f) { return ::exp(f); }
SLANG_FORCE_INLINE double F64_abs(double f) { return ::fabs(f); }
SLANG_FORCE_INLINE double F64_trunc(double f) { return ::trunc(f); }
SLANG_FORCE_INLINE double F64_sqrt(double f) { return ::sqrt(f); }
SLANG_FORCE_INLINE bool F64_isnan(double f) { return SLANG_PRELUDE_STD isnan(f); }
SLANG_FORCE_INLINE bool F64_isfinite(double f) { return SLANG_PRELUDE_STD isfinite(f); }
SLANG_FORCE_INLINE bool F64_isinf(double f) { return SLANG_PRELUDE_STD isinf(f); }
// Binary
SLANG_FORCE_INLINE double F64_min(double a, double b) { return ::fmin(a, b); }
SLANG_FORCE_INLINE double F64_max(double a, double b) { return ::fmax(a, b); }
SLANG_FORCE_INLINE double F64_pow(double a, double b) { return ::pow(a, b); }
SLANG_FORCE_INLINE double F64_fmod(double a, double b) { return ::fmod(a, b); }
SLANG_FORCE_INLINE double F64_remainder(double a, double b) { return ::remainder(a, b); }
SLANG_FORCE_INLINE double F64_atan2(double a, double b) { return ::atan2(a, b); }
SLANG_FORCE_INLINE double F64_frexp(double x, int* e) { return ::frexp(x, e); }
SLANG_FORCE_INLINE double F64_modf(double x, double* ip)
{
return ::modf(x, ip);
}
// Ternary
SLANG_FORCE_INLINE double F64_fma(double a, double b, double c) { return ::fma(a, b, c); }
#endif // SLANG_LLVM
SLANG_FORCE_INLINE double F64_rsqrt(double f) { return 1.0 / F64_sqrt(f); }
SLANG_FORCE_INLINE double F64_sign(double f) { return (f == 0.0) ? f : ((f < 0.0) ? -1.0 : 1.0); }
SLANG_FORCE_INLINE double F64_frac(double f) { return f - F64_floor(f); }
SLANG_FORCE_INLINE void F64_asuint(double d, uint32_t* low, uint32_t* hi)
{
Union64 u;
u.d = d;
*low = uint32_t(u.u);
*hi = uint32_t(u.u >> 32);
}
SLANG_FORCE_INLINE void F64_asint(double d, int32_t* low, int32_t* hi)
{
Union64 u;
u.d = d;
*low = int32_t(u.u);
*hi = int32_t(u.u >> 32);
}
SLANG_FORCE_INLINE double F64_calcSafeRadians(double radians)
{
// Put 0 to 2pi cycles to cycle around 0 to 1
double a = radians * (1.0f / (SLANG_PRELUDE_PI * 2));
// Get truncated fraction, as value in 0 - 1 range
a = a - F64_floor(a);
// Convert back to 0 - 2pi range
return (a * (SLANG_PRELUDE_PI * 2));
}
// ----------------------------- I32 -----------------------------------------
SLANG_FORCE_INLINE int32_t I32_abs(int32_t f) { return (f < 0) ? -f : f; }
SLANG_FORCE_INLINE int32_t I32_min(int32_t a, int32_t b) { return a < b ? a : b; }
SLANG_FORCE_INLINE int32_t I32_max(int32_t a, int32_t b) { return a > b ? a : b; }
SLANG_FORCE_INLINE float I32_asfloat(int32_t x) { Union32 u; u.i = x; return u.f; }
SLANG_FORCE_INLINE uint32_t I32_asuint(int32_t x) { return uint32_t(x); }
SLANG_FORCE_INLINE double I32_asdouble(int32_t low, int32_t hi )
{
Union64 u;
u.u = (uint64_t(hi) << 32) | uint32_t(low);
return u.d;
}
// ----------------------------- U32 -----------------------------------------
SLANG_FORCE_INLINE uint32_t U32_abs(uint32_t f) { return f; }
SLANG_FORCE_INLINE uint32_t U32_min(uint32_t a, uint32_t b) { return a < b ? a : b; }
SLANG_FORCE_INLINE uint32_t U32_max(uint32_t a, uint32_t b) { return a > b ? a : b; }
SLANG_FORCE_INLINE float U32_asfloat(uint32_t x) { Union32 u; u.u = x; return u.f; }
SLANG_FORCE_INLINE uint32_t U32_asint(int32_t x) { return uint32_t(x); }
SLANG_FORCE_INLINE double U32_asdouble(uint32_t low, uint32_t hi)
{
Union64 u;
u.u = (uint64_t(hi) << 32) | low;
return u.d;
}
SLANG_FORCE_INLINE uint32_t U32_countbits(uint32_t v)
{
#if SLANG_GCC_FAMILY && !defined(SLANG_LLVM)
return __builtin_popcount(v);
#elif SLANG_PROCESSOR_X86_64 && SLANG_VC
return __popcnt(v);
#else
uint32_t c = 0;
while (v)
{
c++;
v &= v - 1;
}
return c;
#endif
}
// ----------------------------- U64 -----------------------------------------
SLANG_FORCE_INLINE uint64_t U64_abs(uint64_t f) { return f; }
SLANG_FORCE_INLINE uint64_t U64_min(uint64_t a, uint64_t b) { return a < b ? a : b; }
SLANG_FORCE_INLINE uint64_t U64_max(uint64_t a, uint64_t b) { return a > b ? a : b; }
// TODO(JS): We don't define countbits for 64bit in stdlib currently.
// It's not clear from documentation if it should return 32 or 64 bits, if it exists.
// 32 bits can always hold the result, and will be implicitly promoted.
SLANG_FORCE_INLINE uint32_t U64_countbits(uint64_t v)
{
#if SLANG_GCC_FAMILY && !defined(SLANG_LLVM)
return uint32_t(__builtin_popcountl(v));
#elif SLANG_PROCESSOR_X86_64 && SLANG_VC
return uint32_t(__popcnt64(v));
#else
uint32_t c = 0;
while (v)
{
c++;
v &= v - 1;
}
return c;
#endif
}
// ----------------------------- I64 -----------------------------------------
SLANG_FORCE_INLINE int64_t I64_abs(int64_t f) { return (f < 0) ? -f : f; }
SLANG_FORCE_INLINE int64_t I64_min(int64_t a, int64_t b) { return a < b ? a : b; }
SLANG_FORCE_INLINE int64_t I64_max(int64_t a, int64_t b) { return a > b ? a : b; }
// ----------------------------- Interlocked ---------------------------------
#if SLANG_LLVM
#else // SLANG_LLVM
# ifdef _WIN32
# include <intrin.h>
# endif
SLANG_FORCE_INLINE void InterlockedAdd(uint32_t* dest, uint32_t value, uint32_t* oldValue)
{
# ifdef _WIN32
*oldValue = _InterlockedExchangeAdd((long*)dest, (long)value);
# else
*oldValue = __sync_fetch_and_add(dest, value);
# endif
}
#endif // SLANG_LLVM
// ----------------------- fmod --------------------------
SLANG_FORCE_INLINE float _slang_fmod(float x, float y)
{
return F32_fmod(x, y);
}
SLANG_FORCE_INLINE double _slang_fmod(double x, double y)
{
return F64_fmod(x, y);
}
#ifdef SLANG_PRELUDE_NAMESPACE
}
#endif
#endif

View file

@ -0,0 +1,578 @@
#ifndef SLANG_PRELUDE_CPP_TYPES_CORE_H
#define SLANG_PRELUDE_CPP_TYPES_CORE_H
#ifndef SLANG_PRELUDE_ASSERT
# ifdef SLANG_PRELUDE_ENABLE_ASSERT
# define SLANG_PRELUDE_ASSERT(VALUE) assert(VALUE)
# else
# define SLANG_PRELUDE_ASSERT(VALUE)
# endif
#endif
// Since we are using unsigned arithmatic care is need in this comparison.
// It is *assumed* that sizeInBytes >= elemSize. Which means (sizeInBytes >= elemSize) >= 0
// Which means only a single test is needed
// Asserts for bounds checking.
// It is assumed index/count are unsigned types.
#define SLANG_BOUND_ASSERT(index, count) SLANG_PRELUDE_ASSERT(index < count);
#define SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_PRELUDE_ASSERT(index <= (sizeInBytes - elemSize) && (index & 3) == 0);
// Macros to zero index if an access is out of range
#define SLANG_BOUND_ZERO_INDEX(index, count) index = (index < count) ? index : 0;
#define SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes) index = (index <= (sizeInBytes - elemSize)) ? index : 0;
// The 'FIX' macro define how the index is fixed. The default is to do nothing. If SLANG_ENABLE_BOUND_ZERO_INDEX
// the fix macro will zero the index, if out of range
#ifdef SLANG_ENABLE_BOUND_ZERO_INDEX
# define SLANG_BOUND_FIX(index, count) SLANG_BOUND_ZERO_INDEX(index, count)
# define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes)
# define SLANG_BOUND_FIX_FIXED_ARRAY(index, count) SLANG_BOUND_ZERO_INDEX(index, count)
#else
# define SLANG_BOUND_FIX(index, count)
# define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes)
# define SLANG_BOUND_FIX_FIXED_ARRAY(index, count)
#endif
#ifndef SLANG_BOUND_CHECK
# define SLANG_BOUND_CHECK(index, count) SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX(index, count)
#endif
#ifndef SLANG_BOUND_CHECK_BYTE_ADDRESS
# define SLANG_BOUND_CHECK_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes)
#endif
#ifndef SLANG_BOUND_CHECK_FIXED_ARRAY
# define SLANG_BOUND_CHECK_FIXED_ARRAY(index, count) SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX_FIXED_ARRAY(index, count)
#endif
struct TypeInfo
{
size_t typeSize;
};
template <typename T, size_t SIZE>
struct FixedArray
{
const T& operator[](size_t index) const { SLANG_BOUND_CHECK_FIXED_ARRAY(index, SIZE); return m_data[index]; }
T& operator[](size_t index) { SLANG_BOUND_CHECK_FIXED_ARRAY(index, SIZE); return m_data[index]; }
T m_data[SIZE];
};
// An array that has no specified size, becomes a 'Array'. This stores the size so it can potentially
// do bounds checking.
template <typename T>
struct Array
{
const T& operator[](size_t index) const { SLANG_BOUND_CHECK(index, count); return data[index]; }
T& operator[](size_t index) { SLANG_BOUND_CHECK(index, count); return data[index]; }
T* data;
size_t count;
};
/* Constant buffers become a pointer to the contained type, so ConstantBuffer<T> becomes T* in C++ code.
*/
template <typename T, int COUNT>
struct Vector;
template <typename T>
struct Vector<T, 1>
{
T x;
const T& operator[](size_t /*index*/) const { return x; }
T& operator[](size_t /*index*/) { return x; }
operator T() const { return x; }
Vector() = default;
Vector(T scalar)
{
x = scalar;
}
template <typename U>
Vector(Vector<U, 1> other)
{
x = (T)other.x;
}
template <typename U, int otherSize>
Vector(Vector<U, otherSize> other)
{
int minSize = 1;
if (otherSize < minSize) minSize = otherSize;
for (int i = 0; i < minSize; i++)
(*this)[i] = (T)other[i];
}
};
template <typename T>
struct Vector<T, 2>
{
T x, y;
const T& operator[](size_t index) const { return index == 0 ? x : y; }
T& operator[](size_t index) { return index == 0 ? x : y; }
Vector() = default;
Vector(T scalar)
{
x = y = scalar;
}
Vector(T _x, T _y)
{
x = _x;
y = _y;
}
template <typename U>
Vector(Vector<U, 2> other)
{
x = (T)other.x;
y = (T)other.y;
}
template <typename U, int otherSize>
Vector(Vector<U, otherSize> other)
{
int minSize = 2;
if (otherSize < minSize) minSize = otherSize;
for (int i = 0; i < minSize; i++)
(*this)[i] = (T)other[i];
}
};
template <typename T>
struct Vector<T, 3>
{
T x, y, z;
const T& operator[](size_t index) const { return *((T*)(this) + index); }
T& operator[](size_t index) { return *((T*)(this) + index); }
Vector() = default;
Vector(T scalar)
{
x = y = z = scalar;
}
Vector(T _x, T _y, T _z)
{
x = _x;
y = _y;
z = _z;
}
template <typename U>
Vector(Vector<U, 3> other)
{
x = (T)other.x;
y = (T)other.y;
z = (T)other.z;
}
template <typename U, int otherSize>
Vector(Vector<U, otherSize> other)
{
int minSize = 3;
if (otherSize < minSize) minSize = otherSize;
for (int i = 0; i < minSize; i++)
(*this)[i] = (T)other[i];
}
};
template <typename T>
struct Vector<T, 4>
{
T x, y, z, w;
const T& operator[](size_t index) const { return *((T*)(this) + index); }
T& operator[](size_t index) { return *((T*)(this) + index); }
Vector() = default;
Vector(T scalar)
{
x = y = z = w = scalar;
}
Vector(T _x, T _y, T _z, T _w)
{
x = _x;
y = _y;
z = _z;
w = _w;
}
template <typename U, int otherSize>
Vector(Vector<U, otherSize> other)
{
int minSize = 4;
if (otherSize < minSize) minSize = otherSize;
for (int i = 0; i < minSize; i++)
(*this)[i] = (T)other[i];
}
};
template<typename T, int N>
SLANG_FORCE_INLINE Vector<T, N> _slang_select(Vector<bool, N> condition, Vector<T, N> v0, Vector<T, N> v1)
{
Vector<T, N> result;
for (int i = 0; i < N; i++)
{
result[i] = condition[i] ? v0[i] : v1[i];
}
return result;
}
template<typename T>
SLANG_FORCE_INLINE T _slang_select(bool condition, T v0, T v1)
{
return condition ? v0 : v1;
}
template<typename T, int N>
SLANG_FORCE_INLINE T _slang_vector_get_element(Vector<T, N> x, int index)
{
return x[index];
}
template<typename T, int N>
SLANG_FORCE_INLINE const T* _slang_vector_get_element_ptr(const Vector<T, N>* x, int index)
{
return &((*const_cast<Vector<T,N>*>(x))[index]);
}
template<typename T, int N>
SLANG_FORCE_INLINE T* _slang_vector_get_element_ptr(Vector<T, N>* x, int index)
{
return &((*x)[index]);
}
template<typename T, int n, typename OtherT, int m>
SLANG_FORCE_INLINE Vector<T, n> _slang_vector_reshape(const Vector<OtherT, m> other)
{
Vector<T, n> result;
for (int i = 0; i < n; i++)
{
OtherT otherElement = T(0);
if (i < m)
otherElement = _slang_vector_get_element(other, i);
*_slang_vector_get_element_ptr(&result, i) = (T)otherElement;
}
return result;
}
typedef uint32_t uint;
#define SLANG_VECTOR_BINARY_OP(T, op) \
template<int n> \
SLANG_FORCE_INLINE Vector<T, n> operator op(const Vector<T, n>& thisVal, const Vector<T, n>& other) \
{ \
Vector<T, n> result;\
for (int i = 0; i < n; i++) \
result[i] = thisVal[i] op other[i]; \
return result;\
}
#define SLANG_VECTOR_BINARY_COMPARE_OP(T, op) \
template<int n> \
SLANG_FORCE_INLINE Vector<bool, n> operator op(const Vector<T, n>& thisVal, const Vector<T, n>& other) \
{ \
Vector<bool, n> result;\
for (int i = 0; i < n; i++) \
result[i] = thisVal[i] op other[i]; \
return result;\
}
#define SLANG_VECTOR_UNARY_OP(T, op) \
template<int n> \
SLANG_FORCE_INLINE Vector<T, n> operator op(const Vector<T, n>& thisVal) \
{ \
Vector<T, n> result;\
for (int i = 0; i < n; i++) \
result[i] = op thisVal[i]; \
return result;\
}
#define SLANG_INT_VECTOR_OPS(T) \
SLANG_VECTOR_BINARY_OP(T, +)\
SLANG_VECTOR_BINARY_OP(T, -)\
SLANG_VECTOR_BINARY_OP(T, *)\
SLANG_VECTOR_BINARY_OP(T, / )\
SLANG_VECTOR_BINARY_OP(T, &)\
SLANG_VECTOR_BINARY_OP(T, |)\
SLANG_VECTOR_BINARY_OP(T, &&)\
SLANG_VECTOR_BINARY_OP(T, ||)\
SLANG_VECTOR_BINARY_OP(T, ^)\
SLANG_VECTOR_BINARY_OP(T, %)\
SLANG_VECTOR_BINARY_OP(T, >>)\
SLANG_VECTOR_BINARY_OP(T, <<)\
SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\
SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\
SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\
SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\
SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\
SLANG_VECTOR_BINARY_COMPARE_OP(T, !=)\
SLANG_VECTOR_UNARY_OP(T, !)\
SLANG_VECTOR_UNARY_OP(T, ~)
#define SLANG_FLOAT_VECTOR_OPS(T) \
SLANG_VECTOR_BINARY_OP(T, +)\
SLANG_VECTOR_BINARY_OP(T, -)\
SLANG_VECTOR_BINARY_OP(T, *)\
SLANG_VECTOR_BINARY_OP(T, /)\
SLANG_VECTOR_UNARY_OP(T, -)\
SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\
SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\
SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\
SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\
SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\
SLANG_VECTOR_BINARY_COMPARE_OP(T, !=)
SLANG_INT_VECTOR_OPS(bool)
SLANG_INT_VECTOR_OPS(int)
SLANG_INT_VECTOR_OPS(int8_t)
SLANG_INT_VECTOR_OPS(int16_t)
SLANG_INT_VECTOR_OPS(int64_t)
SLANG_INT_VECTOR_OPS(uint)
SLANG_INT_VECTOR_OPS(uint8_t)
SLANG_INT_VECTOR_OPS(uint16_t)
SLANG_INT_VECTOR_OPS(uint64_t)
SLANG_FLOAT_VECTOR_OPS(float)
SLANG_FLOAT_VECTOR_OPS(double)
#define SLANG_VECTOR_INT_NEG_OP(T) \
template<int N>\
Vector<T, N> operator-(const Vector<T, N>& thisVal) \
{ \
Vector<T, N> result;\
for (int i = 0; i < N; i++) \
result[i] = 0 - thisVal[i]; \
return result;\
}
SLANG_VECTOR_INT_NEG_OP(int)
SLANG_VECTOR_INT_NEG_OP(int8_t)
SLANG_VECTOR_INT_NEG_OP(int16_t)
SLANG_VECTOR_INT_NEG_OP(int64_t)
SLANG_VECTOR_INT_NEG_OP(uint)
SLANG_VECTOR_INT_NEG_OP(uint8_t)
SLANG_VECTOR_INT_NEG_OP(uint16_t)
SLANG_VECTOR_INT_NEG_OP(uint64_t)
#define SLANG_FLOAT_VECTOR_MOD(T)\
template<int N> \
Vector<T, N> operator%(const Vector<T, N>& left, const Vector<T, N>& right) \
{\
Vector<T, N> result;\
for (int i = 0; i < N; i++) \
result[i] = _slang_fmod(left[i], right[i]); \
return result;\
}
SLANG_FLOAT_VECTOR_MOD(float)
SLANG_FLOAT_VECTOR_MOD(double)
#undef SLANG_FLOAT_VECTOR_MOD
#undef SLANG_VECTOR_BINARY_OP
#undef SLANG_VECTOR_UNARY_OP
#undef SLANG_INT_VECTOR_OPS
#undef SLANG_FLOAT_VECTOR_OPS
#undef SLANG_VECTOR_INT_NEG_OP
#undef SLANG_FLOAT_VECTOR_MOD
template <typename T, int ROWS, int COLS>
struct Matrix
{
Vector<T, COLS> rows[ROWS];
Vector<T, COLS>& operator[](size_t index) { return rows[index]; }
Matrix() = default;
Matrix(T scalar)
{
for (int i = 0; i < ROWS; i++)
rows[i] = Vector<T, COLS>(scalar);
}
Matrix(const Vector<T, COLS>& row0)
{
rows[0] = row0;
}
Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1)
{
rows[0] = row0;
rows[1] = row1;
}
Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1, const Vector<T, COLS>& row2)
{
rows[0] = row0;
rows[1] = row1;
rows[2] = row2;
}
Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1, const Vector<T, COLS>& row2, const Vector<T, COLS>& row3)
{
rows[0] = row0;
rows[1] = row1;
rows[2] = row2;
rows[3] = row3;
}
template<typename U, int otherRow, int otherCol>
Matrix(const Matrix<U, otherRow, otherCol>& other)
{
int minRow = ROWS;
int minCol = COLS;
if (minRow > otherRow) minRow = otherRow;
if (minCol > otherCol) minCol = otherCol;
for (int i = 0; i < minRow; i++)
for (int j = 0; j < minCol; j++)
rows[i][j] = (T)other.rows[i][j];
}
Matrix(T v0, T v1, T v2, T v3)
{
rows[0][0] = v0; rows[0][1] = v1;
rows[1][0] = v2; rows[1][1] = v3;
}
Matrix(T v0, T v1, T v2, T v3, T v4, T v5)
{
if (COLS == 3)
{
rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2;
rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5;
}
else
{
rows[0][0] = v0; rows[0][1] = v1;
rows[1][0] = v2; rows[1][1] = v3;
rows[2][0] = v4; rows[2][1] = v5;
}
}
Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7)
{
if (COLS == 4)
{
rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3;
rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7;
}
else
{
rows[0][0] = v0; rows[0][1] = v1;
rows[1][0] = v2; rows[1][1] = v3;
rows[2][0] = v4; rows[2][1] = v5;
rows[3][0] = v6; rows[3][1] = v7;
}
}
Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8)
{
rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2;
rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5;
rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8;
}
Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11)
{
if (COLS == 4)
{
rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3;
rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7;
rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11;
}
else
{
rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2;
rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5;
rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8;
rows[3][0] = v9; rows[3][1] = v10; rows[3][2] = v11;
}
}
Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11, T v12, T v13, T v14, T v15)
{
rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3;
rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7;
rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11;
rows[3][0] = v12; rows[3][1] = v13; rows[3][2] = v14; rows[3][3] = v15;
}
};
#define SLANG_MATRIX_BINARY_OP(T, op) \
template<int R, int C> \
Matrix<T, R, C> operator op(const Matrix<T, R, C>& thisVal, const Matrix<T, R, C>& other) \
{ \
Matrix<T, R, C> result;\
for (int i = 0; i < R; i++) \
for (int j = 0; j < C; j++) \
result.rows[i][j] = thisVal.rows[i][j] op other.rows[i][j]; \
return result;\
}
#define SLANG_MATRIX_UNARY_OP(T, op) \
template<int R, int C> \
Matrix<T, R, C> operator op(const Matrix<T, R, C>& thisVal) \
{ \
Matrix<T, R, C> result;\
for (int i = 0; i < R; i++) \
for (int j = 0; j < C; j++) \
result[i].rows[i][j] = op thisVal.rows[i][j]; \
return result;\
}
#define SLANG_INT_MATRIX_OPS(T) \
SLANG_MATRIX_BINARY_OP(T, +)\
SLANG_MATRIX_BINARY_OP(T, -)\
SLANG_MATRIX_BINARY_OP(T, *)\
SLANG_MATRIX_BINARY_OP(T, / )\
SLANG_MATRIX_BINARY_OP(T, &)\
SLANG_MATRIX_BINARY_OP(T, |)\
SLANG_MATRIX_BINARY_OP(T, &&)\
SLANG_MATRIX_BINARY_OP(T, ||)\
SLANG_MATRIX_BINARY_OP(T, ^)\
SLANG_MATRIX_BINARY_OP(T, %)\
SLANG_MATRIX_UNARY_OP(T, !)\
SLANG_MATRIX_UNARY_OP(T, ~)
#define SLANG_FLOAT_MATRIX_OPS(T) \
SLANG_MATRIX_BINARY_OP(T, +)\
SLANG_MATRIX_BINARY_OP(T, -)\
SLANG_MATRIX_BINARY_OP(T, *)\
SLANG_MATRIX_BINARY_OP(T, /)\
SLANG_MATRIX_UNARY_OP(T, -)
SLANG_INT_MATRIX_OPS(int)
SLANG_INT_MATRIX_OPS(int8_t)
SLANG_INT_MATRIX_OPS(int16_t)
SLANG_INT_MATRIX_OPS(int64_t)
SLANG_INT_MATRIX_OPS(uint)
SLANG_INT_MATRIX_OPS(uint8_t)
SLANG_INT_MATRIX_OPS(uint16_t)
SLANG_INT_MATRIX_OPS(uint64_t)
SLANG_FLOAT_MATRIX_OPS(float)
SLANG_FLOAT_MATRIX_OPS(double)
#define SLANG_MATRIX_INT_NEG_OP(T) \
template<int R, int C>\
SLANG_FORCE_INLINE Matrix<T, R, C> operator-(Matrix<T, R, C> thisVal) \
{ \
Matrix<T, R, C> result;\
for (int i = 0; i < R; i++) \
for (int j = 0; j < C; j++) \
result.rows[i][j] = 0 - thisVal.rows[i][j]; \
return result;\
}
SLANG_MATRIX_INT_NEG_OP(int)
SLANG_MATRIX_INT_NEG_OP(int8_t)
SLANG_MATRIX_INT_NEG_OP(int16_t)
SLANG_MATRIX_INT_NEG_OP(int64_t)
SLANG_MATRIX_INT_NEG_OP(uint)
SLANG_MATRIX_INT_NEG_OP(uint8_t)
SLANG_MATRIX_INT_NEG_OP(uint16_t)
SLANG_MATRIX_INT_NEG_OP(uint64_t)
#define SLANG_FLOAT_MATRIX_MOD(T)\
template<int R, int C> \
SLANG_FORCE_INLINE Matrix<T, R, C> operator%(Matrix<T, R, C> left, Matrix<T, R, C> right) \
{\
Matrix<T, R, C> result;\
for (int i = 0; i < R; i++) \
for (int j = 0; j < C; j++) \
result.rows[i][j] = _slang_fmod(left.rows[i][j], right.rows[i][j]); \
return result;\
}
SLANG_FLOAT_MATRIX_MOD(float)
SLANG_FLOAT_MATRIX_MOD(double)
#undef SLANG_FLOAT_MATRIX_MOD
#undef SLANG_MATRIX_BINARY_OP
#undef SLANG_MATRIX_UNARY_OP
#undef SLANG_INT_MATRIX_OPS
#undef SLANG_FLOAT_MATRIX_OPS
#undef SLANG_MATRIX_INT_NEG_OP
#undef SLANG_FLOAT_MATRIX_MOD
template<typename TResult, typename TInput>
TResult slang_bit_cast(TInput val)
{
return *(TResult*)(&val);
}
#endif

Some files were not shown because too many files have changed in this diff Show more