diff --git a/common/arg.cpp b/common/arg.cpp index 9d436d5b3..90ac457c9 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2796,6 +2796,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.ssl_file_cert = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_CERT_FILE")); + add_opt(common_arg( + {"--chat-template-kwargs"}, "STRING", + string_format("sets additional params for the json template parser"), + [](common_params & params, const std::string & value) { + auto parsed = json::parse(value); + for (const auto & item : parsed.items()) { + params.default_template_kwargs[item.key()] = item.value().dump(); + } + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_CHAT_TEMPLATE_KWARGS")); add_opt(common_arg( {"-to", "--timeout"}, "N", string_format("server read/write timeout in seconds (default: %d)", params.timeout_read), diff --git a/common/chat.cpp b/common/chat.cpp index 478b59d6b..49157b60d 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -17,6 +17,8 @@ #include #include +using json = nlohmann::ordered_json; + static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) { auto time = std::chrono::system_clock::to_time_t(now); auto local_time = *std::localtime(&time); @@ -140,6 +142,7 @@ struct templates_params { bool add_generation_prompt = true; bool enable_thinking = true; std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); + json extra_context; }; common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) { @@ -720,16 +723,23 @@ static void foreach_function(const json & tools, const std::function & messages_override = std::nullopt, + const std::optional & tools_override = std::nullopt, + const std::optional & additional_context = std::nullopt) { minja::chat_template_inputs tmpl_inputs; - tmpl_inputs.messages = messages; - tmpl_inputs.tools = tools; - tmpl_inputs.add_generation_prompt = add_generation_prompt; - tmpl_inputs.extra_context = extra_context; + tmpl_inputs.messages = messages_override ? *messages_override : inputs.messages; + if (tools_override) { + tmpl_inputs.tools = *tools_override; + } else { + tmpl_inputs.tools = inputs.tools.empty() ? json() : inputs.tools; + } + tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt; + tmpl_inputs.extra_context = inputs.extra_context; + if (additional_context) { + tmpl_inputs.extra_context.merge_patch(*additional_context); + } // TODO: add flag to control date/time, if only for testing purposes. // tmpl_inputs.now = std::chrono::system_clock::now(); @@ -828,7 +838,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp inputs.messages, "Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request"); - data.prompt = apply(tmpl, tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages); data.format = COMMON_CHAT_FORMAT_GENERIC; return data; } @@ -904,7 +914,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat data.preserved_tokens = { "[TOOL_CALLS]", }; - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO; return data; } @@ -934,7 +944,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ adjusted_messages.push_back(msg); } } - data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {}); + data.prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages); data.format = COMMON_CHAT_FORMAT_COMMAND_R7B; if (string_ends_with(data.prompt, "<|START_THINKING|>")) { if (!inputs.enable_thinking) { @@ -1122,7 +1132,7 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te } else { data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; } - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, { + data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, json { {"date_string", format_time(inputs.now, "%d %b %Y")}, {"tools_in_user_message", false}, {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools}, @@ -1187,7 +1197,7 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; - auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + auto prompt = apply(tmpl, inputs); // Hacks to fix the official (broken) prompt. // It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead, @@ -1282,7 +1292,7 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { LOG_DBG("%s\n", __func__); common_chat_params data; - data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { + data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ json(), json { {"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")}, {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, }); @@ -1338,7 +1348,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar // If the function is python, we also allow raw python code (if the line after `python\n` doesn't start w/ opening `{`), which the model seems to prefer for multiline code. common_chat_params data; - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; if (inputs.tools.is_array() && !inputs.tools.empty()) { data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -1465,7 +1475,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; } - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs); // TODO: if (has_raw_python) return data; } @@ -1498,14 +1508,15 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; - json additional_context = { + json extra_context = json { {"enable_thinking", inputs.enable_thinking}, }; + extra_context.update(inputs.extra_context); - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, additional_context); + data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, extra_context); data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO; if (string_ends_with(data.prompt, "\n")) { - if (!inputs.enable_thinking) { + if (!extra_context["enable_thinking"]) { data.prompt += ""; } else { data.thinking_forced_open = true; @@ -1691,7 +1702,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; data.grammar_lazy = false; if (!inputs.json_schema.is_null()) { @@ -1722,6 +1733,12 @@ static common_chat_params common_chat_templates_apply_jinja( params.enable_thinking = inputs.enable_thinking; params.grammar = inputs.grammar; params.now = inputs.now; + + params.extra_context = json::object(); + for (auto el : inputs.chat_template_kwargs) { + params.extra_context[el.first] = json::parse(el.second); + } + if (!inputs.json_schema.empty()) { params.json_schema = json::parse(inputs.json_schema); } diff --git a/common/chat.h b/common/chat.h index 9f59e6b08..ca807c145 100644 --- a/common/chat.h +++ b/common/chat.h @@ -7,6 +7,7 @@ #include #include #include +#include struct common_chat_templates; @@ -125,6 +126,7 @@ struct common_chat_templates_inputs { common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; bool enable_thinking = true; std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); + std::map chat_template_kwargs; }; struct common_chat_params { diff --git a/common/common.h b/common/common.h index d6223da84..84b94ec20 100644 --- a/common/common.h +++ b/common/common.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #ifdef _WIN32 @@ -377,6 +378,8 @@ struct common_params { std::string ssl_file_key = ""; // NOLINT std::string ssl_file_cert = ""; // NOLINT + std::map default_template_kwargs; + // "advanced" endpoints are disabled by default for better security bool webui = true; bool endpoint_slots = false; diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 778927f68..a2977ea2e 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -339,7 +339,7 @@ extern "C" { typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data); // Compare the output of two backends - GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data); + GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node); // Tensor initialization GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 141dced2e..81d42e21f 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -526,6 +526,8 @@ extern "C" { GGML_OP_CROSS_ENTROPY_LOSS_BACK, GGML_OP_OPT_STEP_ADAMW, + GGML_OP_GLU, + GGML_OP_COUNT, }; @@ -549,6 +551,14 @@ extern "C" { GGML_UNARY_OP_COUNT, }; + enum ggml_glu_op { + GGML_GLU_OP_REGLU, + GGML_GLU_OP_GEGLU, + GGML_GLU_OP_SWIGLU, + + GGML_GLU_OP_COUNT, + }; + enum ggml_object_type { GGML_OBJECT_TYPE_TENSOR, GGML_OBJECT_TYPE_GRAPH, @@ -671,6 +681,7 @@ extern "C" { GGML_API const char * ggml_op_symbol(enum ggml_op op); GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op); + GGML_API const char * ggml_glu_op_name(enum ggml_glu_op op); GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor); @@ -775,6 +786,7 @@ extern "C" { GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3); GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor); + GGML_API enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor); GGML_API void * ggml_get_data (const struct ggml_tensor * tensor); GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor); @@ -1103,6 +1115,63 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // gated linear unit ops + // A: n columns, r rows, + // result is n / 2 columns, r rows, + // expects gate in second half of row, unless swapped is true + GGML_API struct ggml_tensor * ggml_glu( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_glu_op op, + bool swapped); + + GGML_API struct ggml_tensor * ggml_reglu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_reglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_swiglu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_swiglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // A: n columns, r rows, + // B: n columns, r rows, + GGML_API struct ggml_tensor * ggml_glu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_glu_op op); + + GGML_API struct ggml_tensor * ggml_reglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_geglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_swiglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + // normalize along rows GGML_API struct ggml_tensor * ggml_norm( struct ggml_context * ctx, diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 6f040c41a..eafd5fd34 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -823,8 +823,9 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str } if (sched->debug > 1) { ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node); - GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name, - fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node)); + GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name, + fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node), + graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)]); for (int j = 0; j < GGML_MAX_SRC; j++) { struct ggml_tensor * src = node->src[j]; if (src == NULL) { @@ -1832,7 +1833,7 @@ void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) { ggml_free(copy.ctx_unallocated); } -bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) { +bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node) { struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph); if (copy.buffer == NULL) { return false; @@ -1843,28 +1844,45 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t assert(g1->n_nodes == g2->n_nodes); - for (int i = 0; i < g1->n_nodes; i++) { - struct ggml_tensor * t1 = g1->nodes[i]; - struct ggml_tensor * t2 = g2->nodes[i]; + if (test_node != nullptr) { + // Compute the whole graph and only test the output for a specific tensor + ggml_backend_graph_compute(backend1, g1); + ggml_backend_graph_compute(backend2, g2); - assert(t1->op == t2->op && ggml_are_same_layout(t1, t2)); - - struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1); - struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1); - - ggml_backend_graph_compute(backend1, &g1v); - ggml_backend_graph_compute(backend2, &g2v); - - if (ggml_is_view_op(t1->op)) { - continue; + int test_node_idx = -1; + for (int i = 0; i < g1->n_nodes; i++) { + struct ggml_tensor * t1 = g1->nodes[i]; + if (t1 == test_node) { + test_node_idx = i; + break; + } } + GGML_ASSERT(test_node_idx != -1); - // compare results, calculate rms etc - if (!callback(i, t1, t2, user_data)) { - break; + callback(test_node_idx, g1->nodes[test_node_idx], g2->nodes[test_node_idx], user_data); + } else { + for (int i = 0; i < g1->n_nodes; i++) { + struct ggml_tensor * t1 = g1->nodes[i]; + struct ggml_tensor * t2 = g2->nodes[i]; + + assert(t1->op == t2->op && ggml_are_same_layout(t1, t2)); + + struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1); + struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1); + + ggml_backend_graph_compute(backend1, &g1v); + ggml_backend_graph_compute(backend2, &g2v); + + if (ggml_is_view_op(t1->op)) { + continue; + } + + // compare results, calculate rms etc + if (!callback(i, t1, t2, user_data)) { + break; + } } } - ggml_backend_graph_copy_free(copy); return true; diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 6604601c9..4a6472df9 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1963,6 +1963,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_unary(params, tensor); } break; + case GGML_OP_GLU: + { + ggml_compute_forward_glu(params, tensor); + } break; case GGML_OP_GET_REL_POS: { ggml_compute_forward_get_rel_pos(params, tensor); @@ -2173,6 +2177,18 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { GGML_ABORT("fatal error"); } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: + { + n_tasks = n_threads; + } break; + default: + GGML_ABORT("fatal error"); + } + break; case GGML_OP_SILU_BACK: case GGML_OP_MUL: case GGML_OP_DIV: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 9f17ea43c..27586ed1f 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3184,6 +3184,435 @@ void ggml_compute_forward_silu_back( } } +// ggml_compute_forward_reglu + +static void ggml_compute_forward_reglu_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_reglu_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o); + ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + GGML_UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + +static void ggml_compute_forward_reglu( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_reglu_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_reglu_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_geglu + +static void ggml_compute_forward_geglu_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_geglu_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o); + ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + GGML_UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + +static void ggml_compute_forward_geglu( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_geglu_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_geglu_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_swiglu + +static void ggml_compute_forward_swiglu_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_swiglu_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o); + ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + GGML_UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + +static void ggml_compute_forward_swiglu( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_swiglu_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_swiglu_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_norm static void ggml_compute_forward_norm_f32( @@ -8052,6 +8481,34 @@ void ggml_compute_forward_unary( } } +//ggml_compute_forward_glu + +void ggml_compute_forward_glu( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_glu_op op = ggml_get_glu_op(dst); + + switch (op) { + case GGML_GLU_OP_REGLU: + { + ggml_compute_forward_reglu(params, dst); + } break; + case GGML_GLU_OP_GEGLU: + { + ggml_compute_forward_geglu(params, dst); + } break; + case GGML_GLU_OP_SWIGLU: + { + ggml_compute_forward_swiglu(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_get_rel_pos static void ggml_compute_forward_get_rel_pos_f16( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 3a395fdcd..5b384e4ba 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -94,6 +94,7 @@ void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, st void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_glu(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 5e34d79a1..ed5d7aefc 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -254,6 +254,30 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) { } } +void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g) { + int i = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(g + i))); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(g + i))); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i))); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i))); + } +#endif + for (; i < n; ++i) { + y[i] = ggml_silu_f32(x[i]) * g[i]; + } +} + ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { int i = 0; ggml_float sum = 0; diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 84f6c0e6d..d5507d756 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -905,6 +905,60 @@ inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, con } } +inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) { + for (int i = 0; i < n; ++i) { + y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f; + } +} + +inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) { + for (int i = 0; i < n; ++i) { + float v = GGML_CPU_FP16_TO_FP32(x[i]); + y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v * GGML_CPU_FP16_TO_FP32(g[i]) : 0.f); + } +} + +#ifdef GGML_GELU_FP16 +inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) { + uint16_t t; + for (int i = 0; i < n; ++i) { + if (x[i] <= -10.0f) { + y[i] = 0.0f; + } else if (x[i] >= 10.0f) { + y[i] = x[i] * g[i]; + } else { + ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[t]) * g[i]; + } + } +} +#else +inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_gelu_f32(x[i]) * g[i]; + } +} +#endif + +inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + float v = GGML_CPU_FP16_TO_FP32(g[i]); + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * v); + } +} + +void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g); + +inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) { + for (int i = 0; i < n; ++i) { + float v = GGML_CPU_FP16_TO_FP32(x[i]); + float w = GGML_CPU_FP16_TO_FP32(g[i]); + y[i] = GGML_CPU_FP32_TO_FP16((v/(1.0f + expf(-v))) * w); + } +} + inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE ggml_float sum = 0.0; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 988c9180d..5d9503ee6 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2308,6 +2308,21 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(dst)) { + case GGML_GLU_OP_REGLU: + ggml_cuda_op_reglu(ctx, dst); + break; + case GGML_GLU_OP_GEGLU: + ggml_cuda_op_geglu(ctx, dst); + break; + case GGML_GLU_OP_SWIGLU: + ggml_cuda_op_swiglu(ctx, dst); + break; + default: + return false; + } + break; case GGML_OP_NORM: ggml_cuda_op_norm(ctx, dst); break; @@ -3101,6 +3116,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: + return ggml_is_contiguous_1(op->src[0]); + default: + return false; + } + break; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: { diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 2c0375fbe..ba3c0f137 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -196,6 +196,95 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } +/* gated ops */ + +template +static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) { + const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + // perform base op and multiply with gate (either offset in same tensor or a separate one) + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + + dst[i] = (T)(op((float)x[j0]) * (float)g[j1]); +} + +template +static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) { + const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE; + unary_gated_op_kernel<<>>(x, g, dst, k, n, o0, o1); +} + +template +void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + void * src0_d = src0->data; + void * src1_d = src1 ? src1->data : src0->data; + const int64_t src0_o = src0->nb[1]; + const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + void * dst_d = dst->data; + const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(src0->nb[0] == ggml_element_size(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src1->nb[0] == ggml_element_size(src1)); + GGML_ASSERT(src1->ne[0] == nc); + GGML_ASSERT(src0->type == src1->type); + } + + const int32_t swapped = ((const int32_t *) dst->op_params)[1]; + + if (src0->type == GGML_TYPE_F16) { + half * src0_p = (half *) src0_d; + half * src1_p = (half *) src1_d; + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + unary_gated_cuda(src0_p, src1_p, (half *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(half), src1_o / sizeof(half), stream); + } else { + float * src0_p = (float *) src0_d; + float * src1_p = (float *) src1_d; + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + unary_gated_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), stream); + } +} + +void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} + +void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} + +void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} + /* silu_back */ static __device__ __forceinline__ float op_silu_back(float grad, float x) { diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 6686fc17e..9094f1d0b 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -15,6 +15,7 @@ #define CUDA_SQRT_BLOCK_SIZE 256 #define CUDA_SIN_BLOCK_SIZE 256 #define CUDA_COS_BLOCK_SIZE 256 +#define CUDA_GLU_BLOCK_SIZE 256 void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst); @@ -57,3 +58,9 @@ void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 57761644f..4972558c9 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -301,6 +301,7 @@ struct ggml_cgraph { struct ggml_tensor ** grads; // the outputs of these tensors are the gradients of the nodes struct ggml_tensor ** grad_accs; // accumulators for node gradients struct ggml_tensor ** leafs; // tensors with constant data + int32_t * use_counts;// number of uses of each tensor, indexed by hash table slot struct ggml_hash_set visited_hash_set; @@ -467,13 +468,76 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) { #define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x) #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x) +// return true if the node's results are only used by N other nodes +// and can be fused into their calculations. +static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) { + const struct ggml_tensor * node = cgraph->nodes[node_idx]; + + // check the use count against how many we're replacing + size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node); + if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) { + return false; + } + + // if node is a view, some other node might be using the intermediate result + // via the view source. + if (node->view_src) { + return false; + } + + // If the user requested output for the node, can't fuse + if (node->flags & GGML_TENSOR_FLAG_OUTPUT) { + return false; + } + + return true; +} + +// Returns true if nodes [i, i+ops.size()) are the sequence of ggml_ops in ops[] +// and are fusable. Nodes are considered fusable according to this function if: +// - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses). +// - all nodes except the last are a src of the following node. +// - all nodes are the same shape. +// TODO: Consider allowing GGML_OP_NONE nodes in between +static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) { + if (node_idx + num_ops > cgraph->n_nodes) { + return false; + } + + for (int i = 0; i < num_ops; ++i) { + struct ggml_tensor * node = cgraph->nodes[node_idx + i]; + if (node->op != ops[i]) { + return false; + } + if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idx + i, 1)) { + return false; + } + if (i > 0) { + struct ggml_tensor * prev = cgraph->nodes[node_idx + i - 1]; + if (node->src[0] != prev && node->src[1] != prev) { + return false; + } + if (!ggml_are_same_shape(node, prev)) { + return false; + } + } + } + return true; +} + #ifdef __cplusplus } #endif #ifdef __cplusplus +#include #include +// nicer C++ syntax for ggml_can_fuse +inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops) { + return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size()); +} + // expose GGUF internals for test code GGML_API size_t gguf_type_size(enum gguf_type type); GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params); diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 260440aed..7a9aab316 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -422,6 +422,17 @@ typedef struct { int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources } ggml_metal_kargs_im2col; +typedef struct{ + int32_t ne00; + uint64_t nb01; + int32_t ne10; + uint64_t nb11; + int32_t ne0; + uint64_t nb1; + int32_t i00; + int32_t i10; +} ggml_metal_kargs_glu; + typedef struct { int64_t ne00; int64_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 77df24e81..ceae3c7c3 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -526,6 +526,9 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_SIN, GGML_METAL_KERNEL_TYPE_COS, GGML_METAL_KERNEL_TYPE_NEG, + GGML_METAL_KERNEL_TYPE_REGLU, + GGML_METAL_KERNEL_TYPE_GEGLU, + GGML_METAL_KERNEL_TYPE_SWIGLU, GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_MEAN, GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, @@ -1502,6 +1505,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); @@ -1680,6 +1686,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex default: return false; } + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: + return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + default: + return false; + } case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -2419,6 +2434,62 @@ static bool ggml_metal_encode_node( GGML_ABORT("fatal error"); } } break; + case GGML_OP_GLU: + { + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + if (src1) { + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + } + + id pipeline = nil; + + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_REGLU: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline; + break; + case GGML_GLU_OP_GEGLU: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline; + break; + case GGML_GLU_OP_SWIGLU: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline; + break; + default: + GGML_ABORT("fatal error"); + } + + const int32_t swp = ((const int32_t *) dst->op_params)[1]; + + const int32_t i00 = swp ? ne0 : 0; + const int32_t i10 = swp ? 0 : ne0; + + ggml_metal_kargs_glu args = { + /*.ne00 =*/ ne00, + /*.nb01 =*/ nb01, + /*.ne10 =*/ src1 ? ne10 : ne00, + /*.nb11 =*/ src1 ? nb11 : nb01, + /*.ne0 =*/ ne0, + /*.nb1 =*/ nb1, + /*.i00 =*/ src1 ? 0 : i00, + /*.i10 =*/ src1 ? 0 : i10, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (src1) { + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&args length:sizeof(args) atIndex:3]; + + const int64_t nrows = ggml_nrows(src0); + + const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_SQR: { GGML_ASSERT(ggml_is_contiguous(src0)); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 984a0ab50..fc3cfe35a 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1191,6 +1191,70 @@ kernel void kernel_neg( dst[tpig] = -src0[tpig]; } +kernel void kernel_reglu( + device const char * src0, + device const char * src1, + device char * dst, + constant ggml_metal_kargs_glu & args, + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + + for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { + const float x0 = src0_row[i0]; + const float x1 = src1_row[i0]; + + dst_row[i0] = x0*x1*(x0 > 0.0f); + } +} + +kernel void kernel_geglu( + device const char * src0, + device const char * src1, + device char * dst, + constant ggml_metal_kargs_glu & args, + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + + for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { + const float x0 = src0_row[i0]; + const float x1 = src1_row[i0]; + + const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0))); + + dst_row[i0] = gelu*x1; + } +} + +kernel void kernel_swiglu( + device const char * src0, + device const char * src1, + device char * dst, + constant ggml_metal_kargs_glu & args, + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + + for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { + const float x0 = src0_row[i0]; + const float x1 = src1_row[i0]; + + const float silu = x0 / (1.0f + exp(-x0)); + + dst_row[i0] = silu*x1; + } +} + template kernel void kernel_sum_rows( constant ggml_metal_kargs_sum_rows & args, diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 1bd1dcdf3..9337d0824 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -441,6 +441,7 @@ struct vk_device_struct { vk_pipeline pipeline_norm_f32; vk_pipeline pipeline_group_norm_f32; vk_pipeline pipeline_rms_norm_f32; + vk_pipeline pipeline_rms_norm_mul_f32; vk_pipeline pipeline_rms_norm_back_f32; vk_pipeline pipeline_l2_norm_f32; @@ -452,6 +453,10 @@ struct vk_device_struct { vk_pipeline pipeline_tanh[2]; vk_pipeline pipeline_sigmoid[2]; + vk_pipeline pipeline_geglu[2]; + vk_pipeline pipeline_reglu[2]; + vk_pipeline pipeline_swiglu[2]; + vk_pipeline pipeline_leaky_relu_f32; vk_pipeline pipeline_silu_back_f32; vk_pipeline pipeline_diag_mask_inf_f32; @@ -676,6 +681,13 @@ struct vk_op_push_constants { float param2; }; +struct vk_op_glu_push_constants { + uint32_t N; + uint32_t ne00; + uint32_t ne20; + uint32_t mode; // 0: default, 1: swapped, 2: split +}; + struct vk_op_unary_push_constants { uint32_t ne; uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; @@ -994,6 +1006,10 @@ struct ggml_backend_vk_context { vk_command_pool compute_cmd_pool; vk_command_pool transfer_cmd_pool; + + // number of additional consecutive nodes that are being fused with the + // node currently being processed + uint32_t num_additional_fused_ops {}; }; static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT @@ -2671,7 +2687,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -2767,6 +2784,15 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_UNARY(sigmoid) #undef CREATE_UNARY +#define CREATE_GLU(name) \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); + + CREATE_GLU(geglu) + CREATE_GLU(reglu) + CREATE_GLU(swiglu) +#undef CREATE_GLU + ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); @@ -6454,7 +6480,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_RMS_NORM: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_rms_norm_f32; + return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32; } return nullptr; case GGML_OP_RMS_NORM_BACK: @@ -6491,6 +6517,24 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const break; } return nullptr; + case GGML_OP_GLU: + if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || + (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) || + (src0->type != dst->type)) { + return nullptr; + } + + switch (ggml_get_glu_op(dst)) { + case GGML_GLU_OP_GEGLU: + return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_REGLU: + return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_SWIGLU: + return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16]; + default: + break; + } + return nullptr; case GGML_OP_DIAG_MASK_INF: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_diag_mask_inf_f32; @@ -6951,6 +6995,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_CONCAT: case GGML_OP_UPSCALE: case GGML_OP_UNARY: + case GGML_OP_GLU: case GGML_OP_CONV_2D_DW: { uint32_t ne = ggml_nelements(dst); @@ -6991,7 +7036,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } } - if (op == GGML_OP_SOFT_MAX) { + if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) { // Empty src1 is possible in soft_max, but the shader needs a buffer vk_subbuffer subbuf_y; if (use_src1) { @@ -7554,18 +7599,19 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); } -static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { +static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)ggml_nelements(src0), - (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, - (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, - op_params[0], 0.0f, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + op_params[0], 0.0f, 0, }, dryrun); } @@ -7583,6 +7629,25 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); } +static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const bool swapped = (bool)dst->op_params[1]; + const bool split = src1 != nullptr; + + GGML_ASSERT(ggml_is_contiguous(src0)); + + if (!split) { + GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]); + } else { + GGML_ASSERT(src0->ne[0] == src1->ne[0]); + GGML_ASSERT(src0->ne[0] == dst->ne[0]); + GGML_ASSERT(src0->type == src1->type); + } + + const uint32_t mode = split ? 2 : (swapped ? 1 : 0); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun); +} + static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { int32_t * op_params = (int32_t *)dst->op_params; ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun); @@ -8760,7 +8825,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* t // Returns true if node has enqueued work into the queue, false otherwise // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution. -static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){ +static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){ + ggml_tensor * node = cgraph->nodes[node_idx]; if (ggml_is_empty(node) || !node->buffer) { return false; } @@ -8794,6 +8860,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + break; + default: + return false; + } + break; case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_GET_ROWS: @@ -8886,6 +8962,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_RMS_NORM_BACK: case GGML_OP_L2_NORM: case GGML_OP_UNARY: + case GGML_OP_GLU: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: @@ -8998,8 +9075,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod break; case GGML_OP_RMS_NORM: - ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun); - + if (ctx->num_additional_fused_ops > 0) { + // fused rms_norm + mul + ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0]; + ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun); + } else { + ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun); + } break; case GGML_OP_RMS_NORM_BACK: ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun); @@ -9023,6 +9106,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun); + break; + default: + return false; + } + break; case GGML_OP_DIAG_MASK_INF: ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun); @@ -9148,8 +9242,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod if (!ok) { if (node->op == GGML_OP_UNARY) { std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; - } - else { + } else if (node->op == GGML_OP_GLU) { + std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast(node->op_params[0])) << ")" << std::endl; + } else { std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl; } } @@ -9228,6 +9323,17 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(tensor)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + buf = tensor->buffer; + break; + default: + return false; + } + break; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: case GGML_OP_FLASH_ATTN_EXT: @@ -9734,10 +9840,15 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg uint64_t total_mat_mul_bytes = 0; for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false); + if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 1; + } + ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false); if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); } + i += ctx->num_additional_fused_ops; + ctx->num_additional_fused_ops = 0; } if (ctx->device->need_compiles) { ggml_vk_load_shaders(ctx->device); @@ -9799,14 +9910,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); } + if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 1; + } + // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining) bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; bool submit = (submitted_nodes >= nodes_per_submit) || (mul_mat_bytes >= mul_mat_bytes_per_submit) || - (i == last_node) || + (i + ctx->num_additional_fused_ops == last_node) || (almost_ready && !ctx->almost_ready_fence_pending); - bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit); + bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit); if (vk_perf_logger_enabled) { if (ctx->compute_ctx.expired()) { @@ -9816,7 +9931,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg } else { compute_ctx = ctx->compute_ctx.lock(); } - compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+1); + // If there are fused ops, just write out timestamps for all nodes to keep the accounting simple + for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) { + compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1); + } } if (enqueued) { @@ -9838,6 +9956,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg } submit_count++; } + i += ctx->num_additional_fused_ops; + ctx->num_additional_fused_ops = 0; } if (vk_perf_logger_enabled) { @@ -10012,6 +10132,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + return ggml_is_contiguous(op->src[0]) && + (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (op->src[0]->type == op->type); + default: + return false; + } + break; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: { @@ -10742,6 +10875,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); } + } else if (tensor->op == GGML_OP_GLU) { + if (src_clone[1] == nullptr) { + tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]); + } else { + tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]); + } } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) { if (src1 == nullptr) { tensor_clone = ggml_dup(ggml_ctx, src_clone[0]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp new file mode 100644 index 000000000..f4268ed24 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp @@ -0,0 +1,13 @@ +#version 450 + +#include "glu_head.comp" + +const float GELU_COEF_A = 0.044715f; +const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +float op(float a, float b) { + const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a); + return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b; +} + +#include "glu_main.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp new file mode 100644 index 000000000..41a298890 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp @@ -0,0 +1,15 @@ +#extension GL_EXT_shader_16bit_storage : require + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {A_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +layout (push_constant) uniform parameter +{ + uint N; + uint ne00; + uint ne20; + uint mode; +} p; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp new file mode 100644 index 000000000..85cf65a9e --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp @@ -0,0 +1,29 @@ +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.N) { + return; + } + + const uint row = i / p.ne20; + const uint col = i - row * p.ne20; + + if (p.mode == 0) { + // Default + const uint offset = p.ne00 / 2; + const uint idx = row * p.ne00 + col; + + data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); + } else if (p.mode == 1) { + // Swapped + const uint offset = p.ne00 / 2; + const uint idx = row * p.ne00 + col; + + data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); + } else { + // Split + const uint idx = row * p.ne00 + col; + + data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp new file mode 100644 index 000000000..0073d8f76 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp @@ -0,0 +1,9 @@ +#version 450 + +#include "glu_head.comp" + +float op(float a, float b) { + return max(a, 0.0f) * b; +} + +#include "glu_main.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index deb8ee996..6428ca7ba 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -1,11 +1,13 @@ #version 450 -#include "generic_unary_head.comp" +#include "generic_binary_head.comp" #include "types.comp" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 +layout (constant_id = 1) const bool do_multiply = false; + layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; shared FLOAT_TYPE sum[BLOCK_SIZE]; @@ -25,6 +27,7 @@ void main() { const uint stride_sample = p.nb03; uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset(); + uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset(); uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp @@ -46,7 +49,13 @@ void main() { const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols); const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); - [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { - data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); + if (do_multiply) { + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); + } + } else { + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); + } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp new file mode 100644 index 000000000..a28e7c6cc --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp @@ -0,0 +1,9 @@ +#version 450 + +#include "glu_head.comp" + +float op(float a, float b) { + return a / (1.0f + exp(-a)) * b; +} + +#include "glu_main.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 79a110593..cc161f38b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -511,7 +511,7 @@ void process_shaders() { // Norms string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); - string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); @@ -599,6 +599,13 @@ void process_shaders() { string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 17eb889b2..24ab95a30 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -995,9 +995,11 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS_BACK", "OPT_STEP_ADAMW", + + "GLU", }; -static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84"); +static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1092,9 +1094,11 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss(x,y)", "cross_entropy_loss_back(x,y)", "adamw(x)", + + "glu(x)", }; -static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84"); +static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -1120,6 +1124,15 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15"); +static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = { + "REGLU", + "GEGLU", + "SWIGLU", +}; + +static_assert(GGML_GLU_OP_COUNT == 3, "GGML_GLU_OP_COUNT != 3"); + + static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); @@ -1222,11 +1235,19 @@ const char * ggml_unary_op_name(enum ggml_unary_op op) { return GGML_UNARY_OP_NAME[op]; } +const char * ggml_glu_op_name(enum ggml_glu_op op) { + return GGML_GLU_OP_NAME[op]; +} + const char * ggml_op_desc(const struct ggml_tensor * t) { if (t->op == GGML_OP_UNARY) { enum ggml_unary_op uop = ggml_get_unary_op(t); return ggml_unary_op_name(uop); } + if (t->op == GGML_OP_GLU) { + enum ggml_glu_op gop = ggml_get_glu_op(t); + return ggml_glu_op_name(gop); + } return ggml_op_name(t->op); } @@ -1743,6 +1764,11 @@ enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) { return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0); } +enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->op == GGML_OP_GLU); + return (enum ggml_glu_op) ggml_get_op_params_i32(tensor, 0); +} + const char * ggml_get_name(const struct ggml_tensor * tensor) { return tensor->name; } @@ -2622,6 +2648,114 @@ struct ggml_tensor * ggml_exp_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP); } +// ggml_glu + +static struct ggml_tensor * ggml_glu_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_glu_op op, + bool swapped) { + GGML_ASSERT(ggml_is_contiguous_1(a)); + + if (b) { + GGML_ASSERT(ggml_is_contiguous_1(b)); + GGML_ASSERT(ggml_are_same_shape(a, b)); + GGML_ASSERT(a->type == b->type); + } + + int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i]; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0); + + ggml_set_op_params_i32(result, 0, (int32_t) op); + ggml_set_op_params_i32(result, 1, (int32_t) swapped); + + result->op = GGML_OP_GLU; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_glu( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_glu_op op, + bool swapped) { + return ggml_glu_impl(ctx, a, NULL, op, swapped); +} + +struct ggml_tensor * ggml_glu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_glu_op op) { + return ggml_glu_impl(ctx, a, b, op, false); +} + +// ggml_reglu + +struct ggml_tensor * ggml_reglu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, false); +} + +struct ggml_tensor * ggml_reglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, true); +} + +struct ggml_tensor * ggml_reglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_REGLU, false); +} + +// ggml_geglu + +struct ggml_tensor * ggml_geglu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, false); +} + +struct ggml_tensor * ggml_geglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, true); +} + +struct ggml_tensor * ggml_geglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU, false); +} + +// ggml_swiglu + +struct ggml_tensor * ggml_swiglu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, false); +} + +struct ggml_tensor * ggml_swiglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, true); +} + +struct ggml_tensor * ggml_swiglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false); +} + // ggml_norm static struct ggml_tensor * ggml_norm_impl( @@ -5854,19 +5988,32 @@ static void ggml_compute_backward( GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2])); } -static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { +static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { // check if already visited - if (ggml_hash_insert(&cgraph->visited_hash_set, node) == GGML_HASHSET_ALREADY_EXISTS) { - return; + size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node); + GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL); + if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) { + // This is the first time we see this node in the current graph. + cgraph->visited_hash_set.keys[node_hash_pos] = node; + ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos); + cgraph->use_counts[node_hash_pos] = 0; + } else { + // already visited + return node_hash_pos; } for (int i = 0; i < GGML_MAX_SRC; ++i) { const int k = (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i : (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) : - /* unknown order, just fall back to using i*/ i; - if (node->src[k]) { - ggml_visit_parents(cgraph, node->src[k]); + /* unknown order, just fall back to using i */ i; + + struct ggml_tensor * src = node->src[k]; + if (src) { + size_t src_hash_pos = ggml_visit_parents(cgraph, src); + + // Update the use count for this operand. + cgraph->use_counts[src_hash_pos]++; } } @@ -5890,6 +6037,8 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * cgraph->nodes[cgraph->n_nodes] = node; cgraph->n_nodes++; } + + return node_hash_pos; } static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) { @@ -6027,6 +6176,7 @@ static size_t ggml_graph_nbytes(size_t size, bool grads) { incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1); incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs + incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); // use_counts incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys if (grads) { incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads @@ -6056,11 +6206,12 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz void * p = cgraph + 1; - struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); - struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); - struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); - struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; - struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; + struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); + struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); + int32_t * use_counts_ptr = incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); + struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); + struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; + struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t)); @@ -6075,6 +6226,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz /*.grads =*/ grads_ptr, /*.grad_accs =*/ grad_accs_ptr, /*.leafs =*/ leafs_ptr, + /*.use_counts =*/ use_counts_ptr, /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr }, /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, }; @@ -6101,7 +6253,8 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) /*.grads =*/ NULL, // gradients would need visited_hash_set /*.grad_accs =*/ NULL, /*.leafs =*/ NULL, - /*.visited_hash_set =*/ { 0, NULL, NULL }, + /*.use_counts =*/ cgraph0->use_counts, + /*.visited_hash_set =*/ cgraph0->visited_hash_set, /*.order =*/ cgraph0->order, }; @@ -6128,7 +6281,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) { for (size_t i = 0; i < src->visited_hash_set.size; ++i) { // copy all hashset keys (tensors) that are in use if (ggml_bitset_get(src->visited_hash_set.used, i)) { - ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]); + size_t new_hash_pos = ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]); + dst->use_counts[new_hash_pos] = src->use_counts[i]; } } diff --git a/koboldcpp.py b/koboldcpp.py index 780b3a3f6..ce699e7dd 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -62,7 +62,7 @@ dry_seq_break_max = 128 extra_images_max = 4 # global vars -KcppVersion = "1.95" +KcppVersion = "1.95.1" showdebug = True kcpp_instance = None #global running instance global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False, "restart_override_config_target":""} diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 71ee431a9..010300df6 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -560,12 +560,20 @@ ggml_tensor * llm_graph_context::build_ffn( switch (type_op) { case LLM_FFN_SILU: - { + if (gate && type_gate == LLM_FFN_PAR) { + cur = ggml_swiglu_split(ctx0, cur, tmp); + cb(cur, "ffn_swiglu", il); + type_gate = LLM_FFN_SEQ; + } else { cur = ggml_silu(ctx0, cur); cb(cur, "ffn_silu", il); } break; case LLM_FFN_GELU: - { + if (gate && type_gate == LLM_FFN_PAR) { + cur = ggml_geglu_split(ctx0, cur, tmp); + cb(cur, "ffn_geglu", il); + type_gate = LLM_FFN_SEQ; + } else { cur = ggml_gelu(ctx0, cur); cb(cur, "ffn_gelu", il); if (act_scales != NULL) { @@ -574,7 +582,11 @@ ggml_tensor * llm_graph_context::build_ffn( } } break; case LLM_FFN_RELU: - { + if (gate && type_gate == LLM_FFN_PAR) { + cur = ggml_reglu_split(ctx0, cur, tmp); + cb(cur, "ffn_reglu", il); + type_gate = LLM_FFN_SEQ; + } else { cur = ggml_relu(ctx0, cur); cb(cur, "ffn_relu", il); } break; @@ -588,32 +600,19 @@ ggml_tensor * llm_graph_context::build_ffn( } break; case LLM_FFN_SWIGLU: { - // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf - int64_t split_point = cur->ne[0] / 2; - // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217 - ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); - ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); - - x0 = ggml_silu(ctx0, x0); - cb(cur, "ffn_silu", il); - - cur = ggml_mul(ctx0, x0, x1); - cb(cur, "ffn_mul", il); + cur = ggml_swiglu(ctx0, cur); + cb(cur, "ffn_swiglu", il); } break; case LLM_FFN_GEGLU: { - // Split into two equal parts - int64_t split_point = cur->ne[0] / 2; - // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217 - ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); - ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); - - x0 = ggml_gelu(ctx0, x0); - cb(x0, "ffn_gelu", il); - - cur = ggml_mul(ctx0, x0, x1); + cur = ggml_geglu(ctx0, cur); cb(cur, "ffn_geglu", il); } break; + case LLM_FFN_REGLU: + { + cur = ggml_reglu(ctx0, cur); + cb(cur, "ffn_reglu", il); + } break; } if (gate && type_gate == LLM_FFN_PAR) { @@ -743,12 +742,18 @@ ggml_tensor * llm_graph_context::build_moe_ffn( switch (type_op) { case LLM_FFN_SILU: - { + if (gate_exps) { + cur = ggml_swiglu_split(ctx0, cur, up); + cb(cur, "ffn_moe_swiglu", il); + } else { cur = ggml_silu(ctx0, cur); cb(cur, "ffn_moe_silu", il); } break; case LLM_FFN_GELU: - { + if (gate_exps) { + cur = ggml_geglu_split(ctx0, cur, up); + cb(cur, "ffn_moe_geglu", il); + } else { cur = ggml_gelu(ctx0, cur); cb(cur, "ffn_moe_gelu", il); } break; @@ -756,11 +761,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn( GGML_ABORT("fatal error"); } - if (gate_exps) { - cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens] - cb(cur, "ffn_moe_gate_par", il); - } - experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); diff --git a/src/llama-graph.h b/src/llama-graph.h index ee2197e89..ceddb6021 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -38,6 +38,7 @@ enum llm_ffn_op_type { LLM_FFN_RELU_SQR, LLM_FFN_SWIGLU, LLM_FFN_GEGLU, + LLM_FFN_REGLU, }; enum llm_ffn_gate_type { diff --git a/tests/test-tokenizers-repo.sh b/tests/test-tokenizers-repo.sh index 86e839133..1158aebae 100755 --- a/tests/test-tokenizers-repo.sh +++ b/tests/test-tokenizers-repo.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash if [ $# -lt 2 ]; then printf "Usage: $0 []\n" diff --git a/tools/gguf-split/tests.sh b/tools/gguf-split/tests.sh index 05a932227..c9ad85da0 100755 --- a/tools/gguf-split/tests.sh +++ b/tools/gguf-split/tests.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash set -eu diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index aa0019893..b25024c2f 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # make sure we are in the right directory SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) diff --git a/tools/quantize/tests.sh b/tools/quantize/tests.sh index 70f7610f9..ba9616148 100644 --- a/tools/quantize/tests.sh +++ b/tools/quantize/tests.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash set -eu diff --git a/tools/server/chat-llama2.sh b/tools/server/chat-llama2.sh index 1fc79b7e1..450445f17 100755 --- a/tools/server/chat-llama2.sh +++ b/tools/server/chat-llama2.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash API_URL="${API_URL:-http://127.0.0.1:8080}" diff --git a/tools/server/chat.sh b/tools/server/chat.sh index da0a6ca68..84cea2d56 100755 --- a/tools/server/chat.sh +++ b/tools/server/chat.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash API_URL="${API_URL:-http://127.0.0.1:8080}" diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 0fb01665a..53b71079c 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 852352383..d3f627193 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2110,6 +2110,7 @@ struct server_context { /* use_jinja */ params_base.use_jinja, /* prefill_assistant */ params_base.prefill_assistant, /* reasoning_format */ params_base.reasoning_format, + /* chat_template_kwargs */ params_base.default_template_kwargs, /* common_chat_templates */ chat_templates.get(), /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, diff --git a/tools/server/tests/tests.sh b/tools/server/tests/tests.sh index 33fa8cc64..709b5841a 100755 --- a/tools/server/tests/tests.sh +++ b/tools/server/tests/tests.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # make sure we are in the right directory SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index f8fab2c86..2ef9a1645 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -579,6 +579,7 @@ struct oaicompat_parser_options { bool use_jinja; bool prefill_assistant; common_reasoning_format reasoning_format; + std::map chat_template_kwargs; common_chat_templates * tmpls; bool allow_image; bool allow_audio; @@ -756,6 +757,13 @@ static json oaicompat_chat_params_parse( llama_params["parse_tool_calls"] = true; } + // merge the template args provided from command line with the args provided in the user request + auto chat_template_kwargs_object = json_value(body, "chat_template_kwargs", json::object()); + inputs.chat_template_kwargs = opt.chat_template_kwargs; + for (const auto & item : chat_template_kwargs_object.items()) { + inputs.chat_template_kwargs[item.key()] = item.value().dump(); + } + // if the assistant message appears at the end of list, we do not add end-of-turn token // for ex. this can be useful to modify the reasoning process in reasoning models bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && opt.prefill_assistant; @@ -771,6 +779,11 @@ static json oaicompat_chat_params_parse( /* TODO: test this properly */ inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE; + + if ( (!inputs.enable_thinking) || inputs.chat_template_kwargs.find("enable_thinking") != inputs.chat_template_kwargs.end()) { + throw std::runtime_error("Assistant response prefill is incompatible with enable_thinking."); + } + inputs.add_generation_prompt = true; } diff --git a/tools/server/webui/src/components/Sidebar.tsx b/tools/server/webui/src/components/Sidebar.tsx index a77cb83b4..b52a8df03 100644 --- a/tools/server/webui/src/components/Sidebar.tsx +++ b/tools/server/webui/src/components/Sidebar.tsx @@ -231,7 +231,7 @@ function ConversationItem({ > {conv.name} -
+