mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
Merge commit 'e121edc432
' into concedo_experimental
# Conflicts: # .github/workflows/release.yml # common/CMakeLists.txt # docs/function-calling.md # ggml/src/ggml-sycl/binbcast.cpp # models/templates/README.md # scripts/tool_bench.py # src/llama-kv-cache.cpp # tests/CMakeLists.txt # tests/test-chat.cpp # tools/mtmd/clip.h # tools/rpc/rpc-server.cpp # tools/server/README.md
This commit is contained in:
commit
868cb6aff7
45 changed files with 3521 additions and 1182 deletions
|
@ -2849,15 +2849,24 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA"));
|
||||
add_opt(common_arg(
|
||||
{"--reasoning-format"}, "FORMAT",
|
||||
"reasoning format (default: deepseek; allowed values: deepseek, none)\n"
|
||||
"controls whether thought tags are extracted from the response, and in which format they're returned. 'none' leaves thoughts unparsed in `message.content`, 'deepseek' puts them in `message.reasoning_content` (for DeepSeek R1 & Command R7B only).\n"
|
||||
"only supported for non-streamed responses",
|
||||
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
|
||||
"- none: leaves thoughts unparsed in `message.content`\n"
|
||||
"- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n"
|
||||
"(default: deepseek)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
/**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; }
|
||||
else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; }
|
||||
else { std::invalid_argument("invalid value"); }
|
||||
else { throw std::invalid_argument("invalid value"); }
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK"));
|
||||
add_opt(common_arg(
|
||||
{"--reasoning-budget"}, "N",
|
||||
"controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)",
|
||||
[](common_params & params, int value) {
|
||||
if (value != 0 && value != -1) { throw std::invalid_argument("invalid value"); }
|
||||
params.reasoning_budget = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK_BUDGET"));
|
||||
add_opt(common_arg(
|
||||
{"--chat-template"}, "JINJA_TEMPLATE",
|
||||
string_format(
|
||||
|
@ -2956,7 +2965,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
[](common_params & params, const std::string & value) {
|
||||
/**/ if (value == "jsonl") { params.batched_bench_output_jsonl = true; }
|
||||
else if (value == "md") { params.batched_bench_output_jsonl = false; }
|
||||
else { std::invalid_argument("invalid value"); }
|
||||
else { throw std::invalid_argument("invalid value"); }
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_BENCH}));
|
||||
add_opt(common_arg(
|
||||
|
|
376
common/chat-parser.cpp
Normal file
376
common/chat-parser.cpp
Normal file
|
@ -0,0 +1,376 @@
|
|||
#include "chat-parser.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
#include <optional>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
|
||||
: input_(input), is_partial_(is_partial), syntax_(syntax)
|
||||
{
|
||||
result_.role = "assistant";
|
||||
|
||||
while (true) {
|
||||
std::string id = std::to_string(std::rand());
|
||||
if (input.find(id) == std::string::npos) {
|
||||
healing_marker_ = id;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string common_chat_msg_parser::str(const common_string_range & rng) const {
|
||||
GGML_ASSERT(rng.begin <= rng.end);
|
||||
return input_.substr(rng.begin, rng.end - rng.begin);
|
||||
}
|
||||
|
||||
void common_chat_msg_parser::add_content(const std::string &content) {
|
||||
result_.content += content;
|
||||
}
|
||||
|
||||
void common_chat_msg_parser::add_reasoning_content(const std::string &reasoning_content) {
|
||||
result_.reasoning_content += reasoning_content;
|
||||
}
|
||||
|
||||
bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::string & id, const std::string & arguments) {
|
||||
if (name.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
common_chat_tool_call tool_call;
|
||||
tool_call.name = name;
|
||||
tool_call.arguments = arguments;
|
||||
tool_call.id = id;
|
||||
|
||||
// LOG_DBG("Tool call arguments:\n\traw: %s\n\tresult: %s\n", arguments.c_str(), tool_call.arguments.c_str());
|
||||
result_.tool_calls.emplace_back(tool_call);
|
||||
return true;
|
||||
}
|
||||
bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
|
||||
std::string name = tool_call.contains("name") ? tool_call.at("name") : "";
|
||||
std::string id = tool_call.contains("id") ? tool_call.at("id") : "";
|
||||
std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : "";
|
||||
return add_tool_call(name, id, arguments);
|
||||
}
|
||||
|
||||
bool common_chat_msg_parser::add_tool_calls(const json & arr) {
|
||||
for (const auto & item : arr) {
|
||||
if (!add_tool_call(item)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
void common_chat_msg_parser::finish() {
|
||||
if (!is_partial_ && pos_ != input_.size()) {
|
||||
throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_));
|
||||
}
|
||||
}
|
||||
|
||||
bool common_chat_msg_parser::consume_spaces() {
|
||||
const auto length = input_.size();
|
||||
auto consumed = false;
|
||||
while (pos_ < length && std::isspace(input_[pos_])) {
|
||||
++pos_;
|
||||
consumed = true;
|
||||
}
|
||||
return consumed;
|
||||
}
|
||||
|
||||
bool common_chat_msg_parser::try_consume_literal(const std::string & literal) {
|
||||
auto pos = pos_;
|
||||
for (auto i = 0u; i < literal.size(); ++i) {
|
||||
if (pos >= input_.size()) {
|
||||
return false;
|
||||
}
|
||||
if (input_[pos] != literal[i]) {
|
||||
return false;
|
||||
}
|
||||
++pos;
|
||||
}
|
||||
pos_ = pos;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_literal(const std::string & literal) {
|
||||
auto idx = input_.find(literal, pos_);
|
||||
if (idx != std::string::npos) {
|
||||
find_regex_result res;
|
||||
res.prelude = input_.substr(pos_, idx - pos_);
|
||||
auto end = idx + literal.size();
|
||||
res.groups.emplace_back(common_string_range{idx, end});
|
||||
move_to(end);
|
||||
return res;
|
||||
}
|
||||
if (is_partial_) {
|
||||
idx = string_find_partial_stop(input_, literal);
|
||||
if (idx != std::string::npos && idx >= pos_) {
|
||||
find_regex_result res;
|
||||
res.prelude = input_.substr(pos_, idx - pos_);
|
||||
auto end = input_.size();
|
||||
res.groups.emplace_back(common_string_range{idx, end});
|
||||
move_to(end);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
void common_chat_msg_parser::consume_literal(const std::string & literal) {
|
||||
if (!try_consume_literal(literal)) {
|
||||
throw common_chat_msg_partial_exception(literal);
|
||||
}
|
||||
}
|
||||
|
||||
bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) {
|
||||
auto handle_reasoning = [&](const std::string & reasoning, bool closed) {
|
||||
auto stripped_reasoning = string_strip(reasoning);
|
||||
if (stripped_reasoning.empty()) {
|
||||
return;
|
||||
}
|
||||
if (syntax_.reasoning_in_content) {
|
||||
add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "<think>" : start_think);
|
||||
add_content(stripped_reasoning);
|
||||
if (closed) {
|
||||
add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "</think>" : end_think);
|
||||
}
|
||||
} else {
|
||||
add_reasoning_content(stripped_reasoning);
|
||||
}
|
||||
};
|
||||
if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
|
||||
if (syntax_.thinking_forced_open || try_consume_literal(start_think)) {
|
||||
if (auto res = try_find_literal(end_think)) {
|
||||
handle_reasoning(res->prelude, /* closed */ true);
|
||||
consume_spaces();
|
||||
return true;
|
||||
}
|
||||
auto rest = consume_rest();
|
||||
if (!rest.empty()) {
|
||||
handle_reasoning(rest, /* closed */ !is_partial());
|
||||
}
|
||||
if (!syntax_.thinking_forced_open) {
|
||||
throw common_chat_msg_partial_exception(end_think);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string common_chat_msg_parser::consume_rest() {
|
||||
auto rest = input_.substr(pos_);
|
||||
pos_ = input_.size();
|
||||
return rest;
|
||||
}
|
||||
|
||||
// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback.
|
||||
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from) {
|
||||
auto m = regex.search(input_, from == std::string::npos ? pos_ : from);
|
||||
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
|
||||
if (is_partial()) {
|
||||
throw common_chat_msg_partial_exception(regex.str());
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
|
||||
pos_ = m.groups[0].end;
|
||||
|
||||
return find_regex_result{prelude, m.groups};
|
||||
}
|
||||
|
||||
common_chat_msg_parser::find_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) {
|
||||
if (auto result = try_consume_regex(regex)) {
|
||||
return *result;
|
||||
}
|
||||
throw common_chat_msg_partial_exception(regex.str());
|
||||
}
|
||||
|
||||
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_consume_regex(const common_regex & regex) {
|
||||
auto m = regex.search(input_, pos_);
|
||||
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
|
||||
if (is_partial()) {
|
||||
throw common_chat_msg_partial_exception(regex.str());
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
if (m.groups[0].begin != pos_) {
|
||||
// Didn't match at the current position.
|
||||
return std::nullopt;
|
||||
}
|
||||
pos_ = m.groups[0].end;
|
||||
|
||||
return find_regex_result {
|
||||
/* .prelude = */ "",
|
||||
m.groups,
|
||||
};
|
||||
}
|
||||
|
||||
std::optional<common_json> common_chat_msg_parser::try_consume_json() {
|
||||
auto it = input_.cbegin() + pos_;
|
||||
const auto end = input_.cend();
|
||||
common_json result;
|
||||
if (!common_json_parse(it, end, healing_marker_, result)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
pos_ = std::distance(input_.cbegin(), it);
|
||||
if (result.healing_marker.marker.empty()) {
|
||||
// No healing marker, just return the parsed json
|
||||
return result;
|
||||
}
|
||||
if (!is_partial()) {
|
||||
throw common_chat_msg_partial_exception("JSON");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
common_json common_chat_msg_parser::consume_json() {
|
||||
if (auto result = try_consume_json()) {
|
||||
return *result;
|
||||
}
|
||||
throw common_chat_msg_partial_exception("JSON");
|
||||
}
|
||||
|
||||
common_chat_msg_parser::consume_json_result common_chat_msg_parser::consume_json_with_dumped_args(
|
||||
const std::vector<std::vector<std::string>> & args_paths,
|
||||
const std::vector<std::vector<std::string>> & content_paths
|
||||
) {
|
||||
if (auto result = try_consume_json_with_dumped_args(args_paths, content_paths)) {
|
||||
return *result;
|
||||
}
|
||||
throw common_chat_msg_partial_exception("JSON");
|
||||
}
|
||||
|
||||
std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parser::try_consume_json_with_dumped_args(
|
||||
const std::vector<std::vector<std::string>> & args_paths,
|
||||
const std::vector<std::vector<std::string>> & content_paths
|
||||
) {
|
||||
auto partial = try_consume_json();
|
||||
if (!partial) {
|
||||
return std::nullopt;
|
||||
}
|
||||
auto is_arguments_path = [&](const std::vector<std::string> & path) {
|
||||
return std::find(args_paths.begin(), args_paths.end(), path) != args_paths.end();
|
||||
};
|
||||
auto is_content_path = [&](const std::vector<std::string> & path) {
|
||||
return std::find(content_paths.begin(), content_paths.end(), path) != content_paths.end();
|
||||
};
|
||||
|
||||
if (partial->healing_marker.marker.empty()) {
|
||||
if (args_paths.empty()) {
|
||||
// No arguments to dump, and JSON was parsed fully.
|
||||
return consume_json_result {
|
||||
partial->json,
|
||||
/* .is_partial = */ false,
|
||||
};
|
||||
}
|
||||
if (is_arguments_path({})) {
|
||||
// Entire JSON is the arguments and was parsed fully.
|
||||
return consume_json_result {
|
||||
partial->json.dump(),
|
||||
/* .is_partial = */ false,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
LOG_DBG("Parsed partial JSON: %s (json_healing_marker: %s)\n", partial->json.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
|
||||
|
||||
auto found_healing_marker = false;
|
||||
std::vector<std::string> path;
|
||||
std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
|
||||
if (is_arguments_path(path)) {
|
||||
auto arguments = j.dump();
|
||||
if (is_partial() && !partial->healing_marker.marker.empty()) {
|
||||
auto idx = arguments.find(partial->healing_marker.json_dump_marker);
|
||||
if (idx != std::string::npos) {
|
||||
arguments.resize(idx);
|
||||
found_healing_marker = true;
|
||||
}
|
||||
if (arguments == "\"") {
|
||||
// This happens because of completing `:"$magic` after `"arguments"`
|
||||
arguments = "";
|
||||
}
|
||||
}
|
||||
return arguments;
|
||||
}
|
||||
if (is_content_path(path)) {
|
||||
if (!j.is_string()) {
|
||||
throw std::runtime_error("Content path must be a string");
|
||||
}
|
||||
std::string str = j;
|
||||
auto idx = str.find(partial->healing_marker.marker); // not using json_dump_marker as we're inside a string
|
||||
if (idx != std::string::npos) {
|
||||
str.resize(idx);
|
||||
found_healing_marker = true;
|
||||
}
|
||||
return str;
|
||||
}
|
||||
if (j.is_object()) {
|
||||
auto obj = json::object();
|
||||
for (const auto & p : j.items()) {
|
||||
const auto & key = p.key();
|
||||
const auto & value = p.value();
|
||||
const std::string key_str = key; // NOLINT
|
||||
auto idx = key_str.find(healing_marker_);
|
||||
if (idx != std::string::npos) {
|
||||
found_healing_marker = true;
|
||||
break;
|
||||
}
|
||||
path.push_back(key_str);
|
||||
if (value.is_string()) {
|
||||
const std::string value_str = value;
|
||||
if (value_str.find(healing_marker_) != std::string::npos) {
|
||||
found_healing_marker = true;
|
||||
if (is_content_path(path)) {
|
||||
if (partial->healing_marker.marker == partial->healing_marker.json_dump_marker) {
|
||||
// The healing occurred inside the string: good. Otherwise we just ditch the entire key/value pair.
|
||||
obj[key] = remove_unsupported_healings_and_dump_args(value);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
obj[key] = value;
|
||||
} else {
|
||||
obj[key] = remove_unsupported_healings_and_dump_args(value);
|
||||
}
|
||||
path.pop_back();
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
if (j.is_array()) {
|
||||
auto arr = json::array();
|
||||
for (const auto & value : j) {
|
||||
if (value.is_string()) {
|
||||
std::string str = value;
|
||||
auto idx = str.find(healing_marker_);
|
||||
if (idx != std::string::npos) {
|
||||
// Don't heal array values that aren't in the arguments.
|
||||
found_healing_marker = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
arr.push_back(remove_unsupported_healings_and_dump_args(value));
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
return j;
|
||||
};
|
||||
|
||||
auto cleaned = remove_unsupported_healings_and_dump_args(partial->json);
|
||||
LOG_DBG("Cleaned up JSON %s to %s (json_healing_marker : '%s')\n", partial->json.dump().c_str(), cleaned.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
|
||||
return consume_json_result {
|
||||
cleaned,
|
||||
/* .is_partial = */ found_healing_marker,
|
||||
};
|
||||
}
|
116
common/chat-parser.h
Normal file
116
common/chat-parser.h
Normal file
|
@ -0,0 +1,116 @@
|
|||
#pragma once
|
||||
|
||||
#include "chat.h"
|
||||
#include "json-partial.h"
|
||||
#include "json.hpp"
|
||||
#include "regex-partial.h"
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
class common_chat_msg_partial_exception : public std::runtime_error {
|
||||
public:
|
||||
common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
|
||||
};
|
||||
|
||||
class common_chat_msg_parser {
|
||||
std::string input_;
|
||||
bool is_partial_;
|
||||
common_chat_syntax syntax_;
|
||||
std::string healing_marker_;
|
||||
|
||||
size_t pos_ = 0;
|
||||
common_chat_msg result_;
|
||||
|
||||
public:
|
||||
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
const std::string & input() const { return input_; }
|
||||
size_t pos() const { return pos_; }
|
||||
const std::string & healing_marker() const { return healing_marker_; }
|
||||
const bool & is_partial() const { return is_partial_; }
|
||||
const common_chat_msg & result() const { return result_; }
|
||||
|
||||
void move_to(size_t pos) {
|
||||
if (pos > input_.size()) {
|
||||
throw std::runtime_error("Invalid position!");
|
||||
}
|
||||
pos_ = pos;
|
||||
}
|
||||
void move_back(size_t n) {
|
||||
if (pos_ < n) {
|
||||
throw std::runtime_error("Can't move back that far!");
|
||||
}
|
||||
pos_ -= n;
|
||||
}
|
||||
|
||||
// Get the substring of the input at the given range
|
||||
std::string str(const common_string_range & rng) const;
|
||||
|
||||
// Appends to the result.content field
|
||||
void add_content(const std::string & content);
|
||||
|
||||
// Appends to the result.reasoning_content field
|
||||
void add_reasoning_content(const std::string & reasoning_content);
|
||||
|
||||
// Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything.
|
||||
bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments);
|
||||
|
||||
// Adds a tool call using the "name", "id" and "arguments" fields of the json object
|
||||
bool add_tool_call(const nlohmann::ordered_json & tool_call);
|
||||
|
||||
// Adds an array of tool calls using their "name", "id" and "arguments" fields.
|
||||
bool add_tool_calls(const nlohmann::ordered_json & arr);
|
||||
|
||||
void finish();
|
||||
|
||||
bool consume_spaces();
|
||||
|
||||
void consume_literal(const std::string & literal);
|
||||
|
||||
bool try_parse_reasoning(const std::string & start_think, const std::string & end_think);
|
||||
|
||||
std::string consume_rest();
|
||||
|
||||
struct find_regex_result {
|
||||
std::string prelude;
|
||||
std::vector<common_string_range> groups;
|
||||
};
|
||||
|
||||
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos);
|
||||
|
||||
bool try_consume_literal(const std::string & literal);
|
||||
|
||||
std::optional<find_regex_result> try_find_literal(const std::string & literal);
|
||||
|
||||
find_regex_result consume_regex(const common_regex & regex);
|
||||
|
||||
std::optional<find_regex_result> try_consume_regex(const common_regex & regex);
|
||||
|
||||
std::optional<common_json> try_consume_json();
|
||||
common_json consume_json();
|
||||
|
||||
struct consume_json_result {
|
||||
nlohmann::ordered_json value;
|
||||
bool is_partial;
|
||||
};
|
||||
|
||||
/*
|
||||
Consume (possibly partial) json and converts specific subtrees to (possibly truncated) JSON strings.
|
||||
|
||||
By default, object keys can't be truncated, nor can string values (their corresponding key is removed,
|
||||
e.g. `{"foo": "bar", "baz": "b` -> `{"foo": "bar"}`
|
||||
|
||||
But one can allow subpaths to be kept truncated, and possibly json-dumped to truncated json strings
|
||||
- with `content_paths={{"foo"}}` -> `{"foo": "b` -> {"foo": "b"}`
|
||||
- with `args_paths={{"foo"}}` -> `{"foo": {"b` -> `{"foo": "{b"}`
|
||||
*/
|
||||
consume_json_result consume_json_with_dumped_args(
|
||||
const std::vector<std::vector<std::string>> & args_paths = {},
|
||||
const std::vector<std::vector<std::string>> & content_paths = {}
|
||||
);
|
||||
std::optional<consume_json_result> try_consume_json_with_dumped_args(
|
||||
const std::vector<std::vector<std::string>> & args_paths = {},
|
||||
const std::vector<std::vector<std::string>> & content_paths = {}
|
||||
);
|
||||
};
|
1309
common/chat.cpp
1309
common/chat.cpp
File diff suppressed because it is too large
Load diff
|
@ -3,6 +3,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include <functional>
|
||||
#include <chrono>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -13,11 +14,19 @@ struct common_chat_tool_call {
|
|||
std::string name;
|
||||
std::string arguments;
|
||||
std::string id;
|
||||
|
||||
bool operator==(const common_chat_tool_call & other) const {
|
||||
return name == other.name && arguments == other.arguments && id == other.id;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_content_part {
|
||||
std::string type;
|
||||
std::string text;
|
||||
|
||||
bool operator==(const common_chat_msg_content_part & other) const {
|
||||
return type == other.type && text == other.text;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg {
|
||||
|
@ -28,6 +37,51 @@ struct common_chat_msg {
|
|||
std::string reasoning_content;
|
||||
std::string tool_name;
|
||||
std::string tool_call_id;
|
||||
|
||||
template <class T> T to_json_oaicompat() const;
|
||||
|
||||
bool empty() const {
|
||||
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
|
||||
}
|
||||
void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
|
||||
for (auto i = 0u; i < tool_calls.size(); i++) {
|
||||
if (ids_cache.size() <= i) {
|
||||
auto id = tool_calls[i].id;
|
||||
if (id.empty()) {
|
||||
id = gen_tool_call_id();
|
||||
}
|
||||
ids_cache.push_back(id);
|
||||
}
|
||||
tool_calls[i].id = ids_cache[i];
|
||||
}
|
||||
}
|
||||
bool operator==(const common_chat_msg & other) const {
|
||||
return role == other.role
|
||||
&& content == other.content
|
||||
&& content_parts == other.content_parts
|
||||
&& tool_calls == other.tool_calls
|
||||
&& reasoning_content == other.reasoning_content
|
||||
&& tool_name == other.tool_name
|
||||
&& tool_call_id == other.tool_call_id;
|
||||
}
|
||||
bool operator!=(const common_chat_msg & other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_diff {
|
||||
// std::string reasoning_content_delta;
|
||||
std::string content_delta;
|
||||
size_t tool_call_index = std::string::npos;
|
||||
common_chat_tool_call tool_call_delta;
|
||||
|
||||
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg);
|
||||
|
||||
bool operator==(const common_chat_msg_diff & other) const {
|
||||
return content_delta == other.content_delta
|
||||
&& tool_call_index == other.tool_call_index
|
||||
&& tool_call_delta == other.tool_call_delta;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_tool {
|
||||
|
@ -49,14 +103,11 @@ enum common_chat_format {
|
|||
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
|
||||
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING,
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
@ -71,7 +122,8 @@ struct common_chat_templates_inputs {
|
|||
std::vector<common_chat_tool> tools;
|
||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
bool parallel_tool_calls = false;
|
||||
bool extract_reasoning = true;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
bool enable_thinking = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
};
|
||||
|
||||
|
@ -80,11 +132,20 @@ struct common_chat_params {
|
|||
std::string prompt;
|
||||
std::string grammar;
|
||||
bool grammar_lazy = false;
|
||||
bool thinking_forced_open = false;
|
||||
std::vector<common_grammar_trigger> grammar_triggers;
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
};
|
||||
|
||||
struct common_chat_syntax {
|
||||
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
|
||||
bool reasoning_in_content = false;
|
||||
bool thinking_forced_open = false;
|
||||
};
|
||||
|
||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
||||
|
||||
|
@ -121,8 +182,9 @@ std::string common_chat_format_example(
|
|||
const struct common_chat_templates * tmpls,
|
||||
bool use_jinja);
|
||||
|
||||
std::string common_chat_format_name(common_chat_format format);
|
||||
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
|
||||
const char* common_chat_format_name(common_chat_format format);
|
||||
const char* common_reasoning_format_name(common_reasoning_format format);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
|
||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
||||
|
||||
|
@ -135,3 +197,5 @@ template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common
|
|||
// T can be std::string containing JSON or nlohmann::ordered_json
|
||||
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
|
||||
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
|
||||
|
||||
template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|
||||
|
|
|
@ -857,7 +857,7 @@ std::string fs_get_cache_directory() {
|
|||
if (getenv("LLAMA_CACHE")) {
|
||||
cache_directory = std::getenv("LLAMA_CACHE");
|
||||
} else {
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX)
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__)
|
||||
if (std::getenv("XDG_CACHE_HOME")) {
|
||||
cache_directory = std::getenv("XDG_CACHE_HOME");
|
||||
} else {
|
||||
|
|
|
@ -111,7 +111,7 @@ enum common_grammar_trigger_type {
|
|||
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
|
||||
};
|
||||
|
||||
struct common_grammar_trigger {
|
||||
|
@ -364,6 +364,7 @@ struct common_params {
|
|||
bool use_jinja = false; // NOLINT
|
||||
bool enable_chat_template = true;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
int reasoning_budget = -1;
|
||||
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
|
||||
|
||||
std::vector<std::string> api_keys;
|
||||
|
|
255
common/json-partial.cpp
Normal file
255
common/json-partial.cpp
Normal file
|
@ -0,0 +1,255 @@
|
|||
#include <json-partial.h>
|
||||
#include "ggml.h"
|
||||
#include "log.h"
|
||||
#include <string>
|
||||
|
||||
#include <json.hpp>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
enum common_json_stack_element_type {
|
||||
COMMON_JSON_STACK_ELEMENT_OBJECT,
|
||||
COMMON_JSON_STACK_ELEMENT_KEY,
|
||||
COMMON_JSON_STACK_ELEMENT_ARRAY,
|
||||
};
|
||||
|
||||
struct common_json_stack_element {
|
||||
common_json_stack_element_type type;
|
||||
std::string key;
|
||||
};
|
||||
|
||||
bool common_json_parse(
|
||||
const std::string & input,
|
||||
const std::string & healing_marker,
|
||||
common_json & out)
|
||||
{
|
||||
std::string::const_iterator it = input.begin();
|
||||
const auto end = input.end();
|
||||
return common_json_parse(it, end, healing_marker, out);
|
||||
}
|
||||
|
||||
bool common_json_parse(
|
||||
std::string::const_iterator & it,
|
||||
const std::string::const_iterator & end,
|
||||
const std::string & healing_marker,
|
||||
common_json & out)
|
||||
{
|
||||
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
||||
struct json_error_locator : public nlohmann::json_sax<json> {
|
||||
std::size_t position;
|
||||
bool found_error;
|
||||
std::string last_token;
|
||||
std::string exception_message;
|
||||
std::vector<common_json_stack_element> stack;
|
||||
|
||||
json_error_locator() : position(0), found_error(false) {}
|
||||
|
||||
bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
|
||||
this->position = position - 1;
|
||||
this->found_error = true;
|
||||
this->last_token = last_token;
|
||||
this->exception_message = ex.what();
|
||||
return false;
|
||||
}
|
||||
void close_value() {
|
||||
if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
|
||||
stack.pop_back();
|
||||
}
|
||||
}
|
||||
bool null() override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool boolean(bool) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool number_integer(number_integer_t) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool number_unsigned(number_unsigned_t) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool number_float(number_float_t, const string_t &) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool string(string_t &) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool binary(binary_t &) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool start_object(std::size_t) override { // NOLINT
|
||||
stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
|
||||
return true;
|
||||
}
|
||||
bool end_object() override {
|
||||
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
|
||||
stack.pop_back();
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool key(string_t & key) override { // NOLINT
|
||||
stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
|
||||
return true;
|
||||
}
|
||||
bool start_array(std::size_t) override { // NOLINT
|
||||
stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
|
||||
return true;
|
||||
}
|
||||
bool end_array() override {
|
||||
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
|
||||
stack.pop_back();
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
};
|
||||
json_error_locator err_loc;
|
||||
auto start = it;
|
||||
json::sax_parse(it, end, &err_loc);
|
||||
|
||||
if (err_loc.found_error) {
|
||||
it = start;
|
||||
auto temptative_end = it + err_loc.position;
|
||||
// LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
|
||||
|
||||
auto input = std::string(it, temptative_end);
|
||||
try {
|
||||
out.json = json::parse(input);
|
||||
// out.json = json::parse(it, temptative_end);
|
||||
it = temptative_end;
|
||||
return true;
|
||||
} catch (const std::exception & ex) {
|
||||
// No, needs healing.
|
||||
LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
|
||||
}
|
||||
auto can_parse = [](const std::string & str) {
|
||||
try {
|
||||
auto _ = json::parse(str); // NOLINT
|
||||
return true;
|
||||
} catch (const std::exception &) {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
if (!healing_marker.empty() && !err_loc.stack.empty()) {
|
||||
std::string str(it, temptative_end);
|
||||
auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
|
||||
if (last_non_sp_pos == std::string::npos) {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
|
||||
}
|
||||
auto last_non_sp_char = str[last_non_sp_pos];
|
||||
// Used to detect stops on a number, which may not be complete.
|
||||
auto was_maybe_number = [&]() {
|
||||
if (!str.empty() && std::isspace(str.back())) {
|
||||
return false;
|
||||
}
|
||||
return std::isdigit(last_non_sp_char) ||
|
||||
last_non_sp_char == '.' ||
|
||||
last_non_sp_char == 'e' ||
|
||||
last_non_sp_char == 'E' ||
|
||||
last_non_sp_char == '-';
|
||||
};
|
||||
|
||||
std::string closing;
|
||||
for (size_t i = err_loc.stack.size(); i > 0; i--) {
|
||||
auto & el = err_loc.stack[i - 1];
|
||||
if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
|
||||
closing += "}";
|
||||
} else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
|
||||
closing += "]";
|
||||
} else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
|
||||
throw std::runtime_error("Unexpected stack element type");
|
||||
}
|
||||
}
|
||||
|
||||
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
|
||||
|
||||
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
|
||||
// We're inside an object value
|
||||
if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
|
||||
// Was about to create an object value
|
||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + ": 1" + closing)) {
|
||||
str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
|
||||
} else if (last_non_sp_char == '{' && can_parse(str + closing)) {
|
||||
// Was about to create an object
|
||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
|
||||
} else if (can_parse(str + "\"" + closing)) {
|
||||
// Was inside an object value string
|
||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||
// Was inside an object value string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||
} else {
|
||||
// find last :
|
||||
auto last_pos = str.find_last_of(':');
|
||||
if (last_pos == std::string::npos) {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
|
||||
}
|
||||
// Cutting back to opening : for object value
|
||||
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
}
|
||||
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
|
||||
if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
|
||||
// Was about to create an array value
|
||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + "\"" + closing)) {
|
||||
// Was inside an array value string
|
||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||
// Was inside an array value string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
|
||||
// Had just finished a value
|
||||
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
|
||||
} else {
|
||||
auto last_pos = str.find_last_of("[,");
|
||||
if (last_pos == std::string::npos) {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
|
||||
}
|
||||
// Cutting back to last [ or , for array value
|
||||
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
}
|
||||
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
|
||||
if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
|
||||
(last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
|
||||
// Was about to create an object key+value
|
||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
|
||||
} else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
|
||||
// Was about to create an object key+value
|
||||
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
|
||||
} else if (can_parse(str + "\": 1" + closing)) {
|
||||
// Was inside an object key string
|
||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
|
||||
// Was inside an object key string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
|
||||
} else {
|
||||
auto last_pos = str.find_last_of(':');
|
||||
if (last_pos == std::string::npos) {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
|
||||
}
|
||||
// fprintf(stderr, "Cutting back to last : for object key+value\n");
|
||||
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
|
||||
}
|
||||
// fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
|
||||
out.json = json::parse(str);
|
||||
it = temptative_end;
|
||||
return true;
|
||||
}
|
||||
// TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
|
||||
// fprintf(stderr, "Closing: TODO\n");
|
||||
return false;
|
||||
}
|
||||
out.json = json::parse(it, end);
|
||||
it = end;
|
||||
return true;
|
||||
}
|
37
common/json-partial.h
Normal file
37
common/json-partial.h
Normal file
|
@ -0,0 +1,37 @@
|
|||
#pragma once
|
||||
#include <json.hpp>
|
||||
|
||||
// Healing marker (empty if the JSON was fully parsed / wasn't healed).
|
||||
struct common_healing_marker {
|
||||
// Raw marker.
|
||||
std::string marker;
|
||||
|
||||
// Cutting the `common_json.json.dump()` string at the (only) occurrence of this marker should yield the original partial JSON string (modulo spaces / if it had the same dump format).
|
||||
std::string json_dump_marker;
|
||||
};
|
||||
|
||||
// Represents a parsed JSON object, with its optional healing marker (a JSON dump fragment that can be used to find the position of healing in the JSON dump string)
|
||||
struct common_json {
|
||||
nlohmann::ordered_json json;
|
||||
|
||||
common_healing_marker healing_marker;
|
||||
};
|
||||
|
||||
// Parse the JSON string, healing (closing) any partial JSON if `healing_marker` is not empty.
|
||||
//
|
||||
// Healing completes partial JSON strings by adding a (possibly modified) healing marker, then whatever is needed to close the JSON.
|
||||
// This allows to parse the resulting healed JSON string, yet be able to cut it again if needed at the healing marker.
|
||||
// (this is used when parsing JSON outputs from the models, then crafting partial JSONs for the partial tool calls in OAI format).
|
||||
//
|
||||
// For instance, parsing `{` with a healing marker `foo` will produce a healed JSON `{"foo":1}`, w/ json_dump_marker = `"foo"` (which can be used to break the JSON again).
|
||||
bool common_json_parse(
|
||||
const std::string & input,
|
||||
const std::string & healing_marker,
|
||||
common_json & out);
|
||||
|
||||
// Parse the JSON string (see overload above), but advancing an iterator to the end of the input when the (potentially partial) parsing succeeds.
|
||||
bool common_json_parse(
|
||||
std::string::const_iterator & it,
|
||||
const std::string::const_iterator & end,
|
||||
const std::string & healing_marker,
|
||||
common_json & out);
|
|
@ -161,7 +161,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
} else {
|
||||
std::vector<std::string> patterns_at_start;
|
||||
std::vector<std::string> trigger_patterns;
|
||||
std::vector<std::string> patterns_anywhere;
|
||||
std::vector<llama_token> trigger_tokens;
|
||||
for (const auto & trigger : params.grammar_triggers) {
|
||||
|
@ -173,10 +173,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
|
||||
{
|
||||
const auto & pattern = trigger.value;
|
||||
(trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
|
||||
patterns_anywhere.push_back(trigger.value);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
||||
{
|
||||
trigger_patterns.push_back(trigger.value);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
||||
|
@ -190,10 +193,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> trigger_patterns;
|
||||
if (!patterns_at_start.empty()) {
|
||||
trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
|
||||
}
|
||||
if (!patterns_anywhere.empty()) {
|
||||
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
|
||||
}
|
||||
|
|
|
@ -2643,7 +2643,7 @@ class QwenModel(TextModel):
|
|||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
|
||||
@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM")
|
||||
@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration")
|
||||
class Qwen2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN2
|
||||
|
||||
|
@ -2667,8 +2667,9 @@ class Qwen2Model(TextModel):
|
|||
name = f"model.{name}" # map to Qwen2ForCausalLM tensors
|
||||
if "language_model." in name:
|
||||
name = name.replace("language_model.", "") # for InternVL
|
||||
if name.startswith("mlp") or name.startswith("vision_model"):
|
||||
# skip visual tensors
|
||||
if name.startswith("mlp") or name.startswith("multi_modal_projector") \
|
||||
or name.startswith("vision_model") or name.startswith("audio_tower"):
|
||||
# skip vision and audio tensors
|
||||
return []
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
@ -5993,11 +5994,11 @@ class UltravoxModel(TextModel):
|
|||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
raise NotImplementedError("Ultravox does not have text decoder. Please use --mmproj argument")
|
||||
raise NotImplementedError("Ultravox does not have text decoder. Instead, it uses Llama or other models for text. If you want to get the audio encoder, please use --mmproj argument")
|
||||
|
||||
|
||||
@ModelBase.register("UltravoxModel")
|
||||
class UltravoxAudioModel(MmprojModel):
|
||||
@ModelBase.register("Qwen2AudioForConditionalGeneration")
|
||||
class WhisperEncoderModel(MmprojModel):
|
||||
has_vision_encoder = False # no vision encoder
|
||||
has_audio_encoder = True
|
||||
|
||||
|
@ -6009,10 +6010,9 @@ class UltravoxAudioModel(MmprojModel):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.ULTRAVOX)
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2A)
|
||||
self.gguf_writer.add_audio_num_mel_bins(self.hparams["num_mel_bins"])
|
||||
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))
|
||||
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
|
||||
|
||||
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||
del bid, new_name, n_dims # unused
|
||||
|
@ -6023,6 +6023,10 @@ class UltravoxAudioModel(MmprojModel):
|
|||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
if name.startswith("language_model."):
|
||||
# skip language model tensors
|
||||
return []
|
||||
|
||||
# prevent clash naming with vision tensors
|
||||
if name.startswith("multi_modal_projector"):
|
||||
name = "audio." + name
|
||||
|
@ -6033,6 +6037,16 @@ class UltravoxAudioModel(MmprojModel):
|
|||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register("UltravoxModel")
|
||||
class UltravoxWhisperEncoderModel(WhisperEncoderModel):
|
||||
has_vision_encoder = False # no vision encoder
|
||||
has_audio_encoder = True
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ llama-server -hf ggml-org/gemma-3-4b-it-GGUF --no-mmproj-offload
|
|||
|
||||
## Pre-quantized models
|
||||
|
||||
These are ready-to-use models, most of them come with `Q4_K_M` quantization by default. They can be found at the Hugging Face page of the ggml-org: https://huggingface.co/ggml-org
|
||||
These are ready-to-use models, most of them come with `Q4_K_M` quantization by default. They can be found at the Hugging Face page of the ggml-org: https://huggingface.co/collections/ggml-org/multimodal-ggufs-68244e01ff1f39e5bebeeedc
|
||||
|
||||
Replaces the `(tool_name)` with the name of binary you want to use. For example, `llama-mtmd-cli` or `llama-server`
|
||||
|
||||
|
@ -81,6 +81,10 @@ NOTE: some models may require large context window, for example: `-c 8192`
|
|||
|
||||
# Llama 4 Scout
|
||||
(tool_name) -hf ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF
|
||||
|
||||
# Moondream2 20250414 version
|
||||
(tool_name) -hf ggml-org/moondream2-20250414-GGUF
|
||||
|
||||
```
|
||||
|
||||
**Audio models**:
|
||||
|
@ -89,4 +93,8 @@ NOTE: some models may require large context window, for example: `-c 8192`
|
|||
# Ultravox 0.5
|
||||
(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF
|
||||
(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF
|
||||
|
||||
# Qwen2-Audio and SeaLLM-Audio
|
||||
# note: no pre-quantized GGUF this model, as they have very poor result
|
||||
# ref: https://github.com/ggml-org/llama.cpp/pull/13760
|
||||
```
|
||||
|
|
|
@ -3504,6 +3504,19 @@ void ggml_cpu_init(void) {
|
|||
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
|
||||
|
||||
GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
|
||||
|
||||
#ifdef GGML_USE_OPENMP
|
||||
//if (!getenv("OMP_WAIT_POLICY")) {
|
||||
// // set the wait policy to active, so that OpenMP threads don't sleep
|
||||
// putenv("OMP_WAIT_POLICY=active");
|
||||
//}
|
||||
|
||||
if (!getenv("KMP_BLOCKTIME")) {
|
||||
// set the time to wait before sleeping a thread
|
||||
// this is less aggressive than setting the wait policy to active, but should achieve similar results in most cases
|
||||
putenv("KMP_BLOCKTIME=200"); // 200ms
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(__ARM_ARCH)
|
||||
|
|
|
@ -546,6 +546,7 @@ class MODEL_TENSOR(IntEnum):
|
|||
A_ENC_FFN_GATE = auto()
|
||||
A_ENC_FFN_DOWN = auto()
|
||||
A_MMPROJ = auto()
|
||||
A_MMPROJ_FC = auto()
|
||||
A_MM_NORM_PRE = auto()
|
||||
A_MM_NORM_MID = auto()
|
||||
|
||||
|
@ -825,6 +826,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||
MODEL_TENSOR.A_ENC_FFN_GATE: "a.blk.{bid}.ffn_gate",
|
||||
MODEL_TENSOR.A_ENC_FFN_DOWN: "a.blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.A_MMPROJ: "mm.a.mlp.{bid}",
|
||||
MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc",
|
||||
MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre",
|
||||
MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid",
|
||||
}
|
||||
|
@ -885,6 +887,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.A_ENC_FFN_GATE,
|
||||
MODEL_TENSOR.A_ENC_FFN_DOWN,
|
||||
MODEL_TENSOR.A_MMPROJ,
|
||||
MODEL_TENSOR.A_MMPROJ_FC,
|
||||
MODEL_TENSOR.A_MM_NORM_PRE,
|
||||
MODEL_TENSOR.A_MM_NORM_MID,
|
||||
],
|
||||
|
@ -2256,6 +2259,7 @@ class VisionProjectorType:
|
|||
QWEN25VL = "qwen2.5vl_merger"
|
||||
ULTRAVOX = "ultravox"
|
||||
INTERNVL = "internvl"
|
||||
QWEN2A = "qwen2a" # audio
|
||||
|
||||
|
||||
# Items here are (block size, type size)
|
||||
|
|
|
@ -1165,6 +1165,10 @@ class TensorNameMap:
|
|||
"audio.multi_modal_projector.linear_{bid}", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_MMPROJ_FC: (
|
||||
"audio.multi_modal_projector.linear", # qwen2audio
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_MM_NORM_PRE: (
|
||||
"audio.multi_modal_projector.ln_pre", # ultravox
|
||||
),
|
||||
|
|
|
@ -474,6 +474,7 @@ extern "C" {
|
|||
LLAMA_API int64_t llama_time_us(void);
|
||||
|
||||
LLAMA_API size_t llama_max_devices(void);
|
||||
LLAMA_API size_t llama_max_parallel_sequences(void);
|
||||
|
||||
LLAMA_API bool llama_supports_mmap (void);
|
||||
LLAMA_API bool llama_supports_mlock (void);
|
||||
|
|
BIN
models/ggml-vocab-nomic-bert-moe.gguf
Normal file
BIN
models/ggml-vocab-nomic-bert-moe.gguf
Normal file
Binary file not shown.
112
models/ggml-vocab-nomic-bert-moe.gguf.inp
Normal file
112
models/ggml-vocab-nomic-bert-moe.gguf.inp
Normal file
|
@ -0,0 +1,112 @@
|
|||
ied 4 ½ months
|
||||
__ggml_vocab_test__
|
||||
Führer
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
Hello world
|
||||
__ggml_vocab_test__
|
||||
Hello world
|
||||
__ggml_vocab_test__
|
||||
Hello World
|
||||
__ggml_vocab_test__
|
||||
Hello World
|
||||
__ggml_vocab_test__
|
||||
Hello World!
|
||||
__ggml_vocab_test__
|
||||
Hello, world!
|
||||
__ggml_vocab_test__
|
||||
Hello, world!
|
||||
__ggml_vocab_test__
|
||||
this is 🦙.cpp
|
||||
__ggml_vocab_test__
|
||||
w048 7tuijk dsdfhu
|
||||
__ggml_vocab_test__
|
||||
нещо на Български
|
||||
__ggml_vocab_test__
|
||||
កាន់តែពិសេសអាចខលចេញ
|
||||
__ggml_vocab_test__
|
||||
🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
(
|
||||
__ggml_vocab_test__
|
||||
|
||||
=
|
||||
__ggml_vocab_test__
|
||||
' era
|
||||
__ggml_vocab_test__
|
||||
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
|
||||
__ggml_vocab_test__
|
||||
!!!!!!
|
||||
__ggml_vocab_test__
|
||||
3
|
||||
__ggml_vocab_test__
|
||||
33
|
||||
__ggml_vocab_test__
|
||||
333
|
||||
__ggml_vocab_test__
|
||||
3333
|
||||
__ggml_vocab_test__
|
||||
33333
|
||||
__ggml_vocab_test__
|
||||
333333
|
||||
__ggml_vocab_test__
|
||||
3333333
|
||||
__ggml_vocab_test__
|
||||
33333333
|
||||
__ggml_vocab_test__
|
||||
333333333
|
||||
__ggml_vocab_test__
|
||||
Cửa Việt
|
||||
__ggml_vocab_test__
|
||||
discards
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
|
||||
__ggml_vocab_test__
|
46
models/ggml-vocab-nomic-bert-moe.gguf.out
Normal file
46
models/ggml-vocab-nomic-bert-moe.gguf.out
Normal file
|
@ -0,0 +1,46 @@
|
|||
17 297 201 78660 21775
|
||||
72805 4097 56
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
35378 8999
|
||||
35378 8999
|
||||
35378 6661
|
||||
35378 6661
|
||||
35378 6661 38
|
||||
35378 4 8999 38
|
||||
35378 4 8999 38
|
||||
903 83 6 3 5 238 6366
|
||||
148 7709 1019 361 458 134362 104 7 71 420 1132
|
||||
14271 29 117152
|
||||
6 149561 78270 48967 64254 7616 81705
|
||||
6 247206 15 33176 16 6 247442 6 3 15755 15 144227 8705 18255 40292 158 4460 33 27686 16 6 142325 15 191 538 28 121505 450 1556 6863 10002 47 1098 16
|
||||
35378
|
||||
35378
|
||||
35378
|
||||
35378
|
||||
35378
|
||||
35378 35378
|
||||
15
|
||||
2203
|
||||
242 1615
|
||||
35378 4 113 25 5584 38 11249 621 398 6 201344 705 23638 213 9007 133 1879 2681 2592 135224 1906 6087
|
||||
6 90827
|
||||
138
|
||||
3912
|
||||
6 66000
|
||||
138 66000
|
||||
3912 66000
|
||||
6 66000 66000
|
||||
138 66000 66000
|
||||
3912 66000 66000
|
||||
6 66000 66000 66000
|
||||
199152 3763
|
||||
17116 99397
|
||||
6 247206 15 33176 16 6 247442 6 3 15755 15 144227 8705 18255 40292 158 4460 33 27686 16 6 142325 6 3 138 3912 6 66000 138 66000 3912 66000 6 66000 66000 138 66000 66000 3912 66000 66000 80308 1031 5 363 138 27 363 6 149561 78270 48967 201344 705 23638 213 9007 133 1879 2681 2592 135224 1906 6087 6 110405 1369 69112 69112 69112 14271 29 117152 5106 4765 4765 1135 164721 164721 164721 58 58 58 58 2551 90827 32 85908 87 25 272 2809 242 18 18345 764 25 7 2685 4 242 11766 398 9077 32 242 594 959 9077 87 25 1181 3249 442 4 242 397 398 1884 3060 26156 32 1401 25 26455 10 25 141 866
|
62
models/templates/Qwen-QwQ-32B.jinja
Normal file
62
models/templates/Qwen-QwQ-32B.jinja
Normal file
|
@ -0,0 +1,62 @@
|
|||
{%- if tools %}
|
||||
{{- '<|im_start|>system\n' }}
|
||||
{%- if messages[0]['role'] == 'system' %}
|
||||
{{- messages[0]['content'] }}
|
||||
{%- else %}
|
||||
{{- '' }}
|
||||
{%- endif %}
|
||||
{{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
||||
{%- for tool in tools %}
|
||||
{{- "\n" }}
|
||||
{{- tool | tojson }}
|
||||
{%- endfor %}
|
||||
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
||||
{%- else %}
|
||||
{%- if messages[0]['role'] == 'system' %}
|
||||
{{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- for message in messages %}
|
||||
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
|
||||
{%- elif message.role == "assistant" and not message.tool_calls %}
|
||||
{%- set content = message.content %}
|
||||
{%- if not loop.last %}
|
||||
{%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
|
||||
{%- endif %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
||||
{%- elif message.role == "assistant" %}
|
||||
{%- set content = message.content %}
|
||||
{%- if not loop.last %}
|
||||
{%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
|
||||
{%- endif %}
|
||||
{{- '<|im_start|>' + message.role }}
|
||||
{%- if message.content %}
|
||||
{{- '\n' + content }}
|
||||
{%- endif %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if tool_call.function is defined %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '\n<tool_call>\n{"name": "' }}
|
||||
{{- tool_call.name }}
|
||||
{{- '", "arguments": ' }}
|
||||
{{- tool_call.arguments | tojson }}
|
||||
{{- '}\n</tool_call>' }}
|
||||
{%- endfor %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- elif message.role == "tool" %}
|
||||
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
|
||||
{{- '<|im_start|>user' }}
|
||||
{%- endif %}
|
||||
{{- '\n<tool_response>\n' }}
|
||||
{{- message.content }}
|
||||
{{- '\n</tool_response>' }}
|
||||
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n<think>\n' }}
|
||||
{%- endif %}
|
85
models/templates/Qwen-Qwen3-0.6B.jinja
Normal file
85
models/templates/Qwen-Qwen3-0.6B.jinja
Normal file
|
@ -0,0 +1,85 @@
|
|||
{%- if tools %}
|
||||
{{- '<|im_start|>system\n' }}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{{- messages[0].content + '\n\n' }}
|
||||
{%- endif %}
|
||||
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
||||
{%- for tool in tools %}
|
||||
{{- "\n" }}
|
||||
{{- tool | tojson }}
|
||||
{%- endfor %}
|
||||
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
||||
{%- else %}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
||||
{%- for message in messages[::-1] %}
|
||||
{%- set index = (messages|length - 1) - loop.index0 %}
|
||||
{%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
||||
{%- set ns.multi_step_tool = false %}
|
||||
{%- set ns.last_query_index = index %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- for message in messages %}
|
||||
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
|
||||
{%- elif message.role == "assistant" %}
|
||||
{%- set content = message.content %}
|
||||
{%- set reasoning_content = '' %}
|
||||
{%- if message.reasoning_content is defined and message.reasoning_content is not none %}
|
||||
{%- set reasoning_content = message.reasoning_content %}
|
||||
{%- else %}
|
||||
{%- if '</think>' in message.content %}
|
||||
{%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
|
||||
{%- set reasoning_content = message.content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if loop.index0 > ns.last_query_index %}
|
||||
{%- if loop.last or (not loop.last and reasoning_content) %}
|
||||
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
||||
{%- else %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content }}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content }}
|
||||
{%- endif %}
|
||||
{%- if message.tool_calls %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if (loop.first and content) or (not loop.first) %}
|
||||
{{- '\n' }}
|
||||
{%- endif %}
|
||||
{%- if tool_call.function %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '<tool_call>\n{"name": "' }}
|
||||
{{- tool_call.name }}
|
||||
{{- '", "arguments": ' }}
|
||||
{%- if tool_call.arguments is string %}
|
||||
{{- tool_call.arguments }}
|
||||
{%- else %}
|
||||
{{- tool_call.arguments | tojson }}
|
||||
{%- endif %}
|
||||
{{- '}\n</tool_call>' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- elif message.role == "tool" %}
|
||||
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
||||
{{- '<|im_start|>user' }}
|
||||
{%- endif %}
|
||||
{{- '\n<tool_response>\n' }}
|
||||
{{- message.content }}
|
||||
{{- '\n</tool_response>' }}
|
||||
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- if enable_thinking is defined and enable_thinking is false %}
|
||||
{{- '<think>\n\n</think>\n\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
|
@ -25,7 +25,11 @@ llama_context::llama_context(
|
|||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
||||
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
||||
if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
|
||||
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
|
||||
}
|
||||
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch;
|
||||
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
||||
|
|
|
@ -1 +1,5 @@
|
|||
#include "llama-cparams.h"
|
||||
|
||||
size_t llama_max_parallel_sequences(void) {
|
||||
return LLAMA_MAX_PARALLEL_SEQUENCES;
|
||||
}
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
|
||||
#include <cstdint>
|
||||
|
||||
#define LLAMA_MAX_PARALLEL_SEQUENCES 64
|
||||
|
||||
struct llama_cparams {
|
||||
uint32_t n_ctx; // context size used during inference
|
||||
uint32_t n_batch;
|
||||
|
|
|
@ -1177,8 +1177,18 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
|||
for (const auto & trigger_pattern : grammar.trigger_patterns) {
|
||||
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
|
||||
grammar.awaiting_trigger = false;
|
||||
// get from the first match to the end of the string
|
||||
auto constrained_str = grammar.trigger_buffer.substr(match.position(1));
|
||||
// get from the first matched capturing group to the end of the string
|
||||
size_t start = std::string::npos;
|
||||
for (auto i = 1u; i < match.size(); i++) {
|
||||
if (match.length(i) > 0) {
|
||||
start = match.position(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (start == std::string::npos) {
|
||||
start = match.position(0);
|
||||
}
|
||||
auto constrained_str = grammar.trigger_buffer.substr(start);
|
||||
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
|
||||
grammar.trigger_buffer.clear();
|
||||
llama_grammar_accept_str(grammar, constrained_str);
|
||||
|
|
|
@ -65,8 +65,6 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||
};
|
||||
|
||||
head = 0;
|
||||
size = kv_size;
|
||||
used = 0;
|
||||
|
||||
cells.resize(kv_size);
|
||||
|
||||
|
@ -138,13 +136,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||
}
|
||||
|
||||
void llama_kv_cache_unified::clear() {
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
cells[i].pos = -1;
|
||||
cells[i].seq_id.clear();
|
||||
}
|
||||
cells.reset();
|
||||
|
||||
head = 0;
|
||||
used = 0;
|
||||
|
||||
for (auto & buf : bufs) {
|
||||
ggml_backend_buffer_clear(buf.get(), 0);
|
||||
|
@ -152,7 +146,7 @@ void llama_kv_cache_unified::clear() {
|
|||
}
|
||||
|
||||
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
uint32_t new_head = size;
|
||||
uint32_t new_head = cells.size();
|
||||
|
||||
if (p0 < 0) {
|
||||
p0 = 0;
|
||||
|
@ -162,33 +156,20 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|||
p1 = std::numeric_limits<llama_pos>::max();
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].pos >= p0 && cells[i].pos < p1) {
|
||||
if (seq_id < 0) {
|
||||
cells[i].seq_id.clear();
|
||||
} else if (cells[i].has_seq_id(seq_id)) {
|
||||
cells[i].seq_id.erase(seq_id);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (!cells.pos_in(i, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (cells[i].is_empty()) {
|
||||
// keep count of the number of used cells
|
||||
if (cells[i].pos >= 0) {
|
||||
used--;
|
||||
}
|
||||
|
||||
cells[i].pos = -1;
|
||||
|
||||
if (new_head == size) {
|
||||
new_head = i;
|
||||
}
|
||||
if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
|
||||
if (new_head == cells.size()) {
|
||||
new_head = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we freed up a slot, set head to it so searching can start there.
|
||||
if (new_head != size && new_head < head) {
|
||||
if (new_head != cells.size() && new_head < head) {
|
||||
head = new_head;
|
||||
}
|
||||
|
||||
|
@ -208,49 +189,40 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
|
|||
p1 = std::numeric_limits<llama_pos>::max();
|
||||
}
|
||||
|
||||
// otherwise, this is the KV of a Transformer-like model
|
||||
head = 0;
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (!cells.pos_in(i, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) {
|
||||
cells[i].seq_id.insert(seq_id_dst);
|
||||
if (cells.seq_has(i, seq_id_src)) {
|
||||
cells.seq_add(i, seq_id_dst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
||||
uint32_t new_head = size;
|
||||
uint32_t new_head = cells.size();
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (!cells[i].has_seq_id(seq_id)) {
|
||||
if (cells[i].pos >= 0) {
|
||||
used--;
|
||||
}
|
||||
|
||||
cells[i].pos = -1;
|
||||
cells[i].seq_id.clear();
|
||||
|
||||
if (new_head == size){
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (cells.seq_keep(i, seq_id)) {
|
||||
if (new_head == cells.size()) {
|
||||
new_head = i;
|
||||
}
|
||||
} else {
|
||||
cells[i].seq_id.clear();
|
||||
cells[i].seq_id.insert(seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
// If we freed up a slot, set head to it so searching can start there.
|
||||
if (new_head != size && new_head < head) {
|
||||
if (new_head != cells.size() && new_head < head) {
|
||||
head = new_head;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
||||
if (delta == 0) {
|
||||
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
if (shift == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t new_head = size;
|
||||
uint32_t new_head = cells.size();
|
||||
|
||||
if (p0 < 0) {
|
||||
p0 = 0;
|
||||
|
@ -260,25 +232,19 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
|
|||
p1 = std::numeric_limits<llama_pos>::max();
|
||||
}
|
||||
|
||||
// If there is no range then return early to avoid looping over the
|
||||
// If there is no range then return early to avoid looping over all cells.
|
||||
if (p0 == p1) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
|
||||
has_shift = true;
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (!cells.pos_in(i, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
cells[i].pos += delta;
|
||||
cells[i].delta += delta;
|
||||
|
||||
if (cells[i].pos < 0) {
|
||||
if (!cells[i].is_empty()) {
|
||||
used--;
|
||||
}
|
||||
cells[i].pos = -1;
|
||||
cells[i].seq_id.clear();
|
||||
if (new_head == size) {
|
||||
if (cells.seq_has(i, seq_id)) {
|
||||
if (cells.pos_add(i, shift)) {
|
||||
if (new_head == cells.size()) {
|
||||
new_head = i;
|
||||
}
|
||||
}
|
||||
|
@ -287,7 +253,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
|
|||
|
||||
// If we freed up a slot, set head to it so searching can start there.
|
||||
// Otherwise we just start the next search from the beginning.
|
||||
head = new_head != size ? new_head : 0;
|
||||
head = new_head != cells.size() ? new_head : 0;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
|
@ -308,15 +274,13 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
|
|||
return;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
|
||||
has_shift = true;
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (!cells.pos_in(i, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
{
|
||||
llama_pos p_old = cells[i].pos;
|
||||
cells[i].pos /= d;
|
||||
cells[i].delta += cells[i].pos - p_old;
|
||||
}
|
||||
if (cells.seq_has(i, seq_id)) {
|
||||
cells.pos_div(i, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -324,9 +288,9 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
|
|||
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
||||
llama_pos result = std::numeric_limits<llama_pos>::max();
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].has_seq_id(seq_id)) {
|
||||
result = std::min(result, cells[i].pos);
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (cells.seq_has(i, seq_id)) {
|
||||
result = std::min(result, cells.pos_get(i));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -340,9 +304,9 @@ llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
|||
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
||||
llama_pos result = -1;
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].has_seq_id(seq_id)) {
|
||||
result = std::max(result, cells[i].pos);
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (cells.seq_has(i, seq_id)) {
|
||||
result = std::max(result, cells.pos_get(i));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -350,25 +314,15 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
|||
}
|
||||
|
||||
void llama_kv_cache_unified::restore() {
|
||||
for (const auto & [id, cell] : recovery.cells) {
|
||||
// TODO: move to new `struct kv_cells`
|
||||
const bool is_empty0 = cells[id].is_empty();
|
||||
const bool is_empty1 = cell.is_empty();
|
||||
|
||||
if (!is_empty0 && is_empty1) {
|
||||
used--;
|
||||
} else if (is_empty0 && !is_empty1) {
|
||||
used++;
|
||||
}
|
||||
|
||||
cells[id] = cell;
|
||||
for (auto & state : recovery.states) {
|
||||
cells.set(state.i, state.cells);
|
||||
}
|
||||
|
||||
recovery.clear();
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::commit() {
|
||||
if (recovery.cells.empty()) {
|
||||
if (recovery.states.empty()) {
|
||||
LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n",
|
||||
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194");
|
||||
return;
|
||||
|
@ -382,7 +336,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|||
|
||||
auto * sched = lctx.get_sched();
|
||||
|
||||
if (has_shift) {
|
||||
if (cells.get_has_shift()) {
|
||||
if (!get_can_shift()) {
|
||||
printf("\nWARNING: The current KV cache / model configuration does not support K-shift");
|
||||
} else {
|
||||
|
@ -406,13 +360,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|||
need_reserve = true;
|
||||
}
|
||||
|
||||
{
|
||||
has_shift = false;
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
cells[i].delta = 0;
|
||||
}
|
||||
}
|
||||
cells.reset_shift();
|
||||
}}
|
||||
|
||||
if (do_defrag) {
|
||||
|
@ -443,7 +391,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|||
void llama_kv_cache_unified::defrag_sched(float thold) {
|
||||
// - do not defrag small contexts (i.e. < 2048 tokens)
|
||||
// - count the padding towards the number of used tokens
|
||||
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + n_pad)/n)) : 0.0f;
|
||||
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n)) : 0.0f;
|
||||
|
||||
// queue defragmentation for next llama_kv_cache_update
|
||||
if (fragmentation > thold) {
|
||||
|
@ -454,7 +402,7 @@ void llama_kv_cache_unified::defrag_sched(float thold) {
|
|||
}
|
||||
|
||||
void llama_kv_cache_unified::set_full() {
|
||||
n = size;
|
||||
n = cells.size();
|
||||
|
||||
// when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
|
||||
// affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
|
||||
|
@ -478,14 +426,14 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
|
|||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (head > used + 2*ubatch.n_tokens) {
|
||||
if (head > cells.get_used() + 2*ubatch.n_tokens) {
|
||||
head = 0;
|
||||
}
|
||||
|
||||
// otherwise, one cell per token.
|
||||
|
||||
if (n_tokens > size) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
|
||||
if (n_tokens > cells.size()) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -498,10 +446,10 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
|
|||
std::string ss;
|
||||
if (n_swa > 0) {
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].pos == -1) {
|
||||
if (cells.is_empty(i)) {
|
||||
ss += '.';
|
||||
} else {
|
||||
ss += std::to_string(*cells[i].seq_id.begin());
|
||||
ss += 'x';
|
||||
}
|
||||
if (i%256 == 255) {
|
||||
ss += '\n';
|
||||
|
@ -515,15 +463,16 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
|
|||
uint32_t n_tested = 0;
|
||||
|
||||
while (true) {
|
||||
if (head + n_tokens > size) {
|
||||
n_tested += size - head;
|
||||
if (head + n_tokens > cells.size()) {
|
||||
n_tested += cells.size() - head;
|
||||
head = 0;
|
||||
continue;
|
||||
}
|
||||
|
||||
bool found = true;
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
if (cells[head + i].pos >= 0) {
|
||||
// TODO: improve to accept cells that are masked by the SWA
|
||||
if (!cells.is_empty(head + i)) {
|
||||
found = false;
|
||||
head += i + 1;
|
||||
n_tested += i + 1;
|
||||
|
@ -535,31 +484,27 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
|
|||
break;
|
||||
}
|
||||
|
||||
if (n_tested >= size) {
|
||||
if (n_tested >= cells.size()) {
|
||||
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||
// remember the original state
|
||||
if (recovery.cells.find(head + i) == recovery.cells.end()) {
|
||||
recovery.cells[head + i] = cells[head + i];
|
||||
}
|
||||
// store the old state of the cells in the recovery stack
|
||||
recovery.states.push_back({head, cells.cp(head, n_tokens)});
|
||||
|
||||
cells[head + i].pos = ubatch.pos[i];
|
||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||
cells.pos_set(head + i, ubatch.pos[i]);
|
||||
|
||||
for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
|
||||
cells[head + i].seq_id.insert(ubatch.seq_id[i][j]);
|
||||
cells.seq_add(head + i, ubatch.seq_id[i][j]);
|
||||
}
|
||||
}
|
||||
|
||||
used += n_tokens;
|
||||
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
n = std::min(size, std::max(n_pad, GGML_PAD(cell_max(), n_pad)));
|
||||
n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cell_max(), n_pad)));
|
||||
|
||||
#ifdef FIND_SLOT_DEBUG
|
||||
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
|
||||
|
@ -577,7 +522,7 @@ uint32_t llama_kv_cache_unified::get_n() const {
|
|||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::get_size() const {
|
||||
return size;
|
||||
return cells.size();
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const {
|
||||
|
@ -661,30 +606,19 @@ void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llam
|
|||
|
||||
int n_attended = 0;
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
const llama_pos p0 = cells[i].pos;
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (!cells.seq_has(i, seq_id)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const llama_pos p0 = cells.pos_get(i);
|
||||
|
||||
if (p0 <= pmin && !is_masked_swa(p0, pmin)) {
|
||||
n_attended++;
|
||||
}
|
||||
|
||||
if (is_masked_swa(p0, pmax)) {
|
||||
if (seq_id < 0) {
|
||||
cells[i].seq_id.clear();
|
||||
} else if (cells[i].has_seq_id(seq_id)) {
|
||||
cells[i].seq_id.erase(seq_id);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (cells[i].is_empty()) {
|
||||
// keep count of the number of used cells
|
||||
if (cells[i].pos >= 0) {
|
||||
used--;
|
||||
}
|
||||
|
||||
cells[i].pos = -1;
|
||||
}
|
||||
cells.seq_rm(i, seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -723,25 +657,31 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|||
const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
|
||||
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
const llama_pos p0 = cells[i].pos;
|
||||
float f = 0.0f;
|
||||
|
||||
bool masked = false;
|
||||
|
||||
// mask the token if not the same sequence
|
||||
masked = masked || (!cells[i].has_seq_id(seq_id));
|
||||
if (cells.is_empty(i)) {
|
||||
masked = true;
|
||||
} else {
|
||||
const llama_pos p0 = cells.pos_get(i);
|
||||
|
||||
// mask future tokens
|
||||
masked = masked || (causal_attn && p0 > p1);
|
||||
// mask the token if not the same sequence
|
||||
masked = masked || (!cells.seq_has(i, seq_id));
|
||||
|
||||
// apply SWA if any
|
||||
masked = masked || (is_masked_swa(p0, p1));
|
||||
// mask future tokens
|
||||
masked = masked || (causal_attn && p0 > p1);
|
||||
|
||||
float f = 0.0f;
|
||||
// apply SWA if any
|
||||
masked = masked || (is_masked_swa(p0, p1));
|
||||
|
||||
if (!masked && hparams.use_alibi) {
|
||||
f = -std::abs(p0 - p1);
|
||||
}
|
||||
}
|
||||
|
||||
if (masked) {
|
||||
f = -INFINITY;
|
||||
} else if (hparams.use_alibi) {
|
||||
f = -std::abs(p0 - p1);
|
||||
}
|
||||
|
||||
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
||||
|
@ -765,8 +705,8 @@ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
|
|||
|
||||
int32_t * data = (int32_t *) dst->data;
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
data[i] = cells[i].delta;
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -783,7 +723,10 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
|
|||
for (int h = 0; h < 1; ++h) {
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
|
||||
// the position when the cells is empty is irrelevant - it will be masked out later in the attention
|
||||
const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
|
||||
|
||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -910,7 +853,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
|||
|
||||
ggml_tensor * k =
|
||||
ggml_view_3d(ctx, layer.k,
|
||||
n_embd_head_k, n_head_kv, size,
|
||||
n_embd_head_k, n_head_kv, cells.size(),
|
||||
ggml_row_size(layer.k->type, n_embd_head_k),
|
||||
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
||||
0);
|
||||
|
@ -1050,12 +993,12 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
|||
} else {
|
||||
view_v_src = ggml_view_2d(ctx, layer.v,
|
||||
nm, n_embd_v_gqa,
|
||||
ggml_row_size(layer.v->type, size),
|
||||
ggml_row_size(layer.v->type, cells.size()),
|
||||
ggml_row_size(layer.v->type, i));
|
||||
|
||||
view_v_dst = ggml_view_2d(ctx, layer.v,
|
||||
nm, n_embd_v_gqa,
|
||||
ggml_row_size(layer.v->type, size),
|
||||
ggml_row_size(layer.v->type, cells.size()),
|
||||
ggml_row_size(layer.v->type, id));
|
||||
}
|
||||
|
||||
|
@ -1076,7 +1019,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||
const uint32_t n_layer = layers.size();
|
||||
|
||||
const uint32_t n_kv = cell_max();
|
||||
const uint32_t n_used = used;
|
||||
const uint32_t n_used = cells.get_used();
|
||||
|
||||
assert(n_used <= n_kv);
|
||||
|
||||
|
@ -1104,9 +1047,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||
ids.resize(n_kv, n_kv);
|
||||
|
||||
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
||||
const auto & cell0 = cells[i0];
|
||||
|
||||
if (!cell0.is_empty()) {
|
||||
if (!cells.is_empty(i0)) {
|
||||
ids[i0] = i0;
|
||||
|
||||
continue;
|
||||
|
@ -1117,7 +1058,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||
uint32_t nh = 1;
|
||||
|
||||
// determine the size of the hole
|
||||
while (i0 + nh < n_used && cells[i0 + nh].is_empty()) {
|
||||
while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
|
||||
nh++;
|
||||
}
|
||||
|
||||
|
@ -1126,9 +1067,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||
|
||||
// starting from the end, find nh non-empty cells
|
||||
for (; is > i0; --is) {
|
||||
const auto & cell1 = cells[is];
|
||||
|
||||
if (cell1.is_empty() || ids[is] != n_kv) {
|
||||
if (cells.is_empty(is) || ids[is] != n_kv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -1155,9 +1094,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||
|
||||
// go back and move the nf cells to the hole
|
||||
for (; i1 < n_kv; ++i1) {
|
||||
auto & cell1 = cells[i1];
|
||||
|
||||
if (cell1.is_empty() || ids[i1] != n_kv) {
|
||||
if (cells.is_empty(i1) || ids[i1] != n_kv) {
|
||||
if (n_moves == max_moves) {
|
||||
stop = true;
|
||||
break;
|
||||
|
@ -1171,10 +1108,8 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||
ids[i1] = i0 + nf;
|
||||
|
||||
// move the cell meta data
|
||||
cells[i0 + nf] = cell1;
|
||||
cells.mv(i1, i0 + nf);
|
||||
|
||||
// clear the old cell and move the head there
|
||||
cell1 = kv_cell();
|
||||
head = n_used;
|
||||
|
||||
if (!cont) {
|
||||
|
@ -1210,10 +1145,8 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::cell_max() const {
|
||||
for (uint32_t i = size; i > 0; --i) {
|
||||
const kv_cell & cell = cells[i - 1];
|
||||
|
||||
if (cell.pos >= 0 && !cell.is_empty()) {
|
||||
for (uint32_t i = cells.size(); i > 0; --i) {
|
||||
if (!cells.is_empty(i - 1)) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
@ -1222,9 +1155,7 @@ uint32_t llama_kv_cache_unified::cell_max() const {
|
|||
}
|
||||
|
||||
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
||||
if (p0 < 0) {
|
||||
return true;
|
||||
}
|
||||
assert(p0 >= 0 && p1 >= 0);
|
||||
|
||||
switch (swa_type) {
|
||||
case LLAMA_SWA_TYPE_NONE:
|
||||
|
@ -1255,23 +1186,24 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
|
|||
|
||||
// Count the number of cells with the specified seq_id
|
||||
// Find all the ranges of cells with this seq id (or all, when -1)
|
||||
uint32_t cell_range_begin = size;
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
const auto & cell = cells[i];
|
||||
if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
|
||||
uint32_t cell_range_begin = cells.size();
|
||||
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
|
||||
++cell_count;
|
||||
if (cell_range_begin == size) {
|
||||
if (cell_range_begin == cells.size()) {
|
||||
cell_range_begin = i;
|
||||
}
|
||||
} else {
|
||||
if (cell_range_begin != size) {
|
||||
if (cell_range_begin != cells.size()) {
|
||||
cell_ranges.emplace_back(cell_range_begin, i);
|
||||
cell_range_begin = size;
|
||||
cell_range_begin = cells.size();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (cell_range_begin != size) {
|
||||
cell_ranges.emplace_back(cell_range_begin, size);
|
||||
|
||||
if (cell_range_begin != cells.size()) {
|
||||
cell_ranges.emplace_back(cell_range_begin, cells.size());
|
||||
}
|
||||
|
||||
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
||||
|
@ -1308,17 +1240,24 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
|
|||
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
||||
for (const auto & range : cell_ranges) {
|
||||
for (uint32_t i = range.first; i < range.second; ++i) {
|
||||
const auto & cell = cells[i];
|
||||
const llama_pos pos = cell.pos;
|
||||
const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
|
||||
std::vector<llama_seq_id> seq_ids;
|
||||
|
||||
for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
|
||||
if (cur == seq_id || seq_id == -1) {
|
||||
if (cells.seq_has(i, cur)) {
|
||||
seq_ids.push_back(cur);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const llama_pos pos = cells.pos_get(i);
|
||||
const uint32_t n_seq_id = seq_ids.size();
|
||||
|
||||
io.write(&pos, sizeof(pos));
|
||||
io.write(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
if (n_seq_id) {
|
||||
for (auto seq_id : cell.seq_id) {
|
||||
io.write(&seq_id, sizeof(seq_id));
|
||||
}
|
||||
for (const auto & seq_id : seq_ids) {
|
||||
io.write(&seq_id, sizeof(seq_id));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1379,7 +1318,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
|||
}
|
||||
} else {
|
||||
// When v is transposed, we also need the element size and get the element ranges from each row
|
||||
const uint32_t kv_size = size;
|
||||
const uint32_t kv_size = cells.size();
|
||||
|
||||
for (const auto & layer : layers) {
|
||||
const uint32_t il = layer.il;
|
||||
|
@ -1429,14 +1368,20 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|||
io.read_to(&pos, sizeof(pos));
|
||||
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
if (n_seq_id != 0) {
|
||||
if (n_seq_id != 1) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
batch.pos[i] = pos;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i] = &dest_seq_id;
|
||||
// read the sequence id, but directly discard it - we will use dest_seq_id instead
|
||||
{
|
||||
llama_seq_id seq_id;
|
||||
io.read_to(&seq_id, sizeof(seq_id));
|
||||
}
|
||||
|
||||
batch.pos[i] = pos;
|
||||
batch.n_seq_id[i] = n_seq_id;
|
||||
batch.seq_id[i] = &dest_seq_id;
|
||||
}
|
||||
|
||||
if (!find_slot(batch)) {
|
||||
|
@ -1448,15 +1393,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|||
|
||||
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
||||
// Assume that this is one contiguous block of cells
|
||||
GGML_ASSERT(head + cell_count <= size);
|
||||
GGML_ASSERT(cells[head].pos == batch.pos[0]);
|
||||
GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
|
||||
GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
|
||||
GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
|
||||
GGML_ASSERT(head + cell_count <= cells.size());
|
||||
GGML_ASSERT(cells.pos_get(head) == batch.pos[0]);
|
||||
GGML_ASSERT(cells.pos_get(head + cell_count - 1) == batch.pos[cell_count - 1]);
|
||||
GGML_ASSERT(cells.seq_has(head, dest_seq_id));
|
||||
GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id));
|
||||
} else {
|
||||
// whole KV cache restore
|
||||
|
||||
if (cell_count > size) {
|
||||
if (cell_count > cells.size()) {
|
||||
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
@ -1464,15 +1409,13 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|||
clear();
|
||||
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
kv_cell & cell = cells[i];
|
||||
|
||||
llama_pos pos;
|
||||
uint32_t n_seq_id;
|
||||
|
||||
io.read_to(&pos, sizeof(pos));
|
||||
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
cell.pos = pos;
|
||||
cells.pos_set(i, pos);
|
||||
|
||||
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
||||
llama_seq_id seq_id;
|
||||
|
@ -1483,12 +1426,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|||
return false;
|
||||
}
|
||||
|
||||
cell.seq_id.insert(seq_id);
|
||||
cells.seq_add(i, seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
head = 0;
|
||||
used = cell_count;
|
||||
}
|
||||
|
||||
return true;
|
||||
|
@ -1505,8 +1447,8 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|||
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
|
||||
return false;
|
||||
}
|
||||
if (cell_count > size) {
|
||||
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
|
||||
if (cell_count > cells.size()) {
|
||||
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
|
||||
return false;
|
||||
}
|
||||
if (this->v_trans != (bool) v_trans) {
|
||||
|
@ -1609,7 +1551,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|||
if (cell_count) {
|
||||
// For each row in the transposed matrix, read the values for the whole cell range
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
const size_t dst_offset = (head + j * size) * v_size_el;
|
||||
const size_t dst_offset = (head + j * cells.size()) * v_size_el;
|
||||
ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
||||
}
|
||||
}
|
||||
|
@ -1689,9 +1631,9 @@ void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
|
|||
kv_swa ->seq_keep(seq_id);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
||||
kv_base->seq_add(seq_id, p0, p1, delta);
|
||||
kv_swa ->seq_add(seq_id, p0, p1, delta);
|
||||
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
kv_base->seq_add(seq_id, p0, p1, shift);
|
||||
kv_swa ->seq_add(seq_id, p0, p1, shift);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
|
@ -2063,8 +2005,8 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
|
|||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
||||
if (delta == 0) {
|
||||
void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
if (shift == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -2087,7 +2029,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
|
|||
if (tail_id >= 0) {
|
||||
kv_cell & cell = cells[tail_id];
|
||||
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
||||
cell.pos += delta;
|
||||
cell.pos += shift;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
#include "llama-io.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-memory.h"
|
||||
#include "llama-kv-cells.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
|
@ -35,6 +36,7 @@ struct llama_kv_cache : public llama_memory_i {
|
|||
virtual void defrag_sched(float thold) = 0;
|
||||
|
||||
// simulate full cache, used for allocating worst-case compute buffers
|
||||
// TODO: remove
|
||||
virtual void set_full() = 0;
|
||||
|
||||
//
|
||||
|
@ -42,7 +44,7 @@ struct llama_kv_cache : public llama_memory_i {
|
|||
//
|
||||
|
||||
// =============================================================================================================
|
||||
// TODO: refactor and simplify this
|
||||
// TODO: refactor and simplify this [TAG: KV_API]
|
||||
|
||||
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
|
||||
|
||||
|
@ -121,7 +123,7 @@ public:
|
|||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
|
@ -159,7 +161,7 @@ public:
|
|||
// llama_kv_cache_unified specific API
|
||||
//
|
||||
|
||||
uint32_t get_n() const;
|
||||
uint32_t get_n() const;
|
||||
uint32_t get_size() const;
|
||||
|
||||
// get views of the current state of the cache
|
||||
|
@ -180,26 +182,6 @@ private:
|
|||
const llama_model & model;
|
||||
const llama_hparams & hparams;
|
||||
|
||||
struct kv_cell {
|
||||
llama_pos pos = -1;
|
||||
llama_pos delta = 0;
|
||||
|
||||
// TODO: replace with bitset uint64_t
|
||||
std::set<llama_seq_id> seq_id;
|
||||
|
||||
bool has_seq_id(const llama_seq_id & id) const {
|
||||
return seq_id.find(id) != seq_id.end();
|
||||
}
|
||||
|
||||
bool is_empty() const {
|
||||
return seq_id.empty();
|
||||
}
|
||||
|
||||
bool is_same_seq(const kv_cell & other) const {
|
||||
return seq_id == other.seq_id;
|
||||
}
|
||||
};
|
||||
|
||||
struct kv_layer {
|
||||
// layer index in the model
|
||||
// note: can be different from the layer index in the KV cache
|
||||
|
@ -209,15 +191,13 @@ private:
|
|||
ggml_tensor * v;
|
||||
};
|
||||
|
||||
bool has_shift = false;
|
||||
bool do_defrag = false;
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
|
||||
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
||||
uint32_t size = 0; // total number of cells, shared across all sequences
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id) (TODO: add `struct kv_cells` and keep track automaticallt)
|
||||
|
||||
// computed before each graph build
|
||||
// TODO: cells should start to maintain this value dynamically based on the edits
|
||||
uint32_t n = 0;
|
||||
|
||||
const uint32_t n_seq_max = 1;
|
||||
|
@ -233,19 +213,29 @@ private:
|
|||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
std::vector<kv_cell> cells; // TODO: replace with `struct kv_cells`
|
||||
llama_kv_cells_unified cells;
|
||||
|
||||
std::vector<kv_layer> layers;
|
||||
|
||||
// model layer id -> KV cache layer id
|
||||
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
||||
|
||||
// recovery information used to restore the KV cells to their original state in case of a failure
|
||||
// TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation
|
||||
// to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API]
|
||||
struct {
|
||||
void clear() {
|
||||
cells.clear();
|
||||
states.clear();
|
||||
}
|
||||
|
||||
std::unordered_map<uint32_t, kv_cell> cells;
|
||||
struct state {
|
||||
uint32_t i;
|
||||
|
||||
llama_kv_cells_unified cells;
|
||||
};
|
||||
|
||||
// stack with the partial states before each ubatch
|
||||
std::vector<state> states;
|
||||
} recovery;
|
||||
|
||||
// defrag
|
||||
|
@ -257,6 +247,7 @@ private:
|
|||
bool defrag_prepare(int32_t n_max_nodes);
|
||||
|
||||
// find how many cells are currently in use
|
||||
// TODO: optimize
|
||||
uint32_t cell_max() const;
|
||||
|
||||
size_t total_size() const;
|
||||
|
@ -325,7 +316,7 @@ public:
|
|||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
|
@ -431,7 +422,7 @@ public:
|
|||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
|
|
273
src/llama-kv-cells.h
Normal file
273
src/llama-kv-cells.h
Normal file
|
@ -0,0 +1,273 @@
|
|||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
#include "llama-cparams.h"
|
||||
|
||||
#include <bitset>
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
|
||||
// meta information about KV cells that can be part of multiple sequences at the same time
|
||||
// TODO: add unit tests
|
||||
class llama_kv_cells_unified {
|
||||
public:
|
||||
void reset() {
|
||||
for (uint32_t i = 0; i < pos.size(); ++i) {
|
||||
pos[i] = -1;
|
||||
shift[i] = 0;
|
||||
seq[i].reset();
|
||||
}
|
||||
|
||||
used = 0;
|
||||
has_shift = false;
|
||||
}
|
||||
|
||||
void reset_shift() {
|
||||
has_shift = false;
|
||||
|
||||
for (uint32_t i = 0; i < shift.size(); ++i) {
|
||||
shift[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t size() const {
|
||||
return pos.size();
|
||||
}
|
||||
|
||||
void resize(uint32_t n) {
|
||||
pos.resize(n);
|
||||
shift.resize(n);
|
||||
seq.resize(n);
|
||||
|
||||
reset();
|
||||
}
|
||||
|
||||
bool is_empty(uint32_t i) const {
|
||||
assert(i < pos.size());
|
||||
assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0);
|
||||
|
||||
return pos[i] == -1;
|
||||
}
|
||||
|
||||
uint32_t get_used() const {
|
||||
return used;
|
||||
}
|
||||
|
||||
bool get_has_shift() const {
|
||||
return has_shift;
|
||||
}
|
||||
|
||||
// move cell isrc to idst (used during defrag)
|
||||
void mv(uint32_t isrc, uint32_t idst) {
|
||||
assert(isrc < pos.size());
|
||||
assert(idst < pos.size());
|
||||
|
||||
pos [idst] = pos [isrc];
|
||||
shift[idst] = shift[isrc];
|
||||
seq [idst] = seq [isrc];
|
||||
|
||||
pos [isrc] = -1;
|
||||
shift[isrc] = 0;
|
||||
seq [isrc].reset();
|
||||
}
|
||||
|
||||
// copy the state of cells [i, i + n) (used for save/restore the state of the cells)
|
||||
llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
|
||||
assert(i + n <= pos.size());
|
||||
|
||||
llama_kv_cells_unified res;
|
||||
|
||||
res.resize(n);
|
||||
|
||||
for (uint32_t j = 0; j < n; ++j) {
|
||||
res.pos[j] = pos[i + j];
|
||||
res.seq[j] = seq[i + j];
|
||||
|
||||
assert(shift[i + j] == 0);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
|
||||
void set(uint32_t i, const llama_kv_cells_unified & other) {
|
||||
assert(i + other.pos.size() <= pos.size());
|
||||
|
||||
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
||||
if (pos[i + j] == -1 && other.pos[j] != -1) {
|
||||
used++;
|
||||
}
|
||||
|
||||
if (pos[i + j] != -1 && other.pos[j] == -1) {
|
||||
used--;
|
||||
}
|
||||
|
||||
pos[i + j] = other.pos[j];
|
||||
seq[i + j] = other.seq[j];
|
||||
|
||||
assert(shift[i + j] == 0);
|
||||
}
|
||||
}
|
||||
|
||||
// note: call only if the cell has seq_id
|
||||
// return true if the cell becomes empty
|
||||
bool seq_rm(uint32_t i, llama_seq_id seq_id) {
|
||||
assert(i < pos.size());
|
||||
assert(seq[i].test(seq_id));
|
||||
assert(pos[i] != -1);
|
||||
assert(seq_id >= 0);
|
||||
|
||||
seq[i].reset(seq_id);
|
||||
|
||||
if (seq[i].none()) {
|
||||
pos[i] = -1;
|
||||
|
||||
used--;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// return true if the cell becomes empty (i.e. it did not contain seq_id before the call)
|
||||
bool seq_keep(uint32_t i, llama_seq_id seq_id) {
|
||||
assert(i < pos.size());
|
||||
|
||||
if (seq[i].test(seq_id)) {
|
||||
seq[i].reset();
|
||||
seq[i].set(seq_id);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
if (seq[i].any()) {
|
||||
seq[i].reset();
|
||||
pos[i] = -1;
|
||||
|
||||
used--;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
assert(pos[i] == -1);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool seq_has(uint32_t i, llama_seq_id seq_id) const {
|
||||
assert(i < pos.size());
|
||||
assert(seq_id >= 0);
|
||||
|
||||
return seq[i].test(seq_id);
|
||||
}
|
||||
|
||||
// note: call only if the cell is not empty and the seq_id is not in the cell
|
||||
void seq_add(uint32_t i, llama_seq_id seq_id) {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] != -1);
|
||||
assert(!seq[i].test(seq_id));
|
||||
|
||||
seq[i].set(seq_id);
|
||||
}
|
||||
|
||||
// note: call only if the cell is not empty
|
||||
llama_pos pos_get(uint32_t i) const {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] != -1);
|
||||
|
||||
return pos[i];
|
||||
}
|
||||
|
||||
// note: call only if the cell is not empty
|
||||
llama_pos get_shift(uint32_t i) const {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] != -1);
|
||||
|
||||
return shift[i];
|
||||
}
|
||||
|
||||
// check if a cell is not empty and its position is within [p0, p1)
|
||||
bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const {
|
||||
assert(i < pos.size());
|
||||
|
||||
return pos[i] >= p0 && pos[i] < p1;
|
||||
}
|
||||
|
||||
// set the position of an empty cell
|
||||
// does not modify "has_shift"
|
||||
// note: call only if the cell is empty
|
||||
void pos_set(uint32_t i, llama_pos p) {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] == -1);
|
||||
|
||||
pos[i] = p;
|
||||
used++;
|
||||
}
|
||||
|
||||
// pos[i] = pos[i] + d
|
||||
// sets "has_shift" to true
|
||||
// note: call only if the cell is not empty
|
||||
bool pos_add(uint32_t i, llama_pos d) {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] != -1);
|
||||
|
||||
pos[i] += d;
|
||||
shift[i] += d;
|
||||
|
||||
has_shift = true;
|
||||
|
||||
if (pos[i] < 0) {
|
||||
pos[i] = -1;
|
||||
seq[i].reset();
|
||||
|
||||
used--;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// pos[i] = pos[i] / d
|
||||
// sets "has_shift" to true
|
||||
// note: call only if the cell is not empty
|
||||
void pos_div(uint32_t i, int d) {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] != -1);
|
||||
|
||||
const llama_pos p_old = pos[i];
|
||||
|
||||
pos[i] /= d;
|
||||
shift[i] += p_old - pos[i];
|
||||
|
||||
has_shift = true;
|
||||
}
|
||||
|
||||
private:
|
||||
uint32_t used = 0; // used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
|
||||
|
||||
bool has_shift = false;
|
||||
|
||||
std::vector<llama_pos> pos;
|
||||
|
||||
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
|
||||
// this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
|
||||
//
|
||||
// cells.pos_add(x, shift_x);
|
||||
// cells.pos_div(y, shift_y);
|
||||
// ...
|
||||
//
|
||||
// if (cells.has_shift()) {
|
||||
// for (int i = 0; i < n; ++i) {
|
||||
// auto shift_i = cells.get_shift(i);
|
||||
// ...
|
||||
// }
|
||||
// cells.reset_shift();
|
||||
// }
|
||||
//
|
||||
std::vector<llama_pos> shift;
|
||||
|
||||
std::vector<std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>> seq;
|
||||
};
|
||||
|
|
@ -22,7 +22,7 @@ public:
|
|||
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
||||
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
||||
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
||||
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
|
||||
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
|
||||
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
||||
|
||||
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
|
||||
|
|
|
@ -2585,7 +2585,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
|
355
tests/test-chat-parser.cpp
Normal file
355
tests/test-chat-parser.cpp
Normal file
|
@ -0,0 +1,355 @@
|
|||
// Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
|
||||
//
|
||||
// Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
|
||||
// e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
|
||||
//
|
||||
// cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
|
||||
//
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
#include <json.hpp>
|
||||
#include <string>
|
||||
|
||||
#include "chat-parser.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
template <class T>
|
||||
static void assert_equals(const T & expected, const T & actual) {
|
||||
if (expected != actual) {
|
||||
std::cerr << "Expected: " << expected << std::endl;
|
||||
std::cerr << "Actual: " << actual << std::endl;
|
||||
std::cerr << std::flush;
|
||||
throw std::runtime_error("Test failed");
|
||||
}
|
||||
}
|
||||
static void assert_equals(const char * expected, const std::string & actual) {
|
||||
return assert_equals<std::string>(expected, actual);
|
||||
}
|
||||
|
||||
static void assert_throws(const std::function<void()> & fn, const std::string & expected_exception_pattern = "") {
|
||||
try {
|
||||
fn();
|
||||
} catch (const std::exception & e) {
|
||||
if (expected_exception_pattern.empty()) {
|
||||
return;
|
||||
}
|
||||
std::regex expected_exception_regex(expected_exception_pattern);
|
||||
std::string actual_message = e.what();
|
||||
if (std::regex_search(actual_message, expected_exception_regex)) {
|
||||
return;
|
||||
}
|
||||
throw std::runtime_error("Exception doesn't match expected pattern: " + actual_message + " (pattern: " + expected_exception_pattern + ")");
|
||||
throw std::runtime_error("Exception of unexpected type: " + std::string(e.what()));
|
||||
}
|
||||
throw std::runtime_error("Exception was expected but not thrown");
|
||||
}
|
||||
|
||||
static void test_reasoning() {
|
||||
{
|
||||
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
});
|
||||
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||
assert_equals("<tnk>Cogito</tnk>Ergo sum", builder.consume_rest());
|
||||
}
|
||||
{
|
||||
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
});
|
||||
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||
assert_equals(std::string("Cogito"), builder.result().reasoning_content);
|
||||
assert_equals("Ergo sum", builder.consume_rest());
|
||||
}
|
||||
{
|
||||
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
});
|
||||
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||
assert_equals("Cogito</tnk>Ergo sum", builder.consume_rest());
|
||||
}
|
||||
{
|
||||
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
});
|
||||
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||
assert_equals(std::string("Cogito"), builder.result().reasoning_content);
|
||||
assert_equals("Ergo sum", builder.consume_rest());
|
||||
}
|
||||
{
|
||||
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
/* .reasoning_in_content = */ true,
|
||||
/* .thinking_forced_open = */ true,
|
||||
});
|
||||
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||
assert_equals("<think>Cogito</think>", builder.result().content);
|
||||
assert_equals("Ergo sum", builder.consume_rest());
|
||||
}
|
||||
}
|
||||
|
||||
static void test_regex() {
|
||||
auto test_throws = [](const std::string & input, const std::string & regex, const std::string & expected_exception_pattern = "") {
|
||||
common_chat_msg_parser builder(input, /* is_partial= */ false, {});
|
||||
assert_throws([&]() { builder.consume_regex(common_regex(regex)); }, expected_exception_pattern);
|
||||
};
|
||||
|
||||
test_throws("Hello, world!", "abc", "^abc$");
|
||||
test_throws("Hello, world!", "e", "^e$");
|
||||
|
||||
{
|
||||
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
|
||||
builder.consume_regex(common_regex("Hello"));
|
||||
assert_equals(", world!", builder.consume_rest());
|
||||
}
|
||||
|
||||
{
|
||||
// When in non partial mode, we can say whether the regex was consumed or not.
|
||||
common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
|
||||
assert_equals(false, builder.try_consume_regex(common_regex("Hello, world!")).has_value());
|
||||
}
|
||||
{
|
||||
common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
|
||||
auto res = builder.try_consume_regex(common_regex("H(el)l(?:o, world!)?"));
|
||||
assert_equals(true, res.has_value());
|
||||
// Verify captures
|
||||
assert_equals<size_t>(2, res->groups.size());
|
||||
assert_equals("Hell", builder.str(res->groups[0]));
|
||||
assert_equals("el", builder.str(res->groups[1]));
|
||||
// Verify position is after the match
|
||||
assert_equals<size_t>(4, builder.pos());
|
||||
assert_equals("o,", builder.consume_rest());
|
||||
}
|
||||
{
|
||||
// But in partial mode, we have a partial final match / can't decide, so we throw a partial exception.
|
||||
common_chat_msg_parser builder("Hello,", /* is_partial= */ true, {});
|
||||
assert_throws([&]() {
|
||||
builder.try_consume_regex(common_regex("Hello, world!"));
|
||||
}, "^Hello, world!$");
|
||||
}
|
||||
|
||||
// Now regardless of the mode, we can tell these aren't a match.
|
||||
for (const auto is_partial : {false, true}) {
|
||||
common_chat_msg_parser builder("Hello,", is_partial, {});
|
||||
assert_equals(false, builder.try_consume_regex(common_regex("a(b|c)(d|e)f")).has_value());
|
||||
}
|
||||
for (const auto is_partial : {false, true}) {
|
||||
common_chat_msg_parser builder("Hello,", is_partial, {});
|
||||
assert_equals(false, builder.try_consume_literal("Oh"));
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<std::string> barely_healable_jsons = {
|
||||
"{",
|
||||
"{\"",
|
||||
"{\"\\",
|
||||
"{\"n",
|
||||
"{\"name\"",
|
||||
"{\"name\":",
|
||||
"{\"name\":\"",
|
||||
"{\"name\":\"\\",
|
||||
"{\"name\":\"python",
|
||||
"{\"name\":\"python\\",
|
||||
"{\",",
|
||||
"{\":",
|
||||
"{\"[",
|
||||
"{\"]",
|
||||
"{\"{",
|
||||
"{\"}",
|
||||
"{\"1",
|
||||
"{\"name\":\",",
|
||||
"{\"name\":\":",
|
||||
"{\"name\":\"[",
|
||||
"{\"name\":\"]",
|
||||
"{\"name\":\"{",
|
||||
"{\"name\":\"}",
|
||||
"{\"name\":\"1",
|
||||
};
|
||||
|
||||
static void test(const std::string & input, bool is_partial, const std::vector<std::vector<std::string>> & args_paths, const std::vector<std::vector<std::string>> & content_paths, const std::string & expected) {
|
||||
common_chat_msg_parser builder(input, is_partial, {});
|
||||
auto js = builder.try_consume_json_with_dumped_args(args_paths, content_paths);
|
||||
assert_equals(true, js.has_value());
|
||||
assert_equals(is_partial, js->is_partial);
|
||||
assert_equals(expected, args_paths.size() == 1 && args_paths[0].empty() ? js->value.get<std::string>() : js->value.dump());
|
||||
}
|
||||
static void test_with_args(const std::string & input, const std::string & expected, bool parse_as_partial = true, bool is_partial = true) {
|
||||
common_chat_msg_parser builder(input, parse_as_partial, {});
|
||||
auto js = builder.try_consume_json_with_dumped_args({{"args"}}, {});
|
||||
assert_equals(true, js.has_value());
|
||||
assert_equals(is_partial, js->is_partial);
|
||||
assert_equals(expected, js->value.dump());
|
||||
}
|
||||
|
||||
static void test_json_with_dumped_args_no_args() {
|
||||
// Normal JSON, nothing to heal, nothing to dump
|
||||
test("{\"name\": \"python\"}", false, {}, {}, "{\"name\":\"python\"}");
|
||||
// Full json is args
|
||||
test("{\"name\": \"python\"}", false, {{}}, {}, "{\"name\":\"python\"}");
|
||||
|
||||
// If the arguments are further down, don't heal partial content.
|
||||
for (const auto & src : barely_healable_jsons) {
|
||||
test(src, true, {{"arguments"}}, {}, "{}");
|
||||
}
|
||||
// But heal content that isn't partial.
|
||||
test("{\"name\": \"python\"", true, {{"arguments"}}, {}, "{\"name\":\"python\"}");
|
||||
}
|
||||
|
||||
static void test_json_with_dumped_args() {
|
||||
|
||||
// Partial content.
|
||||
test("{\"content\": \"t", true, {}, {{"content"}}, "{\"content\":\"t\"}");
|
||||
test("{\"content\": \"", true, {}, {{"content"}}, "{\"content\":\"\"}");
|
||||
test("{\"content\": ", true, {}, {{"content"}}, "{}");
|
||||
|
||||
// If the entire JSON is the arguments, healing it them dumping it produces the same output as the input (just reformatted).
|
||||
test("{\"name\": \"python", true, {{}}, {}, "{\"name\":\"python");
|
||||
for (const auto & src : barely_healable_jsons) {
|
||||
test(src, true, {{}}, {}, src);
|
||||
}
|
||||
|
||||
// Full JSON w/ args
|
||||
for (auto parse_as_partial : {true, false}) {
|
||||
test_with_args(
|
||||
R"({"name": "python", "args": {"arg1": 1}})",
|
||||
R"({"name":"python","args":"{\"arg1\":1}"})",
|
||||
parse_as_partial,
|
||||
/* is_partial= */ false
|
||||
);
|
||||
}
|
||||
|
||||
// Partial JSON w/ partial args
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {")",
|
||||
R"({"foo":"bar","args":"{\""})"
|
||||
);
|
||||
// Partial args broken in object key
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"ar)",
|
||||
R"({"foo":"bar","args":"{\"ar"})"
|
||||
);
|
||||
// Partial args broken after object key
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1")",
|
||||
R"({"foo":"bar","args":"{\"arg1\""})"
|
||||
);
|
||||
// Partial args broken before object value
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1":)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":"})"
|
||||
);
|
||||
// Partial args broken before object value (space)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": )",
|
||||
R"({"foo":"bar","args":"{\"arg1\":"})"
|
||||
);
|
||||
// Partial args broken in object value that may not be complete (int)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": 1)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":"})"
|
||||
);
|
||||
// Partial args broken in object value that is complete (int)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": 1 )",
|
||||
R"({"foo":"bar","args":"{\"arg1\":1"})"
|
||||
);
|
||||
// Partial args broken in object value that is incomplete (string)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": ")",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\""})"
|
||||
);
|
||||
// Partial args broken in object value that is complete (string)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "1")",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"1\""})"
|
||||
);
|
||||
// Partial args broken on array opening
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": [)",
|
||||
R"({"foo":"bar","args":"["})"
|
||||
);
|
||||
// Partial args broken on array value that is incomplete (int)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": [1)",
|
||||
R"({"foo":"bar","args":"["})"
|
||||
);
|
||||
// Partial args broken on array value that is complete (int)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": [1 )",
|
||||
R"({"foo":"bar","args":"[1"})"
|
||||
);
|
||||
// Partial args broken on array value that is complete (string)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": ["1")",
|
||||
R"({"foo":"bar","args":"[\"1\""})"
|
||||
);
|
||||
// Partial args broken after array value
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": [1,)",
|
||||
R"({"foo":"bar","args":"[1,"})"
|
||||
);
|
||||
// Partial args broken on nested array
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": [)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":["})"
|
||||
);
|
||||
}
|
||||
|
||||
static void test_positions() {
|
||||
{
|
||||
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
|
||||
assert_equals<size_t>(0, builder.pos());
|
||||
assert_throws([&]() { builder.move_to(100); });
|
||||
assert_equals<size_t>(0, builder.pos());
|
||||
assert_throws([&]() { builder.move_back(1); });
|
||||
assert_equals<size_t>(0, builder.pos());
|
||||
|
||||
builder.move_to(8);
|
||||
assert_equals<size_t>(8, builder.pos());
|
||||
builder.move_back(1);
|
||||
assert_equals<size_t>(7, builder.pos());
|
||||
assert_equals("world!", builder.consume_rest());
|
||||
|
||||
builder.move_to(0);
|
||||
assert_equals<size_t>(0, builder.pos());
|
||||
|
||||
assert_throws([&]() { builder.finish(); });
|
||||
assert_equals<size_t>(0, builder.pos());
|
||||
|
||||
builder.move_to(builder.input().size());
|
||||
builder.finish();
|
||||
}
|
||||
{
|
||||
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ true, {});
|
||||
|
||||
builder.move_to(builder.input().size());
|
||||
assert_equals<size_t>(builder.input().size(), builder.pos());
|
||||
builder.finish();
|
||||
}
|
||||
}
|
||||
|
||||
int main() {
|
||||
test_positions();
|
||||
test_json_with_dumped_args_no_args();
|
||||
test_json_with_dumped_args();
|
||||
test_reasoning();
|
||||
test_regex();
|
||||
std::cout << "All tests passed!\n";
|
||||
return 0;
|
||||
}
|
237
tests/test-json-partial.cpp
Normal file
237
tests/test-json-partial.cpp
Normal file
|
@ -0,0 +1,237 @@
|
|||
#include "common.h"
|
||||
#include "json-partial.h"
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
|
||||
template <class T> static void assert_equals(const T & expected, const T & actual) {
|
||||
if (expected != actual) {
|
||||
std::cerr << "Expected: " << expected << std::endl;
|
||||
std::cerr << "Actual: " << actual << std::endl;
|
||||
std::cerr << std::flush;
|
||||
throw std::runtime_error("Test failed");
|
||||
}
|
||||
}
|
||||
|
||||
static void test_json_healing() {
|
||||
auto parse = [](const std::string & str) {
|
||||
std::cerr << "# Parsing: " << str << '\n';
|
||||
std::string::const_iterator it = str.begin();
|
||||
const auto end = str.end();
|
||||
common_json out;
|
||||
std::string healing_marker = "$llama.cpp.json$";
|
||||
if (common_json_parse(it, end, healing_marker, out)) {
|
||||
auto dump = out.json.dump();
|
||||
std::cerr << "Parsed: " << dump << '\n';
|
||||
std::cerr << "Magic: " << out.healing_marker.json_dump_marker << '\n';
|
||||
std::string result;
|
||||
if (!out.healing_marker.json_dump_marker.empty()) {
|
||||
auto i = dump.find(out.healing_marker.json_dump_marker);
|
||||
if (i == std::string::npos) {
|
||||
throw std::runtime_error("Failed to find magic in dump " + dump + " (magic: " + out.healing_marker.json_dump_marker + ")");
|
||||
}
|
||||
result = dump.substr(0, i);
|
||||
} else {
|
||||
result = dump;
|
||||
}
|
||||
std::cerr << "Result: " << result << '\n';
|
||||
if (string_starts_with(str, result)) {
|
||||
std::cerr << "Failure!\n";
|
||||
}
|
||||
// return dump;
|
||||
} else {
|
||||
throw std::runtime_error("Failed to parse: " + str);
|
||||
}
|
||||
|
||||
};
|
||||
auto parse_all = [&](const std::string & str) {
|
||||
for (size_t i = 1; i < str.size(); i++) {
|
||||
parse(str.substr(0, i));
|
||||
}
|
||||
};
|
||||
parse_all("{\"a\": \"b\"}");
|
||||
parse_all("{\"hey\": 1, \"ho\\\"ha\": [1]}");
|
||||
|
||||
parse_all("[{\"a\": \"b\"}]");
|
||||
|
||||
auto test = [&](const std::vector<std::string> & inputs, const std::string & expected, const std::string & expected_marker) {
|
||||
for (const auto & input : inputs) {
|
||||
common_json out;
|
||||
assert_equals(true, common_json_parse(input, "$foo", out));
|
||||
assert_equals<std::string>(expected, out.json.dump());
|
||||
assert_equals<std::string>(expected_marker, out.healing_marker.json_dump_marker);
|
||||
}
|
||||
};
|
||||
// No healing needed:
|
||||
test(
|
||||
{
|
||||
R"([{"a":"b"}, "y"])",
|
||||
},
|
||||
R"([{"a":"b"},"y"])",
|
||||
""
|
||||
);
|
||||
// Partial literals can't be healed:
|
||||
test(
|
||||
{
|
||||
R"([1)",
|
||||
R"([tru)",
|
||||
R"([n)",
|
||||
R"([nul)",
|
||||
R"([23.2)",
|
||||
},
|
||||
R"(["$foo"])",
|
||||
R"("$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a": 1)",
|
||||
R"({"a": tru)",
|
||||
R"({"a": n)",
|
||||
R"({"a": nul)",
|
||||
R"({"a": 23.2)",
|
||||
},
|
||||
R"({"a":"$foo"})",
|
||||
R"("$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({)",
|
||||
},
|
||||
R"({"$foo":1})",
|
||||
R"("$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([)",
|
||||
},
|
||||
R"(["$foo"])",
|
||||
R"("$foo)"
|
||||
);
|
||||
// Healing right after a full literal
|
||||
test(
|
||||
{
|
||||
R"(1 )",
|
||||
},
|
||||
R"(1)",
|
||||
""
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"(true)",
|
||||
R"(true )",
|
||||
},
|
||||
R"(true)",
|
||||
""
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"(null)",
|
||||
R"(null )",
|
||||
},
|
||||
R"(null)",
|
||||
""
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([1 )",
|
||||
},
|
||||
R"([1,"$foo"])",
|
||||
R"(,"$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([{})",
|
||||
R"([{} )",
|
||||
},
|
||||
R"([{},"$foo"])",
|
||||
R"(,"$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([true)",
|
||||
},
|
||||
// TODO: detect the true/false/null literal was complete
|
||||
R"(["$foo"])",
|
||||
R"("$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([true )",
|
||||
},
|
||||
R"([true,"$foo"])",
|
||||
R"(,"$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([true,)",
|
||||
},
|
||||
R"([true,"$foo"])",
|
||||
R"("$foo)"
|
||||
);
|
||||
// Test nesting
|
||||
test(
|
||||
{
|
||||
R"([{"a": [{"b": [{)",
|
||||
},
|
||||
R"([{"a":[{"b":[{"$foo":1}]}]}])",
|
||||
R"("$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([{"a": [{"b": [)",
|
||||
},
|
||||
R"([{"a":[{"b":["$foo"]}]}])",
|
||||
R"("$foo)"
|
||||
);
|
||||
|
||||
test(
|
||||
{
|
||||
R"([{"a": "b"})",
|
||||
R"([{"a": "b"} )",
|
||||
},
|
||||
R"([{"a":"b"},"$foo"])",
|
||||
R"(,"$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([{"a": "b"},)",
|
||||
R"([{"a": "b"}, )",
|
||||
},
|
||||
R"([{"a":"b"},"$foo"])",
|
||||
R"("$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({ "code)",
|
||||
},
|
||||
R"({"code$foo":1})",
|
||||
R"($foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({ "code\)",
|
||||
},
|
||||
R"({"code\\$foo":1})",
|
||||
R"(\$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({ "code")",
|
||||
},
|
||||
R"({"code":"$foo"})",
|
||||
R"(:"$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({ "key")",
|
||||
},
|
||||
R"({"key":"$foo"})",
|
||||
R"(:"$foo)"
|
||||
);
|
||||
}
|
||||
|
||||
int main() {
|
||||
test_json_healing();
|
||||
std::cerr << "All tests passed.\n";
|
||||
return 0;
|
||||
}
|
|
@ -107,6 +107,7 @@
|
|||
// ultravox
|
||||
#define TN_CONV1D "a.conv1d.%d.%s"
|
||||
#define TN_MM_AUDIO_MLP "mm.a.mlp.%d.%s"
|
||||
#define TN_MM_AUDIO_FC "mm.a.fc.%s" // fully connected layer
|
||||
#define TN_MM_NORM_PRE "mm.a.norm_pre.%s"
|
||||
#define TN_MM_NORM_MID "mm.a.norm_mid.%s"
|
||||
|
||||
|
@ -128,6 +129,7 @@ enum projector_type {
|
|||
PROJECTOR_TYPE_ULTRAVOX,
|
||||
PROJECTOR_TYPE_INTERNVL,
|
||||
PROJECTOR_TYPE_LLAMA4,
|
||||
PROJECTOR_TYPE_QWEN2A,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
|
@ -145,6 +147,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
|||
{ PROJECTOR_TYPE_ULTRAVOX, "ultravox"},
|
||||
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
|
||||
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
|
||||
{ PROJECTOR_TYPE_QWEN2A, "qwen2a"},
|
||||
};
|
||||
|
||||
static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||
|
|
|
@ -269,7 +269,9 @@ struct clip_vision_model {
|
|||
ggml_tensor * post_ln_w;
|
||||
ggml_tensor * post_ln_b;
|
||||
|
||||
ggml_tensor * projection;
|
||||
ggml_tensor * projection; // TODO: rename it to fc (fully connected layer)
|
||||
ggml_tensor * mm_fc_w;
|
||||
ggml_tensor * mm_fc_b;
|
||||
|
||||
// LLaVA projection
|
||||
ggml_tensor * mm_input_norm_w = nullptr;
|
||||
|
@ -1493,48 +1495,58 @@ struct clip_graph {
|
|||
|
||||
cb(cur, "after_transformer", -1);
|
||||
|
||||
// StackAudioFrames
|
||||
// https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
|
||||
{
|
||||
int64_t stride = n_embd * hparams.proj_stack_factor;
|
||||
int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
|
||||
int64_t pad = padded_len - ggml_nelements(cur);
|
||||
if (pad > 0) {
|
||||
cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
|
||||
cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
|
||||
}
|
||||
cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
|
||||
ggml_row_size(cur->type, stride), 0);
|
||||
}
|
||||
|
||||
cb(cur, "after_stacked", -1);
|
||||
|
||||
// UltravoxProjector
|
||||
{
|
||||
// pre-norm
|
||||
cur = ggml_rms_norm(ctx0, cur, 1e-6);
|
||||
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
|
||||
|
||||
// ffn in
|
||||
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
|
||||
|
||||
// swiglu
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX) {
|
||||
// StackAudioFrames
|
||||
// https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
|
||||
{
|
||||
int64_t split_point = cur->ne[0] / 2;
|
||||
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
||||
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
||||
|
||||
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
|
||||
x1 = ggml_silu(ctx0, x1);
|
||||
cur = ggml_mul(ctx0, x0, x1);
|
||||
int64_t stride = n_embd * hparams.proj_stack_factor;
|
||||
int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
|
||||
int64_t pad = padded_len - ggml_nelements(cur);
|
||||
if (pad > 0) {
|
||||
cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
|
||||
cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
|
||||
}
|
||||
cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
|
||||
ggml_row_size(cur->type, stride), 0);
|
||||
}
|
||||
|
||||
// mid-norm
|
||||
cur = ggml_rms_norm(ctx0, cur, 1e-6);
|
||||
cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
|
||||
cb(cur, "after_stacked", -1);
|
||||
|
||||
// ffn out
|
||||
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
|
||||
// UltravoxProjector
|
||||
{
|
||||
// pre-norm
|
||||
cur = ggml_rms_norm(ctx0, cur, 1e-6);
|
||||
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
|
||||
|
||||
// ffn in
|
||||
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
|
||||
|
||||
// swiglu
|
||||
{
|
||||
int64_t split_point = cur->ne[0] / 2;
|
||||
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
||||
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
||||
|
||||
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
|
||||
x1 = ggml_silu(ctx0, x1);
|
||||
cur = ggml_mul(ctx0, x0, x1);
|
||||
}
|
||||
|
||||
// mid-norm
|
||||
cur = ggml_rms_norm(ctx0, cur, 1e-6);
|
||||
cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
|
||||
|
||||
// ffn out
|
||||
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
|
||||
}
|
||||
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2A) {
|
||||
// projector
|
||||
cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.mm_fc_b);
|
||||
|
||||
} else {
|
||||
GGML_ABORT("%s: unknown projector type", __func__);
|
||||
}
|
||||
|
||||
cb(cur, "projected", -1);
|
||||
|
@ -1677,6 +1689,17 @@ private:
|
|||
inpL = cur;
|
||||
}
|
||||
|
||||
// TODO @ngxson : find a way to move this outside
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_QWEN2A) {
|
||||
ggml_tensor * cur = inpL;
|
||||
cur = ggml_transpose(ctx0, cur);
|
||||
cur = ggml_cont(ctx0, cur);
|
||||
cur = ggml_pool_1d(ctx0, cur, GGML_OP_POOL_AVG, 2, 2, 0);
|
||||
cur = ggml_transpose(ctx0, cur);
|
||||
cur = ggml_cont(ctx0, cur);
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
// post-layernorm
|
||||
if (model.post_ln_w) {
|
||||
inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, -1);
|
||||
|
@ -1974,6 +1997,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
res = graph.build_llama4();
|
||||
} break;
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
{
|
||||
res = graph.build_whisper_enc();
|
||||
} break;
|
||||
|
@ -2234,8 +2258,10 @@ struct clip_model_loader {
|
|||
};
|
||||
} break;
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
{
|
||||
get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor);
|
||||
bool require_stack = ctx_clip.proj_type == PROJECTOR_TYPE_ULTRAVOX;
|
||||
get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack);
|
||||
if (hparams.n_mel_bins != 128) {
|
||||
throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__));
|
||||
}
|
||||
|
@ -2314,7 +2340,7 @@ struct clip_model_loader {
|
|||
return cur;
|
||||
};
|
||||
|
||||
auto & vision_model = ctx_clip.vision_model;
|
||||
auto & vision_model = ctx_clip.vision_model; // TODO: rename this to just "model"
|
||||
|
||||
vision_model.class_embedding = get_tensor(TN_CLASS_EMBD, false);
|
||||
|
||||
|
@ -2511,6 +2537,15 @@ struct clip_model_loader {
|
|||
vision_model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
|
||||
vision_model.mm_norm_mid_w = get_tensor(string_format(TN_MM_NORM_MID, "weight"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
{
|
||||
vision_model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
|
||||
vision_model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
|
||||
vision_model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
|
||||
vision_model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
|
||||
vision_model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight"));
|
||||
vision_model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
{
|
||||
vision_model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
|
||||
|
@ -3594,6 +3629,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
|||
const int proj_stack_factor = ctx->vision_model.hparams.proj_stack_factor;
|
||||
const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
|
||||
n_patches = n_len / proj_stack_factor / 2;
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2A) {
|
||||
// divide by 2 because of whisper
|
||||
// another divide by 2 because of nn.AvgPool1d(2, stride=2)
|
||||
n_patches = img->nx / 4;
|
||||
}
|
||||
|
||||
return n_patches;
|
||||
|
@ -3994,6 +4033,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
case PROJECTOR_TYPE_GEMMA3:
|
||||
case PROJECTOR_TYPE_IDEFICS3:
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
{
|
||||
// do nothing
|
||||
|
@ -4048,7 +4088,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
const int n_tokens_out = embeddings->ne[1];
|
||||
const int expected_n_tokens_out = clip_n_output_tokens(ctx, imgs.entries[0].get());
|
||||
if (n_tokens_out != expected_n_tokens_out) {
|
||||
LOG_ERR("%s: expected %d tokens, got %d\n", __func__, expected_n_tokens_out, n_tokens_out);
|
||||
LOG_ERR("%s: expected output %d tokens, got %d\n", __func__, expected_n_tokens_out, n_tokens_out);
|
||||
GGML_ABORT("Invalid number of output tokens");
|
||||
}
|
||||
|
||||
|
@ -4276,6 +4316,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
|||
return ctx->vision_model.mm_3_w->ne[1];
|
||||
case PROJECTOR_TYPE_LLAMA4:
|
||||
return ctx->vision_model.mm_model_proj->ne[1];
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
return ctx->vision_model.mm_fc_w->ne[1];
|
||||
default:
|
||||
GGML_ABORT("Unknown projector type");
|
||||
}
|
||||
|
@ -4316,6 +4358,10 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
|
|||
return ctx->vision_model.hparams.has_audio;
|
||||
}
|
||||
|
||||
bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
|
||||
return ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX || ctx->proj_type == PROJECTOR_TYPE_QWEN2A;
|
||||
}
|
||||
|
||||
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
|
||||
clip_image_f32 clip_img;
|
||||
clip_img.buf.resize(h * w * 3);
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
// !!! Internal header, to be used by mtmd only !!!
|
||||
|
||||
struct clip_ctx;
|
||||
|
||||
struct clip_image_size {
|
||||
|
@ -101,5 +103,6 @@ void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel
|
|||
|
||||
bool clip_has_vision_encoder(const struct clip_ctx * ctx);
|
||||
bool clip_has_audio_encoder(const struct clip_ctx * ctx);
|
||||
bool clip_has_whisper_encoder(const struct clip_ctx * ctx);
|
||||
|
||||
bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype) ;
|
||||
|
|
|
@ -146,6 +146,13 @@ struct mtmd_context {
|
|||
throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
|
||||
}
|
||||
|
||||
if (llama_model_n_embd(text_model) != clip_n_mmproj_embd(ctx_clip)) {
|
||||
throw std::runtime_error(string_format(
|
||||
"mismatch between text model (n_embd = %d) and mmproj (n_embd = %d)\n"
|
||||
"hint: you may be using wrong mmproj\n",
|
||||
llama_model_n_embd(text_model), clip_n_mmproj_embd(ctx_clip)));
|
||||
}
|
||||
|
||||
has_vision = clip_has_vision_encoder(ctx_clip);
|
||||
has_audio = clip_has_audio_encoder(ctx_clip);
|
||||
use_mrope = clip_is_qwen2vl(ctx_clip);
|
||||
|
@ -196,7 +203,7 @@ struct mtmd_context {
|
|||
ov_img_first = false; // overview image is last
|
||||
}
|
||||
|
||||
if (proj == PROJECTOR_TYPE_ULTRAVOX) {
|
||||
if (clip_has_whisper_encoder(ctx_clip)) {
|
||||
// TODO @ngxson : check if model n_mel is 128 or 80
|
||||
w_filters = whisper_precalc_filters::get_128_bins();
|
||||
}
|
||||
|
@ -208,7 +215,7 @@ struct mtmd_context {
|
|||
}
|
||||
if (has_audio) {
|
||||
LOG_WRN("%s: audio input is in experimental stage and may have reduced quality:\n"
|
||||
" https://github.com/ggml-org/llama.cpp/pull/13623\n", __func__);
|
||||
" https://github.com/ggml-org/llama.cpp/discussions/13759\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -327,6 +334,11 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
|||
marker_modified = "<img>" + ctx->media_marker + "</img>";
|
||||
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
|
||||
|
||||
} else if (proj_type == PROJECTOR_TYPE_QWEN2A) {
|
||||
// <|audio_bos|> ... (embeddings) ... <|audio_eos|>
|
||||
marker_modified = "<|audio_bos|>" + ctx->media_marker + "<|audio_eos|>";
|
||||
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
|
||||
|
||||
}
|
||||
|
||||
// llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
|
||||
|
|
|
@ -203,6 +203,8 @@ MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx,
|
|||
const mtmd_input_chunk * chunk);
|
||||
|
||||
// get output embeddings from the last encode pass
|
||||
// the reading size (in bytes) is equal to:
|
||||
// llama_model_n_embd(model) * mtmd_input_chunk_get_n_tokens(chunk) * sizeof(float)
|
||||
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
|
||||
|
||||
/////////////////////////////////////////
|
||||
|
|
Binary file not shown.
|
@ -1,3 +1,4 @@
|
|||
#include "chat.h"
|
||||
#include "utils.hpp"
|
||||
|
||||
#include "arg.h"
|
||||
|
@ -114,11 +115,11 @@ struct slot_params {
|
|||
struct common_params_speculative speculative;
|
||||
|
||||
// OAI-compat fields
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_syntax oaicompat_chat_syntax;
|
||||
|
||||
json to_json() const {
|
||||
std::vector<std::string> samplers;
|
||||
|
@ -176,7 +177,10 @@ struct slot_params {
|
|||
{"grammar_lazy", sampling.grammar_lazy},
|
||||
{"grammar_triggers", grammar_triggers},
|
||||
{"preserved_tokens", sampling.preserved_tokens},
|
||||
{"chat_format", common_chat_format_name(oaicompat_chat_format)},
|
||||
{"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)},
|
||||
{"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)},
|
||||
{"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content},
|
||||
{"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open},
|
||||
{"samplers", samplers},
|
||||
{"speculative.n_max", speculative.n_max},
|
||||
{"speculative.n_min", speculative.n_min},
|
||||
|
@ -352,11 +356,14 @@ struct server_task {
|
|||
{
|
||||
auto it = data.find("chat_format");
|
||||
if (it != data.end()) {
|
||||
params.oaicompat_chat_format = static_cast<common_chat_format>(it->get<int>());
|
||||
SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str());
|
||||
params.oaicompat_chat_syntax.format = static_cast<common_chat_format>(it->get<int>());
|
||||
SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format));
|
||||
} else {
|
||||
params.oaicompat_chat_format = defaults.oaicompat_chat_format;
|
||||
params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
|
||||
}
|
||||
params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format;
|
||||
params.oaicompat_chat_syntax.reasoning_in_content = params.stream;
|
||||
params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -396,7 +403,14 @@ struct server_task {
|
|||
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
|
||||
}
|
||||
} else {
|
||||
params.sampling.grammar_triggers.push_back(std::move(ct.value));
|
||||
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) {
|
||||
SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str());
|
||||
} else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) {
|
||||
SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str());
|
||||
} else {
|
||||
throw std::runtime_error("Unknown grammar trigger type");
|
||||
}
|
||||
params.sampling.grammar_triggers.emplace_back(std::move(ct.value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -639,11 +653,12 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
slot_params generation_params;
|
||||
|
||||
// OAI-compat fields
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_msg oaicompat_msg;
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
|
@ -738,47 +753,20 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
json to_json_oaicompat_chat() {
|
||||
std::string finish_reason = "length";
|
||||
common_chat_msg msg;
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
SRV_DBG("Parsing chat message: %s\n", content.c_str());
|
||||
msg = common_chat_parse(content, oaicompat_chat_format);
|
||||
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
|
||||
if (!oaicompat_msg.empty()) {
|
||||
msg = oaicompat_msg;
|
||||
} else {
|
||||
msg.role = "assistant";
|
||||
msg.content = content;
|
||||
}
|
||||
|
||||
json message {
|
||||
{"role", "assistant"},
|
||||
};
|
||||
if (!msg.reasoning_content.empty()) {
|
||||
message["reasoning_content"] = msg.reasoning_content;
|
||||
}
|
||||
if (msg.content.empty() && !msg.tool_calls.empty()) {
|
||||
message["content"] = json();
|
||||
} else {
|
||||
message["content"] = msg.content;
|
||||
}
|
||||
if (!msg.tool_calls.empty()) {
|
||||
auto tool_calls = json::array();
|
||||
for (const auto & tc : msg.tool_calls) {
|
||||
tool_calls.push_back({
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", tc.name},
|
||||
{"arguments", tc.arguments},
|
||||
}},
|
||||
// Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
|
||||
// We only generate a random id for the ones that don't generate one by themselves
|
||||
// (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
|
||||
{"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
|
||||
});
|
||||
}
|
||||
message["tool_calls"] = tool_calls;
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
|
||||
}
|
||||
|
||||
json choice {
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"message", message},
|
||||
{"message", msg.to_json_oaicompat<json>()},
|
||||
};
|
||||
|
||||
if (!stream && probs_output.size() > 0) {
|
||||
|
@ -818,17 +806,35 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
std::time_t t = std::time(0);
|
||||
std::string finish_reason = "length";
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
finish_reason = "stop";
|
||||
finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls";
|
||||
}
|
||||
|
||||
json choice = json {
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()}
|
||||
};
|
||||
json deltas = json::array();
|
||||
for (const auto & diff : oaicompat_msg_diffs) {
|
||||
deltas.push_back({
|
||||
{"choices", json::array({
|
||||
json {
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", common_chat_msg_diff_to_json_oaicompat<json>(diff)},
|
||||
},
|
||||
})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "chat.completion.chunk"},
|
||||
});
|
||||
}
|
||||
|
||||
json ret = json {
|
||||
{"choices", json::array({choice})},
|
||||
deltas.push_back({
|
||||
{"choices", json::array({
|
||||
json {
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()},
|
||||
},
|
||||
})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
|
@ -839,18 +845,18 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens},
|
||||
}},
|
||||
};
|
||||
});
|
||||
|
||||
if (timings.prompt_n >= 0) {
|
||||
ret.push_back({"timings", timings.to_json()});
|
||||
deltas.back().push_back({"timings", timings.to_json()});
|
||||
}
|
||||
|
||||
// extra fields for debugging purposes
|
||||
if (verbose) {
|
||||
ret["__verbose"] = to_json_non_oaicompat();
|
||||
if (verbose && !deltas.empty()) {
|
||||
deltas.front()["__verbose"] = to_json_non_oaicompat();
|
||||
}
|
||||
|
||||
return ret;
|
||||
return deltas;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -868,10 +874,11 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
result_timings timings;
|
||||
|
||||
// OAI-compat fields
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
|
@ -955,84 +962,50 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
std::time_t t = std::time(0);
|
||||
json choices;
|
||||
|
||||
if (first) {
|
||||
if (content.empty()) {
|
||||
choices = json::array({json{{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", json{{"role", "assistant"}}}}});
|
||||
} else {
|
||||
// We have to send this as two updates to conform to openai behavior
|
||||
// initial_ret is the role message for stream=True
|
||||
json initial_ret = json{{"choices", json::array({json{
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", json{
|
||||
{"role", "assistant"},
|
||||
{"content", ""}
|
||||
}}}})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "chat.completion.chunk"}};
|
||||
|
||||
json second_ret = json{
|
||||
{"choices", json::array({json{{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", json {
|
||||
{"content", content}}}
|
||||
}})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "chat.completion.chunk"}};
|
||||
|
||||
if (prob_output.probs.size() > 0) {
|
||||
second_ret["choices"][0]["logprobs"] = json{
|
||||
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
||||
};
|
||||
}
|
||||
|
||||
if (timings.prompt_n >= 0) {
|
||||
second_ret.push_back({"timings", timings.to_json()});
|
||||
}
|
||||
|
||||
return std::vector<json>({initial_ret, second_ret});
|
||||
}
|
||||
} else {
|
||||
choices = json::array({json{
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta",
|
||||
json {
|
||||
{"content", content},
|
||||
}},
|
||||
}});
|
||||
}
|
||||
|
||||
GGML_ASSERT(choices.size() >= 1);
|
||||
|
||||
if (prob_output.probs.size() > 0) {
|
||||
choices[0]["logprobs"] = json{
|
||||
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
||||
};
|
||||
}
|
||||
|
||||
json ret = json {
|
||||
{"choices", choices},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "chat.completion.chunk"}
|
||||
std::vector<json> deltas;
|
||||
auto add_delta = [&](const json & delta) {
|
||||
deltas.push_back({
|
||||
{"choices", json::array({
|
||||
json {
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", delta},
|
||||
},
|
||||
})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "chat.completion.chunk"},
|
||||
});
|
||||
};
|
||||
|
||||
if (timings.prompt_n >= 0) {
|
||||
ret.push_back({"timings", timings.to_json()});
|
||||
// We have to send an initial update to conform to openai behavior
|
||||
if (first) {
|
||||
add_delta({
|
||||
{"role", "assistant"},
|
||||
{"content", nullptr},
|
||||
});
|
||||
}
|
||||
|
||||
return std::vector<json>({ret});
|
||||
for (const auto & diff : oaicompat_msg_diffs) {
|
||||
add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
|
||||
}
|
||||
|
||||
if (!deltas.empty()) {
|
||||
GGML_ASSERT(deltas[deltas.size() - 1].at("choices").size() >= 1);
|
||||
|
||||
if (prob_output.probs.size() > 0) {
|
||||
deltas[deltas.size() - 1].at("choices").at(0)["logprobs"] = json {
|
||||
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
||||
};
|
||||
}
|
||||
|
||||
if (timings.prompt_n >= 0) {
|
||||
deltas[deltas.size() - 1].push_back({"timings", timings.to_json()});
|
||||
}
|
||||
}
|
||||
|
||||
return deltas;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1293,6 +1266,7 @@ struct server_slot {
|
|||
|
||||
std::string generated_text;
|
||||
llama_tokens generated_tokens;
|
||||
common_chat_msg chat_msg;
|
||||
|
||||
server_tokens cache_tokens;
|
||||
|
||||
|
@ -1313,6 +1287,7 @@ struct server_slot {
|
|||
llama_token sampled;
|
||||
|
||||
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
std::vector<std::string> generated_tool_call_ids;
|
||||
|
||||
// stats
|
||||
size_t n_sent_text = 0; // number of sent text character
|
||||
|
@ -1342,9 +1317,13 @@ struct server_slot {
|
|||
n_past = 0;
|
||||
n_sent_text = 0;
|
||||
task_type = SERVER_TASK_TYPE_COMPLETION;
|
||||
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
|
||||
generated_tokens.clear();
|
||||
generated_token_probs.clear();
|
||||
chat_msg = {};
|
||||
json_schema = json();
|
||||
generated_tool_call_ids.clear();
|
||||
|
||||
// clear speculative decoding stats
|
||||
n_draft_total = 0;
|
||||
|
@ -1424,6 +1403,21 @@ struct server_slot {
|
|||
return timings;
|
||||
}
|
||||
|
||||
const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
|
||||
auto previous_msg = chat_msg;
|
||||
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
|
||||
auto new_msg = common_chat_parse(
|
||||
generated_text,
|
||||
/* is_partial= */ stop != STOP_TYPE_EOS,
|
||||
params.oaicompat_chat_syntax);
|
||||
if (!new_msg.empty()) {
|
||||
new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id);
|
||||
chat_msg = new_msg;
|
||||
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
|
||||
}
|
||||
return chat_msg;
|
||||
}
|
||||
|
||||
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
|
||||
size_t stop_pos = std::string::npos;
|
||||
|
||||
|
@ -2095,6 +2089,7 @@ struct server_context {
|
|||
/* common_chat_templates */ chat_templates.get(),
|
||||
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
|
||||
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
|
||||
/* enable_thinking */ params_base.reasoning_budget != 0,
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -2475,10 +2470,12 @@ struct server_context {
|
|||
res->n_prompt_tokens = slot.n_prompt_tokens;
|
||||
res->post_sampling_probs = slot.params.post_sampling_probs;
|
||||
|
||||
res->verbose = slot.params.verbose;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res->verbose = slot.params.verbose;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
|
||||
slot.update_chat_msg(res->oaicompat_msg_diffs);
|
||||
|
||||
// populate res.probs_output
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
|
@ -2499,7 +2496,7 @@ struct server_context {
|
|||
res->id_slot = slot.id;
|
||||
|
||||
res->index = slot.index;
|
||||
res->content = std::move(slot.generated_text);
|
||||
res->content = slot.generated_text;
|
||||
res->tokens = std::move(slot.generated_tokens);
|
||||
res->timings = slot.get_timings();
|
||||
res->prompt = slot.prompt_tokens.detokenize(ctx, true);
|
||||
|
@ -2519,7 +2516,8 @@ struct server_context {
|
|||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res->oaicompat_chat_format = slot.params.oaicompat_chat_format;
|
||||
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
|
||||
|
||||
// populate res.probs_output
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
|
||||
|
|
|
@ -75,7 +75,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
|
|||
choice = data["choices"][0]
|
||||
if i == 0:
|
||||
# Check first role message for stream=True
|
||||
assert choice["delta"]["content"] == ""
|
||||
assert choice["delta"]["content"] is None
|
||||
assert choice["delta"]["role"] == "assistant"
|
||||
else:
|
||||
assert "role" not in choice["delta"]
|
||||
|
@ -92,7 +92,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
|
|||
assert choice["finish_reason"] == finish_reason
|
||||
else:
|
||||
assert choice["finish_reason"] is None
|
||||
content += choice["delta"]["content"]
|
||||
content += choice["delta"]["content"] or ''
|
||||
|
||||
|
||||
def test_chat_completion_with_openai_library():
|
||||
|
@ -251,8 +251,9 @@ def test_chat_completion_with_timings_per_token():
|
|||
for i, data in enumerate(res):
|
||||
if i == 0:
|
||||
# Check first role message for stream=True
|
||||
assert data["choices"][0]["delta"]["content"] == ""
|
||||
assert data["choices"][0]["delta"]["content"] is None
|
||||
assert data["choices"][0]["delta"]["role"] == "assistant"
|
||||
assert "timings" not in data, f'First event should not have timings: {data}'
|
||||
else:
|
||||
assert "role" not in data["choices"][0]["delta"]
|
||||
assert "timings" in data
|
||||
|
@ -311,7 +312,7 @@ def test_logprobs_stream():
|
|||
choice = data.choices[0]
|
||||
if i == 0:
|
||||
# Check first role message for stream=True
|
||||
assert choice.delta.content == ""
|
||||
assert choice.delta.content is None
|
||||
assert choice.delta.role == "assistant"
|
||||
else:
|
||||
assert choice.delta.role is None
|
||||
|
|
|
@ -25,6 +25,40 @@ def create_server():
|
|||
server.n_slots = 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
|
||||
@pytest.mark.parametrize("template_name,reasoning_budget,expected_end", [
|
||||
("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", None, "<think>\n"),
|
||||
("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", -1, "<think>\n"),
|
||||
("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", 0, "<think>\n</think>"),
|
||||
|
||||
("Qwen-Qwen3-0.6B", -1, "<|im_start|>assistant\n"),
|
||||
("Qwen-Qwen3-0.6B", 0, "<|im_start|>assistant\n<think>\n\n</think>\n\n"),
|
||||
|
||||
("Qwen-QwQ-32B", -1, "<|im_start|>assistant\n<think>\n"),
|
||||
("Qwen-QwQ-32B", 0, "<|im_start|>assistant\n<think>\n</think>"),
|
||||
|
||||
("CohereForAI-c4ai-command-r7b-12-2024-tool_use", -1, "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"),
|
||||
("CohereForAI-c4ai-command-r7b-12-2024-tool_use", 0, "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|><|END_THINKING|>"),
|
||||
])
|
||||
def test_reasoning_budget(template_name: str, reasoning_budget: int | None, expected_end: str, tools: list[dict]):
|
||||
global server
|
||||
server.jinja = True
|
||||
server.reasoning_budget = reasoning_budget
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
|
||||
res = server.make_request("POST", "/apply-template", data={
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is today?"},
|
||||
],
|
||||
"tools": tools,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
prompt = res.body["prompt"]
|
||||
|
||||
assert prompt.endswith(expected_end), f"Expected prompt to end with '{expected_end}', got '{prompt}'"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
|
||||
@pytest.mark.parametrize("template_name,format", [
|
||||
("meta-llama-Llama-3.3-70B-Instruct", "%d %b %Y"),
|
||||
|
@ -47,3 +81,28 @@ def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
|
|||
|
||||
today_str = datetime.date.today().strftime(format)
|
||||
assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("add_generation_prompt", [False, True])
|
||||
@pytest.mark.parametrize("template_name,expected_generation_prompt", [
|
||||
("meta-llama-Llama-3.3-70B-Instruct", "<|start_header_id|>assistant<|end_header_id|>"),
|
||||
])
|
||||
def test_add_generation_prompt(template_name: str, expected_generation_prompt: str, add_generation_prompt: bool):
|
||||
global server
|
||||
server.jinja = True
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
|
||||
res = server.make_request("POST", "/apply-template", data={
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is today?"},
|
||||
],
|
||||
"add_generation_prompt": add_generation_prompt,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
prompt = res.body["prompt"]
|
||||
|
||||
if add_generation_prompt:
|
||||
assert expected_generation_prompt in prompt, f"Expected generation prompt ({expected_generation_prompt}) in content ({prompt})"
|
||||
else:
|
||||
assert expected_generation_prompt not in prompt, f"Did not expect generation prompt ({expected_generation_prompt}) in content ({prompt})"
|
||||
|
|
|
@ -8,6 +8,7 @@ path = Path(__file__).resolve().parents[1]
|
|||
sys.path.insert(0, str(path))
|
||||
|
||||
from utils import *
|
||||
from enum import Enum
|
||||
|
||||
server: ServerProcess
|
||||
|
||||
|
@ -20,7 +21,11 @@ def create_server():
|
|||
server = ServerPreset.tinyllama2()
|
||||
server.model_alias = "tinyllama-2-tool-call"
|
||||
server.server_port = 8081
|
||||
server.n_slots = 1
|
||||
|
||||
class CompletionMode(Enum):
|
||||
NORMAL = "normal"
|
||||
STREAMED = "streamed"
|
||||
|
||||
TEST_TOOL = {
|
||||
"type":"function",
|
||||
|
@ -73,9 +78,8 @@ WEATHER_TOOL = {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs):
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
|
@ -86,13 +90,13 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
|
|||
"parallel_tool_calls": False,
|
||||
**kwargs,
|
||||
})
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
# assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
|
||||
assert expected_function_name == tool_call["function"]["name"]
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
|
@ -102,12 +106,16 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
|
|||
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
||||
("google-gemma-2-2b-it", TEST_TOOL, "success"),
|
||||
("google-gemma-2-2b-it", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
|
||||
])
|
||||
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
|
||||
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
|
||||
global server
|
||||
n_predict = 1024
|
||||
# server = ServerPreset.stories15m_moe()
|
||||
|
@ -115,31 +123,43 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict,
|
|||
server.n_predict = n_predict
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, temperature=0.0, top_k=1, top_p=1.0)
|
||||
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
||||
("meta-llama-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
|
||||
|
||||
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
|
||||
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
|
||||
|
||||
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
|
||||
("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
|
||||
# Functionary v3.2 format supports raw python content, which w/ a dummy stories model will never end on its own.
|
||||
# ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
|
||||
|
||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
|
||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
|
||||
|
||||
("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
|
||||
|
||||
("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
|
||||
("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
|
||||
|
||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
|
||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"),
|
||||
|
||||
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
|
||||
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
|
||||
|
||||
("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"),
|
||||
# ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "codeFalse), True),
|
||||
# ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"),
|
||||
|
||||
])
|
||||
def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None):
|
||||
def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
|
||||
global server
|
||||
n_predict = 512
|
||||
# server = ServerPreset.stories15m_moe()
|
||||
|
@ -147,10 +167,11 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
|
|||
server.n_predict = n_predict
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict)
|
||||
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [
|
||||
(TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
|
@ -184,9 +205,9 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
|
|||
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
# (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
# (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
# (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
|
||||
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
|
||||
|
@ -203,10 +224,9 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
|
|||
(TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||
global server
|
||||
n_predict = 512
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = n_predict
|
||||
|
@ -219,7 +239,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
|||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
|
@ -228,12 +248,12 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
|||
"tool_choice": "required",
|
||||
"tools": [tool],
|
||||
"parallel_tool_calls": False,
|
||||
"stream": stream == CompletionMode.STREAMED,
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"top_p": 1.0,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
choice = body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
|
@ -248,7 +268,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
|||
|
||||
|
||||
def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs):
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
|
@ -258,26 +278,27 @@ def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int,
|
|||
"tool_choice": tool_choice,
|
||||
**kwargs,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
choice = body["choices"][0]
|
||||
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
|
||||
("meta-llama-Llama-3.3-70B-Instruct", 128, [], None),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'),
|
||||
])
|
||||
def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
|
||||
def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
|
||||
global server
|
||||
server.jinja = True
|
||||
server.n_predict = n_predict
|
||||
server.jinja = True
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
|
||||
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
|
||||
("meetkai-functionary-medium-v3.2", 256, [], None),
|
||||
("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None),
|
||||
|
@ -289,16 +310,17 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t
|
|||
("meta-llama-Llama-3.2-3B-Instruct", 256, [TEST_TOOL], None),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'),
|
||||
])
|
||||
def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
|
||||
def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
|
||||
global server
|
||||
server.jinja = True
|
||||
server.n_predict = n_predict
|
||||
server.jinja = True
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
|
||||
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("hf_repo,template_override", [
|
||||
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
@ -321,11 +343,11 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
|
|||
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
# ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
# ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
|
||||
("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
|
||||
# ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
|
||||
# ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
|
||||
|
||||
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
@ -339,10 +361,9 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
|
|||
|
||||
# ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
])
|
||||
def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||
global server
|
||||
n_predict = 512
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = n_predict
|
||||
|
@ -355,11 +376,11 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
|
|||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_weather(server, max_tokens=n_predict)
|
||||
do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
|
||||
|
||||
|
||||
def do_test_weather(server: ServerProcess, **kwargs):
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
|
||||
{"role": "user", "content": "What is the weather in Istanbul?"},
|
||||
|
@ -367,14 +388,13 @@ def do_test_weather(server: ServerProcess, **kwargs):
|
|||
"tools": [WEATHER_TOOL],
|
||||
**kwargs,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
choice = body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"], f'Expected weather tool call, got {tool_call["function"]["name"]}'
|
||||
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
|
||||
location = actual_arguments["location"]
|
||||
|
@ -383,6 +403,7 @@ def do_test_weather(server: ServerProcess, **kwargs):
|
|||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [
|
||||
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
||||
(None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
|
||||
|
@ -400,9 +421,8 @@ def do_test_weather(server: ServerProcess, **kwargs):
|
|||
# (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
# ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||
global server
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192 * 2
|
||||
server.n_predict = n_predict
|
||||
|
@ -415,11 +435,11 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
|
|||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_calc_result(server, result_override, n_predict)
|
||||
do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED)
|
||||
|
||||
|
||||
def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs):
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."},
|
||||
|
@ -466,8 +486,7 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr
|
|||
],
|
||||
**kwargs,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
choice = body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls is None, f'Expected no tool call in {choice["message"]}'
|
||||
content = choice["message"].get("content")
|
||||
|
@ -480,18 +499,18 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr
|
|||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("n_predict,reasoning_format,expect_content,expect_reasoning_content,hf_repo,template_override", [
|
||||
(128, 'deepseek', "^The sum of 102 and 7 is 109[\\s\\S]*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(128, None, "^The sum of 102 and 7 is 109[\\s\\S]*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
|
||||
(1024, 'deepseek', "To find the sum of[\\s\\S]*", "I need to calculate the sum of 102 and 7[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(1024, 'none', "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
|
||||
(1024, 'deepseek', "To find the sum of[\\s\\S]*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
@pytest.mark.parametrize("n_predict,reasoning_format,stream,expect_reasoning_content,expect_content,hf_repo,template_override", [
|
||||
(128, 'deepseek', CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(128, None, CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(1024, 'deepseek', CompletionMode.NORMAL, "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(1024, 'deepseek', CompletionMode.STREAMED, None, "^<think>I need to calculate [\\s\\S]*?</think>To find the sum of [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(1024, 'deepseek', CompletionMode.NORMAL, "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
(1024, 'deepseek', CompletionMode.STREAMED, None, "^<think>First, I [\\s\\S]*?</think>To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
# (1024, 'none', CompletionMode.NORMAL, None, "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
# (128, 'deepseek', None, "^Okay, let me figure out the sum of 102 and 7[\\s\\S]*", "bartowski/Qwen_QwQ-32B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||
global server
|
||||
server.n_slots = 1
|
||||
server.reasoning_format = reasoning_format
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192 * 2
|
||||
|
@ -505,14 +524,14 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
|
|||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the sum of 102 and 7?"},
|
||||
]
|
||||
],
|
||||
"stream": stream == CompletionMode.STREAMED,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
choice = body["choices"][0]
|
||||
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
|
||||
|
||||
content = choice["message"].get("content")
|
||||
|
@ -529,6 +548,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
|
|||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("hf_repo,template_override", [
|
||||
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
|
||||
|
@ -562,10 +582,9 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
|
|||
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
||||
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"),
|
||||
])
|
||||
def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||
global server
|
||||
n_predict = 512 # High because of DeepSeek R1
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = n_predict
|
||||
|
@ -579,11 +598,11 @@ def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | Non
|
|||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
|
||||
do_test_hello_world(server, max_tokens=n_predict)
|
||||
do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
|
||||
|
||||
|
||||
def do_test_hello_world(server: ServerProcess, **kwargs):
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a tool-calling agent."},
|
||||
{"role": "user", "content": "say hello world with python"},
|
||||
|
@ -591,16 +610,15 @@ def do_test_hello_world(server: ServerProcess, **kwargs):
|
|||
"tools": [PYTHON_TOOL],
|
||||
**kwargs,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
choice = body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
|
||||
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||
assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
|
||||
code = actual_arguments["code"]
|
||||
assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
|
||||
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'
|
||||
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', re.sub(r'#.*\n?', '', code)), f'Expected hello world, got {code}'
|
||||
|
|
|
@ -84,7 +84,8 @@ class ServerProcess:
|
|||
draft_max: int | None = None
|
||||
no_webui: bool | None = None
|
||||
jinja: bool | None = None
|
||||
reasoning_format: Literal['deepseek', 'none'] | None = None
|
||||
reasoning_format: Literal['deepseek', 'none', 'nothink'] | None = None
|
||||
reasoning_budget: int | None = None
|
||||
chat_template: str | None = None
|
||||
chat_template_file: str | None = None
|
||||
server_path: str | None = None
|
||||
|
@ -191,6 +192,8 @@ class ServerProcess:
|
|||
server_args.append("--jinja")
|
||||
if self.reasoning_format is not None:
|
||||
server_args.extend(("--reasoning-format", self.reasoning_format))
|
||||
if self.reasoning_budget is not None:
|
||||
server_args.extend(("--reasoning-budget", self.reasoning_budget))
|
||||
if self.chat_template:
|
||||
server_args.extend(["--chat-template", self.chat_template])
|
||||
if self.chat_template_file:
|
||||
|
@ -294,6 +297,77 @@ class ServerProcess:
|
|||
print("Partial response from server", json.dumps(data, indent=2))
|
||||
yield data
|
||||
|
||||
def make_any_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: dict | None = None,
|
||||
headers: dict | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> dict:
|
||||
stream = data.get('stream', False)
|
||||
if stream:
|
||||
content: list[str] = []
|
||||
tool_calls: list[dict] = []
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
content_parts = 0
|
||||
tool_call_parts = 0
|
||||
arguments_parts = 0
|
||||
|
||||
for chunk in self.make_stream_request(method, path, data, headers):
|
||||
assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}'
|
||||
choice = chunk['choices'][0]
|
||||
if choice['delta'].get('content') is not None:
|
||||
assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!'
|
||||
content.append(choice['delta']['content'])
|
||||
content_parts += 1
|
||||
if choice['delta'].get('finish_reason') is not None:
|
||||
finish_reason = choice['delta']['finish_reason']
|
||||
for tc in choice['delta'].get('tool_calls', []):
|
||||
if 'function' not in tc:
|
||||
raise ValueError(f"Expected function type, got {tc['type']}")
|
||||
if tc['index'] >= len(tool_calls):
|
||||
tool_calls.append(dict(
|
||||
id="",
|
||||
type="function",
|
||||
function=dict(
|
||||
name="",
|
||||
arguments="",
|
||||
)
|
||||
))
|
||||
tool_call = tool_calls[tc['index']]
|
||||
if tc.get('id') is not None:
|
||||
tool_call['id'] = tc['id']
|
||||
fct = tc['function']
|
||||
if fct.get('name') is not None:
|
||||
tool_call['function']['name'] = fct['name']
|
||||
if fct.get('arguments') is not None:
|
||||
assert len(fct['arguments']) > 0, f'Expected non empty arguments delta!'
|
||||
tool_call['function']['arguments'] += fct['arguments']
|
||||
|
||||
print(f'Streamed response had {content_parts} content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
|
||||
result = dict(
|
||||
choices=[
|
||||
dict(
|
||||
index=0,
|
||||
finish_reason=finish_reason,
|
||||
message=dict(
|
||||
role='assistant',
|
||||
content=''.join(content) if content else None,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
print("Final response from server", json.dumps(result, indent=2))
|
||||
return result
|
||||
else:
|
||||
response = self.make_request(method, path, data, headers, timeout=timeout)
|
||||
assert response.status_code == 200, f"Server returned error: {response.status_code}"
|
||||
return response.body
|
||||
|
||||
|
||||
|
||||
server_instances: Set[ServerProcess] = set()
|
||||
|
||||
|
|
|
@ -474,26 +474,6 @@ static std::string gen_tool_call_id() {
|
|||
// other common utils
|
||||
//
|
||||
|
||||
static bool ends_with(const std::string & str, const std::string & suffix) {
|
||||
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
|
||||
}
|
||||
|
||||
static size_t find_partial_stop_string(const std::string &stop, const std::string &text) {
|
||||
if (!text.empty() && !stop.empty()) {
|
||||
const char text_last_char = text.back();
|
||||
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
|
||||
if (stop[char_index] == text_last_char) {
|
||||
const std::string current_partial = stop.substr(0, char_index + 1);
|
||||
if (ends_with(text, current_partial)) {
|
||||
return text.size() - char_index - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
// TODO: reuse llama_detokenize
|
||||
template <class Iter>
|
||||
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
|
||||
|
@ -588,6 +568,7 @@ struct oaicompat_parser_options {
|
|||
common_chat_templates * tmpls;
|
||||
bool allow_image;
|
||||
bool allow_audio;
|
||||
bool enable_thinking = true;
|
||||
};
|
||||
|
||||
// used by /chat/completions endpoint
|
||||
|
@ -599,19 +580,16 @@ static json oaicompat_chat_params_parse(
|
|||
json llama_params;
|
||||
|
||||
auto tools = json_value(body, "tools", json());
|
||||
auto has_tools = tools.is_array() && !tools.empty();
|
||||
auto stream = json_value(body, "stream", false);
|
||||
auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
|
||||
|
||||
if (tools.is_array() && !tools.empty()) {
|
||||
if (stream) {
|
||||
throw std::runtime_error("Cannot use tools with stream");
|
||||
}
|
||||
if (!opt.use_jinja) {
|
||||
if (!opt.use_jinja) {
|
||||
if (has_tools) {
|
||||
throw std::runtime_error("tools param requires --jinja flag");
|
||||
}
|
||||
}
|
||||
if (!opt.use_jinja) {
|
||||
if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) {
|
||||
throw std::runtime_error("Unsupported param: tool_choice");
|
||||
if (tool_choice != "auto") {
|
||||
throw std::runtime_error("tool_choice param requires --jinja flag");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -749,14 +727,14 @@ static json oaicompat_chat_params_parse(
|
|||
common_chat_templates_inputs inputs;
|
||||
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
|
||||
inputs.tools = common_chat_tools_parse_oaicompat(tools);
|
||||
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
|
||||
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice);
|
||||
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
|
||||
inputs.grammar = grammar;
|
||||
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
|
||||
inputs.use_jinja = opt.use_jinja;
|
||||
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
|
||||
inputs.extract_reasoning = opt.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
|
||||
inputs.reasoning_format = opt.reasoning_format;
|
||||
inputs.enable_thinking = opt.enable_thinking;
|
||||
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
|
||||
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
|
||||
}
|
||||
|
@ -774,7 +752,8 @@ static json oaicompat_chat_params_parse(
|
|||
throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list.");
|
||||
}
|
||||
|
||||
inputs.extract_reasoning = false;
|
||||
/* TODO: test this properly */
|
||||
inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
inputs.add_generation_prompt = true;
|
||||
}
|
||||
|
||||
|
@ -799,6 +778,7 @@ static json oaicompat_chat_params_parse(
|
|||
}
|
||||
llama_params["grammar_triggers"] = grammar_triggers;
|
||||
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
|
||||
llama_params["thinking_forced_open"] = chat_params.thinking_forced_open;
|
||||
for (const auto & stop : chat_params.additional_stops) {
|
||||
llama_params["stop"].push_back(stop);
|
||||
}
|
||||
|
@ -812,6 +792,9 @@ static json oaicompat_chat_params_parse(
|
|||
// Handle "logprobs" field
|
||||
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
|
||||
if (json_value(body, "logprobs", false)) {
|
||||
if (has_tools && stream) {
|
||||
throw std::runtime_error("logprobs is not supported with tools + stream");
|
||||
}
|
||||
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
|
||||
} else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
|
||||
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
|
||||
|
|
|
@ -46,8 +46,11 @@ export function useChatExtraContext(): ChatExtraContextApi {
|
|||
try {
|
||||
for (const file of files) {
|
||||
const mimeType = file.type;
|
||||
if (file.size > 10 * 1024 * 1024) {
|
||||
toast.error('File is too large. Maximum size is 10MB.');
|
||||
|
||||
// this limit is only to prevent accidental uploads of huge files
|
||||
// it can potentially crashes the browser because we read the file as base64
|
||||
if (file.size > 500 * 1024 * 1024) {
|
||||
toast.error('File is too large. Maximum size is 500MB.');
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue