mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-07 17:22:04 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .github/labeler.yml # CODEOWNERS # docs/backend/OPENCL.md # docs/ops.md # docs/ops/CANN.csv # docs/ops/WebGPU.csv # ggml/src/ggml-blas/CMakeLists.txt # ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl # ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp # ggml/src/ggml-webgpu/ggml-webgpu.cpp # ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl # ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl # tests/test-backend-ops.cpp
This commit is contained in:
commit
7f618454ff
20 changed files with 879 additions and 235 deletions
|
|
@ -91,6 +91,16 @@ lexer_result lexer::tokenize(const std::string & source) {
|
|||
return str;
|
||||
};
|
||||
|
||||
auto consume_numeric = [&]() -> std::string {
|
||||
std::string num = consume_while(is_integer);
|
||||
if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) {
|
||||
++pos; // Consume '.'
|
||||
std::string frac = consume_while(is_integer);
|
||||
num += "." + frac;
|
||||
}
|
||||
return num;
|
||||
};
|
||||
|
||||
auto next_pos_is = [&](std::initializer_list<char> chars, size_t n = 1) -> bool {
|
||||
if (pos + n >= src.size()) return false;
|
||||
for (char c : chars) {
|
||||
|
|
@ -258,7 +268,7 @@ lexer_result lexer::tokenize(const std::string & source) {
|
|||
++pos; // Consume the operator
|
||||
|
||||
// Check for numbers following the unary operator
|
||||
std::string num = consume_while(is_integer);
|
||||
std::string num = consume_numeric();
|
||||
std::string value = std::string(1, ch) + num;
|
||||
token::type t = num.empty() ? token::unary_operator : token::numeric_literal;
|
||||
// JJ_DEBUG("consumed unary operator or numeric literal: '%s'", value.c_str());
|
||||
|
|
@ -307,12 +317,7 @@ lexer_result lexer::tokenize(const std::string & source) {
|
|||
// Numbers
|
||||
if (is_integer(ch)) {
|
||||
start_pos = pos;
|
||||
std::string num = consume_while(is_integer);
|
||||
if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) {
|
||||
++pos; // Consume '.'
|
||||
std::string frac = consume_while(is_integer);
|
||||
num += "." + frac;
|
||||
}
|
||||
std::string num = consume_numeric();
|
||||
// JJ_DEBUG("consumed numeric literal: '%s'", num.c_str());
|
||||
tokens.push_back({token::numeric_literal, num, start_pos});
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -268,8 +268,7 @@ value binary_expression::execute_impl(context & ctx) {
|
|||
// String in object
|
||||
if (is_val<value_string>(left_val) && is_val<value_object>(right_val)) {
|
||||
auto key = left_val->as_string().str();
|
||||
auto & obj = right_val->as_object();
|
||||
bool has_key = obj.find(key) != obj.end();
|
||||
bool has_key = right_val->has_key(key);
|
||||
if (op.value == "in") {
|
||||
return mk_val<value_bool>(has_key);
|
||||
} else if (op.value == "not in") {
|
||||
|
|
@ -464,7 +463,7 @@ value for_statement::execute_impl(context & ctx) {
|
|||
std::vector<value> items;
|
||||
if (is_val<value_object>(iterable_val)) {
|
||||
JJ_DEBUG("%s", "For loop over object keys");
|
||||
auto & obj = iterable_val->as_object();
|
||||
auto & obj = iterable_val->as_ordered_object();
|
||||
for (auto & p : obj) {
|
||||
auto tuple = mk_val<value_array>();
|
||||
if (iterable_val->val_obj.is_key_numeric) {
|
||||
|
|
@ -560,6 +559,7 @@ value for_statement::execute_impl(context & ctx) {
|
|||
for (size_t i = 0; i < filtered_items.size(); i++) {
|
||||
JJ_DEBUG("For loop iteration %zu/%zu", i + 1, filtered_items.size());
|
||||
value_object loop_obj = mk_val<value_object>();
|
||||
loop_obj->has_builtins = false; // loop object has no builtins
|
||||
loop_obj->insert("index", mk_val<value_int>(i + 1));
|
||||
loop_obj->insert("index0", mk_val<value_int>(i));
|
||||
loop_obj->insert("revindex", mk_val<value_int>(filtered_items.size() - i));
|
||||
|
|
@ -717,6 +717,7 @@ value member_expression::execute_impl(context & ctx) {
|
|||
|
||||
value property;
|
||||
if (this->computed) {
|
||||
// syntax: obj[expr]
|
||||
JJ_DEBUG("Member expression, computing property type %s", this->property->type().c_str());
|
||||
|
||||
int64_t arr_size = 0;
|
||||
|
|
@ -745,10 +746,24 @@ value member_expression::execute_impl(context & ctx) {
|
|||
property = this->property->execute(ctx);
|
||||
}
|
||||
} else {
|
||||
// syntax: obj.prop
|
||||
if (!is_stmt<identifier>(this->property)) {
|
||||
throw std::runtime_error("Non-computed member property must be an identifier");
|
||||
throw std::runtime_error("Static member property must be an identifier");
|
||||
}
|
||||
property = mk_val<value_string>(cast_stmt<identifier>(this->property)->val);
|
||||
std::string prop = property->as_string().str();
|
||||
JJ_DEBUG("Member expression, object type %s, static property '%s'", object->type().c_str(), prop.c_str());
|
||||
|
||||
// behavior of jinja2: obj having prop as a built-in function AND 'prop', as an object key,
|
||||
// then obj.prop returns the built-in function, not the property value.
|
||||
// while obj['prop'] returns the property value.
|
||||
// example: {"obj": {"items": 123}} -> obj.items is the built-in function, obj['items'] is 123
|
||||
|
||||
value val = try_builtin_func(ctx, prop, object, true);
|
||||
if (!is_val<value_undefined>(val)) {
|
||||
return val;
|
||||
}
|
||||
// else, fallthrough to normal property access below
|
||||
}
|
||||
|
||||
JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str());
|
||||
|
|
@ -763,11 +778,8 @@ value member_expression::execute_impl(context & ctx) {
|
|||
throw std::runtime_error("Cannot access object with non-string: got " + property->type());
|
||||
}
|
||||
auto key = property->as_string().str();
|
||||
auto & obj = object->as_object();
|
||||
auto it = obj.find(key);
|
||||
if (it != obj.end()) {
|
||||
val = it->second;
|
||||
} else {
|
||||
val = object->at(key, val);
|
||||
if (is_val<value_undefined>(val)) {
|
||||
val = try_builtin_func(ctx, key, object, true);
|
||||
}
|
||||
JJ_DEBUG("Accessed property '%s' value, got type: %s", key.c_str(), val->type().c_str());
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ struct context {
|
|||
// src is optional, used for error reporting
|
||||
context(std::string src = "") : src(std::make_shared<std::string>(std::move(src))) {
|
||||
env = mk_val<value_object>();
|
||||
env->has_builtins = false; // context object has no builtins
|
||||
env->insert("true", mk_val<value_bool>(true));
|
||||
env->insert("True", mk_val<value_bool>(true));
|
||||
env->insert("false", mk_val<value_bool>(false));
|
||||
|
|
@ -68,7 +69,7 @@ struct context {
|
|||
|
||||
context(const context & parent) : context() {
|
||||
// inherit variables (for example, when entering a new scope)
|
||||
auto & pvar = parent.env->as_object();
|
||||
auto & pvar = parent.env->as_ordered_object();
|
||||
for (const auto & pair : pvar) {
|
||||
set_val(pair.first, pair.second);
|
||||
}
|
||||
|
|
@ -265,7 +266,7 @@ struct comment_statement : public statement {
|
|||
struct member_expression : public expression {
|
||||
statement_ptr object;
|
||||
statement_ptr property;
|
||||
bool computed;
|
||||
bool computed; // true if obj[expr] and false if obj.prop
|
||||
|
||||
member_expression(statement_ptr && object, statement_ptr && property, bool computed)
|
||||
: object(std::move(object)), property(std::move(property)), computed(computed) {
|
||||
|
|
|
|||
|
|
@ -698,6 +698,7 @@ const func_builtins & value_bool_t::get_builtins() const {
|
|||
bool val = args.get_pos(0)->as_bool();
|
||||
return mk_val<value_string>(val ? "True" : "False");
|
||||
}},
|
||||
{"tojson", tojson},
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
|
|
@ -775,19 +776,30 @@ const func_builtins & value_array_t::get_builtins() const {
|
|||
if (!is_val<value_array>(args.get_pos(0))) {
|
||||
throw raised_exception("join() first argument must be an array");
|
||||
}
|
||||
value val_delim = args.get_kwarg_or_pos("d", 1);
|
||||
value val_attribute = args.get_kwarg_or_pos("attribute", 2);
|
||||
if (!val_attribute->is_undefined()) {
|
||||
throw not_implemented_exception("array attribute join not implemented");
|
||||
}
|
||||
value val_delim = args.get_kwarg_or_pos("d", 1);
|
||||
value attribute = args.get_kwarg_or_pos("attribute", 2);
|
||||
const auto & arr = args.get_pos(0)->as_array();
|
||||
std::string delim = is_val<value_string>(val_delim) ? val_delim->as_string().str() : "";
|
||||
const bool attr_is_int = is_val<value_int>(attribute);
|
||||
if (!attribute->is_undefined() && !is_val<value_string>(attribute) && !attr_is_int) {
|
||||
throw raised_exception("join() attribute must be string or integer");
|
||||
}
|
||||
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
|
||||
const std::string delim = val_delim->is_undefined() ? "" : val_delim->as_string().str();
|
||||
const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str();
|
||||
std::string result;
|
||||
for (size_t i = 0; i < arr.size(); ++i) {
|
||||
if (!is_val<value_string>(arr[i]) && !is_val<value_int>(arr[i]) && !is_val<value_float>(arr[i])) {
|
||||
value val_arr = arr[i];
|
||||
if (!attribute->is_undefined()) {
|
||||
if (attr_is_int && is_val<value_array>(val_arr)) {
|
||||
val_arr = val_arr->at(attr_int);
|
||||
} else if (!attr_is_int && !attr_name.empty() && is_val<value_object>(val_arr)) {
|
||||
val_arr = val_arr->at(attr_name);
|
||||
}
|
||||
}
|
||||
if (!is_val<value_string>(val_arr) && !is_val<value_int>(val_arr) && !is_val<value_float>(val_arr)) {
|
||||
throw raised_exception("join() can only join arrays of strings or numerics");
|
||||
}
|
||||
result += arr[i]->as_string().str();
|
||||
result += val_arr->as_string().str();
|
||||
if (i < arr.size() - 1) {
|
||||
result += delim;
|
||||
}
|
||||
|
|
@ -802,26 +814,30 @@ const func_builtins & value_array_t::get_builtins() const {
|
|||
}},
|
||||
{"tojson", tojson},
|
||||
{"map", [](const func_args & args) -> value {
|
||||
args.ensure_count(2, 3);
|
||||
args.ensure_count(2);
|
||||
if (!is_val<value_array>(args.get_pos(0))) {
|
||||
throw raised_exception("map: first argument must be an array");
|
||||
}
|
||||
value attribute = args.get_kwarg_or_pos("attribute", 1);
|
||||
if (is_val<value_int>(attribute)) {
|
||||
throw not_implemented_exception("map: integer attribute not implemented");
|
||||
if (!is_val<value_kwarg>(args.get_args().at(1))) {
|
||||
throw not_implemented_exception("map: filter-mapping not implemented");
|
||||
}
|
||||
if (!is_val<value_string>(attribute)) {
|
||||
value attribute = args.get_kwarg_or_pos("attribute", 1);
|
||||
const bool attr_is_int = is_val<value_int>(attribute);
|
||||
if (!is_val<value_string>(attribute) && !attr_is_int) {
|
||||
throw raised_exception("map: attribute must be string or integer");
|
||||
}
|
||||
std::string attr_name = attribute->as_string().str();
|
||||
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
|
||||
const std::string attr_name = attribute->as_string().str();
|
||||
value default_val = args.get_kwarg("default", mk_val<value_undefined>());
|
||||
auto out = mk_val<value_array>();
|
||||
auto arr = args.get_pos(0)->as_array();
|
||||
for (const auto & item : arr) {
|
||||
if (!is_val<value_object>(item)) {
|
||||
throw raised_exception("map: item is not an object");
|
||||
value attr_val;
|
||||
if (attr_is_int) {
|
||||
attr_val = is_val<value_array>(item) ? item->at(attr_int, default_val) : default_val;
|
||||
} else {
|
||||
attr_val = is_val<value_object>(item) ? item->at(attr_name, default_val) : default_val;
|
||||
}
|
||||
value attr_val = item->at(attr_name, default_val);
|
||||
out->push_back(attr_val);
|
||||
}
|
||||
return out;
|
||||
|
|
@ -847,29 +863,35 @@ const func_builtins & value_array_t::get_builtins() const {
|
|||
return arr_editable->pop_at(index);
|
||||
}},
|
||||
{"sort", [](const func_args & args) -> value {
|
||||
args.ensure_count(1, 3);
|
||||
args.ensure_count(1, 4);
|
||||
if (!is_val<value_array>(args.get_pos(0))) {
|
||||
throw raised_exception("sort: first argument must be an array");
|
||||
}
|
||||
bool reverse = args.get_kwarg("reverse", mk_val<value_undefined>())->as_bool();
|
||||
value attribute = args.get_kwarg("attribute", mk_val<value_undefined>());
|
||||
std::string attr = attribute->is_undefined() ? "" : attribute->as_string().str();
|
||||
value val_reverse = args.get_kwarg_or_pos("reverse", 1);
|
||||
value val_case = args.get_kwarg_or_pos("case_sensitive", 2);
|
||||
value attribute = args.get_kwarg_or_pos("attribute", 3);
|
||||
// FIXME: sorting is currently always case sensitive
|
||||
//const bool case_sensitive = val_case->as_bool(); // undefined == false
|
||||
const bool reverse = val_reverse->as_bool(); // undefined == false
|
||||
const bool attr_is_int = is_val<value_int>(attribute);
|
||||
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
|
||||
const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str();
|
||||
std::vector<value> arr = cast_val<value_array>(args.get_pos(0))->as_array(); // copy
|
||||
std::sort(arr.begin(), arr.end(),[&](const value & a, const value & b) {
|
||||
value val_a = a;
|
||||
value val_b = b;
|
||||
if (!attribute->is_undefined()) {
|
||||
if (!is_val<value_object>(a) || !is_val<value_object>(b)) {
|
||||
throw raised_exception("sort: items are not objects");
|
||||
if (attr_is_int && is_val<value_array>(a) && is_val<value_array>(b)) {
|
||||
val_a = a->at(attr_int);
|
||||
val_b = b->at(attr_int);
|
||||
} else if (!attr_is_int && !attr_name.empty() && is_val<value_object>(a) && is_val<value_object>(b)) {
|
||||
val_a = a->at(attr_name);
|
||||
val_b = b->at(attr_name);
|
||||
} else {
|
||||
throw raised_exception("sort: unsupported object attribute comparison");
|
||||
}
|
||||
val_a = attr.empty() ? a : a->at(attr);
|
||||
val_b = attr.empty() ? b : b->at(attr);
|
||||
}
|
||||
if (reverse) {
|
||||
return value_compare(val_a, val_b, value_compare_op::gt);
|
||||
} else {
|
||||
return !value_compare(val_a, val_b, value_compare_op::gt);
|
||||
}
|
||||
return value_compare(val_a, val_b, reverse ? value_compare_op::gt : value_compare_op::lt);
|
||||
});
|
||||
return mk_val<value_array>(arr);
|
||||
}},
|
||||
|
|
@ -888,6 +910,11 @@ const func_builtins & value_array_t::get_builtins() const {
|
|||
|
||||
|
||||
const func_builtins & value_object_t::get_builtins() const {
|
||||
if (!has_builtins) {
|
||||
static const func_builtins no_builtins = {};
|
||||
return no_builtins;
|
||||
}
|
||||
|
||||
static const func_builtins builtins = {
|
||||
// {"default", default_value}, // cause issue with gpt-oss
|
||||
{"get", [](const func_args & args) -> value {
|
||||
|
|
@ -902,18 +929,13 @@ const func_builtins & value_object_t::get_builtins() const {
|
|||
if (args.count() == 3) {
|
||||
default_val = args.get_pos(2);
|
||||
}
|
||||
const auto & obj = args.get_pos(0)->as_object();
|
||||
const value obj = args.get_pos(0);
|
||||
std::string key = args.get_pos(1)->as_string().str();
|
||||
auto it = obj.find(key);
|
||||
if (it != obj.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
return default_val;
|
||||
}
|
||||
return obj->at(key, default_val);
|
||||
}},
|
||||
{"keys", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_object>();
|
||||
const auto & obj = args.get_pos(0)->as_object();
|
||||
const auto & obj = args.get_pos(0)->as_ordered_object();
|
||||
auto result = mk_val<value_array>();
|
||||
for (const auto & pair : obj) {
|
||||
result->push_back(mk_val<value_string>(pair.first));
|
||||
|
|
@ -922,7 +944,7 @@ const func_builtins & value_object_t::get_builtins() const {
|
|||
}},
|
||||
{"values", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_object>();
|
||||
const auto & obj = args.get_pos(0)->as_object();
|
||||
const auto & obj = args.get_pos(0)->as_ordered_object();
|
||||
auto result = mk_val<value_array>();
|
||||
for (const auto & pair : obj) {
|
||||
result->push_back(pair.second);
|
||||
|
|
@ -931,7 +953,7 @@ const func_builtins & value_object_t::get_builtins() const {
|
|||
}},
|
||||
{"items", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_object>();
|
||||
const auto & obj = args.get_pos(0)->as_object();
|
||||
const auto & obj = args.get_pos(0)->as_ordered_object();
|
||||
auto result = mk_val<value_array>();
|
||||
for (const auto & pair : obj) {
|
||||
auto item = mk_val<value_array>();
|
||||
|
|
@ -945,7 +967,7 @@ const func_builtins & value_object_t::get_builtins() const {
|
|||
{"string", tojson},
|
||||
{"length", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_object>();
|
||||
const auto & obj = args.get_pos(0)->as_object();
|
||||
const auto & obj = args.get_pos(0)->as_ordered_object();
|
||||
return mk_val<value_int>(static_cast<int64_t>(obj.size()));
|
||||
}},
|
||||
{"tojson", [](const func_args & args) -> value {
|
||||
|
|
@ -958,21 +980,18 @@ const func_builtins & value_object_t::get_builtins() const {
|
|||
value val_case = args.get_kwarg_or_pos("case_sensitive", 1);
|
||||
value val_by = args.get_kwarg_or_pos("by", 2);
|
||||
value val_reverse = args.get_kwarg_or_pos("reverse", 3);
|
||||
// FIXME: sorting is case sensitive
|
||||
// FIXME: sorting is currently always case sensitive
|
||||
//const bool case_sensitive = val_case->as_bool(); // undefined == false
|
||||
const bool reverse = val_reverse->as_bool(); // undefined == false
|
||||
if (!val_by->is_undefined()) {
|
||||
throw not_implemented_exception("dictsort by key not implemented");
|
||||
}
|
||||
if (reverse) {
|
||||
throw not_implemented_exception("dictsort reverse not implemented");
|
||||
}
|
||||
value_t::map obj = val_input->val_obj; // copy
|
||||
std::sort(obj.ordered.begin(), obj.ordered.end(), [&](const auto & a, const auto & b) {
|
||||
return a.first < b.first;
|
||||
const bool by_value = is_val<value_string>(val_by) && val_by->as_string().str() == "value" ? true : false;
|
||||
auto result = mk_val<value_object>(val_input); // copy
|
||||
std::sort(result->val_obj.ordered.begin(), result->val_obj.ordered.end(), [&](const auto & a, const auto & b) {
|
||||
if (by_value) {
|
||||
return value_compare(a.second, b.second, reverse ? value_compare_op::gt : value_compare_op::lt);
|
||||
} else {
|
||||
return reverse ? a.first > b.first : a.first < b.first;
|
||||
}
|
||||
});
|
||||
auto result = mk_val<value_object>();
|
||||
result->val_obj = std::move(obj);
|
||||
return result;
|
||||
}},
|
||||
{"join", [](const func_args &) -> value {
|
||||
|
|
@ -1169,7 +1188,7 @@ static void value_to_json_internal(std::ostringstream & oss, const value & val,
|
|||
}
|
||||
oss << "]";
|
||||
} else if (is_val<value_object>(val)) {
|
||||
const auto & obj = val->val_obj.ordered; // IMPORTANT: need to keep exact order
|
||||
const auto & obj = val->as_ordered_object(); // IMPORTANT: need to keep exact order
|
||||
oss << "{";
|
||||
if (!obj.empty()) {
|
||||
oss << newline();
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ struct value_t {
|
|||
virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); }
|
||||
virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); }
|
||||
virtual const std::vector<value> & as_array() const { throw std::runtime_error(type() + " is not an array value"); }
|
||||
virtual const std::map<std::string, value> & as_object() const { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual const std::vector<std::pair<std::string, value>> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); }
|
||||
virtual bool is_none() const { return false; }
|
||||
virtual bool is_undefined() const { return false; }
|
||||
|
|
@ -154,6 +154,9 @@ struct value_t {
|
|||
throw std::runtime_error("No builtins available for type " + type());
|
||||
}
|
||||
|
||||
virtual bool has_key(const std::string & key) {
|
||||
return val_obj.unordered.find(key) != val_obj.unordered.end();
|
||||
}
|
||||
virtual value & at(const std::string & key, value & default_val) {
|
||||
auto it = val_obj.unordered.find(key);
|
||||
if (it == val_obj.unordered.end()) {
|
||||
|
|
@ -168,8 +171,20 @@ struct value_t {
|
|||
}
|
||||
return val_obj.unordered.at(key);
|
||||
}
|
||||
virtual value & at(size_t index) {
|
||||
if (index >= val_arr.size()) {
|
||||
virtual value & at(int64_t index, value & default_val) {
|
||||
if (index < 0) {
|
||||
index += val_arr.size();
|
||||
}
|
||||
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
|
||||
return default_val;
|
||||
}
|
||||
return val_arr[index];
|
||||
}
|
||||
virtual value & at(int64_t index) {
|
||||
if (index < 0) {
|
||||
index += val_arr.size();
|
||||
}
|
||||
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
|
||||
throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size()));
|
||||
}
|
||||
return val_arr[index];
|
||||
|
|
@ -286,6 +301,7 @@ using value_array = std::shared_ptr<value_array_t>;
|
|||
|
||||
|
||||
struct value_object_t : public value_t {
|
||||
bool has_builtins = true; // context and loop objects do not have builtins
|
||||
value_object_t() = default;
|
||||
value_object_t(value & v) {
|
||||
val_obj = v->val_obj;
|
||||
|
|
@ -295,11 +311,16 @@ struct value_object_t : public value_t {
|
|||
val_obj.insert(pair.first, pair.second);
|
||||
}
|
||||
}
|
||||
value_object_t(const std::vector<std::pair<std::string, value>> & obj) {
|
||||
for (const auto & pair : obj) {
|
||||
val_obj.insert(pair.first, pair.second);
|
||||
}
|
||||
}
|
||||
void insert(const std::string & key, const value & val) {
|
||||
val_obj.insert(key, val);
|
||||
}
|
||||
virtual std::string type() const override { return "Object"; }
|
||||
virtual const std::map<std::string, value> & as_object() const override { return val_obj.unordered; }
|
||||
virtual const std::vector<std::pair<std::string, value>> & as_ordered_object() const override { return val_obj.ordered; }
|
||||
virtual bool as_bool() const override {
|
||||
return !val_obj.unordered.empty();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,10 +7,9 @@
|
|||
#include "unary-ops.h"
|
||||
#include "vec.h"
|
||||
|
||||
#include <cfloat>
|
||||
#include <algorithm>
|
||||
#include <cfloat>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
|
||||
// ggml_compute_forward_dup
|
||||
|
||||
|
|
@ -7110,12 +7109,13 @@ void ggml_compute_forward_conv_2d_dw(
|
|||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_pool_1d_sk_p0
|
||||
|
||||
static void ggml_compute_forward_pool_1d_sk_p0(
|
||||
// ggml_compute_forward_pool_1d_ksp
|
||||
static void ggml_compute_forward_pool_1d_ksp(
|
||||
const ggml_compute_params * params,
|
||||
const ggml_op_pool op,
|
||||
const int k,
|
||||
const int s,
|
||||
const int p,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src = dst->src[0];
|
||||
|
|
@ -7126,39 +7126,56 @@ static void ggml_compute_forward_pool_1d_sk_p0(
|
|||
return;
|
||||
}
|
||||
|
||||
const char * cdata = (const char *)src->data;
|
||||
const char * const data_end = cdata + ggml_nbytes(src);
|
||||
float * drow = (float *)dst->data;
|
||||
const int64_t IW = src->ne[0];
|
||||
const int64_t OW = dst->ne[0];
|
||||
|
||||
const int64_t rs = dst->ne[0];
|
||||
const int64_t nr = ggml_nrows(src);
|
||||
|
||||
while (cdata < data_end) {
|
||||
const void * srow = (const void *)cdata;
|
||||
int j = 0;
|
||||
for (int64_t i = 0; i < rs; ++i) {
|
||||
for (int64_t ir = 0; ir < nr; ++ir) {
|
||||
const char * srow_bytes = (const char *) src->data + ir * src->nb[1];
|
||||
float * drow = (float *) (( char *) dst->data + ir * dst->nb[1]);
|
||||
|
||||
for (int64_t ow = 0; ow < OW; ++ow) {
|
||||
float res = 0;
|
||||
switch (op) {
|
||||
case GGML_OP_POOL_AVG: drow[i] = 0; break;
|
||||
case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
|
||||
case GGML_OP_POOL_AVG: res = 0.0f; break;
|
||||
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
|
||||
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
int count = 0;
|
||||
const int base = (int) ow * s - p;
|
||||
|
||||
for (int ki = 0; ki < k; ++ki) {
|
||||
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
|
||||
switch (op) {
|
||||
case GGML_OP_POOL_AVG: drow[i] += srow_j; break;
|
||||
case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
|
||||
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
||||
const int j = base + ki;
|
||||
if (j < 0 || j >= (int) IW) {
|
||||
continue;
|
||||
}
|
||||
++j;
|
||||
|
||||
float v;
|
||||
if (src->type == GGML_TYPE_F32) {
|
||||
v = ((const float *) srow_bytes)[j];
|
||||
} else {
|
||||
v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]);
|
||||
}
|
||||
|
||||
switch (op) {
|
||||
case GGML_OP_POOL_AVG: res += v; break;
|
||||
case GGML_OP_POOL_MAX: res = std::max(v, res); break;
|
||||
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
++count;
|
||||
}
|
||||
|
||||
switch (op) {
|
||||
case GGML_OP_POOL_AVG: drow[i] /= k; break;
|
||||
case GGML_OP_POOL_MAX: break;
|
||||
case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break;
|
||||
case GGML_OP_POOL_MAX: break;
|
||||
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
cdata += src->nb[1];
|
||||
drow += rs;
|
||||
drow[ow] = res;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -7173,10 +7190,8 @@ void ggml_compute_forward_pool_1d(
|
|||
const int k0 = opts[1];
|
||||
const int s0 = opts[2];
|
||||
const int p0 = opts[3];
|
||||
GGML_ASSERT(p0 == 0); // padding not supported
|
||||
GGML_ASSERT(k0 == s0); // only s = k supported
|
||||
|
||||
ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
|
||||
ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_pool_2d
|
||||
|
|
@ -7194,6 +7209,7 @@ void ggml_compute_forward_pool_2d(
|
|||
}
|
||||
|
||||
const int32_t * opts = (const int32_t *)dst->op_params;
|
||||
|
||||
ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
|
||||
const int k0 = opts[1];
|
||||
const int k1 = opts[2];
|
||||
|
|
@ -7217,11 +7233,13 @@ void ggml_compute_forward_pool_2d(
|
|||
while (cdata < data_end) {
|
||||
for (int oy = 0; oy < py; ++oy) {
|
||||
float * const drow = dplane + oy * px;
|
||||
float * const out = drow;
|
||||
|
||||
for (int ox = 0; ox < px; ++ox) {
|
||||
float * const out = drow + ox;
|
||||
float res = 0;
|
||||
switch (op) {
|
||||
case GGML_OP_POOL_AVG: *out = 0; break;
|
||||
case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
|
||||
case GGML_OP_POOL_AVG: res = 0; break;
|
||||
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
|
||||
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
|
|
@ -7229,24 +7247,32 @@ void ggml_compute_forward_pool_2d(
|
|||
const int iy = offset1 + oy * s1;
|
||||
|
||||
for (int ky = 0; ky < k1; ++ky) {
|
||||
if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
|
||||
if (iy + ky < 0 || iy + ky >= src->ne[1]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
|
||||
for (int kx = 0; kx < k0; ++kx) {
|
||||
int j = ix + kx;
|
||||
if (j < 0 || j >= src->ne[0]) continue;
|
||||
if (j < 0 || j >= src->ne[0]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
|
||||
switch (op) {
|
||||
case GGML_OP_POOL_AVG: *out += srow_j; break;
|
||||
case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
|
||||
case GGML_OP_POOL_AVG: res += srow_j; break;
|
||||
case GGML_OP_POOL_MAX: res = std::max(srow_j, res); break;
|
||||
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
switch (op) {
|
||||
case GGML_OP_POOL_AVG: *out /= ka; break;
|
||||
case GGML_OP_POOL_MAX: break;
|
||||
case GGML_OP_POOL_AVG: res /= ka; break;
|
||||
case GGML_OP_POOL_MAX: break;
|
||||
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
out[ox] = res;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -94,6 +94,31 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_l
|
|||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
|
||||
|
||||
const char * pool_str = "undefined";
|
||||
switch (op_pool) {
|
||||
case GGML_OP_POOL_AVG: pool_str = "avg"; break;
|
||||
case GGML_OP_POOL_MAX: pool_str = "max"; break;
|
||||
default: GGML_ASSERT(false && "not implemented");
|
||||
};
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, sizeof(base), "kernel_pool_1d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, sizeof(name), "%s", base);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
|
||||
|
|
|
|||
|
|
@ -104,6 +104,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_
|
|||
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
|
||||
|
|
|
|||
|
|
@ -1050,10 +1050,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
op->type == GGML_TYPE_F32 &&
|
||||
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
|
||||
case GGML_OP_POOL_1D:
|
||||
return false;
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
|
||||
case GGML_OP_POOL_1D:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_POOL_2D:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_PAD:
|
||||
|
|
|
|||
|
|
@ -928,6 +928,15 @@ typedef struct {
|
|||
int64_t np;
|
||||
} ggml_metal_kargs_pool_2d;
|
||||
|
||||
typedef struct {
|
||||
int32_t k0;
|
||||
int32_t s0;
|
||||
int32_t p0;
|
||||
int64_t IW;
|
||||
int64_t OW;
|
||||
int64_t np;
|
||||
} ggml_metal_kargs_pool_1d;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
uint64_t nb01;
|
||||
|
|
|
|||
|
|
@ -432,6 +432,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|||
{
|
||||
n_fuse = ggml_metal_op_cpy(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_POOL_1D:
|
||||
{
|
||||
n_fuse = ggml_metal_op_pool_1d(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_POOL_2D:
|
||||
{
|
||||
n_fuse = ggml_metal_op_pool_2d(ctx, idx);
|
||||
|
|
@ -1622,6 +1626,54 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const int32_t * opts = op->op_params;
|
||||
ggml_op_pool op_pool = (ggml_op_pool) opts[0];
|
||||
|
||||
const int32_t k0 = opts[1];
|
||||
const int32_t s0 = opts[2];
|
||||
const int32_t p0 = opts[3];
|
||||
|
||||
const int64_t IW = op->src[0]->ne[0];
|
||||
const int64_t OW = op->ne[0];
|
||||
|
||||
const int64_t np = ggml_nelements(op);
|
||||
|
||||
ggml_metal_kargs_pool_1d args_pool_1d = {
|
||||
/* .k0 = */ k0,
|
||||
/* .s0 = */ s0,
|
||||
/* .p0 = */ p0,
|
||||
/* .IW = */ IW,
|
||||
/* .OW = */ OW,
|
||||
/* .np = */ np
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
|
||||
const int ntg = (np + nth - 1) / nth;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args_pool_1d, sizeof(args_pool_1d), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
|
|||
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_mul_mat (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx);
|
||||
|
|
|
|||
|
|
@ -9869,6 +9869,74 @@ kernel void kernel_pool_2d_avg_f32(
|
|||
o_ptr[cur_oh * args.OW + cur_ow] = res;
|
||||
}
|
||||
|
||||
|
||||
kernel void kernel_pool_1d_max_f32(
|
||||
constant ggml_metal_kargs_pool_1d & args,
|
||||
device const float * src,
|
||||
device float * dst,
|
||||
uint gid [[thread_position_in_grid]]
|
||||
) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ow = (int)gid % args.OW;
|
||||
const int row = (int)gid / args.OW;
|
||||
|
||||
const int base = ow * args.s0 - args.p0;
|
||||
|
||||
float acc = -INFINITY;
|
||||
|
||||
const int src_off = row * args.IW;
|
||||
const int dst_off = row * args.OW;
|
||||
|
||||
for (int ki = 0; ki < args.k0; ++ki) {
|
||||
int j = base + ki;
|
||||
if (j < 0 || j >= args.IW){
|
||||
continue;
|
||||
}
|
||||
float v = src[src_off + j];
|
||||
acc = max(acc, v);
|
||||
}
|
||||
|
||||
dst[dst_off + ow] = acc;
|
||||
}
|
||||
|
||||
kernel void kernel_pool_1d_avg_f32(
|
||||
constant ggml_metal_kargs_pool_1d & args,
|
||||
device const float * src,
|
||||
device float * dst,
|
||||
uint gid [[thread_position_in_grid]]
|
||||
) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ow = (int)gid % args.OW;
|
||||
const int row = (int)gid / args.OW;
|
||||
|
||||
const int base = ow * args.s0 - args.p0;
|
||||
|
||||
float acc = 0.0f;
|
||||
int cnt = 0;
|
||||
|
||||
const int src_off = row * args.IW;
|
||||
const int dst_off = row * args.OW;
|
||||
|
||||
for (int ki = 0; ki < args.k0; ++ki) {
|
||||
const int j = base + ki;
|
||||
if (j < 0 || j >= args.IW) {
|
||||
continue;
|
||||
}
|
||||
acc += src[src_off + j];
|
||||
cnt += 1;
|
||||
}
|
||||
|
||||
dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;
|
||||
}
|
||||
|
||||
kernel void kernel_opt_step_adamw_f32(
|
||||
constant ggml_metal_kargs_opt_step_adamw & args,
|
||||
device float * x,
|
||||
|
|
|
|||
|
|
@ -4854,6 +4854,8 @@ struct ggml_tensor * ggml_pool_1d(
|
|||
a->ne[2],
|
||||
a->ne[3],
|
||||
};
|
||||
GGML_ASSERT(ne[0] > 0);
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
int32_t params[] = { op, k0, s0, p0 };
|
||||
|
|
@ -4884,6 +4886,9 @@ struct ggml_tensor * ggml_pool_2d(
|
|||
a->ne[2],
|
||||
a->ne[3],
|
||||
};
|
||||
GGML_ASSERT(ne[0] > 0);
|
||||
GGML_ASSERT(ne[1] > 0);
|
||||
|
||||
result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ dry_seq_break_max = 128
|
|||
extra_images_max = 4 # for kontext/qwen img
|
||||
|
||||
# global vars
|
||||
KcppVersion = "1.106.2"
|
||||
KcppVersion = "1.107"
|
||||
showdebug = True
|
||||
kcpp_instance = None #global running instance
|
||||
global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False, "restart_override_config_target":""}
|
||||
|
|
|
|||
|
|
@ -200,42 +200,6 @@ uint32_t llama_hparams::n_layer_kv() const {
|
|||
return res;
|
||||
}
|
||||
|
||||
bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
|
||||
assert(p0 >= 0 && p1 >= 0);
|
||||
|
||||
switch (swa_type) {
|
||||
case LLAMA_SWA_TYPE_NONE:
|
||||
{
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_STANDARD:
|
||||
{
|
||||
if (p1 - p0 >= (int32_t) n_swa) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_CHUNKED:
|
||||
{
|
||||
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
|
||||
|
||||
if (p0 < pos_chunk_start) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_SYMMETRIC:
|
||||
{
|
||||
const int32_t half_n_swa = (int32_t) n_swa / 2;
|
||||
const int32_t pos_diff = p1 - p0;
|
||||
|
||||
// Mask if outside the symmetric window
|
||||
if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool llama_hparams::use_mrope() const {
|
||||
return rope_sections[0] > 0 && rope_sections[1] > 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include "llama.h"
|
||||
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
|
||||
// bump if necessary
|
||||
#define LLAMA_MAX_LAYERS 512
|
||||
|
|
@ -274,9 +275,45 @@ struct llama_hparams {
|
|||
uint32_t n_layer_kv() const;
|
||||
|
||||
// note that this function uses different SWA parameters from those in the hparams
|
||||
// note: inlined on purpose for performance reasons
|
||||
// TODO: think of a better place for this function
|
||||
// TODO: pack the SWA params in a struct?
|
||||
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
|
||||
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
|
||||
assert(p0 >= 0 && p1 >= 0);
|
||||
|
||||
switch (swa_type) {
|
||||
case LLAMA_SWA_TYPE_NONE:
|
||||
{
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_STANDARD:
|
||||
{
|
||||
if (p1 - p0 >= (int32_t) n_swa) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_CHUNKED:
|
||||
{
|
||||
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
|
||||
|
||||
if (p0 < pos_chunk_start) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_SYMMETRIC:
|
||||
{
|
||||
const int32_t half_n_swa = (int32_t) n_swa / 2;
|
||||
const int32_t pos_diff = p1 - p0;
|
||||
|
||||
// Mask if outside the symmetric window
|
||||
if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
bool use_mrope() const;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -852,7 +852,7 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
|
|||
const llama_seq_id seq_id_cell = cells.seq_get(idx);
|
||||
|
||||
// SWA mask
|
||||
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
||||
if (llama_hparams::is_masked_swa(n_swa, swa_type, pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
||||
can_use = true;
|
||||
}
|
||||
}
|
||||
|
|
@ -1237,6 +1237,197 @@ void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
|
|||
}
|
||||
}
|
||||
|
||||
struct args_set_input_kq_mask {
|
||||
const llama_hparams & hparams;
|
||||
const llama_ubatch * ubatch;
|
||||
|
||||
const std::vector<llama_kv_cells> & v_cells;
|
||||
const std::vector<uint32_t> & seq_to_stream;
|
||||
|
||||
uint32_t n_swa;
|
||||
llama_swa_type swa_type;
|
||||
|
||||
int64_t n_kv;
|
||||
int64_t n_stream;
|
||||
int64_t n_tps;
|
||||
};
|
||||
|
||||
template<bool causal, bool swa, bool is_2d, bool alibi>
|
||||
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
||||
//const auto & hparams = args.hparams;
|
||||
const auto & ubatch = args.ubatch;
|
||||
|
||||
const auto & v_cells = args.v_cells;
|
||||
const auto & seq_to_stream = args.seq_to_stream;
|
||||
|
||||
const uint32_t n_swa = args.n_swa;
|
||||
const llama_swa_type swa_type = args.swa_type;
|
||||
|
||||
const int64_t n_kv = args.n_kv;
|
||||
const int64_t n_stream = args.n_stream;
|
||||
const int64_t n_tps = args.n_tps;
|
||||
|
||||
// the min position in the batch for each sequence
|
||||
llama_pos seq_pos_min[LLAMA_MAX_SEQ];
|
||||
std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX);
|
||||
|
||||
for (uint32_t i = 0; i < ubatch->n_tokens; ++i) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
|
||||
seq_pos_min[seq_id] = std::min(seq_pos_min[seq_id], ubatch->pos[i]);
|
||||
}
|
||||
|
||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||
// bookeeping of the KQ mask cells that could change for other tokens of the same sequence
|
||||
std::unordered_map<llama_seq_id, uint32_t> seq_srct;
|
||||
std::unordered_map<llama_seq_id, std::vector<uint32_t>> seq_idxs;
|
||||
|
||||
for (uint32_t ii = 0; ii < n_tps; ++ii) {
|
||||
const uint32_t i = s*n_tps + ii;
|
||||
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
|
||||
const auto & cells = v_cells.at(seq_to_stream[seq_id]);
|
||||
|
||||
llama_pos p0 = -1;
|
||||
const llama_pos p1 = ubatch->pos[i];
|
||||
|
||||
// for M-RoPE
|
||||
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
|
||||
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
|
||||
|
||||
const uint64_t idst = n_kv*i;
|
||||
|
||||
// for tokens of the same sequence, the mask is mostly the same, so we can reuse it
|
||||
// the only cells that could change are the ones that are with similar positions as the
|
||||
// ones in the batch (i.e. due to causal masking, SWA, etc.)
|
||||
// keep track of those cells and shortcut the loop to save time
|
||||
// note: this optimization is not compatible with Alibi position encoding
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18842
|
||||
bool prev = false;
|
||||
|
||||
auto & idxs = seq_idxs[seq_id];
|
||||
|
||||
if (!alibi) {
|
||||
if (seq_srct.find(seq_id) != seq_srct.end()) {
|
||||
const uint32_t srct = seq_srct[seq_id];
|
||||
|
||||
const uint64_t idst_prev = n_kv*srct;
|
||||
|
||||
std::copy(data + idst_prev, data + idst_prev + n_kv, data + idst);
|
||||
|
||||
prev = true;
|
||||
} else {
|
||||
idxs.clear();
|
||||
idxs.reserve(ubatch->n_tokens + n_swa + 32);
|
||||
|
||||
seq_srct[seq_id] = i;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t jj = 0; jj < n_kv; ++jj) {
|
||||
uint32_t j = jj;
|
||||
|
||||
// we have an exiting mask for this sequence -> update just seq_idxs
|
||||
if (!alibi) {
|
||||
if (prev) {
|
||||
if (jj >= idxs.size()) {
|
||||
break;
|
||||
}
|
||||
|
||||
j = idxs[jj];
|
||||
}
|
||||
}
|
||||
|
||||
if (cells.is_empty(j)) {
|
||||
goto skip;
|
||||
}
|
||||
|
||||
// mask the token if not the same sequence
|
||||
if (!cells.seq_has(j, seq_id)) {
|
||||
goto skip;
|
||||
}
|
||||
|
||||
p0 = cells.pos_get(j);
|
||||
|
||||
if (!alibi) {
|
||||
if (!prev) {
|
||||
// record all cells for which: p0 >= seq_pos_min[seq_id] - n_swa - 32
|
||||
if (p0 + (int32_t) (n_swa + 32) >= seq_pos_min[seq_id]) {
|
||||
idxs.push_back(j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (causal) {
|
||||
// mask future tokens
|
||||
if (p0 > p1) {
|
||||
goto skip;
|
||||
}
|
||||
|
||||
// M-RoPE causal mask
|
||||
if (is_2d) {
|
||||
if (p0 == p1) {
|
||||
const auto & p0_ext = cells.ext_get(j);
|
||||
|
||||
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
|
||||
goto skip;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// apply SWA if any
|
||||
if (swa) {
|
||||
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
|
||||
goto skip;
|
||||
}
|
||||
}
|
||||
|
||||
if (alibi) {
|
||||
data[idst + j] = -std::abs(p0 - p1);
|
||||
} else {
|
||||
data[idst + j] = 0.0f;
|
||||
}
|
||||
|
||||
continue;
|
||||
skip:
|
||||
data[idst + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<bool causal, bool swa, bool is_2d>
|
||||
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
||||
const bool alibi = args.hparams.use_alibi;
|
||||
if (alibi) {
|
||||
set_input_kq_mask_impl<causal, swa, is_2d, true> (args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl<causal, swa, is_2d, false>(args, data);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool causal, bool swa>
|
||||
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
||||
const bool is_2d = args.ubatch->is_pos_2d();
|
||||
if (is_2d) {
|
||||
set_input_kq_mask_impl<causal, swa, true> (args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl<causal, swa, false>(args, data);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool causal>
|
||||
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
||||
const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE;
|
||||
if (swa) {
|
||||
set_input_kq_mask_impl<causal, true> (args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl<causal, false>(args, data);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||
const uint32_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
|
|
@ -1251,74 +1442,29 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
|
|||
// n_tps == n_tokens_per_stream
|
||||
const int64_t n_tps = n_tokens/n_stream;
|
||||
|
||||
std::fill(data, data + ggml_nelements(dst), -INFINITY);
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
|
||||
// Causal mask:
|
||||
// xxx-------
|
||||
// xxxx------
|
||||
// xxxxx-----
|
||||
// Non-causal mask:
|
||||
// xxxxx-----
|
||||
// xxxxx-----
|
||||
// xxxxx-----
|
||||
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
||||
// TODO: optimize this section
|
||||
for (uint32_t h = 0; h < 1; ++h) {
|
||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||
for (uint32_t ii = 0; ii < n_tps; ++ii) {
|
||||
const uint32_t i = s*n_tps + ii;
|
||||
const args_set_input_kq_mask args = {
|
||||
/*.hparams =*/ hparams,
|
||||
/*.ubatch =*/ ubatch,
|
||||
/*.v_cells =*/ v_cells,
|
||||
/*.seq_to_stream =*/ seq_to_stream,
|
||||
/*.n_swa =*/ n_swa,
|
||||
/*.swa_type =*/ swa_type,
|
||||
/*.n_kv =*/ n_kv,
|
||||
/*.n_stream =*/ n_stream,
|
||||
/*.n_tps =*/ n_tps,
|
||||
};
|
||||
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
|
||||
const auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
|
||||
const llama_pos p1 = ubatch->pos[i];
|
||||
|
||||
// for M-RoPE
|
||||
const bool is_2d = ubatch->is_pos_2d();
|
||||
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
|
||||
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
|
||||
|
||||
const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii);
|
||||
|
||||
for (uint32_t j = 0; j < n_kv; ++j) {
|
||||
if (cells.is_empty(j)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// mask the token if not the same sequence
|
||||
if (!cells.seq_has(j, seq_id)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const llama_pos p0 = cells.pos_get(j);
|
||||
|
||||
// mask future tokens
|
||||
if (causal_attn && p0 > p1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// M-RoPE causal mask
|
||||
if (causal_attn && is_2d && p0 == p1) {
|
||||
const auto & p0_ext = cells.ext_get(j);
|
||||
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// apply SWA if any
|
||||
if (is_masked_swa(p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (causal_attn) {
|
||||
set_input_kq_mask_impl<true> (args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl<false>(args, data);
|
||||
}
|
||||
|
||||
//const int64_t t_end = ggml_time_us();
|
||||
|
||||
//LLAMA_LOG_ERROR("%s: kq mask time: %0.3f ms\n", __func__, (t_end - t_start)/1000.0);
|
||||
}
|
||||
|
||||
void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
|
|
@ -1483,10 +1629,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
|||
return gf;
|
||||
}
|
||||
|
||||
bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
||||
return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1);
|
||||
}
|
||||
|
||||
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||
GGML_UNUSED(flags);
|
||||
|
||||
|
|
|
|||
|
|
@ -257,8 +257,6 @@ private:
|
|||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
|
||||
bool is_masked_swa(llama_pos p0, llama_pos p1) const;
|
||||
|
||||
ggml_tensor * build_rope_shift(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include <cstdlib>
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <sheredom/subprocess.h>
|
||||
|
||||
#include "jinja/runtime.h"
|
||||
#include "jinja/parser.h"
|
||||
|
|
@ -31,12 +32,24 @@ static void test_array_methods(testing & t);
|
|||
static void test_object_methods(testing & t);
|
||||
static void test_fuzzing(testing & t);
|
||||
|
||||
static bool g_python_mode = false;
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
testing t(std::cout);
|
||||
t.verbose = true;
|
||||
|
||||
if (argc >= 2) {
|
||||
t.set_filter(argv[1]);
|
||||
// usage: test-jinja [-py] [filter_regex]
|
||||
// -py : enable python mode (use python jinja2 for rendering expected output)
|
||||
// only use this for cross-checking, not for correctness
|
||||
// note: the implementation of this flag is basic, only intented to be used by maintainers
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
if (arg == "-py") {
|
||||
g_python_mode = true;
|
||||
} else {
|
||||
t.set_filter(arg);
|
||||
}
|
||||
}
|
||||
|
||||
t.test("whitespace control", test_whitespace_control);
|
||||
|
|
@ -53,7 +66,9 @@ int main(int argc, char *argv[]) {
|
|||
t.test("string methods", test_string_methods);
|
||||
t.test("array methods", test_array_methods);
|
||||
t.test("object methods", test_object_methods);
|
||||
t.test("fuzzing", test_fuzzing);
|
||||
if (!g_python_mode) {
|
||||
t.test("fuzzing", test_fuzzing);
|
||||
}
|
||||
|
||||
return t.summary();
|
||||
}
|
||||
|
|
@ -247,6 +262,12 @@ static void test_expressions(testing & t) {
|
|||
"Bob"
|
||||
);
|
||||
|
||||
test_template(t, "negative float (not dot notation)",
|
||||
"{{ -1.0 }}",
|
||||
json::object(),
|
||||
"-1.0"
|
||||
);
|
||||
|
||||
test_template(t, "bracket notation",
|
||||
"{{ user['name'] }}",
|
||||
{{"user", {{"name", "Bob"}}}},
|
||||
|
|
@ -383,6 +404,32 @@ static void test_filters(testing & t) {
|
|||
"123"
|
||||
);
|
||||
|
||||
test_template(t, "sort reverse",
|
||||
"{% for i in items|sort(true) %}{{ i }}{% endfor %}",
|
||||
{{"items", json::array({3, 1, 2})}},
|
||||
"321"
|
||||
);
|
||||
|
||||
test_template(t, "sort with attribute",
|
||||
"{{ items|sort(attribute='name')|join(attribute='age') }}",
|
||||
{{"items", json::array({
|
||||
json({{"name", "c"}, {"age", 3}}),
|
||||
json({{"name", "a"}, {"age", 1}}),
|
||||
json({{"name", "b"}, {"age", 2}}),
|
||||
})}},
|
||||
"123"
|
||||
);
|
||||
|
||||
test_template(t, "sort with numeric attribute",
|
||||
"{{ items|sort(attribute=0)|join(attribute=1) }}",
|
||||
{{"items", json::array({
|
||||
json::array({3, "z"}),
|
||||
json::array({1, "x"}),
|
||||
json::array({2, "y"}),
|
||||
})}},
|
||||
"xyz"
|
||||
);
|
||||
|
||||
test_template(t, "join",
|
||||
"{{ items|join(', ') }}",
|
||||
{{"items", json::array({"a", "b", "c"})}},
|
||||
|
|
@ -534,6 +581,66 @@ static void test_literals(testing & t) {
|
|||
json::object(),
|
||||
"1"
|
||||
);
|
||||
|
||||
test_template(t, "integer|abs",
|
||||
"{{ -42 | abs }}",
|
||||
json::object(),
|
||||
"42"
|
||||
);
|
||||
|
||||
test_template(t, "integer|float",
|
||||
"{{ 42 | float }}",
|
||||
json::object(),
|
||||
"42.0"
|
||||
);
|
||||
|
||||
test_template(t, "integer|tojson",
|
||||
"{{ 42 | tojson }}",
|
||||
json::object(),
|
||||
"42"
|
||||
);
|
||||
|
||||
test_template(t, "float|abs",
|
||||
"{{ -3.14 | abs }}",
|
||||
json::object(),
|
||||
"3.14"
|
||||
);
|
||||
|
||||
test_template(t, "float|int",
|
||||
"{{ 3.14 | int }}",
|
||||
json::object(),
|
||||
"3"
|
||||
);
|
||||
|
||||
test_template(t, "float|tojson",
|
||||
"{{ 3.14 | tojson }}",
|
||||
json::object(),
|
||||
"3.14"
|
||||
);
|
||||
|
||||
test_template(t, "string|tojson",
|
||||
"{{ 'hello' | tojson }}",
|
||||
json::object(),
|
||||
"\"hello\""
|
||||
);
|
||||
|
||||
test_template(t, "boolean|int",
|
||||
"{{ true | int }}",
|
||||
json::object(),
|
||||
"1"
|
||||
);
|
||||
|
||||
test_template(t, "boolean|float",
|
||||
"{{ true | float }}",
|
||||
json::object(),
|
||||
"1.0"
|
||||
);
|
||||
|
||||
test_template(t, "boolean|tojson",
|
||||
"{{ true | tojson }}",
|
||||
json::object(),
|
||||
"true"
|
||||
);
|
||||
}
|
||||
|
||||
static void test_comments(testing & t) {
|
||||
|
|
@ -934,7 +1041,17 @@ static void test_array_methods(testing & t) {
|
|||
);
|
||||
|
||||
test_template(t, "array|join attribute",
|
||||
"{{ arr|join(attribute=0) }}",
|
||||
"{{ arr|join(attribute='age') }}",
|
||||
{{"arr", json::array({
|
||||
json({{"name", "a"}, {"age", 1}}),
|
||||
json({{"name", "b"}, {"age", 2}}),
|
||||
json({{"name", "c"}, {"age", 3}}),
|
||||
})}},
|
||||
"123"
|
||||
);
|
||||
|
||||
test_template(t, "array|join numeric attribute",
|
||||
"{{ arr|join(attribute=-1) }}",
|
||||
{{"arr", json::array({json::array({1}), json::array({2}), json::array({3})})}},
|
||||
"123"
|
||||
);
|
||||
|
|
@ -957,8 +1074,8 @@ static void test_array_methods(testing & t) {
|
|||
"a,b,c,d"
|
||||
);
|
||||
|
||||
test_template(t, "array.map() with attribute",
|
||||
"{% for v in arr.map('age') %}{{ v }} {% endfor %}",
|
||||
test_template(t, "array|map with attribute",
|
||||
"{% for v in arr|map(attribute='age') %}{{ v }} {% endfor %}",
|
||||
{{"arr", json::array({
|
||||
json({{"name", "a"}, {"age", 1}}),
|
||||
json({{"name", "b"}, {"age", 2}}),
|
||||
|
|
@ -967,8 +1084,28 @@ static void test_array_methods(testing & t) {
|
|||
"1 2 3 "
|
||||
);
|
||||
|
||||
test_template(t, "array.map() with numeric attribute",
|
||||
"{% for v in arr.map(0) %}{{ v }} {% endfor %}",
|
||||
test_template(t, "array|map with attribute default",
|
||||
"{% for v in arr|map(attribute='age', default=3) %}{{ v }} {% endfor %}",
|
||||
{{"arr", json::array({
|
||||
json({{"name", "a"}, {"age", 1}}),
|
||||
json({{"name", "b"}, {"age", 2}}),
|
||||
json({{"name", "c"}}),
|
||||
})}},
|
||||
"1 2 3 "
|
||||
);
|
||||
|
||||
test_template(t, "array|map without attribute default",
|
||||
"{% for v in arr|map(attribute='age') %}{{ v }} {% endfor %}",
|
||||
{{"arr", json::array({
|
||||
json({{"name", "a"}, {"age", 1}}),
|
||||
json({{"name", "b"}, {"age", 2}}),
|
||||
json({{"name", "c"}}),
|
||||
})}},
|
||||
"1 2 "
|
||||
);
|
||||
|
||||
test_template(t, "array|map with numeric attribute",
|
||||
"{% for v in arr|map(attribute=0) %}{{ v }} {% endfor %}",
|
||||
{{"arr", json::array({
|
||||
json::array({10, "x"}),
|
||||
json::array({20, "y"}),
|
||||
|
|
@ -977,6 +1114,22 @@ static void test_array_methods(testing & t) {
|
|||
"10 20 30 "
|
||||
);
|
||||
|
||||
test_template(t, "array|map with negative attribute",
|
||||
"{% for v in arr|map(attribute=-1) %}{{ v }} {% endfor %}",
|
||||
{{"arr", json::array({
|
||||
json::array({10, "x"}),
|
||||
json::array({20, "y"}),
|
||||
json::array({30, "z"}),
|
||||
})}},
|
||||
"x y z "
|
||||
);
|
||||
|
||||
test_template(t, "array|map with filter",
|
||||
"{{ arr|map('int')|sum }}",
|
||||
{{"arr", json::array({"1", "2", "3"})}},
|
||||
"6"
|
||||
);
|
||||
|
||||
// not used by any chat templates
|
||||
// test_template(t, "array.insert()",
|
||||
// "{% set _ = arr.insert(1, 'x') %}{{ arr|join(',') }}",
|
||||
|
|
@ -1063,9 +1216,21 @@ static void test_object_methods(testing & t) {
|
|||
{{"obj", {{"items", json::array({1, 2, 3})}}}},
|
||||
"{\"items\": [1, 2, 3]}"
|
||||
);
|
||||
|
||||
test_template(t, "object attribute and key access",
|
||||
"{{ obj.keys()|join(',') }} vs {{ obj['keys'] }} vs {{ obj.test }}",
|
||||
{{"obj", {{"keys", "value"}, {"test", "attr_value"}}}},
|
||||
"keys,test vs value vs attr_value"
|
||||
);
|
||||
|
||||
test_template(t, "env should not have object methods",
|
||||
"{{ keys is undefined }} {{ obj.keys is defined }}",
|
||||
{{"obj", {{"a", "b"}}}},
|
||||
"True True"
|
||||
);
|
||||
}
|
||||
|
||||
static void test_template(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
|
||||
static void test_template_cpp(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
|
||||
t.test(name, [&tmpl, &vars, &expect](testing & t) {
|
||||
jinja::lexer lexer;
|
||||
auto lexer_res = lexer.tokenize(tmpl);
|
||||
|
|
@ -1098,6 +1263,99 @@ static void test_template(testing & t, const std::string & name, const std::stri
|
|||
});
|
||||
}
|
||||
|
||||
// keep this in-sync with https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py
|
||||
// note: we use SandboxedEnvironment instead of ImmutableSandboxedEnvironment to allow usage of in-place array methods like append() and pop()
|
||||
static std::string py_script = R"(
|
||||
import jinja2
|
||||
import jinja2.ext as jinja2_ext
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
tmpl = json.loads(sys.argv[1])
|
||||
vars_json = json.loads(sys.argv[2])
|
||||
|
||||
env = SandboxedEnvironment(
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
extensions=[jinja2_ext.loopcontrols],
|
||||
)
|
||||
|
||||
def raise_exception(message):
|
||||
raise jinja2.exceptions.TemplateError(message)
|
||||
|
||||
env.filters["tojson"] = lambda x, ensure_ascii=False, indent=None, separators=None, sort_keys=False: json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
|
||||
env.globals["strftime_now"] = lambda format: datetime.now().strftime(format)
|
||||
env.globals["raise_exception"] = raise_exception
|
||||
|
||||
template = env.from_string(tmpl)
|
||||
result = template.render(**vars_json)
|
||||
print(result, end='')
|
||||
)";
|
||||
|
||||
static void test_template_py(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
|
||||
t.test(name, [&tmpl, &vars, &expect](testing & t) {
|
||||
// Prepare arguments
|
||||
std::string tmpl_json = json(tmpl).dump();
|
||||
std::string vars_json = vars.dump();
|
||||
|
||||
#ifdef _WIN32
|
||||
const char * python_executable = "python.exe";
|
||||
#else
|
||||
const char * python_executable = "python3";
|
||||
#endif
|
||||
|
||||
const char * command_line[] = {python_executable, "-c", py_script.c_str(), tmpl_json.c_str(), vars_json.c_str(), NULL};
|
||||
|
||||
struct subprocess_s subprocess;
|
||||
int options = subprocess_option_combined_stdout_stderr
|
||||
| subprocess_option_no_window
|
||||
| subprocess_option_inherit_environment
|
||||
| subprocess_option_search_user_path;
|
||||
int result = subprocess_create(command_line, options, &subprocess);
|
||||
|
||||
if (result != 0) {
|
||||
t.log("Failed to create subprocess, error code: " + std::to_string(result));
|
||||
t.assert_true("subprocess creation", false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Read output
|
||||
std::string output;
|
||||
char buffer[1024];
|
||||
FILE * p_stdout = subprocess_stdout(&subprocess);
|
||||
while (fgets(buffer, sizeof(buffer), p_stdout)) {
|
||||
output += buffer;
|
||||
}
|
||||
|
||||
int process_return;
|
||||
subprocess_join(&subprocess, &process_return);
|
||||
subprocess_destroy(&subprocess);
|
||||
|
||||
if (process_return != 0) {
|
||||
t.log("Python script failed with exit code: " + std::to_string(process_return));
|
||||
t.log("Output: " + output);
|
||||
t.assert_true("python execution", false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!t.assert_true("Template render mismatch", expect == output)) {
|
||||
t.log("Template: " + json(tmpl).dump());
|
||||
t.log("Expected: " + json(expect).dump());
|
||||
t.log("Python : " + json(output).dump());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static void test_template(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
|
||||
if (g_python_mode) {
|
||||
test_template_py(t, name, tmpl, vars, expect);
|
||||
} else {
|
||||
test_template_cpp(t, name, tmpl, vars, expect);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// fuzz tests to ensure no crashes occur on malformed inputs
|
||||
//
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue