diff --git a/common/arg.cpp b/common/arg.cpp
index c5f115c77..92ccb84bf 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -2468,7 +2468,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_N_CPU_MOE_DRAFT"));
add_opt(common_arg(
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
- "number of layers to store in VRAM",
+ string_format("max. number of layers to store in VRAM (default: %d)", params.n_gpu_layers),
[](common_params & params, int value) {
params.n_gpu_layers = value;
if (!llama_supports_gpu_offload()) {
diff --git a/common/chat.cpp b/common/chat.cpp
index 1200b738f..fa330a71c 100644
--- a/common/chat.cpp
+++ b/common/chat.cpp
@@ -623,6 +623,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
+ case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
default:
throw std::runtime_error("Unknown chat format");
}
@@ -1184,6 +1185,67 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
});
return data;
}
+
+static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
+ common_chat_params data;
+
+ // Generate the prompt using the apply() function with the template
+ data.prompt = apply(tmpl, inputs);
+ data.format = COMMON_CHAT_FORMAT_NEMOTRON_V2;
+
+ // Handle thinking tags appropriately based on inputs.enable_thinking
+ if (string_ends_with(data.prompt, "\n")) {
+ if (!inputs.enable_thinking) {
+ data.prompt += "";
+ } else {
+ data.thinking_forced_open = true;
+ }
+ }
+
+ // When tools are present, build grammar for the format, similar to CommandR, but without tool call ID
+ if (!inputs.tools.is_null() && inputs.tools.is_array() && !inputs.tools.empty()) {
+ data.grammar_lazy = true;
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
+ auto schemas = json::array();
+ foreach_function(inputs.tools, [&](const json & tool) {
+ const auto & function = tool.at("function");
+ schemas.push_back({
+ { "type", "object" },
+ { "properties",
+ {
+ { "name",
+ {
+ { "type", "string" },
+ { "const", function.at("name") },
+ } },
+ { "arguments", function.at("parameters") },
+ } },
+ { "required", json::array({ "name", "arguments" }) },
+ });
+ });
+ auto schema = json{
+ { "type", "array" },
+ { "items", schemas.size() == 1 ? schemas[0] : json{ { "anyOf", schemas } } },
+ { "minItems", 1 },
+ };
+ if (!inputs.parallel_tool_calls) {
+ schema["maxItems"] = 1;
+ }
+ builder.add_rule("root",
+ std::string(data.thinking_forced_open ? "( \"\" space )? " : "") +
+ "\"\" " + builder.add_schema("tool_calls", schema) +
+ " \"\"");
+ });
+ data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
+ // If thinking_forced_open, then we capture the tag in the grammar,
+ // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
+ std::string(data.thinking_forced_open ?
+ "[\\s\\S]*?(\\s*)" :
+ "(?:[\\s\\S]*?\\s*)?") +
+ "()[\\s\\S]*" });
+ }
+ return data;
+}
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
@@ -1830,7 +1892,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
// If thinking_forced_open, then we capture the tag in the grammar,
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)" : "(?:[\\s\\S]*?\\s*)?") + (
- "(\\s*"
+ "\\s*("
"(?:"
"||||)?"
@@ -2060,6 +2122,33 @@ static void common_chat_parse_granite(common_chat_msg_parser & builder) {
}
}
+static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) {
+ // Parse thinking tags
+ builder.try_parse_reasoning("", "");
+ if (!builder.syntax().parse_tool_calls) {
+ builder.add_content(builder.consume_rest());
+ return;
+ }
+
+ // Look for tool calls
+ static const common_regex tool_call_regex(regex_escape(""));
+ if (auto res = builder.try_find_regex(tool_call_regex)) {
+ builder.move_to(res->groups[0].end);
+
+ // Expect JSON array of tool calls
+ auto tool_calls_data = builder.consume_json();
+ if (tool_calls_data.json.is_array()) {
+ if (!builder.try_consume_literal("")) {
+ throw common_chat_msg_partial_exception("Incomplete tool call");
+ }
+ builder.add_tool_calls(tool_calls_data.json);
+ } else {
+ throw common_chat_msg_partial_exception("Incomplete tool call");
+ }
+ }
+ builder.add_content(builder.consume_rest());
+}
+
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
// Parse thinking tags first - this handles the main reasoning content
builder.try_parse_reasoning("", "");
@@ -2293,6 +2382,11 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_seed_oss(tmpl, params, inputs);
}
+ // Nemotron v2
+ if (src.find("") != std::string::npos) {
+ return common_chat_params_init_nemotron_v2(tmpl, params);
+ }
+
// Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below.
if ((params.tools.is_array() && params.json_schema.is_object())) {
@@ -2454,6 +2548,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_SEED_OSS:
common_chat_parse_seed_oss(builder);
break;
+ case COMMON_CHAT_FORMAT_NEMOTRON_V2:
+ common_chat_parse_nemotron_v2(builder);
+ break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
diff --git a/common/chat.h b/common/chat.h
index b09ff3b12..ccd26f27f 100644
--- a/common/chat.h
+++ b/common/chat.h
@@ -112,6 +112,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_GRANITE,
COMMON_CHAT_FORMAT_GPT_OSS,
COMMON_CHAT_FORMAT_SEED_OSS,
+ COMMON_CHAT_FORMAT_NEMOTRON_V2,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 908435d7e..717b8a659 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -5122,6 +5122,15 @@ class Gemma3Model(TextModel):
return [(self.map_tensor_name(name), data_torch)]
+@ModelBase.register("Gemma3TextModel")
+class EmbeddingGemma(Gemma3Model):
+ model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ self._try_set_pooling_type()
+
+
@ModelBase.register("Gemma3ForConditionalGeneration")
class Gemma3VisionModel(MmprojModel):
def set_gguf_parameters(self):
diff --git a/examples/model-conversion/scripts/utils/curl-embedding-server.sh b/examples/model-conversion/scripts/utils/curl-embedding-server.sh
index b6898665f..7ed69e1ea 100755
--- a/examples/model-conversion/scripts/utils/curl-embedding-server.sh
+++ b/examples/model-conversion/scripts/utils/curl-embedding-server.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
curl --request POST \
--url http://localhost:8080/embedding \
--header "Content-Type: application/json" \
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 2a3173532..e78f2355c 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -517,6 +517,7 @@ extern "C" {
GGML_OP_CONV_TRANSPOSE_1D,
GGML_OP_IM2COL,
GGML_OP_IM2COL_BACK,
+ GGML_OP_IM2COL_3D,
GGML_OP_CONV_2D,
GGML_OP_CONV_3D,
GGML_OP_CONV_2D_DW,
@@ -1895,6 +1896,41 @@ extern "C" {
int d0, // dilation dimension 0
int d1); // dilation dimension 1
+ GGML_API struct ggml_tensor * ggml_im2col_3d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int64_t IC,
+ int s0, // stride width
+ int s1, // stride height
+ int s2, // stride depth
+ int p0, // padding width
+ int p1, // padding height
+ int p2, // padding depth
+ int d0, // dilation width
+ int d1, // dilation height
+ int d2, // dilation depth
+ enum ggml_type dst_type);
+
+ // a: [OC*IC, KD, KH, KW]
+ // b: [N*IC, ID, IH, IW]
+ // result: [N*OC, OD, OH, OW]
+ GGML_API struct ggml_tensor * ggml_conv_3d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int64_t IC,
+ int s0, // stride width
+ int s1, // stride height
+ int s2, // stride depth
+ int p0, // padding width
+ int p1, // padding height
+ int p2, // padding depth
+ int d0, // dilation width
+ int d1, // dilation height
+ int d2 // dilation depth
+ );
+
// kernel size is a->ne[0] x a->ne[1]
// stride is equal to kernel size
// padding is zero
@@ -1966,7 +2002,7 @@ extern "C" {
int d0, // dilation dimension 0
int d1); // dilation dimension 1
- GGML_API struct ggml_tensor * ggml_conv_3d(
+ GGML_API struct ggml_tensor * ggml_conv_3d_direct(
struct ggml_context * ctx,
struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC]
struct ggml_tensor * b, // input [W, H, D, C * N]
@@ -2073,6 +2109,19 @@ extern "C" {
int p2,
int p3);
+ GGML_API struct ggml_tensor * ggml_pad_ext(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int lp0,
+ int rp0,
+ int lp1,
+ int rp1,
+ int lp2,
+ int rp2,
+ int lp3,
+ int rp3
+ );
+
// pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
struct ggml_context * ctx,
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index d650b4dd7..d31c700b2 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -2660,6 +2660,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_im2col_back_f32(params, tensor);
} break;
+ case GGML_OP_IM2COL_3D:
+ {
+ ggml_compute_forward_im2col_3d(params, tensor);
+ } break;
case GGML_OP_CONV_2D:
{
ggml_compute_forward_conv_2d(params, tensor);
@@ -3080,6 +3084,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
} break;
case GGML_OP_IM2COL:
case GGML_OP_IM2COL_BACK:
+ case GGML_OP_IM2COL_3D:
case GGML_OP_CONV_2D:
case GGML_OP_CONV_3D:
case GGML_OP_CONV_2D_DW:
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index 8c1f79488..0bb767e01 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -7027,6 +7027,209 @@ void ggml_compute_forward_im2col_back_f32(
}
}
+
+// ggml_compute_forward_im2col_3d_f16
+// src0: kernel [OC*IC, KD, KH, KW]
+// src1: image [N*IC, ID, IH, IW]
+// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
+static void ggml_compute_forward_im2col_3d_f16(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F16);
+
+ GGML_TENSOR_BINARY_OP_LOCALS;
+
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+ const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
+ const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
+ const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
+ const int32_t IC = ((const int32_t *)(dst->op_params))[9];
+
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t N = ne13 / IC;
+ const int64_t ID = ne12;
+ const int64_t IH = ne11;
+ const int64_t IW = ne10;
+
+ const int64_t OC = ne03 / IC;
+ GGML_UNUSED(OC);
+ const int64_t KD = ne02;
+ const int64_t KH = ne01;
+ const int64_t KW = ne00;
+
+ const int64_t OD = ne3 / N;
+ const int64_t OH = ne2;
+ const int64_t OW = ne1;
+ const int64_t OH_OW = OH*OW;
+ const int64_t KD_KH_KW = KD*KH*KW;
+ const int64_t KH_KW = KH*KW;
+ const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
+
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
+ {
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
+
+ for (int64_t in = 0; in < N; in++) {
+ for (int64_t iod = 0; iod < OD; iod++) {
+ for (int64_t ioh = 0; ioh < OH; ioh++) {
+ for (int64_t iow = 0; iow < OW; iow++) {
+ for (int64_t iic = ith; iic < IC; iic += nth) {
+
+ // micro kernel
+ ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
+ const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
+
+ for (int64_t ikd = 0; ikd < KD; ikd++) {
+ for (int64_t ikh = 0; ikh < KH; ikh++) {
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
+ const int64_t iid = iod*s2 + ikd*d2 - p2;
+
+ if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
+ } else {
+ const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+// ggml_compute_forward_im2col_3d_f32
+// src0: kernel [OC*IC, KD, KH, KW]
+// src1: image [N*IC, ID, IH, IW]
+// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
+static void ggml_compute_forward_im2col_3d_f32(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_TENSOR_BINARY_OP_LOCALS;
+
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+ const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
+ const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
+ const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
+ const int32_t IC = ((const int32_t *)(dst->op_params))[9];
+
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t N = ne13 / IC;
+ const int64_t ID = ne12;
+ const int64_t IH = ne11;
+ const int64_t IW = ne10;
+
+ const int64_t OC = ne03 / IC;
+ GGML_UNUSED(OC);
+ const int64_t KD = ne02;
+ const int64_t KH = ne01;
+ const int64_t KW = ne00;
+
+ const int64_t OD = ne3 / N;
+ const int64_t OH = ne2;
+ const int64_t OW = ne1;
+
+ const int64_t OH_OW = OH*OW;
+ const int64_t KD_KH_KW = KD*KH*KW;
+ const int64_t KH_KW = KH*KW;
+ const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
+
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
+ {
+ float * const wdata = (float *) dst->data;
+
+ for (int64_t in = 0; in < N; in++) {
+ for (int64_t iod = 0; iod < OD; iod++) {
+ for (int64_t ioh = 0; ioh < OH; ioh++) {
+ for (int64_t iow = 0; iow < OW; iow++) {
+ for (int64_t iic = ith; iic < IC; iic += nth) {
+
+ // micro kernel
+ float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
+ const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
+
+ for (int64_t ikd = 0; ikd < KD; ikd++) {
+ for (int64_t ikh = 0; ikh < KH; ikh++) {
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
+ const int64_t iid = iod*s2 + ikd*d2 - p2;
+
+ if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
+ } else {
+ const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+
+void ggml_compute_forward_im2col_3d(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+ switch (dst->type) {
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_im2col_3d_f16(params, dst);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_im2col_3d_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
void * a, void * b, float * c) {
const ggml_type_traits * traits = ggml_get_type_traits(type);
@@ -8014,6 +8217,15 @@ static void ggml_compute_forward_pad_f32(
GGML_TENSOR_UNARY_OP_LOCALS
float * dst_ptr = (float *) dst->data;
+ const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
+ const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
+ const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
+ const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
+ const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
+ const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
+ const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
+ const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
+
// TODO: optimize
@@ -8022,10 +8234,12 @@ static void ggml_compute_forward_pad_f32(
for (int64_t i0 = 0; i0 < ne0; ++i0) {
for (int64_t i3 = 0; i3 < ne3; ++i3) {
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
-
- const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-
- if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+ if ((i0 >= lp0 && i0 < ne0 - rp0) \
+ && (i1 >= lp1 && i1 < ne1 - rp1) \
+ && (i2 >= lp2 && i2 < ne2 - rp2) \
+ && (i3 >= lp3 && i3 < ne3 - rp3)) {
+ const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
+ const float * src_ptr = (const float *)((char *) src0->data + src_idx);
dst_ptr[dst_idx] = *src_ptr;
} else {
dst_ptr[dst_idx] = 0;
diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h
index d0ea83843..9824a03b4 100644
--- a/ggml/src/ggml-cpu/ops.h
+++ b/ggml/src/ggml-cpu/ops.h
@@ -69,6 +69,7 @@ void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struc
void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_im2col_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index fbe7afad6..f78ccb1a9 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -568,6 +568,38 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
#endif // CUDART_VERSION >= 12050
}
+// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
+// Precompute mp (m' in the paper) and L such that division
+// can be computed using a multiply (high 32b of 64b result)
+// and a shift:
+//
+// n/d = (mulhi(n, mp) + n) >> L;
+static const uint3 init_fastdiv_values(uint32_t d) {
+ // compute L = ceil(log2(d));
+ uint32_t L = 0;
+ while (L < 32 && (uint32_t{ 1 } << L) < d) {
+ L++;
+ }
+
+ uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
+ // pack divisor as well to reduce error surface
+ return make_uint3(mp, L, d);
+}
+
+static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint3 fastdiv_values) {
+ // expects fastdiv_values to contain in
+ // fastdiv_values.z is unused and optimized away by the compiler.
+ // Compute high 32 bits of n * mp
+ const uint32_t hi = __umulhi(n, fastdiv_values.x);
+ // add n, apply bit shift
+ return (hi + n) >> fastdiv_values.y;
+}
+
+static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fastdiv_values) {
+ // expects fastdiv_values to contain in (see init_fastdiv_values)
+ return n - fastdiv(n, fastdiv_values) * fastdiv_values.z;
+}
+
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);
static __device__ __forceinline__ float get_alibi_slope(
diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu
index 3ec0e957a..83d02474f 100644
--- a/ggml/src/ggml-cuda/getrows.cu
+++ b/ggml/src/ggml-cuda/getrows.cu
@@ -2,6 +2,8 @@
#include "dequantize.cuh"
#include "convert.cuh"
+#define MAX_GRIDDIM_Y 65535
+
template
static __global__ void k_get_rows(
const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
@@ -11,32 +13,29 @@ static __global__ void k_get_rows(
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
- // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
- const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2;
- const int i10 = blockIdx.x;
- const int i11 = blockIdx.z / ne12;
- const int i12 = blockIdx.z % ne12;
+ for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
+ // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
+ const int i10 = blockIdx.x;
+ const int i11 = blockIdx.z / ne12;
+ const int i12 = blockIdx.z % ne12;
- if (i00 >= ne00) {
- return;
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+ const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
+
+ const int ib = i00/qk; // block index
+ const int iqs = (i00%qk)/qr; // quant index
+ const int iybs = i00 - i00%qk; // dst block start index
+ const int y_offset = qr == 1 ? 1 : qk/2;
+
+ // dequantize
+ float2 v;
+ dequantize_kernel(src0_row, ib, iqs, v);
+
+ dst_row[iybs + iqs + 0] = ggml_cuda_cast(v.x);
+ dst_row[iybs + iqs + y_offset] = ggml_cuda_cast(v.y);
}
-
- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
-
- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
- const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
-
- const int ib = i00/qk; // block index
- const int iqs = (i00%qk)/qr; // quant index
- const int iybs = i00 - i00%qk; // dst block start index
- const int y_offset = qr == 1 ? 1 : qk/2;
-
- // dequantize
- float2 v;
- dequantize_kernel(src0_row, ib, iqs, v);
-
- dst_row[iybs + iqs + 0] = ggml_cuda_cast(v.x);
- dst_row[iybs + iqs + y_offset] = ggml_cuda_cast(v.y);
}
template
@@ -48,22 +47,23 @@ static __global__ void k_get_rows_float(
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
- // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
- const int i00 = blockIdx.y * blockDim.x + threadIdx.x;
- const int i10 = blockIdx.x;
- const int i11 = blockIdx.z / ne12;
- const int i12 = blockIdx.z % ne12;
+ for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
+ // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
+ const int i10 = blockIdx.x;
+ const int i11 = blockIdx.z / ne12;
+ const int i12 = blockIdx.z % ne12;
- if (i00 >= ne00) {
- return;
+ if (i00 >= ne00) {
+ return;
+ }
+
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+ const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
+
+ dst_row[i00] = ggml_cuda_cast(src0_row[i00]);
}
-
- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
-
- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
- const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
-
- dst_row[i00] = ggml_cuda_cast(src0_row[i00]);
}
template
@@ -98,7 +98,7 @@ static void get_rows_cuda_q(
cudaStream_t stream) {
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
- const dim3 block_nums(ne10, block_num_y, ne11*ne12);
+ const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12);
// strides in elements
// const size_t s0 = nb0 / sizeof(dst_t);
@@ -131,7 +131,7 @@ static void get_rows_cuda_float(
cudaStream_t stream) {
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
- const dim3 block_nums(ne10, block_num_y, ne11*ne12);
+ const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12);
// strides in elements
// const size_t s0 = nb0 / sizeof(dst_t);
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index a08117bfe..d9bd4a9f1 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -2465,6 +2465,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_IM2COL:
ggml_cuda_op_im2col(ctx, dst);
break;
+ case GGML_OP_IM2COL_3D:
+ ggml_cuda_op_im2col_3d(ctx, dst);
+ break;
case GGML_OP_CONV_2D:
ggml_cuda_op_conv2d(ctx, dst);
break;
@@ -3572,6 +3575,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
}
case GGML_OP_IM2COL:
+ case GGML_OP_IM2COL_3D:
case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D:
diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu
index 16bb9bec9..7737d6a5d 100644
--- a/ggml/src/ggml-cuda/im2col.cu
+++ b/ggml/src/ggml-cuda/im2col.cu
@@ -112,3 +112,132 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
}
}
+
+// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
+template
+static __global__ void im2col_3d_kernel(
+ const float * src, T * dst,
+ int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
+ int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
+ int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW,
+ int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW,
+ int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH,
+ int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) {
+ const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
+ if (i >= IC_KD_KH_KW) {
+ return;
+ }
+
+ const int64_t iic = i / KD_KH_KW;
+ const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW;
+ const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
+ const int64_t ikw = i % KW;
+
+ const int64_t iow = blockIdx.y;
+ for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) {
+ const int64_t in = iz / OD_OH;
+ const int64_t iod = (iz - in*OD_OH) / OH;
+ const int64_t ioh = iz % OH;
+
+ const int64_t iiw = iow * s0 + ikw * d0 - p0;
+ const int64_t iih = ioh * s1 + ikh * d1 - p1;
+ const int64_t iid = iod * s2 + ikd * d2 - p2;
+
+ const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
+
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
+ dst[offset_dst] = 0.0f;
+ } else {
+ const int64_t offset_src = in*IC_ID_IH_IW + iic*ID_IH_IW + iid*IH_IW + iih*IW + iiw;
+ dst[offset_dst] = src[offset_src];
+ }
+ }
+}
+
+// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
+template
+static void im2col_3d_cuda(const float * src, T* dst,
+ int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
+ int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
+ int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
+ const int64_t OH_OW = OH*OW;
+ const int64_t KD_KH_KW = KD*KH*KW;
+ const int64_t ID_IH_IW = ID*IH*IW;
+ const int64_t KH_KW = KH*KW;
+ const int64_t IH_IW = IH*IW;
+ const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
+ const int64_t OW_KD_KH_KW = OW*KD*KH*KW;
+ const int64_t N_OD_OH = N*OD*OH;
+ const int64_t OD_OH = OD*OH;
+ const int64_t IC_ID_IH_IW = IC*ID*IH*IW;
+ const int64_t OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;
+ const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
+ const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
+ const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
+ dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z));
+ im2col_3d_kernel<<>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
+ OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW,
+ IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW,
+ OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH,
+ s0, s1, s2, p0, p1, p2, d0, d1, d2);
+}
+
+static void im2col_3d_cuda_f16(const float * src, half * dst,
+ int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
+ int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
+ int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
+
+ im2col_3d_cuda(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
+}
+
+static void im2col_3d_cuda_f32(const float * src, float * dst,
+ int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
+ int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
+ int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
+
+ im2col_3d_cuda(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
+}
+
+void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const float * src1_d = (const float *)src1->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+ const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
+ const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
+ const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
+ const int32_t IC = ((const int32_t *)(dst->op_params))[9];
+
+ const int64_t N = ne13 / IC;
+ const int64_t ID = ne12;
+ const int64_t IH = ne11;
+ const int64_t IW = ne10;
+
+ const int64_t OC = ne03 / IC;
+ const int64_t KD = ne02;
+ const int64_t KH = ne01;
+ const int64_t KW = ne00;
+
+ const int64_t OD = ne3 / N;
+ const int64_t OH = ne2;
+ const int64_t OW = ne1;
+
+ if(dst->type == GGML_TYPE_F16) {
+ im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
+ } else {
+ im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
+ }
+}
diff --git a/ggml/src/ggml-cuda/im2col.cuh b/ggml/src/ggml-cuda/im2col.cuh
index 1ce8fae4d..2da1223d6 100644
--- a/ggml/src/ggml-cuda/im2col.cuh
+++ b/ggml/src/ggml-cuda/im2col.cuh
@@ -3,3 +3,4 @@
#define CUDA_IM2COL_BLOCK_SIZE 256
void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu
index d5157d958..4f153c571 100644
--- a/ggml/src/ggml-cuda/norm.cu
+++ b/ggml/src/ggml-cuda/norm.cu
@@ -105,29 +105,29 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
}
template
-static __global__ void rms_norm_f32(const float * x, float * dst,
+static __global__ void rms_norm_f32(const float * x,
+ float * dst,
const int ncols,
const int64_t stride_row,
const int64_t stride_channel,
const int64_t stride_sample,
const float eps,
- const float * mul = nullptr,
- const int64_t mul_stride_row = 0,
- const int64_t mul_stride_channel = 0,
- const int64_t mul_stride_sample = 0,
- const int mul_ncols = 0,
- const int mul_nrows = 0,
- const int mul_nchannels = 0,
- const int mul_nsamples = 0,
- const float * add = nullptr,
- const int64_t add_stride_row = 0,
- const int64_t add_stride_channel = 0,
- const int64_t add_stride_sample = 0,
- const int add_ncols = 0,
- const int add_nrows = 0,
- const int add_nchannels = 0,
- const int add_nsamples = 0) {
-
+ const float * mul = nullptr,
+ const int64_t mul_stride_row = 0,
+ const int64_t mul_stride_channel = 0,
+ const int64_t mul_stride_sample = 0,
+ const uint3 mul_ncols_packed = make_uint3(0, 0, 0),
+ const uint3 mul_nrows_packed = make_uint3(0, 0, 0),
+ const uint3 mul_nchannels_packed = make_uint3(0, 0, 0),
+ const uint3 mul_nsamples_packed = make_uint3(0, 0, 0),
+ const float * add = nullptr,
+ const int64_t add_stride_row = 0,
+ const int64_t add_stride_channel = 0,
+ const int64_t add_stride_sample = 0,
+ const uint3 add_ncols_packed = make_uint3(0, 0, 0),
+ const uint3 add_nrows_packed = make_uint3(0, 0, 0),
+ const uint3 add_nchannels_packed = make_uint3(0, 0, 0),
+ const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
@@ -142,16 +142,16 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
if constexpr (do_multiply) {
- const int mul_row = row % mul_nrows;
- const int mul_channel = channel % mul_nchannels;
- const int mul_sample = sample % mul_nsamples;
- mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
+ const uint32_t mul_row = fastmodulo(row, mul_nrows_packed);
+ const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed);
+ const uint32_t mul_sample = fastmodulo(sample, mul_nsamples_packed);
+ mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;
}
if constexpr (do_add) {
- const int add_row = row % add_nrows;
- const int add_channel = channel % add_nchannels;
- const int add_sample = sample % add_nsamples;
+ const int add_row = fastmodulo(row, add_nrows_packed);
+ const int add_channel = fastmodulo(channel, add_nchannels_packed);
+ const int add_sample = fastmodulo(sample, add_nsamples_packed);
add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
}
@@ -165,15 +165,18 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
// sum up partial sums
tmp = warp_reduce_sum(tmp);
if constexpr (block_size > WARP_SIZE) {
- static_assert(block_size == 1024, "unexpected block_size");
+ static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size");
__shared__ float s_sum[32];
- const int warp_id = threadIdx.x / WARP_SIZE;
- const int lane_id = threadIdx.x % WARP_SIZE;
+ const int warp_id = tid / WARP_SIZE;
+ const int lane_id = tid % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
- tmp = s_sum[lane_id];
+ tmp = 0.0f;
+ if (lane_id < (block_size / WARP_SIZE)) {
+ tmp = s_sum[lane_id];
+ }
tmp = warp_reduce_sum(tmp);
}
@@ -182,12 +185,12 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
for (int col = tid; col < ncols; col += block_size) {
if constexpr (do_multiply && do_add) {
- const int mul_col = col % mul_ncols;
- const int add_col = col % add_ncols;
- dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
+ const int mul_col = fastmodulo(col, mul_ncols_packed);
+ const int add_col = fastmodulo(col, add_ncols_packed);
+ dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
} else if constexpr (do_multiply) {
- const int mul_col = col % mul_ncols;
- dst[col] = scale * x[col] * mul[mul_col];
+ const int mul_col = fastmodulo(col, mul_ncols_packed);
+ dst[col] = scale * x[col] * mul[mul_col];
} else {
dst[col] = scale * x[col];
}
@@ -354,77 +357,86 @@ static void rms_norm_f32_cuda(
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
- const dim3 block_dims(WARP_SIZE, 1, 1);
- rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ const dim3 block_dims(256, 1, 1);
+ rms_norm_f32<256, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
-static void rms_norm_mul_f32_cuda(const float * x,
- const float * mul,
- const float * add,
- float * dst,
- const int ncols,
- const int nrows,
- const int nchannels,
- const int nsamples,
- const int64_t stride_row,
- const int64_t stride_channel,
- const int64_t stride_sample,
- const int64_t mul_stride_row,
- const int64_t mul_stride_channel,
- const int64_t mul_stride_sample,
- const int mul_ncols,
- const int mul_nrows,
- const int mul_nchannels,
- const int mul_nsamples,
- const int64_t add_stride_row,
- const int64_t add_stride_channel,
- const int64_t add_stride_sample,
- const int add_ncols,
- const int add_nrows,
- const int add_nchannels,
- const int add_nsamples,
- const float eps,
- cudaStream_t stream) {
+static void rms_norm_mul_f32_cuda(const float * x,
+ const float * mul,
+ const float * add,
+ float * dst,
+ const int ncols,
+ const int nrows,
+ const int nchannels,
+ const int nsamples,
+ const int64_t stride_row,
+ const int64_t stride_channel,
+ const int64_t stride_sample,
+ const int64_t mul_stride_row,
+ const int64_t mul_stride_channel,
+ const int64_t mul_stride_sample,
+ const uint32_t mul_ncols,
+ const uint32_t mul_nrows,
+ const uint32_t mul_nchannels,
+ const uint32_t mul_nsamples,
+ const int64_t add_stride_row,
+ const int64_t add_stride_channel,
+ const int64_t add_stride_sample,
+ const uint32_t add_ncols,
+ const uint32_t add_nrows,
+ const uint32_t add_nchannels,
+ const uint32_t add_nsamples,
+ const float eps,
+ cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (mul == nullptr) {
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
return;
}
if (add == nullptr) {
+ const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
+ const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
+ const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
+ const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
if (ncols < 1024) {
- const dim3 block_dims(WARP_SIZE, 1, 1);
- rms_norm_f32<<>>(x, dst,
- ncols, stride_row, stride_channel, stride_sample, eps,
- mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
- mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+ const dim3 block_dims(256, 1, 1);
+ rms_norm_f32<256, true><<>>(
+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
} else {
const dim3 block_dims(1024, 1, 1);
- rms_norm_f32<1024, true><<>>(x, dst,
- ncols, stride_row, stride_channel, stride_sample, eps,
- mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
- mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+ rms_norm_f32<1024, true><<>>(
+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
}
} else {
+ const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
+ const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
+ const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
+ const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
+
+ const uint3 add_ncols_packed = init_fastdiv_values(add_ncols);
+ const uint3 add_nrows_packed = init_fastdiv_values(add_nrows);
+ const uint3 add_nchannels_packed = init_fastdiv_values(add_nchannels);
+ const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples);
if (ncols < 1024) {
- const dim3 block_dims(WARP_SIZE, 1, 1);
- rms_norm_f32<<>>(x, dst,
- ncols, stride_row, stride_channel, stride_sample, eps,
- mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
- mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
- add, add_stride_row, add_stride_channel, add_stride_sample,
- add_ncols, add_nrows, add_nchannels, add_nsamples);
+ const dim3 block_dims(256, 1, 1);
+ rms_norm_f32<256, true, true><<>>(
+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
+ add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
+ add_nchannels_packed, add_nsamples_packed);
} else {
const dim3 block_dims(1024, 1, 1);
- rms_norm_f32<1024, true, true><<>>(x, dst,
- ncols, stride_row, stride_channel, stride_sample, eps,
- mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
- mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
- add, add_stride_row, add_stride_channel, add_stride_sample,
- add_ncols, add_nrows, add_nchannels, add_nsamples);
+ rms_norm_f32<1024, true, true><<>>(
+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
+ add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
+ add_nchannels_packed, add_nsamples_packed);
}
}
}
diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu
index 77432b046..29aef33c1 100644
--- a/ggml/src/ggml-cuda/pad.cu
+++ b/ggml/src/ggml-cuda/pad.cu
@@ -1,36 +1,50 @@
#include "pad.cuh"
-static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
- // blockIdx.z: idx of ne2*ne3, aka ne02*ne03
- // blockIdx.y: idx of ne1
- // blockIDx.x: idx of ne0 / BLOCK_SIZE
- int nidx = threadIdx.x + blockIdx.x * blockDim.x;
- if (nidx >= ne0) {
+static __global__ void pad_f32(const float * src, float * dst,
+ const int lp0, const int rp0, const int lp1, const int rp1,
+ const int lp2, const int rp2, const int lp3, const int rp3,
+ const int ne0, const int ne1, const int ne2, const int ne3) {
+ // blockIdx.z: i3*ne2+i2
+ // blockIdx.y: i1
+ // blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE
+ // gridDim.y: ne1
+ int i0 = threadIdx.x + blockIdx.x * blockDim.x;
+ int i1 = blockIdx.y;
+ int i2 = blockIdx.z % ne2;
+ int i3 = blockIdx.z / ne2;
+ if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return;
}
// operation
- int offset_dst =
- nidx +
- blockIdx.y * ne0 +
- blockIdx.z * ne0 * gridDim.y;
- if (nidx < ne00 && blockIdx.y < (unsigned)ne01 && blockIdx.z < (unsigned)(ne02*ne03)) {
- int offset_src =
- nidx +
- blockIdx.y * ne00 +
- blockIdx.z * ne00 * ne01;
- dst[offset_dst] = x[offset_src];
+ const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
+ if ((i0 >= lp0 && i0 < ne0 - rp0) &&
+ (i1 >= lp1 && i1 < ne1 - rp1) &&
+ (i2 >= lp2 && i2 < ne2 - rp2) &&
+ (i3 >= lp3 && i3 < ne3 - rp3)) {
+ const int64_t i00 = i0 - lp0;
+ const int64_t i01 = i1 - lp1;
+ const int64_t i02 = i2 - lp2;
+ const int64_t i03 = i3 - lp3;
+ const int64_t ne02 = ne2 - lp2 - rp2;
+ const int64_t ne01 = ne1 - lp1 - rp1;
+ const int64_t ne00 = ne0 - lp0 - rp0;
+
+ const int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00;
+
+ dst[dst_idx] = src[src_idx];
} else {
- dst[offset_dst] = 0.0f;
+ dst[dst_idx] = 0.0f;
}
}
-static void pad_f32_cuda(const float * x, float * dst,
- const int ne00, const int ne01, const int ne02, const int ne03,
+static void pad_f32_cuda(const float * src, float * dst,
+ const int lp0, const int rp0, const int lp1, const int rp1,
+ const int lp2, const int rp2, const int lp3, const int rp3,
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
dim3 gridDim(num_blocks, ne1, ne2*ne3);
- pad_f32<<>>(x, dst, ne0, ne00, ne01, ne02, ne03);
+ pad_f32<<>>(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3);
}
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -41,9 +55,18 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
- GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
+ const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
+ const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
+ const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
+ const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
+ const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
+ const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
+ const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
pad_f32_cuda(src0_d, dst_d,
- src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
- dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
+ lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
+ dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
}
diff --git a/ggml/src/ggml-cuda/scale.cu b/ggml/src/ggml-cuda/scale.cu
index 2ee9e5889..0ddeff6a1 100644
--- a/ggml/src/ggml-cuda/scale.cu
+++ b/ggml/src/ggml-cuda/scale.cu
@@ -1,18 +1,19 @@
#include "scale.cuh"
-static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) {
- const int i = blockDim.x*blockIdx.x + threadIdx.x;
+#define MAX_GRIDDIM_X 0x7FFFFFFF
- if (i >= k) {
- return;
+static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) {
+ int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
+ int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x;
+
+ for (int64_t i = tid; i < nelements; i += stride) {
+ dst[i] = scale * x[i] + bias;
}
-
- dst[i] = scale * x[i] + bias;
}
-static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) {
- const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
- scale_f32<<>>(x, dst, scale, bias, k);
+static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) {
+ const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
+ scale_f32<<>>(x, dst, scale, bias, nelements);
}
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index 749ba09a7..132feb15b 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -407,6 +407,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_10,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
@@ -1439,6 +1440,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_10, mul_mm_id_map0_f16_ne20_10, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
@@ -1886,7 +1888,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_UPSCALE:
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
case GGML_OP_POOL_2D:
+ return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_PAD:
+ return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
+ (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
@@ -3976,6 +3981,7 @@ static int ggml_metal_encode_node(
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break;
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break;
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break;
+ case 10: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_10].pipeline; break;
case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break;
default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20);
}
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 9c5933d24..2d56c6267 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -7618,6 +7618,7 @@ template [[host_name("kernel_mul_mm_id_map0_f16_ne20_2" )]] kernel kernel_mul_mm
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 1ce8fbf04..9291d515f 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -545,6 +545,8 @@ struct vk_device_struct {
vk_pipeline pipeline_relu[2];
vk_pipeline pipeline_tanh[2];
vk_pipeline pipeline_sigmoid[2];
+ vk_pipeline pipeline_hardsigmoid[2];
+ vk_pipeline pipeline_hardswish[2];
vk_pipeline pipeline_geglu[2];
vk_pipeline pipeline_reglu[2];
@@ -2356,7 +2358,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
std::vector> compiles;
- auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
+ auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint,
uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants,
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
@@ -2393,6 +2395,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
};
+ auto const &ggml_vk_create_pipeline2 = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const char *entrypoint,
+ uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants,
+ uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
+ return ggml_vk_create_pipeline(device, pipeline, name.c_str(), spv_size, spv_data, entrypoint,
+ parameter_count, push_constant_size, wg_denoms, specialization_constants,
+ align, disable_robustness, require_full_subgroups, required_subgroup_size);
+ };
+
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array {
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
};
@@ -2943,9 +2953,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;
// Ensure a subgroup size >= 16 is available
- const bool use_subgroups16 = use_subgroups &&
- (!device->subgroup_size_control && device->subgroup_size >= 16 ||
- device->subgroup_size_control && device->subgroup_min_size <= 16 && device->subgroup_max_size >= 16);
+ const bool use_subgroups16 = use_subgroups && subgroup_min_size_16;
const uint32_t subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control && device->subgroup_min_size <= 16 && device->subgroup_max_size >= 16) ? 16 : device->subgroup_size;
const uint32_t subgroup_size16 = std::max(subgroup_size, 16u);
@@ -3128,9 +3136,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
- ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true);
+ ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true);
} else {
- ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
+ ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
}
}
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 12 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
@@ -3214,7 +3222,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
bool rte = device->float_controls_rte_fp16;
#define CREATE_BINARY(name, namemod, spec, bindings) \
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
- ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
+ ggml_vk_create_pipeline2(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
"main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
@@ -3232,8 +3240,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (device->multi_add) {
for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) {
- ggml_vk_create_pipeline(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_multi_add_rms[i], "multi_add_rms_f32_" + std::to_string(i+1), multi_add_rms_f32_len, multi_add_rms_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
+ ggml_vk_create_pipeline2(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
+ ggml_vk_create_pipeline2(device, device->pipeline_multi_add_rms[i], "multi_add_rms_f32_" + std::to_string(i+1), multi_add_rms_f32_len, multi_add_rms_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
}
}
@@ -3277,6 +3285,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_UNARY(relu)
CREATE_UNARY(tanh)
CREATE_UNARY(sigmoid)
+ CREATE_UNARY(hardsigmoid)
+ CREATE_UNARY(hardswish)
#undef CREATE_UNARY
#define CREATE_GLU(name) \
@@ -3325,7 +3335,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
- ggml_vk_create_pipeline(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
@@ -4297,7 +4307,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
}
}
-static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions);
+static bool ggml_vk_instance_validation_ext_available();
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions);
static bool ggml_vk_instance_debug_utils_ext_available(const std::vector & instance_extensions);
@@ -4318,7 +4328,7 @@ static void ggml_vk_instance_init() {
vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version };
const std::vector instance_extensions = vk::enumerateInstanceExtensionProperties();
- const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions);
+ const bool validation_ext = ggml_vk_instance_validation_ext_available();
#ifdef __APPLE__
const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
#endif
@@ -7563,6 +7573,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_SIGMOID:
return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16];
+ case GGML_UNARY_OP_HARDSIGMOID:
+ return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16];
+ case GGML_UNARY_OP_HARDSWISH:
+ return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
default:
break;
}
@@ -10231,6 +10245,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_SIGMOID:
+ case GGML_UNARY_OP_HARDSIGMOID:
+ case GGML_UNARY_OP_HARDSWISH:
break;
default:
return false;
@@ -10601,6 +10617,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_SIGMOID:
+ case GGML_UNARY_OP_HARDSIGMOID:
+ case GGML_UNARY_OP_HARDSWISH:
ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
break;
default:
@@ -10843,6 +10861,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_SIGMOID:
+ case GGML_UNARY_OP_HARDSIGMOID:
+ case GGML_UNARY_OP_HARDSWISH:
buf = tensor->buffer;
break;
default:
@@ -11794,6 +11814,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_SIGMOID:
+ case GGML_UNARY_OP_HARDSIGMOID:
+ case GGML_UNARY_OP_HARDSWISH:
return ggml_is_contiguous(op->src[0]) &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -12084,7 +12106,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_ACC:
case GGML_OP_CONCAT:
case GGML_OP_SCALE:
+ return true;
case GGML_OP_PAD:
+ return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
+ (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
case GGML_OP_ROLL:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
@@ -12226,22 +12251,23 @@ ggml_backend_reg_t ggml_backend_vk_reg() {
}
// Extension availability
-static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions) {
+static bool ggml_vk_instance_validation_ext_available() {
#ifdef GGML_VULKAN_VALIDATE
- bool portability_enumeration_ext = false;
- // Check for portability enumeration extension for MoltenVK support
- for (const auto& properties : instance_extensions) {
- if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
- return true;
+ // Check if validation layer provides the extension
+ const std::string layer_name = "VK_LAYER_KHRONOS_validation";
+ for (const auto& layer : vk::enumerateInstanceLayerProperties()) {
+ if (layer_name == layer.layerName.data()) {
+ for (const auto& ext : vk::enumerateInstanceExtensionProperties(layer_name)) {
+ if (strcmp("VK_EXT_validation_features", ext.extensionName.data()) == 0) {
+ return true;
+ }
+ }
}
}
- if (!portability_enumeration_ext) {
- std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
- }
+
+ std::cerr << "ggml_vulkan: WARNING: Validation layer or layer extension VK_EXT_validation_features not found." << std::endl;
#endif
return false;
-
- UNUSED(instance_extensions);
}
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions) {
#ifdef __APPLE__
@@ -12610,6 +12636,12 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_UNARY_OP_SIGMOID:
tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]);
break;
+ case GGML_UNARY_OP_HARDSIGMOID:
+ tensor_clone = ggml_hardsigmoid(ggml_ctx, src_clone[0]);
+ break;
+ case GGML_UNARY_OP_HARDSWISH:
+ tensor_clone = ggml_hardswish(ggml_ctx, src_clone[0]);
+ break;
default:
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
GGML_ABORT("fatal error");
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp b/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp
new file mode 100644
index 000000000..1da252cc6
--- /dev/null
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp
@@ -0,0 +1,22 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+ if (i >= p.KX) {
+ return;
+ }
+
+ const float x = float(data_a[i]);
+ data_d[i] = D_TYPE(min(1.0f, max(0.0f, (x + 3.0f) / 6.0f)));
+}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp b/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp
new file mode 100644
index 000000000..3afc58827
--- /dev/null
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp
@@ -0,0 +1,22 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+ if (i >= p.KX) {
+ return;
+ }
+
+ const float x = float(data_a[i]);
+ data_d[i] = D_TYPE(x * min(1.0f, max(0.0f, (x + 3.0f) / 6.0f)));
+}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
index 443c2066a..da9d4743d 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
@@ -671,6 +671,10 @@ void process_shaders() {
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("hardsigmoid_f16","hardsigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("hardswish_f16", "hardswish.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
for (auto rte : {false, true}) {
std::string suffix = rte ? "_rte" : "";
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index bd50f482d..f23ccb200 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -990,6 +990,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CONV_TRANSPOSE_1D",
"IM2COL",
"IM2COL_BACK",
+ "IM2COL_3D",
"CONV_2D",
"CONV_3D",
"CONV_2D_DW",
@@ -1034,7 +1035,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};
-static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
+static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -1093,6 +1094,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"conv_transpose_1d(x)",
"im2col(x)",
"im2col_back(x)",
+ "im2col_3d(x)",
"conv_2d(x)",
"conv_3d(x)",
"conv_2d_dw(x)",
@@ -1137,7 +1139,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)",
};
-static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
+static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -4377,6 +4379,91 @@ struct ggml_tensor * ggml_conv_2d(
return result;
}
+// a: [OC*IC, KD, KH, KW]
+// b: [N*IC, ID, IH, IW]
+// result: [N*OD, OH, OW, IC * KD * KH * KW]
+struct ggml_tensor * ggml_im2col_3d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int64_t IC,
+ int s0, // stride width
+ int s1, // stride height
+ int s2, // stride depth
+ int p0, // padding width
+ int p1, // padding height
+ int p2, // padding depth
+ int d0, // dilation width
+ int d1, // dilation height
+ int d2, // dilation depth
+ enum ggml_type dst_type) {
+ const int64_t N = b->ne[3] / IC;
+ const int64_t ID = b->ne[2];
+ const int64_t IH = b->ne[1];
+ const int64_t IW = b->ne[0];
+
+ const int64_t OC = a->ne[3] / IC;
+ UNUSED(OC);
+ const int64_t KD = a->ne[2];
+ const int64_t KH = a->ne[1];
+ const int64_t KW = a->ne[0];
+ const int64_t OD = ggml_calc_conv_output_size(ID, KD, s2, p2, d2);
+ const int64_t OH = ggml_calc_conv_output_size(IH, KH, s1, p1, d1);
+ const int64_t OW = ggml_calc_conv_output_size(IW, KW, s0, p0, d0);
+
+ GGML_ASSERT((OD > 0) && "b too small compared to a");
+ GGML_ASSERT((OH > 0) && "b too small compared to a");
+ GGML_ASSERT((OW > 0) && "b too small compared to a");
+
+
+ const int64_t ne[4] = {KW*KH*KD*IC, OW, OH, OD*N};
+
+ struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
+ int32_t params[] = { s0, s1, s2, p0, p1, p2, d0, d1, d2, (int32_t)IC};
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_IM2COL_3D;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+// a: [OC*IC, KD, KH, KW]
+// b: [N*IC, ID, IH, IW]
+// result: [N*OC, OD, OH, OW]
+struct ggml_tensor * ggml_conv_3d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int64_t IC,
+ int s0, // stride width
+ int s1, // stride height
+ int s2, // stride depth
+ int p0, // padding width
+ int p1, // padding height
+ int p2, // padding depth
+ int d0, // dilation width
+ int d1, // dilation height
+ int d2 // dilation depth
+ ) {
+ struct ggml_tensor * im2col = ggml_im2col_3d(ctx, a, b, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, a->type); // [N*OD, OH, OW, IC * KD * KH * KW]
+
+ int64_t OC = a->ne[3] / IC;
+ int64_t N = b->ne[3] / IC;
+ struct ggml_tensor * result =
+ ggml_mul_mat(ctx,
+ ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N*OD, OH, OW, IC * KD * KH * KW] => [N*OD*OH*OW, IC * KD * KH * KW]
+ ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2] * IC), OC)); // [OC*IC, KD, KH, KW] => [OC, IC * KD * KH * KW]
+
+ int64_t OD = im2col->ne[3] / N;
+ result = ggml_reshape_4d(ctx, result, im2col->ne[1]*im2col->ne[2], OD, N, OC); // [OC, N*OD*OH*OW] => [OC, N, OD, OH*OW]
+ result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OD, OH*OW]
+ result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], OD, OC * N); // [N*OC, OD, OH, OW]
+
+ return result;
+}
+
// ggml_conv_2d_sk_p0
struct ggml_tensor * ggml_conv_2d_sk_p0(
@@ -4498,9 +4585,9 @@ struct ggml_tensor * ggml_conv_2d_direct(
return result;
}
-// ggml_conv_3d
+// ggml_conv_3d_direct
-struct ggml_tensor * ggml_conv_3d(
+struct ggml_tensor * ggml_conv_3d_direct(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
@@ -4726,11 +4813,36 @@ struct ggml_tensor * ggml_pad(
int p1,
int p2,
int p3) {
+ return ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
+}
+
+struct ggml_tensor * ggml_pad_ext(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int lp0,
+ int rp0,
+ int lp1,
+ int rp1,
+ int lp2,
+ int rp2,
+ int lp3,
+ int rp3
+ ) {
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
- a->ne[0] + p0,
- a->ne[1] + p1,
- a->ne[2] + p2,
- a->ne[3] + p3);
+ a->ne[0] + lp0 + rp0,
+ a->ne[1] + lp1 + rp1,
+ a->ne[2] + lp2 + rp2,
+ a->ne[3] + lp3 + rp3);
+
+ ggml_set_op_params_i32(result, 0, lp0);
+ ggml_set_op_params_i32(result, 1, rp0);
+ ggml_set_op_params_i32(result, 2, lp1);
+ ggml_set_op_params_i32(result, 3, rp1);
+ ggml_set_op_params_i32(result, 4, lp2);
+ ggml_set_op_params_i32(result, 5, rp2);
+ ggml_set_op_params_i32(result, 6, lp3);
+ ggml_set_op_params_i32(result, 7, rp3);
+
result->op = GGML_OP_PAD;
result->src[0] = a;
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index 6156d35c2..c922bacf6 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -340,6 +340,7 @@ class MODEL_ARCH(IntEnum):
GEMMA2 = auto()
GEMMA3 = auto()
GEMMA3N = auto()
+ GEMMA_EMBEDDING = auto()
STARCODER2 = auto()
RWKV6 = auto()
RWKV6QWEN2 = auto()
@@ -674,6 +675,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.GEMMA3: "gemma3",
MODEL_ARCH.GEMMA3N: "gemma3n",
+ MODEL_ARCH.GEMMA_EMBEDDING: "gemma-embedding",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
@@ -1719,6 +1721,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.LAUREL_R,
MODEL_TENSOR.LAUREL_POST_NORM,
],
+ MODEL_ARCH.GEMMA_EMBEDDING: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_POST_NORM,
+ MODEL_TENSOR.FFN_PRE_NORM,
+ MODEL_TENSOR.FFN_POST_NORM,
+ ],
MODEL_ARCH.STARCODER2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index 497f48809..b0c3d65e9 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -14,6 +14,7 @@ class TensorNameMap:
"transformer.word_embeddings", # falcon
"word_embeddings", # bloom
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 plamo2 granite-hybrid
+ "embed_tokens", # embeddinggemma
"tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert nomic-bert
"language_model.embedding.word_embeddings", # persimmon
@@ -141,6 +142,7 @@ class TensorNameMap:
"rwkv.blocks.{bid}.ln1", # rwkv6
"model.layers.{bid}.ln1", # rwkv7
"model.layers.{bid}.input_layernorm", # llama4
+ "layers.{bid}.input_layernorm", # embeddinggemma
"transformer_encoder.{bid}.attention_norm", # neobert
"model.layers.{bid}.operator_norm", # lfm2
"model.transformer.blocks.{bid}.attn_norm", # llada
@@ -179,6 +181,7 @@ class TensorNameMap:
# Attention query
MODEL_TENSOR.ATTN_Q: (
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 phimoe
+ "layers.{bid}.self_attn.q_proj", # embeddinggemma
"model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom
"layers.{bid}.attention.wq", # llama-pth
"encoder.layer.{bid}.attention.self.query", # bert
@@ -197,6 +200,7 @@ class TensorNameMap:
# Attention key
MODEL_TENSOR.ATTN_K: (
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 phimoe
+ "layers.{bid}.self_attn.k_proj", # embeddinggemma
"model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom
"layers.{bid}.attention.wk", # llama-pth
"encoder.layer.{bid}.attention.self.key", # bert
@@ -216,6 +220,7 @@ class TensorNameMap:
# Attention value
MODEL_TENSOR.ATTN_V: (
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe
+ "layers.{bid}.self_attn.v_proj", # embeddinggemma
"layers.{bid}.attention.wv", # llama-pth
"encoder.layer.{bid}.attention.self.value", # bert
"transformer.layer.{bid}.attention.v_lin", # distillbert
@@ -239,6 +244,7 @@ class TensorNameMap:
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe
+ "layers.{bid}.self_attn.o_proj", # embeddinggemma
"model.layers.{bid}.self_attn.out_proj", # lfm2
"model.layers.{bid}.self_attn.linear_attn", # deci
"layers.{bid}.attention.wo", # llama-pth
@@ -277,6 +283,7 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_POST_NORM: (
"model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge
+ "layers.{bid}.post_attention_layernorm", # embeddinggemma
"model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414
"model.layers.layers.{bid}.post_mixer_norm.weight", # plamo2
),
@@ -320,12 +327,14 @@ class TensorNameMap:
# Post feed-forward norm
MODEL_TENSOR.FFN_PRE_NORM: (
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
+ "layers.{bid}.pre_feedforward_layernorm", # embeddinggemma
"model.layers.{bid}.pre_ff_layernorm.weight",
),
# Post feed-forward norm
MODEL_TENSOR.FFN_POST_NORM: (
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
+ "layers.{bid}.post_feedforward_layernorm", # embeddinggemma
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
"model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2
"model.layers.{bid}.feed_forward.up_proj",
@@ -362,6 +371,7 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
"h.{bid}.mlp.dense_h_to_4h", # bloom
"model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo2
+ "layers.{bid}.mlp.up_proj", # embeddinggemma
"layers.{bid}.feed_forward.w3", # llama-pth
"encoder.layer.{bid}.intermediate.dense", # bert
"transformer.layer.{bid}.ffn.lin1", # distillbert
@@ -421,6 +431,7 @@ class TensorNameMap:
# Feed-forward gate
MODEL_TENSOR.FFN_GATE: (
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo2
+ "layers.{bid}.mlp.gate_proj", # embeddinggemma
"layers.{bid}.feed_forward.w1", # llama-pth
"transformer.h.{bid}.mlp.w2", # qwen
"transformer.h.{bid}.mlp.c_fc2", # jais
@@ -461,6 +472,7 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
"h.{bid}.mlp.dense_4h_to_h", # bloom
"model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo2
+ "layers.{bid}.mlp.down_proj", # embeddinggemma
"layers.{bid}.feed_forward.w2", # llama-pth
"encoder.layer.{bid}.output.dense", # bert
"transformer.layer.{bid}.ffn.lin2", # distillbert
@@ -513,6 +525,7 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
"model.layers.{bid}.self_attn.query_layernorm", # hunyuan
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
+ "layers.{bid}.self_attn.q_norm", # embeddinggemma
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
"transformer.layers.{bid}.attn.q_norm", # openelm
@@ -525,6 +538,7 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
"model.layers.{bid}.self_attn.key_layernorm", # hunyuan
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
+ "layers.{bid}.self_attn.k_norm", # embeddinggemma
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
"transformer.layers.{bid}.attn.k_norm", # openelm
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
index d5c8477f4..d92bf2dd8 100644
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -45,6 +45,7 @@ static const std::map LLM_ARCH_NAMES = {
{ LLM_ARCH_GEMMA2, "gemma2" },
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_GEMMA3N, "gemma3n" },
+ { LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_MAMBA2, "mamba2" },
@@ -1038,6 +1039,27 @@ static const std::map> LLM_TENSOR_N
{ LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" },
},
},
+ {
+ LLM_ARCH_GEMMA_EMBEDDING,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
+ },
+ },
{
LLM_ARCH_STARCODER2,
{
diff --git a/src/llama-arch.h b/src/llama-arch.h
index 86c119692..41f377923 100644
--- a/src/llama-arch.h
+++ b/src/llama-arch.h
@@ -49,6 +49,7 @@ enum llm_arch {
LLM_ARCH_GEMMA2,
LLM_ARCH_GEMMA3,
LLM_ARCH_GEMMA3N,
+ LLM_ARCH_GEMMA_EMBEDDING,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_MAMBA2,
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index bc42e7a21..de1d8f58c 100644
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
@@ -285,6 +285,9 @@ llama_context::llama_context(
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
+ // avoid reserving graphs with zero outputs
+ n_outputs = 1;
+
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
// resolve automatic Flash Attention use
@@ -1367,7 +1370,8 @@ llm_graph_result * llama_context::get_gf_res_reserve() const {
}
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) {
- // LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
+ //LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
+ GGML_ASSERT_CONTINUE(n_outputs >= 1);
if (n_tokens % n_seqs != 0) {
n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
index 49ea5da7c..7ce2960eb 100644
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
@@ -258,6 +258,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
}
}
+static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
+ LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
+ const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
+ (swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
+ (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
+ (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
+ LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
+ LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
+ LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
+
+ LLAMA_LOG_DEBUG(" ");
+ for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
+ LLAMA_LOG_DEBUG("%2d", j);
+ }
+ LLAMA_LOG_DEBUG("\n");
+
+ for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
+ LLAMA_LOG_DEBUG(" %2d ", i);
+ for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
+ float val = data[i * n_kv + j];
+ if (val == -INFINITY) {
+ LLAMA_LOG_DEBUG(" ∞");
+ } else {
+ LLAMA_LOG_DEBUG(" 0");
+ }
+ }
+ LLAMA_LOG_DEBUG("\n");
+ }
+}
+
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
const int64_t n_kv = ubatch->n_tokens;
const int64_t n_tokens = ubatch->n_tokens;
@@ -277,21 +307,32 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
const llama_seq_id s0 = ubatch->seq_id[i0][0];
- // TODO: reimplement this like in llama_kv_cache
- if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
- if (hparams.use_alibi) {
- f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
- } else {
- f = 0.0f;
- }
- break;
+ if (s0 != s1) {
+ continue; // skip different sequences
+ }
+
+ if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
+ continue; // skip future tokens for causal attention
+ }
+
+ if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
+ continue; // skip masked tokens for SWA
+ }
+
+ // TODO: reimplement this like in llama_kv_cache_unified
+ if (hparams.use_alibi) {
+ f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
+ } else {
+ f = 0.0f;
}
}
-
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
}
}
}
+ if (debug) {
+ print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
+ }
}
void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
diff --git a/src/llama-graph.h b/src/llama-graph.h
index 3c85333fd..ca90fdf61 100644
--- a/src/llama-graph.h
+++ b/src/llama-graph.h
@@ -78,6 +78,11 @@ struct llm_graph_params;
class llm_graph_input_i {
public:
+ llm_graph_input_i() {
+ const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG");
+ debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0;
+ }
+
virtual ~llm_graph_input_i() = default;
virtual void set_input(const llama_ubatch * ubatch) = 0;
@@ -90,6 +95,9 @@ public:
GGML_UNUSED(params);
return false;
}
+protected:
+ // env: LLAMA_GRAPH_INPUT_DEBUG
+ int debug = 0;
};
using llm_graph_input_ptr = std::unique_ptr;
diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp
index 91636572d..4b7a65362 100644
--- a/src/llama-hparams.cpp
+++ b/src/llama-hparams.cpp
@@ -1,6 +1,7 @@
#include "llama-hparams.h"
#include "ggml.h"
+#include
void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
if (dense_first) {
@@ -178,3 +179,39 @@ uint32_t llama_hparams::n_layer_kv() const {
return res;
}
+
+bool llama_hparams::is_masked_swa(llama_pos p0, llama_pos p1) const {
+ assert(p0 >= 0 && p1 >= 0);
+
+ switch (swa_type) {
+ case LLAMA_SWA_TYPE_NONE:
+ {
+ } break;
+ case LLAMA_SWA_TYPE_STANDARD:
+ {
+ if (p1 - p0 >= (int32_t) n_swa) {
+ return true;
+ }
+ } break;
+ case LLAMA_SWA_TYPE_CHUNKED:
+ {
+ const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
+
+ if (p0 < pos_chunk_start) {
+ return true;
+ }
+ } break;
+ case LLAMA_SWA_TYPE_SYMMETRIC:
+ {
+ const int32_t half_n_swa = (int32_t) n_swa / 2;
+ const int32_t pos_diff = p1 - p0;
+
+ // Mask if outside the symmetric window
+ if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
+ return true;
+ }
+ } break;
+ }
+
+ return false;
+}
diff --git a/src/llama-hparams.h b/src/llama-hparams.h
index 60415f0c2..06d1e51db 100644
--- a/src/llama-hparams.h
+++ b/src/llama-hparams.h
@@ -16,9 +16,10 @@ enum llama_expert_gating_func_type {
};
enum llama_swa_type {
- LLAMA_SWA_TYPE_NONE = 0,
- LLAMA_SWA_TYPE_STANDARD = 1,
- LLAMA_SWA_TYPE_CHUNKED = 2,
+ LLAMA_SWA_TYPE_NONE = 0,
+ LLAMA_SWA_TYPE_STANDARD = 1,
+ LLAMA_SWA_TYPE_CHUNKED = 2,
+ LLAMA_SWA_TYPE_SYMMETRIC = 3,
};
struct llama_hparams_posnet {
@@ -227,6 +228,8 @@ struct llama_hparams {
// number of layers for which has_kv() returns true
uint32_t n_layer_kv() const;
+
+ bool is_masked_swa(llama_pos p0, llama_pos p1) const;
};
static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable");
diff --git a/src/llama-kv-cache-iswa.cpp b/src/llama-kv-cache-iswa.cpp
index 1de943869..3f35882c8 100644
--- a/src/llama-kv-cache-iswa.cpp
+++ b/src/llama-kv-cache-iswa.cpp
@@ -64,14 +64,14 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
kv_base = std::make_unique(
model, type_k, type_v,
v_trans, offload, unified, size_base, n_seq_max, n_pad,
- 0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
+ 0, filter_base, reuse);
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
kv_swa = std::make_unique(
model, type_k, type_v,
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
- hparams.n_swa, hparams.swa_type, filter_swa, reuse);
+ hparams.n_swa, filter_swa, reuse);
}
void llama_kv_cache_iswa::clear(bool data) {
diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
index 061f44ae1..bfa43ea38 100644
--- a/src/llama-kv-cache.cpp
+++ b/src/llama-kv-cache.cpp
@@ -27,11 +27,10 @@ llama_kv_cache::llama_kv_cache(
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
- llama_swa_type swa_type,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse) :
model(model), hparams(model.hparams), v_trans(v_trans),
- n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
+ n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa) {
GGML_ASSERT(kv_size % n_pad == 0);
@@ -1393,29 +1392,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
}
bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
- assert(p0 >= 0 && p1 >= 0);
-
- switch (swa_type) {
- case LLAMA_SWA_TYPE_NONE:
- {
- } break;
- case LLAMA_SWA_TYPE_STANDARD:
- {
- if (p1 - p0 >= (int32_t) n_swa) {
- return true;
- }
- } break;
- case LLAMA_SWA_TYPE_CHUNKED:
- {
- const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
-
- if (p0 < pos_chunk_start) {
- return true;
- }
- } break;
- }
-
- return false;
+ return hparams.is_masked_swa(p0, p1);
}
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h
index 07d29bb81..55ee355b2 100644
--- a/src/llama-kv-cache.h
+++ b/src/llama-kv-cache.h
@@ -89,7 +89,6 @@ public:
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
- llama_swa_type swa_type,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
@@ -212,8 +211,6 @@ private:
// env: LLAMA_KV_CACHE_DEBUG
int debug = 0;
- const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
-
std::vector ctxs;
std::vector bufs;
diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp
index ba61ebaa8..38f92da11 100644
--- a/src/llama-memory-hybrid.cpp
+++ b/src/llama-memory-hybrid.cpp
@@ -17,7 +17,6 @@ llama_memory_hybrid::llama_memory_hybrid(
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
- llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
@@ -41,7 +40,6 @@ llama_memory_hybrid::llama_memory_hybrid(
n_seq_max,
n_pad,
n_swa,
- swa_type,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn,
diff --git a/src/llama-memory-hybrid.h b/src/llama-memory-hybrid.h
index 11a356517..0eb63f5ef 100644
--- a/src/llama-memory-hybrid.h
+++ b/src/llama-memory-hybrid.h
@@ -27,7 +27,6 @@ public:
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
- llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index c3dcbb7d1..f81fd0a51 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -1115,7 +1115,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
- case 18: type = LLM_TYPE_537M; break;
+ case 18: type = LLM_TYPE_270M; break;
case 26: type = LLM_TYPE_1B; break;
case 34: type = LLM_TYPE_4B; break;
case 48: type = LLM_TYPE_12B; break;
@@ -1147,6 +1147,26 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
+ case LLM_ARCH_GEMMA_EMBEDDING:
+ {
+ hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC;
+ hparams.set_swa_pattern(6);
+
+ hparams.causal_attn = false; // embeddings do not use causal attention
+ hparams.rope_freq_base_train_swa = 10000.0f;
+ hparams.rope_freq_scale_train_swa = 1.0f;
+
+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
+
+ switch (hparams.n_layer) {
+ case 24: type = LLM_TYPE_0_3B; break;
+ default: type = LLM_TYPE_UNKNOWN;
+ }
+ hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k));
+
+ } break;
case LLM_ARCH_STARCODER2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -3580,6 +3600,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
} break;
case LLM_ARCH_GEMMA3:
+ case LLM_ARCH_GEMMA_EMBEDDING:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -11145,6 +11166,136 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
}
};
+struct llm_build_gemma_embedding_iswa : public llm_graph_context {
+ llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+ const int64_t n_embd_head = hparams.n_embd_head_k;
+
+ ggml_tensor * cur;
+ ggml_tensor * inpL;
+
+ inpL = build_inp_embd(model.tok_embd);
+
+ // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
+ if (ubatch.token) {
+ inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
+ cb(inpL, "inp_scaled", -1);
+ }
+
+ // inp_pos - contains the positions
+ ggml_tensor * inp_pos = build_inp_pos();
+
+ auto * inp_attn = build_attn_inp_no_cache();
+
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+ for (int il = 0; il < n_layer; ++il) {
+ const float freq_base_l = model.get_rope_freq_base (cparams, il);
+ const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
+ // norm
+ cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+ cb(cur, "attn_norm", il);
+
+ // self-attention
+ {
+ // compute Q and K and RoPE them
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+ cb(Qcur, "Qcur_normed", il);
+
+ Qcur = ggml_rope_ext(
+ ctx0, Qcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+
+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+ cb(Kcur, "Kcur_normed", il);
+
+ Kcur = ggml_rope_ext(
+ ctx0, Kcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
+
+ // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
+ Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
+
+ cur = build_attn(inp_attn,
+ model.layers[il].wo, NULL,
+ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
+ }
+
+ if (il == n_layer - 1 && inp_out_ids) {
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
+ cur = build_norm(cur,
+ model.layers[il].attn_post_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "attn_post_norm", il);
+
+ ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
+ cb(sa_out, "sa_out", il);
+
+ cur = build_norm(sa_out,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "ffn_norm", il);
+
+ // feed-forward network
+ {
+ cur = build_ffn(cur,
+ model.layers[il].ffn_up, NULL, NULL,
+ model.layers[il].ffn_gate, NULL, NULL,
+ model.layers[il].ffn_down, NULL, NULL,
+ NULL,
+ LLM_FFN_GELU, LLM_FFN_PAR, il);
+ cb(cur, "ffn_out", il);
+ }
+
+ cur = build_norm(cur,
+ model.layers[il].ffn_post_norm, NULL,
+ LLM_NORM_RMS, -1);
+ cb(cur, "ffn_post_norm", -1);
+
+ cur = ggml_add(ctx0, cur, sa_out);
+
+ cur = build_cvec(cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = build_norm(cur,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, -1);
+
+ cb(cur, "result_norm", -1);
+ res->t_embd = cur;
+
+ ggml_build_forward_expand(gf, cur);
+ }
+};
+
// TODO: move up next to build_starcoder
struct llm_build_starcoder2 : public llm_graph_context {
llm_build_starcoder2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
@@ -18581,6 +18732,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_NEO_BERT:
case LLM_ARCH_WAVTOKENIZER_DEC:
+ case LLM_ARCH_GEMMA_EMBEDDING:
case LLM_ARCH_DREAM:
case LLM_ARCH_LLADA:
{
@@ -18629,7 +18781,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* attn_kv_size */ cparams.n_ctx,
/* attn_n_pad */ padding,
/* attn_n_swa */ hparams.n_swa,
- /* attn_swa_type */ hparams.swa_type,
/* recurrent_type_k */ GGML_TYPE_F32,
/* recurrent_type_v */ GGML_TYPE_F32,
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
@@ -18699,7 +18850,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.n_seq_max,
padding,
hparams.n_swa,
- hparams.swa_type,
nullptr,
nullptr);
}
@@ -18861,6 +19011,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique(*this, params);
} break;
+ case LLM_ARCH_GEMMA_EMBEDDING:
+ {
+ llm = std::make_unique(*this, params);
+ } break;
case LLM_ARCH_STARCODER2:
{
llm = std::make_unique(*this, params);
@@ -19261,6 +19415,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_GEMMA2:
case LLM_ARCH_GEMMA3:
case LLM_ARCH_GEMMA3N:
+ case LLM_ARCH_GEMMA_EMBEDDING:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM:
case LLM_ARCH_GPTNEOX:
diff --git a/src/llama-model.h b/src/llama-model.h
index fa44d800d..10b1767f2 100644
--- a/src/llama-model.h
+++ b/src/llama-model.h
@@ -39,7 +39,6 @@ enum llm_type {
LLM_TYPE_410M,
LLM_TYPE_450M,
LLM_TYPE_475M,
- LLM_TYPE_537M,
LLM_TYPE_558M,
LLM_TYPE_700M,
LLM_TYPE_770M,
diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp
index e8c0fc341..2186f827b 100644
--- a/src/llama-sampling.cpp
+++ b/src/llama-sampling.cpp
@@ -604,10 +604,73 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_dist *) smpl->ctx;
- // sorting is not necessary here
- llama_sampler_softmax_impl(cur_p, false);
+ // edge cases
+ if (cur_p->size == 0) {
+ cur_p->selected = -1;
+ return;
+ }
+
+ cur_p->selected = 0;
+
+ if (cur_p->size == 1) {
+ cur_p->data[0].p = 1.0f;
+ return;
+ }
+
+ // max logit for numerical stability
+ float max_l = cur_p->data[0].logit;
+ if (!cur_p->sorted) {
+ for (size_t i = 1; i < cur_p->size; ++i) {
+ max_l = std::max(max_l, cur_p->data[i].logit);
+ }
+ }
+
+ // apply softmax to obtain the probabilities
+ double sum_cum = 0.0f;
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ float p = expf(cur_p->data[i].logit - max_l);
+ cur_p->data[i].p = p;
+ sum_cum += p;
+ }
+
+#if 1
+ // sample from the obtained probabilities and normalize the probs in a single pass
+ // this is ~3x faster on Mac with full gpt-oss vocab than the version below
+ //
+ std::uniform_real_distribution dist(0.0f, 1.0f);
+ const double rnd = dist(ctx->rng);
+
+ double sum_run = 0.0f;
+ const double sum_tgt = sum_cum*rnd;
+
+ bool found = false;
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ if (!found) {
+ // accumulate probs until we reach the target sum
+ sum_run += cur_p->data[i].p;
+ if (sum_run >= sum_tgt) {
+ cur_p->selected = i;
+ found = true;
+ }
+ }
+
+ // normalize probs
+ cur_p->data[i].p /= sum_cum;
+ }
+
+ // fallback to the last token (don't think this can happen)
+ assert(found);
+ if (!found) {
+ cur_p->selected = cur_p->size - 1;
+ }
+#else
+ // for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ cur_p->data[i].p /= sum_cum;
+ }
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
+#endif
}
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
diff --git a/tools/server/server.cpp b/tools/server/server.cpp
index e0302e2f2..44487eca5 100644
--- a/tools/server/server.cpp
+++ b/tools/server/server.cpp
@@ -86,6 +86,7 @@ enum error_type {
ERROR_TYPE_PERMISSION,
ERROR_TYPE_UNAVAILABLE, // custom error
ERROR_TYPE_NOT_SUPPORTED, // custom error
+ ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error
};
static bool server_task_type_need_embd(server_task_type task_type) {
@@ -1224,6 +1225,10 @@ static json format_error_response(const std::string & message, const enum error_
type_str = "unavailable_error";
code = 503;
break;
+ case ERROR_TYPE_EXCEED_CONTEXT_SIZE:
+ type_str = "exceed_context_size_error";
+ code = 400;
+ break;
}
return json {
{"code", code},
@@ -1237,12 +1242,21 @@ struct server_task_result_error : server_task_result {
error_type err_type = ERROR_TYPE_SERVER;
std::string err_msg;
+ // for ERROR_TYPE_EXCEED_CONTEXT_SIZE
+ int32_t n_prompt_tokens = 0;
+ int32_t n_ctx = 0;
+
virtual bool is_error() override {
return true;
}
virtual json to_json() override {
- return format_error_response(err_msg, err_type);
+ json res = format_error_response(err_msg, err_type);
+ if (err_type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
+ res["n_prompt_tokens"] = n_prompt_tokens;
+ res["n_ctx"] = n_ctx;
+ }
+ return res;
}
};
@@ -2605,16 +2619,22 @@ struct server_context {
}
void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
- send_error(slot.id_task, error, type);
+ send_error(slot.id_task, error, type, slot.n_prompt_tokens, slot.n_ctx);
}
- void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
+ void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) {
SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
+ if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
+ GGML_ASSERT(n_ctx > 0 && n_prompt_tokens > 0);
+ }
+
auto res = std::make_unique();
- res->id = id_task;
- res->err_type = type;
- res->err_msg = error;
+ res->id = id_task;
+ res->err_type = type;
+ res->err_msg = error;
+ res->n_prompt_tokens = n_prompt_tokens;
+ res->n_ctx = n_ctx;
queue_results.send(std::move(res));
}
@@ -3286,7 +3306,7 @@ struct server_context {
if (slot.n_prompt_tokens > slot.n_ctx) {
slot.release();
- send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER);
+ send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
continue;
}
} else {
@@ -3296,7 +3316,7 @@ struct server_context {
// context shift should be applied only during the generation phase
if (slot.n_prompt_tokens >= slot.n_ctx) {
slot.release();
- send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
+ send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
continue;
}
}
diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py
index 509c024b7..22dbfdd9b 100644
--- a/tools/server/tests/unit/test_chat_completion.py
+++ b/tools/server/tests/unit/test_chat_completion.py
@@ -385,3 +385,20 @@ def test_logit_bias():
output_text = res.choices[0].message.content
assert output_text
assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
+
+def test_context_size_exceeded():
+ global server
+ server.start()
+ res = server.make_request("POST", "/chat/completions", data={
+ "messages": [
+ {"role": "system", "content": "Book"},
+ {"role": "user", "content": "What is the best book"},
+ ] * 100, # make the prompt too long
+ })
+ assert res.status_code == 400
+ assert "error" in res.body
+ assert res.body["error"]["type"] == "exceed_context_size_error"
+ assert res.body["error"]["n_prompt_tokens"] > 0
+ assert server.n_ctx is not None
+ assert server.n_slots is not None
+ assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots