mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .github/workflows/build.yml # README.md # ggml/src/ggml-cann/ggml-cann.cpp # ggml/src/ggml-opencl/ggml-opencl.cpp # ggml/src/ggml-sycl/ggml-sycl.cpp # tests/test-backend-ops.cpp # tests/test-chat-template.cpp
This commit is contained in:
commit
6eea7b88d2
80 changed files with 2737 additions and 185 deletions
|
@ -2949,11 +2949,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
|
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
|
||||||
"- none: leaves thoughts unparsed in `message.content`\n"
|
"- none: leaves thoughts unparsed in `message.content`\n"
|
||||||
"- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n"
|
"- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n"
|
||||||
"(default: deepseek)",
|
"(default: auto)",
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
/**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; }
|
/**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; }
|
||||||
else if (value == "deepseek-legacy") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; }
|
else if (value == "deepseek-legacy") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; }
|
||||||
else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; }
|
else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; }
|
||||||
|
else if (value == "auto") { params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; }
|
||||||
else { throw std::invalid_argument("invalid value"); }
|
else { throw std::invalid_argument("invalid value"); }
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK"));
|
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK"));
|
||||||
|
|
|
@ -126,6 +126,8 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
|
||||||
typedef minja::chat_template common_chat_template;
|
typedef minja::chat_template common_chat_template;
|
||||||
|
|
||||||
struct common_chat_templates {
|
struct common_chat_templates {
|
||||||
|
bool add_bos;
|
||||||
|
bool add_eos;
|
||||||
bool has_explicit_template; // Model had builtin template or template overridde was specified.
|
bool has_explicit_template; // Model had builtin template or template overridde was specified.
|
||||||
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
|
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
|
||||||
std::unique_ptr<common_chat_template> template_tool_use;
|
std::unique_ptr<common_chat_template> template_tool_use;
|
||||||
|
@ -143,6 +145,8 @@ struct templates_params {
|
||||||
bool enable_thinking = true;
|
bool enable_thinking = true;
|
||||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||||
json extra_context;
|
json extra_context;
|
||||||
|
bool add_bos;
|
||||||
|
bool add_eos;
|
||||||
};
|
};
|
||||||
|
|
||||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
|
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
|
||||||
|
@ -445,6 +449,8 @@ std::string common_chat_format_single(
|
||||||
|
|
||||||
common_chat_templates_inputs inputs;
|
common_chat_templates_inputs inputs;
|
||||||
inputs.use_jinja = use_jinja;
|
inputs.use_jinja = use_jinja;
|
||||||
|
inputs.add_bos = tmpls->add_bos;
|
||||||
|
inputs.add_eos = tmpls->add_eos;
|
||||||
|
|
||||||
std::string fmt_past_msg;
|
std::string fmt_past_msg;
|
||||||
if (!past_msg.empty()) {
|
if (!past_msg.empty()) {
|
||||||
|
@ -469,6 +475,8 @@ std::string common_chat_format_single(
|
||||||
std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
|
std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
|
||||||
common_chat_templates_inputs inputs;
|
common_chat_templates_inputs inputs;
|
||||||
inputs.use_jinja = use_jinja;
|
inputs.use_jinja = use_jinja;
|
||||||
|
inputs.add_bos = tmpls->add_bos;
|
||||||
|
inputs.add_eos = tmpls->add_eos;
|
||||||
auto add_simple_msg = [&](auto role, auto content) {
|
auto add_simple_msg = [&](auto role, auto content) {
|
||||||
common_chat_msg msg;
|
common_chat_msg msg;
|
||||||
msg.role = role;
|
msg.role = role;
|
||||||
|
@ -546,6 +554,8 @@ common_chat_templates_ptr common_chat_templates_init(
|
||||||
}
|
}
|
||||||
std::string token_bos = bos_token_override;
|
std::string token_bos = bos_token_override;
|
||||||
std::string token_eos = eos_token_override;
|
std::string token_eos = eos_token_override;
|
||||||
|
bool add_bos = false;
|
||||||
|
bool add_eos = false;
|
||||||
if (model) {
|
if (model) {
|
||||||
const auto * vocab = llama_model_get_vocab(model);
|
const auto * vocab = llama_model_get_vocab(model);
|
||||||
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
|
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
|
||||||
|
@ -560,9 +570,13 @@ common_chat_templates_ptr common_chat_templates_init(
|
||||||
};
|
};
|
||||||
token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
|
token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
|
||||||
token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
|
token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
|
||||||
|
add_bos = llama_vocab_get_add_bos(vocab);
|
||||||
|
add_eos = llama_vocab_get_add_eos(vocab);
|
||||||
}
|
}
|
||||||
common_chat_templates_ptr tmpls(new common_chat_templates());
|
common_chat_templates_ptr tmpls(new common_chat_templates());
|
||||||
tmpls->has_explicit_template = has_explicit_template;
|
tmpls->has_explicit_template = has_explicit_template;
|
||||||
|
tmpls->add_bos = add_bos;
|
||||||
|
tmpls->add_eos = add_eos;
|
||||||
try {
|
try {
|
||||||
tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
|
tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
|
||||||
} catch (const std::exception & e) {
|
} catch (const std::exception & e) {
|
||||||
|
@ -592,6 +606,7 @@ const char * common_chat_format_name(common_chat_format format) {
|
||||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
|
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
|
||||||
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
|
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
|
||||||
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
|
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
|
||||||
|
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown chat format");
|
throw std::runtime_error("Unknown chat format");
|
||||||
}
|
}
|
||||||
|
@ -600,6 +615,7 @@ const char * common_chat_format_name(common_chat_format format) {
|
||||||
const char * common_reasoning_format_name(common_reasoning_format format) {
|
const char * common_reasoning_format_name(common_reasoning_format format) {
|
||||||
switch (format) {
|
switch (format) {
|
||||||
case COMMON_REASONING_FORMAT_NONE: return "none";
|
case COMMON_REASONING_FORMAT_NONE: return "none";
|
||||||
|
case COMMON_REASONING_FORMAT_AUTO: return "auto";
|
||||||
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
|
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
|
||||||
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
|
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
|
||||||
default:
|
default:
|
||||||
|
@ -748,10 +764,10 @@ static std::string apply(
|
||||||
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
|
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
|
||||||
// may be needed inside the template / between messages too.
|
// may be needed inside the template / between messages too.
|
||||||
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
|
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
|
||||||
if (string_starts_with(result, tmpl.bos_token())) {
|
if (inputs.add_bos && string_starts_with(result, tmpl.bos_token())) {
|
||||||
result = result.substr(tmpl.bos_token().size());
|
result = result.substr(tmpl.bos_token().size());
|
||||||
}
|
}
|
||||||
if (string_ends_with(result, tmpl.eos_token())) {
|
if (inputs.add_eos && string_ends_with(result, tmpl.eos_token())) {
|
||||||
result = result.substr(0, result.size() - tmpl.eos_token().size());
|
result = result.substr(0, result.size() - tmpl.eos_token().size());
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
@ -1289,6 +1305,26 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
|
||||||
tool_calls_end);
|
tool_calls_end);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
|
common_chat_params data;
|
||||||
|
auto prompt = apply(tmpl, inputs);
|
||||||
|
|
||||||
|
data.prompt = prompt;
|
||||||
|
data.format = COMMON_CHAT_FORMAT_GPT_OSS;
|
||||||
|
|
||||||
|
// TODO: support tool calls in GPT-OSS?
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
|
||||||
|
// TODO @ngxson : this won't work with --special enabled, we should fix that
|
||||||
|
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>");
|
||||||
|
if (!builder.syntax().parse_tool_calls) {
|
||||||
|
builder.add_content(builder.consume_rest());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
LOG_DBG("%s\n", __func__);
|
LOG_DBG("%s\n", __func__);
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
|
@ -1731,6 +1767,8 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||||
params.enable_thinking = inputs.enable_thinking;
|
params.enable_thinking = inputs.enable_thinking;
|
||||||
params.grammar = inputs.grammar;
|
params.grammar = inputs.grammar;
|
||||||
params.now = inputs.now;
|
params.now = inputs.now;
|
||||||
|
params.add_bos = inputs.add_bos;
|
||||||
|
params.add_eos = inputs.add_eos;
|
||||||
|
|
||||||
params.extra_context = json::object();
|
params.extra_context = json::object();
|
||||||
for (auto el : inputs.chat_template_kwargs) {
|
for (auto el : inputs.chat_template_kwargs) {
|
||||||
|
@ -1772,6 +1810,11 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||||
return common_chat_params_init_hermes_2_pro(tmpl, params);
|
return common_chat_params_init_hermes_2_pro(tmpl, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GPT-OSS
|
||||||
|
if (src.find("<|channel|>") != std::string::npos && params.json_schema.is_null()) {
|
||||||
|
return common_chat_params_init_gpt_oss(tmpl, params);
|
||||||
|
}
|
||||||
|
|
||||||
// Use generic handler when mixing tools + JSON schema.
|
// Use generic handler when mixing tools + JSON schema.
|
||||||
// TODO: support that mix in handlers below.
|
// TODO: support that mix in handlers below.
|
||||||
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
||||||
|
@ -1923,6 +1966,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||||
case COMMON_CHAT_FORMAT_COMMAND_R7B:
|
case COMMON_CHAT_FORMAT_COMMAND_R7B:
|
||||||
common_chat_parse_command_r7b(builder);
|
common_chat_parse_command_r7b(builder);
|
||||||
break;
|
break;
|
||||||
|
case COMMON_CHAT_FORMAT_GPT_OSS:
|
||||||
|
common_chat_parse_gpt_oss(builder);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
|
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
|
||||||
}
|
}
|
||||||
|
|
|
@ -109,6 +109,7 @@ enum common_chat_format {
|
||||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||||
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
||||||
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
||||||
|
COMMON_CHAT_FORMAT_GPT_OSS,
|
||||||
|
|
||||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||||
};
|
};
|
||||||
|
@ -127,6 +128,8 @@ struct common_chat_templates_inputs {
|
||||||
bool enable_thinking = true;
|
bool enable_thinking = true;
|
||||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||||
std::map<std::string, std::string> chat_template_kwargs;
|
std::map<std::string, std::string> chat_template_kwargs;
|
||||||
|
bool add_bos = false;
|
||||||
|
bool add_eos = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_chat_params {
|
struct common_chat_params {
|
||||||
|
|
|
@ -232,6 +232,7 @@ struct common_params_diffusion {
|
||||||
|
|
||||||
enum common_reasoning_format {
|
enum common_reasoning_format {
|
||||||
COMMON_REASONING_FORMAT_NONE,
|
COMMON_REASONING_FORMAT_NONE,
|
||||||
|
COMMON_REASONING_FORMAT_AUTO,
|
||||||
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
|
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
|
||||||
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
|
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
|
||||||
};
|
};
|
||||||
|
@ -390,7 +391,7 @@ struct common_params {
|
||||||
std::string chat_template = ""; // NOLINT
|
std::string chat_template = ""; // NOLINT
|
||||||
bool use_jinja = false; // NOLINT
|
bool use_jinja = false; // NOLINT
|
||||||
bool enable_chat_template = true;
|
bool enable_chat_template = true;
|
||||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
|
||||||
int reasoning_budget = -1;
|
int reasoning_budget = -1;
|
||||||
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
|
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
|
||||||
|
|
||||||
|
|
|
@ -7950,6 +7950,119 @@ class SmolLM3Model(LlamaModel):
|
||||||
self.gguf_writer.add_chat_template(chat_template)
|
self.gguf_writer.add_chat_template(chat_template)
|
||||||
|
|
||||||
|
|
||||||
|
@ModelBase.register("GptOssForCausalLM")
|
||||||
|
class GptOssModel(TextModel):
|
||||||
|
model_arch = gguf.MODEL_ARCH.GPT_OSS
|
||||||
|
|
||||||
|
def transform_nibble_layout(self, tensor):
|
||||||
|
assert tensor.dtype == torch.uint8
|
||||||
|
assert tensor.shape[-1] == 16
|
||||||
|
# swap nibbles
|
||||||
|
t_lo = tensor & 0x0F
|
||||||
|
t_hi = tensor & 0xF0
|
||||||
|
t_swapped = (t_lo << 4) | (t_hi >> 4)
|
||||||
|
tensor = t_swapped
|
||||||
|
# transform aaaa...bbbb... to abababab...
|
||||||
|
blk_a, blk_b = tensor.chunk(2, dim=-1)
|
||||||
|
# get a_
|
||||||
|
blk_a0 = (blk_a & 0xF0).view(-1, 1)
|
||||||
|
blk_a1 = (blk_a << 4).view(-1, 1)
|
||||||
|
blk_a = torch.stack((blk_a0, blk_a1), dim=2).view(tensor.shape)
|
||||||
|
# get _b
|
||||||
|
blk_b0 = (blk_b >> 4).view(-1, 1)
|
||||||
|
blk_b1 = (blk_b & 0x0F).view(-1, 1)
|
||||||
|
blk_b = torch.stack((blk_b0, blk_b1), dim=2).view(tensor.shape)
|
||||||
|
# swap once more
|
||||||
|
out = blk_a | blk_b
|
||||||
|
out_h = out & 0xF0
|
||||||
|
out_l = out & 0x0F
|
||||||
|
out = (out_h >> 4) | (out_l << 4)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def repack_mxfp4(self, new_name: str, blocks: Tensor, scales: Tensor):
|
||||||
|
assert blocks.dtype == torch.uint8
|
||||||
|
assert scales.dtype == torch.uint8
|
||||||
|
scales = scales.unsqueeze(-1)
|
||||||
|
assert len(blocks.shape) == 4
|
||||||
|
assert len(scales.shape) == 4
|
||||||
|
blocks = self.transform_nibble_layout(blocks)
|
||||||
|
new_data = torch.concat((scales, blocks), dim=-1)
|
||||||
|
new_shape = [new_data.shape[0], new_data.shape[1], new_data.shape[2] * 32]
|
||||||
|
logger.info(f"Repacked {new_name} with shape {new_shape} and quantization MXFP4")
|
||||||
|
# flatten last dim
|
||||||
|
new_data = new_data.view(new_data.shape[0], new_data.shape[1], new_data.shape[2] * new_data.shape[3])
|
||||||
|
new_data = new_data.numpy()
|
||||||
|
self.gguf_writer.add_tensor(new_name, new_data, raw_dtype=gguf.GGMLQuantizationType.MXFP4)
|
||||||
|
|
||||||
|
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
blocks0: Tensor = torch.zeros(1)
|
||||||
|
blocks1: Tensor = torch.zeros(1)
|
||||||
|
found_mxfp4_tensors = False
|
||||||
|
# we assume that tensors are loaded in the correct order
|
||||||
|
for name, data_torch in self.get_tensors():
|
||||||
|
if "mlp.experts.down_proj_blocks" in name:
|
||||||
|
blocks0 = data_torch
|
||||||
|
elif "mlp.experts.down_proj_scales" in name:
|
||||||
|
new_name = self.map_tensor_name(name.replace("_scales", ".weight"))
|
||||||
|
self.repack_mxfp4(new_name, blocks0, data_torch)
|
||||||
|
found_mxfp4_tensors = True
|
||||||
|
elif "mlp.experts.gate_up_proj_blocks" in name:
|
||||||
|
blocks0, blocks1 = data_torch[:, ::2, :, :], data_torch[:, 1::2, :, :]
|
||||||
|
elif "mlp.experts.gate_up_proj_scales" in name:
|
||||||
|
scales0, scales1 = data_torch[:, ::2, :], data_torch[:, 1::2, :]
|
||||||
|
new_name_gate = self.map_tensor_name(name.replace("gate_up_proj_scales", "gate_proj.weight"))
|
||||||
|
new_name_up = self.map_tensor_name(name.replace("gate_up_proj_scales", "up_proj.weight"))
|
||||||
|
self.repack_mxfp4(new_name_gate, blocks0, scales0)
|
||||||
|
self.repack_mxfp4(new_name_up, blocks1, scales1)
|
||||||
|
found_mxfp4_tensors = True
|
||||||
|
if not found_mxfp4_tensors:
|
||||||
|
raise ValueError("No MXFP4 tensors found in the model. Please make sure you are using MXFP4 model.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
del bid # unused
|
||||||
|
|
||||||
|
if "sinks" in name:
|
||||||
|
name += ".weight"
|
||||||
|
|
||||||
|
# correct naming for down_proj
|
||||||
|
if "down_proj" in name:
|
||||||
|
if name.endswith("_bias"):
|
||||||
|
name = name.replace("down_proj_bias", "down_proj.bias")
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# split the gate_up into gate and up
|
||||||
|
if "gate_up_proj" in name:
|
||||||
|
if name.endswith("_bias"):
|
||||||
|
name_up = name.replace("gate_up_proj_bias", "up_proj.bias")
|
||||||
|
name_gate = name.replace("gate_up_proj_bias", "gate_proj.bias")
|
||||||
|
gate_proj_bias, up_proj_bias = data_torch[..., ::2], data_torch[..., 1::2]
|
||||||
|
return [
|
||||||
|
(self.map_tensor_name(name_gate), gate_proj_bias),
|
||||||
|
(self.map_tensor_name(name_up), up_proj_bias)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
|
def set_vocab(self):
|
||||||
|
self._set_vocab_gpt2()
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
|
||||||
|
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size"])
|
||||||
|
|
||||||
|
rope_scaling = self.hparams.get("rope_scaling") or {}
|
||||||
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type"))
|
||||||
|
assert rope_type == "yarn", f"GPT-OSS only supports yarn rope scaling, got {rope_type}"
|
||||||
|
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||||
|
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
|
||||||
|
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling.get("original_max_position_embeddings", 4096))
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("Lfm2ForCausalLM")
|
@ModelBase.register("Lfm2ForCausalLM")
|
||||||
@ModelBase.register("LFM2ForCausalLM")
|
@ModelBase.register("LFM2ForCausalLM")
|
||||||
class LFM2Model(TextModel):
|
class LFM2Model(TextModel):
|
||||||
|
@ -8089,6 +8202,7 @@ class LazyTorchTensor(gguf.LazyBase):
|
||||||
_dtype_map: dict[torch.dtype, type] = {
|
_dtype_map: dict[torch.dtype, type] = {
|
||||||
torch.float16: np.float16,
|
torch.float16: np.float16,
|
||||||
torch.float32: np.float32,
|
torch.float32: np.float32,
|
||||||
|
torch.uint8: np.uint8,
|
||||||
}
|
}
|
||||||
|
|
||||||
# used for safetensors slices
|
# used for safetensors slices
|
||||||
|
|
|
@ -310,6 +310,16 @@
|
||||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
||||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||||
|
|
||||||
|
#define GGML_TENSOR_TERNARY_OP_LOCALS \
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) \
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||||
|
|
||||||
#define GGML_TENSOR_BINARY_OP_LOCALS01 \
|
#define GGML_TENSOR_BINARY_OP_LOCALS01 \
|
||||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
||||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
||||||
|
@ -401,7 +411,8 @@ extern "C" {
|
||||||
GGML_TYPE_IQ4_NL_4_4 = 36, //deprecated upstream
|
GGML_TYPE_IQ4_NL_4_4 = 36, //deprecated upstream
|
||||||
// GGML_TYPE_IQ4_NL_4_8 = 37,
|
// GGML_TYPE_IQ4_NL_4_8 = 37,
|
||||||
// GGML_TYPE_IQ4_NL_8_8 = 38,
|
// GGML_TYPE_IQ4_NL_8_8 = 38,
|
||||||
GGML_TYPE_COUNT = 39,
|
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
|
||||||
|
GGML_TYPE_COUNT = 40,
|
||||||
};
|
};
|
||||||
|
|
||||||
// precision
|
// precision
|
||||||
|
@ -436,6 +447,7 @@ extern "C" {
|
||||||
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
|
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
|
||||||
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
|
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
|
||||||
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
|
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
|
||||||
|
GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors
|
||||||
};
|
};
|
||||||
|
|
||||||
// available tensor operations:
|
// available tensor operations:
|
||||||
|
@ -444,6 +456,7 @@ extern "C" {
|
||||||
|
|
||||||
GGML_OP_DUP,
|
GGML_OP_DUP,
|
||||||
GGML_OP_ADD,
|
GGML_OP_ADD,
|
||||||
|
GGML_OP_ADD_ID,
|
||||||
GGML_OP_ADD1,
|
GGML_OP_ADD1,
|
||||||
GGML_OP_ACC,
|
GGML_OP_ACC,
|
||||||
GGML_OP_SUB,
|
GGML_OP_SUB,
|
||||||
|
@ -563,6 +576,7 @@ extern "C" {
|
||||||
GGML_GLU_OP_REGLU,
|
GGML_GLU_OP_REGLU,
|
||||||
GGML_GLU_OP_GEGLU,
|
GGML_GLU_OP_GEGLU,
|
||||||
GGML_GLU_OP_SWIGLU,
|
GGML_GLU_OP_SWIGLU,
|
||||||
|
GGML_GLU_OP_SWIGLU_OAI,
|
||||||
GGML_GLU_OP_GEGLU_ERF,
|
GGML_GLU_OP_GEGLU_ERF,
|
||||||
GGML_GLU_OP_GEGLU_QUICK,
|
GGML_GLU_OP_GEGLU_QUICK,
|
||||||
|
|
||||||
|
@ -844,6 +858,13 @@ extern "C" {
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
enum ggml_type type);
|
enum ggml_type type);
|
||||||
|
|
||||||
|
// dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]]
|
||||||
|
GGML_API struct ggml_tensor * ggml_add_id(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
struct ggml_tensor * ids);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_add1(
|
GGML_API struct ggml_tensor * ggml_add1(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
@ -1211,6 +1232,13 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_swiglu_oai(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
float alpha,
|
||||||
|
float limit);
|
||||||
|
|
||||||
// normalize along rows
|
// normalize along rows
|
||||||
GGML_API struct ggml_tensor * ggml_norm(
|
GGML_API struct ggml_tensor * ggml_norm(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
@ -1583,6 +1611,10 @@ extern "C" {
|
||||||
float scale,
|
float scale,
|
||||||
float max_bias);
|
float max_bias);
|
||||||
|
|
||||||
|
GGML_API void ggml_soft_max_add_sinks(
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * sinks);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
|
GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
@ -2065,6 +2097,10 @@ extern "C" {
|
||||||
GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
|
GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
|
||||||
const struct ggml_tensor * a);
|
const struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API void ggml_flash_attn_ext_add_sinks(
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * sinks);
|
||||||
|
|
||||||
// TODO: needs to be adapted to ggml_flash_attn_ext
|
// TODO: needs to be adapted to ggml_flash_attn_ext
|
||||||
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
|
|
@ -29,6 +29,7 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
|
||||||
case GGML_OP_DIAG_MASK_ZERO:
|
case GGML_OP_DIAG_MASK_ZERO:
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
case GGML_OP_ADD1:
|
case GGML_OP_ADD1:
|
||||||
case GGML_OP_SUB:
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
|
|
|
@ -99,6 +99,9 @@ typedef sycl::half2 ggml_half2;
|
||||||
#define QI4_1 (QK4_1 / (4 * QR4_1))
|
#define QI4_1 (QK4_1 / (4 * QR4_1))
|
||||||
#define QR4_1 2
|
#define QR4_1 2
|
||||||
|
|
||||||
|
#define QI_MXFP4 (QK_MXFP4 / (4 * QR_MXFP4))
|
||||||
|
#define QR_MXFP4 2
|
||||||
|
|
||||||
#define QI5_0 (QK5_0 / (4 * QR5_0))
|
#define QI5_0 (QK5_0 / (4 * QR5_0))
|
||||||
#define QR5_0 2
|
#define QR5_0 2
|
||||||
|
|
||||||
|
@ -184,6 +187,13 @@ typedef struct {
|
||||||
} block_q4_1;
|
} block_q4_1;
|
||||||
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
|
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
|
||||||
|
|
||||||
|
#define QK_MXFP4 32
|
||||||
|
typedef struct {
|
||||||
|
uint8_t e; // E8M0
|
||||||
|
uint8_t qs[QK_MXFP4/2];
|
||||||
|
} block_mxfp4;
|
||||||
|
static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, "wrong mxfp4 block size/padding");
|
||||||
|
|
||||||
#define QK5_0 32
|
#define QK5_0 32
|
||||||
typedef struct {
|
typedef struct {
|
||||||
ggml_half d; // delta
|
ggml_half d; // delta
|
||||||
|
@ -1074,10 +1084,17 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
|
||||||
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
|
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
|
||||||
GGML_TABLE_END()
|
GGML_TABLE_END()
|
||||||
|
|
||||||
|
// TODO: fix name to kvalues_iq4_nl
|
||||||
GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
|
GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
|
||||||
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
|
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
|
||||||
GGML_TABLE_END()
|
GGML_TABLE_END()
|
||||||
|
|
||||||
|
// e2m1 values (doubled)
|
||||||
|
// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
||||||
|
GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
|
||||||
|
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
|
||||||
|
GGML_TABLE_END()
|
||||||
|
|
||||||
#define NGRID_IQ1S 2048
|
#define NGRID_IQ1S 2048
|
||||||
#define IQ1S_DELTA 0.125f
|
#define IQ1S_DELTA 0.125f
|
||||||
#define IQ1M_DELTA 0.125f
|
#define IQ1M_DELTA 0.125f
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
|
#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
|
||||||
#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
|
#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
|
||||||
#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
|
#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
|
||||||
|
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
||||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||||
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
|
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
|
||||||
|
@ -68,6 +69,7 @@
|
||||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||||
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
||||||
|
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
||||||
// repack.cpp
|
// repack.cpp
|
||||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||||
|
@ -90,6 +92,7 @@
|
||||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||||
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
||||||
|
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
||||||
// repack.cpp
|
// repack.cpp
|
||||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||||
|
@ -120,6 +123,7 @@
|
||||||
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
||||||
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
|
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
|
||||||
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
|
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
|
||||||
|
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
||||||
// repack.cpp
|
// repack.cpp
|
||||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||||
|
@ -149,6 +153,7 @@
|
||||||
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
|
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
|
||||||
#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
|
#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
|
||||||
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
||||||
|
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
||||||
// repack.cpp
|
// repack.cpp
|
||||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||||
|
@ -179,6 +184,7 @@
|
||||||
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
||||||
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
|
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
|
||||||
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
|
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
|
||||||
|
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
||||||
// repack.cpp
|
// repack.cpp
|
||||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||||
|
|
|
@ -589,6 +589,67 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||||
*s = sumf;
|
*s = sumf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||||
|
assert(nrc == 1);
|
||||||
|
UNUSED(nrc);
|
||||||
|
UNUSED(bx);
|
||||||
|
UNUSED(by);
|
||||||
|
UNUSED(bs);
|
||||||
|
assert(n % QK_MXFP4 == 0);
|
||||||
|
static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
|
||||||
|
|
||||||
|
const block_mxfp4 * GGML_RESTRICT x = vx;
|
||||||
|
const block_q8_0 * GGML_RESTRICT y = vy;
|
||||||
|
|
||||||
|
const int nb = n / QK_MXFP4;
|
||||||
|
|
||||||
|
int ib = 0;
|
||||||
|
float sumf = 0;
|
||||||
|
|
||||||
|
#if defined __ARM_NEON
|
||||||
|
const int8x16_t values = vld1q_s8(kvalues_mxfp4);
|
||||||
|
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||||
|
uint8x16x2_t q4bits;
|
||||||
|
int8x16x4_t q4b;
|
||||||
|
int8x16x4_t q8b;
|
||||||
|
int32x4_t prod_1;
|
||||||
|
int32x4_t prod_2;
|
||||||
|
|
||||||
|
for (; ib + 1 < nb; ib += 2) {
|
||||||
|
q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
|
||||||
|
q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
|
||||||
|
q8b.val[0] = vld1q_s8(y[ib + 0].qs);
|
||||||
|
q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16);
|
||||||
|
q8b.val[2] = vld1q_s8(y[ib + 1].qs);
|
||||||
|
q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16);
|
||||||
|
|
||||||
|
q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
|
||||||
|
q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
|
||||||
|
q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
|
||||||
|
q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
|
||||||
|
|
||||||
|
prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
|
||||||
|
prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
|
||||||
|
|
||||||
|
sumf +=
|
||||||
|
GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
|
||||||
|
GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
for (; ib < nb; ++ib) {
|
||||||
|
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
|
||||||
|
int sumi1 = 0;
|
||||||
|
int sumi2 = 0;
|
||||||
|
for (int j = 0; j < QK_MXFP4/2; ++j) {
|
||||||
|
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
|
||||||
|
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
|
||||||
|
}
|
||||||
|
sumf += d * (sumi1 + sumi2);
|
||||||
|
}
|
||||||
|
*s = sumf;
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||||
const int qk = QK8_0;
|
const int qk = QK8_0;
|
||||||
const int nb = n / qk;
|
const int nb = n / qk;
|
||||||
|
|
|
@ -67,6 +67,12 @@ static inline int hsum_i32_4(const __m128i a) {
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(__AVX2__) || defined(__AVX512F__)
|
#if defined(__AVX2__) || defined(__AVX512F__)
|
||||||
|
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
|
||||||
|
const __m256i ax = _mm256_sign_epi8(x, x);
|
||||||
|
const __m256i sy = _mm256_sign_epi8(y, x);
|
||||||
|
return _mm256_maddubs_epi16(ax, sy);
|
||||||
|
}
|
||||||
|
|
||||||
// spread 32 bits to 32 bytes { 0x00, 0xFF }
|
// spread 32 bits to 32 bytes { 0x00, 0xFF }
|
||||||
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
|
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
|
||||||
uint32_t x32;
|
uint32_t x32;
|
||||||
|
@ -262,6 +268,11 @@ static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const
|
||||||
return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)),
|
return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)),
|
||||||
_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));
|
_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline __m256 quad_mx_delta_float(const int8_t x0, const float y0, const int8_t x1, const float y1) {
|
||||||
|
return _mm256_set_m128(_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)),
|
||||||
|
_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0)));
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
#elif defined(__SSSE3__)
|
#elif defined(__SSSE3__)
|
||||||
// horizontally add 4x4 floats
|
// horizontally add 4x4 floats
|
||||||
|
@ -747,6 +758,91 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||||
|
assert(nrc == 1);
|
||||||
|
UNUSED(nrc);
|
||||||
|
UNUSED(bx);
|
||||||
|
UNUSED(by);
|
||||||
|
UNUSED(bs);
|
||||||
|
assert(n % QK_MXFP4 == 0);
|
||||||
|
static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
|
||||||
|
|
||||||
|
const block_mxfp4 * GGML_RESTRICT x = vx;
|
||||||
|
const block_q8_0 * GGML_RESTRICT y = vy;
|
||||||
|
|
||||||
|
const int nb = n / QK_MXFP4;
|
||||||
|
|
||||||
|
int ib = 0;
|
||||||
|
float sumf = 0;
|
||||||
|
|
||||||
|
#if defined __AVX2__
|
||||||
|
|
||||||
|
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
|
||||||
|
const __m128i m4b = _mm_set1_epi8(0x0f);
|
||||||
|
const __m256i mone = _mm256_set1_epi16(1);
|
||||||
|
|
||||||
|
__m256 accum1 = _mm256_setzero_ps();
|
||||||
|
__m256 accum2 = _mm256_setzero_ps();
|
||||||
|
for (; ib + 1 < nb; ib += 2) {
|
||||||
|
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
|
||||||
|
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
|
||||||
|
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
|
||||||
|
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
|
||||||
|
const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
|
||||||
|
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
|
||||||
|
const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
|
||||||
|
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
|
||||||
|
const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
|
||||||
|
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
|
||||||
|
const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
|
||||||
|
const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
|
||||||
|
accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)),
|
||||||
|
_mm256_cvtepi32_ps(p_1), accum1);
|
||||||
|
accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)),
|
||||||
|
_mm256_cvtepi32_ps(p_2), accum2);
|
||||||
|
}
|
||||||
|
|
||||||
|
sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
|
||||||
|
|
||||||
|
#elif defined __AVX__
|
||||||
|
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
|
||||||
|
const __m128i m4b = _mm_set1_epi8(0x0f);
|
||||||
|
|
||||||
|
__m256 accum = _mm256_setzero_ps();
|
||||||
|
for (; ib + 1 < nb; ib += 2) {
|
||||||
|
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
|
||||||
|
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
|
||||||
|
const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
|
||||||
|
const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
|
||||||
|
const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
|
||||||
|
const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
|
||||||
|
|
||||||
|
const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
|
||||||
|
const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
|
||||||
|
const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
|
||||||
|
const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
|
||||||
|
|
||||||
|
const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
|
||||||
|
const __m256 deltas = quad_mx_delta_float(x[ib].e, y[ib].d, x[ib + 1].e, y[ib + 1].d);
|
||||||
|
accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
|
||||||
|
}
|
||||||
|
|
||||||
|
sumf = hsum_float_8(accum);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
for (; ib < nb; ++ib) {
|
||||||
|
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
|
||||||
|
int sumi1 = 0;
|
||||||
|
int sumi2 = 0;
|
||||||
|
for (int j = 0; j < QK_MXFP4/2; ++j) {
|
||||||
|
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
|
||||||
|
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
|
||||||
|
}
|
||||||
|
sumf += d * (sumi1 + sumi2);
|
||||||
|
}
|
||||||
|
*s = sumf;
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||||
const int qk = QK8_0;
|
const int qk = QK8_0;
|
||||||
const int nb = n / qk;
|
const int nb = n / qk;
|
||||||
|
@ -3207,14 +3303,6 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(__AVX2__)
|
|
||||||
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
|
|
||||||
const __m256i ax = _mm256_sign_epi8(x, x);
|
|
||||||
const __m256i sy = _mm256_sign_epi8(y, x);
|
|
||||||
return _mm256_maddubs_epi16(ax, sy);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||||
assert(n % QK_K == 0);
|
assert(n % QK_K == 0);
|
||||||
assert(nrc == 1);
|
assert(nrc == 1);
|
||||||
|
|
|
@ -257,6 +257,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
||||||
.vec_dot_type = GGML_TYPE_Q8_1,
|
.vec_dot_type = GGML_TYPE_Q8_1,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
},
|
},
|
||||||
|
[GGML_TYPE_MXFP4] = {
|
||||||
|
.from_float = quantize_row_mxfp4,
|
||||||
|
.vec_dot = ggml_vec_dot_mxfp4_q8_0,
|
||||||
|
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||||
|
.nrows = 1,
|
||||||
|
},
|
||||||
[GGML_TYPE_Q2_K] = {
|
[GGML_TYPE_Q2_K] = {
|
||||||
.from_float = quantize_row_q2_K,
|
.from_float = quantize_row_q2_K,
|
||||||
.vec_dot = ggml_vec_dot_q2_K_q8_K,
|
.vec_dot = ggml_vec_dot_q2_K_q8_K,
|
||||||
|
@ -1684,6 +1690,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
{
|
{
|
||||||
ggml_compute_forward_add(params, tensor);
|
ggml_compute_forward_add(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_add_id(params, tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_ADD1:
|
case GGML_OP_ADD1:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_add1(params, tensor);
|
ggml_compute_forward_add1(params, tensor);
|
||||||
|
@ -1938,7 +1948,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
ggml_compute_forward_flash_attn_ext(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_ATTN_BACK:
|
case GGML_OP_FLASH_ATTN_BACK:
|
||||||
{
|
{
|
||||||
|
@ -2125,6 +2135,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
case GGML_OP_ADD1:
|
case GGML_OP_ADD1:
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
{
|
{
|
||||||
|
@ -2186,6 +2197,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
case GGML_GLU_OP_REGLU:
|
case GGML_GLU_OP_REGLU:
|
||||||
case GGML_GLU_OP_GEGLU:
|
case GGML_GLU_OP_GEGLU:
|
||||||
case GGML_GLU_OP_SWIGLU:
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
case GGML_GLU_OP_SWIGLU_OAI:
|
||||||
case GGML_GLU_OP_GEGLU_ERF:
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
case GGML_GLU_OP_GEGLU_QUICK:
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
{
|
{
|
||||||
|
@ -2696,6 +2708,7 @@ struct ggml_cplan ggml_graph_plan(
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
case GGML_OP_ADD1:
|
case GGML_OP_ADD1:
|
||||||
{
|
{
|
||||||
if (ggml_is_quantized(node->src[0]->type)) {
|
if (ggml_is_quantized(node->src[0]->type)) {
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
#include "vec.h"
|
#include "vec.h"
|
||||||
|
|
||||||
#include <float.h>
|
#include <float.h>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
// ggml_compute_forward_dup
|
// ggml_compute_forward_dup
|
||||||
|
|
||||||
|
@ -1283,6 +1284,7 @@ void ggml_compute_forward_add(
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
|
@ -1309,6 +1311,77 @@ void ggml_compute_forward_add(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_compute_forward_add_id
|
||||||
|
|
||||||
|
static void ggml_compute_forward_add_id_f32(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
const ggml_tensor * src2 = dst->src[2];
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src2->type == GGML_TYPE_I32);
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||||
|
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int nr = ggml_nrows(src0);
|
||||||
|
|
||||||
|
GGML_TENSOR_TERNARY_OP_LOCALS
|
||||||
|
|
||||||
|
GGML_ASSERT( nb0 == sizeof(float));
|
||||||
|
GGML_ASSERT(nb10 == sizeof(float));
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr*ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
for (int ir = ir0; ir < ir1; ++ir) {
|
||||||
|
// src0 indices
|
||||||
|
const int i3 = ir/(ne2*ne1);
|
||||||
|
const int i2 = (ir - i3*ne2*ne1)/ne1;
|
||||||
|
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
||||||
|
|
||||||
|
// src1 indices
|
||||||
|
const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
|
||||||
|
|
||||||
|
GGML_ASSERT(i11 >= 0 && i11 < ne11);
|
||||||
|
|
||||||
|
ggml_vec_add_f32(ne0,
|
||||||
|
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
|
||||||
|
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
|
||||||
|
(float *) ((char *) src1->data + i11*nb11));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_compute_forward_add_id(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_add_id_f32(params, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_add1
|
// ggml_compute_forward_add1
|
||||||
|
|
||||||
static void ggml_compute_forward_add1_f32(
|
static void ggml_compute_forward_add1_f32(
|
||||||
|
@ -1660,6 +1733,7 @@ void ggml_compute_forward_add1(
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
case GGML_TYPE_Q8_1:
|
case GGML_TYPE_Q8_1:
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
|
@ -1787,6 +1861,7 @@ void ggml_compute_forward_acc(
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
case GGML_TYPE_Q8_1:
|
case GGML_TYPE_Q8_1:
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
|
@ -3614,6 +3689,93 @@ static void ggml_compute_forward_swiglu(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_compute_forward_swiglu_oai
|
||||||
|
|
||||||
|
static void ggml_compute_forward_swiglu_oai_f32(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
char * src0_d = (char *) src0->data;
|
||||||
|
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
||||||
|
const size_t src0_o = src0->nb[1];
|
||||||
|
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||||
|
|
||||||
|
if (src1) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||||
|
GGML_ASSERT(src0->type == src1->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||||
|
const int nr = ggml_nrows(src0);
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
|
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||||
|
|
||||||
|
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||||
|
const float alpha = ggml_get_op_params_f32(dst, 2);
|
||||||
|
const float limit = ggml_get_op_params_f32(dst, 3);
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr*ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
|
float * src0_p = (float *) (src0_d + i1*src0_o);
|
||||||
|
float * src1_p = (float *) (src1_d + i1*src1_o);
|
||||||
|
float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
|
||||||
|
|
||||||
|
if (!src1) {
|
||||||
|
src0_p += swapped ? nc : 0;
|
||||||
|
src1_p += swapped ? 0 : nc;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int k = 0; k < nc; k++) {
|
||||||
|
const float x = std::min(src0_p[k], limit);
|
||||||
|
const float y = std::clamp(src1_p[k], -limit, limit);
|
||||||
|
const float out_glu = x / (1.f + expf(alpha * (-x)));
|
||||||
|
dst_p[k] = out_glu * (y + 1.f);
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
for (int k = 0; k < nc; k++) {
|
||||||
|
const float x = dst_p[k];
|
||||||
|
GGML_UNUSED(x);
|
||||||
|
assert(!isnan(x));
|
||||||
|
assert(!isinf(x));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_swiglu_oai(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_swiglu_oai_f32(params, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_geglu_erf
|
// ggml_compute_forward_geglu_erf
|
||||||
|
|
||||||
static void ggml_compute_forward_geglu_erf_f32(
|
static void ggml_compute_forward_geglu_erf_f32(
|
||||||
|
@ -4599,6 +4761,7 @@ void ggml_compute_forward_out_prod(
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
|
@ -4873,6 +5036,7 @@ void ggml_compute_forward_set(
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
case GGML_TYPE_Q8_1:
|
case GGML_TYPE_Q8_1:
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
|
@ -5134,6 +5298,7 @@ void ggml_compute_forward_get_rows(
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
case GGML_TYPE_Q8_1:
|
case GGML_TYPE_Q8_1:
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
|
@ -5523,6 +5688,7 @@ static void ggml_compute_forward_soft_max_f32(
|
||||||
|
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const ggml_tensor * src1 = dst->src[1];
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
const ggml_tensor * src2 = dst->src[2];
|
||||||
|
|
||||||
assert(ggml_is_contiguous(dst));
|
assert(ggml_is_contiguous(dst));
|
||||||
assert(ggml_are_same_shape(src0, dst));
|
assert(ggml_are_same_shape(src0, dst));
|
||||||
|
@ -5557,6 +5723,9 @@ static void ggml_compute_forward_soft_max_f32(
|
||||||
|
|
||||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
||||||
|
|
||||||
|
// sinks
|
||||||
|
const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
|
||||||
|
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||||
|
@ -5599,9 +5768,18 @@ static void ggml_compute_forward_soft_max_f32(
|
||||||
float max = -INFINITY;
|
float max = -INFINITY;
|
||||||
ggml_vec_max_f32(ne00, &max, wp);
|
ggml_vec_max_f32(ne00, &max, wp);
|
||||||
|
|
||||||
|
// if we have sinks, make a correction as if they were included in the softmax
|
||||||
|
if (sk) {
|
||||||
|
max = MAX(max, sk[i02]);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
|
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
|
||||||
assert(sum > 0.0);
|
assert(sum > 0.0);
|
||||||
|
|
||||||
|
if (sk) {
|
||||||
|
sum += (ggml_float) expf(sk[i02] - max);
|
||||||
|
}
|
||||||
|
|
||||||
sum = 1.0/sum;
|
sum = 1.0/sum;
|
||||||
ggml_vec_scale_f32(ne00, dp, sum);
|
ggml_vec_scale_f32(ne00, dp, sum);
|
||||||
|
|
||||||
|
@ -5836,6 +6014,7 @@ void ggml_compute_forward_clamp(
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
case GGML_TYPE_Q8_1:
|
case GGML_TYPE_Q8_1:
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
|
@ -7989,12 +8168,14 @@ void ggml_compute_forward_argsort(
|
||||||
|
|
||||||
static void ggml_compute_forward_flash_attn_ext_f16(
|
static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
const ggml_compute_params * params,
|
const ggml_compute_params * params,
|
||||||
const ggml_tensor * q,
|
|
||||||
const ggml_tensor * k,
|
|
||||||
const ggml_tensor * v,
|
|
||||||
const ggml_tensor * mask,
|
|
||||||
ggml_tensor * dst) {
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * q = dst->src[0];
|
||||||
|
const ggml_tensor * k = dst->src[1];
|
||||||
|
const ggml_tensor * v = dst->src[2];
|
||||||
|
const ggml_tensor * mask = dst->src[3];
|
||||||
|
const ggml_tensor * sinks = dst->src[4];
|
||||||
|
|
||||||
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
||||||
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
||||||
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
||||||
|
@ -8189,6 +8370,23 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sinks
|
||||||
|
if (sinks) {
|
||||||
|
const float s = ((float *)((char *) sinks->data))[h];
|
||||||
|
|
||||||
|
float ms = 1.0f;
|
||||||
|
float vs = 1.0f;
|
||||||
|
|
||||||
|
if (s > M) {
|
||||||
|
ms = expf(M - s);
|
||||||
|
ggml_vec_scale_f32(DV, VKQ32, ms);
|
||||||
|
} else {
|
||||||
|
vs = expf(s - M);
|
||||||
|
}
|
||||||
|
|
||||||
|
S = S*ms + vs;
|
||||||
|
}
|
||||||
|
|
||||||
// V /= S
|
// V /= S
|
||||||
const float S_inv = 1.0f/S;
|
const float S_inv = 1.0f/S;
|
||||||
ggml_vec_scale_f32(DV, VKQ32, S_inv);
|
ggml_vec_scale_f32(DV, VKQ32, S_inv);
|
||||||
|
@ -8208,17 +8406,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
|
|
||||||
void ggml_compute_forward_flash_attn_ext(
|
void ggml_compute_forward_flash_attn_ext(
|
||||||
const ggml_compute_params * params,
|
const ggml_compute_params * params,
|
||||||
const ggml_tensor * q,
|
|
||||||
const ggml_tensor * k,
|
|
||||||
const ggml_tensor * v,
|
|
||||||
const ggml_tensor * mask,
|
|
||||||
ggml_tensor * dst) {
|
ggml_tensor * dst) {
|
||||||
switch (dst->op_params[3]) {
|
switch (dst->op_params[3]) {
|
||||||
case GGML_PREC_DEFAULT:
|
case GGML_PREC_DEFAULT:
|
||||||
case GGML_PREC_F32:
|
case GGML_PREC_F32:
|
||||||
{
|
{
|
||||||
// uses F32 accumulators
|
// uses F32 accumulators
|
||||||
ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
|
ggml_compute_forward_flash_attn_ext_f16(params, dst);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
|
@ -9080,6 +9274,10 @@ void ggml_compute_forward_glu(
|
||||||
{
|
{
|
||||||
ggml_compute_forward_swiglu(params, dst);
|
ggml_compute_forward_swiglu(params, dst);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_GLU_OP_SWIGLU_OAI:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_swiglu_oai(params, dst);
|
||||||
|
} break;
|
||||||
case GGML_GLU_OP_GEGLU_ERF:
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_geglu_erf(params, dst);
|
ggml_compute_forward_geglu_erf(params, dst);
|
||||||
|
|
|
@ -29,6 +29,7 @@ extern "C" {
|
||||||
|
|
||||||
void ggml_compute_forward_dup(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_dup(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_add(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_add(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
void ggml_compute_forward_add_id(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
@ -82,13 +83,7 @@ void ggml_compute_forward_arange(const struct ggml_compute_params * params, stru
|
||||||
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_flash_attn_ext(
|
void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
const struct ggml_compute_params * params,
|
|
||||||
const struct ggml_tensor * q,
|
|
||||||
const struct ggml_tensor * k,
|
|
||||||
const struct ggml_tensor * v,
|
|
||||||
const struct ggml_tensor * mask,
|
|
||||||
struct ggml_tensor * dst);
|
|
||||||
void ggml_compute_forward_flash_attn_back(
|
void ggml_compute_forward_flash_attn_back(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const bool masked,
|
const bool masked,
|
||||||
|
|
|
@ -46,6 +46,10 @@ void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRI
|
||||||
quantize_row_q8_1_ref(x, y, k);
|
quantize_row_q8_1_ref(x, y, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
|
||||||
|
quantize_row_mxfp4_ref(x, y, k);
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// 2-6 bit quantization in super-blocks
|
// 2-6 bit quantization in super-blocks
|
||||||
//
|
//
|
||||||
|
@ -181,6 +185,37 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
|
||||||
*s = sumf;
|
*s = sumf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||||
|
assert(nrc == 1);
|
||||||
|
UNUSED(nrc);
|
||||||
|
UNUSED(bx);
|
||||||
|
UNUSED(by);
|
||||||
|
UNUSED(bs);
|
||||||
|
assert(n % QK_MXFP4 == 0);
|
||||||
|
static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
|
||||||
|
|
||||||
|
const block_mxfp4 * GGML_RESTRICT x = vx;
|
||||||
|
const block_q8_0 * GGML_RESTRICT y = vy;
|
||||||
|
|
||||||
|
const int nb = n / QK_MXFP4;
|
||||||
|
|
||||||
|
int ib = 0;
|
||||||
|
float sumf = 0;
|
||||||
|
|
||||||
|
for (; ib < nb; ++ib) {
|
||||||
|
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
|
||||||
|
|
||||||
|
int sumi1 = 0;
|
||||||
|
int sumi2 = 0;
|
||||||
|
for (int j = 0; j < QK_MXFP4/2; ++j) {
|
||||||
|
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
|
||||||
|
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
|
||||||
|
}
|
||||||
|
sumf += d * (sumi1 + sumi2);
|
||||||
|
}
|
||||||
|
*s = sumf;
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||||
const int qk = QK8_0;
|
const int qk = QK8_0;
|
||||||
const int nb = n / qk;
|
const int nb = n / qk;
|
||||||
|
|
|
@ -19,6 +19,8 @@ void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
|
||||||
void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||||
void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||||
|
|
||||||
|
void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||||
|
|
||||||
void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||||
void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||||
void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||||
|
@ -39,6 +41,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||||
void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
|
|
||||||
|
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
|
|
||||||
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
|
@ -67,8 +71,12 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
|
||||||
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
|
|
||||||
|
void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
|
|
||||||
void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
|
|
||||||
void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
|
|
|
@ -55,7 +55,22 @@ inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x)
|
||||||
|
|
||||||
inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
||||||
inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
||||||
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
|
|
||||||
|
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) {
|
||||||
|
int i = 0;
|
||||||
|
#if defined(__AVX2__)
|
||||||
|
for (; i + 7 < n; i += 8) {
|
||||||
|
__m256 vx = _mm256_loadu_ps(x + i);
|
||||||
|
__m256 vy = _mm256_loadu_ps(y + i);
|
||||||
|
__m256 vz = _mm256_add_ps(vx, vy);
|
||||||
|
_mm256_storeu_ps(z + i, vz);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
for (; i < n; ++i) {
|
||||||
|
z[i] = x[i] + y[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
|
inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) + GGML_CPU_FP16_TO_FP32(y[i]));
|
z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) + GGML_CPU_FP16_TO_FP32(y[i]));
|
||||||
|
@ -992,9 +1007,9 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
|
||||||
|
|
||||||
inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
float v = GGML_CPU_FP16_TO_FP32(x[i]);
|
float xi = GGML_CPU_FP16_TO_FP32(x[i]);
|
||||||
float w = GGML_CPU_FP16_TO_FP32(g[i]);
|
float gi = GGML_CPU_FP16_TO_FP32(g[i]);
|
||||||
y[i] = GGML_CPU_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
|
y[i] = GGML_CPU_FP32_TO_FP16((xi/(1.0f + expf(-xi))) * gi);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
58
ggml/src/ggml-cuda/add-id.cu
Normal file
58
ggml/src/ggml-cuda/add-id.cu
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
#include "add-id.cuh"
|
||||||
|
|
||||||
|
static __global__ void add_id_kernel(
|
||||||
|
const float * src0, const float * src1, const int32_t * src2, float * dst,
|
||||||
|
int64_t ne0, int64_t ne1,
|
||||||
|
size_t nb01, size_t nb02,
|
||||||
|
size_t nb11,
|
||||||
|
size_t nb21
|
||||||
|
) {
|
||||||
|
|
||||||
|
const int64_t i1 = blockIdx.x;
|
||||||
|
const int64_t i2 = blockIdx.y;
|
||||||
|
|
||||||
|
const int i11 = *(int32_t *) ((char *) src2 + i1*sizeof(int32_t) + i2*nb21);
|
||||||
|
|
||||||
|
const size_t nb1 = ne0 * sizeof(float);
|
||||||
|
const size_t nb2 = ne1 * nb1;
|
||||||
|
|
||||||
|
float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);
|
||||||
|
const float * src0_row = (const float *)((char *)src0 + i1*nb01 + i2*nb02);
|
||||||
|
const float * src1_row = (const float *)((char *)src1 + i11*nb11);
|
||||||
|
|
||||||
|
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
|
||||||
|
dst_row[i0] = src0_row[i0] + src1_row[i0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
const ggml_tensor * src2 = dst->src[2];
|
||||||
|
|
||||||
|
GGML_TENSOR_TERNARY_OP_LOCALS
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src2->type == GGML_TYPE_I32);
|
||||||
|
|
||||||
|
GGML_ASSERT(nb00 == sizeof(float));
|
||||||
|
GGML_ASSERT(nb10 == sizeof(float));
|
||||||
|
GGML_ASSERT(nb20 == sizeof(int32_t));
|
||||||
|
|
||||||
|
const float * src0_d = (const float *)src0->data;
|
||||||
|
const float * src1_d = (const float *)src1->data;
|
||||||
|
const int32_t * src2_d = (const int32_t *)src2->data;
|
||||||
|
float * dst_d = (float *)dst->data;
|
||||||
|
|
||||||
|
int threads = std::min((int)ne00, 768); // cols
|
||||||
|
dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
|
||||||
|
add_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(
|
||||||
|
src0_d, src1_d, src2_d, dst_d,
|
||||||
|
ne0, ne1,
|
||||||
|
nb01, nb02,
|
||||||
|
nb11,
|
||||||
|
nb21
|
||||||
|
);
|
||||||
|
}
|
3
ggml/src/ggml-cuda/add-id.cuh
Normal file
3
ggml/src/ggml-cuda/add-id.cuh
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
#include "common.cuh"
|
||||||
|
|
||||||
|
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@ -1,6 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
#include "ggml-impl.h"
|
||||||
#include "ggml-cuda.h"
|
#include "ggml-cuda.h"
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
@ -553,6 +554,24 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
||||||
#endif // defined(GGML_USE_HIP)
|
#endif // defined(GGML_USE_HIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
|
||||||
|
#if CUDART_VERSION >= 12080
|
||||||
|
const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);
|
||||||
|
return (float) e;
|
||||||
|
#else
|
||||||
|
uint32_t bits;
|
||||||
|
if (x == 0) {
|
||||||
|
bits = 0x00400000;
|
||||||
|
} else {
|
||||||
|
bits = (uint32_t) x << 23;
|
||||||
|
}
|
||||||
|
|
||||||
|
float result;
|
||||||
|
memcpy(&result, &bits, sizeof(float));
|
||||||
|
return result;
|
||||||
|
#endif // CUDART_VERSION >= 12050
|
||||||
|
}
|
||||||
|
|
||||||
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
|
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
|
||||||
|
|
||||||
static __device__ __forceinline__ float get_alibi_slope(
|
static __device__ __forceinline__ float get_alibi_slope(
|
||||||
|
@ -611,6 +630,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
|
||||||
static constexpr int qi = QI8_0;
|
static constexpr int qi = QI8_0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct ggml_cuda_type_traits<GGML_TYPE_MXFP4> {
|
||||||
|
static constexpr int qk = QK_MXFP4;
|
||||||
|
static constexpr int qr = QR_MXFP4;
|
||||||
|
static constexpr int qi = QI_MXFP4;
|
||||||
|
};
|
||||||
|
|
||||||
template<>
|
template<>
|
||||||
struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
|
struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
|
||||||
static constexpr int qk = QK_K;
|
static constexpr int qk = QK_K;
|
||||||
|
|
|
@ -465,6 +465,24 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
|
const int64_t i = blockIdx.x;
|
||||||
|
const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4);
|
||||||
|
|
||||||
|
const int64_t tid = threadIdx.x;
|
||||||
|
const int64_t il = tid/8; // 0...3
|
||||||
|
const int64_t ib = tid%8; // 0...7
|
||||||
|
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||||
|
const uint8_t * q4 = x[ib].qs + 4*il;
|
||||||
|
const float d = ggml_cuda_e8m0_to_fp32(x[ib].e);
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f;
|
||||||
|
y[j+16] = d * kvalues_mxfp4[q4[j] >> 4]*0.5f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||||
static void dequantize_block_cuda(const void * vx, dst_t * y,
|
static void dequantize_block_cuda(const void * vx, dst_t * y,
|
||||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||||
|
@ -588,6 +606,12 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t
|
||||||
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
||||||
|
const int nb = (k + QK_K - 1) / QK_K;
|
||||||
|
dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename src_t, typename dst_t>
|
template <typename src_t, typename dst_t>
|
||||||
static __global__ void convert_unary(
|
static __global__ void convert_unary(
|
||||||
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||||
|
@ -677,6 +701,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
||||||
return dequantize_row_iq4_xs_cuda;
|
return dequantize_row_iq4_xs_cuda;
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
return dequantize_row_iq3_s_cuda;
|
return dequantize_row_iq3_s_cuda;
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
|
return dequantize_row_mxfp4_cuda;
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
return convert_unary_cont_cuda<float>;
|
return convert_unary_cont_cuda<float>;
|
||||||
case GGML_TYPE_BF16:
|
case GGML_TYPE_BF16:
|
||||||
|
@ -726,6 +752,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||||
return dequantize_row_iq4_xs_cuda;
|
return dequantize_row_iq4_xs_cuda;
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
return dequantize_row_iq3_s_cuda;
|
return dequantize_row_iq3_s_cuda;
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
|
return dequantize_row_mxfp4_cuda;
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
return convert_unary_cont_cuda<half>;
|
return convert_unary_cont_cuda<half>;
|
||||||
case GGML_TYPE_BF16:
|
case GGML_TYPE_BF16:
|
||||||
|
|
|
@ -15,6 +15,7 @@ typedef void (* fattn_kernel_t)(
|
||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
const char * __restrict__ V,
|
const char * __restrict__ V,
|
||||||
const char * __restrict__ mask,
|
const char * __restrict__ mask,
|
||||||
|
const char * __restrict__ sinks,
|
||||||
const int * __restrict__ KV_max,
|
const int * __restrict__ KV_max,
|
||||||
float * __restrict__ dst,
|
float * __restrict__ dst,
|
||||||
float2 * __restrict__ dst_meta,
|
float2 * __restrict__ dst_meta,
|
||||||
|
@ -737,6 +738,7 @@ void launch_fattn(
|
||||||
GGML_ASSERT(V || is_mla);
|
GGML_ASSERT(V || is_mla);
|
||||||
|
|
||||||
const ggml_tensor * mask = dst->src[3];
|
const ggml_tensor * mask = dst->src[3];
|
||||||
|
const ggml_tensor * sinks = dst->src[4];
|
||||||
|
|
||||||
ggml_tensor * KQV = dst;
|
ggml_tensor * KQV = dst;
|
||||||
|
|
||||||
|
@ -940,6 +942,7 @@ void launch_fattn(
|
||||||
K_data,
|
K_data,
|
||||||
V_data,
|
V_data,
|
||||||
mask ? ((const char *) mask->data) : nullptr,
|
mask ? ((const char *) mask->data) : nullptr,
|
||||||
|
sinks ? ((const char *) sinks->data) : nullptr,
|
||||||
KV_max.ptr,
|
KV_max.ptr,
|
||||||
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||||
|
|
|
@ -1206,6 +1206,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
const char * __restrict__ V,
|
const char * __restrict__ V,
|
||||||
const char * __restrict__ mask,
|
const char * __restrict__ mask,
|
||||||
|
const char * __restrict__ sinks,
|
||||||
const int * __restrict__ KV_max,
|
const int * __restrict__ KV_max,
|
||||||
float * __restrict__ dst,
|
float * __restrict__ dst,
|
||||||
float2 * __restrict__ dst_meta,
|
float2 * __restrict__ dst_meta,
|
||||||
|
@ -1267,6 +1268,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
// kb0 == k start index when in the output tile.
|
// kb0 == k start index when in the output tile.
|
||||||
int kb0_start = kbc % iter_k;
|
int kb0_start = kbc % iter_k;
|
||||||
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
||||||
|
|
||||||
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
||||||
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
||||||
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
||||||
|
@ -1340,7 +1342,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
|
||||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
||||||
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||||
|
|
|
@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
const char * __restrict__ V,
|
const char * __restrict__ V,
|
||||||
const char * __restrict__ mask,
|
const char * __restrict__ mask,
|
||||||
|
const char * __restrict__ sinks,
|
||||||
const int * __restrict__ KV_max,
|
const int * __restrict__ KV_max,
|
||||||
float * __restrict__ dst,
|
float * __restrict__ dst,
|
||||||
float2 * __restrict__ dst_meta,
|
float2 * __restrict__ dst_meta,
|
||||||
|
@ -272,7 +273,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
|
||||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||||
|
|
|
@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
const char * __restrict__ V,
|
const char * __restrict__ V,
|
||||||
const char * __restrict__ mask,
|
const char * __restrict__ mask,
|
||||||
|
const char * __restrict__ sinks,
|
||||||
const int * __restrict__ KV_max,
|
const int * __restrict__ KV_max,
|
||||||
float * __restrict__ dst,
|
float * __restrict__ dst,
|
||||||
float2 * __restrict__ dst_meta,
|
float2 * __restrict__ dst_meta,
|
||||||
|
@ -37,7 +38,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||||
return;
|
return;
|
||||||
#endif // FP16_MMA_AVAILABLE
|
#endif // FP16_MMA_AVAILABLE
|
||||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
|
||||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
||||||
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||||
|
|
|
@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
const char * __restrict__ V,
|
const char * __restrict__ V,
|
||||||
const char * __restrict__ mask,
|
const char * __restrict__ mask,
|
||||||
|
const char * __restrict__ sinks,
|
||||||
const int * __restrict__ KV_max,
|
const int * __restrict__ KV_max,
|
||||||
float * __restrict__ dst,
|
float * __restrict__ dst,
|
||||||
float2 * __restrict__ dst_meta,
|
float2 * __restrict__ dst_meta,
|
||||||
|
@ -62,6 +63,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
V += nb23*sequence + nb22*(head / gqa_ratio);
|
V += nb23*sequence + nb22*(head / gqa_ratio);
|
||||||
|
|
||||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||||
|
const float * sinksf = (const float *) (sinks);
|
||||||
|
|
||||||
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||||
const half slopeh = __float2half(slopef);
|
const half slopeh = __float2half(slopef);
|
||||||
|
@ -75,11 +77,12 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
half2 * KQ2 = (half2 *) KQ;
|
half2 * KQ2 = (half2 *) KQ;
|
||||||
|
|
||||||
half kqmax[ncols];
|
half kqmax[ncols];
|
||||||
|
half kqsum[ncols];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
kqmax[j] = -HALF_MAX_HALF;
|
kqmax[j] = -HALF_MAX_HALF;
|
||||||
|
kqsum[j] = 0.0f;
|
||||||
}
|
}
|
||||||
half kqsum[ncols] = {0.0f};
|
|
||||||
|
|
||||||
__shared__ half kqmax_shared[ncols][WARP_SIZE];
|
__shared__ half kqmax_shared[ncols][WARP_SIZE];
|
||||||
__shared__ half kqsum_shared[ncols][WARP_SIZE];
|
__shared__ half kqsum_shared[ncols][WARP_SIZE];
|
||||||
|
@ -283,6 +286,39 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sinksf && blockIdx.y == 0) {
|
||||||
|
const half sink = __float2half(sinksf[head]);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
half kqmax_new_j = kqmax_shared[j][threadIdx.x];
|
||||||
|
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||||
|
|
||||||
|
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
|
||||||
|
kqmax[j] = kqmax_new_j;
|
||||||
|
|
||||||
|
const half val = hexp(sink - kqmax[j]);
|
||||||
|
kqsum[j] = kqsum[j]*KQ_max_scale;
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
kqsum[j] += val;
|
||||||
|
}
|
||||||
|
|
||||||
|
VKQ[j] *= __half2half2(KQ_max_scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
kqsum[j] = warp_reduce_sum((float)kqsum[j]);
|
kqsum[j] = warp_reduce_sum((float)kqsum[j]);
|
||||||
|
@ -313,7 +349,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
|
||||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
||||||
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||||
|
|
|
@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
const char * __restrict__ V,
|
const char * __restrict__ V,
|
||||||
const char * __restrict__ mask,
|
const char * __restrict__ mask,
|
||||||
|
const char * __restrict__ sinks,
|
||||||
const int * __restrict__ KV_max,
|
const int * __restrict__ KV_max,
|
||||||
float * __restrict__ dst,
|
float * __restrict__ dst,
|
||||||
float2 * __restrict__ dst_meta,
|
float2 * __restrict__ dst_meta,
|
||||||
|
@ -73,6 +74,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
V += nb23*sequence + nb22*(head / gqa_ratio);
|
V += nb23*sequence + nb22*(head / gqa_ratio);
|
||||||
|
|
||||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||||
|
const float * sinksf = (const float *) (sinks);
|
||||||
|
|
||||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||||
|
|
||||||
|
@ -88,11 +90,12 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
}
|
}
|
||||||
|
|
||||||
float kqmax[ncols];
|
float kqmax[ncols];
|
||||||
|
float kqsum[ncols];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
kqmax[j] = -FLT_MAX/2.0f;
|
kqmax[j] = -FLT_MAX/2.0f;
|
||||||
|
kqsum[j] = 0.0f;
|
||||||
}
|
}
|
||||||
float kqsum[ncols] = {0.0f};
|
|
||||||
|
|
||||||
__shared__ float kqmax_shared[ncols][WARP_SIZE];
|
__shared__ float kqmax_shared[ncols][WARP_SIZE];
|
||||||
__shared__ float kqsum_shared[ncols][WARP_SIZE];
|
__shared__ float kqsum_shared[ncols][WARP_SIZE];
|
||||||
|
@ -279,6 +282,39 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sinksf && blockIdx.y == 0) {
|
||||||
|
const float sink = sinksf[head];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
float kqmax_new_j = kqmax_shared[j][threadIdx.x];
|
||||||
|
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||||
|
|
||||||
|
const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
|
||||||
|
kqmax[j] = kqmax_new_j;
|
||||||
|
|
||||||
|
const float val = expf(sink - kqmax[j]);
|
||||||
|
kqsum[j] = kqsum[j]*KQ_max_scale;
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
kqsum[j] += val;
|
||||||
|
}
|
||||||
|
|
||||||
|
VKQ[j] *= KQ_max_scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
kqsum[j] = warp_reduce_sum(kqsum[j]);
|
kqsum[j] = warp_reduce_sum(kqsum[j]);
|
||||||
|
|
|
@ -29,6 +29,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
const char * __restrict__ V,
|
const char * __restrict__ V,
|
||||||
const char * __restrict__ mask,
|
const char * __restrict__ mask,
|
||||||
|
const char * __restrict__ sinks,
|
||||||
const int * __restrict__ KV_max,
|
const int * __restrict__ KV_max,
|
||||||
float * __restrict__ dst,
|
float * __restrict__ dst,
|
||||||
float2 * __restrict__ dst_meta,
|
float2 * __restrict__ dst_meta,
|
||||||
|
@ -423,7 +424,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
dst_meta[j_dst_unrolled] = dst_meta_val;
|
dst_meta[j_dst_unrolled] = dst_meta_val;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
|
||||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||||
|
|
|
@ -274,12 +274,23 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
const ggml_tensor * K = dst->src[1];
|
const ggml_tensor * K = dst->src[1];
|
||||||
const ggml_tensor * V = dst->src[2];
|
const ggml_tensor * V = dst->src[2];
|
||||||
const ggml_tensor * mask = dst->src[3];
|
const ggml_tensor * mask = dst->src[3];
|
||||||
|
const ggml_tensor * sinks = dst->src[4];
|
||||||
|
|
||||||
ggml_cuda_set_device(ctx.device);
|
ggml_cuda_set_device(ctx.device);
|
||||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||||
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
||||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||||
|
|
||||||
|
// TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS]
|
||||||
|
if (sinks) {
|
||||||
|
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||||
|
} else {
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
#if defined(GGML_HIP_ROCWMMA_FATTN)
|
#if defined(GGML_HIP_ROCWMMA_FATTN)
|
||||||
if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
|
if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
|
||||||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||||
|
|
|
@ -6,6 +6,7 @@ bool g_mul_mat_q = true;
|
||||||
|
|
||||||
#include "ggml-cuda/common.cuh"
|
#include "ggml-cuda/common.cuh"
|
||||||
#include "ggml-cuda/acc.cuh"
|
#include "ggml-cuda/acc.cuh"
|
||||||
|
#include "ggml-cuda/add-id.cuh"
|
||||||
#include "ggml-cuda/arange.cuh"
|
#include "ggml-cuda/arange.cuh"
|
||||||
#include "ggml-cuda/argmax.cuh"
|
#include "ggml-cuda/argmax.cuh"
|
||||||
#include "ggml-cuda/argsort.cuh"
|
#include "ggml-cuda/argsort.cuh"
|
||||||
|
@ -2264,6 +2265,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
case GGML_OP_ADD1: // TODO: more efficient implementation
|
case GGML_OP_ADD1: // TODO: more efficient implementation
|
||||||
ggml_cuda_op_add(ctx, dst);
|
ggml_cuda_op_add(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
|
ggml_cuda_op_add_id(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_OP_SUB:
|
case GGML_OP_SUB:
|
||||||
ggml_cuda_op_sub(ctx, dst);
|
ggml_cuda_op_sub(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
@ -2338,6 +2342,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
case GGML_GLU_OP_SWIGLU:
|
case GGML_GLU_OP_SWIGLU:
|
||||||
ggml_cuda_op_swiglu(ctx, dst);
|
ggml_cuda_op_swiglu(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_GLU_OP_SWIGLU_OAI:
|
||||||
|
ggml_cuda_op_swiglu_oai(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_GLU_OP_GEGLU_ERF:
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
ggml_cuda_op_geglu_erf(ctx, dst);
|
ggml_cuda_op_geglu_erf(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
@ -2612,6 +2619,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||||
|
|
||||||
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
|
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
|
||||||
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
|
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
|
||||||
|
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
|
||||||
|
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
|
||||||
|
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
|
||||||
|
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
ggml_tensor * node = cgraph->nodes[i];
|
ggml_tensor * node = cgraph->nodes[i];
|
||||||
|
@ -2634,7 +2644,13 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true)) {
|
if (node->op == GGML_OP_ADD &&
|
||||||
|
node->src[1] && node->src[1]->ne[1] > 1 &&
|
||||||
|
(node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
|
||||||
|
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
|
||||||
|
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
|
||||||
|
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
|
||||||
|
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) {
|
||||||
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
|
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
|
||||||
// by means of matching node names. See
|
// by means of matching node names. See
|
||||||
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
|
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
|
||||||
|
@ -3232,6 +3248,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_GLU_OP_REGLU:
|
case GGML_GLU_OP_REGLU:
|
||||||
case GGML_GLU_OP_GEGLU:
|
case GGML_GLU_OP_GEGLU:
|
||||||
case GGML_GLU_OP_SWIGLU:
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
case GGML_GLU_OP_SWIGLU_OAI:
|
||||||
case GGML_GLU_OP_GEGLU_ERF:
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
case GGML_GLU_OP_GEGLU_QUICK:
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
return ggml_is_contiguous_1(op->src[0]);
|
return ggml_is_contiguous_1(op->src[0]);
|
||||||
|
@ -3282,6 +3299,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
|
@ -3428,6 +3446,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_PERMUTE:
|
case GGML_OP_PERMUTE:
|
||||||
case GGML_OP_TRANSPOSE:
|
case GGML_OP_TRANSPOSE:
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
case GGML_OP_ADD1:
|
case GGML_OP_ADD1:
|
||||||
case GGML_OP_SUB:
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
|
@ -3508,6 +3527,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
|
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
|
||||||
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
|
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
|
||||||
}
|
}
|
||||||
|
// TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
|
||||||
|
if (op->src[4] && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) { // currently only sinks for head_size 64 and 128 are supported
|
||||||
|
return false;
|
||||||
|
}
|
||||||
if (op->src[0]->ne[0] == 192) {
|
if (op->src[0]->ne[0] == 192) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
#include "im2col.cuh"
|
#include "im2col.cuh"
|
||||||
|
|
||||||
#define MIN(a, b) (a) < (b) ? (a) : (b)
|
|
||||||
|
|
||||||
#define MAX_GRIDDIM_Z 65535
|
#define MAX_GRIDDIM_Z 65535
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -38,6 +36,9 @@ static __global__ void im2col_kernel(
|
||||||
dst[offset_dst] = x[offset_src + iih * IW + iiw];
|
dst[offset_dst] = x[offset_src + iih * IW + iiw];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GGML_UNUSED(IC);
|
||||||
|
GGML_UNUSED(KH);
|
||||||
}
|
}
|
||||||
|
|
||||||
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
|
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
|
||||||
|
|
|
@ -20,6 +20,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
|
mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
|
mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
|
||||||
|
break;
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
|
mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
|
@ -284,6 +287,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
|
|
|
@ -59,6 +59,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
||||||
return MMQ_Q8_1_DS_LAYOUT_DS4;
|
return MMQ_Q8_1_DS_LAYOUT_DS4;
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
return MMQ_Q8_1_DS_LAYOUT_D4;
|
return MMQ_Q8_1_DS_LAYOUT_D4;
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
|
return MMQ_Q8_1_DS_LAYOUT_D4;
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
return MMQ_Q8_1_DS_LAYOUT_D2S6;
|
return MMQ_Q8_1_DS_LAYOUT_D2S6;
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
|
@ -171,6 +173,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
||||||
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
|
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
|
||||||
case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
|
case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
|
||||||
case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
|
case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
|
||||||
|
case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
|
||||||
case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
|
case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
|
||||||
case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
|
case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
|
||||||
case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
|
case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
|
||||||
|
@ -207,6 +210,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
||||||
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||||
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
|
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||||
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||||
|
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||||
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
|
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
|
||||||
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
|
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
|
||||||
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
|
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||||
|
@ -693,6 +697,71 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_mxfp4(
|
||||||
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
||||||
|
constexpr int nwarps = mmq_get_nwarps_device();
|
||||||
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||||
|
|
||||||
|
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
||||||
|
int * x_qs = (int *) x_tile;
|
||||||
|
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
||||||
|
#else
|
||||||
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
|
||||||
|
int * x_qs = (int *) x_tile;
|
||||||
|
float * x_df = (float *) (x_qs + txs.qs);
|
||||||
|
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
||||||
|
|
||||||
|
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
|
||||||
|
constexpr int nrows = warp_size / threads_per_row;
|
||||||
|
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
||||||
|
const int kbx = txi / QI_MXFP4;
|
||||||
|
const int kqsx = txi % QI_MXFP4;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
||||||
|
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
||||||
|
|
||||||
|
if (need_check) {
|
||||||
|
i = min(i, i_max);
|
||||||
|
}
|
||||||
|
|
||||||
|
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx;
|
||||||
|
|
||||||
|
const int aux_q4 = get_int_b1(bxi->qs, kqsx);
|
||||||
|
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
|
||||||
|
const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
|
||||||
|
|
||||||
|
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
||||||
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
|
||||||
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
|
||||||
|
#else
|
||||||
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
|
||||||
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
|
||||||
|
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
|
||||||
|
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
||||||
|
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
||||||
|
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
||||||
|
|
||||||
|
if (need_check) {
|
||||||
|
i = min(i, i_max);
|
||||||
|
}
|
||||||
|
|
||||||
|
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
|
||||||
|
|
||||||
|
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
||||||
|
x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
|
||||||
|
#else
|
||||||
|
x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
|
||||||
|
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <int mmq_x, int mmq_y>
|
template <int mmq_x, int mmq_y>
|
||||||
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
||||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||||
|
@ -2269,7 +2338,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||||
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
|
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
|
||||||
|
|
||||||
const int aux_q4 = get_int_b2(bxi->qs, kqsx);
|
const int aux_q4 = get_int_b2(bxi->qs, kqsx);
|
||||||
const int2 v = get_int_from_table_16(aux_q4);
|
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
||||||
const int k0 = kbx * (2 * QI4_NL) + kqsx;
|
const int k0 = kbx * (2 * QI4_NL) + kqsx;
|
||||||
|
|
||||||
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
||||||
|
@ -2708,7 +2777,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||||
const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
|
const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
|
||||||
|
|
||||||
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
|
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
|
||||||
const int2 v = get_int_from_table_16(aux_q4);
|
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
||||||
const int k0 = 8 * (kqsx / 4) + kqsx % 4;
|
const int k0 = 8 * (kqsx / 4) + kqsx % 4;
|
||||||
|
|
||||||
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
||||||
|
@ -2864,6 +2933,14 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
|
||||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <int mmq_x, int mmq_y, bool need_check>
|
||||||
|
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
||||||
|
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
|
||||||
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
|
||||||
|
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
||||||
|
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
||||||
|
};
|
||||||
|
|
||||||
template <int mmq_x, int mmq_y, bool need_check>
|
template <int mmq_x, int mmq_y, bool need_check>
|
||||||
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
|
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
|
||||||
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
|
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
|
||||||
|
@ -3643,6 +3720,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
|
||||||
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
|
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
|
||||||
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
|
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
|
||||||
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
|
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
|
||||||
|
extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
|
||||||
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
|
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
|
||||||
extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
|
extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
|
||||||
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
|
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
|
||||||
|
|
|
@ -13,6 +13,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
|
||||||
case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
|
case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
|
||||||
case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
|
case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
|
||||||
case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
|
case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
|
||||||
|
case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1;
|
||||||
case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
|
case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
|
||||||
case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
|
case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
|
||||||
case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
|
case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
|
||||||
|
@ -38,6 +39,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
|
||||||
case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
|
case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
|
||||||
case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
|
case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
|
||||||
case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
|
case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
|
||||||
|
case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ;
|
||||||
case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
|
case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
|
||||||
case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
|
case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
|
||||||
case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
|
case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
|
||||||
|
@ -384,6 +386,13 @@ static void mul_mat_vec_q_switch_type(
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||||
stream);
|
stream);
|
||||||
break;
|
break;
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
|
||||||
|
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||||
|
stream);
|
||||||
|
break;
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
|
||||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
|
|
|
@ -45,7 +45,7 @@ struct soft_max_params {
|
||||||
#endif // __clang__
|
#endif // __clang__
|
||||||
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
||||||
static __global__ void soft_max_f32(
|
static __global__ void soft_max_f32(
|
||||||
const float * x, const T * mask, float * dst, const soft_max_params p) {
|
const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) {
|
||||||
const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
|
const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
|
@ -77,7 +77,7 @@ static __global__ void soft_max_f32(
|
||||||
// shared memory buffer to cache values between iterations:
|
// shared memory buffer to cache values between iterations:
|
||||||
float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
|
float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
|
||||||
|
|
||||||
float max_val = -INFINITY;
|
float max_val = sinks ? sinks[i02] : -INFINITY;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||||
|
@ -143,6 +143,10 @@ static __global__ void soft_max_f32(
|
||||||
tmp = warp_reduce_sum(tmp);
|
tmp = warp_reduce_sum(tmp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sinks) {
|
||||||
|
tmp += expf(sinks[i02] - max_val);
|
||||||
|
}
|
||||||
|
|
||||||
const float inv_sum = 1.0f / tmp;
|
const float inv_sum = 1.0f / tmp;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -183,7 +187,7 @@ static __global__ void soft_max_back_f32(
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int... Ns, typename T>
|
template<int... Ns, typename T>
|
||||||
static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
|
static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst,
|
||||||
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
|
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
|
||||||
{
|
{
|
||||||
const int id = ggml_cuda_get_device();
|
const int id = ggml_cuda_get_device();
|
||||||
|
@ -196,7 +200,7 @@ static void launch_soft_max_kernels(const float * x, const T * mask, float * dst
|
||||||
if (p.ncols == ncols) {
|
if (p.ncols == ncols) {
|
||||||
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
|
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
|
||||||
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
|
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
(x, mask, dst, p);
|
(x, mask, sinks, dst, p);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
@ -209,12 +213,12 @@ static void launch_soft_max_kernels(const float * x, const T * mask, float * dst
|
||||||
|
|
||||||
//default case
|
//default case
|
||||||
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
|
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
|
||||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
|
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
|
static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
|
||||||
int nth = WARP_SIZE;
|
int nth = WARP_SIZE;
|
||||||
const int64_t ncols_x = params.ncols;
|
const int64_t ncols_x = params.ncols;
|
||||||
|
|
||||||
|
@ -230,10 +234,10 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
|
||||||
|
|
||||||
|
|
||||||
if (nbytes_shared <= smpbo) {
|
if (nbytes_shared <= smpbo) {
|
||||||
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
|
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
|
||||||
} else {
|
} else {
|
||||||
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
||||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
|
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -249,9 +253,11 @@ static void soft_max_back_f32_cuda(
|
||||||
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const ggml_tensor * src1 = dst->src[1];
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
const ggml_tensor * src2 = dst->src[2];
|
||||||
|
|
||||||
const float * src0_d = (const float *) src0->data;
|
const float * src0_d = (const float *) src0->data;
|
||||||
const void * src1_d = src1 ? (const void *) src1->data : nullptr;
|
const void * src1_d = src1 ? (const void *) src1->data : nullptr;
|
||||||
|
const void * src2_d = src2 ? (const void *) src2->data : nullptr;
|
||||||
float * dst_d = (float *) dst->data;
|
float * dst_d = (float *) dst->data;
|
||||||
|
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
@ -309,9 +315,9 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
params.m1 = m1;
|
params.m1 = m1;
|
||||||
|
|
||||||
if (use_f16) {
|
if (use_f16) {
|
||||||
soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, stream);
|
soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream);
|
||||||
} else {
|
} else {
|
||||||
soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream);
|
soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||||
|
|
||||||
|
#include "../mmq.cuh"
|
||||||
|
|
||||||
|
DECL_MMQ_CASE(GGML_TYPE_MXFP4);
|
|
@ -300,6 +300,81 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
|
ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// swiglu_oai
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) {
|
||||||
|
const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (i >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// perform base op and multiply with gate (either offset in same tensor or a separate one)
|
||||||
|
const int64_t j0 = (i / n) * o0 + (i % n);
|
||||||
|
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
||||||
|
|
||||||
|
float xi = x[j0];
|
||||||
|
float gi = g[j1];
|
||||||
|
xi = fminf(xi, limit);
|
||||||
|
gi = fmaxf(fminf(gi, limit), -limit);
|
||||||
|
|
||||||
|
float out_glu = xi / (1.0f + expf(-xi * alpha));
|
||||||
|
out_glu = out_glu * (1.0f + gi);
|
||||||
|
|
||||||
|
dst[i] = out_glu;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
|
||||||
|
const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
|
||||||
|
swiglu_oai_kernel<<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1, alpha, limit);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
void * src0_d = src0->data;
|
||||||
|
void * src1_d = src1 ? src1->data : src0->data;
|
||||||
|
const int64_t src0_o = src0->nb[1];
|
||||||
|
const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||||
|
void * dst_d = dst->data;
|
||||||
|
const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||||
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src0->type == dst->type);
|
||||||
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
|
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
|
||||||
|
|
||||||
|
if (src1) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||||
|
GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
|
||||||
|
GGML_ASSERT(src1->ne[0] == nc);
|
||||||
|
GGML_ASSERT(src0->type == src1->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
//const int32_t swapped = ((const int32_t *) dst->op_params)[1];
|
||||||
|
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||||
|
const float alpha = ggml_get_op_params_f32(dst, 2);
|
||||||
|
const float limit = ggml_get_op_params_f32(dst, 3);
|
||||||
|
|
||||||
|
float * src0_p = (float *) src0_d;
|
||||||
|
float * src1_p = (float *) src1_d;
|
||||||
|
|
||||||
|
if (!src1) {
|
||||||
|
src0_p += swapped ? nc : 0;
|
||||||
|
src1_p += swapped ? 0 : nc;
|
||||||
|
}
|
||||||
|
|
||||||
|
swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
|
||||||
|
}
|
||||||
|
|
||||||
/* silu_back */
|
/* silu_back */
|
||||||
|
|
||||||
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
|
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
|
||||||
|
|
|
@ -67,6 +67,8 @@ void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
|
@ -1,8 +1,20 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) {
|
||||||
|
const uint8_t * x8 = (const uint8_t *) x;
|
||||||
|
|
||||||
|
int x32 = x8[4*i32 + 0] << 0;
|
||||||
|
x32 |= x8[4*i32 + 1] << 8;
|
||||||
|
x32 |= x8[4*i32 + 2] << 16;
|
||||||
|
x32 |= x8[4*i32 + 3] << 24;
|
||||||
|
|
||||||
|
return x32;
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
|
static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
|
||||||
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
|
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
|
||||||
|
|
||||||
|
@ -16,6 +28,20 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
|
||||||
return ((const int *) x)[i32]; // assume at least 4 byte alignment
|
return ((const int *) x)[i32]; // assume at least 4 byte alignment
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
|
||||||
|
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
|
||||||
|
const int8_t * q0_8 = (const int8_t *) &q0_32;
|
||||||
|
const char4 val0_8 = make_char4(
|
||||||
|
table[q0_8[0]], table[q0_8[1]], table[q0_8[2]], table[q0_8[3]]);
|
||||||
|
|
||||||
|
const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
|
||||||
|
const int8_t * q1_8 = (const int8_t *) &q1_32;
|
||||||
|
const char4 val1_8 = make_char4(
|
||||||
|
table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
|
||||||
|
|
||||||
|
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
|
||||||
|
}
|
||||||
|
|
||||||
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
||||||
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
||||||
|
|
||||||
|
@ -211,6 +237,30 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_
|
||||||
return d8_1*sumf;
|
return d8_1*sumf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define VDR_MXFP4_Q8_1_MMVQ 2
|
||||||
|
#define VDR_MXFP4_Q8_1_MMQ 4
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
|
||||||
|
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||||
|
|
||||||
|
const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx;
|
||||||
|
|
||||||
|
const int * q8 = (const int *) bq8_1->qs + iqs;
|
||||||
|
|
||||||
|
int sumi = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {
|
||||||
|
const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
|
||||||
|
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
|
||||||
|
|
||||||
|
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
|
||||||
|
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = ggml_cuda_e8m0_to_fp32(bq4->e) * 0.5f * __low2float(bq8_1->ds);
|
||||||
|
return d * sumi;
|
||||||
|
}
|
||||||
|
|
||||||
#define VDR_Q2_K_Q8_1_MMVQ 1
|
#define VDR_Q2_K_Q8_1_MMVQ 1
|
||||||
#define VDR_Q2_K_Q8_1_MMQ 4
|
#define VDR_Q2_K_Q8_1_MMQ 4
|
||||||
|
|
||||||
|
@ -1068,20 +1118,6 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
|
||||||
return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
|
return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
|
|
||||||
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
|
|
||||||
const int8_t * q0_8 = (const int8_t *) &q0_32;
|
|
||||||
const char4 val0_8 = make_char4(
|
|
||||||
kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
|
|
||||||
|
|
||||||
const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
|
|
||||||
const int8_t * q1_8 = (const int8_t *) &q1_32;
|
|
||||||
const char4 val1_8 = make_char4(
|
|
||||||
kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]);
|
|
||||||
|
|
||||||
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
|
|
||||||
}
|
|
||||||
|
|
||||||
#define VDR_IQ4_NL_Q8_1_MMVQ 2
|
#define VDR_IQ4_NL_Q8_1_MMVQ 2
|
||||||
#define VDR_IQ4_NL_Q8_1_MMQ 4
|
#define VDR_IQ4_NL_Q8_1_MMQ 4
|
||||||
|
|
||||||
|
@ -1096,7 +1132,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
|
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
|
||||||
const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
|
const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
|
||||||
const int2 v = get_int_from_table_16(aux_q4);
|
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
||||||
|
|
||||||
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
|
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
|
||||||
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
|
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
|
||||||
|
@ -1118,7 +1154,7 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 4; ++j) {
|
for (int j = 0; j < 4; ++j) {
|
||||||
const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
|
const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
|
||||||
const int2 v = get_int_from_table_16(aux_q4);
|
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
||||||
|
|
||||||
const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
|
const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
|
||||||
const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
|
const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
|
||||||
|
|
4
ggml/src/ggml-cuda/vendors/cuda.h
vendored
4
ggml/src/ggml-cuda/vendors/cuda.h
vendored
|
@ -6,6 +6,10 @@
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
#if CUDART_VERSION >= 12050
|
||||||
|
#include <cuda_fp8.h>
|
||||||
|
#endif // CUDART_VERSION >= 12050
|
||||||
|
|
||||||
#if CUDART_VERSION < 11020
|
#if CUDART_VERSION < 11020
|
||||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
||||||
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
|
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
|
||||||
|
|
|
@ -410,6 +410,67 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
||||||
#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
|
#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
|
||||||
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
|
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
|
||||||
|
|
||||||
|
static inline float ggml_e8m0_to_fp32(uint8_t x) {
|
||||||
|
uint32_t bits; // Stores the raw bit representation of the float
|
||||||
|
|
||||||
|
// Handle special case for minimum exponent (denormalized float)
|
||||||
|
if (x == 0) {
|
||||||
|
// Bit pattern for 2^(-127):
|
||||||
|
// - Sign bit: 0 (positive)
|
||||||
|
// - Exponent: 0 (denormalized number)
|
||||||
|
// - Mantissa: 0x400000 (0.5 in fractional form)
|
||||||
|
// Value = 0.5 * 2^(-126) = 2^(-127)
|
||||||
|
bits = 0x00400000;
|
||||||
|
}
|
||||||
|
// note: disabled as we don't need to handle NaNs
|
||||||
|
//// Handle special case for NaN (all bits set)
|
||||||
|
//else if (x == 0xFF) {
|
||||||
|
// // Standard quiet NaN pattern:
|
||||||
|
// // - Sign bit: 0
|
||||||
|
// // - Exponent: all 1s (0xFF)
|
||||||
|
// // - Mantissa: 0x400000 (quiet NaN flag)
|
||||||
|
// bits = 0x7FC00000;
|
||||||
|
//}
|
||||||
|
// Normalized values (most common case)
|
||||||
|
else {
|
||||||
|
// Construct normalized float by shifting exponent into position:
|
||||||
|
// - Exponent field: 8 bits (positions 30-23)
|
||||||
|
// - Mantissa: 0 (implicit leading 1)
|
||||||
|
// Value = 2^(x - 127)
|
||||||
|
bits = (uint32_t) x << 23;
|
||||||
|
}
|
||||||
|
|
||||||
|
float result; // Final float value
|
||||||
|
// Safely reinterpret bit pattern as float without type-punning issues
|
||||||
|
memcpy(&result, &bits, sizeof(float));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Equal to ggml_e8m0_to_fp32/2
|
||||||
|
// Useful with MXFP4 quantization since the E0M2 values are doubled
|
||||||
|
static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
|
||||||
|
uint32_t bits;
|
||||||
|
|
||||||
|
// For x < 2: use precomputed denormal patterns
|
||||||
|
if (x < 2) {
|
||||||
|
// 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)
|
||||||
|
bits = 0x00200000 << x;
|
||||||
|
}
|
||||||
|
// For x >= 2: normalized exponent adjustment
|
||||||
|
else {
|
||||||
|
// 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)
|
||||||
|
bits = (uint32_t)(x - 1) << 23;
|
||||||
|
}
|
||||||
|
// Note: NaNs are not handled here
|
||||||
|
|
||||||
|
float result;
|
||||||
|
memcpy(&result, &bits, sizeof(float));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
|
||||||
|
#define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Converts brain16 to float32.
|
* Converts brain16 to float32.
|
||||||
*
|
*
|
||||||
|
|
|
@ -23,6 +23,9 @@
|
||||||
#define N_R0_Q8_0 4
|
#define N_R0_Q8_0 4
|
||||||
#define N_SG_Q8_0 2
|
#define N_SG_Q8_0 2
|
||||||
|
|
||||||
|
#define N_R0_MXFP4 2
|
||||||
|
#define N_SG_MXFP4 2
|
||||||
|
|
||||||
#define N_R0_Q2_K 4
|
#define N_R0_Q2_K 4
|
||||||
#define N_SG_Q2_K 2
|
#define N_SG_Q2_K 2
|
||||||
|
|
||||||
|
@ -129,6 +132,15 @@ typedef struct {
|
||||||
uint64_t o1[8];
|
uint64_t o1[8];
|
||||||
} ggml_metal_kargs_bin;
|
} ggml_metal_kargs_bin;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int64_t ne0;
|
||||||
|
int64_t ne1;
|
||||||
|
size_t nb01;
|
||||||
|
size_t nb02;
|
||||||
|
size_t nb11;
|
||||||
|
size_t nb21;
|
||||||
|
} ggml_metal_kargs_add_id;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int32_t ne00;
|
int32_t ne00;
|
||||||
int32_t ne01;
|
int32_t ne01;
|
||||||
|
@ -444,6 +456,8 @@ typedef struct{
|
||||||
uint64_t nb1;
|
uint64_t nb1;
|
||||||
int32_t i00;
|
int32_t i00;
|
||||||
int32_t i10;
|
int32_t i10;
|
||||||
|
float alpha;
|
||||||
|
float limit;
|
||||||
} ggml_metal_kargs_glu;
|
} ggml_metal_kargs_glu;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
|
|
@ -195,6 +195,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
|
GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
|
||||||
GGML_METAL_KERNEL_TYPE_DIV,
|
GGML_METAL_KERNEL_TYPE_DIV,
|
||||||
GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
|
GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_ADD_ID,
|
||||||
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
|
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
|
||||||
|
@ -234,6 +235,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
|
||||||
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4,
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
|
||||||
|
@ -286,6 +288,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
|
||||||
|
@ -310,6 +313,10 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
|
||||||
|
@ -351,6 +358,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
|
||||||
|
@ -373,6 +381,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
|
||||||
|
@ -397,6 +406,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
|
||||||
|
@ -579,6 +589,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_REGLU,
|
GGML_METAL_KERNEL_TYPE_REGLU,
|
||||||
GGML_METAL_KERNEL_TYPE_GEGLU,
|
GGML_METAL_KERNEL_TYPE_GEGLU,
|
||||||
GGML_METAL_KERNEL_TYPE_SWIGLU,
|
GGML_METAL_KERNEL_TYPE_SWIGLU,
|
||||||
|
GGML_METAL_KERNEL_TYPE_SWIGLU_OAI,
|
||||||
GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
|
GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
|
||||||
GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
|
GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
|
||||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||||
|
@ -1199,6 +1210,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ID, add_id, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
|
||||||
|
@ -1238,6 +1250,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
|
||||||
|
@ -1290,6 +1303,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
|
||||||
|
@ -1314,6 +1328,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2, mul_mv_ext_mxfp4_f32_r1_2, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3, mul_mv_ext_mxfp4_f32_r1_3, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4, mul_mv_ext_mxfp4_f32_r1_4, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5, mul_mv_ext_mxfp4_f32_r1_5, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
|
||||||
|
@ -1355,6 +1373,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
|
||||||
|
@ -1377,6 +1396,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
|
||||||
|
@ -1401,6 +1422,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, mul_mm_id_mxfp4_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm);
|
||||||
|
@ -1583,6 +1605,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_OAI, swiglu_oai, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||||
|
@ -1774,6 +1797,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||||
case GGML_GLU_OP_REGLU:
|
case GGML_GLU_OP_REGLU:
|
||||||
case GGML_GLU_OP_GEGLU:
|
case GGML_GLU_OP_GEGLU:
|
||||||
case GGML_GLU_OP_SWIGLU:
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
case GGML_GLU_OP_SWIGLU_OAI:
|
||||||
case GGML_GLU_OP_GEGLU_ERF:
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
case GGML_GLU_OP_GEGLU_QUICK:
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||||
|
@ -1791,6 +1815,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||||
case GGML_OP_SUB:
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
return op->src[0]->type == GGML_TYPE_F32;
|
return op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
case GGML_OP_REPEAT:
|
case GGML_OP_REPEAT:
|
||||||
|
@ -2042,6 +2067,7 @@ static int ggml_metal_encode_node(
|
||||||
|
|
||||||
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
|
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
|
||||||
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
||||||
|
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT;
|
||||||
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
|
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
|
||||||
|
|
||||||
size_t offs_src0 = 0;
|
size_t offs_src0 = 0;
|
||||||
|
@ -2291,6 +2317,38 @@ static int ggml_metal_encode_node(
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src2t == GGML_TYPE_I32);
|
||||||
|
GGML_ASSERT(dstt == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_rows(src0));
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ID].pipeline;
|
||||||
|
|
||||||
|
ggml_metal_kargs_add_id args = {
|
||||||
|
/*.ne0 =*/ ne0,
|
||||||
|
/*.ne1 =*/ ne1,
|
||||||
|
/*.nb01 =*/ nb01,
|
||||||
|
/*.nb02 =*/ nb02,
|
||||||
|
/*.nb11 =*/ nb11,
|
||||||
|
/*.nb21 =*/ nb21,
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||||
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||||
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
||||||
|
|
||||||
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
|
} break;
|
||||||
case GGML_OP_REPEAT:
|
case GGML_OP_REPEAT:
|
||||||
{
|
{
|
||||||
id<MTLComputePipelineState> pipeline;
|
id<MTLComputePipelineState> pipeline;
|
||||||
|
@ -2710,6 +2768,9 @@ static int ggml_metal_encode_node(
|
||||||
case GGML_GLU_OP_SWIGLU:
|
case GGML_GLU_OP_SWIGLU:
|
||||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
|
||||||
break;
|
break;
|
||||||
|
case GGML_GLU_OP_SWIGLU_OAI:
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_OAI].pipeline;
|
||||||
|
break;
|
||||||
case GGML_GLU_OP_GEGLU_ERF:
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
|
||||||
break;
|
break;
|
||||||
|
@ -2720,7 +2781,9 @@ static int ggml_metal_encode_node(
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
||||||
const int32_t swp = ((const int32_t *) dst->op_params)[1];
|
const int32_t swp = ggml_get_op_params_i32(dst, 1);
|
||||||
|
const float alpha = ggml_get_op_params_f32(dst, 2);
|
||||||
|
const float limit = ggml_get_op_params_f32(dst, 3);
|
||||||
|
|
||||||
const int32_t i00 = swp ? ne0 : 0;
|
const int32_t i00 = swp ? ne0 : 0;
|
||||||
const int32_t i10 = swp ? 0 : ne0;
|
const int32_t i10 = swp ? 0 : ne0;
|
||||||
|
@ -2734,6 +2797,8 @@ static int ggml_metal_encode_node(
|
||||||
/*.nb1 =*/ nb1,
|
/*.nb1 =*/ nb1,
|
||||||
/*.i00 =*/ src1 ? 0 : i00,
|
/*.i00 =*/ src1 ? 0 : i00,
|
||||||
/*.i10 =*/ src1 ? 0 : i10,
|
/*.i10 =*/ src1 ? 0 : i10,
|
||||||
|
/*.alpha=*/ alpha,
|
||||||
|
/*.limit=*/ limit
|
||||||
};
|
};
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
@ -2992,8 +3057,13 @@ static int ggml_metal_encode_node(
|
||||||
} else {
|
} else {
|
||||||
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
|
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
|
||||||
}
|
}
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
if (id_src2) {
|
||||||
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||||
|
} else {
|
||||||
|
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:2];
|
||||||
|
}
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||||
|
[encoder setBytes:&args length:sizeof(args) atIndex:4];
|
||||||
|
|
||||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||||
|
|
||||||
|
@ -3291,6 +3361,7 @@ static int ggml_metal_encode_node(
|
||||||
src0t == GGML_TYPE_Q5_0 ||
|
src0t == GGML_TYPE_Q5_0 ||
|
||||||
src0t == GGML_TYPE_Q5_1 ||
|
src0t == GGML_TYPE_Q5_1 ||
|
||||||
src0t == GGML_TYPE_Q8_0 ||
|
src0t == GGML_TYPE_Q8_0 ||
|
||||||
|
src0t == GGML_TYPE_MXFP4 ||
|
||||||
src0t == GGML_TYPE_IQ4_NL ||
|
src0t == GGML_TYPE_IQ4_NL ||
|
||||||
false) && (ne11 >= 2 && ne11 <= 8)
|
false) && (ne11 >= 2 && ne11 <= 8)
|
||||||
) ||
|
) ||
|
||||||
|
@ -3383,6 +3454,14 @@ static int ggml_metal_encode_node(
|
||||||
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
|
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
|
||||||
default: GGML_ABORT("not implemented");
|
default: GGML_ABORT("not implemented");
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
|
switch (r1ptg) {
|
||||||
|
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2].pipeline; break;
|
||||||
|
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3].pipeline; break;
|
||||||
|
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4].pipeline; break;
|
||||||
|
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5].pipeline; break;
|
||||||
|
default: GGML_ABORT("not implemented");
|
||||||
|
} break;
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
switch (r1ptg) {
|
switch (r1ptg) {
|
||||||
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
|
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
|
||||||
|
@ -3481,6 +3560,7 @@ static int ggml_metal_encode_node(
|
||||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
|
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
|
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
|
||||||
|
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
|
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
|
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
|
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
|
||||||
|
@ -3623,6 +3703,13 @@ static int ggml_metal_encode_node(
|
||||||
nr0 = N_R0_Q8_0;
|
nr0 = N_R0_Q8_0;
|
||||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
|
{
|
||||||
|
nsg = N_SG_MXFP4;
|
||||||
|
nr0 = N_R0_MXFP4;
|
||||||
|
smem = 32*sizeof(float);
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline;
|
||||||
|
} break;
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
{
|
{
|
||||||
nsg = N_SG_Q2_K;
|
nsg = N_SG_Q2_K;
|
||||||
|
@ -3756,8 +3843,6 @@ static int ggml_metal_encode_node(
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
{
|
{
|
||||||
// src2 = ids
|
// src2 = ids
|
||||||
const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
|
|
||||||
|
|
||||||
GGML_ASSERT(src2t == GGML_TYPE_I32);
|
GGML_ASSERT(src2t == GGML_TYPE_I32);
|
||||||
|
|
||||||
GGML_ASSERT(!ggml_is_transposed(src0));
|
GGML_ASSERT(!ggml_is_transposed(src0));
|
||||||
|
@ -3883,6 +3968,7 @@ static int ggml_metal_encode_node(
|
||||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break;
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break;
|
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break;
|
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break;
|
||||||
|
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break;
|
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break;
|
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break;
|
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break;
|
||||||
|
@ -4018,6 +4104,13 @@ static int ggml_metal_encode_node(
|
||||||
nr0 = N_R0_Q8_0;
|
nr0 = N_R0_Q8_0;
|
||||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
|
{
|
||||||
|
nsg = N_SG_MXFP4;
|
||||||
|
nr0 = N_R0_MXFP4;
|
||||||
|
smem = 32*sizeof(float);
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline;
|
||||||
|
} break;
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
{
|
{
|
||||||
nsg = N_SG_Q2_K;
|
nsg = N_SG_Q2_K;
|
||||||
|
@ -4170,6 +4263,7 @@ static int ggml_metal_encode_node(
|
||||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
|
||||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
|
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
|
||||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
|
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
|
||||||
|
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4 ].pipeline; break;
|
||||||
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
|
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
|
||||||
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
|
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
|
||||||
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
|
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
|
||||||
|
@ -4980,11 +5074,14 @@ static int ggml_metal_encode_node(
|
||||||
GGML_ASSERT(ne11 == ne21);
|
GGML_ASSERT(ne11 == ne21);
|
||||||
GGML_ASSERT(ne12 == ne22);
|
GGML_ASSERT(ne12 == ne22);
|
||||||
|
|
||||||
struct ggml_tensor * src3 = node->src[3];
|
struct ggml_tensor * src3 = node->src[3]; // mask
|
||||||
|
struct ggml_tensor * src4 = node->src[4]; // sinks
|
||||||
|
|
||||||
size_t offs_src3 = 0;
|
size_t offs_src3 = 0;
|
||||||
|
size_t offs_src4 = 0;
|
||||||
|
|
||||||
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
||||||
|
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
|
||||||
|
|
||||||
GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
|
GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
|
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
|
||||||
|
@ -5000,8 +5097,6 @@ static int ggml_metal_encode_node(
|
||||||
const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
|
const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
|
||||||
const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
|
const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
|
||||||
|
|
||||||
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
|
||||||
|
|
||||||
float scale;
|
float scale;
|
||||||
float max_bias;
|
float max_bias;
|
||||||
float logit_softcap;
|
float logit_softcap;
|
||||||
|
@ -5389,7 +5484,12 @@ static int ggml_metal_encode_node(
|
||||||
} else {
|
} else {
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
|
||||||
}
|
}
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:5];
|
if (id_src4) {
|
||||||
|
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:5];
|
||||||
|
} else {
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
|
||||||
|
}
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
||||||
|
|
||||||
if (!use_vec_kernel) {
|
if (!use_vec_kernel) {
|
||||||
// half8x8 kernel
|
// half8x8 kernel
|
||||||
|
|
|
@ -35,6 +35,10 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
|
||||||
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
||||||
};
|
};
|
||||||
|
|
||||||
|
constexpr constant static float kvalues_mxfp4_f[16] = {
|
||||||
|
0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
|
||||||
|
};
|
||||||
|
|
||||||
static inline int best_index_int8(int n, constant float * val, float x) {
|
static inline int best_index_int8(int n, constant float * val, float x) {
|
||||||
if (x <= val[0]) return 0;
|
if (x <= val[0]) return 0;
|
||||||
if (x >= val[n-1]) return n-1;
|
if (x >= val[n-1]) return n-1;
|
||||||
|
@ -46,6 +50,18 @@ static inline int best_index_int8(int n, constant float * val, float x) {
|
||||||
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline float e8m0_to_fp32(uint8_t x) {
|
||||||
|
uint32_t bits;
|
||||||
|
|
||||||
|
if (x == 0) {
|
||||||
|
bits = 0x00400000;
|
||||||
|
} else {
|
||||||
|
bits = (uint32_t) x << 23;
|
||||||
|
}
|
||||||
|
|
||||||
|
return as_type<float>(bits);
|
||||||
|
}
|
||||||
|
|
||||||
// NOTE: this is not dequantizing - we are simply fitting the template
|
// NOTE: this is not dequantizing - we are simply fitting the template
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
||||||
|
@ -242,6 +258,27 @@ void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
|
||||||
|
#pragma METAL fp math_mode(safe)
|
||||||
|
float amax = 0.0f; // absolute max
|
||||||
|
|
||||||
|
for (int j = 0; j < QK8_0; j++) {
|
||||||
|
const float v = src[j];
|
||||||
|
amax = MAX(amax, fabs(v));
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = amax / ((1 << 7) - 1);
|
||||||
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
|
dst.d = d;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK8_0; ++j) {
|
||||||
|
const float x0 = src[j]*id;
|
||||||
|
|
||||||
|
dst.qs[j] = round(x0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
|
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
|
||||||
#pragma METAL fp math_mode(safe)
|
#pragma METAL fp math_mode(safe)
|
||||||
float amax = 0.0f; // absolute max
|
float amax = 0.0f; // absolute max
|
||||||
|
@ -462,25 +499,34 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
|
template <typename type4x4>
|
||||||
#pragma METAL fp math_mode(safe)
|
void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
|
||||||
float amax = 0.0f; // absolute max
|
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
|
||||||
|
|
||||||
for (int j = 0; j < QK8_0; j++) {
|
const float d = e8m0_to_fp32(xb->e);
|
||||||
const float v = src[j];
|
const uint8_t shr = il >= 1 ? 4 : 0;
|
||||||
amax = MAX(amax, fabs(v));
|
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F];
|
||||||
|
reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F];
|
||||||
|
reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F];
|
||||||
|
reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const float d = amax / ((1 << 7) - 1);
|
template <typename type4>
|
||||||
const float id = d ? 1.0f/d : 0.0f;
|
void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) {
|
||||||
|
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
|
||||||
|
|
||||||
dst.d = d;
|
const float d = e8m0_to_fp32(xb->e);
|
||||||
|
const short il4 = il%4;
|
||||||
|
|
||||||
for (int j = 0; j < QK8_0; ++j) {
|
const uint8_t shr = il >= 4 ? 4 : 0;
|
||||||
const float x0 = src[j]*id;
|
|
||||||
|
|
||||||
dst.qs[j] = round(x0);
|
reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F];
|
||||||
}
|
reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F];
|
||||||
|
reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F];
|
||||||
|
reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
|
@ -960,6 +1006,32 @@ kernel void kernel_div(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_add_id(
|
||||||
|
constant ggml_metal_kargs_add_id & args,
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device const char * src2,
|
||||||
|
device char * dst,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int i1 = tgpig.x;
|
||||||
|
const int i2 = tgpig.y;
|
||||||
|
|
||||||
|
const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
|
||||||
|
|
||||||
|
const size_t nb1 = args.ne0 * sizeof(float);
|
||||||
|
const size_t nb2 = args.ne1 * nb1;
|
||||||
|
|
||||||
|
device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
|
||||||
|
device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
|
||||||
|
device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
|
||||||
|
|
||||||
|
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||||
|
dst_row[i0] = src0_row[i0] + src1_row[i0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
kernel void kernel_repeat(
|
kernel void kernel_repeat(
|
||||||
constant ggml_metal_kargs_repeat & args,
|
constant ggml_metal_kargs_repeat & args,
|
||||||
|
@ -1431,6 +1503,32 @@ kernel void kernel_swiglu(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_swiglu_oai(
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device char * dst,
|
||||||
|
constant ggml_metal_kargs_glu & args,
|
||||||
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint ntg[[threads_per_threadgroup]]) {
|
||||||
|
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
||||||
|
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
||||||
|
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
||||||
|
|
||||||
|
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
||||||
|
float x0 = src0_row[i0];
|
||||||
|
float x1 = src1_row[i0];
|
||||||
|
|
||||||
|
x0 = min(x0, args.limit);
|
||||||
|
x1 = max(min(x1, args.limit), -args.limit);
|
||||||
|
|
||||||
|
float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
|
||||||
|
out_glu = out_glu * (1.0f + x1);
|
||||||
|
|
||||||
|
dst_row[i0] = out_glu;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_geglu_erf(
|
kernel void kernel_geglu_erf(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
|
@ -1534,6 +1632,7 @@ template<typename T>
|
||||||
kernel void kernel_soft_max(
|
kernel void kernel_soft_max(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
|
device const char * src2,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
constant ggml_metal_kargs_soft_max & args,
|
constant ggml_metal_kargs_soft_max & args,
|
||||||
threadgroup float * buf [[threadgroup(0)]],
|
threadgroup float * buf [[threadgroup(0)]],
|
||||||
|
@ -1552,6 +1651,7 @@ kernel void kernel_soft_max(
|
||||||
|
|
||||||
device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
||||||
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
||||||
|
device const float * psrc2 = src2 != src0 ? (device const float *) (src2) : nullptr;
|
||||||
device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
||||||
|
|
||||||
float slope = 1.0f;
|
float slope = 1.0f;
|
||||||
|
@ -1567,7 +1667,7 @@ kernel void kernel_soft_max(
|
||||||
}
|
}
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
float lmax = -INFINITY;
|
float lmax = psrc2 ? psrc2[i02] : -INFINITY;
|
||||||
|
|
||||||
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
||||||
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
|
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
|
||||||
|
@ -1623,6 +1723,10 @@ kernel void kernel_soft_max(
|
||||||
sum = simd_sum(sum);
|
sum = simd_sum(sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (psrc2) {
|
||||||
|
sum += exp(psrc2[i02] - max_val);
|
||||||
|
}
|
||||||
|
|
||||||
const float inv_sum = 1.0f/sum;
|
const float inv_sum = 1.0f/sum;
|
||||||
|
|
||||||
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
||||||
|
@ -1634,6 +1738,7 @@ template<typename T>
|
||||||
kernel void kernel_soft_max_4(
|
kernel void kernel_soft_max_4(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
|
device const char * src2,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
constant ggml_metal_kargs_soft_max & args,
|
constant ggml_metal_kargs_soft_max & args,
|
||||||
threadgroup float * buf [[threadgroup(0)]],
|
threadgroup float * buf [[threadgroup(0)]],
|
||||||
|
@ -1652,6 +1757,7 @@ kernel void kernel_soft_max_4(
|
||||||
|
|
||||||
device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
||||||
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
||||||
|
device const float * psrc2 = src2 != src0 ? (device const float * ) (src2) : nullptr;
|
||||||
device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
||||||
|
|
||||||
float slope = 1.0f;
|
float slope = 1.0f;
|
||||||
|
@ -1666,7 +1772,7 @@ kernel void kernel_soft_max_4(
|
||||||
}
|
}
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
float4 lmax4 = -INFINITY;
|
float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
|
||||||
|
|
||||||
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
||||||
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
||||||
|
@ -1725,6 +1831,10 @@ kernel void kernel_soft_max_4(
|
||||||
sum = simd_sum(sum);
|
sum = simd_sum(sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (psrc2) {
|
||||||
|
sum += exp(psrc2[i02] - max_val);
|
||||||
|
}
|
||||||
|
|
||||||
const float inv_sum = 1.0f/sum;
|
const float inv_sum = 1.0f/sum;
|
||||||
|
|
||||||
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
||||||
|
@ -3106,6 +3216,11 @@ template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q4
|
||||||
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>;
|
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>;
|
||||||
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>;
|
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_mxfp4, 32, dequantize_mxfp4_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_mxfp4, 32, dequantize_mxfp4_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_mxfp4, 32, dequantize_mxfp4_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_mxfp4, 32, dequantize_mxfp4_t4>;
|
||||||
|
|
||||||
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
||||||
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
||||||
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
||||||
|
@ -4092,6 +4207,7 @@ kernel void kernel_flash_attn_ext(
|
||||||
device const char * k,
|
device const char * k,
|
||||||
device const char * v,
|
device const char * v,
|
||||||
device const char * mask,
|
device const char * mask,
|
||||||
|
device const char * sinks,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
@ -4407,6 +4523,35 @@ kernel void kernel_flash_attn_ext(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sinks != q && sgitg == 0) {
|
||||||
|
for (ushort j = 0; j < Q; ++j) {
|
||||||
|
const float m = M[j];
|
||||||
|
const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
|
||||||
|
|
||||||
|
M[j] = simd_max(max(M[j], s));
|
||||||
|
|
||||||
|
const float ms = exp(m - M[j]);
|
||||||
|
const float vs = exp(s - M[j]);
|
||||||
|
|
||||||
|
S[j] = S[j]*ms + simd_sum(vs);
|
||||||
|
|
||||||
|
if (tiisg == j) {
|
||||||
|
ss[j*TS + 2*C + j] = ms;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// O = diag(ms)*O
|
||||||
|
{
|
||||||
|
s8x8_t ms;
|
||||||
|
simdgroup_load(ms, ss + 2*C, TS, 0, false);
|
||||||
|
|
||||||
|
#pragma unroll(DV8)
|
||||||
|
for (short i = 0; i < DV8; ++i) {
|
||||||
|
simdgroup_multiply(lo[i], ms, lo[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||||
for (short j = tiisg; j < Q; j += NW) {
|
for (short j = tiisg; j < Q; j += NW) {
|
||||||
ss[j*TS + 0] = S[j];
|
ss[j*TS + 0] = S[j];
|
||||||
|
@ -4618,6 +4763,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||||
device const char * k,
|
device const char * k,
|
||||||
device const char * v,
|
device const char * v,
|
||||||
device const char * mask,
|
device const char * mask,
|
||||||
|
device const char * sinks,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
@ -4835,6 +4981,23 @@ kernel void kernel_flash_attn_ext_vec(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sinks != q && sgitg == 0) {
|
||||||
|
const float m = M;
|
||||||
|
const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
|
||||||
|
|
||||||
|
M = simd_max(max(M, s));
|
||||||
|
|
||||||
|
const float ms = exp(m - M);
|
||||||
|
const float vs = exp(s - M);
|
||||||
|
|
||||||
|
S = S*ms + simd_sum(vs);
|
||||||
|
|
||||||
|
#pragma unroll(DV4/NL)
|
||||||
|
for (short ii = 0; ii < DV4; ii += NL) {
|
||||||
|
lo[ii/NL] *= ms;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
ss[0] = (s_t) S;
|
ss[0] = (s_t) S;
|
||||||
|
@ -6940,6 +7103,95 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
||||||
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<int nr0, int nsg, int nw, typename args_t>
|
||||||
|
void kernel_mul_mv_mxfp4_f32_impl(
|
||||||
|
args_t args,
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device char * dst,
|
||||||
|
threadgroup char * shmem,
|
||||||
|
uint3 tgpig,
|
||||||
|
ushort tiisg,
|
||||||
|
ushort sgitg) {
|
||||||
|
|
||||||
|
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
||||||
|
const int nb = args.ne00/QK_MXFP4;
|
||||||
|
|
||||||
|
const int r0 = tgpig.x;
|
||||||
|
const int r1 = tgpig.y;
|
||||||
|
const int im = tgpig.z;
|
||||||
|
|
||||||
|
const int first_row = (r0 * nsg + sgitg) * nr0;
|
||||||
|
|
||||||
|
const uint i12 = im%args.ne12;
|
||||||
|
const uint i13 = im/args.ne12;
|
||||||
|
|
||||||
|
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
|
device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
|
||||||
|
device const float * y = (device const float *) (src1 + offset1);
|
||||||
|
|
||||||
|
const short ix = tiisg/2; // 0...15
|
||||||
|
const short it = tiisg%2; // 0 or 1
|
||||||
|
|
||||||
|
shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16];
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
float4 yl[4];
|
||||||
|
float sumf[nr0]={0.f};
|
||||||
|
|
||||||
|
device const float * yb = y + ix * QK_MXFP4 + it * 8;
|
||||||
|
|
||||||
|
for (int ib = ix; ib < nb; ib += 16) {
|
||||||
|
device const float4 * y4 = (device const float4 *)yb;
|
||||||
|
yl[0] = y4[0];
|
||||||
|
yl[1] = y4[4];
|
||||||
|
yl[2] = y4[1];
|
||||||
|
yl[3] = y4[5];
|
||||||
|
|
||||||
|
#pragma unroll(nr0)
|
||||||
|
for (short row = 0; row < nr0; row++) {
|
||||||
|
device const block_mxfp4 & xb = x[row*nb + ib];
|
||||||
|
device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it);
|
||||||
|
|
||||||
|
float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
|
||||||
|
float4 acc2 = yl[1]*float4(shmem_f32[q2[0] >> 4 ], shmem_f32[q2[1] >> 4 ], shmem_f32[q2[2] >> 4 ], shmem_f32[q2[3] >> 4 ]);
|
||||||
|
float4 acc3 = yl[2]*float4(shmem_f32[q2[4] & 0x0F], shmem_f32[q2[5] & 0x0F], shmem_f32[q2[6] & 0x0F], shmem_f32[q2[7] & 0x0F]);
|
||||||
|
float4 acc4 = yl[3]*float4(shmem_f32[q2[4] >> 4 ], shmem_f32[q2[5] >> 4 ], shmem_f32[q2[6] >> 4 ], shmem_f32[q2[7] >> 4 ]);
|
||||||
|
|
||||||
|
acc1 = (acc1 + acc3) + (acc2 + acc4);
|
||||||
|
|
||||||
|
sumf[row] += e8m0_to_fp32(xb.e) * ((acc1[0] + acc1[1]) + (acc1[2] + acc1[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
yb += 16 * QK_MXFP4;
|
||||||
|
}
|
||||||
|
|
||||||
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
|
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
||||||
|
float sum_all = simd_sum(sumf[row]);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
dst_f32[first_row + row] = sum_all;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[[host_name("kernel_mul_mv_mxfp4_f32")]]
|
||||||
|
kernel void kernel_mul_mv_mxfp4_f32(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device char * dst,
|
||||||
|
threadgroup char * shmem [[threadgroup(0)]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
|
kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||||
|
}
|
||||||
|
|
||||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
||||||
kernel void kernel_get_rows_q(
|
kernel void kernel_get_rows_q(
|
||||||
constant ggml_metal_kargs_get_rows & args,
|
constant ggml_metal_kargs_get_rows & args,
|
||||||
|
@ -7475,6 +7727,7 @@ template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get
|
||||||
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
|
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
|
||||||
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
|
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
|
||||||
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
|
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
|
||||||
|
template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_mxfp4, 2, dequantize_mxfp4>;
|
||||||
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
|
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
|
||||||
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
|
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
|
||||||
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
|
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
|
||||||
|
@ -7527,6 +7780,7 @@ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_m
|
||||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
||||||
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
||||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
||||||
|
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
|
||||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
||||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
||||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
||||||
|
@ -7558,6 +7812,7 @@ template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_m
|
||||||
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
||||||
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
||||||
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
||||||
|
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
|
||||||
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
||||||
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
||||||
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
||||||
|
@ -7703,6 +7958,8 @@ template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t
|
||||||
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH>>>;
|
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH>>>;
|
||||||
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH>>>;
|
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH>>>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH>>>;
|
||||||
|
|
||||||
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH>>>;
|
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH>>>;
|
||||||
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH>>>;
|
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH>>>;
|
||||||
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH>>>;
|
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH>>>;
|
||||||
|
|
|
@ -21,6 +21,17 @@
|
||||||
|
|
||||||
#define UNUSED GGML_UNUSED
|
#define UNUSED GGML_UNUSED
|
||||||
|
|
||||||
|
static inline int best_index_int8(int n, const int8_t * val, float x) {
|
||||||
|
if (x <= val[0]) return 0;
|
||||||
|
if (x >= val[n-1]) return n-1;
|
||||||
|
int ml = 0, mu = n-1;
|
||||||
|
while (mu-ml > 1) {
|
||||||
|
int mav = (ml+mu)/2;
|
||||||
|
if (x < val[mav]) mu = mav; else ml = mav;
|
||||||
|
}
|
||||||
|
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
||||||
|
}
|
||||||
|
|
||||||
// reference implementation for deterministic creation of model files
|
// reference implementation for deterministic creation of model files
|
||||||
void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) {
|
void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) {
|
||||||
static const int qk = QK4_0;
|
static const int qk = QK4_0;
|
||||||
|
@ -246,6 +257,53 @@ void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_REST
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline int best_index_mxfp4(float x, float e) {
|
||||||
|
int best_index = 0;
|
||||||
|
float best_err = fabsf(kvalues_mxfp4[0]*e - x);
|
||||||
|
for (int i = 1; i < 16; i++) {
|
||||||
|
float err = fabsf(kvalues_mxfp4[i]*e - x);
|
||||||
|
if (err < best_err) {
|
||||||
|
best_index = i;
|
||||||
|
best_err = err;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return best_index;
|
||||||
|
}
|
||||||
|
|
||||||
|
void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
|
||||||
|
static const int qk = QK_MXFP4;
|
||||||
|
|
||||||
|
assert(k % qk == 0);
|
||||||
|
|
||||||
|
const int nb = k / qk;
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; i++) {
|
||||||
|
float amax = 0.0f; // absolute max
|
||||||
|
|
||||||
|
for (int j = 0; j < qk; j++) {
|
||||||
|
const float v = x[i*qk + j];
|
||||||
|
|
||||||
|
if (amax < fabsf(v)) {
|
||||||
|
amax = fabsf(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint8_t e = (uint8_t) (floorf(log2f(amax)) - 2 + 127);
|
||||||
|
|
||||||
|
const float d = GGML_E8M0_TO_FP32_HALF(e);
|
||||||
|
|
||||||
|
y[i].e = e;
|
||||||
|
|
||||||
|
for (int j = 0; j < qk/2; ++j) {
|
||||||
|
const uint8_t x0 = best_index_mxfp4(x[i*qk + 0 + j], d);
|
||||||
|
const uint8_t x1 = best_index_mxfp4(x[i*qk + qk/2 + j], d);
|
||||||
|
|
||||||
|
y[i].qs[j] = x0;
|
||||||
|
y[i].qs[j] |= x1 << 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||||
static const int qk = QK4_0;
|
static const int qk = QK4_0;
|
||||||
|
|
||||||
|
@ -356,6 +414,26 @@ void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRI
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||||
|
static const int qk = QK_MXFP4;
|
||||||
|
|
||||||
|
assert(k % qk == 0);
|
||||||
|
|
||||||
|
const int nb = k / qk;
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; i++) {
|
||||||
|
const float d = GGML_E8M0_TO_FP32_HALF(x[i].e);
|
||||||
|
|
||||||
|
for (int j = 0; j < qk/2; ++j) {
|
||||||
|
const int8_t x0 = kvalues_mxfp4[x[i].qs[j] & 0x0F];
|
||||||
|
const int8_t x1 = kvalues_mxfp4[x[i].qs[j] >> 4];
|
||||||
|
|
||||||
|
y[i*qk + j + 0 ] = x0*d;
|
||||||
|
y[i*qk + j + qk/2] = x1*d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// 2-6 bit quantization in super-blocks
|
// 2-6 bit quantization in super-blocks
|
||||||
//
|
//
|
||||||
|
@ -2014,6 +2092,12 @@ size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
|
||||||
return nrow * row_size;
|
return nrow * row_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||||
|
GGML_UNUSED(quant_weights);
|
||||||
|
quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||||
|
return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
|
||||||
|
}
|
||||||
|
|
||||||
// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
|
// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
|
||||||
|
|
||||||
void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
|
void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
|
||||||
|
@ -4551,17 +4635,6 @@ size_t quantize_iq1_m(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
|
||||||
|
|
||||||
// ============================ 4-bit non-linear quants
|
// ============================ 4-bit non-linear quants
|
||||||
|
|
||||||
static inline int best_index_int8(int n, const int8_t * val, float x) {
|
|
||||||
if (x <= val[0]) return 0;
|
|
||||||
if (x >= val[n-1]) return n-1;
|
|
||||||
int ml = 0, mu = n-1;
|
|
||||||
while (mu-ml > 1) {
|
|
||||||
int mav = (ml+mu)/2;
|
|
||||||
if (x < val[mav]) mu = mav; else ml = mav;
|
|
||||||
}
|
|
||||||
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x,
|
static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x,
|
||||||
ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
|
ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
|
||||||
float * scales, float * weight, uint8_t * L,
|
float * scales, float * weight, uint8_t * L,
|
||||||
|
@ -4961,6 +5034,15 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool validate_e_e8m0(uint8_t e, size_t i) {
|
||||||
|
if (e == 0xff) {
|
||||||
|
fprintf(stderr, "ggml_validate_row_data: found invalid e value %d at block %zu\n", e, i);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
|
#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
|
||||||
const type * q = (const type *) (data); \
|
const type * q = (const type *) (data); \
|
||||||
for (size_t i = 0; i < (nb); ++i) { \
|
for (size_t i = 0; i < (nb); ++i) { \
|
||||||
|
@ -4977,6 +5059,14 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define VALIDATE_ROW_DATA_E_E8M0_IMPL(type, data, nb) \
|
||||||
|
const type * q = (const type *) (data); \
|
||||||
|
for (size_t i = 0; i < (nb); ++i) { \
|
||||||
|
if (!validate_e_e8m0(q[i].e, i)) { \
|
||||||
|
return false; \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
|
#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
|
||||||
const type * q = (const type *) (data); \
|
const type * q = (const type *) (data); \
|
||||||
for (size_t i = 0; i < (nb); ++i) { \
|
for (size_t i = 0; i < (nb); ++i) { \
|
||||||
|
@ -5130,6 +5220,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
||||||
{
|
{
|
||||||
VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
|
{
|
||||||
|
VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);
|
||||||
|
} break;
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
{
|
{
|
||||||
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
|
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
|
||||||
|
|
|
@ -21,6 +21,8 @@ GGML_API void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 *
|
||||||
GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
|
GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
|
||||||
GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
|
GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
|
||||||
|
|
||||||
|
GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
|
||||||
|
|
||||||
GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
|
GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
|
||||||
GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
|
GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
|
||||||
GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
|
GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
|
||||||
|
@ -45,6 +47,8 @@ GGML_API void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GG
|
||||||
GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||||
//GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
//GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||||
|
|
||||||
|
GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||||
|
|
||||||
GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||||
GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||||
GGML_API void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
GGML_API void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||||
|
@ -90,6 +94,8 @@ GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTR
|
||||||
GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||||
GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||||
|
|
||||||
|
GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||||
|
|
||||||
GGML_API void iq2xs_init_impl(enum ggml_type type);
|
GGML_API void iq2xs_init_impl(enum ggml_type type);
|
||||||
GGML_API void iq2xs_free_impl(enum ggml_type type);
|
GGML_API void iq2xs_free_impl(enum ggml_type type);
|
||||||
GGML_API void iq3xs_init_impl(int grid_size);
|
GGML_API void iq3xs_init_impl(int grid_size);
|
||||||
|
|
|
@ -465,6 +465,8 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_div[2][2][2];
|
vk_pipeline pipeline_div[2][2][2];
|
||||||
vk_pipeline pipeline_div_norepeat[2][2][2];
|
vk_pipeline pipeline_div_norepeat[2][2][2];
|
||||||
|
|
||||||
|
vk_pipeline pipeline_add_id_f32;
|
||||||
|
|
||||||
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
|
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
|
||||||
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
|
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
|
||||||
vk_pipeline pipeline_scale_f32;
|
vk_pipeline pipeline_scale_f32;
|
||||||
|
@ -499,6 +501,7 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_geglu[2];
|
vk_pipeline pipeline_geglu[2];
|
||||||
vk_pipeline pipeline_reglu[2];
|
vk_pipeline pipeline_reglu[2];
|
||||||
vk_pipeline pipeline_swiglu[2];
|
vk_pipeline pipeline_swiglu[2];
|
||||||
|
vk_pipeline pipeline_swiglu_oai[2];
|
||||||
vk_pipeline pipeline_geglu_erf[2];
|
vk_pipeline pipeline_geglu_erf[2];
|
||||||
vk_pipeline pipeline_geglu_quick[2];
|
vk_pipeline pipeline_geglu_quick[2];
|
||||||
|
|
||||||
|
@ -721,6 +724,8 @@ struct vk_op_glu_push_constants {
|
||||||
uint32_t ne00;
|
uint32_t ne00;
|
||||||
uint32_t ne20;
|
uint32_t ne20;
|
||||||
uint32_t mode; // 0: default, 1: swapped, 2: split
|
uint32_t mode; // 0: default, 1: swapped, 2: split
|
||||||
|
float alpha; // for swiglu_oai
|
||||||
|
float limit;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct vk_op_unary_push_constants {
|
struct vk_op_unary_push_constants {
|
||||||
|
@ -810,6 +815,15 @@ struct vk_op_binary_push_constants {
|
||||||
float param1; float param2; int32_t param3;
|
float param1; float param2; int32_t param3;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct vk_op_add_id_push_constants {
|
||||||
|
uint32_t ne0;
|
||||||
|
uint32_t ne1;
|
||||||
|
uint32_t s01;
|
||||||
|
uint32_t s02;
|
||||||
|
uint32_t s11;
|
||||||
|
uint32_t s21;
|
||||||
|
};
|
||||||
|
|
||||||
struct vk_op_diag_mask_push_constants {
|
struct vk_op_diag_mask_push_constants {
|
||||||
uint32_t ncols;
|
uint32_t ncols;
|
||||||
uint32_t rows_per_channel;
|
uint32_t rows_per_channel;
|
||||||
|
@ -851,6 +865,7 @@ struct vk_op_soft_max_push_constants {
|
||||||
float m1;
|
float m1;
|
||||||
uint32_t n_head_log2;
|
uint32_t n_head_log2;
|
||||||
uint32_t nrows_x;
|
uint32_t nrows_x;
|
||||||
|
uint32_t has_sinks;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct vk_op_argsort_push_constants {
|
struct vk_op_argsort_push_constants {
|
||||||
|
@ -1993,6 +2008,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
lut_size = 4*16;
|
lut_size = 4*16;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
@ -2369,6 +2385,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||||
|
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||||
|
|
||||||
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
||||||
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||||
|
@ -2395,6 +2412,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||||
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||||
#undef CREATE_MM
|
#undef CREATE_MM
|
||||||
#undef CREATE_MM2
|
#undef CREATE_MM2
|
||||||
} else
|
} else
|
||||||
|
@ -2456,6 +2474,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
|
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
} else {
|
} else {
|
||||||
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
|
@ -2477,6 +2496,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
|
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
}
|
}
|
||||||
|
|
||||||
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
|
@ -2509,6 +2529,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
|
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
} else {
|
} else {
|
||||||
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
|
@ -2530,6 +2551,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
|
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
}
|
}
|
||||||
#undef CREATE_MM2
|
#undef CREATE_MM2
|
||||||
#undef CREATE_MM
|
#undef CREATE_MM
|
||||||
|
@ -2597,6 +2619,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
|
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
|
|
||||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||||
if (device->integer_dot_product) {
|
if (device->integer_dot_product) {
|
||||||
|
@ -2634,6 +2657,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
|
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
#undef CREATE_MM2
|
#undef CREATE_MM2
|
||||||
#undef CREATE_MMQ
|
#undef CREATE_MMQ
|
||||||
#undef CREATE_MM
|
#undef CREATE_MM
|
||||||
|
@ -2688,6 +2712,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
|
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
|
|
||||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||||
if (device->integer_dot_product) {
|
if (device->integer_dot_product) {
|
||||||
|
@ -2725,6 +2750,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
|
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
}
|
}
|
||||||
// reusing CREATE_MM from the fp32 path
|
// reusing CREATE_MM from the fp32 path
|
||||||
if ((device->coopmat2 || device->coopmat_support)
|
if ((device->coopmat2 || device->coopmat_support)
|
||||||
|
@ -2783,6 +2809,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f32_f32_len, mul_mat_vec_mxfp4_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||||
|
@ -2806,6 +2833,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f16_f32_len, mul_mat_vec_mxfp4_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||||
|
@ -2830,6 +2858,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
||||||
|
|
||||||
// dequant shaders
|
// dequant shaders
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||||
|
@ -2852,6 +2881,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||||
|
|
||||||
// get_rows
|
// get_rows
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||||
|
@ -2871,6 +2901,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||||
|
@ -2889,6 +2920,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
||||||
|
@ -2992,6 +3024,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_BINARY(div, _norepeat, {1})
|
CREATE_BINARY(div, _norepeat, {1})
|
||||||
#undef CREATE_BINARY
|
#undef CREATE_BINARY
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
@ -3042,6 +3076,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_GLU(geglu)
|
CREATE_GLU(geglu)
|
||||||
CREATE_GLU(reglu)
|
CREATE_GLU(reglu)
|
||||||
CREATE_GLU(swiglu)
|
CREATE_GLU(swiglu)
|
||||||
|
CREATE_GLU(swiglu_oai)
|
||||||
CREATE_GLU(geglu_erf)
|
CREATE_GLU(geglu_erf)
|
||||||
CREATE_GLU(geglu_quick)
|
CREATE_GLU(geglu_quick)
|
||||||
#undef CREATE_GLU
|
#undef CREATE_GLU
|
||||||
|
@ -3051,10 +3086,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||||
|
@ -4268,6 +4303,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -4338,6 +4374,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -4381,6 +4418,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -4435,6 +4473,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -4470,6 +4509,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -4655,6 +4695,7 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
|
||||||
std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
|
std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
|
||||||
GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
|
GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
|
||||||
GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
|
GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
|
||||||
|
GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
|
||||||
|
|
||||||
vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
|
vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
|
||||||
vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
|
vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
|
||||||
|
@ -6871,6 +6912,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
|
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
|
||||||
|
return ctx->device->pipeline_add_id_f32;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
return ctx->device->pipeline_concat_f32;
|
return ctx->device->pipeline_concat_f32;
|
||||||
|
@ -7016,6 +7062,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
|
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
|
||||||
case GGML_GLU_OP_SWIGLU:
|
case GGML_GLU_OP_SWIGLU:
|
||||||
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
|
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
|
||||||
|
case GGML_GLU_OP_SWIGLU_OAI:
|
||||||
|
return ctx->device->pipeline_swiglu_oai[dst->type == GGML_TYPE_F16];
|
||||||
case GGML_GLU_OP_GEGLU_ERF:
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
|
return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
|
||||||
case GGML_GLU_OP_GEGLU_QUICK:
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
|
@ -7031,6 +7079,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return nullptr;
|
return nullptr;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
||||||
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
|
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
|
||||||
|
@ -7201,6 +7250,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
||||||
case GGML_OP_SUB:
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
case GGML_OP_SQR:
|
case GGML_OP_SQR:
|
||||||
|
@ -7547,6 +7597,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
elements = { ne, 1, 1 };
|
elements = { ne, 1, 1 };
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
|
{
|
||||||
|
elements = { (uint32_t)ne01, (uint32_t)ne02, 1 };
|
||||||
|
} break;
|
||||||
case GGML_OP_SET_ROWS:
|
case GGML_OP_SET_ROWS:
|
||||||
{
|
{
|
||||||
uint32_t ne = ggml_nelements(src0);
|
uint32_t ne = ggml_nelements(src0);
|
||||||
|
@ -7586,8 +7640,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) {
|
if (op == GGML_OP_GLU) {
|
||||||
// Empty src1 is possible in soft_max, but the shader needs a buffer
|
// Empty src1 is possible in glu, but the shader needs a buffer
|
||||||
vk_subbuffer subbuf_y;
|
vk_subbuffer subbuf_y;
|
||||||
if (use_src1) {
|
if (use_src1) {
|
||||||
subbuf_y = { d_Y, y_buf_offset, y_sz };
|
subbuf_y = { d_Y, y_buf_offset, y_sz };
|
||||||
|
@ -7597,6 +7651,24 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
|
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||||
|
} else if (op == GGML_OP_SOFT_MAX) {
|
||||||
|
// Empty src1 and src2 is possible in soft_max, but the shader needs a buffer
|
||||||
|
vk_subbuffer subbuf_y;
|
||||||
|
if (use_src1) {
|
||||||
|
subbuf_y = { d_Y, y_buf_offset, y_sz };
|
||||||
|
} else {
|
||||||
|
subbuf_y = { d_X, 0, x_sz };
|
||||||
|
}
|
||||||
|
|
||||||
|
vk_subbuffer subbuf_z;
|
||||||
|
if (use_src2) {
|
||||||
|
subbuf_z = { d_Z, z_buf_offset, z_sz };
|
||||||
|
} else {
|
||||||
|
subbuf_z = { d_X, 0, x_sz };
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_vk_sync_buffers(subctx);
|
||||||
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||||
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
|
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
|
||||||
// Empty src2 is possible in rope, but the shader needs a buffer
|
// Empty src2 is possible in rope, but the shader needs a buffer
|
||||||
vk_subbuffer subbuf_z;
|
vk_subbuffer subbuf_z;
|
||||||
|
@ -7725,6 +7797,21 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
||||||
}, dryrun);
|
}, dryrun);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||||
|
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||||
|
const uint32_t src2_type_size = ggml_type_size(src2->type);
|
||||||
|
|
||||||
|
ggml_vk_op_f32<vk_op_add_id_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ADD_ID, {
|
||||||
|
(uint32_t)dst->ne[0],
|
||||||
|
(uint32_t)dst->ne[1],
|
||||||
|
(uint32_t)src0->nb[1] / src0_type_size,
|
||||||
|
(uint32_t)src0->nb[2] / src0_type_size,
|
||||||
|
(uint32_t)src1->nb[1] / src1_type_size,
|
||||||
|
(uint32_t)src2->nb[1] / src2_type_size,
|
||||||
|
}, dryrun);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
|
static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
|
||||||
GGML_ASSERT(version == 6 || version == 7);
|
GGML_ASSERT(version == 6 || version == 7);
|
||||||
int num_srcs = version == 6 ? 6 : 7;
|
int num_srcs = version == 6 ? 6 : 7;
|
||||||
|
@ -8143,8 +8230,12 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
const float * op_params_f = (const float *)dst->op_params;
|
||||||
|
|
||||||
const bool swapped = (bool)dst->op_params[1];
|
const bool swapped = (bool)dst->op_params[1];
|
||||||
const bool split = src1 != nullptr;
|
const bool split = src1 != nullptr;
|
||||||
|
const float alpha = op_params_f[2];
|
||||||
|
const float limit = op_params_f[3];
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
|
@ -8158,7 +8249,15 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
||||||
|
|
||||||
const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
|
const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
|
||||||
|
|
||||||
ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun);
|
ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU,
|
||||||
|
{
|
||||||
|
(uint32_t)ggml_nelements(dst),
|
||||||
|
(uint32_t)src0->ne[0],
|
||||||
|
(uint32_t)dst->ne[0],
|
||||||
|
mode,
|
||||||
|
alpha,
|
||||||
|
limit
|
||||||
|
}, dryrun);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
@ -8166,7 +8265,7 @@ static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& sub
|
||||||
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
|
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
|
||||||
float * op_params = (float *)dst->op_params;
|
float * op_params = (float *)dst->op_params;
|
||||||
|
|
||||||
float scale = op_params[0];
|
float scale = op_params[0];
|
||||||
|
@ -8188,7 +8287,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
|
||||||
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
|
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
|
||||||
ncols,
|
ncols,
|
||||||
src1 != nullptr ? nrows_y : (uint32_t)0,
|
src1 != nullptr ? nrows_y : (uint32_t)0,
|
||||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
|
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
|
||||||
|
@ -8198,6 +8297,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||||
m0, m1,
|
m0, m1,
|
||||||
n_head_log2,
|
n_head_log2,
|
||||||
nrows_x,
|
nrows_x,
|
||||||
|
src2 != nullptr
|
||||||
}, dryrun);
|
}, dryrun);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9437,6 +9537,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
case GGML_GLU_OP_GEGLU:
|
case GGML_GLU_OP_GEGLU:
|
||||||
case GGML_GLU_OP_REGLU:
|
case GGML_GLU_OP_REGLU:
|
||||||
case GGML_GLU_OP_SWIGLU:
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
case GGML_GLU_OP_SWIGLU_OAI:
|
||||||
case GGML_GLU_OP_GEGLU_ERF:
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
case GGML_GLU_OP_GEGLU_QUICK:
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
break;
|
break;
|
||||||
|
@ -9448,6 +9549,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
case GGML_OP_REPEAT_BACK:
|
case GGML_OP_REPEAT_BACK:
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
case GGML_OP_SUB:
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
|
@ -9602,6 +9704,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun);
|
ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||||
|
|
||||||
|
break;
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
|
ggml_vk_add_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun);
|
ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||||
|
@ -9699,6 +9805,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
case GGML_GLU_OP_GEGLU:
|
case GGML_GLU_OP_GEGLU:
|
||||||
case GGML_GLU_OP_REGLU:
|
case GGML_GLU_OP_REGLU:
|
||||||
case GGML_GLU_OP_SWIGLU:
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
case GGML_GLU_OP_SWIGLU_OAI:
|
||||||
case GGML_GLU_OP_GEGLU_ERF:
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
case GGML_GLU_OP_GEGLU_QUICK:
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
|
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||||
|
@ -9712,7 +9819,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
|
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_SOFT_MAX_BACK:
|
case GGML_OP_SOFT_MAX_BACK:
|
||||||
|
@ -9858,6 +9965,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||||
case GGML_OP_SUB:
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
|
@ -9927,6 +10035,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||||
case GGML_GLU_OP_GEGLU:
|
case GGML_GLU_OP_GEGLU:
|
||||||
case GGML_GLU_OP_REGLU:
|
case GGML_GLU_OP_REGLU:
|
||||||
case GGML_GLU_OP_SWIGLU:
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
case GGML_GLU_OP_SWIGLU_OAI:
|
||||||
case GGML_GLU_OP_GEGLU_ERF:
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
case GGML_GLU_OP_GEGLU_QUICK:
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
buf = tensor->buffer;
|
buf = tensor->buffer;
|
||||||
|
@ -10776,6 +10885,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_GLU_OP_GEGLU:
|
case GGML_GLU_OP_GEGLU:
|
||||||
case GGML_GLU_OP_REGLU:
|
case GGML_GLU_OP_REGLU:
|
||||||
case GGML_GLU_OP_SWIGLU:
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
case GGML_GLU_OP_SWIGLU_OAI:
|
||||||
case GGML_GLU_OP_GEGLU_ERF:
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
case GGML_GLU_OP_GEGLU_QUICK:
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
return ggml_is_contiguous(op->src[0]) &&
|
return ggml_is_contiguous(op->src[0]) &&
|
||||||
|
@ -10821,6 +10931,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
@ -10858,6 +10969,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
|
if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
// TODO: support attention sinks [TAG_ATTN_SINKS]
|
||||||
|
if (op->src[4]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
if (op->src[0]->type != GGML_TYPE_F32) {
|
if (op->src[0]->type != GGML_TYPE_F32) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -10930,6 +11045,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
@ -11028,6 +11144,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||||
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
|
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
|
||||||
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
|
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
|
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 &&
|
||||||
|
op->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_SILU_BACK:
|
case GGML_OP_SILU_BACK:
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
case GGML_OP_RMS_NORM_BACK:
|
||||||
case GGML_OP_SQR:
|
case GGML_OP_SQR:
|
||||||
|
|
42
ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp
Normal file
42
ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : require
|
||||||
|
|
||||||
|
#include "types.comp"
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter
|
||||||
|
{
|
||||||
|
uint ne0;
|
||||||
|
uint ne1;
|
||||||
|
uint s01;
|
||||||
|
uint s02;
|
||||||
|
uint s11;
|
||||||
|
uint s21;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
#define BLOCK_SIZE 512
|
||||||
|
|
||||||
|
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||||
|
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
|
||||||
|
layout (binding = 2) readonly buffer Z {int32_t data_c[];};
|
||||||
|
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint i1 = gl_WorkGroupID.x;
|
||||||
|
const uint i2 = gl_WorkGroupID.y;
|
||||||
|
|
||||||
|
const uint i11 = data_c[i1 + i2 * p.s21];
|
||||||
|
|
||||||
|
const uint s1 = p.ne0;
|
||||||
|
const uint s2 = p.ne0 * p.ne1;
|
||||||
|
|
||||||
|
const uint d0 = i1 * s1 + i2 * s2;
|
||||||
|
const uint a0 = i1 * p.s01 + i2 * p.s02;
|
||||||
|
const uint b0 = i11 * p.s11;
|
||||||
|
|
||||||
|
for (uint i0 = gl_LocalInvocationID.x; i0 < p.ne0; i0 += BLOCK_SIZE) {
|
||||||
|
data_d[d0 + i0] = data_a[a0 + i0] + data_b[b0 + i0];
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,8 +4,8 @@
|
||||||
#include "generic_unary_head.comp"
|
#include "generic_unary_head.comp"
|
||||||
#include "dequant_funcs.comp"
|
#include "dequant_funcs.comp"
|
||||||
|
|
||||||
#if defined(DATA_A_IQ4_NL)
|
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4)
|
||||||
// 16 invocations needed for init_iq4nl_shmem
|
// 16 invocations needed for init_iq_shmem
|
||||||
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
|
||||||
#else
|
#else
|
||||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
|
@ -434,6 +434,18 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_MXFP4)
|
||||||
|
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||||
|
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
||||||
|
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
|
||||||
|
}
|
||||||
|
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||||
|
vec2 v0 = dequantize(ib, iqs, a_offset);
|
||||||
|
vec2 v1 = dequantize(ib, iqs + 1, a_offset);
|
||||||
|
return vec4(v0.x, v0.y, v1.x, v1.y);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
|
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
|
||||||
vec2 get_dm(uint ib, uint a_offset) {
|
vec2 get_dm(uint ib, uint a_offset) {
|
||||||
return vec2(0, 0);
|
return vec2(0, 0);
|
||||||
|
@ -455,6 +467,12 @@ vec2 get_dm(uint ib, uint a_offset) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_MXFP4)
|
||||||
|
vec2 get_dm(uint ib, uint a_offset) {
|
||||||
|
return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
||||||
vec2 get_dm(uint ib, uint a_offset) {
|
vec2 get_dm(uint ib, uint a_offset) {
|
||||||
return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
|
return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
|
||||||
|
|
|
@ -654,6 +654,25 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_MXFP4)
|
||||||
|
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 {
|
||||||
|
block_mxfp4 block;
|
||||||
|
};
|
||||||
|
|
||||||
|
float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||||
|
{
|
||||||
|
const float d = e8m0_to_fp32(bl.block.e);
|
||||||
|
const uint idx = coordInBlock[1];
|
||||||
|
const uint iqs = idx & 0xF;
|
||||||
|
const uint shift = (idx & 0x10) >> 2;
|
||||||
|
uint32_t qs = bl.block.qs[iqs];
|
||||||
|
qs >>= shift;
|
||||||
|
qs &= 0xF;
|
||||||
|
float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q4_0)
|
#if defined(DATA_A_Q4_0)
|
||||||
#define dequantFuncA dequantFuncQ4_0
|
#define dequantFuncA dequantFuncQ4_0
|
||||||
#elif defined(DATA_A_Q4_1)
|
#elif defined(DATA_A_Q4_1)
|
||||||
|
@ -696,4 +715,6 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
|
||||||
#define dequantFuncA dequantFuncIQ4_XS
|
#define dequantFuncA dequantFuncIQ4_XS
|
||||||
#elif defined(DATA_A_IQ4_NL)
|
#elif defined(DATA_A_IQ4_NL)
|
||||||
#define dequantFuncA dequantFuncIQ4_NL
|
#define dequantFuncA dequantFuncIQ4_NL
|
||||||
|
#elif defined(DATA_A_MXFP4)
|
||||||
|
#define dequantFuncA dequantFuncMXFP4
|
||||||
#endif
|
#endif
|
||||||
|
|
32
ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp
Normal file
32
ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "dequant_head.comp"
|
||||||
|
|
||||||
|
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer A {block_mxfp4 data_a[];};
|
||||||
|
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
|
||||||
|
|
||||||
|
init_iq_shmem(gl_WorkGroupSize);
|
||||||
|
|
||||||
|
const uint tid = gl_LocalInvocationID.x % 64;
|
||||||
|
const uint il = tid/32;
|
||||||
|
const uint ir = tid%32;
|
||||||
|
const uint ib = 32*i + ir;
|
||||||
|
if (ib >= p.nel / 32) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint q_idx = 8*il;
|
||||||
|
const uint b_idx = 1024*i + 32*ir + q_idx;
|
||||||
|
|
||||||
|
const float d = e8m0_to_fp32(data_a[ib].e);
|
||||||
|
|
||||||
|
[[unroll]] for (uint l = 0; l < 8; ++l) {
|
||||||
|
data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
|
||||||
|
data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
|
||||||
|
}
|
||||||
|
}
|
|
@ -14,4 +14,6 @@ layout (push_constant) uniform parameter
|
||||||
uint ne00;
|
uint ne00;
|
||||||
uint ne20;
|
uint ne20;
|
||||||
uint mode;
|
uint mode;
|
||||||
|
float alpha;
|
||||||
|
float limit;
|
||||||
} p;
|
} p;
|
||||||
|
|
|
@ -747,6 +747,21 @@ void main() {
|
||||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d;
|
buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d;
|
||||||
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d;
|
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d;
|
||||||
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d;
|
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d;
|
||||||
|
#elif defined(DATA_A_MXFP4)
|
||||||
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||||
|
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
|
||||||
|
|
||||||
|
const uint ib = idx / 8;
|
||||||
|
const uint iqs = (idx & 0x07) * 2;
|
||||||
|
|
||||||
|
const float d = e8m0_to_fp32(data_a[ib].e);
|
||||||
|
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||||
|
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
|
||||||
|
|
||||||
|
buf_a[buf_idx ] = FLOAT_TYPE(kvalues_mxfp4[vui & 0xF] * d);
|
||||||
|
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_mxfp4[vui >> 4] * d);
|
||||||
|
buf_a[buf_idx + 1] = FLOAT_TYPE(kvalues_mxfp4[vui2 & 0xF] * d);
|
||||||
|
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_mxfp4[vui2 >> 4] * d);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
|
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
|
||||||
|
|
|
@ -92,6 +92,12 @@ FLOAT_TYPE get_d(uint ib) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_MXFP4)
|
||||||
|
FLOAT_TYPE get_d(uint ib) {
|
||||||
|
return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
||||||
FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
||||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||||
|
|
|
@ -20,6 +20,7 @@ layout (push_constant) uniform parameter
|
||||||
float m1;
|
float m1;
|
||||||
uint n_head_log2;
|
uint n_head_log2;
|
||||||
uint nrows_x;
|
uint nrows_x;
|
||||||
|
uint has_sinks;
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
#include "types.comp"
|
#include "types.comp"
|
||||||
|
@ -29,7 +30,8 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||||
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
|
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
|
||||||
layout (binding = 2) buffer D {D_TYPE data_d[];};
|
layout (binding = 2) readonly buffer Z {float data_c[];};
|
||||||
|
layout (binding = 3) buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
shared FLOAT_TYPE vals[BLOCK_SIZE];
|
shared FLOAT_TYPE vals[BLOCK_SIZE];
|
||||||
|
|
||||||
|
@ -66,7 +68,7 @@ void soft_max(uint num_iters) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find max
|
// Find max
|
||||||
FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000);
|
FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
|
||||||
|
|
||||||
// Cache values while we compute the max, so we don't need to read them
|
// Cache values while we compute the max, so we don't need to read them
|
||||||
// again when we're ready to compute exp(x-max).
|
// again when we're ready to compute exp(x-max).
|
||||||
|
@ -148,6 +150,10 @@ void soft_max(uint num_iters) {
|
||||||
}
|
}
|
||||||
sum = vals[0];
|
sum = vals[0];
|
||||||
|
|
||||||
|
if (p.has_sinks != 0) {
|
||||||
|
sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val));
|
||||||
|
}
|
||||||
|
|
||||||
FLOAT_TYPE rcpdivisor = 1.0/sum;
|
FLOAT_TYPE rcpdivisor = 1.0/sum;
|
||||||
|
|
||||||
[[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
|
[[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
|
||||||
|
|
14
ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp
Normal file
14
ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "glu_head.comp"
|
||||||
|
|
||||||
|
float op(float a, float b) {
|
||||||
|
float xi = min(a, p.limit);
|
||||||
|
float gi = max(min(b, p.limit), -p.limit);
|
||||||
|
|
||||||
|
float out_glu = xi / (1.0f + exp(-xi * p.alpha));
|
||||||
|
out_glu = out_glu * (1.0f + gi);
|
||||||
|
return out_glu;
|
||||||
|
}
|
||||||
|
|
||||||
|
#include "glu_main.comp"
|
|
@ -1337,6 +1337,29 @@ struct block_iq4_nl_packed16
|
||||||
#define A_TYPE_PACKED16 block_iq4_nl_packed16
|
#define A_TYPE_PACKED16 block_iq4_nl_packed16
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#define QUANT_K_MXFP4 32
|
||||||
|
#define QUANT_R_MXFP4 2
|
||||||
|
|
||||||
|
struct block_mxfp4
|
||||||
|
{
|
||||||
|
uint8_t e;
|
||||||
|
uint8_t qs[QUANT_K_MXFP4/2];
|
||||||
|
};
|
||||||
|
|
||||||
|
//struct block_mxfp4_packed16
|
||||||
|
//{
|
||||||
|
// uint8_t e;
|
||||||
|
// uint16_t qs[QUANT_K_MXFP4/2/2];
|
||||||
|
//};
|
||||||
|
|
||||||
|
#if defined(DATA_A_MXFP4)
|
||||||
|
#define QUANT_K QUANT_K_MXFP4
|
||||||
|
#define QUANT_R QUANT_R_MXFP4
|
||||||
|
#define QUANT_AUXF 1
|
||||||
|
#define A_TYPE block_mxfp4
|
||||||
|
//#define A_TYPE_PACKED16 block_mxfp4_packed16
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
|
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
|
||||||
const int8_t kvalues_iq4nl_const[16] = {
|
const int8_t kvalues_iq4nl_const[16] = {
|
||||||
int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
|
int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
|
||||||
|
@ -1356,6 +1379,25 @@ void init_iq_shmem(uvec3 wgsize)
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_MXFP4)
|
||||||
|
const FLOAT_TYPE kvalues_mxfp4_const[16] = {
|
||||||
|
FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f),
|
||||||
|
FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f)
|
||||||
|
};
|
||||||
|
|
||||||
|
shared FLOAT_TYPE kvalues_mxfp4[16];
|
||||||
|
|
||||||
|
#define NEEDS_INIT_IQ_SHMEM
|
||||||
|
void init_iq_shmem(uvec3 wgsize)
|
||||||
|
{
|
||||||
|
// copy the table into shared memory and sync
|
||||||
|
for (uint i = gl_LocalInvocationIndex.x; i < kvalues_mxfp4.length(); i += wgsize.x) {
|
||||||
|
kvalues_mxfp4[i] = kvalues_mxfp4_const[i];
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// returns the bfloat value in the low 16b.
|
// returns the bfloat value in the low 16b.
|
||||||
// See ggml_compute_fp32_to_bf16
|
// See ggml_compute_fp32_to_bf16
|
||||||
uint32_t fp32_to_bf16(float f)
|
uint32_t fp32_to_bf16(float f)
|
||||||
|
@ -1370,4 +1412,17 @@ float bf16_to_fp32(uint32_t u)
|
||||||
return uintBitsToFloat(u << 16);
|
return uintBitsToFloat(u << 16);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float e8m0_to_fp32(uint8_t x) {
|
||||||
|
uint32_t bits;
|
||||||
|
|
||||||
|
if (x == 0) {
|
||||||
|
bits = 0x00400000;
|
||||||
|
} else {
|
||||||
|
bits = x;
|
||||||
|
bits = bits << 23;
|
||||||
|
}
|
||||||
|
|
||||||
|
return uintBitsToFloat(bits);
|
||||||
|
}
|
||||||
|
|
||||||
#endif // !defined(GGML_TYPES_COMP)
|
#endif // !defined(GGML_TYPES_COMP)
|
||||||
|
|
|
@ -76,6 +76,7 @@ const std::vector<std::string> type_names = {
|
||||||
"iq3_s",
|
"iq3_s",
|
||||||
"iq4_xs",
|
"iq4_xs",
|
||||||
"iq4_nl",
|
"iq4_nl",
|
||||||
|
"mxfp4",
|
||||||
"bf16",
|
"bf16",
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -376,7 +377,7 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
||||||
std::string load_vec_quant = "2";
|
std::string load_vec_quant = "2";
|
||||||
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
|
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
|
||||||
load_vec_quant = "8";
|
load_vec_quant = "8";
|
||||||
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl"))
|
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
|
||||||
load_vec_quant = "4";
|
load_vec_quant = "4";
|
||||||
|
|
||||||
if (tname == "bf16") {
|
if (tname == "bf16") {
|
||||||
|
@ -616,6 +617,8 @@ void process_shaders() {
|
||||||
string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||||
string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||||
string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||||
|
string_to_spv("swiglu_oai_f16" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||||
|
string_to_spv("swiglu_oai_f32" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||||
string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||||
string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||||
string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||||
|
@ -685,6 +688,8 @@ void process_shaders() {
|
||||||
|
|
||||||
string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
|
|
||||||
|
string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
|
|
||||||
for (auto &c : compiles) {
|
for (auto &c : compiles) {
|
||||||
c.wait();
|
c.wait();
|
||||||
}
|
}
|
||||||
|
|
|
@ -118,8 +118,6 @@ struct webgpu_context_struct {
|
||||||
wgpu::Limits limits;
|
wgpu::Limits limits;
|
||||||
|
|
||||||
std::recursive_mutex mutex;
|
std::recursive_mutex mutex;
|
||||||
std::mutex get_tensor_mutex;
|
|
||||||
std::mutex init_mutex;
|
|
||||||
|
|
||||||
bool device_init = false;
|
bool device_init = false;
|
||||||
|
|
||||||
|
@ -139,6 +137,8 @@ struct webgpu_context_struct {
|
||||||
|
|
||||||
// Parameter buffers associated with the staged command buffers
|
// Parameter buffers associated with the staged command buffers
|
||||||
std::vector<webgpu_param_bufs> staged_param_bufs;
|
std::vector<webgpu_param_bufs> staged_param_bufs;
|
||||||
|
|
||||||
|
std::vector<wgpu::FutureWaitInfo> callback_futures;
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
|
typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
|
||||||
|
@ -221,25 +221,39 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
||||||
|
|
||||||
/** WebGPU Actions */
|
/** WebGPU Actions */
|
||||||
|
|
||||||
|
// Wait for the queue to finish processing all submitted work
|
||||||
static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
|
static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
|
||||||
// Wait for the queue to finish processing all commands
|
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
|
||||||
|
if (ctx->callback_futures.empty()) {
|
||||||
|
// no existing callbacks, wait on queue submission
|
||||||
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
|
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
|
||||||
wgpu::CallbackMode::AllowSpontaneous,
|
wgpu::CallbackMode::AllowSpontaneous,
|
||||||
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
||||||
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
||||||
GGML_LOG_ERROR("ggml_webgpu: Failed to wait on queue: %s\n", message.data);
|
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data);
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
UINT64_MAX);
|
UINT64_MAX);
|
||||||
|
} else {
|
||||||
|
// existing callbacks, wait on them
|
||||||
|
ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX);
|
||||||
|
ctx->callback_futures.clear();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
|
static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
|
||||||
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
|
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
|
||||||
|
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_submit_queue()");
|
||||||
|
if (ctx->staged_command_bufs.empty()) {
|
||||||
|
// Nothing to submit
|
||||||
|
return;
|
||||||
|
}
|
||||||
ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data());
|
ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data());
|
||||||
ctx->staged_command_bufs.clear();
|
ctx->staged_command_bufs.clear();
|
||||||
std::vector<webgpu_param_bufs> staged_param_bufs = std::move(ctx->staged_param_bufs);
|
std::vector<webgpu_param_bufs> staged_param_bufs = std::move(ctx->staged_param_bufs);
|
||||||
|
|
||||||
// Free the staged parameter buffers once the submission completes
|
// Free the staged parameter buffers once the submission completes
|
||||||
ctx->queue.OnSubmittedWorkDone(
|
wgpu::Future f = ctx->queue.OnSubmittedWorkDone(
|
||||||
wgpu::CallbackMode::AllowSpontaneous,
|
wgpu::CallbackMode::AllowSpontaneous,
|
||||||
[ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
[ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
||||||
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
||||||
|
@ -248,6 +262,7 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
|
||||||
// Free the staged parameter buffers
|
// Free the staged parameter buffers
|
||||||
ctx->param_buf_pool.free_bufs(staged_param_bufs);
|
ctx->param_buf_pool.free_bufs(staged_param_bufs);
|
||||||
});
|
});
|
||||||
|
ctx->callback_futures.push_back({ f });
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
|
static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
|
||||||
|
@ -273,7 +288,7 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
|
||||||
std::vector<uint32_t> params,
|
std::vector<uint32_t> params,
|
||||||
std::vector<wgpu::BindGroupEntry> bind_group_entries,
|
std::vector<wgpu::BindGroupEntry> bind_group_entries,
|
||||||
uint32_t wg_x,
|
uint32_t wg_x,
|
||||||
bool submit_imm = false) {
|
bool submit_and_wait = false) {
|
||||||
webgpu_param_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
|
webgpu_param_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
|
||||||
|
|
||||||
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
|
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
|
||||||
|
@ -304,17 +319,18 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
|
||||||
pass.DispatchWorkgroups(wg_x, 1, 1);
|
pass.DispatchWorkgroups(wg_x, 1, 1);
|
||||||
pass.End();
|
pass.End();
|
||||||
wgpu::CommandBuffer commands = encoder.Finish();
|
wgpu::CommandBuffer commands = encoder.Finish();
|
||||||
if (submit_imm) {
|
if (submit_and_wait) {
|
||||||
// Submit immediately
|
// Submit and wait immediately
|
||||||
ctx->queue.Submit(1, &commands);
|
ctx->queue.Submit(1, &commands);
|
||||||
ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
|
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
|
||||||
|
wgpu::CallbackMode::AllowSpontaneous,
|
||||||
[ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
[ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
||||||
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
||||||
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
|
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data);
|
||||||
message.data);
|
|
||||||
}
|
}
|
||||||
ctx->param_buf_pool.free_bufs({ params_bufs });
|
ctx->param_buf_pool.free_bufs({ params_bufs });
|
||||||
});
|
}),
|
||||||
|
UINT64_MAX);
|
||||||
} else {
|
} else {
|
||||||
// Lock the context mutex when pushing to the staging vectors.
|
// Lock the context mutex when pushing to the staging vectors.
|
||||||
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
|
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
|
||||||
|
@ -579,6 +595,9 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||||
// memset the remaining bytes
|
// memset the remaining bytes
|
||||||
ggml_backend_webgpu_buffer_memset(
|
ggml_backend_webgpu_buffer_memset(
|
||||||
webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size);
|
webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size);
|
||||||
|
} else {
|
||||||
|
// wait for WriteBuffer to complete
|
||||||
|
ggml_backend_webgpu_wait_on_submission(webgpu_ctx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -602,7 +621,7 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
||||||
final_size = size + (4 - (size % 4));
|
final_size = size + (4 - (size % 4));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::lock_guard<std::mutex> lock(webgpu_ctx->get_tensor_mutex);
|
std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex);
|
||||||
|
|
||||||
if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
|
if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
|
||||||
// Create a new staging buffer if it doesn't exist or is too small
|
// Create a new staging buffer if it doesn't exist or is too small
|
||||||
|
@ -768,10 +787,11 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
|
||||||
webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;
|
webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;
|
||||||
|
|
||||||
// Multiple threads may try to initialize the device
|
// Multiple threads may try to initialize the device
|
||||||
std::lock_guard<std::mutex> lock(webgpu_ctx->init_mutex);
|
std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex);
|
||||||
if (!webgpu_ctx->device_init) {
|
if (!webgpu_ctx->device_init) {
|
||||||
// Initialize device
|
// Initialize device
|
||||||
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16, wgpu::FeatureName::ImplicitDeviceSynchronization };
|
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
|
||||||
|
wgpu::FeatureName::ImplicitDeviceSynchronization };
|
||||||
wgpu::DeviceDescriptor dev_desc;
|
wgpu::DeviceDescriptor dev_desc;
|
||||||
dev_desc.requiredLimits = &webgpu_ctx->limits;
|
dev_desc.requiredLimits = &webgpu_ctx->limits;
|
||||||
dev_desc.requiredFeatures = required_features.data();
|
dev_desc.requiredFeatures = required_features.data();
|
||||||
|
|
|
@ -586,9 +586,6 @@ FILE * ggml_fopen(const char * fname, const char * mode) {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
static void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
|
|
||||||
static void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
|
|
||||||
static void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc);
|
|
||||||
|
|
||||||
static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
|
static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
|
||||||
[GGML_TYPE_I8] = {
|
[GGML_TYPE_I8] = {
|
||||||
|
@ -694,6 +691,14 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref,
|
.from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref,
|
||||||
},
|
},
|
||||||
|
[GGML_TYPE_MXFP4] = {
|
||||||
|
.type_name = "mxfp4",
|
||||||
|
.blck_size = QK_MXFP4,
|
||||||
|
.type_size = sizeof(block_mxfp4),
|
||||||
|
.is_quantized = true,
|
||||||
|
.to_float = (ggml_to_float_t) dequantize_row_mxfp4,
|
||||||
|
.from_float_ref = (ggml_from_float_t)quantize_row_mxfp4_ref,
|
||||||
|
},
|
||||||
[GGML_TYPE_Q2_K] = {
|
[GGML_TYPE_Q2_K] = {
|
||||||
.type_name = "q2_K",
|
.type_name = "q2_K",
|
||||||
.blck_size = QK_K,
|
.blck_size = QK_K,
|
||||||
|
@ -933,6 +938,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
|
|
||||||
"DUP",
|
"DUP",
|
||||||
"ADD",
|
"ADD",
|
||||||
|
"ADD_ID",
|
||||||
"ADD1",
|
"ADD1",
|
||||||
"ACC",
|
"ACC",
|
||||||
"SUB",
|
"SUB",
|
||||||
|
@ -1026,13 +1032,14 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"GLU",
|
"GLU",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
|
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"none",
|
"none",
|
||||||
|
|
||||||
"x",
|
"x",
|
||||||
"x+y",
|
"x+y",
|
||||||
|
"x[i]+y",
|
||||||
"x+y",
|
"x+y",
|
||||||
"view(x,nb,offset)+=y->x",
|
"view(x,nb,offset)+=y->x",
|
||||||
"x-y",
|
"x-y",
|
||||||
|
@ -1126,7 +1133,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"glu(x)",
|
"glu(x)",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
|
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
|
||||||
|
|
||||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||||
|
|
||||||
|
@ -1156,11 +1163,12 @@ static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
|
||||||
"REGLU",
|
"REGLU",
|
||||||
"GEGLU",
|
"GEGLU",
|
||||||
"SWIGLU",
|
"SWIGLU",
|
||||||
|
"SWIGLU_OAI",
|
||||||
"GEGLU_ERF",
|
"GEGLU_ERF",
|
||||||
"GEGLU_QUICK",
|
"GEGLU_QUICK",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_GLU_OP_COUNT == 5, "GGML_GLU_OP_COUNT != 5");
|
static_assert(GGML_GLU_OP_COUNT == 6, "GGML_GLU_OP_COUNT != 6");
|
||||||
|
|
||||||
|
|
||||||
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
||||||
|
@ -1328,6 +1336,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
||||||
case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
|
case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
|
||||||
case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
|
case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
|
||||||
case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
|
case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
|
||||||
|
case GGML_FTYPE_MOSTLY_MXFP4: wtype = GGML_TYPE_MXFP4; break;
|
||||||
case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
|
case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
|
||||||
case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
|
case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
|
||||||
case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
|
case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
|
||||||
|
@ -1978,6 +1987,27 @@ struct ggml_tensor * ggml_add_cast(
|
||||||
return ggml_add_cast_impl(ctx, a, b, type);
|
return ggml_add_cast_impl(ctx, a, b, type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_add_id(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
struct ggml_tensor * ids) {
|
||||||
|
|
||||||
|
GGML_ASSERT(a->ne[0] == b->ne[0]);
|
||||||
|
GGML_ASSERT(a->ne[1] == ids->ne[0]);
|
||||||
|
GGML_ASSERT(a->ne[2] == ids->ne[1]);
|
||||||
|
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||||
|
|
||||||
|
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||||
|
|
||||||
|
result->op = GGML_OP_ADD_ID;
|
||||||
|
result->src[0] = a;
|
||||||
|
result->src[1] = b;
|
||||||
|
result->src[2] = ids;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_add1
|
// ggml_add1
|
||||||
|
|
||||||
static struct ggml_tensor * ggml_add1_impl(
|
static struct ggml_tensor * ggml_add1_impl(
|
||||||
|
@ -2828,6 +2858,19 @@ struct ggml_tensor * ggml_geglu_quick_split(
|
||||||
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false);
|
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_swiglu_oai(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
float alpha,
|
||||||
|
float limit) {
|
||||||
|
struct ggml_tensor * result = ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU_OAI, false);
|
||||||
|
ggml_set_op_params_f32(result, 2, alpha);
|
||||||
|
ggml_set_op_params_f32(result, 3, limit);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_norm
|
// ggml_norm
|
||||||
|
|
||||||
static struct ggml_tensor * ggml_norm_impl(
|
static struct ggml_tensor * ggml_norm_impl(
|
||||||
|
@ -3795,6 +3838,22 @@ struct ggml_tensor * ggml_soft_max_ext(
|
||||||
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
|
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_soft_max_add_sinks(
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * sinks) {
|
||||||
|
if (!sinks) {
|
||||||
|
a->src[2] = NULL;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(a->op == GGML_OP_SOFT_MAX);
|
||||||
|
GGML_ASSERT(a->src[2] == NULL);
|
||||||
|
GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
|
||||||
|
GGML_ASSERT(sinks->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
a->src[2] = sinks;
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_soft_max_ext_back
|
// ggml_soft_max_ext_back
|
||||||
|
|
||||||
static struct ggml_tensor * ggml_soft_max_ext_back_impl(
|
static struct ggml_tensor * ggml_soft_max_ext_back_impl(
|
||||||
|
@ -4828,6 +4887,22 @@ enum ggml_prec ggml_flash_attn_ext_get_prec(
|
||||||
return (enum ggml_prec) prec_i32;
|
return (enum ggml_prec) prec_i32;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_flash_attn_ext_add_sinks(
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * sinks) {
|
||||||
|
if (!sinks) {
|
||||||
|
a->src[4] = NULL;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
|
||||||
|
GGML_ASSERT(a->src[4] == NULL);
|
||||||
|
GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
|
||||||
|
GGML_ASSERT(sinks->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
a->src[4] = sinks;
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_flash_attn_back
|
// ggml_flash_attn_back
|
||||||
|
|
||||||
struct ggml_tensor * ggml_flash_attn_back(
|
struct ggml_tensor * ggml_flash_attn_back(
|
||||||
|
@ -6888,6 +6963,7 @@ size_t ggml_quantize_chunk(
|
||||||
case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||||
case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||||
case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||||
|
case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||||
case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||||
case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||||
case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||||
|
|
|
@ -380,6 +380,7 @@ class MODEL_ARCH(IntEnum):
|
||||||
HUNYUAN_MOE = auto()
|
HUNYUAN_MOE = auto()
|
||||||
HUNYUAN_DENSE = auto()
|
HUNYUAN_DENSE = auto()
|
||||||
SMOLLM3 = auto()
|
SMOLLM3 = auto()
|
||||||
|
GPT_OSS = auto()
|
||||||
LFM2 = auto()
|
LFM2 = auto()
|
||||||
DREAM = auto()
|
DREAM = auto()
|
||||||
SMALLTHINKER = auto()
|
SMALLTHINKER = auto()
|
||||||
|
@ -416,6 +417,7 @@ class MODEL_TENSOR(IntEnum):
|
||||||
ATTN_OUT_NORM = auto()
|
ATTN_OUT_NORM = auto()
|
||||||
ATTN_POST_NORM = auto()
|
ATTN_POST_NORM = auto()
|
||||||
ATTN_ROT_EMBD = auto()
|
ATTN_ROT_EMBD = auto()
|
||||||
|
ATTN_SINKS = auto()
|
||||||
FFN_GATE_INP = auto()
|
FFN_GATE_INP = auto()
|
||||||
FFN_GATE_INP_SHEXP = auto()
|
FFN_GATE_INP_SHEXP = auto()
|
||||||
FFN_NORM = auto()
|
FFN_NORM = auto()
|
||||||
|
@ -710,6 +712,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
|
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
|
||||||
MODEL_ARCH.HUNYUAN_DENSE: "hunyuan-dense",
|
MODEL_ARCH.HUNYUAN_DENSE: "hunyuan-dense",
|
||||||
MODEL_ARCH.SMOLLM3: "smollm3",
|
MODEL_ARCH.SMOLLM3: "smollm3",
|
||||||
|
MODEL_ARCH.GPT_OSS: "gpt-oss",
|
||||||
MODEL_ARCH.LFM2: "lfm2",
|
MODEL_ARCH.LFM2: "lfm2",
|
||||||
MODEL_ARCH.DREAM: "dream",
|
MODEL_ARCH.DREAM: "dream",
|
||||||
MODEL_ARCH.SMALLTHINKER: "smallthinker",
|
MODEL_ARCH.SMALLTHINKER: "smallthinker",
|
||||||
|
@ -744,6 +747,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
|
MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
|
||||||
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
|
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
|
||||||
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
|
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
|
||||||
|
MODEL_TENSOR.ATTN_SINKS: "blk.{bid}.attn_sinks",
|
||||||
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
|
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
|
||||||
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
|
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
|
||||||
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
|
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
|
||||||
|
@ -2553,6 +2557,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.FFN_DOWN,
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
MODEL_TENSOR.FFN_UP,
|
MODEL_TENSOR.FFN_UP,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.GPT_OSS: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_POST_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_Q,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.ATTN_SINKS,
|
||||||
|
MODEL_TENSOR.FFN_GATE_INP,
|
||||||
|
MODEL_TENSOR.FFN_GATE_EXP,
|
||||||
|
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
|
MODEL_TENSOR.FFN_UP_EXP,
|
||||||
|
],
|
||||||
MODEL_ARCH.LFM2: [
|
MODEL_ARCH.LFM2: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||||
|
@ -2707,6 +2727,7 @@ class GGMLQuantizationType(IntEnum):
|
||||||
BF16 = 30
|
BF16 = 30
|
||||||
TQ1_0 = 34
|
TQ1_0 = 34
|
||||||
TQ2_0 = 35
|
TQ2_0 = 35
|
||||||
|
MXFP4 = 39
|
||||||
|
|
||||||
|
|
||||||
class ExpertGatingFuncType(IntEnum):
|
class ExpertGatingFuncType(IntEnum):
|
||||||
|
@ -2847,6 +2868,7 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
|
||||||
GGMLQuantizationType.BF16: (1, 2),
|
GGMLQuantizationType.BF16: (1, 2),
|
||||||
GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13),
|
GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13),
|
||||||
GGMLQuantizationType.TQ2_0: (256, 2 + 64),
|
GGMLQuantizationType.TQ2_0: (256, 2 + 64),
|
||||||
|
GGMLQuantizationType.MXFP4: (32, 1 + 16),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -138,8 +138,9 @@ class GGUFWriter:
|
||||||
size = prod(shape)
|
size = prod(shape)
|
||||||
|
|
||||||
if "_exps." in name:
|
if "_exps." in name:
|
||||||
expert_params += (size // shape[-3])
|
expert_count = shape[-2 if ".bias" in name else -3]
|
||||||
expert_sum += shape[-3]
|
expert_params += (size // expert_count)
|
||||||
|
expert_sum += expert_count
|
||||||
n_expert_tensors += 1
|
n_expert_tensors += 1
|
||||||
else:
|
else:
|
||||||
shared_params += size
|
shared_params += size
|
||||||
|
|
|
@ -285,6 +285,10 @@ class TensorNameMap:
|
||||||
"transformer.h.{bid}.attn.rotary_emb.inv_freq", # codeshell
|
"transformer.h.{bid}.attn.rotary_emb.inv_freq", # codeshell
|
||||||
),
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.ATTN_SINKS: (
|
||||||
|
"model.layers.{bid}.self_attn.sinks", # openai-moe
|
||||||
|
),
|
||||||
|
|
||||||
# Feed-forward norm
|
# Feed-forward norm
|
||||||
MODEL_TENSOR.FFN_NORM: (
|
MODEL_TENSOR.FFN_NORM: (
|
||||||
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
|
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
|
||||||
|
@ -332,6 +336,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
|
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
|
||||||
"model.layers.{bid}.feed_forward.router", # llama4 jamba
|
"model.layers.{bid}.feed_forward.router", # llama4 jamba
|
||||||
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
|
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
|
||||||
|
"model.layers.{bid}.mlp.router", # openai-moe
|
||||||
"model.layers.{bid}.mlp.gate.wg", # hunyuan
|
"model.layers.{bid}.mlp.gate.wg", # hunyuan
|
||||||
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
|
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
|
||||||
),
|
),
|
||||||
|
|
|
@ -155,6 +155,7 @@ extern "C" {
|
||||||
//LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // removed from gguf files, use Q4_0 and runtime repack
|
//LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // removed from gguf files, use Q4_0 and runtime repack
|
||||||
LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors
|
LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors
|
||||||
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors
|
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38, // except 1d tensors
|
||||||
|
|
||||||
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||||
};
|
};
|
||||||
|
|
|
@ -88,6 +88,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||||
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
|
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
|
||||||
{ LLM_ARCH_HUNYUAN_DENSE, "hunyuan-dense" },
|
{ LLM_ARCH_HUNYUAN_DENSE, "hunyuan-dense" },
|
||||||
{ LLM_ARCH_SMOLLM3, "smollm3" },
|
{ LLM_ARCH_SMOLLM3, "smollm3" },
|
||||||
|
{ LLM_ARCH_OPENAI_MOE, "gpt-oss" },
|
||||||
{ LLM_ARCH_LFM2, "lfm2" },
|
{ LLM_ARCH_LFM2, "lfm2" },
|
||||||
{ LLM_ARCH_DREAM, "dream" },
|
{ LLM_ARCH_DREAM, "dream" },
|
||||||
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
|
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
|
||||||
|
@ -1971,6 +1972,25 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_OPENAI_MOE,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_ATTN_SINKS, "blk.%d.attn_sinks" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_LFM2,
|
LLM_ARCH_LFM2,
|
||||||
{
|
{
|
||||||
|
@ -2086,6 +2106,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||||
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_ATTN_SINKS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SCALE}},
|
||||||
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
|
|
@ -92,6 +92,7 @@ enum llm_arch {
|
||||||
LLM_ARCH_HUNYUAN_MOE,
|
LLM_ARCH_HUNYUAN_MOE,
|
||||||
LLM_ARCH_HUNYUAN_DENSE,
|
LLM_ARCH_HUNYUAN_DENSE,
|
||||||
LLM_ARCH_SMOLLM3,
|
LLM_ARCH_SMOLLM3,
|
||||||
|
LLM_ARCH_OPENAI_MOE,
|
||||||
LLM_ARCH_LFM2,
|
LLM_ARCH_LFM2,
|
||||||
LLM_ARCH_DREAM,
|
LLM_ARCH_DREAM,
|
||||||
LLM_ARCH_SMALLTHINKER,
|
LLM_ARCH_SMALLTHINKER,
|
||||||
|
@ -265,6 +266,7 @@ enum llm_tensor {
|
||||||
LLM_TENSOR_ATTN_OUT_NORM,
|
LLM_TENSOR_ATTN_OUT_NORM,
|
||||||
LLM_TENSOR_ATTN_POST_NORM,
|
LLM_TENSOR_ATTN_POST_NORM,
|
||||||
LLM_TENSOR_ATTN_ROT_EMBD,
|
LLM_TENSOR_ATTN_ROT_EMBD,
|
||||||
|
LLM_TENSOR_ATTN_SINKS,
|
||||||
LLM_TENSOR_FFN_GATE_INP,
|
LLM_TENSOR_FFN_GATE_INP,
|
||||||
LLM_TENSOR_FFN_GATE_INP_SHEXP,
|
LLM_TENSOR_FFN_GATE_INP_SHEXP,
|
||||||
LLM_TENSOR_FFN_NORM,
|
LLM_TENSOR_FFN_NORM,
|
||||||
|
|
|
@ -66,6 +66,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||||
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
||||||
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
||||||
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
|
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
|
||||||
|
{ "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE },
|
||||||
{ "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE },
|
{ "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE },
|
||||||
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
|
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
|
||||||
};
|
};
|
||||||
|
@ -194,6 +195,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||||
return LLM_CHAT_TEMPLATE_DOTS1;
|
return LLM_CHAT_TEMPLATE_DOTS1;
|
||||||
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
|
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
|
||||||
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
|
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
|
||||||
|
} else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) {
|
||||||
|
return LLM_CHAT_TEMPLATE_OPENAI_MOE;
|
||||||
} else if (tmpl_contains("<|hy_place▁holder▁no▁2|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) {
|
} else if (tmpl_contains("<|hy_place▁holder▁no▁2|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) {
|
||||||
return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE;
|
return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE;
|
||||||
} else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
|
} else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
|
||||||
|
@ -706,6 +709,16 @@ int32_t llm_chat_apply_template(
|
||||||
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
|
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if (tmpl == LLM_CHAT_TEMPLATE_OPENAI_MOE) {
|
||||||
|
// OpenAI MoE (based on Harmony chat template)
|
||||||
|
for (auto message : chat) {
|
||||||
|
std::string role(message->role);
|
||||||
|
ss << "<|start|>" << role << "<|message|>" << message->content;
|
||||||
|
ss << (role == "assistant" ? "<|return|>" : "<|end|>");
|
||||||
|
}
|
||||||
|
if (add_ass) {
|
||||||
|
ss << "<|start|>assistant";
|
||||||
|
}
|
||||||
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_DENSE) {
|
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_DENSE) {
|
||||||
// tencent/Hunyuan-4B-Instruct
|
// tencent/Hunyuan-4B-Instruct
|
||||||
for (size_t i = 0; i < chat.size(); i++) {
|
for (size_t i = 0; i < chat.size(); i++) {
|
||||||
|
|
|
@ -46,6 +46,7 @@ enum llm_chat_template {
|
||||||
LLM_CHAT_TEMPLATE_SMOLVLM,
|
LLM_CHAT_TEMPLATE_SMOLVLM,
|
||||||
LLM_CHAT_TEMPLATE_DOTS1,
|
LLM_CHAT_TEMPLATE_DOTS1,
|
||||||
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
|
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
|
||||||
|
LLM_CHAT_TEMPLATE_OPENAI_MOE,
|
||||||
LLM_CHAT_TEMPLATE_HUNYUAN_DENSE,
|
LLM_CHAT_TEMPLATE_HUNYUAN_DENSE,
|
||||||
LLM_CHAT_TEMPLATE_KIMI_K2,
|
LLM_CHAT_TEMPLATE_KIMI_K2,
|
||||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||||
|
|
|
@ -740,6 +740,8 @@ ggml_tensor * llm_graph_context::build_ffn(
|
||||||
cur = ggml_reglu(ctx0, cur);
|
cur = ggml_reglu(ctx0, cur);
|
||||||
cb(cur, "ffn_reglu", il);
|
cb(cur, "ffn_reglu", il);
|
||||||
} break;
|
} break;
|
||||||
|
default:
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (gate && type_gate == LLM_FFN_PAR) {
|
if (gate && type_gate == LLM_FFN_PAR) {
|
||||||
|
@ -787,6 +789,45 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||||
llama_expert_gating_func_type gating_op,
|
llama_expert_gating_func_type gating_op,
|
||||||
int il,
|
int il,
|
||||||
ggml_tensor * probs_in) const {
|
ggml_tensor * probs_in) const {
|
||||||
|
return build_moe_ffn(
|
||||||
|
cur,
|
||||||
|
gate_inp, /* gate_inp_b */ nullptr,
|
||||||
|
up_exps, /* up_exps_b */ nullptr,
|
||||||
|
gate_exps, /* gate_exps_b */ nullptr,
|
||||||
|
down_exps, /* down_exps_b */ nullptr,
|
||||||
|
exp_probs_b,
|
||||||
|
n_expert,
|
||||||
|
n_expert_used,
|
||||||
|
type_op,
|
||||||
|
norm_w,
|
||||||
|
scale_w,
|
||||||
|
w_scale,
|
||||||
|
gating_op,
|
||||||
|
il,
|
||||||
|
probs_in
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||||
|
ggml_tensor * cur,
|
||||||
|
ggml_tensor * gate_inp,
|
||||||
|
ggml_tensor * gate_inp_b,
|
||||||
|
ggml_tensor * up_exps,
|
||||||
|
ggml_tensor * up_exps_b,
|
||||||
|
ggml_tensor * gate_exps,
|
||||||
|
ggml_tensor * gate_exps_b,
|
||||||
|
ggml_tensor * down_exps,
|
||||||
|
ggml_tensor * down_exps_b,
|
||||||
|
ggml_tensor * exp_probs_b,
|
||||||
|
int64_t n_expert,
|
||||||
|
int64_t n_expert_used,
|
||||||
|
llm_ffn_op_type type_op,
|
||||||
|
bool norm_w,
|
||||||
|
bool scale_w,
|
||||||
|
float w_scale,
|
||||||
|
llama_expert_gating_func_type gating_op,
|
||||||
|
int il,
|
||||||
|
ggml_tensor * probs_in) const {
|
||||||
const int64_t n_embd = cur->ne[0];
|
const int64_t n_embd = cur->ne[0];
|
||||||
const int64_t n_tokens = cur->ne[1];
|
const int64_t n_tokens = cur->ne[1];
|
||||||
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
|
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
|
||||||
|
@ -800,6 +841,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||||
logits = probs_in;
|
logits = probs_in;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (gate_inp_b) {
|
||||||
|
logits = ggml_add(ctx0, logits, gate_inp_b);
|
||||||
|
cb(logits, "ffn_moe_logits_biased", il);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor * probs = nullptr;
|
ggml_tensor * probs = nullptr;
|
||||||
switch (gating_op) {
|
switch (gating_op) {
|
||||||
case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
|
case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
|
||||||
|
@ -810,6 +856,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||||
{
|
{
|
||||||
probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
|
probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
|
||||||
} break;
|
} break;
|
||||||
|
case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
|
||||||
|
{
|
||||||
|
probs = logits; // [n_expert, n_tokens]
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
@ -838,6 +888,13 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||||
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
|
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
|
||||||
cb(weights, "ffn_moe_weights", il);
|
cb(weights, "ffn_moe_weights", il);
|
||||||
|
|
||||||
|
if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
|
||||||
|
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
|
||||||
|
weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
|
||||||
|
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
|
||||||
|
cb(weights, "ffn_moe_weights_softmax", il);
|
||||||
|
}
|
||||||
|
|
||||||
if (norm_w) {
|
if (norm_w) {
|
||||||
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
|
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
|
||||||
|
|
||||||
|
@ -866,6 +923,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||||
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
||||||
cb(up, "ffn_moe_up", il);
|
cb(up, "ffn_moe_up", il);
|
||||||
|
|
||||||
|
if (up_exps_b) {
|
||||||
|
up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
|
||||||
|
cb(up, "ffn_moe_up_biased", il);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor * experts = nullptr;
|
ggml_tensor * experts = nullptr;
|
||||||
if (gate_exps) {
|
if (gate_exps) {
|
||||||
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
||||||
|
@ -874,6 +936,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||||
cur = up;
|
cur = up;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (gate_exps_b) {
|
||||||
|
cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
|
||||||
|
cb(cur, "ffn_moe_gate_biased", il);
|
||||||
|
}
|
||||||
|
|
||||||
switch (type_op) {
|
switch (type_op) {
|
||||||
case LLM_FFN_SILU:
|
case LLM_FFN_SILU:
|
||||||
if (gate_exps) {
|
if (gate_exps) {
|
||||||
|
@ -891,6 +958,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||||
cur = ggml_gelu(ctx0, cur);
|
cur = ggml_gelu(ctx0, cur);
|
||||||
cb(cur, "ffn_moe_gelu", il);
|
cb(cur, "ffn_moe_gelu", il);
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_FFN_SWIGLU_OAI_MOE:
|
||||||
|
{
|
||||||
|
// TODO: move to hparams?
|
||||||
|
constexpr float alpha = 1.702f;
|
||||||
|
constexpr float limit = 7.0f;
|
||||||
|
cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit);
|
||||||
|
cb(cur, "ffn_moe_swiglu_oai", il);
|
||||||
|
} break;
|
||||||
case LLM_FFN_RELU:
|
case LLM_FFN_RELU:
|
||||||
if (gate_exps) {
|
if (gate_exps) {
|
||||||
cur = ggml_reglu_split(ctx0, cur, up);
|
cur = ggml_reglu_split(ctx0, cur, up);
|
||||||
|
@ -906,6 +981,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||||
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
||||||
cb(experts, "ffn_moe_down", il);
|
cb(experts, "ffn_moe_down", il);
|
||||||
|
|
||||||
|
if (down_exps_b) {
|
||||||
|
experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts);
|
||||||
|
cb(experts, "ffn_moe_down_biased", il);
|
||||||
|
}
|
||||||
|
|
||||||
if (!weight_before_ffn) {
|
if (!weight_before_ffn) {
|
||||||
experts = ggml_mul(ctx0, experts, weights);
|
experts = ggml_mul(ctx0, experts, weights);
|
||||||
cb(cur, "ffn_moe_weighted", il);
|
cb(cur, "ffn_moe_weighted", il);
|
||||||
|
@ -1144,6 +1224,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||||
ggml_tensor * kq_b,
|
ggml_tensor * kq_b,
|
||||||
ggml_tensor * kq_mask,
|
ggml_tensor * kq_mask,
|
||||||
ggml_tensor * v_mla,
|
ggml_tensor * v_mla,
|
||||||
|
ggml_tensor * sinks,
|
||||||
float kq_scale) const {
|
float kq_scale) const {
|
||||||
const bool v_trans = v->nb[1] > v->nb[2];
|
const bool v_trans = v->nb[1] > v->nb[2];
|
||||||
|
|
||||||
|
@ -1180,6 +1261,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||||
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||||
|
|
||||||
|
ggml_flash_attn_ext_add_sinks(cur, sinks);
|
||||||
ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
|
ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
|
||||||
|
|
||||||
if (v_mla) {
|
if (v_mla) {
|
||||||
|
@ -1228,6 +1310,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||||
}
|
}
|
||||||
|
|
||||||
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||||
|
ggml_soft_max_add_sinks(kq, sinks);
|
||||||
|
|
||||||
if (!v_trans) {
|
if (!v_trans) {
|
||||||
// note: avoid this branch
|
// note: avoid this branch
|
||||||
|
@ -1298,7 +1381,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
ggml_tensor * k = k_cur;
|
ggml_tensor * k = k_cur;
|
||||||
ggml_tensor * v = v_cur;
|
ggml_tensor * v = v_cur;
|
||||||
|
|
||||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
|
||||||
if (wo) {
|
if (wo) {
|
||||||
|
@ -1386,7 +1469,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||||
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
||||||
|
|
||||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
|
||||||
if (wo) {
|
if (wo) {
|
||||||
|
@ -1415,6 +1498,32 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
ggml_tensor * v_mla,
|
ggml_tensor * v_mla,
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
int il) const {
|
int il) const {
|
||||||
|
return build_attn_with_sinks(
|
||||||
|
inp,
|
||||||
|
wo,
|
||||||
|
wo_b,
|
||||||
|
q_cur,
|
||||||
|
k_cur,
|
||||||
|
v_cur,
|
||||||
|
kq_b,
|
||||||
|
v_mla,
|
||||||
|
nullptr,
|
||||||
|
kq_scale,
|
||||||
|
il);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llm_graph_context::build_attn_with_sinks(
|
||||||
|
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||||
|
ggml_tensor * wo,
|
||||||
|
ggml_tensor * wo_b,
|
||||||
|
ggml_tensor * q_cur,
|
||||||
|
ggml_tensor * k_cur,
|
||||||
|
ggml_tensor * v_cur,
|
||||||
|
ggml_tensor * kq_b,
|
||||||
|
ggml_tensor * v_mla,
|
||||||
|
ggml_tensor * sinks,
|
||||||
|
float kq_scale,
|
||||||
|
int il) const {
|
||||||
// these nodes are added to the graph together so that they are not reordered
|
// these nodes are added to the graph together so that they are not reordered
|
||||||
// by doing so, the number of splits in the graph is reduced
|
// by doing so, the number of splits in the graph is reduced
|
||||||
ggml_build_forward_expand(gf, q_cur);
|
ggml_build_forward_expand(gf, q_cur);
|
||||||
|
@ -1452,7 +1561,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||||
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
||||||
|
|
||||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, sinks, kq_scale);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
|
||||||
if (wo) {
|
if (wo) {
|
||||||
|
@ -1506,7 +1615,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
ggml_tensor * k = k_cur;
|
ggml_tensor * k = k_cur;
|
||||||
ggml_tensor * v = v_cur;
|
ggml_tensor * v = v_cur;
|
||||||
|
|
||||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
|
||||||
if (wo) {
|
if (wo) {
|
||||||
|
|
|
@ -39,6 +39,7 @@ enum llm_ffn_op_type {
|
||||||
LLM_FFN_SWIGLU,
|
LLM_FFN_SWIGLU,
|
||||||
LLM_FFN_GEGLU,
|
LLM_FFN_GEGLU,
|
||||||
LLM_FFN_REGLU,
|
LLM_FFN_REGLU,
|
||||||
|
LLM_FFN_SWIGLU_OAI_MOE,
|
||||||
};
|
};
|
||||||
|
|
||||||
enum llm_ffn_gate_type {
|
enum llm_ffn_gate_type {
|
||||||
|
@ -619,6 +620,7 @@ struct llm_graph_context {
|
||||||
llm_ffn_gate_type type_gate,
|
llm_ffn_gate_type type_gate,
|
||||||
int il) const;
|
int il) const;
|
||||||
|
|
||||||
|
// build MoE FFN without bias tensors
|
||||||
ggml_tensor * build_moe_ffn(
|
ggml_tensor * build_moe_ffn(
|
||||||
ggml_tensor * cur,
|
ggml_tensor * cur,
|
||||||
ggml_tensor * gate_inp,
|
ggml_tensor * gate_inp,
|
||||||
|
@ -636,6 +638,27 @@ struct llm_graph_context {
|
||||||
int il,
|
int il,
|
||||||
ggml_tensor * probs_in = nullptr) const;
|
ggml_tensor * probs_in = nullptr) const;
|
||||||
|
|
||||||
|
ggml_tensor * build_moe_ffn(
|
||||||
|
ggml_tensor * cur,
|
||||||
|
ggml_tensor * gate_inp,
|
||||||
|
ggml_tensor * gate_inp_b,
|
||||||
|
ggml_tensor * up_exps,
|
||||||
|
ggml_tensor * up_exps_b,
|
||||||
|
ggml_tensor * gate_exps,
|
||||||
|
ggml_tensor * gate_exps_b,
|
||||||
|
ggml_tensor * down_exps,
|
||||||
|
ggml_tensor * down_exps_b,
|
||||||
|
ggml_tensor * exp_probs_b,
|
||||||
|
int64_t n_expert,
|
||||||
|
int64_t n_expert_used,
|
||||||
|
llm_ffn_op_type type_op,
|
||||||
|
bool norm_w,
|
||||||
|
bool scale_w,
|
||||||
|
float w_scale,
|
||||||
|
llama_expert_gating_func_type gating_op,
|
||||||
|
int il,
|
||||||
|
ggml_tensor * probs_in = nullptr) const;
|
||||||
|
|
||||||
//
|
//
|
||||||
// inputs
|
// inputs
|
||||||
//
|
//
|
||||||
|
@ -662,6 +685,7 @@ struct llm_graph_context {
|
||||||
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
|
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
|
||||||
ggml_tensor * kq_b,
|
ggml_tensor * kq_b,
|
||||||
ggml_tensor * kq_mask,
|
ggml_tensor * kq_mask,
|
||||||
|
ggml_tensor * sinks,
|
||||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||||
float kq_scale) const;
|
float kq_scale) const;
|
||||||
|
|
||||||
|
@ -708,6 +732,20 @@ struct llm_graph_context {
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
int il) const;
|
int il) const;
|
||||||
|
|
||||||
|
// TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
|
||||||
|
ggml_tensor * build_attn_with_sinks(
|
||||||
|
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||||
|
ggml_tensor * wo,
|
||||||
|
ggml_tensor * wo_b,
|
||||||
|
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||||
|
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
|
||||||
|
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
|
||||||
|
ggml_tensor * kq_b,
|
||||||
|
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||||
|
ggml_tensor * sinks, // [n_head_q]
|
||||||
|
float kq_scale,
|
||||||
|
int il) const;
|
||||||
|
|
||||||
llm_graph_input_attn_cross * build_attn_inp_cross() const;
|
llm_graph_input_attn_cross * build_attn_inp_cross() const;
|
||||||
|
|
||||||
ggml_tensor * build_attn(
|
ggml_tensor * build_attn(
|
||||||
|
|
|
@ -12,6 +12,7 @@ enum llama_expert_gating_func_type {
|
||||||
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
|
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
|
||||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1,
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1,
|
||||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
|
||||||
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT = 3, // applied to the router weights instead of the logits
|
||||||
};
|
};
|
||||||
|
|
||||||
enum llama_swa_type {
|
enum llama_swa_type {
|
||||||
|
|
|
@ -39,6 +39,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
|
||||||
case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0";
|
case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0";
|
||||||
case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1";
|
case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1";
|
||||||
case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0";
|
case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0";
|
||||||
|
case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return "MXFP4 MoE";
|
||||||
case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium";
|
case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium";
|
||||||
case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small";
|
case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small";
|
||||||
case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small";
|
case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small";
|
||||||
|
|
|
@ -197,6 +197,13 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
|
||||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
|
ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
|
||||||
op_tensor = ggml_add(ctx, a, w);
|
op_tensor = ggml_add(ctx, a, w);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
|
{
|
||||||
|
int n_expert_used = hparams.n_expert_used;
|
||||||
|
ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512);
|
||||||
|
ggml_tensor * c = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512);
|
||||||
|
op_tensor = ggml_add_id(ctx, a, w, c);
|
||||||
|
} break;
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
{
|
{
|
||||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
|
ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
|
||||||
|
@ -265,6 +272,10 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
|
||||||
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1);
|
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1);
|
||||||
op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16);
|
op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_SCALE:
|
||||||
|
{
|
||||||
|
op_tensor = ggml_scale(ctx, w, 1.0f);
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name);
|
GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name);
|
||||||
}
|
}
|
||||||
|
@ -1818,6 +1829,17 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
default: type = LLM_TYPE_UNKNOWN;
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_OPENAI_MOE:
|
||||||
|
{
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
||||||
|
|
||||||
|
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||||
|
hparams.set_swa_pattern(2);
|
||||||
|
|
||||||
|
// TODO: switch (hparams.n_layer)
|
||||||
|
} break;
|
||||||
case LLM_ARCH_LFM2:
|
case LLM_ARCH_LFM2:
|
||||||
{
|
{
|
||||||
ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache);
|
ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache);
|
||||||
|
@ -2061,11 +2083,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// tensors with "bias" suffix are always used with GGML_OP_ADD
|
// tensors with "bias" suffix are always used with GGML_OP_ADD or GGML_OP_ADD_ID
|
||||||
ggml_op op;
|
ggml_op op;
|
||||||
bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0;
|
bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0;
|
||||||
if (bias) {
|
if (bias) {
|
||||||
|
if (info.op == GGML_OP_MUL_MAT_ID) {
|
||||||
|
op = GGML_OP_ADD_ID;
|
||||||
|
} else {
|
||||||
op = GGML_OP_ADD;
|
op = GGML_OP_ADD;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
op = info.op;
|
op = info.op;
|
||||||
}
|
}
|
||||||
|
@ -5496,6 +5522,46 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_OPENAI_MOE:
|
||||||
|
{
|
||||||
|
const int64_t n_ff_exp = hparams.n_ff_exp;
|
||||||
|
|
||||||
|
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||||
|
|
||||||
|
// output
|
||||||
|
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||||
|
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||||
|
|
||||||
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
|
auto & layer = layers[i];
|
||||||
|
|
||||||
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||||
|
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||||
|
|
||||||
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0);
|
||||||
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0);
|
||||||
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0);
|
||||||
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0);
|
||||||
|
|
||||||
|
layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0);
|
||||||
|
|
||||||
|
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0);
|
||||||
|
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
||||||
|
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
||||||
|
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
||||||
|
|
||||||
|
// bias
|
||||||
|
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_head * n_rot}, 0);
|
||||||
|
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_head_kv * n_rot}, 0);
|
||||||
|
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_head_kv * n_rot}, 0);
|
||||||
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
|
||||||
|
|
||||||
|
layer.ffn_gate_inp_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0);
|
||||||
|
layer.ffn_gate_exps_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0);
|
||||||
|
layer.ffn_down_exps_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "bias", i), { n_embd, n_expert}, 0);
|
||||||
|
layer.ffn_up_exps_b = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i), {n_ff_exp, n_expert}, 0);
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case LLM_ARCH_LFM2:
|
case LLM_ARCH_LFM2:
|
||||||
{
|
{
|
||||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||||
|
@ -5869,7 +5935,7 @@ void llama_model::print_info() const {
|
||||||
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
|
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (arch == LLM_ARCH_QWEN3MOE) {
|
if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE) {
|
||||||
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17640,6 +17706,136 @@ struct llm_build_smollm3 : public llm_graph_context {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct llm_build_openai_moe_iswa : public llm_graph_context {
|
||||||
|
llm_build_openai_moe_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||||
|
ggml_tensor * cur;
|
||||||
|
ggml_tensor * inpL;
|
||||||
|
|
||||||
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
|
|
||||||
|
// inp_pos - contains the positions
|
||||||
|
ggml_tensor * inp_pos = build_inp_pos();
|
||||||
|
|
||||||
|
auto * inp_attn = build_attn_inp_kv_unified_iswa();
|
||||||
|
|
||||||
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
ggml_tensor * inpSA = inpL;
|
||||||
|
|
||||||
|
// norm
|
||||||
|
cur = build_norm(inpL,
|
||||||
|
model.layers[il].attn_norm, nullptr,
|
||||||
|
LLM_NORM_RMS, il);
|
||||||
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
|
// self-attention
|
||||||
|
{
|
||||||
|
// compute Q and K and RoPE them
|
||||||
|
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
if (model.layers[il].bq) {
|
||||||
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
if (model.layers[il].bk) {
|
||||||
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
if (model.layers[il].bv) {
|
||||||
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens);
|
||||||
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens);
|
||||||
|
Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens);
|
||||||
|
|
||||||
|
Qcur = ggml_rope_ext(
|
||||||
|
ctx0, Qcur, inp_pos, nullptr,
|
||||||
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
|
||||||
|
Kcur = ggml_rope_ext(
|
||||||
|
ctx0, Kcur, inp_pos, nullptr,
|
||||||
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
|
cur = build_attn_with_sinks(inp_attn,
|
||||||
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
|
Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].attn_sinks, 1.0f/sqrtf(float(n_rot)), il);
|
||||||
|
|
||||||
|
cb(cur, "attn_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (il == n_layer - 1) {
|
||||||
|
// skip computing output for unused tokens
|
||||||
|
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||||
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||||
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||||
|
cb(ffn_inp, "ffn_inp", il);
|
||||||
|
|
||||||
|
cur = ffn_inp;
|
||||||
|
cur = build_norm(cur,
|
||||||
|
model.layers[il].attn_post_norm, nullptr,
|
||||||
|
LLM_NORM_RMS, il);
|
||||||
|
cb(cur, "attn_post_norm", il);
|
||||||
|
|
||||||
|
// MoE branch
|
||||||
|
cur = build_moe_ffn(cur,
|
||||||
|
model.layers[il].ffn_gate_inp, model.layers[il].ffn_gate_inp_b,
|
||||||
|
model.layers[il].ffn_up_exps, model.layers[il].ffn_up_exps_b,
|
||||||
|
model.layers[il].ffn_gate_exps, model.layers[il].ffn_gate_exps_b,
|
||||||
|
model.layers[il].ffn_down_exps, model.layers[il].ffn_down_exps_b,
|
||||||
|
nullptr,
|
||||||
|
n_expert, n_expert_used,
|
||||||
|
LLM_FFN_SWIGLU_OAI_MOE, false,
|
||||||
|
false, 0.0,
|
||||||
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT,
|
||||||
|
il);
|
||||||
|
cb(cur, "ffn_moe_out", il);
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||||
|
|
||||||
|
cur = build_cvec(cur, il);
|
||||||
|
cb(cur, "l_out", il);
|
||||||
|
|
||||||
|
// input for next layer
|
||||||
|
inpL = cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = inpL;
|
||||||
|
|
||||||
|
cur = build_norm(cur,
|
||||||
|
model.output_norm, NULL,
|
||||||
|
LLM_NORM_RMS, -1);
|
||||||
|
|
||||||
|
cb(cur, "result_norm", -1);
|
||||||
|
res->t_embd = cur;
|
||||||
|
|
||||||
|
// lm_head
|
||||||
|
cur = build_lora_mm(model.output, cur);
|
||||||
|
|
||||||
|
cb(cur, "result_output", -1);
|
||||||
|
res->t_logits = cur;
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct llm_build_lfm2 : public llm_graph_context {
|
struct llm_build_lfm2 : public llm_graph_context {
|
||||||
const llama_model & model;
|
const llama_model & model;
|
||||||
|
|
||||||
|
@ -18383,6 +18579,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||||
{
|
{
|
||||||
llm = std::make_unique<llm_build_smollm3>(*this, params);
|
llm = std::make_unique<llm_build_smollm3>(*this, params);
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_OPENAI_MOE:
|
||||||
|
{
|
||||||
|
llm = std::make_unique<llm_build_openai_moe_iswa>(*this, params);
|
||||||
|
} break;
|
||||||
case LLM_ARCH_FALCON_H1:
|
case LLM_ARCH_FALCON_H1:
|
||||||
{
|
{
|
||||||
llm = std::make_unique<llm_build_falcon_h1>(*this, params);
|
llm = std::make_unique<llm_build_falcon_h1>(*this, params);
|
||||||
|
@ -18598,6 +18798,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||||
case LLM_ARCH_MINICPM3:
|
case LLM_ARCH_MINICPM3:
|
||||||
case LLM_ARCH_DOTS1:
|
case LLM_ARCH_DOTS1:
|
||||||
case LLM_ARCH_HUNYUAN_MOE:
|
case LLM_ARCH_HUNYUAN_MOE:
|
||||||
|
case LLM_ARCH_OPENAI_MOE:
|
||||||
case LLM_ARCH_HUNYUAN_DENSE:
|
case LLM_ARCH_HUNYUAN_DENSE:
|
||||||
case LLM_ARCH_LFM2:
|
case LLM_ARCH_LFM2:
|
||||||
case LLM_ARCH_SMALLTHINKER:
|
case LLM_ARCH_SMALLTHINKER:
|
||||||
|
|
|
@ -256,6 +256,10 @@ struct llama_layer {
|
||||||
struct ggml_tensor * ffn_gate_exps = nullptr;
|
struct ggml_tensor * ffn_gate_exps = nullptr;
|
||||||
struct ggml_tensor * ffn_down_exps = nullptr;
|
struct ggml_tensor * ffn_down_exps = nullptr;
|
||||||
struct ggml_tensor * ffn_up_exps = nullptr;
|
struct ggml_tensor * ffn_up_exps = nullptr;
|
||||||
|
struct ggml_tensor * ffn_gate_inp_b = nullptr;
|
||||||
|
struct ggml_tensor * ffn_gate_exps_b = nullptr;
|
||||||
|
struct ggml_tensor * ffn_down_exps_b = nullptr;
|
||||||
|
struct ggml_tensor * ffn_up_exps_b = nullptr;
|
||||||
|
|
||||||
// ff shared expert (shexp)
|
// ff shared expert (shexp)
|
||||||
struct ggml_tensor * ffn_gate_inp_shexp = nullptr;
|
struct ggml_tensor * ffn_gate_inp_shexp = nullptr;
|
||||||
|
@ -360,6 +364,9 @@ struct llama_layer {
|
||||||
struct ggml_tensor * laurel_r = nullptr;
|
struct ggml_tensor * laurel_r = nullptr;
|
||||||
struct ggml_tensor * laurel_post_norm = nullptr;
|
struct ggml_tensor * laurel_post_norm = nullptr;
|
||||||
|
|
||||||
|
// openai-moe
|
||||||
|
struct ggml_tensor * attn_sinks = nullptr;
|
||||||
|
|
||||||
struct llama_layer_posnet posnet;
|
struct llama_layer_posnet posnet;
|
||||||
|
|
||||||
struct llama_layer_convnext convnext;
|
struct llama_layer_convnext convnext;
|
||||||
|
|
|
@ -211,7 +211,10 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
|
||||||
const int64_t nx = tensor->ne[0];
|
const int64_t nx = tensor->ne[0];
|
||||||
const int64_t qk_k = ggml_blck_size(new_type);
|
const int64_t qk_k = ggml_blck_size(new_type);
|
||||||
|
|
||||||
if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) {
|
if (ftype == LLAMA_FTYPE_MOSTLY_MXFP4_MOE) {
|
||||||
|
new_type = GGML_TYPE_Q8_0;
|
||||||
|
}
|
||||||
|
else if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) {
|
||||||
new_type = GGML_TYPE_Q8_0;
|
new_type = GGML_TYPE_Q8_0;
|
||||||
}
|
}
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
|
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
|
||||||
|
@ -223,6 +226,14 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
|
||||||
new_type = GGML_TYPE_Q6_K;
|
new_type = GGML_TYPE_Q6_K;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if (ftype == LLAMA_FTYPE_MOSTLY_MXFP4_MOE) {
|
||||||
|
// MoE tensors -> MXFP4
|
||||||
|
// other tensors -> Q8_0
|
||||||
|
if (tensor->ne[2] > 1) {
|
||||||
|
new_type = GGML_TYPE_MXFP4;
|
||||||
|
} else {
|
||||||
|
new_type = GGML_TYPE_Q8_0;
|
||||||
|
}
|
||||||
} else if (name == "token_embd.weight" || name == "per_layer_token_embd.weight") {
|
} else if (name == "token_embd.weight" || name == "per_layer_token_embd.weight") {
|
||||||
if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
|
if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
|
||||||
new_type = qs.params->token_embedding_type;
|
new_type = qs.params->token_embedding_type;
|
||||||
|
@ -536,6 +547,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
|
case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
|
||||||
case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
|
case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
|
||||||
|
|
||||||
|
case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: default_type = GGML_TYPE_MXFP4; break;
|
||||||
|
|
||||||
// K-quants
|
// K-quants
|
||||||
case LLAMA_FTYPE_MOSTLY_Q2_K_S:
|
case LLAMA_FTYPE_MOSTLY_Q2_K_S:
|
||||||
case LLAMA_FTYPE_MOSTLY_Q2_K: default_type = GGML_TYPE_Q2_K; break;
|
case LLAMA_FTYPE_MOSTLY_Q2_K: default_type = GGML_TYPE_Q2_K; break;
|
||||||
|
@ -987,6 +1000,29 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
|
const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
|
||||||
|
|
||||||
new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
|
new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
|
||||||
|
|
||||||
|
// TODO: temporary sanity check that the F16 -> MXFP4 is lossless
|
||||||
|
#if 1
|
||||||
|
if (new_type == GGML_TYPE_MXFP4) {
|
||||||
|
auto * x = f32_data_03;
|
||||||
|
|
||||||
|
//LLAMA_LOG_INFO("nrows = %d, n_per_row = %d\n", nrows, n_per_row);
|
||||||
|
std::vector<float> deq(nrows*n_per_row);
|
||||||
|
const ggml_type_traits * qtype = ggml_get_type_traits(new_type);
|
||||||
|
qtype->to_float(new_data_03, deq.data(), deq.size());
|
||||||
|
|
||||||
|
double err = 0.0f;
|
||||||
|
for (int i = 0; i < (int) deq.size(); ++i) {
|
||||||
|
err += fabsf(deq[i] - x[i]);
|
||||||
|
//if (fabsf(deq[i] - x[i]) > 0.00001 && i < 256) {
|
||||||
|
if (deq[i] != x[i]) {
|
||||||
|
LLAMA_LOG_INFO("deq[%d] = %f, x[%d] = %f\n", i, deq[i], i, x[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//LLAMA_LOG_INFO("err = %f\n", err);
|
||||||
|
GGML_ASSERT(err == 0.00000);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
|
LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
|
||||||
}
|
}
|
||||||
|
|
|
@ -2553,6 +2553,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||||
|| t.first == "<|eot_id|>"
|
|| t.first == "<|eot_id|>"
|
||||||
|| t.first == "<|im_end|>"
|
|| t.first == "<|im_end|>"
|
||||||
|| t.first == "<|end|>"
|
|| t.first == "<|end|>"
|
||||||
|
|| t.first == "<|return|>" // o200k_harmony
|
||||||
|
|| t.first == "<|call|>" // o200k_harmony
|
||||||
|| t.first == "<end_of_turn>"
|
|| t.first == "<end_of_turn>"
|
||||||
|| t.first == "<|endoftext|>"
|
|| t.first == "<|endoftext|>"
|
||||||
|| t.first == "<|eom_id|>"
|
|| t.first == "<|eom_id|>"
|
||||||
|
@ -2576,6 +2578,13 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// @ngxson : quick hack for gpt-oss, always render these tokens
|
||||||
|
for (const auto & t : token_to_id) {
|
||||||
|
if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>") {
|
||||||
|
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// sanity checks
|
// sanity checks
|
||||||
if (special_eos_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eos_id) == 0) {
|
if (special_eos_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eos_id) == 0) {
|
||||||
special_eog_ids.insert(special_eos_id);
|
special_eog_ids.insert(special_eos_id);
|
||||||
|
@ -2591,6 +2600,36 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||||
special_eog_ids.insert(special_eom_id);
|
special_eog_ids.insert(special_eom_id);
|
||||||
LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
|
LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: workaround for o200k_harmony tokenizer: the "<|end|>" token should not be EOG
|
||||||
|
// we don't have a good way to detect this, so for now, if we have "<|return|>" and "<|call|>" tokens,
|
||||||
|
// we remove the "<|end|>" token from the EOG list
|
||||||
|
{
|
||||||
|
bool has_return = false;
|
||||||
|
bool has_call = false;
|
||||||
|
bool has_end = false;
|
||||||
|
|
||||||
|
llama_token end_id = LLAMA_TOKEN_NULL;
|
||||||
|
|
||||||
|
LLAMA_LOG_INFO("%s: printing all EOG tokens:\n", __func__);
|
||||||
|
for (auto tid : special_eog_ids) {
|
||||||
|
LLAMA_LOG_INFO("%s: - %d ('%s')\n", __func__, tid, id_to_token[tid].text.c_str());
|
||||||
|
|
||||||
|
if (id_to_token[tid].text == "<|return|>") {
|
||||||
|
has_return = true;
|
||||||
|
} else if (id_to_token[tid].text == "<|call|>") {
|
||||||
|
has_call = true;
|
||||||
|
} else if (id_to_token[tid].text == "<|end|>") {
|
||||||
|
has_end = true;
|
||||||
|
end_id = tid;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_return && has_call && has_end) {
|
||||||
|
special_eog_ids.erase(end_id);
|
||||||
|
LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>' tokens, removing '<|end|>' token from EOG list\n", __func__);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// build special tokens cache
|
// build special tokens cache
|
||||||
|
|
|
@ -23,6 +23,7 @@ struct quant_option {
|
||||||
static const std::vector<quant_option> QUANT_OPTIONS = {
|
static const std::vector<quant_option> QUANT_OPTIONS = {
|
||||||
{ "Q4_0", LLAMA_FTYPE_MOSTLY_Q4_0, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
|
{ "Q4_0", LLAMA_FTYPE_MOSTLY_Q4_0, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
|
||||||
{ "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1, " 4.78G, +0.4511 ppl @ Llama-3-8B", },
|
{ "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1, " 4.78G, +0.4511 ppl @ Llama-3-8B", },
|
||||||
|
{ "MXFP4_MOE",LLAMA_FTYPE_MOSTLY_MXFP4_MOE," MXFP4 MoE", },
|
||||||
{ "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 5.21G, +0.1316 ppl @ Llama-3-8B", },
|
{ "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 5.21G, +0.1316 ppl @ Llama-3-8B", },
|
||||||
{ "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 5.65G, +0.1062 ppl @ Llama-3-8B", },
|
{ "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 5.65G, +0.1062 ppl @ Llama-3-8B", },
|
||||||
{ "IQ2_XXS", LLAMA_FTYPE_MOSTLY_IQ2_XXS, " 2.06 bpw quantization", },
|
{ "IQ2_XXS", LLAMA_FTYPE_MOSTLY_IQ2_XXS, " 2.06 bpw quantization", },
|
||||||
|
|
Binary file not shown.
|
@ -61,20 +61,23 @@ export default function ChatMessage({
|
||||||
if (msg.content === null || msg.role !== 'assistant') {
|
if (msg.content === null || msg.role !== 'assistant') {
|
||||||
return { content: msg.content };
|
return { content: msg.content };
|
||||||
}
|
}
|
||||||
|
const REGEX_THINK_OPEN = /<think>|<\|channel\|>analysis<\|message\|>/;
|
||||||
|
const REGEX_THINK_CLOSE =
|
||||||
|
/<\/think>|<\|start\|>assistant<\|channel\|>final<\|message\|>/;
|
||||||
let actualContent = '';
|
let actualContent = '';
|
||||||
let thought = '';
|
let thought = '';
|
||||||
let isThinking = false;
|
let isThinking = false;
|
||||||
let thinkSplit = msg.content.split('<think>', 2);
|
let thinkSplit = msg.content.split(REGEX_THINK_OPEN, 2);
|
||||||
actualContent += thinkSplit[0];
|
actualContent += thinkSplit[0];
|
||||||
while (thinkSplit[1] !== undefined) {
|
while (thinkSplit[1] !== undefined) {
|
||||||
// <think> tag found
|
// <think> tag found
|
||||||
thinkSplit = thinkSplit[1].split('</think>', 2);
|
thinkSplit = thinkSplit[1].split(REGEX_THINK_CLOSE, 2);
|
||||||
thought += thinkSplit[0];
|
thought += thinkSplit[0];
|
||||||
isThinking = true;
|
isThinking = true;
|
||||||
if (thinkSplit[1] !== undefined) {
|
if (thinkSplit[1] !== undefined) {
|
||||||
// </think> closing tag found
|
// </think> closing tag found
|
||||||
isThinking = false;
|
isThinking = false;
|
||||||
thinkSplit = thinkSplit[1].split('<think>', 2);
|
thinkSplit = thinkSplit[1].split(REGEX_THINK_OPEN, 2);
|
||||||
actualContent += thinkSplit[0];
|
actualContent += thinkSplit[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue