diff --git a/common/arg.cpp b/common/arg.cpp
index 5f7a2c3db..5746728c5 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -1547,10 +1547,18 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL}));
add_opt(common_arg(
- {"-fa", "--flash-attn"},
- string_format("enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled"),
- [](common_params & params) {
- params.flash_attn = true;
+ {"-fa", "--flash-attn"}, "FA",
+ string_format("set Flash Attention use ('on', 'off', or 'auto', default: '%s')", llama_flash_attn_type_name(params.flash_attn_type)),
+ [](common_params & params, const std::string & value) {
+ if (value == "on" || value == "enabled") {
+ params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
+ } else if (value == "off" || value == "disabled") {
+ params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
+ } else if (value == "auto") {
+ params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
+ } else {
+ throw std::runtime_error(string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str()));
+ }
}
).set_env("LLAMA_ARG_FLASH_ATTN"));
add_opt(common_arg(
@@ -3461,8 +3469,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-Q8_0-GGUF";
params.model.hf_file = "qwen2.5-coder-1.5b-q8_0.gguf";
params.port = 8012;
- params.n_gpu_layers = 99;
- params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
@@ -3477,8 +3483,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF";
params.model.hf_file = "qwen2.5-coder-3b-q8_0.gguf";
params.port = 8012;
- params.n_gpu_layers = 99;
- params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
@@ -3493,8 +3497,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
params.port = 8012;
- params.n_gpu_layers = 99;
- params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
@@ -3510,10 +3512,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
- params.speculative.n_gpu_layers = 99;
params.port = 8012;
- params.n_gpu_layers = 99;
- params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
@@ -3529,10 +3528,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf";
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
- params.speculative.n_gpu_layers = 99;
params.port = 8012;
- params.n_gpu_layers = 99;
- params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
@@ -3547,8 +3543,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.model.hf_repo = "ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF";
params.model.hf_file = "qwen3-coder-30b-a3b-instruct-q8_0.gguf";
params.port = 8012;
- params.n_gpu_layers = 99;
- params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
diff --git a/common/chat.cpp b/common/chat.cpp
index 00715b0df..1200b738f 100644
--- a/common/chat.cpp
+++ b/common/chat.cpp
@@ -622,6 +622,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
+ case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
default:
throw std::runtime_error("Unknown chat format");
}
@@ -2059,6 +2060,94 @@ static void common_chat_parse_granite(common_chat_msg_parser & builder) {
}
}
+static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
+ // Parse thinking tags first - this handles the main reasoning content
+ builder.try_parse_reasoning("", "");
+
+ if (!builder.syntax().parse_tool_calls) {
+ builder.add_content(builder.consume_rest());
+ return;
+ }
+
+ // Parse tool calls - Seed-OSS uses format
+ static const common_regex tool_call_begin_regex("");
+ static const common_regex tool_call_end_regex("");
+ static const common_regex function_regex("]+)>");
+ static const common_regex param_regex("]+)>");
+
+ while (auto tool_res = builder.try_find_regex(tool_call_begin_regex)) {
+ builder.consume_spaces(); // Consume whitespace after
+
+ // Look for function call inside tool call, ignore any content before it
+ if (auto func_res = builder.try_find_regex(function_regex, std::string::npos, false)) {
+ auto function_name = builder.str(func_res->groups[1]);
+
+ // Parse Seed-OSS parameters value
+ json args = json::object();
+ // Parse all parameters
+ while (auto param_res = builder.try_find_regex(param_regex, std::string::npos, false)) {
+ // again, ignore noise around parameters
+ auto param_name = builder.str(param_res->groups[1]);
+ builder.move_to(param_res->groups[0].end);
+ builder.consume_spaces(); // Consume whitespace after parameter
+ auto savedPos = builder.pos();
+ if (auto param_parse = builder.try_find_literal("")) {
+ auto param = param_parse->prelude;
+ builder.move_to(savedPos);
+ try {
+ if (auto param_res = builder.try_consume_json()) {
+ args[param_name] = param_res->json;
+ } else {
+ args[param_name] = param;
+ }
+ } catch (json::exception &) {
+ args[param_name] = param;
+ }
+ } else {
+ throw common_chat_msg_partial_exception("Incomplete tool parameter");
+ }
+ }
+ // Look for closing function tag
+ auto end_func = builder.try_find_literal("");
+ if (end_func) {
+ builder.move_to(end_func->groups[0].end);
+ builder.consume_spaces(); // Consume whitespace after
+
+ // Add the tool call with parsed arguments, but only if we REALLY got the literal
+ auto eaten_fragment = builder.input().substr(end_func->groups[0].begin, end_func->groups[0].end);
+ auto funlen = std::string("").length();
+ if (eaten_fragment.length() >= funlen && eaten_fragment.substr(0, funlen) == std::string("")) {
+ if (!builder.add_tool_call(function_name, "", args.dump())) {
+ throw common_chat_msg_partial_exception("Incomplete tool call");
+ }
+ } else {
+ throw common_chat_msg_partial_exception("Incomplete tool call");
+ }
+ } else {
+ throw common_chat_msg_partial_exception("Incomplete tool call");
+ }
+ // Look for closing tool call tag
+ if (auto end_tool = builder.try_find_regex(tool_call_end_regex, std::string::npos, false)) {
+ builder.move_to(end_tool->groups[0].end);
+ builder.consume_spaces(); // Consume trailing whitespace after tool call
+ } else {
+ throw common_chat_msg_partial_exception("Incomplete tool call");
+ }
+ } else {
+ // No function found - don't consume content here, let it be handled at the end
+ break;
+ }
+ }
+
+ // Consume any remaining whitespace after all tool call processing
+ builder.consume_spaces();
+ auto remaining = builder.consume_rest();
+ // If there's any non-whitespace content remaining, add it as content
+ if (!string_strip(remaining).empty()) {
+ builder.add_content(remaining);
+ }
+}
+
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.prompt = apply(tmpl, inputs);
@@ -2075,8 +2164,62 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
return data;
}
+static common_chat_params common_chat_params_init_seed_oss(
+ const common_chat_template & tmpl,
+ templates_params & params,
+ const common_chat_templates_inputs & inputs)
+{
+ common_chat_params data;
+ data.prompt = apply(tmpl, params);
+ data.format = COMMON_CHAT_FORMAT_SEED_OSS;
+ if (string_ends_with(data.prompt, "")) {
+ if (!inputs.enable_thinking) {
+ data.prompt += "";
+ } else {
+ data.thinking_forced_open = true;
+ }
+ }
+
+ if (params.tools.is_array() && !params.tools.empty()) {
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
+ std::vector tool_rules;
+ foreach_function(params.tools, [&](const json & tool) {
+ const auto & function = tool.at("function");
+ std::string name = function.at("name");
+ auto parameters = function.at("parameters");
+ builder.resolve_refs(parameters);
+
+ // Create rule for Seed-OSS function call format
+ std::string param_rules;
+ if (parameters.contains("properties")) {
+ for (const auto & [key, value] : parameters.at("properties").items()) {
+ param_rules += "\"\"" + builder.add_schema(name + "-arg-" + key, value) +
+ "\"\"";
+ }
+ }
+
+ tool_rules.push_back(builder.add_rule(name + "-call",
+ "\"\" space \"\" space " +
+ param_rules +
+ " \"\" space \"\""));
+ });
+
+ data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "" });
+
+ data.preserved_tokens = {
+ "", "", "", "",
+ "", "",
+ };
+
+ builder.add_rule("root", string_join(tool_rules, " | "));
+ });
+ }
+ return data;
+}
+
static common_chat_params common_chat_templates_apply_jinja(
- const struct common_chat_templates * tmpls,
+ const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs)
{
templates_params params;
@@ -2145,6 +2288,11 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_gpt_oss(tmpl, params);
}
+ // Seed-OSS
+ if (src.find("") != std::string::npos) {
+ return common_chat_params_init_seed_oss(tmpl, params, inputs);
+ }
+
// Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below.
if ((params.tools.is_array() && params.json_schema.is_object())) {
@@ -2303,6 +2451,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_GPT_OSS:
common_chat_parse_gpt_oss(builder);
break;
+ case COMMON_CHAT_FORMAT_SEED_OSS:
+ common_chat_parse_seed_oss(builder);
+ break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
diff --git a/common/chat.h b/common/chat.h
index d1e480c91..b09ff3b12 100644
--- a/common/chat.h
+++ b/common/chat.h
@@ -111,6 +111,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_GRANITE,
COMMON_CHAT_FORMAT_GPT_OSS,
+ COMMON_CHAT_FORMAT_SEED_OSS,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
diff --git a/common/common.cpp b/common/common.cpp
index 8d35106ba..0e5057942 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -909,7 +909,8 @@ struct common_init_result common_init_from_params(common_params & params) {
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
if (model == NULL) {
- LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
+ LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
+ __func__, params.model.path.c_str());
return iparams;
}
@@ -919,7 +920,8 @@ struct common_init_result common_init_from_params(common_params & params) {
llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) {
- LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
+ LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
+ __func__, params.model.path.c_str());
llama_model_free(model);
return iparams;
}
@@ -1165,10 +1167,10 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
cparams.pooling_type = params.pooling_type;
cparams.attention_type = params.attention_type;
+ cparams.flash_attn_type = params.flash_attn_type;
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.offload_kqv = !params.no_kv_offload;
- cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf;
cparams.op_offload = !params.no_op_offload;
cparams.swa_full = params.swa_full;
diff --git a/common/common.h b/common/common.h
index b4fe700c1..3a03cf0cd 100644
--- a/common/common.h
+++ b/common/common.h
@@ -308,6 +308,7 @@ struct common_params {
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
+ enum llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; // whether to use Flash Attention
struct common_params_sampling sampling;
struct common_params_speculative speculative;
@@ -371,7 +372,6 @@ struct common_params {
bool multiline_input = false; // reverse the usage of `\`
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly
- bool flash_attn = false; // flash attention
bool no_perf = false; // disable performance metrics
bool ctx_shift = false; // context shift on infinite text generation
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 6c8a03406..df37c4a6e 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -7546,9 +7546,13 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
]
# n_group and d_inner are used during reshape_tensors for mamba2
- self.d_model = self.find_hparam(["hidden_size", "d_model"])
- self.n_group = self.find_hparam(["n_groups"])
- self.d_inner = self.find_hparam(["expand"]) * self.d_model
+ # NOTE: Explicitly include hparam prefix prefix for d_model to
+ # disambiguate with top-level head_dim
+ # NOTE 2: If needed for future models, this can be isolated in a method
+ # to separate the prefix setting and teh keys used
+ self.d_model = self.find_hparam([f"{self.hparam_prefixes[0]}_head_dim", "hidden_size", "d_model"])
+ self.n_group = self.find_hparam(["n_groups", "num_groups"])
+ self.d_inner = self.find_hparam(["expand", "num_heads"]) * self.d_model
def get_attn_layers(self):
# Explicit list of layer type names
@@ -7609,12 +7613,12 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
## Mamba mixer params ##
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
- self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
+ self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state", "state_dim", "ssm_state_size"]))
self.gguf_writer.add_ssm_group_count(self.n_group)
self.gguf_writer.add_ssm_inner_size(self.d_inner)
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
# in llama.cpp
- self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
+ self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads", "num_heads"]))
## Attention params ##
head_count_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
@@ -7641,6 +7645,55 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
Mamba2Model.set_vocab(self)
+@ModelBase.register("NemotronHForCausalLM")
+class NemotronHModel(GraniteHybridModel):
+ """Hybrid mamba2/attention model from NVIDIA"""
+ model_arch = gguf.MODEL_ARCH.NEMOTRON_H
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # Save the top-level head_dim for later
+ self.head_dim = self.hparams.get("head_dim", self.hparams.get("attention_head_dim"))
+ assert self.head_dim is not None, "Could not find the attention head dim in config"
+
+ # Don't use expand to calculate d_inner
+ self.d_inner = self.find_hparam(["num_heads"]) * self.d_model
+
+ # Update the ssm / attn / mlp layers
+ # M: Mamba2, *: Attention, -: MLP
+ hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
+ self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"]
+ self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "-"]
+
+ def get_attn_layers(self):
+ hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
+ assert len(hybrid_override_pattern) == self.block_count, "Mismatch between hybrid override and num_hidden_layers!"
+ return [i for i, val in enumerate(hybrid_override_pattern) if val == "*"]
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+
+ self.gguf_writer.add_key_length(self.head_dim)
+ self.gguf_writer.add_value_length(self.head_dim)
+
+ # Set feed_forward_length
+ # NOTE: This will trigger an override warning. This is preferrable to
+ # duplicating all the parent logic
+ n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"])
+ self.gguf_writer.add_feed_forward_length([
+ n_ff if i in self._mlp_layers else 0 for i in range(self.block_count)
+ ])
+
+ def set_vocab(self):
+ super().set_vocab()
+
+ # The tokenizer _does_ add a BOS token (via post_processor type
+ # TemplateProcessing) but does not set add_bos_token to true in the
+ # config, so we need to explicitly override it here.
+ self.gguf_writer.add_add_bos_token(True)
+
+
@ModelBase.register("BailingMoeForCausalLM")
class BailingMoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.BAILINGMOE
diff --git a/examples/diffusion/diffusion-cli.cpp b/examples/diffusion/diffusion-cli.cpp
index 8431dcea8..abf7fb357 100644
--- a/examples/diffusion/diffusion-cli.cpp
+++ b/examples/diffusion/diffusion-cli.cpp
@@ -564,7 +564,7 @@ int main(int argc, char ** argv) {
ctx_params.n_ctx = params.n_ctx;
ctx_params.n_batch = params.n_batch;
ctx_params.n_ubatch = params.n_ubatch;
- ctx_params.flash_attn = params.flash_attn;
+ ctx_params.flash_attn_type = params.flash_attn_type;
ctx_params.no_perf = params.no_perf;
ctx_params.type_k = params.cache_type_k;
ctx_params.type_v = params.cache_type_v;
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
index 3fa8d2090..fa3efe976 100644
--- a/ggml/src/ggml-backend.cpp
+++ b/ggml/src/ggml-backend.cpp
@@ -31,6 +31,7 @@
// backend buffer type
const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) {
+ GGML_ASSERT(buft);
return buft->iface.get_name(buft);
}
@@ -40,14 +41,17 @@ ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t
return ggml_backend_buffer_init(buft, {}, NULL, 0);
}
+ GGML_ASSERT(buft);
return buft->iface.alloc_buffer(buft, size);
}
size_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) {
+ GGML_ASSERT(buft);
return buft->iface.get_alignment(buft);
}
size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) {
+ GGML_ASSERT(buft);
// get_max_size is optional, defaults to SIZE_MAX
if (buft->iface.get_max_size) {
return buft->iface.get_max_size(buft);
@@ -56,6 +60,7 @@ size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) {
}
size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
+ GGML_ASSERT(buft);
// get_alloc_size is optional, defaults to ggml_nbytes
if (buft->iface.get_alloc_size) {
size_t size = buft->iface.get_alloc_size(buft, tensor);
@@ -66,6 +71,7 @@ size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const s
}
bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) {
+ GGML_ASSERT(buft);
if (buft->iface.is_host) {
return buft->iface.is_host(buft);
}
@@ -73,6 +79,7 @@ bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) {
}
ggml_backend_dev_t ggml_backend_buft_get_device(ggml_backend_buffer_type_t buft) {
+ GGML_ASSERT(buft);
return buft->device;
}
@@ -110,10 +117,12 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
}
size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
+ GGML_ASSERT(buffer);
return buffer->size;
}
void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
+ GGML_ASSERT(buffer);
// get_base is optional if the buffer is zero-sized
if (buffer->size == 0) {
return NULL;
@@ -127,6 +136,7 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
}
enum ggml_status ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+ GGML_ASSERT(buffer);
// init_tensor is optional
if (buffer->iface.init_tensor) {
return buffer->iface.init_tensor(buffer, tensor);
@@ -135,6 +145,7 @@ enum ggml_status ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, s
}
void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ GGML_ASSERT(buffer);
// clear is optional if the buffer is zero-sized
if (buffer->size == 0) {
return;
@@ -160,6 +171,7 @@ bool ggml_backend_buffer_is_host(ggml_backend_buffer_t buffer) {
}
void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {
+ GGML_ASSERT(buffer);
buffer->usage = usage;
// FIXME: add a generic callback to the buffer interface
@@ -169,14 +181,17 @@ void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backe
}
enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage(ggml_backend_buffer_t buffer) {
+ GGML_ASSERT(buffer);
return buffer->usage;
}
ggml_backend_buffer_type_t ggml_backend_buffer_get_type(ggml_backend_buffer_t buffer) {
+ GGML_ASSERT(buffer);
return buffer->buft;
}
void ggml_backend_buffer_reset(ggml_backend_buffer_t buffer) {
+ GGML_ASSERT(buffer);
if (buffer->iface.reset) {
buffer->iface.reset(buffer);
}
@@ -215,6 +230,7 @@ void ggml_backend_free(ggml_backend_t backend) {
}
ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) {
+ GGML_ASSERT(backend);
return ggml_backend_dev_buffer_type(backend->device);
}
@@ -231,6 +247,8 @@ size_t ggml_backend_get_max_size(ggml_backend_t backend) {
}
void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ GGML_ASSERT(backend);
+ GGML_ASSERT(tensor);
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
@@ -242,6 +260,8 @@ void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor *
}
void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ GGML_ASSERT(backend);
+ GGML_ASSERT(tensor);
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
@@ -283,6 +303,7 @@ void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, siz
}
void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+ GGML_ASSERT(tensor);
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
if (size == 0) {
@@ -298,6 +319,7 @@ void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size
}
void ggml_backend_synchronize(ggml_backend_t backend) {
+ GGML_ASSERT(backend);
if (backend->iface.synchronize == NULL) {
return;
}
@@ -306,18 +328,21 @@ void ggml_backend_synchronize(ggml_backend_t backend) {
}
ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+ GGML_ASSERT(backend);
GGML_ASSERT(backend->iface.graph_plan_create != NULL);
return backend->iface.graph_plan_create(backend, cgraph);
}
void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+ GGML_ASSERT(backend);
GGML_ASSERT(backend->iface.graph_plan_free != NULL);
backend->iface.graph_plan_free(backend, plan);
}
enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+ GGML_ASSERT(backend);
GGML_ASSERT(backend->iface.graph_plan_compute != NULL);
return backend->iface.graph_plan_compute(backend, plan);
@@ -330,22 +355,27 @@ enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_
}
enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+ GGML_ASSERT(backend);
return backend->iface.graph_compute(backend, cgraph);
}
bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+ GGML_ASSERT(backend);
return ggml_backend_dev_supports_op(backend->device, op);
}
bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+ GGML_ASSERT(backend);
return ggml_backend_dev_supports_buft(backend->device, buft);
}
bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+ GGML_ASSERT(backend);
return ggml_backend_dev_offload_op(backend->device, op);
}
ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) {
+ GGML_ASSERT(backend);
return backend->device;
}
@@ -381,6 +411,7 @@ void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t b
return;
}
+ GGML_ASSERT(backend_dst);
if (backend_dst->iface.cpy_tensor_async != NULL) {
if (backend_dst->iface.cpy_tensor_async(backend_src, backend_dst, src, dst)) {
return;
@@ -412,18 +443,21 @@ void ggml_backend_event_free(ggml_backend_event_t event) {
}
void ggml_backend_event_record(ggml_backend_event_t event, ggml_backend_t backend) {
+ GGML_ASSERT(backend);
GGML_ASSERT(backend->iface.event_record != NULL);
backend->iface.event_record(backend, event);
}
void ggml_backend_event_synchronize(ggml_backend_event_t event) {
+ GGML_ASSERT(event);
GGML_ASSERT(event->device->iface.event_synchronize);
event->device->iface.event_synchronize(event->device, event);
}
void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
+ GGML_ASSERT(backend);
GGML_ASSERT(backend->iface.event_wait != NULL);
backend->iface.event_wait(backend, event);
@@ -432,18 +466,22 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event)
// Backend device
const char * ggml_backend_dev_name(ggml_backend_dev_t device) {
+ GGML_ASSERT(device);
return device->iface.get_name(device);
}
const char * ggml_backend_dev_description(ggml_backend_dev_t device) {
+ GGML_ASSERT(device);
return device->iface.get_description(device);
}
void ggml_backend_dev_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
+ GGML_ASSERT(device);
device->iface.get_memory(device, free, total);
}
enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) {
+ GGML_ASSERT(device);
return device->iface.get_type(device);
}
@@ -453,18 +491,22 @@ void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_d
}
ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device) {
+ GGML_ASSERT(device);
return device->reg;
}
ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params) {
+ GGML_ASSERT(device);
return device->iface.init_backend(device, params);
}
ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {
+ GGML_ASSERT(device);
return device->iface.get_buffer_type(device);
}
ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device) {
+ GGML_ASSERT(device);
if (device->iface.get_host_buffer_type == NULL) {
return NULL;
}
@@ -473,18 +515,22 @@ ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t
}
ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size) {
+ GGML_ASSERT(device);
return device->iface.buffer_from_host_ptr(device, ptr, size, max_tensor_size);
}
bool ggml_backend_dev_supports_op(ggml_backend_dev_t device, const struct ggml_tensor * op) {
+ GGML_ASSERT(device);
return device->iface.supports_op(device, op);
}
bool ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buffer_type_t buft) {
+ GGML_ASSERT(device);
return device->iface.supports_buft(device, buft);
}
bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op) {
+ GGML_ASSERT(device);
if (device->iface.offload_op != NULL) {
return device->iface.offload_op(device, op);
}
@@ -495,18 +541,22 @@ bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_te
// Backend (reg)
const char * ggml_backend_reg_name(ggml_backend_reg_t reg) {
+ GGML_ASSERT(reg);
return reg->iface.get_name(reg);
}
size_t ggml_backend_reg_dev_count(ggml_backend_reg_t reg) {
+ GGML_ASSERT(reg);
return reg->iface.get_device_count(reg);
}
ggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index) {
+ GGML_ASSERT(reg);
return reg->iface.get_device(reg, index);
}
void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
+ GGML_ASSERT(reg);
if (!reg->iface.get_proc_address) {
return NULL;
}
@@ -521,6 +571,7 @@ struct ggml_backend_multi_buffer_context {
};
static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ GGML_ASSERT(buffer);
ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;
for (size_t i = 0; i < ctx->n_buffers; i++) {
ggml_backend_buffer_free(ctx->buffers[i]);
@@ -531,6 +582,7 @@ static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer)
}
static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ GGML_ASSERT(buffer);
ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;
for (size_t i = 0; i < ctx->n_buffers; i++) {
ggml_backend_buffer_clear(ctx->buffers[i], value);
@@ -566,10 +618,12 @@ ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer
}
bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer) {
+ GGML_ASSERT(buffer);
return buffer->iface.free_buffer == ggml_backend_multi_buffer_free_buffer;
}
void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {
+ GGML_ASSERT(buffer);
GGML_ASSERT(ggml_backend_buffer_is_multi_buffer(buffer));
ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;
for (size_t i = 0; i < ctx->n_buffers; i++) {
@@ -1355,6 +1409,7 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
}
static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
+ GGML_ASSERT(sched);
struct ggml_backend_sched_split * splits = sched->splits;
ggml_tensor * prev_ids_tensor = nullptr;
@@ -1623,6 +1678,7 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
}
void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
+ GGML_ASSERT(sched);
// reset state for the next run
if (!sched->is_reset) {
ggml_hash_set_reset(&sched->hash_set);
@@ -1634,6 +1690,7 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
}
bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
+ GGML_ASSERT(sched);
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
ggml_backend_sched_synchronize(sched);
@@ -1650,6 +1707,7 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
}
bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+ GGML_ASSERT(sched);
GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs);
GGML_ASSERT(!sched->is_alloc);
@@ -1674,6 +1732,7 @@ enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, st
}
enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+ GGML_ASSERT(sched);
if (!sched->is_reset && !sched->is_alloc) {
ggml_backend_sched_reset(sched);
}
@@ -1688,6 +1747,7 @@ enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sch
}
void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
+ GGML_ASSERT(sched);
for (int i = 0; i < sched->n_backends; i++) {
ggml_backend_synchronize(sched->backends[i]);
}
@@ -1700,28 +1760,34 @@ void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
}
void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
+ GGML_ASSERT(sched);
sched->callback_eval = callback;
sched->callback_eval_user_data = user_data;
}
int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {
+ GGML_ASSERT(sched);
return sched->n_splits;
}
int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) {
+ GGML_ASSERT(sched);
return sched->n_copies;
}
int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched) {
+ GGML_ASSERT(sched);
return sched->n_backends;
}
ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i) {
+ GGML_ASSERT(sched);
GGML_ASSERT(i >= 0 && i < sched->n_backends);
return sched->backends[i];
}
size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {
+ GGML_ASSERT(sched);
int backend_index = ggml_backend_sched_backend_id(sched, backend);
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
@@ -1729,6 +1795,7 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe
}
void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
+ GGML_ASSERT(sched);
int backend_index = ggml_backend_sched_backend_id(sched, backend);
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
tensor_backend_id(node) = backend_index;
@@ -1737,6 +1804,7 @@ void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct gg
}
ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) {
+ GGML_ASSERT(sched);
int backend_index = tensor_backend_id(node);
if (backend_index == -1) {
return NULL;
@@ -1747,6 +1815,7 @@ ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched,
// utils
enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor) {
+ GGML_ASSERT(tensor);
GGML_ASSERT(tensor->buffer == NULL);
GGML_ASSERT(tensor->view_src != NULL);
GGML_ASSERT(tensor->view_src->buffer != NULL);
@@ -1758,6 +1827,7 @@ enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor) {
}
enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) {
+ GGML_ASSERT(tensor);
GGML_ASSERT(tensor->buffer == NULL);
GGML_ASSERT(tensor->data == NULL);
GGML_ASSERT(tensor->view_src == NULL);
@@ -1831,6 +1901,7 @@ static void graph_copy_init_tensor(struct ggml_hash_set * hash_set, struct ggml_
}
struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) {
+ GGML_ASSERT(graph);
struct ggml_hash_set hash_set = ggml_hash_set_new(graph->visited_hash_set.size);
struct ggml_tensor ** node_copies = (ggml_tensor **) calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT
bool * node_init = (bool *) calloc(hash_set.size, sizeof(node_init[0]));
@@ -1975,6 +2046,7 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
// CPU backend - buffer
static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
+ GGML_ASSERT(buffer);
uintptr_t data = (uintptr_t)buffer->context;
// align the buffer
@@ -1986,28 +2058,33 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
}
static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ GGML_ASSERT(buffer);
ggml_aligned_free(buffer->context, buffer->size);
}
static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+ GGML_ASSERT(tensor);
memset((char *)tensor->data + offset, value, size);
GGML_UNUSED(buffer);
}
static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ GGML_ASSERT(tensor);
memcpy((char *)tensor->data + offset, data, size);
GGML_UNUSED(buffer);
}
static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ GGML_ASSERT(tensor);
memcpy(data, (const char *)tensor->data + offset, size);
GGML_UNUSED(buffer);
}
static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
+ GGML_ASSERT(src);
if (ggml_backend_buffer_is_host(src->buffer)) {
memcpy(dst->data, src->data, ggml_nbytes(src));
return true;
@@ -2018,6 +2095,7 @@ static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
}
static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ GGML_ASSERT(buffer);
memset(buffer->context, value, buffer->size);
}
diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h
index 1f6844e16..e08c30a34 100644
--- a/ggml/src/ggml-cpu/ggml-cpu-impl.h
+++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h
@@ -489,7 +489,7 @@ inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) {
/**
* @see https://github.com/ggml-org/llama.cpp/pull/14037
*/
-inline float vec_hsum(float32x4_t v) {
+inline static float vec_hsum(float32x4_t v) {
float32x4_t v_temp = v + vec_reve(v);
return v_temp[0] + v_temp[1];
}
diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu
index e1fbf0e13..1c7656634 100644
--- a/ggml/src/ggml-cuda/binbcast.cu
+++ b/ggml/src/ggml-cuda/binbcast.cu
@@ -1,5 +1,6 @@
#include "binbcast.cuh"
#include
+#include
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
return b;
@@ -22,13 +23,16 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {
return a / b;
}
-template
+
+
+template
static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
- int ne0, int ne1, int ne2, int ne3,
- int ne10, int ne11, int ne12, int ne13,
- /*int s0, */ int s1, int s2, int s3,
- /*int s00,*/ int s01, int s02, int s03,
- /*int s10,*/ int s11, int s12, int s13) {
+ const int ne0, const int ne1, const int ne2, const int ne3,
+ const int ne10, const int ne11, const int ne12, const int ne13,
+ /*int s0, */ const int s1, const int s2, const int s3,
+ /*int s00,*/ const int s01, const int s02, const int s03,
+ /*int s10,*/ const int s11, const int s12, const int s13,
+ src1_ptrs... src1s) {
const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
@@ -46,24 +50,31 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
- const src0_t * src0_row = src0 + i_src0;
- const src1_t * src1_row = src1 + i_src1;
+ const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
const int i10 = i0 % ne10;
- dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+
+ float result = src0_row ? (float) src0_row[i0] : 0.0f;
+ if constexpr (sizeof...(src1_ptrs) > 0) {
+ result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+ } else {
+ result = bin_op(result, (float)src1[i_src1 + i10]);
+ }
+
+ dst_row[i0] = (dst_t) result;
}
}
-template
-static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
- int ne0, int ne1, int ne2, int ne3,
- int ne10, int ne11, int ne12, int ne13,
- /*int s0, */ int s1, int s2, int s3,
- /*int s00,*/ int s01, int s02, int s03,
- /*int s10,*/ int s11, int s12, int s13) {
-
+template
+static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
+ const int ne0, const int ne1, const int ne2,const int ne3,
+ const int ne10, const int ne11, const int ne12, const int ne13,
+ /*int s0, */ const int s1, const int s2, const int s3,
+ /*int s00,*/ const int s01, const int s02, const int s03,
+ /*int s10,*/ const int s11, const int s12, const int s13,
+ src1_ptrs ... src1s) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
const int i3 = i/(ne2*ne1*ne0);
@@ -83,12 +94,190 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
- const src0_t * src0_row = src0 + i_src0;
- const src1_t * src1_row = src1 + i_src1;
+ const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;
const int i10 = i0 % ne10;
- dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+
+ float result = src0_row ? (float) src0_row[i0] : 0.0f;
+ if constexpr (sizeof...(src1_ptrs) > 0) {
+ result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+ } else {
+ result = bin_op(result, (float)src1[i_src1 + i10]);
+ }
+
+ dst_row[i0] = (dst_t) result;
+}
+
+template
+static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
+ cudaStream_t stream, std::index_sequence) {
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ int nr0 = ne10 / ne0;
+ int nr1 = ne11 / ne1;
+ int nr2 = ne12 / ne2;
+ int nr3 = ne13 / ne3;
+
+ int nr[4] = { nr0, nr1, nr2, nr3 };
+
+ int64_t cne[] = { ne0, ne1, ne2, ne3 };
+ int64_t cne0[] = { ne00, ne01, ne02, ne03 };
+ int64_t cne1[] = { ne10, ne11, ne12, ne13 };
+
+ size_t cnb[] = { nb0, nb1, nb2, nb3 };
+ size_t cnb0[] = { nb00, nb01, nb02, nb03 };
+ size_t cnb1[] = { nb10, nb11, nb12, nb13 };
+
+ auto collapse = [](int64_t cne[]) {
+ cne[0] *= cne[1];
+ cne[1] = cne[2];
+ cne[2] = cne[3];
+ cne[3] = 1;
+ };
+
+ auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
+ cnb[1] *= cne[1];
+ cnb[2] *= cne[2];
+ cnb[3] *= cne[3];
+ };
+
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
+ for (int i = 0; i < 4; i++) {
+ if (nr[i] != 1) {
+ break;
+ }
+ if (i > 0) {
+ collapse_nb(cnb, cne);
+ collapse_nb(cnb0, cne0);
+ collapse_nb(cnb1, cne1);
+ collapse(cne);
+ collapse(cne0);
+ collapse(cne1);
+ }
+ }
+ }
+
+ {
+ int64_t ne0 = cne[0];
+ int64_t ne1 = cne[1];
+ int64_t ne2 = cne[2];
+ int64_t ne3 = cne[3];
+
+ //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
+ //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
+ //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
+ //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
+
+ int64_t ne10 = cne1[0];
+ int64_t ne11 = cne1[1];
+ int64_t ne12 = cne1[2];
+ int64_t ne13 = cne1[3];
+
+ size_t nb0 = cnb[0];
+ size_t nb1 = cnb[1];
+ size_t nb2 = cnb[2];
+ size_t nb3 = cnb[3];
+
+ size_t nb00 = cnb0[0];
+ size_t nb01 = cnb0[1];
+ size_t nb02 = cnb0[2];
+ size_t nb03 = cnb0[3];
+
+ size_t nb10 = cnb1[0];
+ size_t nb11 = cnb1[1];
+ size_t nb12 = cnb1[2];
+ size_t nb13 = cnb1[3];
+
+ size_t s0 = nb0 / sizeof(dst_t);
+ size_t s1 = nb1 / sizeof(dst_t);
+ size_t s2 = nb2 / sizeof(dst_t);
+ size_t s3 = nb3 / sizeof(dst_t);
+
+ size_t s10 = nb10 / sizeof(src1_t);
+ size_t s11 = nb11 / sizeof(src1_t);
+ size_t s12 = nb12 / sizeof(src1_t);
+ size_t s13 = nb13 / sizeof(src1_t);
+
+ size_t s00 = nb00 / sizeof(src0_t);
+ size_t s01 = nb01 / sizeof(src0_t);
+ size_t s02 = nb02 / sizeof(src0_t);
+ size_t s03 = nb03 / sizeof(src0_t);
+
+ GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
+ GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
+ GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
+ GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
+
+ GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
+ GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
+ GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
+ GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
+
+ GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
+ GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
+ GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
+ GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
+
+ GGML_ASSERT(s0 == 1);
+ GGML_ASSERT(s00 == 1);
+ GGML_ASSERT(s10 == 1);
+
+ const int block_size = 128;
+
+ int64_t hne0 = std::max(ne0 / 2LL, 1LL);
+
+ dim3 block_dims;
+ block_dims.x = std::min(hne0, block_size);
+ block_dims.y = std::min(ne1, block_size / block_dims.x);
+ block_dims.z = std::min(std::min(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
+
+ dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x,
+ (ne1 + block_dims.y - 1) / block_dims.y,
+ (ne2 * ne3 + block_dims.z - 1) / block_dims.z);
+
+ if (block_nums.z > 65535) {
+ int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
+ if constexpr (sizeof...(I) > 0) {
+ k_bin_bcast_unravel
+ <<>>(src0_dd, src1_dd, dst_dd,
+ ne0, ne1, ne2, ne3,
+ ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12,s13,
+ (const src1_t *) dst->src[I + 1]->data...);
+ } else {
+ k_bin_bcast_unravel
+ <<>>(src0_dd, src1_dd, dst_dd,
+ ne0, ne1, ne2, ne3,
+ ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12,s13);
+ }
+ } else {
+ if constexpr (sizeof...(I) > 0) {
+ k_bin_bcast
+ <<>>(src0_dd, src1_dd, dst_dd,
+ ne0, ne1, ne2, ne3,
+ ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12,s13,
+ (const src1_t *) dst->src[I + 1]->data...);
+ } else {
+ k_bin_bcast
+ <<>>(src0_dd, src1_dd, dst_dd,
+ ne0, ne1, ne2, ne3,
+ ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12,s13);
+ }
+ }
+ }
}
template
@@ -120,160 +309,14 @@ static __global__ void k_repeat_back(
dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
}
-template
+template
struct bin_bcast_cuda {
template
void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
cudaStream_t stream) {
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- int nr0 = ne10/ne0;
- int nr1 = ne11/ne1;
- int nr2 = ne12/ne2;
- int nr3 = ne13/ne3;
-
- int nr[4] = { nr0, nr1, nr2, nr3 };
-
- // collapse dimensions until first broadcast dimension
- int64_t cne[] = {ne0, ne1, ne2, ne3};
- int64_t cne0[] = {ne00, ne01, ne02, ne03};
- int64_t cne1[] = {ne10, ne11, ne12, ne13};
-
- size_t cnb[] = {nb0, nb1, nb2, nb3};
- size_t cnb0[] = {nb00, nb01, nb02, nb03};
- size_t cnb1[] = {nb10, nb11, nb12, nb13};
-
- auto collapse = [](int64_t cne[]) {
- cne[0] *= cne[1];
- cne[1] = cne[2];
- cne[2] = cne[3];
- cne[3] = 1;
- };
-
- auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
- cnb[1] *= cne[1];
- cnb[2] *= cne[2];
- cnb[3] *= cne[3];
- };
-
- if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
- for (int i = 0; i < 4; i++) {
- if (nr[i] != 1) {
- break;
- }
- if (i > 0) {
- collapse_nb(cnb, cne);
- collapse_nb(cnb0, cne0);
- collapse_nb(cnb1, cne1);
- collapse(cne);
- collapse(cne0);
- collapse(cne1);
- }
- }
- }
-
- {
- int64_t ne0 = cne[0];
- int64_t ne1 = cne[1];
- int64_t ne2 = cne[2];
- int64_t ne3 = cne[3];
-
- //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
- //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
- //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
- //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
-
- int64_t ne10 = cne1[0];
- int64_t ne11 = cne1[1];
- int64_t ne12 = cne1[2];
- int64_t ne13 = cne1[3];
-
- size_t nb0 = cnb[0];
- size_t nb1 = cnb[1];
- size_t nb2 = cnb[2];
- size_t nb3 = cnb[3];
-
- size_t nb00 = cnb0[0];
- size_t nb01 = cnb0[1];
- size_t nb02 = cnb0[2];
- size_t nb03 = cnb0[3];
-
- size_t nb10 = cnb1[0];
- size_t nb11 = cnb1[1];
- size_t nb12 = cnb1[2];
- size_t nb13 = cnb1[3];
-
- size_t s0 = nb0 / sizeof(dst_t);
- size_t s1 = nb1 / sizeof(dst_t);
- size_t s2 = nb2 / sizeof(dst_t);
- size_t s3 = nb3 / sizeof(dst_t);
-
- size_t s10 = nb10 / sizeof(src1_t);
- size_t s11 = nb11 / sizeof(src1_t);
- size_t s12 = nb12 / sizeof(src1_t);
- size_t s13 = nb13 / sizeof(src1_t);
-
- size_t s00 = nb00 / sizeof(src0_t);
- size_t s01 = nb01 / sizeof(src0_t);
- size_t s02 = nb02 / sizeof(src0_t);
- size_t s03 = nb03 / sizeof(src0_t);
-
- GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
- GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
- GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
- GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
-
- GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
- GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
- GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
- GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
-
- GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
- GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
- GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
- GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
-
- GGML_ASSERT(s0 == 1);
- GGML_ASSERT(s00 == 1);
- GGML_ASSERT(s10 == 1);
-
- const int block_size = 128;
-
- int64_t hne0 = std::max(ne0/2LL, 1LL);
-
- dim3 block_dims;
- block_dims.x = std::min(hne0, block_size);
- block_dims.y = std::min(ne1, block_size / block_dims.x);
- block_dims.z = std::min(std::min(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U);
-
- dim3 block_nums(
- (hne0 + block_dims.x - 1) / block_dims.x,
- (ne1 + block_dims.y - 1) / block_dims.y,
- (ne2*ne3 + block_dims.z - 1) / block_dims.z
- );
-
- if (block_nums.z > 65535) {
- // this is the maximum number of blocks in z dimension, fallback to 1D grid kernel
- int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
- k_bin_bcast_unravel<<>>(
- src0_dd, src1_dd, dst_dd,
- ne0, ne1, ne2, ne3,
- ne10, ne11, ne12, ne13,
- /* s0, */ s1, s2, s3,
- /* s00, */ s01, s02, s03,
- /* s10, */ s11, s12, s13);
- } else {
- k_bin_bcast<<>>(
- src0_dd, src1_dd, dst_dd,
- ne0, ne1, ne2, ne3,
- ne10, ne11, ne12, ne13,
- /* s0, */ s1, s2, s3,
- /* s00, */ s01, s02, s03,
- /* s10, */ s11, s12, s13);
- }
- }
+ launch_bin_bcast_pack(
+ src0, src1, dst, src0_dd, src1_dd, dst_dd, stream, std::make_index_sequence{});
}
};
@@ -312,7 +355,7 @@ static void ggml_cuda_op_bin_bcast(
}
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- ggml_cuda_op_bin_bcast>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
+ ggml_cuda_op_bin_bcast>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
}
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -331,6 +374,68 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}
+template
+static void ggml_cuda_op_fused_binbcast_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ cudaStream_t stream = ctx.stream();
+
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ launch_bin_bcast_pack(src0, src1, dst,
+ (const float *) src0->data, (const float *) src1->data, (float *) dst->data,
+ stream, std::make_index_sequence{});
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+ launch_bin_bcast_pack(src0, src1, dst,
+ (const half *) src0->data, (const half *) src1->data, (half *) dst->data,
+ stream, std::make_index_sequence{});
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
+ launch_bin_bcast_pack(src0, src1, dst,
+ (const half *) src0->data, (const float *) src1->data, (half *) dst->data,
+ stream, std::make_index_sequence{});
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+ launch_bin_bcast_pack(src0, src1, dst,
+ (const half *) src0->data, (const float *) src1->data, (float *) dst->data,
+ stream, std::make_index_sequence{});
+ } else {
+ fprintf(stderr,
+ "%s: unsupported types for fusion: dst: %s, src0: %s, src1: %s\n",
+ __func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
+ GGML_ABORT("fatal error");
+ }
+}
+
+
+void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
+ GGML_ASSERT(2 <= n_fuse && n_fuse <= 8);
+
+ switch (n_fuse) {
+ case 2:
+ ggml_cuda_op_fused_binbcast_impl(ctx, dst);
+ break;
+ case 3:
+ ggml_cuda_op_fused_binbcast_impl(ctx, dst);
+ break;
+ case 4:
+ ggml_cuda_op_fused_binbcast_impl(ctx, dst);
+ break;
+ case 5:
+ ggml_cuda_op_fused_binbcast_impl(ctx, dst);
+ break;
+ case 6:
+ ggml_cuda_op_fused_binbcast_impl(ctx, dst);
+ break;
+ case 7:
+ ggml_cuda_op_fused_binbcast_impl(ctx, dst);
+ break;
+ case 8:
+ ggml_cuda_op_fused_binbcast_impl(ctx, dst);
+ break;
+ default:
+ GGML_ASSERT(false && "Unsupported n_fuse value");
+ }
+}
+
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
diff --git a/ggml/src/ggml-cuda/binbcast.cuh b/ggml/src/ggml-cuda/binbcast.cuh
index 3ac1c9b03..62bc95011 100644
--- a/ggml/src/ggml-cuda/binbcast.cuh
+++ b/ggml/src/ggml-cuda/binbcast.cuh
@@ -7,3 +7,5 @@ void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);
diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu
new file mode 100644
index 000000000..bcb70762e
--- /dev/null
+++ b/ggml/src/ggml-cuda/conv2d.cu
@@ -0,0 +1,165 @@
+#include "conv2d.cuh"
+
+struct conv_params {
+ const int64_t IW, IH;
+ const int64_t OW, OH;
+ const int64_t KW, KH;
+ const int64_t ST_X, ST_Y;
+ const int64_t PD_X, PD_Y;
+ const int64_t DL_X, DL_Y;
+ const int64_t IC, OC;
+ const int64_t B;
+ const int64_t TOTAL;
+};
+
+struct kernel_bounds {
+ int64_t y_min, y_max;
+ int64_t x_min, x_max;
+};
+
+__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) {
+ return (a > b) ? a : b;
+}
+
+__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) {
+ return (a < b) ? a : b;
+}
+
+__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) {
+ kernel_bounds bounds;
+ bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
+ bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
+ bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
+ bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
+ return bounds;
+}
+
+__device__ __forceinline__ int calculate_input_coord(int64_t out_coord,
+ int64_t kern_coord,
+ int64_t stride,
+ int64_t dilation,
+ int64_t padding) {
+ return out_coord * stride + kern_coord * dilation - padding;
+}
+
+struct whcn_layout {
+ __device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
+ return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x;
+ }
+
+ __device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) {
+ return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx;
+ }
+
+ __device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
+ return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x;
+ }
+
+ __device__ static void unpack_indices(int64_t global_idx,
+ const conv_params & P,
+ int64_t & n,
+ int64_t & c,
+ int64_t & out_y,
+ int64_t & out_x) {
+ out_x = global_idx % P.OW;
+ out_y = (global_idx / P.OW) % P.OH;
+ c = (global_idx / (P.OW * P.OH)) % P.OC;
+ n = global_idx / (P.OW * P.OH * P.OC);
+ }
+};
+
+template
+static __global__ void conv2d_kernel(const float * __restrict__ input,
+ const T * __restrict__ kernel,
+ float * __restrict__ output,
+ const conv_params P) {
+ const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (global_idx >= P.TOTAL) {
+ return;
+ }
+
+ int64_t n, c_out, out_y, out_x;
+ Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x);
+
+ float acc = 0.0f;
+
+ for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
+ kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);
+
+ for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) {
+ const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y);
+
+ for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
+ const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
+
+ const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
+ const float kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
+ acc += (input_val * kernel_val);
+ }
+ }
+ }
+
+ // [N, OC, OH, OW]
+ output[Layout::output_index(n, c_out, out_y, out_x, P)] = acc;
+}
+
+template
+static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
+ const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
+ conv2d_kernel<<>>(X_D, K_D, Y_D, P);
+}
+
+static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
+ conv2d_cuda(X_D, K_D, Y_D, P, st);
+}
+
+static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
+ conv2d_cuda(X_D, K_D, Y_D, P, st);
+}
+
+void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * kernel = dst->src[0];
+ const ggml_tensor * input = dst->src[1];
+ float * K_D = (float *) kernel->data;
+ const float * X_D = (const float *) input->data;
+ float * Y_D = (float *) dst->data;
+
+ GGML_ASSERT(ggml_is_contiguous(kernel));
+ GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32);
+
+ // same number of input channels
+ GGML_ASSERT(input->ne[2] == kernel->ne[2]);
+
+ cudaStream_t st = ctx.stream();
+
+ const int32_t * p = (const int32_t *) dst->op_params;
+ const int ST_X = p[0]; // stride_x
+ const int ST_Y = p[1]; // stride_y
+ const int PD_X = p[2]; // padding_x
+ const int PD_Y = p[3]; // padding_y
+ const int DL_X = p[4]; // dilation_x
+ const int DL_Y = p[5]; // dilation_y
+
+ // No cwhn
+ GGML_ASSERT(p[6] == false);
+
+ const int IW = input->ne[0]; // input_w
+ const int IH = input->ne[1]; // input_h
+ const int OW = dst->ne[0]; // output_w
+ const int OH = dst->ne[1]; // output_h
+ const int KW = kernel->ne[0]; // kernel_w
+ const int KH = kernel->ne[1]; // kernel_h
+ const int IC = input->ne[2]; // input_channels
+ const int OC = kernel->ne[3]; // ouptut_chanles
+ const int B = input->ne[3]; // n_batches
+
+ const int64_t total = B * OC * OH * OW;
+ conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };
+
+ if (kernel->type == GGML_TYPE_F16) {
+ conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st);
+ } else {
+ conv2d_cuda_f32(X_D, K_D, Y_D, params, st);
+ }
+}
diff --git a/ggml/src/ggml-cuda/conv2d.cuh b/ggml/src/ggml-cuda/conv2d.cuh
new file mode 100644
index 000000000..ce4802c7e
--- /dev/null
+++ b/ggml/src/ggml-cuda/conv2d.cuh
@@ -0,0 +1,5 @@
+#pragma once
+#include "common.cuh"
+
+#define CUDA_CONV2D_BLOCK_SIZE 256
+void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 646af9231..a08117bfe 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -14,6 +14,7 @@ bool g_mul_mat_q = true;
#include "ggml-cuda/clamp.cuh"
#include "ggml-cuda/concat.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh"
+#include "ggml-cuda/conv2d.cuh"
#include "ggml-cuda/conv2d-dw.cuh"
#include "ggml-cuda/conv2d-transpose.cuh"
#include "ggml-cuda/convert.cuh"
@@ -2464,6 +2465,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_IM2COL:
ggml_cuda_op_im2col(ctx, dst);
break;
+ case GGML_OP_CONV_2D:
+ ggml_cuda_op_conv2d(ctx, dst);
+ break;
case GGML_OP_CONV_2D_DW:
ggml_cuda_op_conv2d_dw(ctx, dst);
break;
@@ -2830,9 +2834,14 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false;
}
- if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
+ if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
+ const ggml_tensor *add = nullptr;
+
+ if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) {
+ add = cgraph->nodes[node_idx+2];
+ }
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
@@ -2844,6 +2853,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false;
}
+ if (add && (add->src[0]->type != GGML_TYPE_F32 ||
+ add->src[1]->type != GGML_TYPE_F32 ||
+ add->type != GGML_TYPE_F32) ) {
+ return false;
+ }
+
//if rms norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
return false;
@@ -2854,6 +2869,10 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false;
}
+ if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) {
+ return false;
+ }
+
return true;
}
@@ -2900,7 +2919,46 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion) {
- if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
+
+ if (node->op == GGML_OP_ADD) {
+ int n_fuse = 0;
+ ggml_op ops[8];
+ std::fill(ops, ops + 8, GGML_OP_ADD);
+
+ for (; n_fuse <= 6; ++n_fuse){
+ if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
+ break;
+ }
+ if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
+ break;
+ }
+ if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
+ break;
+ }
+ }
+
+ n_fuse++;
+
+ if (n_fuse > 1) {
+ for (int j = 0; j < n_fuse - 1; ++j) {
+ node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
+ }
+ cgraph->nodes[i + n_fuse - 1]->data = node->data;
+ ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
+ i += n_fuse - 1;
+
+ continue;
+ }
+ }
+
+
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
+ ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
+ i += 2;
+ continue;
+ }
+
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
i++;
continue;
@@ -3514,6 +3572,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
}
case GGML_OP_IM2COL:
+ case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D:
diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu
index bddcca51b..d5157d958 100644
--- a/ggml/src/ggml-cuda/norm.cu
+++ b/ggml/src/ggml-cuda/norm.cu
@@ -104,12 +104,30 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
}
}
-template
-static __global__ void rms_norm_f32(
- const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
- const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0,
- const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0,
- const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) {
+template
+static __global__ void rms_norm_f32(const float * x, float * dst,
+ const int ncols,
+ const int64_t stride_row,
+ const int64_t stride_channel,
+ const int64_t stride_sample,
+ const float eps,
+ const float * mul = nullptr,
+ const int64_t mul_stride_row = 0,
+ const int64_t mul_stride_channel = 0,
+ const int64_t mul_stride_sample = 0,
+ const int mul_ncols = 0,
+ const int mul_nrows = 0,
+ const int mul_nchannels = 0,
+ const int mul_nsamples = 0,
+ const float * add = nullptr,
+ const int64_t add_stride_row = 0,
+ const int64_t add_stride_channel = 0,
+ const int64_t add_stride_sample = 0,
+ const int add_ncols = 0,
+ const int add_nrows = 0,
+ const int add_nchannels = 0,
+ const int add_nsamples = 0) {
+
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
@@ -118,6 +136,8 @@ static __global__ void rms_norm_f32(
const int sample = blockIdx.z;
const int tid = threadIdx.x;
+ static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying");
+
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
@@ -128,6 +148,13 @@ static __global__ void rms_norm_f32(
mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
}
+ if constexpr (do_add) {
+ const int add_row = row % add_nrows;
+ const int add_channel = channel % add_nchannels;
+ const int add_sample = sample % add_nsamples;
+ add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
+ }
+
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
@@ -154,7 +181,11 @@ static __global__ void rms_norm_f32(
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
- if constexpr (do_multiply) {
+ if constexpr (do_multiply && do_add) {
+ const int mul_col = col % mul_ncols;
+ const int add_col = col % add_ncols;
+ dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
+ } else if constexpr (do_multiply) {
const int mul_col = col % mul_ncols;
dst[col] = scale * x[col] * mul[mul_col];
} else {
@@ -331,23 +362,70 @@ static void rms_norm_f32_cuda(
}
}
-static void rms_norm_mul_f32_cuda(
- const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
- const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
- const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample,
- const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples,
- const float eps, cudaStream_t stream) {
+static void rms_norm_mul_f32_cuda(const float * x,
+ const float * mul,
+ const float * add,
+ float * dst,
+ const int ncols,
+ const int nrows,
+ const int nchannels,
+ const int nsamples,
+ const int64_t stride_row,
+ const int64_t stride_channel,
+ const int64_t stride_sample,
+ const int64_t mul_stride_row,
+ const int64_t mul_stride_channel,
+ const int64_t mul_stride_sample,
+ const int mul_ncols,
+ const int mul_nrows,
+ const int mul_nchannels,
+ const int mul_nsamples,
+ const int64_t add_stride_row,
+ const int64_t add_stride_channel,
+ const int64_t add_stride_sample,
+ const int add_ncols,
+ const int add_nrows,
+ const int add_nchannels,
+ const int add_nsamples,
+ const float eps,
+ cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (mul == nullptr) {
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
return;
}
- if (ncols < 1024) {
- const dim3 block_dims(WARP_SIZE, 1, 1);
- rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+ if (add == nullptr) {
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ rms_norm_f32<<>>(x, dst,
+ ncols, stride_row, stride_channel, stride_sample, eps,
+ mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ rms_norm_f32<1024, true><<>>(x, dst,
+ ncols, stride_row, stride_channel, stride_sample, eps,
+ mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+ }
} else {
- const dim3 block_dims(1024, 1, 1);
- rms_norm_f32<1024, true><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ rms_norm_f32<<>>(x, dst,
+ ncols, stride_row, stride_channel, stride_sample, eps,
+ mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+ add, add_stride_row, add_stride_channel, add_stride_sample,
+ add_ncols, add_nrows, add_nchannels, add_nsamples);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ rms_norm_f32<1024, true, true><<>>(x, dst,
+ ncols, stride_row, stride_channel, stride_sample, eps,
+ mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+ add, add_stride_row, add_stride_channel, add_stride_sample,
+ add_ncols, add_nrows, add_nchannels, add_nsamples);
+ }
}
}
@@ -491,7 +569,102 @@ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor *
const int mul_nchannels = mul_src->ne[2];
const int mul_nsamples = mul_src->ne[3];
- rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream);
+ rms_norm_mul_f32_cuda(src0_d, mul_d, nullptr, dst_d,
+ ne00, ne01, ne02, ne03,
+ /*s00*/ s01, s02, s03,
+ /*mul_s00*/ mul_s01, mul_s02, mul_s03,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+ /*add_s00*/ 0, 0, 0,
+ 0, 0, 0, 0,
+ eps, stream);
+}
+
+void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
+ ggml_tensor * dst,
+ ggml_tensor * mul_tensor,
+ ggml_tensor * add_tensor) {
+ const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
+ float eps = 0.0f;
+
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ const float * src0_d = (const float *) rms_norm_src->data;
+ const float * mul_d = nullptr;
+ const ggml_tensor * mul_src = nullptr;
+
+ if (mul_tensor->src[0] == dst) {
+ mul_d = (float *) mul_tensor->src[1]->data;
+ mul_src = mul_tensor->src[1];
+ } else if (mul_tensor->src[1] == dst) {
+ mul_d = (float *) mul_tensor->src[0]->data;
+ mul_src = mul_tensor->src[0];
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ const float * add_d = nullptr;
+ const ggml_tensor * add_src = nullptr;
+
+ if (add_tensor->src[0] == mul_tensor) {
+ add_d = (float *) add_tensor->src[1]->data;
+ add_src = add_tensor->src[1];
+ } else if (add_tensor->src[1] == mul_tensor) {
+ add_d = (float *) add_tensor->src[0]->data;
+ add_src = add_tensor->src[0];
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ float * dst_d = (float *) add_tensor->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
+ GGML_ASSERT(add_tensor->type == GGML_TYPE_F32);
+ GGML_ASSERT(eps >= 0.0f);
+
+ const int64_t ne00 = rms_norm_src->ne[0];
+ const int64_t ne01 = rms_norm_src->ne[1];
+ const int64_t ne02 = rms_norm_src->ne[2];
+ const int64_t ne03 = rms_norm_src->ne[3];
+
+ const size_t ts0 = ggml_type_size(rms_norm_src->type);
+ GGML_ASSERT(rms_norm_src->nb[0] == ts0);
+ const int64_t s01 = rms_norm_src->nb[1] / ts0;
+ const int64_t s02 = rms_norm_src->nb[2] / ts0;
+ const int64_t s03 = rms_norm_src->nb[3] / ts0;
+
+ const size_t ts_mul = ggml_type_size(mul_src->type);
+ GGML_ASSERT(mul_src->nb[0] == ts_mul);
+ const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
+ const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
+ const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
+
+ const int mul_ncols = mul_src->ne[0];
+ const int mul_nrows = mul_src->ne[1];
+ const int mul_nchannels = mul_src->ne[2];
+ const int mul_nsamples = mul_src->ne[3];
+
+ const size_t ts_add = ggml_type_size(add_src->type);
+ GGML_ASSERT(add_src->nb[0] == ts_add);
+ const int64_t add_s01 = add_src->nb[1] / ts_add;
+ const int64_t add_s02 = add_src->nb[2] / ts_add;
+ const int64_t add_s03 = add_src->nb[3] / ts_add;
+
+ const int add_ncols = add_src->ne[0];
+ const int add_nrows = add_src->ne[1];
+ const int add_nchannels = add_src->ne[2];
+ const int add_nsamples = add_src->ne[3];
+
+ rms_norm_mul_f32_cuda(src0_d, mul_d,add_d,dst_d,
+ ne00,ne01, ne02, ne03,
+ /*s00*/ s01, s02, s03,
+ /*mul_s00*/ mul_s01, mul_s02, mul_s03,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+ /*add_s00*/ add_s01, add_s02, add_s03,
+ add_ncols, add_nrows, add_nchannels, add_nsamples,
+ eps, stream);
}
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh
index 7ea7bd4df..a74f63767 100644
--- a/ggml/src/ggml-cuda/norm.cuh
+++ b/ggml/src/ggml-cuda/norm.cuh
@@ -8,6 +8,11 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor);
+void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
+ ggml_tensor * dst,
+ ggml_tensor * mul_tensor,
+ ggml_tensor * add_tensor);
+
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 9e39e5c3c..264601cbb 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -582,6 +582,7 @@ struct vk_device_struct {
bool disable_fusion;
bool disable_host_visible_vidmem;
+ bool allow_sysmem_fallback;
#ifdef GGML_VULKAN_MEMORY_DEBUG
std::unique_ptr memory_logger;
@@ -1824,8 +1825,8 @@ static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_pr
return UINT32_MAX;
}
-static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
- VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags) << ", " << to_string(fallback_flags) << ")");
+static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list) {
+ VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")");
if (size > device->max_memory_allocation_size) {
printf("\nWARNING: Requested buffer size (%zu) exceeds device memory allocation limit (%zu)!\n",size,device->max_memory_allocation_size);
}
@@ -1852,42 +1853,27 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
- uint32_t memory_type_index = UINT32_MAX;
+ for (auto &req_flags : req_flags_list) {
+ uint32_t memory_type_index = find_properties(&mem_props, &mem_req, req_flags);
- memory_type_index = find_properties(&mem_props, &mem_req, req_flags);
- buf->memory_property_flags = req_flags;
+ if (memory_type_index == UINT32_MAX) {
+ continue;
+ }
+ buf->memory_property_flags = req_flags;
- if (memory_type_index == UINT32_MAX && fallback_flags) {
- memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags);
- buf->memory_property_flags = fallback_flags;
+ try {
+ buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
+ break;
+ } catch (const vk::SystemError& e) {
+ // loop and retry
+ }
}
- if (memory_type_index == UINT32_MAX) {
+ if (buf->device_memory == VK_NULL_HANDLE) {
device->device.destroyBuffer(buf->buffer);
throw vk::OutOfDeviceMemoryError("No suitable memory type found");
}
- try {
- buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
- } catch (const vk::SystemError& e) {
- if (buf->memory_property_flags != fallback_flags) {
- // Try again with fallback flags
- memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags);
- buf->memory_property_flags = fallback_flags;
-
- try {
- buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
- }
- catch (const vk::SystemError& e) {
- device->device.destroyBuffer(buf->buffer);
- throw e;
- }
- } else {
- // Out of Host/Device memory, clean up buffer
- device->device.destroyBuffer(buf->buffer);
- throw e;
- }
- }
buf->ptr = nullptr;
if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
@@ -1908,7 +1894,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor
static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
try {
- return ggml_vk_create_buffer(device, size, req_flags, fallback_flags);
+ return ggml_vk_create_buffer(device, size, {req_flags, fallback_flags});
} catch (const vk::SystemError& e) {
std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl;
std::cerr << "ggml_vulkan: " << e.what() << std::endl;
@@ -1920,15 +1906,29 @@ static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
vk_buffer buf;
try {
if (device->prefer_host_memory) {
- buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
+ vk::MemoryPropertyFlagBits::eDeviceLocal});
} else if (device->uma) {
// Fall back to host memory type
- buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
+ buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal,
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
} else if (device->disable_host_visible_vidmem) {
- buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ if (device->allow_sysmem_fallback) {
+ buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal,
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
+ } else {
+ buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal});
+ }
} else {
// use rebar if available, otherwise fallback to device only visible memory
- buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ if (device->allow_sysmem_fallback) {
+ buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
+ vk::MemoryPropertyFlagBits::eDeviceLocal,
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
+ } else {
+ buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
+ vk::MemoryPropertyFlagBits::eDeviceLocal});
+ }
}
} catch (const vk::SystemError& e) {
std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
@@ -2241,7 +2241,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
s_mmq_wg_denoms_k = { 32, 64, 1 };
// spec constants and tile sizes for quant matmul_id
- l_warptile_mmqid = { 256, 128, 128, 16, 0, device->subgroup_size };
+ l_warptile_mmqid = { 256, 128, 128, 16, 1, device->subgroup_size };
m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
l_mmqid_wg_denoms = { 128, 128, 1 };
@@ -3459,6 +3459,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->disable_host_visible_vidmem = true; //kcpp requested fix for vulkan BSOD on Nvidia
}
+ const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv("GGML_VK_ALLOW_SYSMEM_FALLBACK");
+ device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr;
+
bool fp16_storage = false;
bool fp16_compute = false;
bool maintenance4_support = false;
@@ -4804,8 +4807,8 @@ static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_
static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
vk_buffer buf = ggml_vk_create_buffer(device, size,
- vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
- vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
+ {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n",
@@ -5830,11 +5833,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
ggml_vk_sync_buffers(ctx, subctx);
}
}
- if (y_non_contig || quantize_y) {
- if (ctx->prealloc_y_need_sync) {
- ggml_vk_sync_buffers(ctx, subctx);
- }
- }
if (x_non_contig) {
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
@@ -5846,6 +5844,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
if (y_non_contig) {
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
+ if (ctx->prealloc_y_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
@@ -5854,6 +5855,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
if (quantize_y) {
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
+ if (ctx->prealloc_y_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
ctx->prealloc_y_last_tensor_used = src1;
@@ -6038,11 +6042,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
ggml_vk_sync_buffers(ctx, subctx);
}
}
- if (y_non_contig) {
- if (ctx->prealloc_y_need_sync) {
- ggml_vk_sync_buffers(ctx, subctx);
- }
- }
if (x_non_contig) {
GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
@@ -6052,6 +6051,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
+ if (ctx->prealloc_y_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
@@ -6484,11 +6486,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
ggml_vk_sync_buffers(ctx, subctx);
}
}
- if (y_non_contig) {
- if (ctx->prealloc_y_need_sync) {
- ggml_vk_sync_buffers(ctx, subctx);
- }
- }
if (x_non_contig) {
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
@@ -6501,6 +6498,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (y_non_contig) {
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
+ if (ctx->prealloc_y_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
@@ -6698,11 +6698,6 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
ggml_vk_sync_buffers(ctx, subctx);
}
}
- if (y_non_contig) {
- if (ctx->prealloc_y_need_sync) {
- ggml_vk_sync_buffers(ctx, subctx);
- }
- }
if (x_non_contig) {
GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
@@ -6712,6 +6707,9 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
+ if (ctx->prealloc_y_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
@@ -7881,6 +7879,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
break;
case GGML_OP_GET_ROWS:
elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
+ elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
+ elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
break;
case GGML_OP_ARGSORT:
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
@@ -9217,7 +9217,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
if (ctx->prealloc_split_k != nullptr) {
ggml_vk_destroy_buffer(ctx->prealloc_split_k);
}
- ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal});
}
}
@@ -9227,9 +9227,9 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
ggml_pipeline_allocate_descriptor_sets(ctx);
- vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
- vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
- vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal});
+ vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal});
+ vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal});
X_TYPE* x = (X_TYPE *) malloc(sizeof(X_TYPE) * x_ne);
Y_TYPE* y = (Y_TYPE *) malloc(sizeof(Y_TYPE) * y_ne);
@@ -9455,8 +9455,8 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
float * x = (float *) malloc(x_sz);
void * qx = malloc(qx_sz);
- vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
- vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
+ vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, {vk::MemoryPropertyFlagBits::eDeviceLocal});
float * x_ref = (float *) malloc(x_sz);
ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16);
@@ -9561,8 +9561,8 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
// float * x = (float *) malloc(x_sz);
// block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz);
// block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz);
-// vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
-// vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
+// vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
+// vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
//
// for (size_t i = 0; i < ne; i++) {
// x[i] = rand() / (float)RAND_MAX;
@@ -9709,10 +9709,10 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
float * x = (float *) malloc(x_sz);
float * y = (float *) malloc(y_sz);
void * qx = malloc(qx_sz);
- vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
- vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
- vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
- vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
+ vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
+ vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
+ vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
float * d = (float *) malloc(d_sz);
float * d_chk = (float *) malloc(d_sz);
@@ -9739,7 +9739,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
if (ctx->prealloc_split_k != nullptr) {
ggml_vk_destroy_buffer(ctx->prealloc_split_k);
}
- ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal});
}
}
if (mmq) {
@@ -12047,16 +12047,13 @@ static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions) {
#ifdef __APPLE__
- bool portability_enumeration_ext = false;
// Check for portability enumeration extension for MoltenVK support
for (const auto& properties : instance_extensions) {
if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
return true;
}
}
- if (!portability_enumeration_ext) {
- std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
- }
+ std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
#endif
return false;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
index d40848e15..482445c6f 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
@@ -334,6 +334,9 @@ void main() {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] *= Lfrcp[r];
+#if defined(ACC_TYPE_MAX)
+ Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX));
+#endif
}
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
index 97c2a5412..63b32171b 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
@@ -373,6 +373,9 @@ void main() {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] *= ACC_TYPE(Lfrcp[r]);
+#if defined(ACC_TYPE_MAX)
+ Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX);
+#endif
}
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
index 77ae5ff01..ab647e9bc 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
@@ -283,6 +283,10 @@ void main() {
O = Ldiag*O;
+#if defined(ACC_TYPE_MAX)
+ [[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
+
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
coopmat O_D = coopmat(O);
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
index 76ef4b6df..06e83822f 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
@@ -111,6 +111,10 @@ void main() {
}
}
O *= L;
+
+ const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF);
+ O = clamp(O, -FLT_MAX, FLT_MAX);
+
data_d[iq3 * D * N + D * n + d] = O;
}
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp
index ee6b86a18..7ef75cd7a 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp
@@ -7,27 +7,36 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint i00 = gl_GlobalInvocationID.x;
- const uint i10 = gl_GlobalInvocationID.y;
- const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
- const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
if (i00 >= p.ne00) {
return;
}
- const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
+ uint gid_z = gl_GlobalInvocationID.z;
+ while (gid_z < p.ne11 * p.ne12) {
+ uint gid_y = gl_GlobalInvocationID.y;
+ while (gid_y < p.ne10) {
+ const uint i10 = gid_y;
+ const uint i11 = gid_z / p.ne12;
+ const uint i12 = gid_z % p.ne12;
- const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
- const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
+ const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
+
+ const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
+ const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
#if defined(DATA_A_BF16)
- FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
+ FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
#else
- FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
+ FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
#endif
#ifndef OPTIMIZATION_ERROR_WORKAROUND
- data_d[d_offset + i00] = D_TYPE(v);
+ data_d[d_offset + i00] = D_TYPE(v);
#else
- data_d[d_offset + i00] = D_TYPE(v);
+ data_d[d_offset + i00] = D_TYPE(v);
#endif
+ gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+ }
+ gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z;
+ }
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp
index cfd645a38..339f905fc 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp
@@ -10,9 +10,6 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint i00 = (gl_GlobalInvocationID.x)*2;
- const uint i10 = gl_GlobalInvocationID.y;
- const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
- const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
@@ -22,20 +19,33 @@ void main() {
return;
}
- const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
+ uint gid_z = gl_GlobalInvocationID.z;
+ while (gid_z < p.ne11 * p.ne12) {
+ uint gid_y = gl_GlobalInvocationID.y;
+ while (gid_y < p.ne10) {
+ const uint i10 = gid_y;
+ const uint i11 = gid_z / p.ne12;
+ const uint i12 = gid_z % p.ne12;
- const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
- const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
+ const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
- const uint ib = a_offset + i00/QUANT_K; // block index
- const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index
- const uint iybs = i00 - i00%QUANT_K; // dst block start index
- const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
+ const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
+ const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
- vec2 v = dequantize(ib, iqs, 0);
- const vec2 dm = get_dm(ib, 0);
- v = v * dm.x + dm.y;
+ const uint ib = a_offset + i00/QUANT_K; // block index
+ const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index
+ const uint iybs = i00 - i00%QUANT_K; // dst block start index
+ const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
- data_d[d_offset + iybs + iqs ] = D_TYPE(v.x);
- data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);
+ vec2 v = dequantize(ib, iqs, 0);
+ const vec2 dm = get_dm(ib, 0);
+ v = v * dm.x + dm.y;
+
+ data_d[d_offset + iybs + iqs ] = D_TYPE(v.x);
+ data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);
+
+ gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+ }
+ gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z;
+ }
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
index 5ecf68a64..7e10e99e9 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
@@ -891,6 +891,20 @@ void main() {
barrier();
}
+#if defined(ACC_TYPE_MAX)
+#ifdef COOPMAT
+ [[unroll]] for (uint j = 0; j < cms_per_row * cms_per_col; j++) {
+ [[unroll]] for (uint i = 0; i < sums[j].length(); ++i) {
+ sums[j][i] = clamp(sums[j][i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
+ }
+ }
+#else
+ [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
+ sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
+ }
+#endif
+#endif
+
const uint dr = ir * BM + warp_r * WM;
const uint dc = ic * BN + warp_c * WN;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
index f5aebf6e9..654105a49 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
@@ -349,6 +349,10 @@ void main() {
sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK;
}
+#if defined(ACC_TYPE_MAX)
+ [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
+
coopmat mat_d = coopmat(sum);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
@@ -388,6 +392,10 @@ void main() {
sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK;
}
+#if defined(ACC_TYPE_MAX)
+ [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
+
coopmat mat_d = coopmat(sum);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
@@ -428,6 +436,10 @@ void main() {
sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK;
}
+#if defined(ACC_TYPE_MAX)
+ [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
+
coopmat mat_d = coopmat(sum);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
@@ -444,18 +456,105 @@ void main() {
tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
- coopmat sum;
- sum = coopmat(0.0);
-
uint k_iters = (end_k - start_k + BK - 1) / BK;
fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
+ store_scales(tid);
+
+#ifdef MUL_MAT_ID
+ if (enable_smaller_matrices && ic * BN + BNover4 >= _ne1) {
+ coopmat sum;
+ sum = coopmat(0.0);
+
+ [[dont_unroll]]
+ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
+
+ if ((block_k % QUANT_K) == 0) {
+ store_scales(tid);
+ }
+ if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
+ }
+
+ if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
+ coopmat mat_a;
+ coopmat mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ } else {
+ coopmat mat_a;
+ coopmat mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ }
+ }
+
+ // Convert from ACC_TYPE to D_TYPE
+ coopmat mat_d;
+ mat_d = coopmat(sum);
+
+ // Call callback to store each element, remapping row through shared memory
+ coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
+ return;
+ }
+ if (enable_smaller_matrices && ic * BN + BNover2 >= _ne1) {
+ coopmat sum;
+ sum = coopmat(0.0);
+
+ [[dont_unroll]]
+ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
+
+ if ((block_k % QUANT_K) == 0) {
+ store_scales(tid);
+ }
+ if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
+ }
+
+ if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
+ coopmat mat_a;
+ coopmat mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ } else {
+ coopmat mat_a;
+ coopmat mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ }
+ }
+
+ // Convert from ACC_TYPE to D_TYPE
+ coopmat mat_d;
+ mat_d = coopmat(sum);
+
+ // Call callback to store each element, remapping row through shared memory
+ coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
+ return;
+ }
+#endif
+ coopmat sum;
+ sum = coopmat(0.0);
[[dont_unroll]]
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
- store_scales(tid);
- if (block_k + BK < end_k) {
+ if ((block_k % QUANT_K) == 0) {
+ store_scales(tid);
+ }
+ if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
}
@@ -485,6 +584,9 @@ void main() {
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
}
+#if defined(ACC_TYPE_MAX)
+ [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
// Convert from ACC_TYPE to D_TYPE
coopmat mat_d;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
index 05e396889..d7c8a83da 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
@@ -337,6 +337,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
+ if (f16acc) {
+ base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
+ }
if (coopmat) {
base_dict["COOPMAT"] = "1";
@@ -451,8 +454,12 @@ void process_shaders() {
// flash attention
for (const auto& f16acc : {false, true}) {
- std::string acctype = f16acc ? "float16_t" : "float";
- std::string acctypev4 = f16acc ? "f16vec4" : "vec4";
+ std::map fa_base_dict = base_dict;
+ fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
+ fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
+ if (f16acc) {
+ fa_base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
+ }
for (const auto& tname : type_names) {
if (tname == "f32") {
@@ -463,30 +470,30 @@ void process_shaders() {
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
- merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
+ merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc);
} else {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
- merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
+ merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
}
#endif
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
- merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc);
+ merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
- merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
+ merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
}
#endif
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
- merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc);
+ merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
- merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
+ merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
}
}
}
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index a581f9601..6156d35c2 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -367,6 +367,7 @@ class MODEL_ARCH(IntEnum):
T5ENCODER = auto()
JAIS = auto()
NEMOTRON = auto()
+ NEMOTRON_H = auto()
EXAONE = auto()
EXAONE4 = auto()
GRANITE = auto()
@@ -700,6 +701,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.T5ENCODER: "t5encoder",
MODEL_ARCH.JAIS: "jais",
MODEL_ARCH.NEMOTRON: "nemotron",
+ MODEL_ARCH.NEMOTRON_H: "nemotron_h",
MODEL_ARCH.EXAONE: "exaone",
MODEL_ARCH.EXAONE4: "exaone4",
MODEL_ARCH.GRANITE: "granite",
@@ -2297,6 +2299,25 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
+ MODEL_ARCH.NEMOTRON_H: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.SSM_IN,
+ MODEL_TENSOR.SSM_CONV1D,
+ MODEL_TENSOR.SSM_DT,
+ MODEL_TENSOR.SSM_A,
+ MODEL_TENSOR.SSM_D,
+ MODEL_TENSOR.SSM_NORM,
+ MODEL_TENSOR.SSM_OUT,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
MODEL_ARCH.EXAONE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index abb21fa82..497f48809 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -191,6 +191,7 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.q_proj", # llama4
"model.transformer.blocks.{bid}.q_proj", # llada
"layers.{bid}.self_attn.q_proj", # qwen3-embedding
+ "backbone.layers.{bid}.mixer.q_proj", # nemotron-h
),
# Attention key
@@ -209,6 +210,7 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.k_proj", # llama4
"model.transformer.blocks.{bid}.k_proj", # llada
"layers.{bid}.self_attn.k_proj", # qwen3-embedding
+ "backbone.layers.{bid}.mixer.k_proj", # nemotron-h
),
# Attention value
@@ -226,6 +228,7 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.v_proj", # llama4
"model.transformer.blocks.{bid}.v_proj", # llada
"layers.{bid}.self_attn.v_proj", # qwen3-embedding
+ "backbone.layers.{bid}.mixer.v_proj", # nemotron-h
),
# Attention output
@@ -260,6 +263,7 @@ class TensorNameMap:
"transformer_encoder.{bid}.wo", # neobert
"model.transformer.blocks.{bid}.attn_out", # llada
"layers.{bid}.self_attn.o_proj", # qwen3-embedding
+ "backbone.layers.{bid}.mixer.o_proj", # nemotron-h
),
# Attention output norm
@@ -387,6 +391,7 @@ class TensorNameMap:
"model.layers.{bid}.block_sparse_moe.up", # smallthinker
"model.transformer.blocks.{bid}.up_proj", # llada
"layers.{bid}.mlp.up_proj", # qwen3-embedding
+ "backbone.layers.{bid}.mixer.up_proj", # nemotron-h
),
MODEL_TENSOR.FFN_UP_EXP: (
@@ -480,6 +485,7 @@ class TensorNameMap:
"model.layers.{bid}.block_sparse_moe.down", # smallthinker
"model.transformer.blocks.{bid}.ff_out", # llada
"layers.{bid}.mlp.down_proj", # qwen3-embedding
+ "backbone.layers.{bid}.mixer.down_proj", # nemotron-h
),
MODEL_TENSOR.FFN_DOWN_EXP: (
diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp
index 22654b8fa..8a79d1895 100644
--- a/gpttype_adapter.cpp
+++ b/gpttype_adapter.cpp
@@ -607,7 +607,7 @@ static void speculative_decoding_setup(std::string spec_model_filename, const ll
draft_ctx_params.n_ubatch = base_ctx_params.n_ubatch;
draft_ctx_params.n_threads = base_ctx_params.n_threads;
draft_ctx_params.n_threads_batch = base_ctx_params.n_threads_batch;
- draft_ctx_params.flash_attn = base_ctx_params.flash_attn;
+ draft_ctx_params.flash_attn_type = base_ctx_params.flash_attn_type;
draft_ctx_params.type_k = base_ctx_params.type_k;
draft_ctx_params.type_v = base_ctx_params.type_v;
draft_ctx_params.swa_full = base_ctx_params.swa_full;
@@ -2401,7 +2401,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
llamamodel->vocab.set_eos_bos(0,0);
}
- llama_ctx_params.flash_attn = kcpp_data->flash_attn;
+ llama_ctx_params.flash_attn_type = (kcpp_data->flash_attn?LLAMA_FLASH_ATTN_TYPE_ENABLED:LLAMA_FLASH_ATTN_TYPE_DISABLED);
llama_ctx_params.swa_full = kcpp_data->swa_full;
llama_ctx_params.type_k = (inputs.quant_k>1?GGML_TYPE_Q4_0:(inputs.quant_k==1?GGML_TYPE_Q8_0:GGML_TYPE_F16));
llama_ctx_params.type_v = (inputs.quant_v>1?GGML_TYPE_Q4_0:(inputs.quant_v==1?GGML_TYPE_Q8_0:GGML_TYPE_F16));
diff --git a/include/llama.h b/include/llama.h
index 4b8cd8d5a..35c687070 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -182,6 +182,14 @@ extern "C" {
LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1,
};
+ enum llama_flash_attn_type {
+ LLAMA_FLASH_ATTN_TYPE_AUTO = -1,
+ LLAMA_FLASH_ATTN_TYPE_DISABLED = 0,
+ LLAMA_FLASH_ATTN_TYPE_ENABLED = 1,
+ };
+
+ LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type);
+
enum llama_split_mode {
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
@@ -306,6 +314,7 @@ extern "C" {
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
enum llama_attention_type attention_type; // attention type to use for embeddings
+ enum llama_flash_attn_type flash_attn_type; // when to enable Flash Attention
// ref: https://github.com/ggml-org/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency, 0 = from model
@@ -332,7 +341,6 @@ extern "C" {
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
- bool flash_attn; // use flash attention [EXPERIMENTAL]
bool no_perf; // measure performance timings
bool op_offload; // offload host tensor operations to device
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
diff --git a/otherarch/embeddings_adapter.cpp b/otherarch/embeddings_adapter.cpp
index e9a57e031..8a63c9e57 100644
--- a/otherarch/embeddings_adapter.cpp
+++ b/otherarch/embeddings_adapter.cpp
@@ -136,7 +136,7 @@ bool embeddingstype_load_model(const embeddings_load_model_inputs inputs)
ctx_params.offload_kqv = false;
ctx_params.n_threads = nthreads;
ctx_params.n_threads_batch = nthreads;
- ctx_params.flash_attn = inputs.flash_attention;
+ ctx_params.flash_attn_type = (inputs.flash_attention?LLAMA_FLASH_ATTN_TYPE_ENABLED:LLAMA_FLASH_ATTN_TYPE_DISABLED);
ctx_params.kv_unified = true;
embeddings_ctx = llama_init_from_model(embeddingsmodel, ctx_params);
diff --git a/otherarch/tts_adapter.cpp b/otherarch/tts_adapter.cpp
index 6703bfc10..fc372c1ed 100644
--- a/otherarch/tts_adapter.cpp
+++ b/otherarch/tts_adapter.cpp
@@ -695,7 +695,7 @@ bool ttstype_load_model(const tts_load_model_inputs inputs)
tts_ctx_params.n_ubatch = 512;
tts_ctx_params.n_threads = nthreads;
tts_ctx_params.n_threads_batch = nthreads;
- tts_ctx_params.flash_attn = inputs.flash_attention;
+ tts_ctx_params.flash_attn_type = (inputs.flash_attention?LLAMA_FLASH_ATTN_TYPE_ENABLED:LLAMA_FLASH_ATTN_TYPE_DISABLED);
tts_ctx_params.kv_unified = true;
llama_model * ttcmodel = llama_model_load_from_file(modelfile_ttc.c_str(), tts_model_params);
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
index a61dc177a..d5c8477f4 100644
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -69,6 +69,7 @@ static const std::map LLM_ARCH_NAMES = {
{ LLM_ARCH_T5ENCODER, "t5encoder" },
{ LLM_ARCH_JAIS, "jais" },
{ LLM_ARCH_NEMOTRON, "nemotron" },
+ { LLM_ARCH_NEMOTRON_H, "nemotron_h" },
{ LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_EXAONE4, "exaone4" },
{ LLM_ARCH_RWKV6, "rwkv6" },
@@ -1550,6 +1551,31 @@ static const std::map> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
+ {
+ LLM_ARCH_NEMOTRON_H,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ // mamba(2) ssm layers
+ { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
+ { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
+ { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
+ { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
+ { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
+ { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
+ { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
+ // attention layers
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ // dense FFN
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ },
+ },
{
LLM_ARCH_EXAONE,
{
@@ -2355,6 +2381,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
case LLM_ARCH_PLAMO2:
case LLM_ARCH_GRANITE_HYBRID:
case LLM_ARCH_LFM2:
+ case LLM_ARCH_NEMOTRON_H:
return true;
default:
return false;
diff --git a/src/llama-arch.h b/src/llama-arch.h
index 94b0bef71..86c119692 100644
--- a/src/llama-arch.h
+++ b/src/llama-arch.h
@@ -73,6 +73,7 @@ enum llm_arch {
LLM_ARCH_T5ENCODER,
LLM_ARCH_JAIS,
LLM_ARCH_NEMOTRON,
+ LLM_ARCH_NEMOTRON_H,
LLM_ARCH_EXAONE,
LLM_ARCH_EXAONE4,
LLM_ARCH_RWKV6,
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index 410a0f5ee..1dda251ec 100644
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
@@ -41,7 +41,6 @@ llama_context::llama_context(
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.embeddings = params.embeddings;
cparams.offload_kqv = params.offload_kqv;
- cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
cparams.warmup = false;
@@ -86,6 +85,8 @@ llama_context::llama_context(
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
}
+ cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
+
// with causal attention, the batch size is limited by the context size
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
@@ -119,7 +120,7 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
- LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
+ LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type));
LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
@@ -269,7 +270,7 @@ llama_context::llama_context(
}
}
- // reserve worst-case graph
+ // resolve automatic Flash Attention use and reserve worst-case graph
if (!hparams.vocab_only) {
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
@@ -300,6 +301,48 @@ llama_context::llama_context(
throw std::runtime_error("failed to allocate compute pp buffers");
}
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
+ ggml_backend_sched_alloc_graph(sched.get(), gf);
+
+ const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
+ bool fa_device_mismatch = false;
+ for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
+ ggml_tensor * n = ggml_graph_node(gf, i);
+ if (n->op != GGML_OP_FLASH_ATTN_EXT) {
+ continue;
+ }
+ ggml_backend_dev_t device_fa = ggml_backend_get_device(
+ ggml_backend_sched_get_tensor_backend(sched.get(), n));
+
+ // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
+ GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
+ const int il = std::stoi(n->name + prefix_len);
+ ggml_backend_dev_t device_kv = model.dev_layer(il);
+ if (device_fa != device_kv) {
+ LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
+ "is assigned to device %s (usually due to missing support)\n",
+ __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
+ // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
+ fa_device_mismatch = true;
+ break;
+ }
+ }
+ if (fa_device_mismatch) {
+ cparams.flash_attn = false;
+ LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
+ if (ggml_is_quantized(params.type_v)) {
+ throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
+ }
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
+ if (!gf) {
+ throw std::runtime_error("failed to allocate compute pp buffers");
+ }
+ } else {
+ cparams.flash_attn = true;
+ LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
+ }
+ }
+
n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
n_nodes_pp = ggml_graph_n_nodes(gf);
}
@@ -2208,6 +2251,7 @@ llama_context_params llama_context_default_params() {
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
/*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
+ /*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO,
/*.rope_freq_base =*/ 0.0f,
/*.rope_freq_scale =*/ 0.0f,
/*.yarn_ext_factor =*/ -1.0f,
@@ -2224,7 +2268,6 @@ llama_context_params llama_context_default_params() {
/*.abort_callback_data =*/ nullptr,
/*.embeddings =*/ false,
/*.offload_kqv =*/ true,
- /*.flash_attn =*/ false,
/*.no_perf =*/ true,
/*.op_offload =*/ true,
/*.swa_full =*/ true,
@@ -2252,12 +2295,30 @@ llama_context * llama_init_from_model(
return nullptr;
}
- if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
+ if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && model->arch == LLM_ARCH_GROK) {
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
- params.flash_attn = false;
+ params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
}
- if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
+ const uint32_t blck_size = ggml_blck_size(params.type_k);
+ if (model->hparams.n_embd_head_k % blck_size != 0) {
+ LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
+ __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k);
+ return nullptr;
+ }
+ }
+
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
+ const uint32_t blck_size = ggml_blck_size(params.type_v);
+ if (model->hparams.n_embd_head_v % blck_size != 0) {
+ LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n",
+ __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v);
+ return nullptr;
+ }
+ }
+
+ if (ggml_is_quantized(params.type_v) && params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_DISABLED) {
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
return nullptr;
}
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
index 1f2fc3ab6..49ea5da7c 100644
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
@@ -1221,7 +1221,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
ggml_tensor * kq_mask,
ggml_tensor * sinks,
ggml_tensor * v_mla,
- float kq_scale) const {
+ float kq_scale,
+ int il) const {
const bool v_trans = v->nb[1] > v->nb[2];
// split the batch into streams if needed
@@ -1256,6 +1257,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
+ cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
ggml_flash_attn_ext_add_sinks(cur, sinks);
ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
@@ -1271,6 +1273,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
// The permutations are noops and only change how the tensor data is interpreted.
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
cur = ggml_mul_mat(ctx0, v_mla, cur);
+ cb(cur, "fattn_mla", il);
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
#endif
@@ -1279,6 +1282,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
} else {
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+ cb(kq, "kq", il);
// note: this op tends to require high floating point range
// while for some models F16 is enough, for others it is not, so we default to F32 here
@@ -1292,32 +1296,42 @@ ggml_tensor * llm_graph_context::build_attn_mha(
// before the softmax below
kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
+ cb(kq, "kq_tanh", il);
kq = ggml_scale(ctx0, kq, 30);
+ cb(kq, "kq_scaled", il);
}
if (hparams.attn_soft_cap) {
kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
+ cb(kq, "kq_scaled_1", il);
kq = ggml_tanh (ctx0, kq);
+ cb(kq, "kq_tanh", il);
kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
+ cb(kq, "kq_scaled_2", il);
}
if (kq_b) {
kq = ggml_add(ctx0, kq, kq_b);
+ cb(kq, "kq_plus_kq_b", il);
}
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
ggml_soft_max_add_sinks(kq, sinks);
+ cb(kq, "kq_soft_max", il);
if (!v_trans) {
// note: avoid this branch
v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
+ cb(v, "v_cont", il);
}
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
+ cb(kqv, "kqv", il);
// for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
if (v_mla) {
kqv = ggml_mul_mat(ctx0, v_mla, kqv);
+ cb(kqv, "kqv_mla", il);
}
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
@@ -1378,7 +1392,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k = k_cur;
ggml_tensor * v = v_cur;
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
if (wo) {
@@ -1467,7 +1481,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
if (wo) {
@@ -1534,7 +1548,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
if (wo) {
@@ -1589,7 +1603,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k = k_cur;
ggml_tensor * v = v_cur;
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
if (wo) {
diff --git a/src/llama-graph.h b/src/llama-graph.h
index e11d91d52..3c85333fd 100644
--- a/src/llama-graph.h
+++ b/src/llama-graph.h
@@ -687,7 +687,8 @@ struct llm_graph_context {
ggml_tensor * kq_mask,
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
- float kq_scale) const;
+ float kq_scale,
+ int il) const;
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
diff --git a/src/llama-impl.h b/src/llama-impl.h
index 02b1d07f8..c5163e922 100644
--- a/src/llama-impl.h
+++ b/src/llama-impl.h
@@ -59,3 +59,5 @@ std::string llama_format_tensor_shape(const std::vector & ne);
std::string llama_format_tensor_shape(const struct ggml_tensor * t);
std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i);
+
+#define LLAMA_TENSOR_NAME_FATTN "__fattn__"
diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp
index 9f107d3f4..00b513475 100644
--- a/src/llama-model-loader.cpp
+++ b/src/llama-model-loader.cpp
@@ -793,6 +793,7 @@ const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::stri
}
struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list & ne, int flags) {
+ // LLAMA_LOG_DEBUG("%s: loading tensor %s\n", __func__, name.c_str());
const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED));
if (cur == NULL) {
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 17c06b8df..c3dcbb7d1 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -1575,6 +1575,27 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
+ case LLM_ARCH_NEMOTRON_H:
+ {
+ ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
+ ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
+ ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
+ ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
+ ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
+
+ // A layer is recurrent IFF the n_head_kv value is set to 0 and
+ // the n_ff value is set to 0
+ for (uint32_t i = 0; i < hparams.n_layer; ++i) {
+ hparams.recurrent_layer_arr[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0);
+ }
+
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+ switch (hparams.n_layer) {
+ case 56: type = LLM_TYPE_9B; break;
+ default: type = LLM_TYPE_UNKNOWN;
+ }
+ } break;
case LLM_ARCH_EXAONE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -4784,6 +4805,75 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
}
} break;
+ case LLM_ARCH_NEMOTRON_H:
+ {
+ // mamba2 Mixer SSM params
+ // NOTE: int64_t for tensor dimensions
+ const int64_t d_conv = hparams.ssm_d_conv;
+ const int64_t d_inner = hparams.ssm_d_inner;
+ const int64_t d_state = hparams.ssm_d_state;
+ const int64_t n_ssm_head = hparams.ssm_dt_rank;
+ const int64_t n_group = hparams.ssm_n_group;
+ const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head;
+
+ // embeddings
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+ // output
+ {
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+ // if output is NULL, init from the input tok embed, duplicated to allow offloading
+ if (output == NULL) {
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+ }
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = layers[i];
+
+ // all blocks use the attn norm
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+ if (hparams.is_recurrent(i)) {
+ // ssm layers
+ layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0);
+
+ layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0);
+ layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED);
+
+ layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0);
+
+ // no "weight" suffix for these
+ layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0);
+ layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0);
+
+ layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0);
+
+ // out_proj
+ layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
+ } else if (hparams.n_ff(i) == 0) {
+ // attention layers (with optional bias)
+ const int64_t n_head_i = hparams.n_head(i);
+ const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i);
+ const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i);
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0);
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0);
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0);
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0);
+ layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
+ layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED);
+ layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED);
+ layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
+ } else {
+ // mlp layers
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0);
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0);
+ layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
+ layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED);
+ }
+ }
+ } break;
case LLM_ARCH_EXAONE:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -5959,7 +6049,8 @@ void llama_model::print_info() const {
arch == LLM_ARCH_JAMBA ||
arch == LLM_ARCH_FALCON_H1 ||
arch == LLM_ARCH_PLAMO2 ||
- arch == LLM_ARCH_GRANITE_HYBRID) {
+ arch == LLM_ARCH_GRANITE_HYBRID ||
+ arch == LLM_ARCH_NEMOTRON_H) {
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
@@ -14229,6 +14320,138 @@ struct llm_build_nemotron : public llm_graph_context {
}
};
+struct llm_build_nemotron_h : public llm_graph_context_mamba {
+ llm_build_nemotron_h(
+ const llama_model & model,
+ const llm_graph_params & params) :
+ llm_graph_context_mamba(params) {
+
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+ ggml_tensor * cur;
+ ggml_tensor * inpL;
+
+ inpL = build_inp_embd(model.tok_embd);
+
+ auto * inp = build_inp_mem_hybrid();
+
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * inpSA = inpL;
+
+ // norm
+ cur = build_norm(inpL,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "attn_norm", il);
+
+ if (hparams.is_recurrent(il)) {
+ // ssm layer //
+ cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il);
+ } else if (hparams.n_ff(il) == 0) {
+ // attention layer //
+ cur = build_attention_layer(cur, inp->get_attn(), model, n_embd_head, il);
+ } else {
+ cur = build_ffn_layer(cur, model, il);
+ }
+
+ if (il == n_layer - 1 && inp_out_ids) {
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+ // add residual
+ cur = ggml_add(ctx0, cur, inpSA);
+ cb(cur, "block_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = build_norm(cur,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, -1);
+
+ cb(cur, "result_norm", -1);
+ res->t_embd = cur;
+
+ // lm_head
+ cur = build_lora_mm(model.output, cur);
+ cb(cur, "result_output", -1);
+ res->t_logits = cur;
+
+ ggml_build_forward_expand(gf, cur);
+ }
+
+ ggml_tensor * build_attention_layer(
+ ggml_tensor * cur,
+ llm_graph_input_attn_kv * inp_attn,
+ const llama_model & model,
+ const int64_t n_embd_head,
+ const int il) {
+
+ // compute Q and K and (optionally) RoPE them
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+ if (model.layers[il].bq) {
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+ cb(Qcur, "Qcur", il);
+ }
+
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+ if (model.layers[il].bk) {
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+ cb(Kcur, "Kcur", il);
+ }
+
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+ if (model.layers[il].bv) {
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+ cb(Vcur, "Vcur", il);
+ }
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
+
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
+
+ const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+ cur = build_attn(inp_attn,
+ model.layers[il].wo, model.layers[il].bo,
+ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+ cb(cur, "attn_out", il);
+ return cur;
+ }
+
+ ggml_tensor * build_ffn_layer(
+ ggml_tensor * cur,
+ const llama_model & model,
+ const int il) {
+
+ cur = build_ffn(cur,
+ model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
+ NULL, NULL, NULL,
+ model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+ NULL,
+ LLM_FFN_RELU_SQR, LLM_FFN_PAR, il);
+ cb(cur, "ffn_out", il);
+
+ cur = build_cvec(cur, il);
+ cb(cur, "l_out", il);
+
+ return cur;
+ }
+};
+
struct llm_build_exaone : public llm_graph_context {
llm_build_exaone(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -18377,6 +18600,23 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.n_seq_max,
nullptr);
} else if (llm_arch_is_hybrid(arch)) {
+
+ // The main difference between hybrid architectures is the
+ // layer filters, so pick the right one here
+ llama_memory_hybrid::layer_filter_cb filter_attn = nullptr;
+ llama_memory_hybrid::layer_filter_cb filter_recr = nullptr;
+ if (arch == LLM_ARCH_FALCON_H1) {
+ filter_attn = [&](int32_t) { return true; };
+ filter_recr = [&](int32_t) { return true; };
+ } else if (arch == LLM_ARCH_NEMOTRON_H) {
+ filter_attn = [&](int32_t il) {
+ return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0;
+ };
+ filter_recr = [&](int32_t il) {
+ return hparams.is_recurrent(il) && hparams.n_ff(il) == 0;
+ };
+ }
+
const auto padding = llama_kv_cache::get_padding(cparams);
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
@@ -18396,8 +18636,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* n_seq_max */ cparams.n_seq_max,
/* offload */ cparams.offload_kqv,
/* unified */ cparams.kv_unified,
- /* filter_attn */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr,
- /* filter_recr */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr);
+ /* filter_attn */ std::move(filter_attn),
+ /* filter_recr */ std::move(filter_recr));
} else {
const auto padding = llama_kv_cache::get_padding(cparams);
@@ -18725,6 +18965,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique(*this, params);
} break;
+ case LLM_ARCH_NEMOTRON_H:
+ {
+ llm = std::make_unique(*this, params);
+ } break;
case LLM_ARCH_EXAONE:
{
llm = std::make_unique(*this, params);
@@ -18850,7 +19094,7 @@ llama_model_params llama_model_default_params() {
llama_model_params result = {
/*.devices =*/ nullptr,
/*.tensor_buft_overrides =*/ nullptr,
- /*.n_gpu_layers =*/ 0,
+ /*.n_gpu_layers =*/ 999,
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
/*.main_gpu =*/ 0,
/*.tensor_split =*/ nullptr,
@@ -18864,11 +19108,6 @@ llama_model_params llama_model_default_params() {
/*.use_extra_bufts =*/ true,
};
-#ifdef GGML_USE_METAL
- // note: we usually have plenty of VRAM, so by default offload all layers to the GPU
- result.n_gpu_layers = 999;
-#endif
-
return result;
}
@@ -18960,6 +19199,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
case LLM_ARCH_WAVTOKENIZER_DEC:
+ case LLM_ARCH_NEMOTRON_H:
return LLAMA_ROPE_TYPE_NONE;
// use what we call a normal RoPE, operating on pairs of consecutive head values
diff --git a/src/llama.cpp b/src/llama.cpp
index b633804bf..a631c348a 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -52,6 +52,18 @@ static bool old_mixtral_warning_showed = false;
// interface implementation
//
+const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type) {
+ switch (flash_attn_type) {
+ case LLAMA_FLASH_ATTN_TYPE_AUTO:
+ return "auto";
+ case LLAMA_FLASH_ATTN_TYPE_DISABLED:
+ return "disabled";
+ case LLAMA_FLASH_ATTN_TYPE_ENABLED:
+ return "enabled";
+ }
+ GGML_ABORT("fatal error");
+}
+
struct llama_sampler_chain_params llama_sampler_chain_default_params() {
struct llama_sampler_chain_params result = {
/*.no_perf =*/ true,
diff --git a/tools/server/tests/unit/test_ctx_shift.py b/tools/server/tests/unit/test_ctx_shift.py
index 8f51bc301..92e49f2bb 100644
--- a/tools/server/tests/unit/test_ctx_shift.py
+++ b/tools/server/tests/unit/test_ctx_shift.py
@@ -15,25 +15,26 @@ Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deseru
def create_server():
global server
server = ServerPreset.tinyllama2()
- server.n_ctx = 256
+ server.n_ctx = 512
server.n_slots = 2
+ server.n_predict = 128
def test_ctx_shift_enabled():
# the prompt is 301 tokens
- # the slot context is 256/2 = 128 tokens
- # the prompt is truncated to keep the last 109 tokens
- # 64 tokens are generated thanks to shifting the context when it gets full
+ # the slot context is 512/2 = 256 tokens
+ # the prompt is truncated to keep the last (301 - 256/2) = 173 tokens
+ # 96 tokens are generated thanks to shifting the context when it gets full
global server
server.enable_ctx_shift = True
server.start()
res = server.make_request("POST", "/completion", data={
- "n_predict": 64,
+ "n_predict": 96,
"prompt": LONG_TEXT,
})
assert res.status_code == 200
- assert res.body["timings"]["prompt_n"] == 109
- assert res.body["timings"]["predicted_n"] == 64
+ assert res.body["timings"]["prompt_n"] == 173
+ assert res.body["timings"]["predicted_n"] == 96
assert res.body["truncated"] is True
diff --git a/tools/server/tests/unit/test_speculative.py b/tools/server/tests/unit/test_speculative.py
index 38ca4325b..65952de8b 100644
--- a/tools/server/tests/unit/test_speculative.py
+++ b/tools/server/tests/unit/test_speculative.py
@@ -14,6 +14,7 @@ def create_server():
server.model_draft = download_file(MODEL_DRAFT_FILE_URL)
server.draft_min = 4
server.draft_max = 8
+ server.fa = "off"
@pytest.fixture(autouse=True)
diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py
index f55a53947..82f7215d5 100644
--- a/tools/server/tests/utils.py
+++ b/tools/server/tests/utils.py
@@ -66,7 +66,7 @@ class ServerProcess:
n_slots: int | None = None
ctk: str | None = None
ctv: str | None = None
- fa: bool | None = None
+ fa: str | None = None
server_continuous_batching: bool | None = False
server_embeddings: bool | None = False
server_reranking: bool | None = False
@@ -161,7 +161,7 @@ class ServerProcess:
if self.ctv:
server_args.extend(["-ctv", self.ctv])
if self.fa is not None:
- server_args.append("-fa")
+ server_args.extend(["-fa", self.fa])
if self.n_predict:
server_args.extend(["--n-predict", self.n_predict])
if self.slot_save_path:
@@ -427,7 +427,7 @@ class ServerPreset:
server.n_batch = 300
server.n_ubatch = 300
server.n_slots = 2
- server.fa = True
+ server.fa = "on"
server.seed = 42
server.server_embeddings = True
return server