diff --git a/common/jinja/lexer.cpp b/common/jinja/lexer.cpp index 85eaa1a76..598982c2f 100644 --- a/common/jinja/lexer.cpp +++ b/common/jinja/lexer.cpp @@ -91,6 +91,16 @@ lexer_result lexer::tokenize(const std::string & source) { return str; }; + auto consume_numeric = [&]() -> std::string { + std::string num = consume_while(is_integer); + if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) { + ++pos; // Consume '.' + std::string frac = consume_while(is_integer); + num += "." + frac; + } + return num; + }; + auto next_pos_is = [&](std::initializer_list chars, size_t n = 1) -> bool { if (pos + n >= src.size()) return false; for (char c : chars) { @@ -258,7 +268,7 @@ lexer_result lexer::tokenize(const std::string & source) { ++pos; // Consume the operator // Check for numbers following the unary operator - std::string num = consume_while(is_integer); + std::string num = consume_numeric(); std::string value = std::string(1, ch) + num; token::type t = num.empty() ? token::unary_operator : token::numeric_literal; // JJ_DEBUG("consumed unary operator or numeric literal: '%s'", value.c_str()); @@ -307,12 +317,7 @@ lexer_result lexer::tokenize(const std::string & source) { // Numbers if (is_integer(ch)) { start_pos = pos; - std::string num = consume_while(is_integer); - if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) { - ++pos; // Consume '.' - std::string frac = consume_while(is_integer); - num += "." + frac; - } + std::string num = consume_numeric(); // JJ_DEBUG("consumed numeric literal: '%s'", num.c_str()); tokens.push_back({token::numeric_literal, num, start_pos}); continue; diff --git a/common/jinja/runtime.cpp b/common/jinja/runtime.cpp index ba07f7a6d..d8ef27908 100644 --- a/common/jinja/runtime.cpp +++ b/common/jinja/runtime.cpp @@ -268,8 +268,7 @@ value binary_expression::execute_impl(context & ctx) { // String in object if (is_val(left_val) && is_val(right_val)) { auto key = left_val->as_string().str(); - auto & obj = right_val->as_object(); - bool has_key = obj.find(key) != obj.end(); + bool has_key = right_val->has_key(key); if (op.value == "in") { return mk_val(has_key); } else if (op.value == "not in") { @@ -464,7 +463,7 @@ value for_statement::execute_impl(context & ctx) { std::vector items; if (is_val(iterable_val)) { JJ_DEBUG("%s", "For loop over object keys"); - auto & obj = iterable_val->as_object(); + auto & obj = iterable_val->as_ordered_object(); for (auto & p : obj) { auto tuple = mk_val(); if (iterable_val->val_obj.is_key_numeric) { @@ -560,6 +559,7 @@ value for_statement::execute_impl(context & ctx) { for (size_t i = 0; i < filtered_items.size(); i++) { JJ_DEBUG("For loop iteration %zu/%zu", i + 1, filtered_items.size()); value_object loop_obj = mk_val(); + loop_obj->has_builtins = false; // loop object has no builtins loop_obj->insert("index", mk_val(i + 1)); loop_obj->insert("index0", mk_val(i)); loop_obj->insert("revindex", mk_val(filtered_items.size() - i)); @@ -717,6 +717,7 @@ value member_expression::execute_impl(context & ctx) { value property; if (this->computed) { + // syntax: obj[expr] JJ_DEBUG("Member expression, computing property type %s", this->property->type().c_str()); int64_t arr_size = 0; @@ -745,10 +746,24 @@ value member_expression::execute_impl(context & ctx) { property = this->property->execute(ctx); } } else { + // syntax: obj.prop if (!is_stmt(this->property)) { - throw std::runtime_error("Non-computed member property must be an identifier"); + throw std::runtime_error("Static member property must be an identifier"); } property = mk_val(cast_stmt(this->property)->val); + std::string prop = property->as_string().str(); + JJ_DEBUG("Member expression, object type %s, static property '%s'", object->type().c_str(), prop.c_str()); + + // behavior of jinja2: obj having prop as a built-in function AND 'prop', as an object key, + // then obj.prop returns the built-in function, not the property value. + // while obj['prop'] returns the property value. + // example: {"obj": {"items": 123}} -> obj.items is the built-in function, obj['items'] is 123 + + value val = try_builtin_func(ctx, prop, object, true); + if (!is_val(val)) { + return val; + } + // else, fallthrough to normal property access below } JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str()); @@ -763,11 +778,8 @@ value member_expression::execute_impl(context & ctx) { throw std::runtime_error("Cannot access object with non-string: got " + property->type()); } auto key = property->as_string().str(); - auto & obj = object->as_object(); - auto it = obj.find(key); - if (it != obj.end()) { - val = it->second; - } else { + val = object->at(key, val); + if (is_val(val)) { val = try_builtin_func(ctx, key, object, true); } JJ_DEBUG("Accessed property '%s' value, got type: %s", key.c_str(), val->type().c_str()); diff --git a/common/jinja/runtime.h b/common/jinja/runtime.h index 1e7c63b85..dc7f4e471 100644 --- a/common/jinja/runtime.h +++ b/common/jinja/runtime.h @@ -56,6 +56,7 @@ struct context { // src is optional, used for error reporting context(std::string src = "") : src(std::make_shared(std::move(src))) { env = mk_val(); + env->has_builtins = false; // context object has no builtins env->insert("true", mk_val(true)); env->insert("True", mk_val(true)); env->insert("false", mk_val(false)); @@ -68,7 +69,7 @@ struct context { context(const context & parent) : context() { // inherit variables (for example, when entering a new scope) - auto & pvar = parent.env->as_object(); + auto & pvar = parent.env->as_ordered_object(); for (const auto & pair : pvar) { set_val(pair.first, pair.second); } @@ -265,7 +266,7 @@ struct comment_statement : public statement { struct member_expression : public expression { statement_ptr object; statement_ptr property; - bool computed; + bool computed; // true if obj[expr] and false if obj.prop member_expression(statement_ptr && object, statement_ptr && property, bool computed) : object(std::move(object)), property(std::move(property)), computed(computed) { diff --git a/common/jinja/value.cpp b/common/jinja/value.cpp index 0ae9d1c56..e414aad44 100644 --- a/common/jinja/value.cpp +++ b/common/jinja/value.cpp @@ -698,6 +698,7 @@ const func_builtins & value_bool_t::get_builtins() const { bool val = args.get_pos(0)->as_bool(); return mk_val(val ? "True" : "False"); }}, + {"tojson", tojson}, }; return builtins; } @@ -775,19 +776,30 @@ const func_builtins & value_array_t::get_builtins() const { if (!is_val(args.get_pos(0))) { throw raised_exception("join() first argument must be an array"); } - value val_delim = args.get_kwarg_or_pos("d", 1); - value val_attribute = args.get_kwarg_or_pos("attribute", 2); - if (!val_attribute->is_undefined()) { - throw not_implemented_exception("array attribute join not implemented"); - } + value val_delim = args.get_kwarg_or_pos("d", 1); + value attribute = args.get_kwarg_or_pos("attribute", 2); const auto & arr = args.get_pos(0)->as_array(); - std::string delim = is_val(val_delim) ? val_delim->as_string().str() : ""; + const bool attr_is_int = is_val(attribute); + if (!attribute->is_undefined() && !is_val(attribute) && !attr_is_int) { + throw raised_exception("join() attribute must be string or integer"); + } + const int64_t attr_int = attr_is_int ? attribute->as_int() : 0; + const std::string delim = val_delim->is_undefined() ? "" : val_delim->as_string().str(); + const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str(); std::string result; for (size_t i = 0; i < arr.size(); ++i) { - if (!is_val(arr[i]) && !is_val(arr[i]) && !is_val(arr[i])) { + value val_arr = arr[i]; + if (!attribute->is_undefined()) { + if (attr_is_int && is_val(val_arr)) { + val_arr = val_arr->at(attr_int); + } else if (!attr_is_int && !attr_name.empty() && is_val(val_arr)) { + val_arr = val_arr->at(attr_name); + } + } + if (!is_val(val_arr) && !is_val(val_arr) && !is_val(val_arr)) { throw raised_exception("join() can only join arrays of strings or numerics"); } - result += arr[i]->as_string().str(); + result += val_arr->as_string().str(); if (i < arr.size() - 1) { result += delim; } @@ -802,26 +814,30 @@ const func_builtins & value_array_t::get_builtins() const { }}, {"tojson", tojson}, {"map", [](const func_args & args) -> value { - args.ensure_count(2, 3); + args.ensure_count(2); if (!is_val(args.get_pos(0))) { throw raised_exception("map: first argument must be an array"); } - value attribute = args.get_kwarg_or_pos("attribute", 1); - if (is_val(attribute)) { - throw not_implemented_exception("map: integer attribute not implemented"); + if (!is_val(args.get_args().at(1))) { + throw not_implemented_exception("map: filter-mapping not implemented"); } - if (!is_val(attribute)) { + value attribute = args.get_kwarg_or_pos("attribute", 1); + const bool attr_is_int = is_val(attribute); + if (!is_val(attribute) && !attr_is_int) { throw raised_exception("map: attribute must be string or integer"); } - std::string attr_name = attribute->as_string().str(); + const int64_t attr_int = attr_is_int ? attribute->as_int() : 0; + const std::string attr_name = attribute->as_string().str(); value default_val = args.get_kwarg("default", mk_val()); auto out = mk_val(); auto arr = args.get_pos(0)->as_array(); for (const auto & item : arr) { - if (!is_val(item)) { - throw raised_exception("map: item is not an object"); + value attr_val; + if (attr_is_int) { + attr_val = is_val(item) ? item->at(attr_int, default_val) : default_val; + } else { + attr_val = is_val(item) ? item->at(attr_name, default_val) : default_val; } - value attr_val = item->at(attr_name, default_val); out->push_back(attr_val); } return out; @@ -847,29 +863,35 @@ const func_builtins & value_array_t::get_builtins() const { return arr_editable->pop_at(index); }}, {"sort", [](const func_args & args) -> value { - args.ensure_count(1, 3); + args.ensure_count(1, 4); if (!is_val(args.get_pos(0))) { throw raised_exception("sort: first argument must be an array"); } - bool reverse = args.get_kwarg("reverse", mk_val())->as_bool(); - value attribute = args.get_kwarg("attribute", mk_val()); - std::string attr = attribute->is_undefined() ? "" : attribute->as_string().str(); + value val_reverse = args.get_kwarg_or_pos("reverse", 1); + value val_case = args.get_kwarg_or_pos("case_sensitive", 2); + value attribute = args.get_kwarg_or_pos("attribute", 3); + // FIXME: sorting is currently always case sensitive + //const bool case_sensitive = val_case->as_bool(); // undefined == false + const bool reverse = val_reverse->as_bool(); // undefined == false + const bool attr_is_int = is_val(attribute); + const int64_t attr_int = attr_is_int ? attribute->as_int() : 0; + const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str(); std::vector arr = cast_val(args.get_pos(0))->as_array(); // copy std::sort(arr.begin(), arr.end(),[&](const value & a, const value & b) { value val_a = a; value val_b = b; if (!attribute->is_undefined()) { - if (!is_val(a) || !is_val(b)) { - throw raised_exception("sort: items are not objects"); + if (attr_is_int && is_val(a) && is_val(b)) { + val_a = a->at(attr_int); + val_b = b->at(attr_int); + } else if (!attr_is_int && !attr_name.empty() && is_val(a) && is_val(b)) { + val_a = a->at(attr_name); + val_b = b->at(attr_name); + } else { + throw raised_exception("sort: unsupported object attribute comparison"); } - val_a = attr.empty() ? a : a->at(attr); - val_b = attr.empty() ? b : b->at(attr); - } - if (reverse) { - return value_compare(val_a, val_b, value_compare_op::gt); - } else { - return !value_compare(val_a, val_b, value_compare_op::gt); } + return value_compare(val_a, val_b, reverse ? value_compare_op::gt : value_compare_op::lt); }); return mk_val(arr); }}, @@ -888,6 +910,11 @@ const func_builtins & value_array_t::get_builtins() const { const func_builtins & value_object_t::get_builtins() const { + if (!has_builtins) { + static const func_builtins no_builtins = {}; + return no_builtins; + } + static const func_builtins builtins = { // {"default", default_value}, // cause issue with gpt-oss {"get", [](const func_args & args) -> value { @@ -902,18 +929,13 @@ const func_builtins & value_object_t::get_builtins() const { if (args.count() == 3) { default_val = args.get_pos(2); } - const auto & obj = args.get_pos(0)->as_object(); + const value obj = args.get_pos(0); std::string key = args.get_pos(1)->as_string().str(); - auto it = obj.find(key); - if (it != obj.end()) { - return it->second; - } else { - return default_val; - } + return obj->at(key, default_val); }}, {"keys", [](const func_args & args) -> value { args.ensure_vals(); - const auto & obj = args.get_pos(0)->as_object(); + const auto & obj = args.get_pos(0)->as_ordered_object(); auto result = mk_val(); for (const auto & pair : obj) { result->push_back(mk_val(pair.first)); @@ -922,7 +944,7 @@ const func_builtins & value_object_t::get_builtins() const { }}, {"values", [](const func_args & args) -> value { args.ensure_vals(); - const auto & obj = args.get_pos(0)->as_object(); + const auto & obj = args.get_pos(0)->as_ordered_object(); auto result = mk_val(); for (const auto & pair : obj) { result->push_back(pair.second); @@ -931,7 +953,7 @@ const func_builtins & value_object_t::get_builtins() const { }}, {"items", [](const func_args & args) -> value { args.ensure_vals(); - const auto & obj = args.get_pos(0)->as_object(); + const auto & obj = args.get_pos(0)->as_ordered_object(); auto result = mk_val(); for (const auto & pair : obj) { auto item = mk_val(); @@ -945,7 +967,7 @@ const func_builtins & value_object_t::get_builtins() const { {"string", tojson}, {"length", [](const func_args & args) -> value { args.ensure_vals(); - const auto & obj = args.get_pos(0)->as_object(); + const auto & obj = args.get_pos(0)->as_ordered_object(); return mk_val(static_cast(obj.size())); }}, {"tojson", [](const func_args & args) -> value { @@ -958,21 +980,18 @@ const func_builtins & value_object_t::get_builtins() const { value val_case = args.get_kwarg_or_pos("case_sensitive", 1); value val_by = args.get_kwarg_or_pos("by", 2); value val_reverse = args.get_kwarg_or_pos("reverse", 3); - // FIXME: sorting is case sensitive + // FIXME: sorting is currently always case sensitive //const bool case_sensitive = val_case->as_bool(); // undefined == false const bool reverse = val_reverse->as_bool(); // undefined == false - if (!val_by->is_undefined()) { - throw not_implemented_exception("dictsort by key not implemented"); - } - if (reverse) { - throw not_implemented_exception("dictsort reverse not implemented"); - } - value_t::map obj = val_input->val_obj; // copy - std::sort(obj.ordered.begin(), obj.ordered.end(), [&](const auto & a, const auto & b) { - return a.first < b.first; + const bool by_value = is_val(val_by) && val_by->as_string().str() == "value" ? true : false; + auto result = mk_val(val_input); // copy + std::sort(result->val_obj.ordered.begin(), result->val_obj.ordered.end(), [&](const auto & a, const auto & b) { + if (by_value) { + return value_compare(a.second, b.second, reverse ? value_compare_op::gt : value_compare_op::lt); + } else { + return reverse ? a.first > b.first : a.first < b.first; + } }); - auto result = mk_val(); - result->val_obj = std::move(obj); return result; }}, {"join", [](const func_args &) -> value { @@ -1169,7 +1188,7 @@ static void value_to_json_internal(std::ostringstream & oss, const value & val, } oss << "]"; } else if (is_val(val)) { - const auto & obj = val->val_obj.ordered; // IMPORTANT: need to keep exact order + const auto & obj = val->as_ordered_object(); // IMPORTANT: need to keep exact order oss << "{"; if (!obj.empty()) { oss << newline(); diff --git a/common/jinja/value.h b/common/jinja/value.h index 05e7d1e41..4e916919b 100644 --- a/common/jinja/value.h +++ b/common/jinja/value.h @@ -146,7 +146,7 @@ struct value_t { virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); } virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); } virtual const std::vector & as_array() const { throw std::runtime_error(type() + " is not an array value"); } - virtual const std::map & as_object() const { throw std::runtime_error(type() + " is not an object value"); } + virtual const std::vector> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); } virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); } virtual bool is_none() const { return false; } virtual bool is_undefined() const { return false; } @@ -154,6 +154,9 @@ struct value_t { throw std::runtime_error("No builtins available for type " + type()); } + virtual bool has_key(const std::string & key) { + return val_obj.unordered.find(key) != val_obj.unordered.end(); + } virtual value & at(const std::string & key, value & default_val) { auto it = val_obj.unordered.find(key); if (it == val_obj.unordered.end()) { @@ -168,8 +171,20 @@ struct value_t { } return val_obj.unordered.at(key); } - virtual value & at(size_t index) { - if (index >= val_arr.size()) { + virtual value & at(int64_t index, value & default_val) { + if (index < 0) { + index += val_arr.size(); + } + if (index < 0 || static_cast(index) >= val_arr.size()) { + return default_val; + } + return val_arr[index]; + } + virtual value & at(int64_t index) { + if (index < 0) { + index += val_arr.size(); + } + if (index < 0 || static_cast(index) >= val_arr.size()) { throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size())); } return val_arr[index]; @@ -286,6 +301,7 @@ using value_array = std::shared_ptr; struct value_object_t : public value_t { + bool has_builtins = true; // context and loop objects do not have builtins value_object_t() = default; value_object_t(value & v) { val_obj = v->val_obj; @@ -295,11 +311,16 @@ struct value_object_t : public value_t { val_obj.insert(pair.first, pair.second); } } + value_object_t(const std::vector> & obj) { + for (const auto & pair : obj) { + val_obj.insert(pair.first, pair.second); + } + } void insert(const std::string & key, const value & val) { val_obj.insert(key, val); } virtual std::string type() const override { return "Object"; } - virtual const std::map & as_object() const override { return val_obj.unordered; } + virtual const std::vector> & as_ordered_object() const override { return val_obj.ordered; } virtual bool as_bool() const override { return !val_obj.unordered.empty(); } diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 303278397..387e2fe42 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7,10 +7,9 @@ #include "unary-ops.h" #include "vec.h" -#include #include +#include #include -#include // ggml_compute_forward_dup @@ -7110,12 +7109,13 @@ void ggml_compute_forward_conv_2d_dw( } } -// ggml_compute_forward_pool_1d_sk_p0 - -static void ggml_compute_forward_pool_1d_sk_p0( +// ggml_compute_forward_pool_1d_ksp +static void ggml_compute_forward_pool_1d_ksp( const ggml_compute_params * params, const ggml_op_pool op, const int k, + const int s, + const int p, ggml_tensor * dst) { const ggml_tensor * src = dst->src[0]; @@ -7126,39 +7126,56 @@ static void ggml_compute_forward_pool_1d_sk_p0( return; } - const char * cdata = (const char *)src->data; - const char * const data_end = cdata + ggml_nbytes(src); - float * drow = (float *)dst->data; + const int64_t IW = src->ne[0]; + const int64_t OW = dst->ne[0]; - const int64_t rs = dst->ne[0]; + const int64_t nr = ggml_nrows(src); - while (cdata < data_end) { - const void * srow = (const void *)cdata; - int j = 0; - for (int64_t i = 0; i < rs; ++i) { + for (int64_t ir = 0; ir < nr; ++ir) { + const char * srow_bytes = (const char *) src->data + ir * src->nb[1]; + float * drow = (float *) (( char *) dst->data + ir * dst->nb[1]); + + for (int64_t ow = 0; ow < OW; ++ow) { + float res = 0; switch (op) { - case GGML_OP_POOL_AVG: drow[i] = 0; break; - case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break; + case GGML_OP_POOL_AVG: res = 0.0f; break; + case GGML_OP_POOL_MAX: res = -FLT_MAX; break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } + + int count = 0; + const int base = (int) ow * s - p; + for (int ki = 0; ki < k; ++ki) { - const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]); - switch (op) { - case GGML_OP_POOL_AVG: drow[i] += srow_j; break; - case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break; - case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); + const int j = base + ki; + if (j < 0 || j >= (int) IW) { + continue; } - ++j; + + float v; + if (src->type == GGML_TYPE_F32) { + v = ((const float *) srow_bytes)[j]; + } else { + v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]); + } + + switch (op) { + case GGML_OP_POOL_AVG: res += v; break; + case GGML_OP_POOL_MAX: res = std::max(v, res); break; + case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); + } + + ++count; } + switch (op) { - case GGML_OP_POOL_AVG: drow[i] /= k; break; - case GGML_OP_POOL_MAX: break; + case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break; + case GGML_OP_POOL_MAX: break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } - } - cdata += src->nb[1]; - drow += rs; + drow[ow] = res; + } } } @@ -7173,10 +7190,8 @@ void ggml_compute_forward_pool_1d( const int k0 = opts[1]; const int s0 = opts[2]; const int p0 = opts[3]; - GGML_ASSERT(p0 == 0); // padding not supported - GGML_ASSERT(k0 == s0); // only s = k supported - ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst); + ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst); } // ggml_compute_forward_pool_2d @@ -7194,6 +7209,7 @@ void ggml_compute_forward_pool_2d( } const int32_t * opts = (const int32_t *)dst->op_params; + ggml_op_pool op = static_cast(opts[0]); const int k0 = opts[1]; const int k1 = opts[2]; @@ -7217,11 +7233,13 @@ void ggml_compute_forward_pool_2d( while (cdata < data_end) { for (int oy = 0; oy < py; ++oy) { float * const drow = dplane + oy * px; + float * const out = drow; + for (int ox = 0; ox < px; ++ox) { - float * const out = drow + ox; + float res = 0; switch (op) { - case GGML_OP_POOL_AVG: *out = 0; break; - case GGML_OP_POOL_MAX: *out = -FLT_MAX; break; + case GGML_OP_POOL_AVG: res = 0; break; + case GGML_OP_POOL_MAX: res = -FLT_MAX; break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } @@ -7229,24 +7247,32 @@ void ggml_compute_forward_pool_2d( const int iy = offset1 + oy * s1; for (int ky = 0; ky < k1; ++ky) { - if (iy + ky < 0 || iy + ky >= src->ne[1]) continue; + if (iy + ky < 0 || iy + ky >= src->ne[1]) { + continue; + } + const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky)); for (int kx = 0; kx < k0; ++kx) { int j = ix + kx; - if (j < 0 || j >= src->ne[0]) continue; + if (j < 0 || j >= src->ne[0]) { + continue; + } + const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]); switch (op) { - case GGML_OP_POOL_AVG: *out += srow_j; break; - case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break; + case GGML_OP_POOL_AVG: res += srow_j; break; + case GGML_OP_POOL_MAX: res = std::max(srow_j, res); break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } } } switch (op) { - case GGML_OP_POOL_AVG: *out /= ka; break; - case GGML_OP_POOL_MAX: break; + case GGML_OP_POOL_AVG: res /= ka; break; + case GGML_OP_POOL_MAX: break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } + + out[ox] = res; } } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index b0734797f..04c6137c5 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -94,6 +94,31 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_l return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type); + + const char * pool_str = "undefined"; + switch (op_pool) { + case GGML_OP_POOL_AVG: pool_str = "avg"; break; + case GGML_OP_POOL_MAX: pool_str = "max"; break; + default: GGML_ASSERT(false && "not implemented"); + }; + + char base[256]; + char name[256]; + + snprintf(base, sizeof(base), "kernel_pool_1d_%s_%s", pool_str, ggml_type_name(op->src[0]->type)); + snprintf(name, sizeof(name), "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { GGML_ASSERT(ggml_is_contiguous(op->src[0])); GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 9c3b00148..3d01c56fb 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -104,6 +104,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 9d75ebaa4..ef45e7e33 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1050,10 +1050,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); - case GGML_OP_POOL_1D: - return false; case GGML_OP_UPSCALE: return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); + case GGML_OP_POOL_1D: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_POOL_2D: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_PAD: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index d3b0e732e..59d88b01a 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -928,6 +928,15 @@ typedef struct { int64_t np; } ggml_metal_kargs_pool_2d; +typedef struct { + int32_t k0; + int32_t s0; + int32_t p0; + int64_t IW; + int64_t OW; + int64_t np; +} ggml_metal_kargs_pool_1d; + typedef struct { int64_t ne00; uint64_t nb01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index a50b12b6f..680ad794d 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -432,6 +432,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_cpy(ctx, idx); } break; + case GGML_OP_POOL_1D: + { + n_fuse = ggml_metal_op_pool_1d(ctx, idx); + } break; case GGML_OP_POOL_2D: { n_fuse = ggml_metal_op_pool_2d(ctx, idx); @@ -1622,6 +1626,54 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const int32_t * opts = op->op_params; + ggml_op_pool op_pool = (ggml_op_pool) opts[0]; + + const int32_t k0 = opts[1]; + const int32_t s0 = opts[2]; + const int32_t p0 = opts[3]; + + const int64_t IW = op->src[0]->ne[0]; + const int64_t OW = op->ne[0]; + + const int64_t np = ggml_nelements(op); + + ggml_metal_kargs_pool_1d args_pool_1d = { + /* .k0 = */ k0, + /* .s0 = */ s0, + /* .p0 = */ p0, + /* .IW = */ IW, + /* .OW = */ OW, + /* .np = */ np + }; + + auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np); + const int ntg = (np + nth - 1) / nth; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args_pool_1d, sizeof(args_pool_1d), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1); + + return 1; +} + + int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index c1025d356..10686a334 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -61,6 +61,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_mul_mat (ggml_metal_op_t ctx, int idx); int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 16d17d26a..a4e1cafe5 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -9869,6 +9869,74 @@ kernel void kernel_pool_2d_avg_f32( o_ptr[cur_oh * args.OW + cur_ow] = res; } + +kernel void kernel_pool_1d_max_f32( + constant ggml_metal_kargs_pool_1d & args, + device const float * src, + device float * dst, + uint gid [[thread_position_in_grid]] +) { + + if (gid >= args.np) { + return; + } + + const int ow = (int)gid % args.OW; + const int row = (int)gid / args.OW; + + const int base = ow * args.s0 - args.p0; + + float acc = -INFINITY; + + const int src_off = row * args.IW; + const int dst_off = row * args.OW; + + for (int ki = 0; ki < args.k0; ++ki) { + int j = base + ki; + if (j < 0 || j >= args.IW){ + continue; + } + float v = src[src_off + j]; + acc = max(acc, v); + } + + dst[dst_off + ow] = acc; +} + +kernel void kernel_pool_1d_avg_f32( + constant ggml_metal_kargs_pool_1d & args, + device const float * src, + device float * dst, + uint gid [[thread_position_in_grid]] +) { + + if (gid >= args.np) { + return; + } + + const int ow = (int)gid % args.OW; + const int row = (int)gid / args.OW; + + const int base = ow * args.s0 - args.p0; + + float acc = 0.0f; + int cnt = 0; + + const int src_off = row * args.IW; + const int dst_off = row * args.OW; + + for (int ki = 0; ki < args.k0; ++ki) { + const int j = base + ki; + if (j < 0 || j >= args.IW) { + continue; + } + acc += src[src_off + j]; + cnt += 1; + } + + dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f; +} + kernel void kernel_opt_step_adamw_f32( constant ggml_metal_kargs_opt_step_adamw & args, device float * x, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 4cf3e03bf..275311ca7 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4854,6 +4854,8 @@ struct ggml_tensor * ggml_pool_1d( a->ne[2], a->ne[3], }; + GGML_ASSERT(ne[0] > 0); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); int32_t params[] = { op, k0, s0, p0 }; @@ -4884,6 +4886,9 @@ struct ggml_tensor * ggml_pool_2d( a->ne[2], a->ne[3], }; + GGML_ASSERT(ne[0] > 0); + GGML_ASSERT(ne[1] > 0); + result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; diff --git a/koboldcpp.py b/koboldcpp.py index 675e5f000..b8bde6da5 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -67,7 +67,7 @@ dry_seq_break_max = 128 extra_images_max = 4 # for kontext/qwen img # global vars -KcppVersion = "1.106.2" +KcppVersion = "1.107" showdebug = True kcpp_instance = None #global running instance global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False, "restart_override_config_target":""} diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index c847ef91b..5f1df995f 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -200,42 +200,6 @@ uint32_t llama_hparams::n_layer_kv() const { return res; } -bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) { - assert(p0 >= 0 && p1 >= 0); - - switch (swa_type) { - case LLAMA_SWA_TYPE_NONE: - { - } break; - case LLAMA_SWA_TYPE_STANDARD: - { - if (p1 - p0 >= (int32_t) n_swa) { - return true; - } - } break; - case LLAMA_SWA_TYPE_CHUNKED: - { - const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; - - if (p0 < pos_chunk_start) { - return true; - } - } break; - case LLAMA_SWA_TYPE_SYMMETRIC: - { - const int32_t half_n_swa = (int32_t) n_swa / 2; - const int32_t pos_diff = p1 - p0; - - // Mask if outside the symmetric window - if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { - return true; - } - } break; - } - - return false; -} - bool llama_hparams::use_mrope() const { return rope_sections[0] > 0 && rope_sections[1] > 0; } diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 7ae3ec292..2bf866552 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -3,6 +3,7 @@ #include "llama.h" #include +#include // bump if necessary #define LLAMA_MAX_LAYERS 512 @@ -274,9 +275,45 @@ struct llama_hparams { uint32_t n_layer_kv() const; // note that this function uses different SWA parameters from those in the hparams + // note: inlined on purpose for performance reasons // TODO: think of a better place for this function // TODO: pack the SWA params in a struct? - static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1); + static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) { + assert(p0 >= 0 && p1 >= 0); + + switch (swa_type) { + case LLAMA_SWA_TYPE_NONE: + { + } break; + case LLAMA_SWA_TYPE_STANDARD: + { + if (p1 - p0 >= (int32_t) n_swa) { + return true; + } + } break; + case LLAMA_SWA_TYPE_CHUNKED: + { + const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; + + if (p0 < pos_chunk_start) { + return true; + } + } break; + case LLAMA_SWA_TYPE_SYMMETRIC: + { + const int32_t half_n_swa = (int32_t) n_swa / 2; + const int32_t pos_diff = p1 - p0; + + // Mask if outside the symmetric window + if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { + return true; + } + } break; + } + + return false; + } + bool use_mrope() const; }; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 1201fd08f..fe409d08c 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -852,7 +852,7 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, const llama_seq_id seq_id_cell = cells.seq_get(idx); // SWA mask - if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { + if (llama_hparams::is_masked_swa(n_swa, swa_type, pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { can_use = true; } } @@ -1237,6 +1237,197 @@ void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const { } } +struct args_set_input_kq_mask { + const llama_hparams & hparams; + const llama_ubatch * ubatch; + + const std::vector & v_cells; + const std::vector & seq_to_stream; + + uint32_t n_swa; + llama_swa_type swa_type; + + int64_t n_kv; + int64_t n_stream; + int64_t n_tps; +}; + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + //const auto & hparams = args.hparams; + const auto & ubatch = args.ubatch; + + const auto & v_cells = args.v_cells; + const auto & seq_to_stream = args.seq_to_stream; + + const uint32_t n_swa = args.n_swa; + const llama_swa_type swa_type = args.swa_type; + + const int64_t n_kv = args.n_kv; + const int64_t n_stream = args.n_stream; + const int64_t n_tps = args.n_tps; + + // the min position in the batch for each sequence + llama_pos seq_pos_min[LLAMA_MAX_SEQ]; + std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX); + + for (uint32_t i = 0; i < ubatch->n_tokens; ++i) { + const llama_seq_id seq_id = ubatch->seq_id[i][0]; + + seq_pos_min[seq_id] = std::min(seq_pos_min[seq_id], ubatch->pos[i]); + } + + for (uint32_t s = 0; s < n_stream; ++s) { + // bookeeping of the KQ mask cells that could change for other tokens of the same sequence + std::unordered_map seq_srct; + std::unordered_map> seq_idxs; + + for (uint32_t ii = 0; ii < n_tps; ++ii) { + const uint32_t i = s*n_tps + ii; + + const llama_seq_id seq_id = ubatch->seq_id[i][0]; + + const auto & cells = v_cells.at(seq_to_stream[seq_id]); + + llama_pos p0 = -1; + const llama_pos p1 = ubatch->pos[i]; + + // for M-RoPE + const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0; + const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0; + + const uint64_t idst = n_kv*i; + + // for tokens of the same sequence, the mask is mostly the same, so we can reuse it + // the only cells that could change are the ones that are with similar positions as the + // ones in the batch (i.e. due to causal masking, SWA, etc.) + // keep track of those cells and shortcut the loop to save time + // note: this optimization is not compatible with Alibi position encoding + // ref: https://github.com/ggml-org/llama.cpp/pull/18842 + bool prev = false; + + auto & idxs = seq_idxs[seq_id]; + + if (!alibi) { + if (seq_srct.find(seq_id) != seq_srct.end()) { + const uint32_t srct = seq_srct[seq_id]; + + const uint64_t idst_prev = n_kv*srct; + + std::copy(data + idst_prev, data + idst_prev + n_kv, data + idst); + + prev = true; + } else { + idxs.clear(); + idxs.reserve(ubatch->n_tokens + n_swa + 32); + + seq_srct[seq_id] = i; + } + } + + for (uint32_t jj = 0; jj < n_kv; ++jj) { + uint32_t j = jj; + + // we have an exiting mask for this sequence -> update just seq_idxs + if (!alibi) { + if (prev) { + if (jj >= idxs.size()) { + break; + } + + j = idxs[jj]; + } + } + + if (cells.is_empty(j)) { + goto skip; + } + + // mask the token if not the same sequence + if (!cells.seq_has(j, seq_id)) { + goto skip; + } + + p0 = cells.pos_get(j); + + if (!alibi) { + if (!prev) { + // record all cells for which: p0 >= seq_pos_min[seq_id] - n_swa - 32 + if (p0 + (int32_t) (n_swa + 32) >= seq_pos_min[seq_id]) { + idxs.push_back(j); + } + } + } + + if (causal) { + // mask future tokens + if (p0 > p1) { + goto skip; + } + + // M-RoPE causal mask + if (is_2d) { + if (p0 == p1) { + const auto & p0_ext = cells.ext_get(j); + + if (p0_ext.is_2d_gt(p1_x, p1_y)) { + goto skip; + } + } + } + } + + // apply SWA if any + if (swa) { + if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) { + goto skip; + } + } + + if (alibi) { + data[idst + j] = -std::abs(p0 - p1); + } else { + data[idst + j] = 0.0f; + } + + continue; +skip: + data[idst + j] = -INFINITY; + } + } + } +} + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + const bool alibi = args.hparams.use_alibi; + if (alibi) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } +} + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + const bool is_2d = args.ubatch->is_pos_2d(); + if (is_2d) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } +} + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE; + if (swa) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } +} + void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { const uint32_t n_tokens = ubatch->n_tokens; @@ -1251,74 +1442,29 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u // n_tps == n_tokens_per_stream const int64_t n_tps = n_tokens/n_stream; - std::fill(data, data + ggml_nelements(dst), -INFINITY); + //const int64_t t_start = ggml_time_us(); - // Use only the previous KV cells of the correct sequence for each token of the ubatch. - // It's assumed that if a token in the batch has multiple sequences, they are equivalent. - // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: - // Causal mask: - // xxx------- - // xxxx------ - // xxxxx----- - // Non-causal mask: - // xxxxx----- - // xxxxx----- - // xxxxx----- - // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 - // TODO: optimize this section - for (uint32_t h = 0; h < 1; ++h) { - for (uint32_t s = 0; s < n_stream; ++s) { - for (uint32_t ii = 0; ii < n_tps; ++ii) { - const uint32_t i = s*n_tps + ii; + const args_set_input_kq_mask args = { + /*.hparams =*/ hparams, + /*.ubatch =*/ ubatch, + /*.v_cells =*/ v_cells, + /*.seq_to_stream =*/ seq_to_stream, + /*.n_swa =*/ n_swa, + /*.swa_type =*/ swa_type, + /*.n_kv =*/ n_kv, + /*.n_stream =*/ n_stream, + /*.n_tps =*/ n_tps, + }; - const llama_seq_id seq_id = ubatch->seq_id[i][0]; - - const auto & cells = v_cells[seq_to_stream[seq_id]]; - - const llama_pos p1 = ubatch->pos[i]; - - // for M-RoPE - const bool is_2d = ubatch->is_pos_2d(); - const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0; - const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0; - - const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii); - - for (uint32_t j = 0; j < n_kv; ++j) { - if (cells.is_empty(j)) { - continue; - } - - // mask the token if not the same sequence - if (!cells.seq_has(j, seq_id)) { - continue; - } - - const llama_pos p0 = cells.pos_get(j); - - // mask future tokens - if (causal_attn && p0 > p1) { - continue; - } - - // M-RoPE causal mask - if (causal_attn && is_2d && p0 == p1) { - const auto & p0_ext = cells.ext_get(j); - if (p0_ext.is_2d_gt(p1_x, p1_y)) { - continue; - } - } - - // apply SWA if any - if (is_masked_swa(p0, p1)) { - continue; - } - - data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; - } - } - } + if (causal_attn) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); } + + //const int64_t t_end = ggml_time_us(); + + //LLAMA_LOG_ERROR("%s: kq mask time: %0.3f ms\n", __func__, (t_end - t_start)/1000.0); } void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { @@ -1483,10 +1629,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co return gf; } -bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const { - return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1); -} - void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { GGML_UNUSED(flags); diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 0c4ed6484..e194bf3e2 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -257,8 +257,6 @@ private: size_t size_k_bytes() const; size_t size_v_bytes() const; - bool is_masked_swa(llama_pos p0, llama_pos p1) const; - ggml_tensor * build_rope_shift( const llama_cparams & cparams, ggml_context * ctx, diff --git a/tests/test-jinja.cpp b/tests/test-jinja.cpp index 7adb302ff..13818381e 100644 --- a/tests/test-jinja.cpp +++ b/tests/test-jinja.cpp @@ -4,6 +4,7 @@ #include #include +#include #include "jinja/runtime.h" #include "jinja/parser.h" @@ -31,12 +32,24 @@ static void test_array_methods(testing & t); static void test_object_methods(testing & t); static void test_fuzzing(testing & t); +static bool g_python_mode = false; + int main(int argc, char *argv[]) { testing t(std::cout); t.verbose = true; - if (argc >= 2) { - t.set_filter(argv[1]); + // usage: test-jinja [-py] [filter_regex] + // -py : enable python mode (use python jinja2 for rendering expected output) + // only use this for cross-checking, not for correctness + // note: the implementation of this flag is basic, only intented to be used by maintainers + + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + if (arg == "-py") { + g_python_mode = true; + } else { + t.set_filter(arg); + } } t.test("whitespace control", test_whitespace_control); @@ -53,7 +66,9 @@ int main(int argc, char *argv[]) { t.test("string methods", test_string_methods); t.test("array methods", test_array_methods); t.test("object methods", test_object_methods); - t.test("fuzzing", test_fuzzing); + if (!g_python_mode) { + t.test("fuzzing", test_fuzzing); + } return t.summary(); } @@ -247,6 +262,12 @@ static void test_expressions(testing & t) { "Bob" ); + test_template(t, "negative float (not dot notation)", + "{{ -1.0 }}", + json::object(), + "-1.0" + ); + test_template(t, "bracket notation", "{{ user['name'] }}", {{"user", {{"name", "Bob"}}}}, @@ -383,6 +404,32 @@ static void test_filters(testing & t) { "123" ); + test_template(t, "sort reverse", + "{% for i in items|sort(true) %}{{ i }}{% endfor %}", + {{"items", json::array({3, 1, 2})}}, + "321" + ); + + test_template(t, "sort with attribute", + "{{ items|sort(attribute='name')|join(attribute='age') }}", + {{"items", json::array({ + json({{"name", "c"}, {"age", 3}}), + json({{"name", "a"}, {"age", 1}}), + json({{"name", "b"}, {"age", 2}}), + })}}, + "123" + ); + + test_template(t, "sort with numeric attribute", + "{{ items|sort(attribute=0)|join(attribute=1) }}", + {{"items", json::array({ + json::array({3, "z"}), + json::array({1, "x"}), + json::array({2, "y"}), + })}}, + "xyz" + ); + test_template(t, "join", "{{ items|join(', ') }}", {{"items", json::array({"a", "b", "c"})}}, @@ -534,6 +581,66 @@ static void test_literals(testing & t) { json::object(), "1" ); + + test_template(t, "integer|abs", + "{{ -42 | abs }}", + json::object(), + "42" + ); + + test_template(t, "integer|float", + "{{ 42 | float }}", + json::object(), + "42.0" + ); + + test_template(t, "integer|tojson", + "{{ 42 | tojson }}", + json::object(), + "42" + ); + + test_template(t, "float|abs", + "{{ -3.14 | abs }}", + json::object(), + "3.14" + ); + + test_template(t, "float|int", + "{{ 3.14 | int }}", + json::object(), + "3" + ); + + test_template(t, "float|tojson", + "{{ 3.14 | tojson }}", + json::object(), + "3.14" + ); + + test_template(t, "string|tojson", + "{{ 'hello' | tojson }}", + json::object(), + "\"hello\"" + ); + + test_template(t, "boolean|int", + "{{ true | int }}", + json::object(), + "1" + ); + + test_template(t, "boolean|float", + "{{ true | float }}", + json::object(), + "1.0" + ); + + test_template(t, "boolean|tojson", + "{{ true | tojson }}", + json::object(), + "true" + ); } static void test_comments(testing & t) { @@ -934,7 +1041,17 @@ static void test_array_methods(testing & t) { ); test_template(t, "array|join attribute", - "{{ arr|join(attribute=0) }}", + "{{ arr|join(attribute='age') }}", + {{"arr", json::array({ + json({{"name", "a"}, {"age", 1}}), + json({{"name", "b"}, {"age", 2}}), + json({{"name", "c"}, {"age", 3}}), + })}}, + "123" + ); + + test_template(t, "array|join numeric attribute", + "{{ arr|join(attribute=-1) }}", {{"arr", json::array({json::array({1}), json::array({2}), json::array({3})})}}, "123" ); @@ -957,8 +1074,8 @@ static void test_array_methods(testing & t) { "a,b,c,d" ); - test_template(t, "array.map() with attribute", - "{% for v in arr.map('age') %}{{ v }} {% endfor %}", + test_template(t, "array|map with attribute", + "{% for v in arr|map(attribute='age') %}{{ v }} {% endfor %}", {{"arr", json::array({ json({{"name", "a"}, {"age", 1}}), json({{"name", "b"}, {"age", 2}}), @@ -967,8 +1084,28 @@ static void test_array_methods(testing & t) { "1 2 3 " ); - test_template(t, "array.map() with numeric attribute", - "{% for v in arr.map(0) %}{{ v }} {% endfor %}", + test_template(t, "array|map with attribute default", + "{% for v in arr|map(attribute='age', default=3) %}{{ v }} {% endfor %}", + {{"arr", json::array({ + json({{"name", "a"}, {"age", 1}}), + json({{"name", "b"}, {"age", 2}}), + json({{"name", "c"}}), + })}}, + "1 2 3 " + ); + + test_template(t, "array|map without attribute default", + "{% for v in arr|map(attribute='age') %}{{ v }} {% endfor %}", + {{"arr", json::array({ + json({{"name", "a"}, {"age", 1}}), + json({{"name", "b"}, {"age", 2}}), + json({{"name", "c"}}), + })}}, + "1 2 " + ); + + test_template(t, "array|map with numeric attribute", + "{% for v in arr|map(attribute=0) %}{{ v }} {% endfor %}", {{"arr", json::array({ json::array({10, "x"}), json::array({20, "y"}), @@ -977,6 +1114,22 @@ static void test_array_methods(testing & t) { "10 20 30 " ); + test_template(t, "array|map with negative attribute", + "{% for v in arr|map(attribute=-1) %}{{ v }} {% endfor %}", + {{"arr", json::array({ + json::array({10, "x"}), + json::array({20, "y"}), + json::array({30, "z"}), + })}}, + "x y z " + ); + + test_template(t, "array|map with filter", + "{{ arr|map('int')|sum }}", + {{"arr", json::array({"1", "2", "3"})}}, + "6" + ); + // not used by any chat templates // test_template(t, "array.insert()", // "{% set _ = arr.insert(1, 'x') %}{{ arr|join(',') }}", @@ -1063,9 +1216,21 @@ static void test_object_methods(testing & t) { {{"obj", {{"items", json::array({1, 2, 3})}}}}, "{\"items\": [1, 2, 3]}" ); + + test_template(t, "object attribute and key access", + "{{ obj.keys()|join(',') }} vs {{ obj['keys'] }} vs {{ obj.test }}", + {{"obj", {{"keys", "value"}, {"test", "attr_value"}}}}, + "keys,test vs value vs attr_value" + ); + + test_template(t, "env should not have object methods", + "{{ keys is undefined }} {{ obj.keys is defined }}", + {{"obj", {{"a", "b"}}}}, + "True True" + ); } -static void test_template(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) { +static void test_template_cpp(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) { t.test(name, [&tmpl, &vars, &expect](testing & t) { jinja::lexer lexer; auto lexer_res = lexer.tokenize(tmpl); @@ -1098,6 +1263,99 @@ static void test_template(testing & t, const std::string & name, const std::stri }); } +// keep this in-sync with https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py +// note: we use SandboxedEnvironment instead of ImmutableSandboxedEnvironment to allow usage of in-place array methods like append() and pop() +static std::string py_script = R"( +import jinja2 +import jinja2.ext as jinja2_ext +import json +import sys +from datetime import datetime +from jinja2.sandbox import SandboxedEnvironment + +tmpl = json.loads(sys.argv[1]) +vars_json = json.loads(sys.argv[2]) + +env = SandboxedEnvironment( + trim_blocks=True, + lstrip_blocks=True, + extensions=[jinja2_ext.loopcontrols], +) + +def raise_exception(message): + raise jinja2.exceptions.TemplateError(message) + +env.filters["tojson"] = lambda x, ensure_ascii=False, indent=None, separators=None, sort_keys=False: json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys) +env.globals["strftime_now"] = lambda format: datetime.now().strftime(format) +env.globals["raise_exception"] = raise_exception + +template = env.from_string(tmpl) +result = template.render(**vars_json) +print(result, end='') +)"; + +static void test_template_py(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) { + t.test(name, [&tmpl, &vars, &expect](testing & t) { + // Prepare arguments + std::string tmpl_json = json(tmpl).dump(); + std::string vars_json = vars.dump(); + +#ifdef _WIN32 + const char * python_executable = "python.exe"; +#else + const char * python_executable = "python3"; +#endif + + const char * command_line[] = {python_executable, "-c", py_script.c_str(), tmpl_json.c_str(), vars_json.c_str(), NULL}; + + struct subprocess_s subprocess; + int options = subprocess_option_combined_stdout_stderr + | subprocess_option_no_window + | subprocess_option_inherit_environment + | subprocess_option_search_user_path; + int result = subprocess_create(command_line, options, &subprocess); + + if (result != 0) { + t.log("Failed to create subprocess, error code: " + std::to_string(result)); + t.assert_true("subprocess creation", false); + return; + } + + // Read output + std::string output; + char buffer[1024]; + FILE * p_stdout = subprocess_stdout(&subprocess); + while (fgets(buffer, sizeof(buffer), p_stdout)) { + output += buffer; + } + + int process_return; + subprocess_join(&subprocess, &process_return); + subprocess_destroy(&subprocess); + + if (process_return != 0) { + t.log("Python script failed with exit code: " + std::to_string(process_return)); + t.log("Output: " + output); + t.assert_true("python execution", false); + return; + } + + if (!t.assert_true("Template render mismatch", expect == output)) { + t.log("Template: " + json(tmpl).dump()); + t.log("Expected: " + json(expect).dump()); + t.log("Python : " + json(output).dump()); + } + }); +} + +static void test_template(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) { + if (g_python_mode) { + test_template_py(t, name, tmpl, vars, expect); + } else { + test_template_cpp(t, name, tmpl, vars, expect); + } +} + // // fuzz tests to ensure no crashes occur on malformed inputs //