diff --git a/common/chat.cpp b/common/chat.cpp
index fa926bf82..32aa205e9 100644
--- a/common/chat.cpp
+++ b/common/chat.cpp
@@ -631,6 +631,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
+ case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: return "DeepSeek V3.1";
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
@@ -698,11 +699,13 @@ static void parse_json_tool_calls(
size_t from = std::string::npos;
auto first = true;
while (true) {
+ auto start_pos = builder.pos();
auto res = function_regex_start_only && first
? builder.try_consume_regex(*function_regex_start_only)
: function_regex
? builder.try_find_regex(*function_regex, from)
: std::nullopt;
+
if (res) {
std::string name;
if (get_function_name) {
@@ -737,6 +740,8 @@ static void parse_json_tool_calls(
return;
}
throw common_chat_msg_partial_exception("incomplete tool call");
+ } else {
+ builder.move_to(start_pos);
}
break;
}
@@ -1388,6 +1393,71 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
}
return data;
}
+
+static common_chat_params common_chat_params_init_deepseek_v3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
+ common_chat_params data;
+
+ // Pass thinking context for DeepSeek V3.1 template
+ json additional_context = {
+ {"thinking", inputs.enable_thinking},
+ };
+
+ auto prompt = apply(tmpl, inputs,
+ /* messages_override= */ inputs.messages,
+ /* tools_override= */ std::nullopt,
+ additional_context);
+ data.prompt = prompt;
+ data.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
+ if (string_ends_with(data.prompt, "")) {
+ if (!inputs.enable_thinking) {
+ data.prompt += "";
+ } else {
+ data.thinking_forced_open = true;
+ }
+ }
+ if (inputs.tools.is_array() && !inputs.tools.empty()) {
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
+ std::vector tool_rules;
+ foreach_function(inputs.tools, [&](const json & tool) {
+ const auto & function = tool.at("function");
+ std::string name = function.at("name");
+ auto parameters = function.at("parameters");
+ builder.resolve_refs(parameters);
+ tool_rules.push_back(builder.add_rule(name + "-call",
+ "( \"<|tool▁call▁begin|>\" )? \"" + name + "<|tool▁sep|>"
+ "\" " + builder.add_schema(name + "-args", parameters) + " "
+ "\"<|tool▁call▁end|>\""));
+ });
+ // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
+ // so we accept common variants (then it's all constrained)
+ builder.add_rule("root",
+ std::string(data.thinking_forced_open ? "( \"\" space )? " : "") +
+ "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" | \"<|tool▁calls|>\" ) "
+ "(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
+ "\"<|tool▁calls▁end|>\""
+ " space");
+ data.grammar_triggers.push_back({
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
+ // If thinking_forced_open, then we capture the tag in the grammar,
+ // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
+ std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)" : "(?:[\\s\\S]*?\\s*)?") +
+ "(<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)[\\s\\S]*"
+ });
+ data.preserved_tokens = {
+ "",
+ "",
+ "<|tool▁calls▁begin|>",
+ "<|tool▁call▁begin|>",
+ "<|tool▁sep|>",
+ "<|tool▁call▁end|>",
+ "<|tool▁calls▁end|>",
+ };
+ });
+ }
+ return data;
+}
+
static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("", "");
if (!builder.syntax().parse_tool_calls) {
@@ -1409,6 +1479,66 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
tool_calls_end);
}
+static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) {
+ static const common_regex function_regex("(?:<|tool▁call▁begin|>)?([^\\n<]+)(?:<|tool▁sep|>)");
+
+ static const common_regex close_regex("(?:[\\s]*)?<|tool▁call▁end|>");
+ static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)");
+ static const common_regex tool_calls_end("<|tool▁calls▁end|>");
+
+ if (!builder.syntax().parse_tool_calls) {
+ LOG_DBG("%s: not parse_tool_calls\n", __func__);
+ builder.add_content(builder.consume_rest());
+ return;
+ }
+
+ LOG_DBG("%s: parse_tool_calls\n", __func__);
+
+ parse_json_tool_calls(
+ builder,
+ /* block_open= */ tool_calls_begin,
+ /* function_regex_start_only= */ std::nullopt,
+ function_regex,
+ close_regex,
+ tool_calls_end);
+}
+
+static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
+ // DeepSeek V3.1 outputs reasoning content between "" and "" tags, followed by regular content
+ // First try to parse using the standard reasoning parsing method
+ LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str());
+
+ auto start_pos = builder.pos();
+ auto found_end_think = builder.try_find_literal("");
+ builder.move_to(start_pos);
+
+ if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) {
+ LOG_DBG("%s: no end_think, not partial, adding content\n", __func__);
+ common_chat_parse_deepseek_v3_1_content(builder);
+ } else if (builder.try_parse_reasoning("", "")) {
+ // If reasoning was parsed successfully, the remaining content is regular content
+ LOG_DBG("%s: parsed reasoning, adding content\n", __func__);
+ // <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>NAME\n```json\nJSON\n```<|tool▁call▁end|><|tool▁calls▁end|>
+ common_chat_parse_deepseek_v3_1_content(builder);
+ } else {
+ if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) {
+ LOG_DBG("%s: reasoning_format none, adding content\n", __func__);
+ common_chat_parse_deepseek_v3_1_content(builder);
+ return;
+ }
+ // If no reasoning tags found, check if we should treat everything as reasoning
+ if (builder.syntax().thinking_forced_open) {
+ // If thinking is forced open but no tags found, treat everything as reasoning
+ LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__);
+ builder.add_reasoning_content(builder.consume_rest());
+ } else {
+ LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__);
+ // <|tool▁call▁begin|>NAME<|tool▁sep|>JSON<|tool▁call▁end|>
+ common_chat_parse_deepseek_v3_1_content(builder);
+ }
+ }
+}
+
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
auto prompt = apply(tmpl, inputs);
@@ -2365,6 +2495,12 @@ static common_chat_params common_chat_templates_apply_jinja(
}
}
+ // DeepSeek V3.1: detect based on specific patterns in the template
+ if (src.find("message['prefix'] is defined and message['prefix'] and thinking") != std::string::npos &&
+ params.json_schema.is_null()) {
+ return common_chat_params_init_deepseek_v3_1(tmpl, params);
+ }
+
// DeepSeek R1: use handler in all cases except json schema (thinking / tools).
if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_deepseek_r1(tmpl, params);
@@ -2537,6 +2673,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
common_chat_parse_deepseek_r1(builder);
break;
+ case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1:
+ common_chat_parse_deepseek_v3_1(builder);
+ break;
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
common_chat_parse_functionary_v3_2(builder);
break;
diff --git a/common/chat.h b/common/chat.h
index 41851022d..5170fc14f 100644
--- a/common/chat.h
+++ b/common/chat.h
@@ -107,6 +107,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
+ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_GRANITE,
diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp
index 637891f50..182c78754 100644
--- a/common/json-schema-to-grammar.cpp
+++ b/common/json-schema-to-grammar.cpp
@@ -843,9 +843,10 @@ public:
_build_object_rule(
properties, required, name,
schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
- } else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) {
+ } else if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) {
std::unordered_set required;
std::vector> properties;
+ std::map enum_values;
std::string hybrid_name = name;
std::function add_component = [&](const json & comp_schema, bool is_required) {
if (comp_schema.contains("$ref")) {
@@ -857,6 +858,14 @@ public:
required.insert(prop.key());
}
}
+ } else if (comp_schema.contains("enum")) {
+ for (const auto & v : comp_schema["enum"]) {
+ const auto rule = _generate_constant_rule(v);
+ if (enum_values.find(rule) == enum_values.end()) {
+ enum_values[rule] = 0;
+ }
+ enum_values[rule] += 1;
+ }
} else {
// todo warning
}
@@ -870,6 +879,17 @@ public:
add_component(t, true);
}
}
+ if (!enum_values.empty()) {
+ std::vector enum_intersection;
+ for (const auto & p : enum_values) {
+ if (p.second == schema["allOf"].size()) {
+ enum_intersection.push_back(p.first);
+ }
+ }
+ if (!enum_intersection.empty()) {
+ return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space");
+ }
+ }
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 62a546ee2..bbc21813f 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -6701,6 +6701,8 @@ class T5Model(TextModel):
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
self.gguf_writer.add_block_count(self.hparams["num_layers"])
+ if (dec_n_layer := self.hparams.get("num_decoder_layers")) is not None:
+ self.gguf_writer.add_decoder_block_count(dec_n_layer)
self.gguf_writer.add_head_count(self.hparams["num_heads"])
self.gguf_writer.add_key_length(self.hparams["d_kv"])
self.gguf_writer.add_value_length(self.hparams["d_kv"])
diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h
index 1a78935aa..9edd48513 100644
--- a/ggml/include/ggml-cpu.h
+++ b/ggml/include/ggml-cpu.h
@@ -134,6 +134,7 @@ extern "C" {
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
+ GGML_BACKEND_API void ggml_cpu_fp32_to_i32 (const float *, int32_t *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);
diff --git a/ggml/include/ggml-metal.h b/ggml/include/ggml-metal.h
index 68df4c4e7..ba8ab3649 100644
--- a/ggml/include/ggml-metal.h
+++ b/ggml/include/ggml-metal.h
@@ -43,14 +43,8 @@ GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void);
GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend);
-GGML_DEPRECATED(
- GGML_BACKEND_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
- "obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713");
-
GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
-GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
-
// helper to check if the device supports a specific family
// ideally, the user code should be doing these checks
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index e78f2355c..083b3a1aa 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -1429,6 +1429,7 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
+ // note: casting from f32 to i32 will discard the fractional part
GGML_API struct ggml_tensor * ggml_cast(
struct ggml_context * ctx,
struct ggml_tensor * a,
@@ -1553,7 +1554,11 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
- // supports 3D: a->ne[2] == b->ne[1]
+ // supports 4D a:
+ // a [n_embd, ne1, ne2, ne3]
+ // b I32 [n_rows, ne2, ne3, 1]
+ //
+ // return [n_embd, n_rows, ne2, ne3]
GGML_API struct ggml_tensor * ggml_get_rows(
struct ggml_context * ctx,
struct ggml_tensor * a, // data
diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h
index c36c12d65..2db5c4e0f 100644
--- a/ggml/src/ggml-backend-impl.h
+++ b/ggml/src/ggml-backend-impl.h
@@ -114,6 +114,9 @@ extern "C" {
void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event);
// wait for an event on on a different stream
void (*event_wait) (ggml_backend_t backend, ggml_backend_event_t event);
+
+ // (optional) sort/optimize the nodes in the graph
+ void (*optimize_graph) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
};
struct ggml_backend {
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
index 3a13f8447..c944a8fa6 100644
--- a/ggml/src/ggml-backend.cpp
+++ b/ggml/src/ggml-backend.cpp
@@ -463,6 +463,13 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event)
backend->iface.event_wait(backend, event);
}
+static void ggml_backend_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+ GGML_ASSERT(backend);
+ if (backend->iface.optimize_graph != NULL) {
+ backend->iface.optimize_graph(backend, cgraph);
+ }
+}
+
// Backend device
const char * ggml_backend_dev_name(ggml_backend_dev_t device) {
@@ -1304,6 +1311,10 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra
struct ggml_backend_sched_split * split = &sched->splits[i];
split->graph = ggml_graph_view(graph, split->i_start, split->i_end);
+ // Optimize this split of the graph. This needs to happen before we make graph_copy,
+ // so they are in sync.
+ ggml_backend_optimize_graph(sched->backends[split->backend_id], &split->graph);
+
// add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
for (int j = 0; j < split->n_inputs; j++) {
assert(graph_copy->size > (graph_copy->n_nodes + 1));
diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp
index aeac2e574..cdfc5a9bc 100644
--- a/ggml/src/ggml-blas/ggml-blas.cpp
+++ b/ggml/src/ggml-blas/ggml-blas.cpp
@@ -270,6 +270,7 @@ static struct ggml_backend_i blas_backend_i = {
/* .graph_compute = */ ggml_backend_blas_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
+ /* .optimize_graph = */ NULL,
};
static ggml_guid_t ggml_backend_blas_guid(void) {
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index 463f947dd..ead03916f 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -377,6 +377,9 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
},
+ [GGML_TYPE_I32] = {
+ .from_float = (ggml_from_float_t) ggml_cpu_fp32_to_i32,
+ },
};
const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
@@ -3543,7 +3546,10 @@ struct ggml_cplan ggml_graph_plan(
if (ggml_is_quantized(node->type) ||
// F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
(node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
- (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
+ (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16) ||
+ // conversion between F32 and I32
+ (node->src[0]->type == GGML_TYPE_F32 && node->src[1] && node->src[1]->type == GGML_TYPE_I32) ||
+ (node->src[0]->type == GGML_TYPE_I32 && node->src[1] && node->src[1]->type == GGML_TYPE_F32)) {
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
}
} break;
@@ -4124,6 +4130,13 @@ void ggml_cpu_fp32_to_bf16(const float * x, ggml_bf16_t * y, int64_t n) {
}
}
+void ggml_cpu_fp32_to_i32(const float * x, int32_t * y, int64_t n) {
+ int64_t i = 0;
+ for (; i < n; ++i) {
+ y[i] = x[i];
+ }
+}
+
void ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) {
int64_t i = 0;
#if defined(__AVX2__)
diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp
index 890564c8c..eb60df7fe 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.cpp
+++ b/ggml/src/ggml-cpu/ggml-cpu.cpp
@@ -190,6 +190,7 @@ static const struct ggml_backend_i ggml_backend_cpu_i = {
/* .graph_compute = */ ggml_backend_cpu_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
+ /* .optimize_graph = */ NULL,
};
static ggml_guid_t ggml_backend_cpu_guid(void) {
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index 0bb767e01..212e52ef6 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -776,6 +776,24 @@ static void ggml_compute_forward_dup_f32(
id += ne00 * (ne01 - ir1);
}
}
+ } else if (dst->type == GGML_TYPE_I32) {
+ size_t id = 0;
+ int32_t * dst_ptr = (int32_t *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ id += ne00 * ir0;
+ for (int i01 = ir0; i01 < ir1; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = *src0_ptr;
+ id++;
+ }
+ }
+ id += ne00 * (ne01 - ir1);
+ }
+ }
} else {
GGML_ABORT("fatal error"); // TODO: implement
}
@@ -947,6 +965,144 @@ static void ggml_compute_forward_dup_f32(
}
}
}
+ } else if (dst->type == GGML_TYPE_I32) {
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ i10 += ne00 * ir0;
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
+
+ *(int32_t *) dst_ptr = *(const float *) src0_ptr;
+
+ if (++i10 == ne0) {
+ i10 = 0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ i10 += ne00 * (ne01 - ir1);
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ } else {
+ GGML_ABORT("fatal error"); // TODO: implement
+ }
+}
+
+static void ggml_compute_forward_dup_i32(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+
+ const ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ const int ith = params->ith; // thread index
+ const int nth = params->nth; // number of threads
+
+ // parallelize by rows
+ const int nr = ne01;
+ // number of rows per thread
+ const int dr = (nr + nth - 1) / nth;
+ // row range for this thread
+ const int ir0 = dr * ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ // dst counters
+
+ int64_t i10 = 0;
+ int64_t i11 = 0;
+ int64_t i12 = 0;
+ int64_t i13 = 0;
+
+ // TODO: not optimal, but works
+ if (dst->type == GGML_TYPE_F32) {
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ i10 += ne00 * ir0;
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
+
+ *(float *) dst_ptr = *(const int32_t *) src0_ptr;
+
+ if (++i10 == ne0) {
+ i10 = 0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ i10 += ne00 * (ne01 - ir1);
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
} else {
GGML_ABORT("fatal error"); // TODO: implement
}
@@ -1177,6 +1333,10 @@ void ggml_compute_forward_dup(
{
ggml_compute_forward_dup_f32(params, dst);
} break;
+ case GGML_TYPE_I32:
+ {
+ ggml_compute_forward_dup_i32(params, dst);
+ } break;
default:
{
if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
@@ -8438,6 +8598,7 @@ static void ggml_compute_forward_timestep_embedding_f32(
embed_data[j + half] = sinf(arg);
}
if (dim % 2 != 0 && ith == 0) {
+ embed_data[2 * half] = 0.f;
embed_data[dim] = 0.f;
}
}
diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu
index 1c7656634..725e1a81a 100644
--- a/ggml/src/ggml-cuda/binbcast.cu
+++ b/ggml/src/ggml-cuda/binbcast.cu
@@ -23,28 +23,44 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {
return a / b;
}
+template
+static __global__ void k_bin_bcast(const src0_t * src0,
+ const src1_t * src1,
+ dst_t * dst,
+ const int ne0,
+ const int ne1,
+ const int ne2,
+ const uint3 ne3,
+ const uint3 ne10,
+ const uint3 ne11,
+ const uint3 ne12,
+ const uint3 ne13,
+ /*int s0, */ const int s1,
+ const int s2,
+ const int s3,
+ /*int s00,*/ const int s01,
+ const int s02,
+ const int s03,
+ /*int s10,*/ const int s11,
+ const int s12,
+ const int s13,
+ src1_ptrs... src1s) {
+ const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
+ const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
+ const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
+ const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
-
-template
-static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
- const int ne0, const int ne1, const int ne2, const int ne3,
- const int ne10, const int ne11, const int ne12, const int ne13,
- /*int s0, */ const int s1, const int s2, const int s3,
- /*int s00,*/ const int s01, const int s02, const int s03,
- /*int s10,*/ const int s11, const int s12, const int s13,
- src1_ptrs... src1s) {
- const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
- const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
- const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
- const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3;
-
- if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+ if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3.z) {
return;
}
- const int i11 = i1 % ne11;
- const int i12 = i2 % ne12;
- const int i13 = i3 % ne13;
+ const uint32_t i11 = fastmodulo(i1, ne11);
+ const uint32_t i12 = fastmodulo(i2, ne12);
+ const uint32_t i13 = fastmodulo(i3, ne13);
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
@@ -53,8 +69,8 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;
- for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
- const int i10 = i0 % ne10;
+ for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
+ const uint32_t i10 = fastmodulo(i0, ne10);
float result = src0_row ? (float) src0_row[i0] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
@@ -67,28 +83,48 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
}
}
-template
-static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
- const int ne0, const int ne1, const int ne2,const int ne3,
- const int ne10, const int ne11, const int ne12, const int ne13,
- /*int s0, */ const int s1, const int s2, const int s3,
- /*int s00,*/ const int s01, const int s02, const int s03,
- /*int s10,*/ const int s11, const int s12, const int s13,
- src1_ptrs ... src1s) {
+template
+static __global__ void k_bin_bcast_unravel(const src0_t * src0,
+ const src1_t * src1,
+ dst_t * dst,
+ const uint3 ne0,
+ const uint3 ne1,
+ const uint3 ne2,
+ const uint32_t ne3,
+ const uint3 prod_012,
+ const uint3 prod_01,
+ const uint3 ne10,
+ const uint3 ne11,
+ const uint3 ne12,
+ const uint3 ne13,
+ /*int s0, */ const int s1,
+ const int s2,
+ const int s3,
+ /*int s00,*/ const int s01,
+ const int s02,
+ const int s03,
+ /*int s10,*/ const int s11,
+ const int s12,
+ const int s13,
+ src1_ptrs... src1s) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
- const int i3 = i/(ne2*ne1*ne0);
- const int i2 = (i/(ne1*ne0)) % ne2;
- const int i1 = (i/ne0) % ne1;
- const int i0 = i % ne0;
+ const uint32_t i3 = fastdiv(i, prod_012);
+ const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01);
+ const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0);
+ const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z;
- if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+ if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) {
return;
}
- const int i11 = i1 % ne11;
- const int i12 = i2 % ne12;
- const int i13 = i3 % ne13;
+ const int i11 = fastmodulo(i1, ne11);
+ const int i12 = fastmodulo(i2, ne12);
+ const int i13 = fastmodulo(i3, ne13);
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
@@ -97,7 +133,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t *
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;
- const int i10 = i0 % ne10;
+ const int i10 = fastmodulo(i0, ne10);
float result = src0_row ? (float) src0_row[i0] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
@@ -170,11 +206,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
//int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
//int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
- int64_t ne10 = cne1[0];
- int64_t ne11 = cne1[1];
- int64_t ne12 = cne1[2];
- int64_t ne13 = cne1[3];
-
size_t nb0 = cnb[0];
size_t nb1 = cnb[1];
size_t nb2 = cnb[2];
@@ -233,48 +264,51 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
block_dims.y = std::min(ne1, block_size / block_dims.x);
block_dims.z = std::min(std::min(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
- dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x,
- (ne1 + block_dims.y - 1) / block_dims.y,
+ dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y,
(ne2 * ne3 + block_dims.z - 1) / block_dims.z);
+ const uint3 ne10 = init_fastdiv_values((uint32_t) cne1[0]);
+ const uint3 ne11 = init_fastdiv_values((uint32_t) cne1[1]);
+ const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
+ const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
+
if (block_nums.z > 65535) {
- int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
+ int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
+ const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
+ const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
+ const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
+ const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
+ const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);
+
if constexpr (sizeof...(I) > 0) {
- k_bin_bcast_unravel
- <<>>(src0_dd, src1_dd, dst_dd,
- ne0, ne1, ne2, ne3,
- ne10, ne11, ne12, ne13,
- /* s0, */ s1, s2, s3,
- /* s00,*/ s01, s02, s03,
- /* s10,*/ s11, s12,s13,
- (const src1_t *) dst->src[I + 1]->data...);
+ k_bin_bcast_unravel<<>>(
+ src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
+ ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast_unravel
- <<>>(src0_dd, src1_dd, dst_dd,
- ne0, ne1, ne2, ne3,
- ne10, ne11, ne12, ne13,
- /* s0, */ s1, s2, s3,
- /* s00,*/ s01, s02, s03,
- /* s10,*/ s11, s12,s13);
+ <<>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
+ ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12, s13);
}
} else {
+ const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
if constexpr (sizeof...(I) > 0) {
- k_bin_bcast
- <<>>(src0_dd, src1_dd, dst_dd,
- ne0, ne1, ne2, ne3,
- ne10, ne11, ne12, ne13,
- /* s0, */ s1, s2, s3,
- /* s00,*/ s01, s02, s03,
- /* s10,*/ s11, s12,s13,
- (const src1_t *) dst->src[I + 1]->data...);
+ k_bin_bcast<<>>(
+ src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
- k_bin_bcast
- <<>>(src0_dd, src1_dd, dst_dd,
- ne0, ne1, ne2, ne3,
- ne10, ne11, ne12, ne13,
- /* s0, */ s1, s2, s3,
- /* s00,*/ s01, s02, s03,
- /* s10,*/ s11, s12,s13);
+ k_bin_bcast<<>>(
+ src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12, s13);
}
}
}
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 2e272da43..c205bb0f9 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -550,6 +550,31 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
#endif // defined(GGML_USE_HIP)
}
+static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float v, const float u) {
+ acc += v*u;
+}
+
+static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v, const float2 u) {
+ acc += v.x*u.x;
+ acc += v.y*u.y;
+}
+
+static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
+#if defined(GGML_USE_HIP) && defined(GCN)
+ asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
+#else
+#ifdef FAST_FP16_AVAILABLE
+ const float2 tmp = __half22float2(v*u);
+ acc += tmp.x + tmp.y;
+#else
+ const float2 tmpv = __half22float2(v);
+ const float2 tmpu = __half22float2(u);
+ acc += tmpv.x * tmpu.x;
+ acc += tmpv.y * tmpu.y;
+#endif // FAST_FP16_AVAILABLE
+#endif // defined(GGML_USE_HIP) && defined(GCN)
+}
+
static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
#if CUDART_VERSION >= 12080
const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);
diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh
index c62e8a1b1..ef9e12995 100644
--- a/ggml/src/ggml-cuda/convert.cuh
+++ b/ggml/src/ggml-cuda/convert.cuh
@@ -38,6 +38,8 @@ template
return __float2bfloat16(float(x));
} else if constexpr(std::is_same_v) {
return __bfloat162float(x);
+ } else if constexpr(std::is_same_v) {
+ return int32_t(x);
} else {
return float(x);
}
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
index c40db08ce..8567c3d5a 100644
--- a/ggml/src/ggml-cuda/cpy.cu
+++ b/ggml/src/ggml-cuda/cpy.cu
@@ -374,6 +374,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
+ ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+ } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
+ ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu
index fb2163acd..64f7d4a1a 100644
--- a/ggml/src/ggml-cuda/fattn-tile.cu
+++ b/ggml/src/ggml-cuda/fattn-tile.cu
@@ -304,12 +304,7 @@ static __global__ void flash_attn_tile(
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
-#ifdef FAST_FP16_AVAILABLE
- const float2 tmp = __half22float2(K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps]);
- sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += tmp.x + tmp.y;
-#else
- sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps];
-#endif // FAST_FP16_AVAILABLE
+ ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size], Q_k[j_KQ_0/nwarps]);
}
}
}
diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu
index 83d02474f..2fab33243 100644
--- a/ggml/src/ggml-cuda/getrows.cu
+++ b/ggml/src/ggml-cuda/getrows.cu
@@ -2,39 +2,39 @@
#include "dequantize.cuh"
#include "convert.cuh"
-#define MAX_GRIDDIM_Y 65535
-
template
static __global__ void k_get_rows(
const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
- /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
+ /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
- for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
- // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
- const int i10 = blockIdx.x;
- const int i11 = blockIdx.z / ne12;
- const int i12 = blockIdx.z % ne12;
+ for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
+ for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
+ // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
+ const int i10 = blockIdx.x;
+ const int i11 = z / ne12; // TODO fastdiv
+ const int i12 = z % ne12;
- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
- const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+ const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
- const int ib = i00/qk; // block index
- const int iqs = (i00%qk)/qr; // quant index
- const int iybs = i00 - i00%qk; // dst block start index
- const int y_offset = qr == 1 ? 1 : qk/2;
+ const int ib = i00/qk; // block index
+ const int iqs = (i00%qk)/qr; // quant index
+ const int iybs = i00 - i00%qk; // dst block start index
+ const int y_offset = qr == 1 ? 1 : qk/2;
- // dequantize
- float2 v;
- dequantize_kernel(src0_row, ib, iqs, v);
+ // dequantize
+ float2 v;
+ dequantize_kernel(src0_row, ib, iqs, v);
- dst_row[iybs + iqs + 0] = ggml_cuda_cast(v.x);
- dst_row[iybs + iqs + y_offset] = ggml_cuda_cast(v.y);
+ dst_row[iybs + iqs + 0] = ggml_cuda_cast(v.x);
+ dst_row[iybs + iqs + y_offset] = ggml_cuda_cast(v.y);
+ }
}
}
@@ -42,27 +42,29 @@ template
static __global__ void k_get_rows_float(
const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
- /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
+ /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
- for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
- // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
- const int i10 = blockIdx.x;
- const int i11 = blockIdx.z / ne12;
- const int i12 = blockIdx.z % ne12;
+ for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
+ for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
+ // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
+ const int i10 = blockIdx.x;
+ const int i11 = z / ne12; // TODO fastdiv
+ const int i12 = z % ne12;
- if (i00 >= ne00) {
- return;
+ if (i00 >= ne00) {
+ return;
+ }
+
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+ const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
+
+ dst_row[i00] = ggml_cuda_cast(src0_row[i00]);
}
-
- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
-
- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
- const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
-
- dst_row[i00] = ggml_cuda_cast(src0_row[i00]);
}
}
@@ -98,7 +100,7 @@ static void get_rows_cuda_q(
cudaStream_t stream) {
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
- const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12);
+ const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX));
// strides in elements
// const size_t s0 = nb0 / sizeof(dst_t);
@@ -116,7 +118,7 @@ static void get_rows_cuda_q(
k_get_rows<<>>(
src0_d, src1_d, dst_d,
ne00, /*ne01, ne02, ne03,*/
- /*ne10, ne11,*/ ne12, /*ne13,*/
+ /*ne10,*/ ne11, ne12, /*ne13,*/
/* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03,
s10, s11, s12/*, s13*/);
@@ -131,7 +133,7 @@ static void get_rows_cuda_float(
cudaStream_t stream) {
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
- const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12);
+ const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX));
// strides in elements
// const size_t s0 = nb0 / sizeof(dst_t);
@@ -147,7 +149,7 @@ static void get_rows_cuda_float(
k_get_rows_float<<>>(
src0_d, src1_d, dst_d,
ne00, /*ne01, ne02, ne03,*/
- /*ne10, ne11,*/ ne12, /*ne13,*/
+ /*ne10,*/ ne11, ne12, /*ne13,*/
/* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03,
s10, s11, s12/*, s13*/);
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index d9bd4a9f1..391ac4b7a 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -2118,6 +2118,11 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
return;
}
+
+ if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2])) {
+ ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
+ return;
+ }
}
cudaStream_t stream = ctx.stream();
@@ -3148,6 +3153,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
/* .graph_compute = */ ggml_backend_cuda_graph_compute,
/* .event_record = */ ggml_backend_cuda_event_record,
/* .event_wait = */ ggml_backend_cuda_event_wait,
+ /* .optimize_graph = */ NULL,
};
static ggml_guid_t ggml_backend_cuda_guid() {
@@ -3474,6 +3480,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
return true;
}
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {
+ return true;
+ }
if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
return true;
}
@@ -3587,9 +3599,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_GROUP_NORM:
+ case GGML_OP_PAD:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_UPSCALE:
- case GGML_OP_PAD:
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh
index 667deb9c6..c1f24243f 100644
--- a/ggml/src/ggml-cuda/mma.cuh
+++ b/ggml/src/ggml-cuda/mma.cuh
@@ -1,3 +1,4 @@
+#pragma once
// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
// The documentation for the PTX instructions can be found under:
diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu
index cfa5c5cce..16331e9ec 100644
--- a/ggml/src/ggml-cuda/mmf.cu
+++ b/ggml/src/ggml-cuda/mmf.cu
@@ -1,343 +1,12 @@
#include "ggml.h"
-#include "common.cuh"
-#include "mma.cuh"
#include "mmf.cuh"
-using namespace ggml_cuda_mma;
-
-#define MMF_ROWS_PER_BLOCK 32
-
-template
-__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
-static __global__ void mul_mat_f(
- const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
- const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst,
- const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
- const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
- typedef tile<16, 8, T> tile_A;
- typedef tile< 8, 8, T> tile_B;
- typedef tile<16, 8, float> tile_C;
-
- constexpr int warp_size = ggml_cuda_get_physical_warp_size();
- constexpr int tile_k_padded = warp_size + 4;
- constexpr int ntA = rows_per_block / tile_A::I;
- constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
-
- const int row0 = blockIdx.x * rows_per_block;
- const int channel_dst = blockIdx.y;
- const int channel_x = channel_dst / channel_ratio;
- const int channel_y = channel_dst;
- const int sample_dst = blockIdx.z;
- const int sample_x = sample_dst / sample_ratio;
- const int sample_y = sample_dst;
-
- x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
- y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
- dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
-
- const float2 * y2 = (const float2 *) y;
-
- extern __shared__ char data_mmv[];
-
- tile_C C[ntA][ntB];
-
- T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded);
-
- for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
- tile_A A[ntA][warp_size / tile_A::J];
-#pragma unroll
- for (int itA = 0; itA < ntA; ++itA) {
-#pragma unroll
- for (int i = 0; i < tile_A::I; ++i) {
- tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
- }
-#pragma unroll
- for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
- load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
- }
- }
-
-#pragma unroll
- for (int itB = 0; itB < ntB; ++itB) {
- if constexpr (std::is_same_v) {
-#pragma unroll
- for (int j0 = 0; j0 < tile_B::I; ++j0) {
- const int j = j0 + itB*tile_B::I;
-
- tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
- }
- } else if constexpr (std::is_same_v || std::is_same_v) {
-#pragma unroll
- for (int j0 = 0; j0 < tile_B::I; ++j0) {
- const int j = j0 + itB*tile_B::I;
-
- const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
- tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
- }
- } else {
- static_assert(std::is_same_v, "unsupported type");
- }
-#pragma unroll
- for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
- tile_B B;
- load_ldmatrix(B, tile_xy + k0, tile_k_padded);
-#pragma unroll
- for (int itA = 0; itA < ntA; ++itA) {
- mma(C[itA][itB], A[itA][k0/tile_B::J], B);
- }
- }
- }
- }
-
- float * buf_iw = (float *) data_mmv;
- constexpr int kiw = nwarps*rows_per_block + 4;
-
- if (nwarps > 1) {
- __syncthreads();
- }
-#pragma unroll
- for (int itB = 0; itB < ntB; ++itB) {
-#pragma unroll
- for (int itA = 0; itA < ntA; ++itA) {
-#pragma unroll
- for (int l = 0; l < tile_C::ne; ++l) {
- const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
- const int j = itB*tile_C::J + tile_C::get_j(l);
- buf_iw[j*kiw + i] = C[itA][itB].x[l];
- }
- }
- }
-
- if (nwarps > 1) {
- __syncthreads();
- }
-
-#pragma unroll
- for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
- const int j = j0 + threadIdx.y;
-
- if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
- return;
- }
-
- float sum = 0.0f;
- static_assert(rows_per_block == warp_size, "need loop/check");
-#pragma unroll
- for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
- const int i = i0 + threadIdx.x;
-
- sum += buf_iw[j*kiw + i];
- }
- dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
- }
-#else
- GGML_UNUSED_VARS(x, y, ids, dst,
- ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
- NO_DEVICE_CODE;
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
-}
-
-template
-static void mul_mat_f_cuda(
- const T * x, const float * y, const int32_t * ids, float * dst,
- const int64_t ncols_x, const int64_t nrows_x,
- const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
- const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
- const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
- const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
- cudaStream_t stream) {
- typedef tile<16, 8, T> tile_A;
- typedef tile< 8, 8, T> tile_B;
-
- GGML_ASSERT(!ids && "mul_mat_id not implemented");
-
- GGML_ASSERT(ncols_x % 2 == 0);
- GGML_ASSERT(stride_row % 2 == 0);
- GGML_ASSERT(stride_col_y % 2 == 0);
- GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
- GGML_ASSERT( nsamples_dst % nsamples_x == 0);
- const int64_t channel_ratio = nchannels_dst / nchannels_x;
- const int64_t sample_ratio = nsamples_dst / nsamples_x;
-
- const int device = ggml_cuda_get_device();
- const int warp_size = ggml_cuda_info().devices[device].warp_size;
-
- int64_t nwarps_best = 1;
- int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
- int64_t max_block_size = 256;
- for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
- const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
- if (niter < niter_best) {
- niter_best = niter;
- nwarps_best = nwarps;
- }
- }
-
- constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
- const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
- const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
- const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
- const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst);
- const dim3 block_dims(warp_size, nwarps_best, 1);
- switch (nwarps_best) {
- case 1: {
- mul_mat_f<<>>
- (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
- } break;
- case 2: {
- mul_mat_f<<>>
- (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
- } break;
- case 3: {
- mul_mat_f<<>>
- (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
- } break;
- case 4: {
- mul_mat_f<<>>
- (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
- } break;
- case 5: {
- mul_mat_f<<>>
- (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
- } break;
- case 6: {
- mul_mat_f<<>>
- (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
- } break;
- case 7: {
- mul_mat_f<<>>
- (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
- } break;
- case 8: {
- mul_mat_f<<>>
- (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
- } break;
- default: {
- GGML_ABORT("fatal error");
- } break;
- }
-}
-
-template
-static void mul_mat_f_switch_cols_per_block(
- const T * x, const float * y, const int32_t * ids, float * dst,
- const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
- const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
- const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
- const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
- const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
- cudaStream_t stream) {
- switch (ncols_dst) {
- case 1: {
- mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 2: {
- mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 3: {
- mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 4: {
- mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 5: {
- mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 6: {
- mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 7: {
- mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 8: {
- mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 9: {
- mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 10: {
- mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 11: {
- mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 12: {
- mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 13: {
- mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 14: {
- mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 15: {
- mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 16: {
- mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- default: {
- GGML_ABORT("fatal error");
- } break;
- }
-}
-
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
GGML_TENSOR_BINARY_OP_LOCALS;
const size_t ts_src0 = ggml_type_size(src0->type);
@@ -365,55 +34,72 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
const int64_t s13 = src1->nb[3] / ts_src1;
const int64_t s3 = dst->nb[3] / ts_dst;
+ const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
+ const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
+
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
const int64_t ncols_dst = ids ? ne2 : ne1;
- const int64_t nchannels_y = ids ? ne11 : ne12;
- const int64_t nchannels_dst = ids ? ne1 : ne2;
- const int64_t stride_channel_dst = ids ? s1 : s2;
- const int64_t stride_channel_y = ids ? s11 : s12;
+ const int64_t nchannels_dst = ids ? ne1 : ne2;
- GGML_ASSERT(!ids || ncols_dst == 1);
+ const int64_t stride_col_dst = ids ? s2 : s1;
+ const int64_t stride_col_y = ids ? s12 : s11;
+ const int64_t stride_channel_dst = ids ? s1 : s2;
+
+ int64_t stride_channel_y = ids ? s11 : s12;
+ int64_t nchannels_y = ids ? ne11 : ne12;
+
+ //mul_mat_id: handle broadcast
+ if (ids && nchannels_y == 1) {
+ stride_channel_y = 0;
+ nchannels_y = ids->ne[0];
+ }
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
constexpr int vals_per_T = 1;
mul_mat_f_switch_cols_per_block(
- src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
- ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
- ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
+ src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
+ ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
+ ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
} break;
case GGML_TYPE_F16: {
const half2 * src0_d = (const half2 *) src0->data;
constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block(
- src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
- ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
- ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
+ src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
+ ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
+ ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
} break;
case GGML_TYPE_BF16: {
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block(
- src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
- ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
- ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
+ src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
+ ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
+ ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
} break;
default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
}
}
-bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) {
+bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols) {
+
+ if (ggml_is_quantized(type)) {
+ return false;
+ }
+
if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) {
return false;
}
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
return false;
}
- if (ne11 > 16) {
+ if (src1_ncols > 16) {
return false;
}
+
switch (type) {
case GGML_TYPE_F32:
return ampere_mma_available(cc);
diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh
index 785f9f211..bf724bc57 100644
--- a/ggml/src/ggml-cuda/mmf.cuh
+++ b/ggml/src/ggml-cuda/mmf.cuh
@@ -1,5 +1,473 @@
+#pragma once
+
+#include "mma.cuh"
#include "common.cuh"
+using namespace ggml_cuda_mma;
+
+#define MMF_ROWS_PER_BLOCK 32
+
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
-bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11);
+bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols);
+
+template
+__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
+static __global__ void mul_mat_f(
+ const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
+ const int ncols, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
+ const int stride_col_id, const int stride_row_id,
+ const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+ const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+ typedef tile<16, 8, T> tile_A;
+ typedef tile< 8, 8, T> tile_B;
+ typedef tile<16, 8, float> tile_C;
+
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+ constexpr int tile_k_padded = warp_size + 4;
+ constexpr int ntA = rows_per_block / tile_A::I;
+ constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
+
+ const int row0 = blockIdx.x * rows_per_block;
+
+ const int expert_idx = has_ids ? blockIdx.y : 0;
+ const int channel_dst = has_ids ? 0 : blockIdx.y;
+
+ const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio);
+ const int channel_y = channel_dst;
+ const int sample_dst = blockIdx.z;
+ const int sample_x = sample_dst / sample_ratio;
+ const int sample_y = sample_dst;
+
+ x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
+ y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y);
+ dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst);
+
+ const float2 * y2 = (const float2 *) y;
+
+ extern __shared__ char data_mmv[];
+
+ char * shmem_base = data_mmv;
+ int * slot_map = (int *) shmem_base;
+ char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base;
+
+ tile_C C[ntA][ntB];
+
+ T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
+
+ if constexpr (has_ids) {
+ __shared__ int has_any;
+ if (threadIdx.y == 0) {
+ int local_has_any = 0;
+ for (int j = threadIdx.x; j < cols_per_block; j += warp_size) {
+ int slot = -1;
+ for (int k = 0; k < nchannels_dst; ++k) {
+ const int idv = ids[j*stride_row_id + k*stride_col_id];
+ if (idv == expert_idx) {
+ slot = k;
+ break;
+ }
+ }
+ if (j < cols_per_block) {
+ local_has_any |= (slot >= 0);
+ slot_map[j] = slot;
+ }
+ }
+ has_any = warp_reduce_any(local_has_any);
+ }
+ __syncthreads();
+ if (has_any == 0) {
+ return;
+ }
+ }
+
+ for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
+ tile_A A[ntA][warp_size / tile_A::J];
+#pragma unroll
+ for (int itA = 0; itA < ntA; ++itA) {
+#pragma unroll
+ for (int i = 0; i < tile_A::I; ++i) {
+ tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
+ }
+#pragma unroll
+ for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
+ load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
+ }
+ }
+
+#pragma unroll
+ for (int itB = 0; itB < ntB; ++itB) {
+ if constexpr (std::is_same_v) {
+#pragma unroll
+ for (int j0 = 0; j0 < tile_B::I; ++j0) {
+ const int j = j0 + itB*tile_B::I;
+
+ if constexpr (!has_ids) {
+ tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
+ } else {
+ float val = 0.0f;
+ if (j < cols_per_block) {
+ const int slot = slot_map[j];
+ if (slot >= 0) {
+ val = y[slot*stride_channel_y + j*stride_col_y + col];
+ }
+ }
+ tile_xy[j0*tile_k_padded + threadIdx.x] = val;
+ }
+ }
+ } else if constexpr (std::is_same_v || std::is_same_v) {
+#pragma unroll
+ for (int j0 = 0; j0 < tile_B::I; ++j0) {
+ const int j = j0 + itB*tile_B::I;
+
+ if constexpr (!has_ids) {
+ const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
+ tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
+ } else {
+ float2 tmp = make_float2(0.0f, 0.0f);
+ if (j < cols_per_block) {
+ const int slot = slot_map[j];
+ if (slot >= 0) {
+ const float2 * y2_slot = (const float2 *)(y + slot*stride_channel_y);
+ tmp = y2_slot[j*stride_col_y + col];
+ }
+ }
+ tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
+ }
+ }
+ } else {
+ static_assert(std::is_same_v, "unsupported type");
+ }
+#pragma unroll
+ for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
+ tile_B B;
+ load_ldmatrix(B, tile_xy + k0, tile_k_padded);
+#pragma unroll
+ for (int itA = 0; itA < ntA; ++itA) {
+ mma(C[itA][itB], A[itA][k0/tile_B::J], B);
+ }
+ }
+ }
+ }
+
+ float * buf_iw = (float *) compute_base;
+ constexpr int kiw = nwarps*rows_per_block + 4;
+
+ if (nwarps > 1) {
+ __syncthreads();
+ }
+#pragma unroll
+ for (int itB = 0; itB < ntB; ++itB) {
+#pragma unroll
+ for (int itA = 0; itA < ntA; ++itA) {
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
+ const int j = itB*tile_C::J + tile_C::get_j(l);
+ buf_iw[j*kiw + i] = C[itA][itB].x[l];
+ }
+ }
+ }
+
+ if (nwarps > 1) {
+ __syncthreads();
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
+ return;
+ }
+
+ float sum = 0.0f;
+ static_assert(rows_per_block == warp_size, "need loop/check");
+#pragma unroll
+ for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
+ const int i = i0 + threadIdx.x;
+
+ sum += buf_iw[j*kiw + i];
+ }
+
+ if constexpr (!has_ids) {
+ dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
+ } else {
+ const int slot = (j < cols_per_block) ? slot_map[j] : -1;
+ if (slot >= 0) {
+ dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
+ }
+ }
+ }
+#else
+ GGML_UNUSED_VARS(x, y, ids, dst,
+ ncols, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ NO_DEVICE_CODE;
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+}
+
+template
+static inline void mul_mat_f_switch_ids(
+ const T * x, const float * y, const int32_t * ids, float * dst,
+ const int64_t ncols_x, const int64_t nchannels_dst,
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+ const int64_t stride_col_id, const int64_t stride_row_id,
+ const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+ const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+ const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
+ if (ids) {
+ mul_mat_f<<>>
+ (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ } else {
+ mul_mat_f<<>>
+ (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ }
+}
+
+template
+void mul_mat_f_cuda(
+ const T * x, const float * y, const int32_t * ids, float * dst,
+ const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+ const int64_t stride_col_id, const int64_t stride_row_id,
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+ cudaStream_t stream) {
+ typedef tile<16, 8, T> tile_A;
+ typedef tile< 8, 8, T> tile_B;
+
+ GGML_ASSERT(ncols_x % 2 == 0);
+ GGML_ASSERT(stride_row % 2 == 0);
+ GGML_ASSERT(stride_col_y % 2 == 0);
+ GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
+ GGML_ASSERT( nsamples_dst % nsamples_x == 0);
+ const int64_t channel_ratio = nchannels_dst / nchannels_x;
+ const int64_t sample_ratio = nsamples_dst / nsamples_x;
+
+ const int device = ggml_cuda_get_device();
+ const int warp_size = ggml_cuda_info().devices[device].warp_size;
+
+ int64_t nwarps_best = 1;
+ int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
+ int64_t max_block_size = 256;
+ for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
+ const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
+ if (niter < niter_best) {
+ niter_best = niter;
+ nwarps_best = nwarps;
+ }
+ }
+
+ constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
+ const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
+ const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
+ const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
+ const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
+ const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
+ const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present
+
+ const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
+ const dim3 block_dims(warp_size, nwarps_best, 1);
+
+ switch (nwarps_best) {
+ case 1: {
+ mul_mat_f_switch_ids(
+ x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ } break;
+ case 2: {
+ mul_mat_f_switch_ids(
+ x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ } break;
+ case 3: {
+ mul_mat_f_switch_ids(
+ x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ } break;
+ case 4: {
+ mul_mat_f_switch_ids(
+ x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ } break;
+ case 5: {
+ mul_mat_f_switch_ids(
+ x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ } break;
+ case 6: {
+ mul_mat_f_switch_ids(
+ x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ } break;
+ case 7: {
+ mul_mat_f_switch_ids(
+ x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ } break;
+ case 8: {
+ mul_mat_f_switch_ids(
+ x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ } break;
+ default: {
+ GGML_ABORT("fatal error");
+ } break;
+ }
+
+ GGML_UNUSED_VARS(nchannels_y);
+}
+
+template
+static void mul_mat_f_switch_cols_per_block(
+ const T * x, const float * y, const int32_t * ids, float * dst,
+ const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+ const int64_t stride_col_id, const int stride_row_id,
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+ cudaStream_t stream) {
+ switch (ncols_dst) {
+ case 1: {
+ mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 2: {
+ mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 3: {
+ mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 4: {
+ mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 5: {
+ mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 6: {
+ mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 7: {
+ mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 8: {
+ mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 9: {
+ mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 10: {
+ mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 11: {
+ mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 12: {
+ mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 13: {
+ mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 14: {
+ mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 15: {
+ mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 16: {
+ mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ default: {
+ GGML_ABORT("fatal error");
+ } break;
+ }
+}
+
+#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
+ template void mul_mat_f_cuda( \
+ const T * x, const float * y, const int32_t * ids, float * dst, \
+ const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
+ const int64_t stride_col_id, const int64_t stride_row_id, \
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
+ cudaStream_t stream);
+
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#define DECL_MMF_CASE_EXTERN(ncols_dst) \
+ extern DECL_MMF_CASE_HELPER(float, ncols_dst) \
+ extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \
+ extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
+
+#define DECL_MMF_CASE(ncols_dst) \
+ DECL_MMF_CASE_HELPER(float, ncols_dst) \
+ DECL_MMF_CASE_HELPER(half2, ncols_dst) \
+ DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
+
+DECL_MMF_CASE_EXTERN(1);
+DECL_MMF_CASE_EXTERN(2);
+DECL_MMF_CASE_EXTERN(3);
+DECL_MMF_CASE_EXTERN(4);
+DECL_MMF_CASE_EXTERN(5);
+DECL_MMF_CASE_EXTERN(6);
+DECL_MMF_CASE_EXTERN(7);
+DECL_MMF_CASE_EXTERN(8);
+DECL_MMF_CASE_EXTERN(9);
+DECL_MMF_CASE_EXTERN(10);
+DECL_MMF_CASE_EXTERN(11);
+DECL_MMF_CASE_EXTERN(12);
+DECL_MMF_CASE_EXTERN(13);
+DECL_MMF_CASE_EXTERN(14);
+DECL_MMF_CASE_EXTERN(15);
+DECL_MMF_CASE_EXTERN(16);
+#else
+#define DECL_MMF_CASE(ncols_dst)
+#endif
diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
index 3428113dc..da2d7b7c3 100755
--- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
+++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
@@ -24,7 +24,7 @@ TYPES_MMQ = [
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
- "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS"
+ "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4"
]
SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
@@ -34,6 +34,13 @@ SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do
DECL_MMQ_CASE({type});
"""
+SOURCE_MMF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE({type});
+"""
+
def get_short_name(long_quant_name):
return long_quant_name.replace("GGML_TYPE_", "").lower()
@@ -76,3 +83,7 @@ for ncols in [8, 16, 32, 64]:
for type in TYPES_MMQ:
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
f.write(SOURCE_MMQ.format(type=type))
+
+for type in range(1, 17):
+ with open(f"mmf-instance-ncols_{type}.cu", "w") as f:
+ f.write(SOURCE_MMF.format(type=type))
diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu
new file mode 100644
index 000000000..f594d5d51
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(1);
diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu
new file mode 100644
index 000000000..9cc677254
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(10);
diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu
new file mode 100644
index 000000000..317f487d7
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(11);
diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu
new file mode 100644
index 000000000..dc0033227
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(12);
diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu
new file mode 100644
index 000000000..078210175
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(13);
diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu
new file mode 100644
index 000000000..a23ad6ae2
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(14);
diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu
new file mode 100644
index 000000000..0fe3f7821
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(15);
diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu
new file mode 100644
index 000000000..544086375
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(16);
diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu
new file mode 100644
index 000000000..3b901797c
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(2);
diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu
new file mode 100644
index 000000000..56e940bba
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(3);
diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu
new file mode 100644
index 000000000..a7665d49d
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(4);
diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu
new file mode 100644
index 000000000..3a1dff258
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(5);
diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu
new file mode 100644
index 000000000..400fb7c66
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(6);
diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu
new file mode 100644
index 000000000..954a1c7e0
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(7);
diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu
new file mode 100644
index 000000000..f1bd09c94
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(8);
diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu
new file mode 100644
index 000000000..1255ac2af
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(9);
diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
index b9d363944..651943fa9 100644
--- a/ggml/src/ggml-metal/ggml-metal-impl.h
+++ b/ggml/src/ggml-metal/ggml-metal-impl.h
@@ -20,8 +20,8 @@
#define N_R0_Q5_1 4
#define N_SG_Q5_1 2
-#define N_R0_Q8_0 4
-#define N_SG_Q8_0 2
+#define N_R0_Q8_0 2
+#define N_SG_Q8_0 4
#define N_R0_MXFP4 2
#define N_SG_MXFP4 2
@@ -68,6 +68,11 @@
#define N_R0_IQ4_XS 2
#define N_SG_IQ4_XS 2
+// function constants offsets
+#define FC_FLASH_ATTN_EXT 100
+#define FC_FLASH_ATTN_EXT_VEC 200
+#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
+
// kernel argument structs
//
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
@@ -236,9 +241,11 @@ typedef struct {
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
+ int32_t ns10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
+ int32_t ns20;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
@@ -258,10 +265,43 @@ typedef struct {
float logit_softcap;
} ggml_metal_kargs_flash_attn_ext;
+typedef struct {
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne11;
+ int32_t ne_12_2; // assume K and V are same shape
+ int32_t ne_12_3;
+ int32_t ns10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ int32_t ns20;
+ uint64_t nb21;
+ uint64_t nb22;
+ uint64_t nb23;
+ int32_t ne32;
+ int32_t ne33;
+ uint64_t nb31;
+ uint64_t nb32;
+ uint64_t nb33;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ float scale;
+ float max_bias;
+ float m0;
+ float m1;
+ int32_t n_head_log2;
+ float logit_softcap;
+} ggml_metal_kargs_flash_attn_ext_vec;
+
typedef struct {
int32_t nrows;
- int32_t ne20;
-} ggml_metal_kargs_flash_attn_ext_reduce;
+} ggml_metal_kargs_flash_attn_ext_vec_reduce;
typedef struct {
int32_t ne00;
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index 132feb15b..7adb1b2a1 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -48,6 +48,11 @@ static struct ggml_backend_metal_device_context {
int mtl_device_ref_count;
id mtl_library;
+ // a single global queue shared by all Metal backends
+ // technically not needed for devices with unified memory, but enables discrete GPUs support
+ // ref: https://github.com/ggml-org/llama.cpp/pull/15906
+ id mtl_queue;
+
NSLock * mtl_lock;
bool has_simdgroup_reduction;
@@ -56,6 +61,7 @@ static struct ggml_backend_metal_device_context {
bool has_bfloat;
bool use_bfloat;
bool use_fusion;
+ bool use_shared_buffers;
int debug_fusion;
@@ -69,6 +75,7 @@ static struct ggml_backend_metal_device_context {
/*.mtl_device =*/ nil,
/*.mtl_device_ref_count =*/ 0,
/*.mtl_library =*/ nil,
+ /*.mtl_queue =*/ nil,
/*.mtl_lock =*/ nil,
/*.has_simdgroup_reduction =*/ false,
/*.has_simdgroup_mm =*/ false,
@@ -76,6 +83,7 @@ static struct ggml_backend_metal_device_context {
/*.has_bfloat =*/ false,
/*.use_bfloat =*/ false,
/*.use_fusion =*/ true,
+ /*.use_shared_buffers =*/ true,
/*.debug_fusion =*/ 0,
/*.fuse_cnt =*/ { 0 },
/*.max_size =*/ 0,
@@ -94,6 +102,11 @@ static id ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
ctx->mtl_device = MTLCreateSystemDefaultDevice();
if (ctx->mtl_device) {
+ ctx->mtl_queue = [ctx->mtl_device newCommandQueue];
+ if (ctx->mtl_queue == nil) {
+ GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
+ }
+
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
@@ -118,6 +131,12 @@ static id ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
ctx->debug_fusion = val ? atoi(val) : 0;
}
+ ctx->use_shared_buffers = ctx->mtl_device.hasUnifiedMemory;
+
+ if (getenv("GGML_METAL_SHARED_BUFFERS_DISABLE") != NULL) {
+ ctx->use_shared_buffers = false;
+ }
+
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
ctx->max_size = ctx->mtl_device.maxBufferLength;
@@ -161,6 +180,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
ctx->mtl_library = nil;
}
+ if (ctx->mtl_queue) {
+ [ctx->mtl_queue release];
+ ctx->mtl_queue = nil;
+ }
+
if (ctx->mtl_device) {
[ctx->mtl_device release];
ctx->mtl_device = nil;
@@ -174,6 +198,19 @@ struct ggml_metal_kernel {
id pipeline;
};
+@interface ggml_metal_kernel_wrapper : NSObject
+
+@property (nonatomic, assign) struct ggml_metal_kernel kernel;
+
+@end
+
+@implementation ggml_metal_kernel_wrapper
+- (void) dealloc {
+ [_kernel.pipeline release];
+ [super dealloc];
+}
+@end
+
enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_ADD,
GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
@@ -454,128 +491,6 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE,
- GGML_METAL_KERNEL_TYPE_SET_I32,
- GGML_METAL_KERNEL_TYPE_SET_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
@@ -583,6 +498,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,
+ GGML_METAL_KERNEL_TYPE_CPY_F32_I32,
+ GGML_METAL_KERNEL_TYPE_CPY_I32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
@@ -878,12 +795,16 @@ struct ggml_metal_command_buffer {
struct ggml_backend_metal_context {
id device;
- id queue;
+ id queue; // currently a pointer to the device queue, but might become separate queue [TAG_QUEUE_PER_BACKEND]
dispatch_queue_t d_queue;
+ // the set of pre-compiled kernels for this context
struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
+ // additional, inference-time compiled kernels
+ NSMutableDictionary * kernels_ext;
+
// capture state
bool capture_next_compute;
bool capture_started;
@@ -904,6 +825,12 @@ struct ggml_backend_metal_context {
// n_cb command buffers + 1 used by the main thread
struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
+ // extra command buffers for things like getting, setting and copying tensors
+ NSMutableArray * cmd_bufs_ext;
+
+ // the last command buffer queued into the Metal queue with operations relevant to the current Metal backend
+ id cmd_buf_last;
+
// abort ggml_metal_graph_compute if callback returns true
ggml_abort_callback abort_callback;
void * abort_callback_data;
@@ -949,6 +876,8 @@ static void * ggml_metal_host_malloc(size_t n) {
// - if not found, load the source and compile it
// - if that fails, return NULL
static id ggml_metal_load_library(id device, bool use_bfloat) {
+ const int64_t t_start = ggml_time_us();
+
id metal_library = nil;
NSError * error = nil;
NSString * src = nil;
@@ -1072,6 +1001,8 @@ static id ggml_metal_load_library(id device, bool use_bfl
[src release];
#endif // GGML_METAL_EMBED_LIBRARY
+ GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
+
return metal_library;
}
@@ -1096,7 +1027,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
ctx->device = device;
- ctx->queue = [device newCommandQueue];
+
+ // TODO: question - would it be better to have one queue for the backend and one queue for the device?
+ // the graph encoders and async ops would use the backend queue while the sync ops would use the device queue?
+ //ctx->queue = [device newCommandQueue]; [TAG_QUEUE_PER_BACKEND]
+ ctx->queue = ctx_dev->mtl_queue;
if (ctx->queue == nil) {
GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
return NULL;
@@ -1155,6 +1090,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_LOG_INFO("%s: has residency sets = %s\n", __func__, ctx_dev->has_residency_sets ? "true" : "false");
GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false");
+ GGML_LOG_INFO("%s: use fusion = %s\n", __func__, ctx_dev->use_fusion ? "true" : "false");
+ GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, ctx_dev->use_shared_buffers ? "true" : "false");
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
ctx->capture_next_compute = false;
@@ -1170,6 +1107,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
ctx->cmd_bufs[i].mem_pool->device = device;
}
+ ctx->cmd_bufs_ext = [[NSMutableArray alloc] init];
+
+ ctx->cmd_buf_last = nil;
+
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
if (@available(macOS 10.12, iOS 16.0, *)) {
GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6);
@@ -1269,7 +1210,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
@@ -1487,128 +1428,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40, flash_attn_ext_f16_h40, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40, flash_attn_ext_bf16_h40, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40, flash_attn_ext_q4_0_h40, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40, flash_attn_ext_q4_1_h40, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40, flash_attn_ext_q5_0_h40, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40, flash_attn_ext_q5_1_h40, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40, flash_attn_ext_q8_0_h40, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, flash_attn_ext_vec_q4_1_h64, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, flash_attn_ext_vec_q5_0_h64, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, flash_attn_ext_vec_q5_1_h64, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, flash_attn_ext_vec_q4_1_h96, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, flash_attn_ext_vec_q5_0_h96, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, flash_attn_ext_vec_q5_1_h96, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, flash_attn_ext_vec_f16_h192, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, flash_attn_ext_vec_bf16_h192, has_simdgroup_reduction && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, flash_attn_ext_vec_q4_0_h192, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, flash_attn_ext_vec_q4_1_h192, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, flash_attn_ext_vec_q5_0_h192, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, flash_attn_ext_vec_q5_1_h192, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, flash_attn_ext_vec_q8_0_h192, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, flash_attn_ext_vec_f16_hk192_hv128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, flash_attn_ext_vec_bf16_hk192_hv128, has_simdgroup_reduction && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, flash_attn_ext_vec_q4_0_hk192_hv128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, flash_attn_ext_vec_q4_1_hk192_hv128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, flash_attn_ext_vec_q5_0_hk192_hv128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, flash_attn_ext_vec_q5_1_hk192_hv128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, flash_attn_ext_vec_q8_0_hk192_hv128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, flash_attn_ext_vec_f16_hk576_hv512, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, flash_attn_ext_vec_bf16_hk576_hv512, has_simdgroup_reduction && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, flash_attn_ext_vec_q4_0_hk576_hv512, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, flash_attn_ext_vec_q4_1_hk576_hv512, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, flash_attn_ext_reduce, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
@@ -1616,6 +1435,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_I32, cpy_f32_i32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_I32_F32, cpy_i32_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
@@ -1651,9 +1472,219 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
}
+ ctx->kernels_ext = [[NSMutableDictionary alloc] init];
+
return ctx;
}
+static id ggml_metal_get_kernel(struct ggml_backend_metal_context * ctx, const char * name) {
+ NSString * key = [NSString stringWithUTF8String:name];
+
+ ggml_metal_kernel_wrapper * obj = [ctx->kernels_ext objectForKey:key];
+ if (obj) {
+ return obj.kernel.pipeline;
+ }
+
+ return nil;
+}
+
+static id ggml_metal_compile_kernel(ggml_backend_t backend, const char * base, const char * name, MTLFunctionConstantValues * cv) {
+ struct ggml_backend_metal_context * ctx = backend->context;
+ struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
+ id res = nil;
+
+ @autoreleasepool {
+ NSError * error = nil;
+
+ NSString * base_func = [NSString stringWithUTF8String:base];
+
+ GGML_LOG_DEBUG("%s: compiling kernel: base = '%s', name = '%s'\n", __func__, base, name);
+
+ // TODO: make sure it is thread-safe to compile kernels in parallel
+ id metal_function = [ctx_dev->mtl_library newFunctionWithName:base_func constantValues:cv error:&error];
+ if (!metal_function) {
+ GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
+
+ return nil;
+ }
+
+ struct ggml_metal_kernel kernel = {
+ /*.pipeline =*/ [ctx_dev->mtl_device newComputePipelineStateWithFunction:metal_function error:&error],
+ };
+
+ ggml_metal_kernel_wrapper * obj = [[ggml_metal_kernel_wrapper alloc] init];
+ obj.kernel = kernel;
+
+ res = obj.kernel.pipeline;
+
+ NSString * key = [NSString stringWithUTF8String:name];
+ [ctx->kernels_ext setObject:obj forKey:key];
+
+ GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) kernel.pipeline,
+ (int) kernel.pipeline.maxTotalThreadsPerThreadgroup,
+ (int) kernel.pipeline.threadExecutionWidth);
+ }
+
+ return res;
+}
+
+static id ggml_metal_get_pipeline_flash_attn_ext(
+ ggml_backend_t backend, struct ggml_tensor * op,
+ bool has_mask,
+ bool has_sinks,
+ bool has_bias,
+ bool has_scap,
+ int32_t nsg) {
+ struct ggml_backend_metal_context * ctx = backend->context;
+
+ char base[256];
+ char name[256];
+
+ @autoreleasepool {
+ MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init];
+
+ const int32_t dk = (int32_t) op->src[1]->ne[0];
+ const int32_t dv = (int32_t) op->src[2]->ne[0];
+
+ const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
+ const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
+
+ snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
+ "flash_attn_ext",
+ ggml_type_name(op->src[1]->type),
+ dk,
+ dv);
+
+ snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
+ "flash_attn_ext",
+ ggml_type_name(op->src[1]->type),
+ dk,
+ dv,
+ has_mask,
+ has_sinks,
+ has_bias,
+ has_scap,
+ ns10,
+ ns20,
+ nsg);
+
+ id res = ggml_metal_get_kernel(ctx, name);
+ if (res) {
+ // kernel found
+ return res;
+ }
+
+ cv = [[MTLFunctionConstantValues alloc] init];
+
+ [cv setConstantValue:&has_mask type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 0];
+ [cv setConstantValue:&has_sinks type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 1];
+ [cv setConstantValue:&has_bias type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 2];
+ [cv setConstantValue:&has_scap type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 3];
+
+ [cv setConstantValue:&ns10 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 20];
+ [cv setConstantValue:&ns20 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 21];
+ [cv setConstantValue:&nsg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 22];
+
+ return ggml_metal_compile_kernel(backend, base, name, cv);
+ }
+}
+
+static id ggml_metal_get_pipeline_flash_attn_ext_vec(
+ ggml_backend_t backend, struct ggml_tensor * op,
+ bool has_mask,
+ bool has_sinks,
+ bool has_bias,
+ bool has_scap,
+ int32_t nsg,
+ int32_t nwg) {
+ struct ggml_backend_metal_context * ctx = backend->context;
+
+ char base[256];
+ char name[256];
+
+ @autoreleasepool {
+ MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init];
+
+ const int32_t dk = (int32_t) op->src[1]->ne[0];
+ const int32_t dv = (int32_t) op->src[2]->ne[0];
+
+ const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
+ const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
+
+ snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
+ "flash_attn_ext_vec",
+ ggml_type_name(op->src[1]->type),
+ dk,
+ dv);
+
+ snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
+ "flash_attn_ext_vec",
+ ggml_type_name(op->src[1]->type),
+ dk,
+ dv,
+ has_mask,
+ has_sinks,
+ has_bias,
+ has_scap,
+ ns10,
+ ns20,
+ nsg, nwg);
+
+ id res = ggml_metal_get_kernel(ctx, name);
+ if (res) {
+ // kernel found
+ return res;
+ }
+
+ cv = [[MTLFunctionConstantValues alloc] init];
+
+ [cv setConstantValue:&has_mask type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 0];
+ [cv setConstantValue:&has_sinks type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 1];
+ [cv setConstantValue:&has_bias type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 2];
+ [cv setConstantValue:&has_scap type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 3];
+
+ [cv setConstantValue:&ns10 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 20];
+ [cv setConstantValue:&ns20 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 21];
+ [cv setConstantValue:&nsg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 22];
+ [cv setConstantValue:&nwg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 23];
+
+ return ggml_metal_compile_kernel(backend, base, name, cv);
+ }
+}
+
+static id ggml_metal_get_pipeline_flash_attn_ext_vec_reduce(
+ ggml_backend_t backend, struct ggml_tensor * op,
+ int32_t dv,
+ int32_t nwg) {
+ struct ggml_backend_metal_context * ctx = backend->context;
+
+ char base[256];
+ char name[256];
+
+ @autoreleasepool {
+ MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init];
+
+ snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
+ snprintf(name, 256, "kernel_flash_attn_ext_vec_reduce_dv=%d_nwg=%d", dv, nwg);
+
+ id res = ggml_metal_get_kernel(ctx, name);
+ if (res) {
+ // kernel found
+ return res;
+ }
+
+ cv = [[MTLFunctionConstantValues alloc] init];
+
+ [cv setConstantValue:&dv type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC_REDUCE + 0];
+ [cv setConstantValue:&nwg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC_REDUCE + 1];
+
+ return ggml_metal_compile_kernel(backend, base, name, cv);
+ }
+
+ GGML_UNUSED(op);
+}
+
static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
GGML_LOG_INFO("%s: deallocating\n", __func__);
@@ -1661,16 +1692,26 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
[ctx->kernels[i].pipeline release];
}
+ if (ctx->kernels_ext) {
+ [ctx->kernels_ext release];
+ ctx->kernels_ext = nil;
+ }
+
Block_release(ctx->encode_async);
- [ctx->queue release];
+ //[ctx->queue release]; // [TAG_QUEUE_PER_BACKEND]
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
- // ctx->cmd_bufs[i].obj is auto released
+ if (ctx->cmd_bufs[i].obj) {
+ [ctx->cmd_bufs[i].obj release];
+ }
ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
}
+ [ctx->cmd_bufs_ext removeAllObjects];
+ [ctx->cmd_bufs_ext release];
+
dispatch_release(ctx->d_queue);
free(ctx);
@@ -1688,14 +1729,21 @@ struct ggml_backend_metal_buffer {
struct ggml_backend_metal_buffer_context {
void * all_data;
size_t all_size;
- bool owned;
+
+ // if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host
+ bool is_shared;
// multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
int n_buffers;
struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
// optional MTLResidencySet
+ // note: cannot use explicity "id" here because it is not available on certain OSes
id rset;
+
+ // pointers to global device objects
+ id device;
+ id queue;
};
// rset init
@@ -1761,7 +1809,7 @@ static void ggml_backend_metal_buffer_rset_free(struct ggml_backend_metal_buffer
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
// Metal buffer based on the host memory pointer
//
-static id ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs) {
+static id ggml_metal_get_buffer(const struct ggml_tensor * t, size_t * offs) {
//GGML_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
const int64_t tsize = ggml_nbytes(t);
@@ -1945,6 +1993,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_I32:
return true;
default:
return false;
@@ -1977,16 +2026,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
default:
return false;
}
- default:
- return false;
- };
- }
- case GGML_OP_SET:
- {
- switch (op->src[0]->type) {
- case GGML_TYPE_F32:
case GGML_TYPE_I32:
- return true;
+ return op->type == GGML_TYPE_F32;
default:
return false;
};
@@ -3765,6 +3806,7 @@ static int ggml_metal_encode_node(
{
nsg = N_SG_Q8_0;
nr0 = N_R0_Q8_0;
+ smem = 32*sizeof(float)*N_R0_Q8_0;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
} break;
case GGML_TYPE_MXFP4:
@@ -3901,7 +3943,12 @@ static int ggml_metal_encode_node(
if (smem > 0) {
[encoder setThreadgroupMemoryLength:smem atIndex:0];
}
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+
+ if (src0t == GGML_TYPE_Q8_0) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0 - 1)/(nr0), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+ } else {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+ }
}
} break;
case GGML_OP_MUL_MAT_ID:
@@ -4122,6 +4169,7 @@ static int ggml_metal_encode_node(
{
nsg = N_SG_Q8_0;
nr0 = N_R0_Q8_0;
+ smem = 32*sizeof(float)*N_R0_Q8_0;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
} break;
case GGML_TYPE_MXFP4:
@@ -4267,7 +4315,12 @@ static int ggml_metal_encode_node(
if (smem > 0) {
[encoder setThreadgroupMemoryLength:smem atIndex:0];
}
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+
+ if (src0t == GGML_TYPE_Q8_0) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+ } else {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+ }
}
} break;
case GGML_OP_GET_ROWS:
@@ -5118,6 +5171,7 @@ static int ggml_metal_encode_node(
float scale;
float max_bias;
float logit_softcap;
+
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
memcpy(&logit_softcap, ((const int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
@@ -5126,398 +5180,24 @@ static int ggml_metal_encode_node(
scale /= logit_softcap;
}
+ const bool has_mask = src3 != NULL;
+ const bool has_sinks = src4 != NULL;
+ const bool has_bias = max_bias != 0.0f;
+ const bool has_scap = logit_softcap != 0.0f;
+
const uint32_t n_head = src0->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
- id pipeline = nil;
-
- bool use_vec_kernel = false;
+ GGML_ASSERT(ne01 < 65536);
// use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size
- if (ne01 >= 20 || (ne00 == 40 || ne00 == 80 || ne00 == 112)) {
- switch (src1->type) {
- case GGML_TYPE_F16:
- {
- if (ne00 == 192 && ne20 == 128) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
- } else if (ne00 == 576 && ne20 == 512) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
- } else {
- switch (ne00) {
- case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40 ].pipeline; break;
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
- case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192].pipeline; break;
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- }
- } break;
- case GGML_TYPE_BF16:
- {
- if (ne00 == 192 && ne20 == 128) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline;
- } else if (ne00 == 576 && ne20 == 512) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
- } else {
- switch (ne00) {
- case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40 ].pipeline; break;
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
- case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192].pipeline; break;
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- }
- } break;
- case GGML_TYPE_Q4_0:
- {
- if (ne00 == 192 && ne20 == 128) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline;
- } else if (ne00 == 576 && ne20 == 512) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
- } else {
- switch (ne00) {
- case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40 ].pipeline; break;
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
- case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192].pipeline; break;
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- }
- } break;
- case GGML_TYPE_Q4_1:
- {
- if (ne00 == 192 && ne20 == 128) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline;
- } else if (ne00 == 576 && ne20 == 512) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
- } else {
- switch (ne00) {
- case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40 ].pipeline; break;
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
- case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192].pipeline; break;
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- }
- } break;
- case GGML_TYPE_Q5_0:
- {
- if (ne00 == 192 && ne20 == 128) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline;
- } else if (ne00 == 576 && ne20 == 512) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
- } else {
- switch (ne00) {
- case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40 ].pipeline; break;
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
- case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192].pipeline; break;
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- }
- } break;
- case GGML_TYPE_Q5_1:
- {
- if (ne00 == 192 && ne20 == 128) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline;
- } else if (ne00 == 576 && ne20 == 512) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
- } else {
- switch (ne00) {
- case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40 ].pipeline; break;
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
- case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192].pipeline; break;
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- }
- } break;
- case GGML_TYPE_Q8_0:
- {
- if (ne00 == 192 && ne20 == 128) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
- } else if (ne00 == 576 && ne20 == 512) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
- } else {
- switch (ne00) {
- case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40 ].pipeline; break;
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
- case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192].pipeline; break;
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- }
- } break;
- default:
- {
- GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
- GGML_LOG_ERROR("add template specialization for this type\n");
- GGML_ABORT("add template specialization for this type");
- }
- }
- } else {
- use_vec_kernel = true;
-
- switch (ne00) {
- case 64:
- {
- switch (src1->type) {
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
- GGML_LOG_ERROR("add template specialization for this type\n");
- GGML_ABORT("add template specialization for this type");
- }
- }
- } break;
- case 96:
- {
- switch (src1->type) {
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
- GGML_LOG_ERROR("add template specialization for this type\n");
- GGML_ABORT("add template specialization for this type");
- }
- }
- } break;
- case 128:
- {
- switch (src1->type) {
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
- GGML_LOG_ERROR("add template specialization for this type\n");
- GGML_ABORT("add template specialization for this type");
- }
- }
- } break;
- case 192:
- {
- if (ne20 == 128) {
- switch (src1->type) {
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
- GGML_LOG_ERROR("add template specialization for this type\n");
- GGML_ABORT("add template specialization for this type");
- }
- }
- } else {
- switch (src1->type) {
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
- GGML_LOG_ERROR("add template specialization for this type\n");
- GGML_ABORT("add template specialization for this type");
- }
- }
- }
- } break;
- case 256:
- {
- switch (src1->type) {
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
- GGML_LOG_ERROR("add template specialization for this type\n");
- GGML_ABORT("add template specialization for this type");
- }
- }
- } break;
- case 576:
- {
- if (ne20 == 512) {
- switch (src1->type) {
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
- GGML_LOG_ERROR("add template specialization for this type\n");
- GGML_ABORT("add template specialization for this type");
- }
- }
- } else {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne20);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- } break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- }
-
- ggml_metal_kargs_flash_attn_ext args = {
- /*.ne01 =*/ ne01,
- /*.ne02 =*/ ne02,
- /*.ne03 =*/ ne03,
- /*.nb01 =*/ nb01,
- /*.nb02 =*/ nb02,
- /*.nb03 =*/ nb03,
- /*.ne11 =*/ ne11,
- /*.ne_12_2 =*/ ne12,
- /*.ne_12_3 =*/ ne13,
- /*.nb11 =*/ nb11,
- /*.nb12 =*/ nb12,
- /*.nb13 =*/ nb13,
- /*.nb21 =*/ nb21,
- /*.nb22 =*/ nb22,
- /*.nb23 =*/ nb23,
- /*.ne32 =*/ ne32,
- /*.ne33 =*/ ne33,
- /*.nb31 =*/ nb31,
- /*.nb32 =*/ nb32,
- /*.nb33 =*/ nb33,
- /*.ne1 =*/ ne1,
- /*.ne2 =*/ ne2,
- /*.ne3 =*/ ne3,
- /*.scale =*/ scale,
- /*.max_bias =*/ max_bias,
- /*.m0 =*/ m0,
- /*.m1 =*/ m1,
- /*.n_head_log2 =*/ n_head_log2,
- /*.logit_softcap =*/ logit_softcap,
- };
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBytes:&args length:sizeof(args) atIndex:0];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
- if (id_src3) {
- [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4];
- } else {
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
- }
- if (id_src4) {
- [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5];
- } else {
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
- }
-
- if (!use_vec_kernel) {
+ if (ne01 >= 20 || (ne00 % 32 != 0)) {
// half8x8 kernel
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
- const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+ const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !!
GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 8 == 0);
@@ -5525,34 +5205,90 @@ static int ggml_metal_encode_node(
const int is_q = ggml_is_quantized(src1->type) ? 1 : 0;
- // 2*(2*ncpsg + nqptg)*(nsg)
- // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
+ // 2*(2*ncpsg)
+ // ncpsg soft_max values + ncpsg mask values
//
// 16*32*(nsg)
// the shared memory needed for the simdgroups to load the KV cache
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
//
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*GGML_PAD(ne20, 64) + 2*(2*ncpsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
- int64_t nsgmax = 2;
-
- while (true) {
- const size_t smem = FATTN_SMEM(nsgmax);
- if (smem > device.maxThreadgroupMemoryLength/2) {
- break;
- }
- nsgmax *= 2;
- }
- nsgmax /= 2;
+ //int64_t nsgmax = 4;
+ //
+ //if (is_q) {
+ // nsgmax = 2;
+ // while (true) {
+ // const size_t smem = FATTN_SMEM(nsgmax);
+ // if (smem > device.maxThreadgroupMemoryLength/2) {
+ // break;
+ // }
+ // nsgmax *= 2;
+ // }
+ // nsgmax /= 2;
+ //}
// simdgroups per threadgroup (a.k.a. warps)
- const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
+ //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
+ int32_t nsg = 4;
const size_t smem = FATTN_SMEM(nsg);
+ ggml_metal_kargs_flash_attn_ext args = {
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne11 =*/ ne11,
+ /*.ne_12_2 =*/ ne12,
+ /*.ne_12_3 =*/ ne13,
+ /*.ns10 =*/ nb11/nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.ns20 =*/ nb21/nb20,
+ /*.nb21 =*/ nb21,
+ /*.nb22 =*/ nb22,
+ /*.nb23 =*/ nb23,
+ /*.ne32 =*/ ne32,
+ /*.ne33 =*/ ne33,
+ /*.nb31 =*/ nb31,
+ /*.nb32 =*/ nb32,
+ /*.nb33 =*/ nb33,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.scale =*/ scale,
+ /*.max_bias =*/ max_bias,
+ /*.m0 =*/ m0,
+ /*.m1 =*/ m1,
+ /*.n_head_log2 =*/ n_head_log2,
+ /*.logit_softcap =*/ logit_softcap,
+ };
+
+ id pipeline = ggml_metal_get_pipeline_flash_attn_ext(backend, node, has_mask, has_sinks, has_bias, has_scap, nsg);
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
+ if (id_src3) {
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
+ }
+ if (id_src4) {
+ [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
+ }
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
- //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
+ //printf("smem: %zu, max: %zu, nsg = %d, ne02 = %d, ne12 = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, ne02, ne12);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
@@ -5561,7 +5297,7 @@ static int ggml_metal_encode_node(
// half4x4 kernel
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
- const int64_t nkpsg = 1*ncpsg; // TODO: make adjustable
+ const int64_t nkpsg = 1*ncpsg;
GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 1 == 0);
@@ -5574,8 +5310,7 @@ static int ggml_metal_encode_node(
// ne20*(nsg)
// each simdgroup has a full f32 head vector in shared mem to accumulate results
//
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
-//#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)))*(sizeof(float)/2), 16))
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16))
int64_t nsgmax = 2;
while (true) {
@@ -5589,7 +5324,8 @@ static int ggml_metal_encode_node(
nsgmax /= 2;
// simdgroups per threadgroup (a.k.a. warps)
- const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
+ //const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
+ const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32)));
int64_t nsg = 1;
while (nsg <= nsgt) {
@@ -5599,28 +5335,86 @@ static int ggml_metal_encode_node(
// workgroups
// each workgroup handles nsg*nkpsg cache values
- uint16_t nwg = 1;
- if (4*nsg*nkpsg >= ne11) {
- const size_t smem = FATTN_SMEM(nsg);
+ int32_t nwg = 1;
+ if (false) {
+ // for small KV caches, we could launch a single workgroup and write the results directly to dst/
+ // however, this does not lead to significant improvement, so disabled
+ nwg = 1;
+ nsg = 4;
+ } else {
+ nwg = 32;
+ nsg = 1;
+ while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) {
+ nsg *= 2;
+ }
+ }
- //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
- GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
+ ggml_metal_kargs_flash_attn_ext_vec args = {
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne11 =*/ ne11,
+ /*.ne_12_2 =*/ ne12,
+ /*.ne_12_3 =*/ ne13,
+ /*.ns10 =*/ nb11/nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.ns20 =*/ nb21/nb20,
+ /*.nb21 =*/ nb21,
+ /*.nb22 =*/ nb22,
+ /*.nb23 =*/ nb23,
+ /*.ne32 =*/ ne32,
+ /*.ne33 =*/ ne33,
+ /*.nb31 =*/ nb31,
+ /*.nb32 =*/ nb32,
+ /*.nb33 =*/ nb33,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.scale =*/ scale,
+ /*.max_bias =*/ max_bias,
+ /*.m0 =*/ m0,
+ /*.m1 =*/ m1,
+ /*.n_head_log2 =*/ n_head_log2,
+ /*.logit_softcap =*/ logit_softcap,
+ };
+ id pipeline = ggml_metal_get_pipeline_flash_attn_ext_vec(backend, node, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
+
+ GGML_ASSERT(nsg*32 <= (int) pipeline.maxTotalThreadsPerThreadgroup);
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
+ if (id_src3) {
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
+ }
+ if (id_src4) {
+ [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
+ }
+
+ const size_t smem = FATTN_SMEM(nsg);
+
+ //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
+ GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
+
+ if (nwg == 1) {
// using 1 workgroup -> write the result directly into dst
- [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
- [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
} else {
- nwg = 32;
- nsg = MIN(4, nsg);
-
- const size_t smem = FATTN_SMEM(nsg);
-
- //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
- GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
-
// sanity checks
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
@@ -5640,20 +5434,18 @@ static int ggml_metal_encode_node(
//printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
//printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
- [encoder setBuffer:h_tmp offset:0 atIndex:6];
- [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
+ [encoder setBuffer:h_tmp offset:0 atIndex:6];
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
// reduce the results from the workgroups
{
- ggml_metal_kargs_flash_attn_ext_reduce args0 = {
+ ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
nrows,
- ne20,
};
- id pipeline0 = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline;
+ id pipeline0 = ggml_metal_get_pipeline_flash_attn_ext_vec_reduce(backend, node, ne20, nwg);
[encoder setComputePipelineState:pipeline0];
[encoder setBytes:&args0 length:sizeof(args0) atIndex:0];
@@ -5661,7 +5453,7 @@ static int ggml_metal_encode_node(
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
//printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*32, 1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*nwg, 1, 1)];
}
}
#undef FATTN_SMEM
@@ -5680,6 +5472,7 @@ static int ggml_metal_encode_node(
switch (dstt) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
+ case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_I32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
@@ -5691,6 +5484,13 @@ static int ggml_metal_encode_node(
default: GGML_ABORT("not implemented");
};
} break;
+ case GGML_TYPE_I32:
+ {
+ switch (dstt) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_I32_F32].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ };
+ } break;
case GGML_TYPE_F16:
{
switch (dstt) {
@@ -5807,68 +5607,6 @@ static int ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
} break;
- case GGML_OP_SET:
- {
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
- GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
-
- // src0 and dst as viewed during set
- const size_t dst_nb0 = ggml_element_size(src0);
-
- const size_t dst_nb1 = ((int32_t *) dst->op_params)[0];
- const size_t dst_nb2 = ((int32_t *) dst->op_params)[1];
- const size_t dst_nb3 = ((int32_t *) dst->op_params)[2];
- const size_t offset = ((int32_t *) dst->op_params)[3];
- const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
-
- if (!inplace) {
- memcpy(((char *) dst->data), ((char *) src0->data), ggml_nbytes(dst));
- }
-
- const int im0 = (ne10 == 0 ? 0 : ne10-1);
- const int im1 = (ne11 == 0 ? 0 : ne11-1);
- const int im2 = (ne12 == 0 ? 0 : ne12-1);
- const int im3 = (ne13 == 0 ? 0 : ne13-1);
-
- GGML_ASSERT(offset + im0*dst_nb0 + im1*dst_nb1 + im2*dst_nb2 + im3*dst_nb3 <= ggml_nbytes(dst));
-
- id pipeline = nil;
-
- switch (src0t) {
- case GGML_TYPE_F32:
- GGML_ASSERT(nb10 == sizeof(float));
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_F32].pipeline; break;
- case GGML_TYPE_I32:
- GGML_ASSERT(nb10 == sizeof(int32_t));
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_I32].pipeline; break;
- default: GGML_ABORT("fatal error");
- }
-
- ggml_metal_kargs_set args = {
- /*.ne10 =*/ ne10,
- /*.ne11 =*/ ne11,
- /*.ne12 =*/ ne12,
- /*.nb10 =*/ nb10,
- /*.nb11 =*/ nb11,
- /*.nb12 =*/ nb12,
- /*.nb13 =*/ nb13,
- /*.nb1 =*/ dst_nb1,
- /*.nb2 =*/ dst_nb2,
- /*.nb3 =*/ dst_nb3,
- /*.offs =*/ offset,
- /*.inplace =*/ inplace,
- };
-
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne10);
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBytes:&args length:sizeof(args) atIndex:0];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
case GGML_OP_POOL_2D:
{
GGML_ASSERT(ggml_is_contiguous(src0));
@@ -5997,6 +5735,12 @@ static enum ggml_status ggml_metal_graph_compute(
if (should_capture) {
ctx->capture_next_compute = false;
+ // make sure all previous computations have finished before starting the capture
+ if (ctx->cmd_buf_last) {
+ [ctx->cmd_buf_last waitUntilCompleted];
+ ctx->cmd_buf_last = nil;
+ }
+
if (!ctx->capture_started) {
// create capture scope
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device];
@@ -6019,78 +5763,103 @@ static enum ggml_status ggml_metal_graph_compute(
// the main thread commits the first few commands immediately
// cmd_buf[n_cb]
{
- id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
+ // cannot use commandBufferWithUnretainedReferences because the buffers from the memory pool can get destroyed
+ // TODO: when the memory pools are removed, we can again use commandBufferWithUnretainedReferences
+ // https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2334215009
+ //id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
+ id cmd_buf = [ctx->queue commandBuffer];
+ [cmd_buf retain];
+
ctx->cmd_bufs[n_cb].obj = cmd_buf;
[cmd_buf enqueue];
+
ctx->encode_async(n_cb);
}
- // prepare the rest of the command buffers asynchronously
+ // remember the command buffer for the next iteration
+ ctx->cmd_buf_last = ctx->cmd_bufs[n_cb].obj;
+
+ // prepare the rest of the command buffers asynchronously (optional)
// cmd_buf[0.. n_cb)
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
- id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
+ //id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
+ id cmd_buf = [ctx->queue commandBuffer];
+ [cmd_buf retain];
+
+ if (ctx->cmd_bufs[cb_idx].obj) {
+ [ctx->cmd_bufs[cb_idx].obj release];
+ }
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
// always enqueue the first two command buffers
// enqueue all of the command buffers if we don't need to abort
if (cb_idx < 2 || ctx->abort_callback == NULL) {
[cmd_buf enqueue];
+
+ // update the pointer to the last queued command buffer
+ // this is needed to implement synchronize()
+ ctx->cmd_buf_last = cmd_buf;
}
}
dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
- // wait for completion and check status of each command buffer
- // needed to detect if the device ran out-of-memory for example (#1881)
- {
- id cmd_buf = ctx->cmd_bufs[n_cb].obj;
- [cmd_buf waitUntilCompleted];
-
- MTLCommandBufferStatus status = [cmd_buf status];
- if (status != MTLCommandBufferStatusCompleted) {
- GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
- if (status == MTLCommandBufferStatusError) {
- GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
- }
-
- return GGML_STATUS_FAILED;
- }
- }
-
- for (int i = 0; i < n_cb; ++i) {
- id cmd_buf = ctx->cmd_bufs[i].obj;
- [cmd_buf waitUntilCompleted];
-
- MTLCommandBufferStatus status = [cmd_buf status];
- if (status != MTLCommandBufferStatusCompleted) {
- GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
- if (status == MTLCommandBufferStatusError) {
- GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
- }
-
- return GGML_STATUS_FAILED;
- }
-
- id next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
- if (!next_buffer) {
- continue;
- }
-
- const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
- if (next_queued) {
- continue;
- }
-
- if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
- GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
- return GGML_STATUS_ABORTED;
- }
-
- [next_buffer commit];
- }
+ // for debugging: block until graph is computed
+ //[ctx->cmd_buf_last waitUntilCompleted];
+ // enter here only when capturing in order to wait for all computation to finish
+ // otherwise, we leave the graph to compute asynchronously
if (!should_capture && ctx->capture_started) {
+ // wait for completion and check status of each command buffer
+ // needed to detect if the device ran out-of-memory for example (#1881)
+ {
+ id cmd_buf = ctx->cmd_bufs[n_cb].obj;
+ [cmd_buf waitUntilCompleted];
+
+ MTLCommandBufferStatus status = [cmd_buf status];
+ if (status != MTLCommandBufferStatusCompleted) {
+ GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
+ if (status == MTLCommandBufferStatusError) {
+ GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
+ }
+
+ return GGML_STATUS_FAILED;
+ }
+ }
+
+ for (int i = 0; i < n_cb; ++i) {
+ id cmd_buf = ctx->cmd_bufs[i].obj;
+ [cmd_buf waitUntilCompleted];
+
+ MTLCommandBufferStatus status = [cmd_buf status];
+ if (status != MTLCommandBufferStatusCompleted) {
+ GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
+ if (status == MTLCommandBufferStatusError) {
+ GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
+ }
+
+ return GGML_STATUS_FAILED;
+ }
+
+ id next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
+ if (!next_buffer) {
+ continue;
+ }
+
+ const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
+ if (next_queued) {
+ continue;
+ }
+
+ if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
+ GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
+ return GGML_STATUS_ABORTED;
+ }
+
+ [next_buffer commit];
+ }
+
[ctx->capture_scope endScope];
[[MTLCaptureManager sharedCaptureManager] stopCapture];
}
@@ -6100,10 +5869,12 @@ static enum ggml_status ggml_metal_graph_compute(
}
////////////////////////////////////////////////////////////////////////////////
-
// backend interface
+////////////////////////////////////////////////////////////////////////////////
-static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+// shared buffer
+
+static void ggml_backend_metal_buffer_shared_free_buffer(ggml_backend_buffer_t buffer) {
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
for (int i = 0; i < ctx->n_buffers; i++) {
@@ -6112,7 +5883,9 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
ggml_backend_metal_buffer_rset_free(ctx);
- if (ctx->owned) {
+ GGML_ASSERT(ctx->is_shared);
+
+ {
#if TARGET_OS_OSX
vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size);
#else
@@ -6123,66 +5896,254 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
free(ctx);
}
-static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
+static void * ggml_backend_metal_buffer_shared_get_base(ggml_backend_buffer_t buffer) {
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
return ctx->all_data;
}
-static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
- memset((char *)tensor->data + offset, value, size);
-
- GGML_UNUSED(buffer);
-}
-
-static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
- memcpy((char *)tensor->data + offset, data, size);
-
- GGML_UNUSED(buffer);
-}
-
-static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
- memcpy(data, (const char *)tensor->data + offset, size);
-
- GGML_UNUSED(buffer);
-}
-
-static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
- if (ggml_backend_buffer_is_host(src->buffer)) {
- memcpy(dst->data, src->data, ggml_nbytes(src));
- return true;
- }
- return false;
-
- GGML_UNUSED(buffer);
-}
-
-static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+static void ggml_backend_metal_buffer_shared_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
+ GGML_ASSERT(ctx->is_shared);
+
+ memset((char *)tensor->data + offset, value, size);
+}
+
+static void ggml_backend_metal_buffer_shared_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
+
+ GGML_ASSERT(ctx->is_shared);
+
+ memcpy((char *)tensor->data + offset, data, size);
+}
+
+static void ggml_backend_metal_buffer_shared_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
+
+ GGML_ASSERT(ctx->is_shared);
+
+ memcpy(data, (const char *)tensor->data + offset, size);
+}
+
+static bool ggml_backend_metal_buffer_shared_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
+ GGML_UNUSED(buffer);
+ GGML_UNUSED(src);
+ GGML_UNUSED(dst);
+
+ return false;
+}
+
+static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
+
+ GGML_ASSERT(ctx->is_shared);
+
memset(ctx->all_data, value, ctx->all_size);
}
-static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
- /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
- /* .get_base = */ ggml_backend_metal_buffer_get_base,
+static struct ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = {
+ /* .free_buffer = */ ggml_backend_metal_buffer_shared_free_buffer,
+ /* .get_base = */ ggml_backend_metal_buffer_shared_get_base,
/* .init_tensor = */ NULL,
- /* .memset_tensor = */ ggml_backend_metal_buffer_memset_tensor,
- /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
- /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
- /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor,
- /* .clear = */ ggml_backend_metal_buffer_clear,
+ /* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor,
+ /* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor,
+ /* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor,
+ /* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor,
+ /* .clear = */ ggml_backend_metal_buffer_shared_clear,
/* .reset = */ NULL,
};
-// default buffer type
+// private buffer
-static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
- return "Metal";
+static void ggml_backend_metal_buffer_private_free_buffer(ggml_backend_buffer_t buffer) {
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
- GGML_UNUSED(buft);
+ for (int i = 0; i < ctx->n_buffers; i++) {
+ [ctx->buffers[i].metal release];
+ }
+
+ ggml_backend_metal_buffer_rset_free(ctx);
+
+ GGML_ASSERT(!ctx->is_shared);
+
+ free(ctx);
}
+static void * ggml_backend_metal_buffer_private_get_base(ggml_backend_buffer_t buffer) {
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
+
+ return ctx->all_data;
+}
+
+static void ggml_backend_metal_buffer_private_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
+
+ GGML_ASSERT(!ctx->is_shared);
+
+ @autoreleasepool {
+ // dst
+ size_t buf_dst_offset = 0;
+ id buf_dst = ggml_metal_get_buffer(tensor, &buf_dst_offset);
+
+ buf_dst_offset += offset;
+
+ id queue = ctx->queue;
+ id cmd_buf = [queue commandBufferWithUnretainedReferences];
+
+ {
+ id encoder = [cmd_buf blitCommandEncoder];
+
+ [encoder fillBuffer:buf_dst
+ range:NSMakeRange(buf_dst_offset, buf_dst_offset + size)
+ value:value];
+
+ [encoder endEncoding];
+ }
+
+ [cmd_buf commit];
+ [cmd_buf waitUntilCompleted];
+ }
+}
+
+static void ggml_backend_metal_buffer_private_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
+
+ GGML_ASSERT(!ctx->is_shared);
+
+ @autoreleasepool {
+ // src
+ void * data_ptr = (void *)(uintptr_t) data; // "const cast" the src data
+ id buf_src = [ctx->device newBufferWithBytesNoCopy:data_ptr
+ length:size
+ options:MTLResourceStorageModeShared
+ deallocator:nil];
+
+ // dst
+ size_t buf_dst_offset = 0;
+ id buf_dst = ggml_metal_get_buffer(tensor, &buf_dst_offset);
+
+ buf_dst_offset += offset;
+
+ // note: for experimentation purposes, here we use a semaphore to wait for the copy to complete
+ // this is alternative to waitUntilCompleted, which should be faster, but don't seem to make much difference
+ dispatch_semaphore_t completion_semaphore = dispatch_semaphore_create(0);
+
+ id queue = ctx->queue;
+ id cmd_buf = [queue commandBufferWithUnretainedReferences];
+
+ {
+ id encoder = [cmd_buf blitCommandEncoder];
+
+ [encoder copyFromBuffer:buf_src
+ sourceOffset:0
+ toBuffer:buf_dst
+ destinationOffset:buf_dst_offset
+ size:size];
+
+ [encoder endEncoding];
+ }
+
+ [cmd_buf addCompletedHandler:^(id cb) {
+ // TODO: can check for errors here
+ GGML_UNUSED(cb);
+
+ dispatch_semaphore_signal(completion_semaphore);
+ }];
+
+ [cmd_buf commit];
+
+ dispatch_semaphore_wait(completion_semaphore, DISPATCH_TIME_FOREVER);
+ //[cmd_buf waitUntilCompleted];
+ }
+}
+
+static void ggml_backend_metal_buffer_private_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
+
+ GGML_ASSERT(!ctx->is_shared);
+
+ @autoreleasepool {
+ // src
+ size_t buf_src_offset = 0;
+ id buf_src = ggml_metal_get_buffer(tensor, &buf_src_offset);
+
+ buf_src_offset += offset;
+
+ // dst
+ id buf_dst = [ctx->device newBufferWithBytesNoCopy:data
+ length:size
+ options:MTLResourceStorageModeShared
+ deallocator:nil];
+
+ id queue = ctx->queue;
+ id cmd_buf = [queue commandBufferWithUnretainedReferences];
+
+ {
+ id encoder = [cmd_buf blitCommandEncoder];
+
+ [encoder copyFromBuffer:buf_src
+ sourceOffset:buf_src_offset
+ toBuffer:buf_dst
+ destinationOffset:0
+ size:size];
+
+ [encoder endEncoding];
+ }
+
+ [cmd_buf commit];
+ [cmd_buf waitUntilCompleted];
+ }
+}
+
+static bool ggml_backend_metal_buffer_private_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
+ GGML_UNUSED(buffer);
+ GGML_UNUSED(src);
+ GGML_UNUSED(dst);
+
+ return false;
+}
+
+static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
+
+ GGML_ASSERT(!ctx->is_shared);
+
+ @autoreleasepool {
+ id queue = ctx->queue;
+ id cmd_buf = [queue commandBufferWithUnretainedReferences];
+
+ {
+ id encoder = [cmd_buf blitCommandEncoder];
+
+ [encoder fillBuffer:ctx->buffers[0].metal
+ range:NSMakeRange(0, ctx->buffers[0].size)
+ value:value];
+
+ [encoder endEncoding];
+ }
+
+ [cmd_buf commit];
+ [cmd_buf waitUntilCompleted];
+ }
+}
+
+static struct ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = {
+ /* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer,
+ /* .get_base = */ ggml_backend_metal_buffer_private_get_base,
+ /* .init_tensor = */ NULL,
+ /* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor,
+ /* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor,
+ /* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor,
+ /* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor,
+ /* .clear = */ ggml_backend_metal_buffer_private_clear,
+ /* .reset = */ NULL,
+};
+
+//
+// buffer types
+//
+
static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) {
#ifndef GGML_METAL_NDEBUG
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
@@ -6208,7 +6169,8 @@ static void ggml_backend_metal_log_allocated_size(id device, size_t s
GGML_UNUSED(size_aligned);
}
-static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+// common method for allocating shread or private Metal buffers
+static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size, bool shared) {
struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
const size_t size_page = sysconf(_SC_PAGESIZE);
@@ -6224,22 +6186,40 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
id device = ctx_dev->mtl_device;
- ctx->all_data = ggml_metal_host_malloc(size_aligned);
+ // allocate shared buffer if the device supports it and it is required by the buffer type
+ if (ctx_dev->use_shared_buffers && shared) {
+ ctx->all_data = ggml_metal_host_malloc(size_aligned);
+ ctx->is_shared = true;
+ } else {
+ // dummy, non-NULL value - we'll populate this after creating the Metal buffer below
+ ctx->all_data = (void *) 0x000000400ULL;
+ ctx->is_shared = false;
+ }
ctx->all_size = size_aligned;
- ctx->owned = true;
+
+ ctx->device = device;
+ ctx->queue = ctx_dev->mtl_queue;
+
ctx->n_buffers = 1;
if (ctx->all_data != NULL) {
- ctx->buffers[0].data = ctx->all_data;
ctx->buffers[0].size = size;
ctx->buffers[0].metal = nil;
if (size_aligned > 0) {
- ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
- length:size_aligned
- options:MTLResourceStorageModeShared
- deallocator:nil];
+ if (ctx_dev->use_shared_buffers) {
+ ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
+ length:size_aligned
+ options:MTLResourceStorageModeShared
+ deallocator:nil];
+ } else {
+ ctx->buffers[0].metal = [device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate];
+
+ ctx->all_data = (void *) (ctx->buffers[0].metal.gpuAddress);
+ }
}
+
+ ctx->buffers[0].data = ctx->all_data;
}
if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
@@ -6256,36 +6236,50 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
//ggml_backend_metal_log_allocated_size(device, size_aligned);
- return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
+ struct ggml_backend_buffer_i buf_i = ctx->is_shared ? ggml_backend_metal_buffer_shared_i : ggml_backend_metal_buffer_private_i;
+
+ return ggml_backend_buffer_init(buft, buf_i, ctx, size);
}
-static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+// default (shared) buffer type
+
+static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) {
+ return "Metal";
+
+ GGML_UNUSED(buft);
+}
+
+static ggml_backend_buffer_t ggml_backend_metal_buffer_type_shared_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true);
+}
+
+static size_t ggml_backend_metal_buffer_type_shared_get_alignment(ggml_backend_buffer_type_t buft) {
return 32;
GGML_UNUSED(buft);
}
-static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
+static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_buffer_type_t buft) {
const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size;
return max_size;
}
-static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
- return true;
+static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_type_t buft) {
+ return false;
GGML_UNUSED(buft);
}
-ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
+static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(void) {
static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
/* .iface = */ {
- /* .get_name = */ ggml_backend_metal_buffer_type_get_name,
- /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
- /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
- /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size,
+ /* .get_name = */ ggml_backend_metal_buffer_type_shared_get_name,
+ /* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment,
+ /* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size,
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
- /* .is_host = */ ggml_backend_metal_buffer_type_is_host,
+ /* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host,
},
/* .device = */ &g_ggml_backend_metal_device,
/* .context = */ NULL,
@@ -6294,116 +6288,101 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
return &ggml_backend_buffer_type_metal;
}
-static const char * ggml_backend_metal_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) {
- return "Metal_Mapped";
+// default (private) buffer type
+
+static const char * ggml_backend_metal_buffer_type_private_get_name(ggml_backend_buffer_type_t buft) {
+ return "Metal_Private";
GGML_UNUSED(buft);
}
-static ggml_backend_buffer_type_t ggml_backend_metal_buffer_from_ptr_type(void) {
- static struct ggml_backend_buffer_type ggml_backend_buffer_from_ptr_type_metal = {
+static ggml_backend_buffer_t ggml_backend_metal_buffer_type_private_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, false);
+}
+
+static size_t ggml_backend_metal_buffer_type_private_get_alignment(ggml_backend_buffer_type_t buft) {
+ return 32;
+
+ GGML_UNUSED(buft);
+}
+
+static size_t ggml_backend_metal_buffer_type_private_get_max_size(ggml_backend_buffer_type_t buft) {
+ const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size;
+
+ return max_size;
+}
+
+static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_type_t buft) {
+ return false;
+
+ GGML_UNUSED(buft);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(void) {
+ static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
/* .iface = */ {
- /* .get_name = */ ggml_backend_metal_buffer_from_ptr_type_get_name,
- /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
- /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
- /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size,
+ /* .get_name = */ ggml_backend_metal_buffer_type_private_get_name,
+ /* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment,
+ /* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size,
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
- /* .is_host = */ ggml_backend_metal_buffer_type_is_host,
+ /* .is_host = */ ggml_backend_metal_buffer_type_private_is_host,
},
/* .device = */ &g_ggml_backend_metal_device,
/* .context = */ NULL,
};
- return &ggml_backend_buffer_from_ptr_type_metal;
+ return &ggml_backend_buffer_type_metal;
}
-// TODO: obsoleted by ggml_backend_metal_device_buffer_from_ptr
-ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
- struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
+// mapped buffer type
- ctx->all_data = data;
- ctx->all_size = size;
- ctx->owned = false;
- ctx->n_buffers = 0;
+static const char * ggml_backend_metal_buffer_type_mapped_get_name(ggml_backend_buffer_type_t buft) {
+ return "Metal_Mapped";
- const size_t size_page = sysconf(_SC_PAGESIZE);
+ GGML_UNUSED(buft);
+}
- // page-align the data ptr
- {
- const uintptr_t offs = (uintptr_t) data % size_page;
- data = (void *) ((char *) data - offs);
- size += offs;
- }
+static ggml_backend_buffer_t ggml_backend_metal_buffer_type_mapped_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ // for mapped buffers, prefer shared memory
+ return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true);
+}
- size_t size_aligned = size;
- if ((size_aligned % size_page) != 0) {
- size_aligned += (size_page - (size_aligned % size_page));
- }
+static size_t ggml_backend_metal_buffer_type_mapped_get_alignment(ggml_backend_buffer_type_t buft) {
+ return 32;
- struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
+ GGML_UNUSED(buft);
+}
- GGML_ASSERT(ctx_dev->mtl_device != nil);
+static size_t ggml_backend_metal_buffer_type_mapped_get_max_size(ggml_backend_buffer_type_t buft) {
+ const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size;
- id device = ctx_dev->mtl_device;
+ return max_size;
+}
- // the buffer fits into the max buffer size allowed by the device
- if (size_aligned <= device.maxBufferLength) {
- ctx->buffers[ctx->n_buffers].data = data;
- ctx->buffers[ctx->n_buffers].size = size;
- ctx->buffers[ctx->n_buffers].metal = nil;
+static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_type_t buft) {
+ return false;
- if (size_aligned > 0) {
- ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
+ GGML_UNUSED(buft);
+}
- if (ctx->buffers[ctx->n_buffers].metal == nil) {
- GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
- return false;
- }
- }
+static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(void) {
+ // note: not obvious, but this buffer type still needs to implement .alloc_buffer:
+ // https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099
+ static struct ggml_backend_buffer_type ggml_backend_buffer_type_mapped_metal = {
+ /* .iface = */ {
+ /* .get_name = */ ggml_backend_metal_buffer_type_mapped_get_name,
+ /* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment,
+ /* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size,
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
+ /* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host,
+ },
+ /* .device = */ &g_ggml_backend_metal_device,
+ /* .context = */ NULL,
+ };
- ggml_backend_metal_log_allocated_size(device, size_aligned);
-
- ++ctx->n_buffers;
- } else {
- // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
- // one of the views
- const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
- const size_t size_step = device.maxBufferLength - size_ovlp;
- const size_t size_view = device.maxBufferLength;
-
- for (size_t i = 0; i < size; i += size_step) {
- const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
-
- ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
- ctx->buffers[ctx->n_buffers].size = size_step_aligned;
- ctx->buffers[ctx->n_buffers].metal = nil;
-
- if (size_step_aligned > 0) {
- ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
-
- if (ctx->buffers[ctx->n_buffers].metal == nil) {
- GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
- return false;
- }
- }
-
- ggml_backend_metal_log_allocated_size(device, size_step_aligned);
-
- if (i + size_step < size) {
- GGML_LOG_INFO("\n");
- }
-
- ++ctx->n_buffers;
- }
- }
-
- if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
- GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
- free(ctx);
- return NULL;
- }
-
- return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
+ return &ggml_backend_buffer_type_mapped_metal;
}
// backend
@@ -6422,6 +6401,137 @@ static void ggml_backend_metal_free(ggml_backend_t backend) {
free(backend);
}
+static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
+ struct ggml_backend_metal_context * ctx = backend->context;
+
+ // wait for any backend operations to finish
+ if (ctx->cmd_buf_last) {
+ [ctx->cmd_buf_last waitUntilCompleted];
+ ctx->cmd_buf_last = nil;
+ }
+
+ // release any completed command buffers
+ if (ctx->cmd_bufs_ext.count > 0) {
+ for (size_t i = 0; i < ctx->cmd_bufs_ext.count; ++i) {
+ id cmd_buf = ctx->cmd_bufs_ext[i];
+
+ MTLCommandBufferStatus status = [cmd_buf status];
+ if (status != MTLCommandBufferStatusCompleted) {
+ GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, (int) i, (int) status);
+ if (status == MTLCommandBufferStatusError) {
+ GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
+ }
+ GGML_ABORT("fatal error");
+ }
+
+ [cmd_buf release];
+ }
+
+ [ctx->cmd_bufs_ext removeAllObjects];
+ }
+}
+
+static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ struct ggml_backend_metal_context * ctx = backend->context;
+ struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
+ @autoreleasepool {
+ id device = ctx_dev->mtl_device;
+
+ // wrap the source data into a Metal buffer
+ id buf_src = [device newBufferWithBytes:data
+ length:size
+ options:MTLResourceStorageModeShared];
+
+ size_t buf_dst_offset = 0;
+ id buf_dst = ggml_metal_get_buffer(tensor, &buf_dst_offset);
+
+ if (buf_dst == nil) {
+ GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name);
+ }
+
+ buf_dst_offset += offset;
+
+ // queue the copy operation into the queue of the Metal context
+ // this will be queued at the end, after any currently ongoing GPU operations
+ id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
+ id encoder = [cmd_buf blitCommandEncoder];
+
+ [encoder copyFromBuffer:buf_src
+ sourceOffset:0
+ toBuffer:buf_dst
+ destinationOffset:buf_dst_offset
+ size:size];
+
+ [encoder endEncoding];
+ [cmd_buf commit];
+
+ // do not wait here for completion
+ //[cmd_buf waitUntilCompleted];
+
+ // instead, remember a reference to the command buffer and wait for it later if needed
+ [ctx->cmd_bufs_ext addObject:cmd_buf];
+ ctx->cmd_buf_last = cmd_buf;
+
+ [cmd_buf retain];
+ }
+}
+
+static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ struct ggml_backend_metal_context * ctx = backend->context;
+ struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
+ @autoreleasepool {
+ id device = ctx_dev->mtl_device;
+
+ id buf_dst = [device newBufferWithBytesNoCopy:data
+ length:size
+ options:MTLResourceStorageModeShared
+ deallocator:nil];
+
+ size_t buf_src_offset = 0;
+ id buf_src = ggml_metal_get_buffer(tensor, &buf_src_offset);
+
+ if (buf_src == nil) {
+ GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name);
+ }
+
+ buf_src_offset += offset;
+
+ // queue the copy operation into the queue of the Metal context
+ // this will be queued at the end, after any currently ongoing GPU operations
+ id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
+ id encoder = [cmd_buf blitCommandEncoder];
+
+ [encoder copyFromBuffer:buf_src
+ sourceOffset:buf_src_offset
+ toBuffer:buf_dst
+ destinationOffset:0
+ size:size];
+
+ [encoder endEncoding];
+ [cmd_buf commit];
+
+ // do not wait here for completion
+ //[cmd_buf waitUntilCompleted];
+
+ // instead, remember a reference to the command buffer and wait for it later if needed
+ [ctx->cmd_bufs_ext addObject:cmd_buf];
+ ctx->cmd_buf_last = cmd_buf;
+
+ [cmd_buf retain];
+ }
+}
+
+static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) {
+ return false;
+
+ GGML_UNUSED(backend_src);
+ GGML_UNUSED(backend_dst);
+ GGML_UNUSED(src);
+ GGML_UNUSED(dst);
+}
+
static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
return ggml_metal_graph_compute(backend, cgraph);
}
@@ -6452,7 +6562,10 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
- id cmd_buf = ctx->cmd_bufs[cb_idx].obj;
+ id cmd_buf = ctx->cmd_bufs[cb_idx].obj;
+ struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
+
+ ggml_metal_mem_pool_reset(mem_pool);
id encoder = [cmd_buf computeCommandEncoder];
@@ -6466,9 +6579,6 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
const bool should_capture = ctx->capture_next_compute;
- struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
- ggml_metal_mem_pool_reset(mem_pool);
-
for (int idx = node_start; idx < node_end;) {
if (should_capture) {
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
@@ -6502,17 +6612,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
static struct ggml_backend_i ggml_backend_metal_i = {
/* .get_name = */ ggml_backend_metal_name,
/* .free = */ ggml_backend_metal_free,
- /* .set_tensor_async = */ NULL,
- /* .get_tensor_async = */ NULL,
- /* .cpy_tensor_async = */ NULL,
- /* .synchronize = */ NULL,
+ /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
+ /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
+ /* .cpy_tensor_async = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups
+ /* .synchronize = */ ggml_backend_metal_synchronize,
/* .graph_plan_create = */ NULL,
/* .graph_plan_free = */ NULL,
/* .graph_plan_update = */ NULL,
/* .graph_plan_compute = */ NULL,
/* .graph_compute = */ ggml_backend_metal_graph_compute,
+
+ // the events API is needed only for multi-GPU setups, so likely no need to implement it for Metal
+ // in any case, these docs seem relevant if we ever decide to implement it:
+ // https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
+ /* .optimize_graph = */ NULL,
};
static ggml_guid_t ggml_backend_metal_guid(void) {
@@ -6613,7 +6728,7 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct g
props->type = ggml_backend_metal_device_get_type(dev);
ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->caps = (struct ggml_backend_dev_caps) {
- /* .async = */ false,
+ /* .async = */ true,
/* .host_buffer = */ false,
/* .buffer_from_host_ptr = */ true,
/* .events = */ false,
@@ -6644,17 +6759,19 @@ static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, con
}
static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) {
- return ggml_backend_metal_buffer_type();
+ struct ggml_backend_metal_device_context * ctx_dev = dev->context;
- GGML_UNUSED(dev);
+ return ctx_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared() : ggml_backend_metal_buffer_type_private();
}
-static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
+static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
ctx->all_data = ptr;
ctx->all_size = size;
- ctx->owned = false;
+
+ ctx->is_shared = true;
+
ctx->n_buffers = 0;
const size_t size_page = sysconf(_SC_PAGESIZE);
@@ -6677,6 +6794,9 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
id device = ctx_dev->mtl_device;
+ ctx->device = device;
+ ctx->queue = ctx_dev->mtl_queue;
+
// the buffer fits into the max buffer size allowed by the device
if (size_aligned <= device.maxBufferLength) {
ctx->buffers[ctx->n_buffers].data = ptr;
@@ -6734,7 +6854,7 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
return NULL;
}
- return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
+ return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(), ggml_backend_metal_buffer_shared_i, ctx, size);
}
static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
@@ -6745,14 +6865,30 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const
static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
return
- buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
- buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
+ buft->iface.get_name == ggml_backend_metal_buffer_type_shared_get_name ||
+ buft->iface.get_name == ggml_backend_metal_buffer_type_private_get_name ||
+ buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name;
GGML_UNUSED(dev);
}
+static int64_t get_op_batch_size(const struct ggml_tensor * op) {
+ switch (op->op) {
+ case GGML_OP_MUL_MAT:
+ return op->ne[1];
+ case GGML_OP_MUL_MAT_ID:
+ return op->ne[2];
+ default:
+ return ggml_nrows(op);
+ }
+}
+
static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
- return false;
+ const int min_batch_size = 32;
+
+ return (op->op == GGML_OP_MUL_MAT ||
+ op->op == GGML_OP_MUL_MAT_ID) &&
+ get_op_batch_size(op) >= min_batch_size;
GGML_UNUSED(dev);
GGML_UNUSED(op);
@@ -6767,7 +6903,7 @@ static struct ggml_backend_device_i ggml_backend_metal_device_i = {
/* .init_backend = */ ggml_backend_metal_device_init,
/* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type,
/* .get_host_buffer_type = */ NULL,
- /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_from_ptr,
+ /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_mapped,
/* .supports_op = */ ggml_backend_metal_device_supports_op,
/* .supports_buft = */ ggml_backend_metal_device_supports_buft,
/* .offload_op = */ ggml_backend_metal_device_offload_op,
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 2d56c6267..157d0cc6d 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -15,6 +15,10 @@ using namespace metal;
#define MIN(x, y) ((x) < (y) ? (x) : (y))
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
+#define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1))
+
+#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x)
+
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
@@ -2755,7 +2759,47 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
}
-template
+template
+static inline void helper_mv_reduce_and_write(
+ device float * dst_f32,
+ float sumf[NR0],
+ const int r0,
+ const int ne01,
+ ushort tiisg,
+ ushort sgitg,
+ threadgroup char * shmem) {
+ threadgroup float * shmem_f32[NR0];
+
+ for (short row = 0; row < NR0; ++row) {
+ shmem_f32[row] = (threadgroup float *) shmem + NW*row;
+
+ if (sgitg == 0) {
+ shmem_f32[row][tiisg] = 0.0f;
+ }
+
+ sumf[row] = simd_sum(sumf[row]);
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (short row = 0; row < NR0; ++row) {
+ if (tiisg == 0) {
+ shmem_f32[row][sgitg] = sumf[row];
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (short row = 0; row < NR0 && r0 + row < ne01; ++row) {
+ float tot = simd_sum(shmem_f32[row][tiisg]);
+
+ if (tiisg == 0 && sgitg == 0) {
+ dst_f32[r0 + row] = tot;
+ }
+ }
+}
+
+template
void mul_vec_q_n_f32_impl(
args_t args,
device const char * src0,
@@ -2765,45 +2809,51 @@ void mul_vec_q_n_f32_impl(
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
+ constexpr short NQ = 16;
+
const int nb = args.ne00/QK4_0;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
-
- const int first_row = (r0 * nsg + sgitg) * nr0;
+ const int r0 = (tgpig.x*NSG + sgitg)*NR0;
+ //const int r0 = tgpig.x*NR0;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
- //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+ //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
//device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
// pointers to src0 rows
- device const block_q_type * ax[nr0];
- for (int row = 0; row < nr0; ++row) {
- const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ device const block_q_type * ax[NR0];
+ FOR_UNROLL (int row = 0; row < NR0; ++row) {
+ const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
}
+ float sumf[NR0] = {0.f};
+
+ const short ix = (tiisg/(NW/NQ));
+ const short il = (tiisg%(NW/NQ))*8;
+
+ //const int ib0 = sgitg*NQ + ix;
+ const int ib0 = ix;
+
float yl[16]; // src1 vector cache
- float sumf[nr0] = {0.f};
- const short ix = (tiisg/2);
- const short il = (tiisg%2)*8;
-
- device const float * yb = y + ix*QK4_0 + il;
+ //device const float * yb = y + ix*QK4_0 + il;
+ device const float * yb = y + ib0*QK4_0 + il;
// each thread in a SIMD group deals with half a block.
- for (int ib = ix; ib < nb; ib += nw/2) {
+ //for (int ib = ib0; ib < nb; ib += NSG*NQ) {
+ for (int ib = ib0; ib < nb; ib += NQ) {
float sumy[2] = { 0.f, 0.f };
-#pragma unroll
- for (short i = 0; i < 8; i += 2) {
+ FOR_UNROLL (short i = 0; i < 8; i += 2) {
sumy[0] += yb[i + 0] + yb[i + 1];
yl[i + 0] = yb[i + 0];
yl[i + 1] = yb[i + 1]/256.f;
@@ -2813,21 +2863,23 @@ void mul_vec_q_n_f32_impl(
yl[i + 9] = yb[i + 17]/4096.f;
}
-#pragma unroll
- for (short row = 0; row < nr0; row++) {
+ FOR_UNROLL (short row = 0; row < NR0; row++) {
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
}
yb += QK4_0 * 16;
+ //yb += NSG*NQ*QK4_0;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
- for (int row = 0; row < nr0; ++row) {
+ //helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
+
+ for (int row = 0; row < NR0; ++row) {
const float tot = simd_sum(sumf[row]);
- if (tiisg == 0 && first_row + row < args.ne01) {
- dst_f32[first_row + row] = tot;
+ if (tiisg == 0 && r0 + row < args.ne01) {
+ dst_f32[r0 + row] = tot;
}
}
}
@@ -2837,10 +2889,11 @@ kernel void kernel_mul_mv_q4_0_f32(
device const char * src0,
device const char * src1,
device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q4_1_f32(
@@ -2848,10 +2901,11 @@ kernel void kernel_mul_mv_q4_1_f32(
device const char * src0,
device const char * src1,
device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q5_0_f32(
@@ -2859,10 +2913,11 @@ kernel void kernel_mul_mv_q5_0_f32(
device const char * src0,
device const char * src1,
device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q5_1_f32(
@@ -2870,15 +2925,14 @@ kernel void kernel_mul_mv_q5_1_f32(
device const char * src0,
device const char * src1,
device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
-#define NB_Q8_0 8
-
-template
+template
void kernel_mul_mv_q8_0_f32_impl(
args_t args,
device const char * src0,
@@ -2888,66 +2942,65 @@ void kernel_mul_mv_q8_0_f32_impl(
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
+ constexpr short NQ = 8;
+
const int nb = args.ne00/QK8_0;
- const int r0 = tgpig.x;
+ const int r0 = tgpig.x*NR0;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * nsg + sgitg) * nr0;
-
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
- //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+ //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
//device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
// pointers to src0 rows
- device const block_q8_0 * ax[nr0];
- for (int row = 0; row < nr0; ++row) {
- const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ device const block_q8_0 * ax[NR0];
+ FOR_UNROLL (short row = 0; row < NR0; ++row) {
+ const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
}
- float yl[NB_Q8_0];
- float sumf[nr0] = { 0.f };
+ float sumf[NR0] = { 0.f };
- const short ix = tiisg/4;
- const short il = tiisg%4;
+ const short ix = tiisg/(NW/NQ);
+ const short il = tiisg%(NW/NQ);
- device const float * yb = y + ix*QK8_0 + il*NB_Q8_0;
+ const int ib0 = sgitg*NQ + ix;
- // each thread in a SIMD group deals with NB_Q8_0 quants at a time
- for (int ib = ix; ib < nb; ib += nw/4) {
- for (short i = 0; i < NB_Q8_0; ++i) {
+ float yl[NQ];
+
+ device const float * yb = y + ib0*QK8_0 + il*NQ;
+
+ // each thread in a SIMD group deals with NQ quants at a time
+ for (int ib = ib0; ib < nb; ib += NSG*NQ) {
+ for (short i = 0; i < NQ; ++i) {
yl[i] = yb[i];
}
- for (short row = 0; row < nr0; row++) {
- device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
+ for (short row = 0; row < NR0; row++) {
+ device const int8_t * qs = ax[row][ib].qs + il*NQ;
+
float sumq = 0.f;
- for (short iq = 0; iq < NB_Q8_0; ++iq) {
- sumq += qs[iq] * yl[iq];
+ FOR_UNROLL (short i = 0; i < NQ; ++i) {
+ sumq += qs[i] * yl[i];
}
+
sumf[row] += sumq*ax[row][ib].d;
}
- yb += nw*NB_Q8_0;
+ yb += NSG*NQ*QK8_0;
}
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < nr0; ++row) {
- const float tot = simd_sum(sumf[row]);
-
- if (tiisg == 0 && first_row + row < args.ne01) {
- dst_f32[first_row + row] = tot;
- }
- }
+ helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
}
[[host_name("kernel_mul_mv_q8_0_f32")]]
@@ -2956,10 +3009,11 @@ kernel void kernel_mul_mv_q8_0_f32(
device const char * src0,
device const char * src1,
device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
// mat-vec kernel processing in chunks of float4
@@ -4197,6 +4251,19 @@ kernel void kernel_leaky_relu_f32(
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope;
}
+constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]];
+constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]];
+constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]];
+constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]];
+
+//constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
+//constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]];
+//constant float FC_flash_attn_ext_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT + 12)]];
+
+constant int32_t FC_flash_attn_ext_ns10 [[function_constant(FC_FLASH_ATTN_EXT + 20)]];
+constant int32_t FC_flash_attn_ext_ns20 [[function_constant(FC_FLASH_ATTN_EXT + 21)]];
+constant int32_t FC_flash_attn_ext_nsg [[function_constant(FC_FLASH_ATTN_EXT + 22)]];
+
// ref: https://arxiv.org/pdf/2307.08691.pdf
template<
typename q_t, // query types in shared memory
@@ -4211,6 +4278,7 @@ template<
typename qk_t, // Q*K types
typename qk8x8_t,
typename s_t, // soft-max types
+ typename s2_t,
typename s8x8_t,
typename o_t, // attention accumulation types
typename o4_t,
@@ -4221,12 +4289,12 @@ template<
typename vd4x4_t, // value type in device memory
short nl_v,
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
- short DK, // K head size
- short DV, // V head size
- short Q = 8, // queries per threadgroup
- short KV = 8, // key/value processed per each simdgroup
- short C = 32> // cache items per threadgroup
-kernel void kernel_flash_attn_ext(
+ short DK, // K head size
+ short DV, // V head size
+ short Q, // queries per threadgroup
+ short C, // cache items per threadgroup
+ short NSG> // number of simd groups
+void kernel_flash_attn_ext_impl(
constant ggml_metal_kargs_flash_attn_ext & args,
device const char * q,
device const char * k,
@@ -4234,46 +4302,85 @@ kernel void kernel_flash_attn_ext(
device const char * mask,
device const char * sinks,
device char * dst,
- threadgroup half * shmem_f16 [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 ntg[[threads_per_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- const short nsg = ntg.y; // number of simdgroups
+ threadgroup half * shmem_f16,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+ const ushort iq3 = tgpig[2];
+ const ushort iq2 = tgpig[1];
+ const ushort iq1 = tgpig[0]*Q;
- const int iq3 = tgpig[2];
- const int iq2 = tgpig[1];
- const int iq1 = tgpig[0]*Q;
+#define NS10 (FC_flash_attn_ext_ns10)
+#define NS20 (FC_flash_attn_ext_ns20)
+
+ // note: I had some concerns that using this instead of the ugly macros above was affecting performance
+ // need to re-check carefully and if no regressions are observerd - remove the macros
+ // the concerns is that maybe using const variables requires extra registers? but not sure if the compiler
+ // is clever enough to avoid this. unfortunately, using constexpr is not possible with FC
+ //const short NS10 = FC_flash_attn_ext_ns10;
+ //const short NS20 = FC_flash_attn_ext_ns20;
+
+ constexpr short KV = 8;
constexpr short DK4 = DK/4;
constexpr short DK8 = DK/8;
constexpr short DK16 = DK/16;
constexpr short DV4 = DV/4;
- constexpr short DV8 = DV/8;
+ //constexpr short DV8 = DV/8;
constexpr short DV16 = DV/16;
+ constexpr short PV = PAD2(DV, 64);
+ constexpr short PV4 = PV/4;
+ constexpr short PV8 = PV/8;
+ //constexpr short PV16 = PV/16;
+
constexpr short NW = N_SIMDWIDTH;
- constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
+ constexpr short NQ = Q/NSG;
+ constexpr short SH = 2*C; // shared memory per simdgroup (s_t == float)
- const short TS = nsg*SH; // shared memory size per query in (s_t == float)
- const short T = 2*DK + 2*TS; // shared memory size per query in (half)
+ constexpr short TS = 2*SH;
+ constexpr short T = DK + 2*PV; // shared memory size per query in (half)
- threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix
+ threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*T); // holds the query data
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*T); // same as above but in q4_t
+ threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*T + Q*DK); // the result for all queries in 8x8 matrices (the O matrix from the paper)
+ threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*T + Q*DK);
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + Q*T); // scratch buffer for attention, mask and diagonal matrix
+ threadgroup s2_t * ss2 = (threadgroup s2_t *) (shmem_f16 + Q*T); // same as above but in s2_t
- threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
- threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
+ threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load K in shared memory
+ threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in k4x4_t
- threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
- threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
+ threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load V in shared memory
+ threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in v4x4_t
- // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
- o8x8_t lo[DV8];
+ // mask storage in shared mem
+ threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C);
+
+ // per-query mask pointers
+ device const half2 * pm2[NQ];
+
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+ const short j = jj*NSG + sgitg;
+
+ pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
+ }
+
+ {
+ q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03;
+
+ const short ikv2 = iq2/(args.ne02/args.ne_12_2);
+ const short ikv3 = iq3/(args.ne03/args.ne_12_3);
+
+ k += ikv2*args.nb12 + ikv3*args.nb13;
+ v += ikv2*args.nb22 + ikv3*args.nb23;
+ }
// load heads from Q to shared memory
- for (short j = sgitg; j < Q; j += nsg) {
- device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+ const short j = jj*NSG + sgitg;
+
+ device const float4 * q4 = (device const float4 *) ((device const char *) q + j*args.nb01);
for (short i = tiisg; i < DK4; i += NW) {
if (iq1 + j < args.ne01) {
@@ -4284,43 +4391,30 @@ kernel void kernel_flash_attn_ext(
}
}
- // zero out lo
- for (short i = 0; i < DV8; ++i) {
- lo[i] = make_filled_simdgroup_matrix((o_t) 0.0f);
- }
+ // zero out
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+ const short j = jj*NSG + sgitg;
+
+ for (short i = tiisg; i < DV4; i += NW) {
+ so4[j*PV4 + i] = 0;
+ }
- // zero out shared memory SH
- for (short j = 0; j < Q; ++j) {
for (short i = tiisg; i < SH; i += NW) {
- ss[j*TS + i] = 0.0f;
+ ss[j*SH + i] = 0.0f;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
+ float S[NQ] = { [0 ... NQ-1] = 0.0f };
+
{
- float S[Q] = { [0 ... Q-1] = 0.0f };
- float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 };
-
- // thread indices inside the simdgroup
- // TODO: see if we can utilize quad-group functions for better performance
- // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3)
- const short tx = tiisg%4;
- const short ty = tiisg/4;
-
- // broadcast kv
- //const short rk2 = args.ne02/args.ne12;
- //const short rk3 = args.ne03/args.ne13;
-
- const short ikv2 = iq2/(args.ne02/args.ne_12_2);
- const short ikv3 = iq3/(args.ne03/args.ne_12_3);
-
- const bool has_mask = mask != q;
+ float M[NQ] = { [0 ... NQ-1] = -FLT_MAX/2 };
float slope = 1.0f;
// ALiBi
- if (args.max_bias > 0.0f) {
+ if (FC_flash_attn_ext_has_bias) {
const short h = iq2;
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
@@ -4331,177 +4425,277 @@ kernel void kernel_flash_attn_ext(
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
- for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
- const int ic = ic0 + C*sgitg;
- if (ic >= args.ne11) {
- break;
- }
+ for (int ic = 0; ic < args.ne11; ic += C) {
+ // read the mask into shared mem
+ if (FC_flash_attn_ext_has_mask) {
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+ const short j = jj*NSG + sgitg;
- if (has_mask) {
- // used to detect blocks full of -INF
- float smax = -INFINITY;
-
- // load the mask in shared memory
- #pragma unroll(Q)
- for (short j = 0; j < Q; ++j) {
- device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
-
- const float m = pm[ic + tiisg];
-
- ss[j*TS + C + tiisg] = m;
- smax = max(smax, m);
+ sm2[j*SH + tiisg] = pm2[jj][tiisg];
+ pm2[jj] += NW;
}
- smax = simd_max(smax);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // used to detect blocks full of -INF
+ // skip only when the entire threadgroup is masked
+ half2 smax2(-MAXHALF/2, -MAXHALF/2);
+
+ FOR_UNROLL (short j = 0; j < Q; ++j) {
+ smax2 = max(smax2, sm2[j*SH + tiisg]);
+ }
+
+ smax2 = simd_max(smax2);
+
+ if (max(smax2[0], smax2[1]) <= -MAXHALF/2) {
+ // this barrier is important
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- if (smax == -INFINITY) {
continue;
}
}
// Q*K^T
- {
- for (short cc = 0; cc < C/8; ++cc) {
+ // this is compile-time check, so it does not have runtime overhead
+ if (is_same::value) {
+ // we can read directly from global memory
+ device const k_t * pk = (device const k_t *) ((device const char *) k + ic*args.nb11);
+ threadgroup const q_t * pq = sq;
+ threadgroup s_t * ps = ss;
+
+ pk += sgitg*(8*NS10);
+ ps += sgitg*(8*1);
+
+ static_assert((C/8) % NSG == 0, "");
+
+ constexpr short NC = (C/8)/NSG;
+
+ // TODO: not good to unroll for large contexts - not sure why?
+ for (short cc = 0; cc < NC; ++cc) {
qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f);
- // this is compile-time check, so it does not have runtime overhead
- if (is_same::value) {
- // we can read directly from global memory
- device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
+ if (DK8 % 16 != 0) {
+ k8x8_t mk;
+ q8x8_t mq;
- #pragma unroll(DK8)
- for (short i = 0; i < DK8; ++i) {
- k8x8_t mk;
- simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
+ FOR_UNROLL (short i = 0; i < DK8; ++i) {
+ simdgroup_barrier(mem_flags::mem_none);
+
+ simdgroup_load(mk, pk, NS10, 0, true);
+ simdgroup_load(mq, pq, DK);
+
+ simdgroup_barrier(mem_flags::mem_none);
- q8x8_t mq;
- simdgroup_load(mq, sq + i*8, DK);
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+
+ pk += 8;
+ pq += 8;
}
} else {
- for (short ii = 0; ii < DK16; ii += 4) {
- device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
+ k8x8_t mk[2];
+ q8x8_t mq[2];
- if (DK16%4 == 0) {
- // the head is evenly divisible by 4*16 = 64, so no need for bound checks
- {
- k4x4_t tmp;
- deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
- sk4x4[4*ty + tx] = tmp;
- }
+ FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
+ simdgroup_barrier(mem_flags::mem_none);
- simdgroup_barrier(mem_flags::mem_threadgroup);
+ simdgroup_load(mk[0], pk + 0*8, NS10, 0, true);
+ simdgroup_load(mk[1], pk + 1*8, NS10, 0, true);
- #pragma unroll(4)
- for (short k = 0; k < 4; ++k) {
- k8x8_t mk;
- q8x8_t mq;
+ simdgroup_load(mq[0], pq + 0*8, DK);
+ simdgroup_load(mq[1], pq + 1*8, DK);
- simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
- simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
- simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+ simdgroup_barrier(mem_flags::mem_none);
- simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
- simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
- simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
- }
- } else {
- if (ii + tx < DK16) {
- k4x4_t tmp;
- deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
- sk4x4[4*ty + tx] = tmp;
- }
+ simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk);
+ simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk);
- simdgroup_barrier(mem_flags::mem_threadgroup);
+ pk += 16;
+ pq += 16;
+ }
+ }
- for (short k = 0; k < 4 && ii + k < DK16; ++k) {
- k8x8_t mk;
- q8x8_t mq;
+ simdgroup_store(mqk, ps, SH, 0, false);
- simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
- simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
- simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+ pk += 8*(NSG*NS10 - DK8);
+ pq += 8*(NSG*0 - DK8);
+ ps += 8*(NSG);
+ }
+ } else {
+ // TODO: this is the quantized K cache branch - not optimized yet
+ for (short ccc = 0; ccc < (C/8)/NSG; ++ccc) {
+ const short cc = ccc*NSG + sgitg;
- simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
- simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
- simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
- }
+ const short tx = tiisg%4;
+ const short ty = tiisg/4;
+
+ qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f);
+
+ for (short ii = 0; ii < DK16; ii += 4) {
+ device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11));
+
+ if (DK16%4 == 0) {
+ // the head is evenly divisible by 4*16 = 64, so no need for bound checks
+ {
+ k4x4_t tmp;
+ deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
+ sk4x4[4*ty + tx] = tmp;
+ }
+
+ simdgroup_barrier(mem_flags::mem_threadgroup);
+
+ FOR_UNROLL (short k = 0; k < 4; ++k) {
+ k8x8_t mk;
+ q8x8_t mq;
+
+ simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
+ simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+
+ simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
+ simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+ }
+ } else {
+ if (ii + tx < DK16) {
+ k4x4_t tmp;
+ deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
+ sk4x4[4*ty + tx] = tmp;
+ }
+
+ simdgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (short k = 0; k < 4 && ii + k < DK16; ++k) {
+ k8x8_t mk;
+ q8x8_t mq;
+
+ simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
+ simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+
+ simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
+ simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
}
}
}
- // cast qk_t -> s_t
- //s8x8_t mqks(1.0f);
- //simdgroup_multiply(mqks, mqk, mqks);
- //simdgroup_store(mqks, ss + 8*cc, TS, 0, false);
-
- simdgroup_store(mqk, ss + 8*cc, TS, 0, false);
+ simdgroup_store(mqk, ss + 8*cc, SH, 0, false);
}
}
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
// online softmax
- {
- for (ushort j = 0; j < Q; ++j) {
- const float m = M[j];
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+ const short j = jj*NSG + sgitg;
- // scale and apply the logitcap / mask
- float s = ss[j*TS + tiisg]*args.scale;
+ const float m = M[jj];
- if (args.logit_softcap != 0.0f) {
- s = args.logit_softcap*precise::tanh(s);
+ // scale and apply the logitcap / mask
+ float2 s2 = ss2[j*SH/2 + tiisg]*args.scale;
+
+ if (FC_flash_attn_ext_has_scap) {
+ s2 = args.logit_softcap*precise::tanh(s2);
+ }
+
+ // mqk = mqk + slope*mask
+ if (FC_flash_attn_ext_has_bias) {
+ s2 += s2_t(sm2[j*SH + tiisg])*slope;
+ } else {
+ s2 += s2_t(sm2[j*SH + tiisg]);
+ }
+
+ M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));
+
+ const float ms = exp(m - M[jj]);
+ const float2 vs2 = exp(s2 - M[jj]);
+
+ S[jj] = S[jj]*ms + simd_sum(vs2[0] + vs2[1]);
+
+ // the P matrix from the paper (Q rows, C columns)
+ ss2[j*SH/2 + tiisg] = vs2;
+
+ if (DV4 % NW == 0) {
+ FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
+ const short i = ii*NW + tiisg;
+
+ so4[j*PV4 + i] *= ms;
}
-
- // mqk = mqk + mask*slope
- s += slope*ss[j*TS + C + tiisg];
-
- M[j] = simd_max(max(M[j], s));
-
- const float ms = exp(m - M[j]);
- const float vs = exp(s - M[j]);
-
- S[j] = S[j]*ms + simd_sum(vs);
-
- // the P matrix from the paper (Q rows, C columns)
- ss[j*TS + tiisg] = vs;
-
- // create a QxQ diagonal matrix for rescaling the output
- if (tiisg == j) {
- ss[j*TS + 2*C + j] = ms;
+ } else {
+ for (short i = tiisg; i < DV4; i += NW) {
+ so4[j*PV4 + i] *= ms;
}
}
}
- // O = diag(ms)*O
- {
- s8x8_t ms;
- simdgroup_load(ms, ss + 2*C, TS, 0, false);
-
- #pragma unroll(DV8)
- for (short i = 0; i < DV8; ++i) {
- simdgroup_multiply(lo[i], ms, lo[i]);
- }
- }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
// O = O + (Q*K^T)*V
{
- for (short cc = 0; cc < C/8; ++cc) {
- s8x8_t vs;
- simdgroup_load(vs, ss + 8*cc, TS, 0, false);
+ // we can read directly from global memory
+ if (is_same::value) {
+ static_assert(PV8 % NSG == 0, "");
- if (is_same::value) {
- // we can read directly from global memory
- device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
+ constexpr short NO = PV8/NSG;
- #pragma unroll(DV8)
- for (short i = 0; i < DV8; ++i) {
- v8x8_t mv;
- simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
+ o8x8_t lo[NO];
- simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]);
+ {
+ auto sot = so + 8*sgitg;
+
+ FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
+ simdgroup_load(lo[ii], sot, PV, 0, false);
+
+ sot += 8*NSG;
}
- } else {
- for (short ii = 0; ii < DV16; ii += 4) {
- device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
+ }
+
+ {
+ auto sst = ss;
+
+ device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21);
+
+ pv += 8*sgitg;
+
+ FOR_UNROLL (short cc = 0; cc < C/8; ++cc) {
+ s8x8_t vs;
+ simdgroup_load(vs, sst, SH, 0, false);
+
+ FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
+ v8x8_t mv;
+
+ simdgroup_load(mv, pv, NS20, 0, false);
+ simdgroup_multiply_accumulate(lo[ii], vs, mv, lo[ii]);
+
+ pv += 8*NSG;
+ }
+
+ pv += 8*(NS20 - NO*NSG);
+ sst += 8;
+ }
+ }
+
+ {
+ auto sot = so + 8*sgitg;
+
+ FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
+ simdgroup_store(lo[ii], sot, PV, 0, false);
+
+ sot += 8*NSG;
+ }
+ }
+ } else {
+ // TODO: this is the quantized V cache branch - not optimized yet
+
+ const short tx = tiisg%4;
+ const short ty = tiisg/4;
+
+ for (short cc = 0; cc < C/8; ++cc) {
+ s8x8_t vs;
+ simdgroup_load(vs, ss + 8*cc, SH, 0, false);
+
+ for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) {
+ device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21));
if (DV16%4 == 0) {
// no need for bound checks
@@ -4513,15 +4707,20 @@ kernel void kernel_flash_attn_ext(
simdgroup_barrier(mem_flags::mem_threadgroup);
- #pragma unroll(4)
- for (short k = 0; k < 4; ++k) {
- v8x8_t mv;
+ FOR_UNROLL (short k = 0; k < 4; ++k) {
+ v8x8_t mv[2];
+ o8x8_t lo[2];
- simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
+ simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false);
+ simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false);
+ simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
+ simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
- simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
+ simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]);
+ simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]);
+
+ simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
+ simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
}
} else {
if (ii + tx < DV16) {
@@ -4533,243 +4732,249 @@ kernel void kernel_flash_attn_ext(
simdgroup_barrier(mem_flags::mem_threadgroup);
for (short k = 0; k < 4 && ii + k < DV16; ++k) {
- v8x8_t mv;
+ v8x8_t mv[2];
+ o8x8_t lo[2];
- simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
+ simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false);
+ simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false);
+ simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
+ simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
- simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
+ simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]);
+ simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]);
+
+ simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
+ simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
}
}
}
}
}
}
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
}
- if (sinks != q && sgitg == 0) {
- for (ushort j = 0; j < Q; ++j) {
- const float m = M[j];
+ if (FC_flash_attn_ext_has_sinks) {
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+ const short j = jj*NSG + sgitg;
+
+ const float m = M[jj];
const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
- M[j] = simd_max(max(M[j], s));
+ M[jj] = simd_max(max(M[jj], s));
- const float ms = exp(m - M[j]);
- const float vs = exp(s - M[j]);
+ const float ms = exp(m - M[jj]);
+ const float vs = exp(s - M[jj]);
- S[j] = S[j]*ms + simd_sum(vs);
+ S[jj] = S[jj]*ms + simd_sum(vs);
- if (tiisg == j) {
- ss[j*TS + 2*C + j] = ms;
- }
- }
-
- // O = diag(ms)*O
- {
- s8x8_t ms;
- simdgroup_load(ms, ss + 2*C, TS, 0, false);
-
- #pragma unroll(DV8)
- for (short i = 0; i < DV8; ++i) {
- simdgroup_multiply(lo[i], ms, lo[i]);
+ for (short i = tiisg; i < DV4; i += NW) {
+ so4[j*PV4 + i] *= ms;
}
}
}
-
- // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
- for (short j = tiisg; j < Q; j += NW) {
- ss[j*TS + 0] = S[j];
- ss[j*TS + 1] = M[j];
- }
}
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- threadgroup float * so = (threadgroup float *) (shmem_f16 + 0*DK); // reuse query data for accumulation
- threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK);
-
- // store result to shared memory in F32
- if (sgitg == 0) {
- for (short i = 0; i < DV8; ++i) {
- //simdgroup_store(lo[i], so + i*8, DV, 0, false);
- simdgroup_float8x8 t(1.0f);
- simdgroup_multiply(t, lo[i], t);
- simdgroup_store(t, so + i*8, DV, 0, false);
+ // store to global memory
+ for (short jj = 0; jj < NQ; ++jj) {
+ const short j = jj*NSG + sgitg;
+ if (iq1 + j >= args.ne01) {
+ break;
}
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- // reduce the warps sequentially
- for (ushort sg = 1; sg < nsg; ++sg) {
- if (sgitg == sg) {
- for (short j = tiisg; j < Q; j += NW) {
- const float S0 = ss[j*TS - 1*SH + 0];
- const float S1 = ss[j*TS + 0];
-
- const float M0 = ss[j*TS - 1*SH + 1];
- const float M1 = ss[j*TS + 1];
-
- const float M = max(M0, M1);
-
- float ms0 = exp(M0 - M);
- float ms1 = exp(M1 - M);
-
- const float S = S0*ms0 + S1*ms1;
-
- ss[j*TS + 0] = S;
- ss[j*TS + 1] = M;
-
- ss[j*TS + 2*C + j - 1*SH] = ms0;
- ss[j*TS + 2*C + j ] = ms1;
- }
-
- //simdgroup_barrier(mem_flags::mem_threadgroup);
-
- // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
- {
- s8x8_t ms0;
- s8x8_t ms1;
-
- simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false);
- simdgroup_load(ms1, ss + 2*C, TS, 0, false);
-
- #pragma unroll(DV8)
- for (short i = 0; i < DV8; ++i) {
- simdgroup_float8x8 t;
-
- simdgroup_load (t, so + i*8, DV, 0, false);
- simdgroup_multiply(t, ms0, t);
-
- simdgroup_multiply_accumulate(t, ms1, lo[i], t);
- simdgroup_store(t, so + i*8, DV, 0, false);
- }
- }
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
- }
-
- threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK);
-
- // final rescale with 1/S and store to global memory
- for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) {
- const float S = 1.0f/sf[j*TS + 0];
device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
- for (short i = tiisg; i < DV4; i += NW) {
- dst4[i] = (float4) so4[j*DV4 + i]*S;
+ const float scale = 1.0f/S[jj];
+
+ if (DV4 % NW == 0) {
+ FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
+ const short i = ii*NW + tiisg;
+
+ dst4[i] = (float4) so4[j*PV4 + i]*scale;
+ }
+ } else {
+ for (short i = tiisg; i < DV4; i += NW) {
+ dst4[i] = (float4) so4[j*PV4 + i]*scale;
+ }
}
}
+
+#undef NS10
+#undef NS20
+}
+
+template<
+ typename q_t, // query types in shared memory
+ typename q4_t,
+ typename q8x8_t,
+ typename k_t, // key types in shared memory
+ typename k4x4_t,
+ typename k8x8_t,
+ typename v_t, // value types in shared memory
+ typename v4x4_t,
+ typename v8x8_t,
+ typename qk_t, // Q*K types
+ typename qk8x8_t,
+ typename s_t, // soft-max types
+ typename s2_t,
+ typename s8x8_t,
+ typename o_t, // attention accumulation types
+ typename o4_t,
+ typename o8x8_t,
+ typename kd4x4_t, // key type in device memory
+ short nl_k,
+ void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
+ typename vd4x4_t, // value type in device memory
+ short nl_v,
+ void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
+ short DK, // K head size
+ short DV, // V head size
+ short Q = 8, // queries per threadgroup
+ short C = 64> // cache items per threadgroup
+kernel void kernel_flash_attn_ext(
+ constant ggml_metal_kargs_flash_attn_ext & args,
+ device const char * q,
+ device const char * k,
+ device const char * v,
+ device const char * mask,
+ device const char * sinks,
+ device char * dst,
+ threadgroup half * shmem_f16 [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+#define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C
+#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
+ switch (FC_flash_attn_ext_nsg) {
+ // note: disabled cases to reduce library load time
+ //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break;
+ //case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break;
+ case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break;
+ }
+#undef FWD_TMPL
+#undef FWD_ARGS
}
// TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as
// template to be able to explore different combinations
//
#define FA_TYPES \
- float, float4, simdgroup_float8x8, \
+ half, half4, simdgroup_half8x8, \
half, half4x4, simdgroup_half8x8, \
half, half4x4, simdgroup_half8x8, \
float, simdgroup_float8x8, \
- float, simdgroup_float8x8, \
- half, half4, simdgroup_half8x8
- //float, float4, simdgroup_float8x8
+ float, float2, simdgroup_float8x8, \
+ float, float4, simdgroup_float8x8
+ //half, half4, simdgroup_half8x8
#define FA_TYPES_BF \
bfloat, bfloat4, simdgroup_bfloat8x8, \
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
float, simdgroup_float8x8, \
- float, simdgroup_float8x8, \
+ float, float2, simdgroup_float8x8, \
half, half4, simdgroup_half8x8
//float, float4, simdgroup_float8x8
typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t;
-template [[host_name("kernel_flash_attn_ext_f16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_bf16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
-template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext