Merge commit 'f1e3eb4249' into concedo_experimental

# Conflicts:
#	.github/workflows/build.yml
#	README.md
#	docs/backend/SYCL.md
#	examples/llava/clip.cpp
#	ggml/src/ggml-sycl/CMakeLists.txt
#	ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in
This commit is contained in:
Concedo 2025-04-08 20:48:53 +08:00
commit 822cf2430e
26 changed files with 1822 additions and 2088 deletions

121
.github/workflows/build-linux-cross.yml vendored Normal file
View file

@ -0,0 +1,121 @@
name: Build on Linux using cross-compiler
on:
workflow_dispatch:
workflow_call:
jobs:
ubuntu-latest-riscv64-cpu-cross:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Setup Riscv
run: |
sudo dpkg --add-architecture riscv64
sudo sed -i 's|http://azure.archive.ubuntu.com/ubuntu|http://ports.ubuntu.com/ubuntu-ports|g' \
/etc/apt/sources.list /etc/apt/apt-mirrors.txt
sudo apt-get clean
sudo apt-get update
sudo apt-get install -y --no-install-recommends \
build-essential \
gcc-14-riscv64-linux-gnu \
g++-14-riscv64-linux-gnu
- name: Build
run: |
cmake -B build -DCMAKE_BUILD_TYPE=Release \
-DGGML_OPENMP=OFF \
-DLLAMA_BUILD_EXAMPLES=ON \
-DLLAMA_BUILD_TESTS=OFF \
-DCMAKE_SYSTEM_NAME=Linux \
-DCMAKE_SYSTEM_PROCESSOR=riscv64 \
-DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
-DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
-DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
cmake --build build --config Release -j $(nproc)
ubuntu-latest-riscv64-vulkan-cross:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Riscv
run: |
sudo dpkg --add-architecture riscv64
sudo sed -i 's|http://azure.archive.ubuntu.com/ubuntu|http://ports.ubuntu.com/ubuntu-ports|g' \
/etc/apt/sources.list /etc/apt/apt-mirrors.txt
sudo apt-get clean
sudo apt-get update
sudo apt-get install -y --no-install-recommends \
build-essential \
glslc \
gcc-14-riscv64-linux-gnu \
g++-14-riscv64-linux-gnu \
libvulkan-dev:riscv64
- name: Build
run: |
cmake -B build -DCMAKE_BUILD_TYPE=Release \
-DGGML_VULKAN=ON \
-DGGML_OPENMP=OFF \
-DLLAMA_BUILD_EXAMPLES=ON \
-DLLAMA_BUILD_TESTS=OFF \
-DCMAKE_SYSTEM_NAME=Linux \
-DCMAKE_SYSTEM_PROCESSOR=riscv64 \
-DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
-DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
-DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
cmake --build build --config Release -j $(nproc)
ubuntu-latest-arm64-vulkan-cross:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Arm64
run: |
sudo dpkg --add-architecture arm64
sudo sed -i 's|http://azure.archive.ubuntu.com/ubuntu|http://ports.ubuntu.com/ubuntu-ports|g' \
/etc/apt/sources.list /etc/apt/apt-mirrors.txt
sudo apt-get clean
sudo apt-get update
sudo apt-get install -y --no-install-recommends \
build-essential \
glslc \
crossbuild-essential-arm64 \
libvulkan-dev:arm64
- name: Build
run: |
cmake -B build -DCMAKE_BUILD_TYPE=Release \
-DGGML_VULKAN=ON \
-DGGML_OPENMP=OFF \
-DLLAMA_BUILD_EXAMPLES=ON \
-DLLAMA_BUILD_TESTS=OFF \
-DCMAKE_SYSTEM_NAME=Linux \
-DCMAKE_SYSTEM_PROCESSOR=aarch64 \
-DCMAKE_C_COMPILER=aarch64-linux-gnu-gcc \
-DCMAKE_CXX_COMPILER=aarch64-linux-gnu-g++ \
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
-DCMAKE_FIND_ROOT_PATH=/usr/lib/aarch64-linux-gnu \
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
cmake --build build --config Release -j $(nproc)

View file

@ -19,6 +19,7 @@
#include <algorithm> #include <algorithm>
#include <climits> #include <climits>
#include <cstdarg> #include <cstdarg>
#include <filesystem>
#include <fstream> #include <fstream>
#include <regex> #include <regex>
#include <set> #include <set>
@ -657,9 +658,13 @@ static void common_params_handle_model(
} }
} }
// TODO: allow custom host std::string hf_endpoint = "https://huggingface.co/";
model.url = "https://huggingface.co/" + model.hf_repo + "/resolve/main/" + model.hf_file; const char * hf_endpoint_env = getenv("HF_ENDPOINT");
if (hf_endpoint_env) {
hf_endpoint = hf_endpoint_env;
if (hf_endpoint.back() != '/') hf_endpoint += '/';
}
model.url = hf_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file;
// make sure model path is present (for caching purposes) // make sure model path is present (for caching purposes)
if (model.path.empty()) { if (model.path.empty()) {
// this is to avoid different repo having same file name, or same file name in different subdirs // this is to avoid different repo having same file name, or same file name in different subdirs

View file

@ -9,10 +9,19 @@
#pragma once #pragma once
#include "minja.hpp" #include "minja.hpp"
#include <json.hpp>
#include <chrono>
#include <cstddef>
#include <cstdio>
#include <exception>
#include <iomanip>
#include <memory>
#include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include <json.hpp>
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
namespace minja { namespace minja {
@ -425,7 +434,7 @@ class chat_template {
auto obj = json { auto obj = json {
{"tool_calls", tool_calls}, {"tool_calls", tool_calls},
}; };
if (!content.is_null() && content != "") { if (!content.is_null() && !content.empty()) {
obj["content"] = content; obj["content"] = content;
} }
message["content"] = obj.dump(2); message["content"] = obj.dump(2);
@ -435,13 +444,12 @@ class chat_template {
if (polyfill_tool_responses && role == "tool") { if (polyfill_tool_responses && role == "tool") {
message["role"] = "user"; message["role"] = "user";
auto obj = json { auto obj = json {
{"tool_response", { {"tool_response", json::object()},
{"content", message.at("content")},
}},
}; };
if (message.contains("name")) { if (message.contains("name")) {
obj["tool_response"]["name"] = message.at("name"); obj["tool_response"]["tool"] = message.at("name");
} }
obj["tool_response"]["content"] = message.at("content");
if (message.contains("tool_call_id")) { if (message.contains("tool_call_id")) {
obj["tool_response"]["tool_call_id"] = message.at("tool_call_id"); obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
} }
@ -510,7 +518,7 @@ class chat_template {
static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) { static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
json messages_with_system = messages; json messages_with_system = messages;
if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") { if (!messages_with_system.empty() && messages_with_system[0].at("role") == "system") {
std::string existing_system = messages_with_system.at(0).at("content"); std::string existing_system = messages_with_system.at(0).at("content");
messages_with_system[0] = json { messages_with_system[0] = json {
{"role", "system"}, {"role", "system"},

View file

@ -8,14 +8,26 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
#pragma once #pragma once
#include <algorithm>
#include <cctype>
#include <cstddef>
#include <cmath>
#include <exception>
#include <functional>
#include <iostream> #include <iostream>
#include <string> #include <iterator>
#include <vector> #include <limits>
#include <regex> #include <map>
#include <memory> #include <memory>
#include <stdexcept> #include <regex>
#include <sstream> #include <sstream>
#include <string>
#include <stdexcept>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include <vector>
#include <json.hpp> #include <json.hpp>
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
@ -240,7 +252,7 @@ public:
auto index = key.get<int>(); auto index = key.get<int>();
return array_->at(index < 0 ? array_->size() + index : index); return array_->at(index < 0 ? array_->size() + index : index);
} else if (object_) { } else if (object_) {
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump()); if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
auto it = object_->find(key.primitive_); auto it = object_->find(key.primitive_);
if (it == object_->end()) return Value(); if (it == object_->end()) return Value();
return it->second; return it->second;
@ -249,7 +261,7 @@ public:
} }
void set(const Value& key, const Value& value) { void set(const Value& key, const Value& value) {
if (!object_) throw std::runtime_error("Value is not an object: " + dump()); if (!object_) throw std::runtime_error("Value is not an object: " + dump());
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump()); if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
(*object_)[key.primitive_] = value; (*object_)[key.primitive_] = value;
} }
Value call(const std::shared_ptr<Context> & context, ArgumentsValue & args) const { Value call(const std::shared_ptr<Context> & context, ArgumentsValue & args) const {
@ -731,51 +743,51 @@ public:
struct TextTemplateToken : public TemplateToken { struct TextTemplateToken : public TemplateToken {
std::string text; std::string text;
TextTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Text, location, pre, post), text(t) {} TextTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Text, loc, pre, post), text(t) {}
}; };
struct ExpressionTemplateToken : public TemplateToken { struct ExpressionTemplateToken : public TemplateToken {
std::shared_ptr<Expression> expr; std::shared_ptr<Expression> expr;
ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {} ExpressionTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && e) : TemplateToken(Type::Expression, loc, pre, post), expr(std::move(e)) {}
}; };
struct IfTemplateToken : public TemplateToken { struct IfTemplateToken : public TemplateToken {
std::shared_ptr<Expression> condition; std::shared_ptr<Expression> condition;
IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {} IfTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && c) : TemplateToken(Type::If, loc, pre, post), condition(std::move(c)) {}
}; };
struct ElifTemplateToken : public TemplateToken { struct ElifTemplateToken : public TemplateToken {
std::shared_ptr<Expression> condition; std::shared_ptr<Expression> condition;
ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {} ElifTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && c) : TemplateToken(Type::Elif, loc, pre, post), condition(std::move(c)) {}
}; };
struct ElseTemplateToken : public TemplateToken { struct ElseTemplateToken : public TemplateToken {
ElseTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Else, location, pre, post) {} ElseTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Else, loc, pre, post) {}
}; };
struct EndIfTemplateToken : public TemplateToken { struct EndIfTemplateToken : public TemplateToken {
EndIfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, location, pre, post) {} EndIfTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, loc, pre, post) {}
}; };
struct MacroTemplateToken : public TemplateToken { struct MacroTemplateToken : public TemplateToken {
std::shared_ptr<VariableExpr> name; std::shared_ptr<VariableExpr> name;
Expression::Parameters params; Expression::Parameters params;
MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<VariableExpr> && n, Expression::Parameters && p) MacroTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr<VariableExpr> && n, Expression::Parameters && p)
: TemplateToken(Type::Macro, location, pre, post), name(std::move(n)), params(std::move(p)) {} : TemplateToken(Type::Macro, loc, pre, post), name(std::move(n)), params(std::move(p)) {}
}; };
struct EndMacroTemplateToken : public TemplateToken { struct EndMacroTemplateToken : public TemplateToken {
EndMacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndMacro, location, pre, post) {} EndMacroTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndMacro, loc, pre, post) {}
}; };
struct FilterTemplateToken : public TemplateToken { struct FilterTemplateToken : public TemplateToken {
std::shared_ptr<Expression> filter; std::shared_ptr<Expression> filter;
FilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && filter) FilterTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && filter)
: TemplateToken(Type::Filter, location, pre, post), filter(std::move(filter)) {} : TemplateToken(Type::Filter, loc, pre, post), filter(std::move(filter)) {}
}; };
struct EndFilterTemplateToken : public TemplateToken { struct EndFilterTemplateToken : public TemplateToken {
EndFilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFilter, location, pre, post) {} EndFilterTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFilter, loc, pre, post) {}
}; };
struct ForTemplateToken : public TemplateToken { struct ForTemplateToken : public TemplateToken {
@ -783,38 +795,38 @@ struct ForTemplateToken : public TemplateToken {
std::shared_ptr<Expression> iterable; std::shared_ptr<Expression> iterable;
std::shared_ptr<Expression> condition; std::shared_ptr<Expression> condition;
bool recursive; bool recursive;
ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector<std::string> & vns, std::shared_ptr<Expression> && iter, ForTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, const std::vector<std::string> & vns, std::shared_ptr<Expression> && iter,
std::shared_ptr<Expression> && c, bool r) std::shared_ptr<Expression> && c, bool r)
: TemplateToken(Type::For, location, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {} : TemplateToken(Type::For, loc, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {}
}; };
struct EndForTemplateToken : public TemplateToken { struct EndForTemplateToken : public TemplateToken {
EndForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, location, pre, post) {} EndForTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, loc, pre, post) {}
}; };
struct GenerationTemplateToken : public TemplateToken { struct GenerationTemplateToken : public TemplateToken {
GenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Generation, location, pre, post) {} GenerationTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Generation, loc, pre, post) {}
}; };
struct EndGenerationTemplateToken : public TemplateToken { struct EndGenerationTemplateToken : public TemplateToken {
EndGenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndGeneration, location, pre, post) {} EndGenerationTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndGeneration, loc, pre, post) {}
}; };
struct SetTemplateToken : public TemplateToken { struct SetTemplateToken : public TemplateToken {
std::string ns; std::string ns;
std::vector<std::string> var_names; std::vector<std::string> var_names;
std::shared_ptr<Expression> value; std::shared_ptr<Expression> value;
SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector<std::string> & vns, std::shared_ptr<Expression> && v) SetTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector<std::string> & vns, std::shared_ptr<Expression> && v)
: TemplateToken(Type::Set, location, pre, post), ns(ns), var_names(vns), value(std::move(v)) {} : TemplateToken(Type::Set, loc, pre, post), ns(ns), var_names(vns), value(std::move(v)) {}
}; };
struct EndSetTemplateToken : public TemplateToken { struct EndSetTemplateToken : public TemplateToken {
EndSetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndSet, location, pre, post) {} EndSetTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndSet, loc, pre, post) {}
}; };
struct CommentTemplateToken : public TemplateToken { struct CommentTemplateToken : public TemplateToken {
std::string text; std::string text;
CommentTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, location, pre, post), text(t) {} CommentTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, loc, pre, post), text(t) {}
}; };
enum class LoopControlType { Break, Continue }; enum class LoopControlType { Break, Continue };
@ -830,7 +842,7 @@ public:
struct LoopControlTemplateToken : public TemplateToken { struct LoopControlTemplateToken : public TemplateToken {
LoopControlType control_type; LoopControlType control_type;
LoopControlTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, location, pre, post), control_type(control_type) {} LoopControlTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, loc, pre, post), control_type(control_type) {}
}; };
class TemplateNode { class TemplateNode {
@ -868,8 +880,8 @@ public:
class SequenceNode : public TemplateNode { class SequenceNode : public TemplateNode {
std::vector<std::shared_ptr<TemplateNode>> children; std::vector<std::shared_ptr<TemplateNode>> children;
public: public:
SequenceNode(const Location & location, std::vector<std::shared_ptr<TemplateNode>> && c) SequenceNode(const Location & loc, std::vector<std::shared_ptr<TemplateNode>> && c)
: TemplateNode(location), children(std::move(c)) {} : TemplateNode(loc), children(std::move(c)) {}
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override { void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
for (const auto& child : children) child->render(out, context); for (const auto& child : children) child->render(out, context);
} }
@ -878,7 +890,7 @@ public:
class TextNode : public TemplateNode { class TextNode : public TemplateNode {
std::string text; std::string text;
public: public:
TextNode(const Location & location, const std::string& t) : TemplateNode(location), text(t) {} TextNode(const Location & loc, const std::string& t) : TemplateNode(loc), text(t) {}
void do_render(std::ostringstream & out, const std::shared_ptr<Context> &) const override { void do_render(std::ostringstream & out, const std::shared_ptr<Context> &) const override {
out << text; out << text;
} }
@ -887,7 +899,7 @@ public:
class ExpressionNode : public TemplateNode { class ExpressionNode : public TemplateNode {
std::shared_ptr<Expression> expr; std::shared_ptr<Expression> expr;
public: public:
ExpressionNode(const Location & location, std::shared_ptr<Expression> && e) : TemplateNode(location), expr(std::move(e)) {} ExpressionNode(const Location & loc, std::shared_ptr<Expression> && e) : TemplateNode(loc), expr(std::move(e)) {}
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override { void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
if (!expr) throw std::runtime_error("ExpressionNode.expr is null"); if (!expr) throw std::runtime_error("ExpressionNode.expr is null");
auto result = expr->evaluate(context); auto result = expr->evaluate(context);
@ -904,8 +916,8 @@ public:
class IfNode : public TemplateNode { class IfNode : public TemplateNode {
std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> cascade; std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> cascade;
public: public:
IfNode(const Location & location, std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> && c) IfNode(const Location & loc, std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> && c)
: TemplateNode(location), cascade(std::move(c)) {} : TemplateNode(loc), cascade(std::move(c)) {}
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override { void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
for (const auto& branch : cascade) { for (const auto& branch : cascade) {
auto enter_branch = true; auto enter_branch = true;
@ -924,7 +936,7 @@ public:
class LoopControlNode : public TemplateNode { class LoopControlNode : public TemplateNode {
LoopControlType control_type_; LoopControlType control_type_;
public: public:
LoopControlNode(const Location & location, LoopControlType control_type) : TemplateNode(location), control_type_(control_type) {} LoopControlNode(const Location & loc, LoopControlType control_type) : TemplateNode(loc), control_type_(control_type) {}
void do_render(std::ostringstream &, const std::shared_ptr<Context> &) const override { void do_render(std::ostringstream &, const std::shared_ptr<Context> &) const override {
throw LoopControlException(control_type_); throw LoopControlException(control_type_);
} }
@ -938,9 +950,9 @@ class ForNode : public TemplateNode {
bool recursive; bool recursive;
std::shared_ptr<TemplateNode> else_body; std::shared_ptr<TemplateNode> else_body;
public: public:
ForNode(const Location & location, std::vector<std::string> && var_names, std::shared_ptr<Expression> && iterable, ForNode(const Location & loc, std::vector<std::string> && var_names, std::shared_ptr<Expression> && iterable,
std::shared_ptr<Expression> && condition, std::shared_ptr<TemplateNode> && body, bool recursive, std::shared_ptr<TemplateNode> && else_body) std::shared_ptr<Expression> && condition, std::shared_ptr<TemplateNode> && body, bool recursive, std::shared_ptr<TemplateNode> && else_body)
: TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {} : TemplateNode(loc), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {}
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override { void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
// https://jinja.palletsprojects.com/en/3.0.x/templates/#for // https://jinja.palletsprojects.com/en/3.0.x/templates/#for
@ -1025,8 +1037,8 @@ class MacroNode : public TemplateNode {
std::shared_ptr<TemplateNode> body; std::shared_ptr<TemplateNode> body;
std::unordered_map<std::string, size_t> named_param_positions; std::unordered_map<std::string, size_t> named_param_positions;
public: public:
MacroNode(const Location & location, std::shared_ptr<VariableExpr> && n, Expression::Parameters && p, std::shared_ptr<TemplateNode> && b) MacroNode(const Location & loc, std::shared_ptr<VariableExpr> && n, Expression::Parameters && p, std::shared_ptr<TemplateNode> && b)
: TemplateNode(location), name(std::move(n)), params(std::move(p)), body(std::move(b)) { : TemplateNode(loc), name(std::move(n)), params(std::move(p)), body(std::move(b)) {
for (size_t i = 0; i < params.size(); ++i) { for (size_t i = 0; i < params.size(); ++i) {
const auto & name = params[i].first; const auto & name = params[i].first;
if (!name.empty()) { if (!name.empty()) {
@ -1072,8 +1084,8 @@ class FilterNode : public TemplateNode {
std::shared_ptr<TemplateNode> body; std::shared_ptr<TemplateNode> body;
public: public:
FilterNode(const Location & location, std::shared_ptr<Expression> && f, std::shared_ptr<TemplateNode> && b) FilterNode(const Location & loc, std::shared_ptr<Expression> && f, std::shared_ptr<TemplateNode> && b)
: TemplateNode(location), filter(std::move(f)), body(std::move(b)) {} : TemplateNode(loc), filter(std::move(f)), body(std::move(b)) {}
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override { void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
if (!filter) throw std::runtime_error("FilterNode.filter is null"); if (!filter) throw std::runtime_error("FilterNode.filter is null");
@ -1095,8 +1107,8 @@ class SetNode : public TemplateNode {
std::vector<std::string> var_names; std::vector<std::string> var_names;
std::shared_ptr<Expression> value; std::shared_ptr<Expression> value;
public: public:
SetNode(const Location & location, const std::string & ns, const std::vector<std::string> & vns, std::shared_ptr<Expression> && v) SetNode(const Location & loc, const std::string & ns, const std::vector<std::string> & vns, std::shared_ptr<Expression> && v)
: TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)) {} : TemplateNode(loc), ns(ns), var_names(vns), value(std::move(v)) {}
void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override { void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
if (!value) throw std::runtime_error("SetNode.value is null"); if (!value) throw std::runtime_error("SetNode.value is null");
if (!ns.empty()) { if (!ns.empty()) {
@ -1118,8 +1130,8 @@ class SetTemplateNode : public TemplateNode {
std::string name; std::string name;
std::shared_ptr<TemplateNode> template_value; std::shared_ptr<TemplateNode> template_value;
public: public:
SetTemplateNode(const Location & location, const std::string & name, std::shared_ptr<TemplateNode> && tv) SetTemplateNode(const Location & loc, const std::string & name, std::shared_ptr<TemplateNode> && tv)
: TemplateNode(location), name(name), template_value(std::move(tv)) {} : TemplateNode(loc), name(name), template_value(std::move(tv)) {}
void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override { void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
if (!template_value) throw std::runtime_error("SetTemplateNode.template_value is null"); if (!template_value) throw std::runtime_error("SetTemplateNode.template_value is null");
Value value { template_value->render(context) }; Value value { template_value->render(context) };
@ -1132,8 +1144,8 @@ class IfExpr : public Expression {
std::shared_ptr<Expression> then_expr; std::shared_ptr<Expression> then_expr;
std::shared_ptr<Expression> else_expr; std::shared_ptr<Expression> else_expr;
public: public:
IfExpr(const Location & location, std::shared_ptr<Expression> && c, std::shared_ptr<Expression> && t, std::shared_ptr<Expression> && e) IfExpr(const Location & loc, std::shared_ptr<Expression> && c, std::shared_ptr<Expression> && t, std::shared_ptr<Expression> && e)
: Expression(location), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {} : Expression(loc), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override { Value do_evaluate(const std::shared_ptr<Context> & context) const override {
if (!condition) throw std::runtime_error("IfExpr.condition is null"); if (!condition) throw std::runtime_error("IfExpr.condition is null");
if (!then_expr) throw std::runtime_error("IfExpr.then_expr is null"); if (!then_expr) throw std::runtime_error("IfExpr.then_expr is null");
@ -1150,16 +1162,16 @@ public:
class LiteralExpr : public Expression { class LiteralExpr : public Expression {
Value value; Value value;
public: public:
LiteralExpr(const Location & location, const Value& v) LiteralExpr(const Location & loc, const Value& v)
: Expression(location), value(v) {} : Expression(loc), value(v) {}
Value do_evaluate(const std::shared_ptr<Context> &) const override { return value; } Value do_evaluate(const std::shared_ptr<Context> &) const override { return value; }
}; };
class ArrayExpr : public Expression { class ArrayExpr : public Expression {
std::vector<std::shared_ptr<Expression>> elements; std::vector<std::shared_ptr<Expression>> elements;
public: public:
ArrayExpr(const Location & location, std::vector<std::shared_ptr<Expression>> && e) ArrayExpr(const Location & loc, std::vector<std::shared_ptr<Expression>> && e)
: Expression(location), elements(std::move(e)) {} : Expression(loc), elements(std::move(e)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override { Value do_evaluate(const std::shared_ptr<Context> & context) const override {
auto result = Value::array(); auto result = Value::array();
for (const auto& e : elements) { for (const auto& e : elements) {
@ -1173,8 +1185,8 @@ public:
class DictExpr : public Expression { class DictExpr : public Expression {
std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> elements; std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> elements;
public: public:
DictExpr(const Location & location, std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> && e) DictExpr(const Location & loc, std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> && e)
: Expression(location), elements(std::move(e)) {} : Expression(loc), elements(std::move(e)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override { Value do_evaluate(const std::shared_ptr<Context> & context) const override {
auto result = Value::object(); auto result = Value::object();
for (const auto& [key, value] : elements) { for (const auto& [key, value] : elements) {
@ -1189,8 +1201,8 @@ public:
class SliceExpr : public Expression { class SliceExpr : public Expression {
public: public:
std::shared_ptr<Expression> start, end; std::shared_ptr<Expression> start, end;
SliceExpr(const Location & location, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e) SliceExpr(const Location & loc, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e)
: Expression(location), start(std::move(s)), end(std::move(e)) {} : Expression(loc), start(std::move(s)), end(std::move(e)) {}
Value do_evaluate(const std::shared_ptr<Context> &) const override { Value do_evaluate(const std::shared_ptr<Context> &) const override {
throw std::runtime_error("SliceExpr not implemented"); throw std::runtime_error("SliceExpr not implemented");
} }
@ -1200,8 +1212,8 @@ class SubscriptExpr : public Expression {
std::shared_ptr<Expression> base; std::shared_ptr<Expression> base;
std::shared_ptr<Expression> index; std::shared_ptr<Expression> index;
public: public:
SubscriptExpr(const Location & location, std::shared_ptr<Expression> && b, std::shared_ptr<Expression> && i) SubscriptExpr(const Location & loc, std::shared_ptr<Expression> && b, std::shared_ptr<Expression> && i)
: Expression(location), base(std::move(b)), index(std::move(i)) {} : Expression(loc), base(std::move(b)), index(std::move(i)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override { Value do_evaluate(const std::shared_ptr<Context> & context) const override {
if (!base) throw std::runtime_error("SubscriptExpr.base is null"); if (!base) throw std::runtime_error("SubscriptExpr.base is null");
if (!index) throw std::runtime_error("SubscriptExpr.index is null"); if (!index) throw std::runtime_error("SubscriptExpr.index is null");
@ -1243,8 +1255,8 @@ public:
enum class Op { Plus, Minus, LogicalNot, Expansion, ExpansionDict }; enum class Op { Plus, Minus, LogicalNot, Expansion, ExpansionDict };
std::shared_ptr<Expression> expr; std::shared_ptr<Expression> expr;
Op op; Op op;
UnaryOpExpr(const Location & location, std::shared_ptr<Expression> && e, Op o) UnaryOpExpr(const Location & loc, std::shared_ptr<Expression> && e, Op o)
: Expression(location), expr(std::move(e)), op(o) {} : Expression(loc), expr(std::move(e)), op(o) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override { Value do_evaluate(const std::shared_ptr<Context> & context) const override {
if (!expr) throw std::runtime_error("UnaryOpExpr.expr is null"); if (!expr) throw std::runtime_error("UnaryOpExpr.expr is null");
auto e = expr->evaluate(context); auto e = expr->evaluate(context);
@ -1269,8 +1281,8 @@ private:
std::shared_ptr<Expression> right; std::shared_ptr<Expression> right;
Op op; Op op;
public: public:
BinaryOpExpr(const Location & location, std::shared_ptr<Expression> && l, std::shared_ptr<Expression> && r, Op o) BinaryOpExpr(const Location & loc, std::shared_ptr<Expression> && l, std::shared_ptr<Expression> && r, Op o)
: Expression(location), left(std::move(l)), right(std::move(r)), op(o) {} : Expression(loc), left(std::move(l)), right(std::move(r)), op(o) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override { Value do_evaluate(const std::shared_ptr<Context> & context) const override {
if (!left) throw std::runtime_error("BinaryOpExpr.left is null"); if (!left) throw std::runtime_error("BinaryOpExpr.left is null");
if (!right) throw std::runtime_error("BinaryOpExpr.right is null"); if (!right) throw std::runtime_error("BinaryOpExpr.right is null");
@ -1427,8 +1439,8 @@ class MethodCallExpr : public Expression {
std::shared_ptr<VariableExpr> method; std::shared_ptr<VariableExpr> method;
ArgumentsExpression args; ArgumentsExpression args;
public: public:
MethodCallExpr(const Location & location, std::shared_ptr<Expression> && obj, std::shared_ptr<VariableExpr> && m, ArgumentsExpression && a) MethodCallExpr(const Location & loc, std::shared_ptr<Expression> && obj, std::shared_ptr<VariableExpr> && m, ArgumentsExpression && a)
: Expression(location), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {} : Expression(loc), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override { Value do_evaluate(const std::shared_ptr<Context> & context) const override {
if (!object) throw std::runtime_error("MethodCallExpr.object is null"); if (!object) throw std::runtime_error("MethodCallExpr.object is null");
if (!method) throw std::runtime_error("MethodCallExpr.method is null"); if (!method) throw std::runtime_error("MethodCallExpr.method is null");
@ -1526,8 +1538,8 @@ class CallExpr : public Expression {
public: public:
std::shared_ptr<Expression> object; std::shared_ptr<Expression> object;
ArgumentsExpression args; ArgumentsExpression args;
CallExpr(const Location & location, std::shared_ptr<Expression> && obj, ArgumentsExpression && a) CallExpr(const Location & loc, std::shared_ptr<Expression> && obj, ArgumentsExpression && a)
: Expression(location), object(std::move(obj)), args(std::move(a)) {} : Expression(loc), object(std::move(obj)), args(std::move(a)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override { Value do_evaluate(const std::shared_ptr<Context> & context) const override {
if (!object) throw std::runtime_error("CallExpr.object is null"); if (!object) throw std::runtime_error("CallExpr.object is null");
auto obj = object->evaluate(context); auto obj = object->evaluate(context);
@ -1542,8 +1554,8 @@ public:
class FilterExpr : public Expression { class FilterExpr : public Expression {
std::vector<std::shared_ptr<Expression>> parts; std::vector<std::shared_ptr<Expression>> parts;
public: public:
FilterExpr(const Location & location, std::vector<std::shared_ptr<Expression>> && p) FilterExpr(const Location & loc, std::vector<std::shared_ptr<Expression>> && p)
: Expression(location), parts(std::move(p)) {} : Expression(loc), parts(std::move(p)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override { Value do_evaluate(const std::shared_ptr<Context> & context) const override {
Value result; Value result;
bool first = true; bool first = true;
@ -2460,7 +2472,7 @@ private:
static std::regex leading_space_regex(R"(^\s+)"); static std::regex leading_space_regex(R"(^\s+)");
text = std::regex_replace(text, leading_space_regex, ""); text = std::regex_replace(text, leading_space_regex, "");
} else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast<ExpressionTemplateToken*>((*(it - 2)).get())) { } else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast<ExpressionTemplateToken*>((*(it - 2)).get())) {
if (text.length() > 0 && text[0] == '\n') { if (!text.empty() && text[0] == '\n') {
text.erase(0, 1); text.erase(0, 1);
} }
} }
@ -2538,7 +2550,7 @@ public:
TemplateTokenIterator begin = tokens.begin(); TemplateTokenIterator begin = tokens.begin();
auto it = begin; auto it = begin;
TemplateTokenIterator end = tokens.end(); TemplateTokenIterator end = tokens.end();
return parser.parseTemplate(begin, it, end, /* full= */ true); return parser.parseTemplate(begin, it, end, /* fully= */ true);
} }
}; };
@ -2577,7 +2589,7 @@ inline std::shared_ptr<Context> Context::builtins() {
throw std::runtime_error(args.at("message").get<std::string>()); throw std::runtime_error(args.at("message").get<std::string>());
})); }));
globals.set("tojson", simple_function("tojson", { "value", "indent" }, [](const std::shared_ptr<Context> &, Value & args) { globals.set("tojson", simple_function("tojson", { "value", "indent" }, [](const std::shared_ptr<Context> &, Value & args) {
return Value(args.at("value").dump(args.get<int64_t>("indent", -1), /* tojson= */ true)); return Value(args.at("value").dump(args.get<int64_t>("indent", -1), /* to_json= */ true));
})); }));
globals.set("items", simple_function("items", { "object" }, [](const std::shared_ptr<Context> &, Value & args) { globals.set("items", simple_function("items", { "object" }, [](const std::shared_ptr<Context> &, Value & args) {
auto items = Value::array(); auto items = Value::array();
@ -2599,7 +2611,7 @@ inline std::shared_ptr<Context> Context::builtins() {
globals.set("last", simple_function("last", { "items" }, [](const std::shared_ptr<Context> &, Value & args) { globals.set("last", simple_function("last", { "items" }, [](const std::shared_ptr<Context> &, Value & args) {
auto items = args.at("items"); auto items = args.at("items");
if (!items.is_array()) throw std::runtime_error("object is not a list"); if (!items.is_array()) throw std::runtime_error("object is not a list");
if (items.size() == 0) return Value(); if (items.empty()) return Value();
return items.at(items.size() - 1); return items.at(items.size() - 1);
})); }));
globals.set("trim", simple_function("trim", { "text" }, [](const std::shared_ptr<Context> &, Value & args) { globals.set("trim", simple_function("trim", { "text" }, [](const std::shared_ptr<Context> &, Value & args) {
@ -2747,12 +2759,17 @@ inline std::shared_ptr<Context> Context::builtins() {
return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) { return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits<size_t>::max)()}, {0, 0}); args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
auto & items = args.args[0]; auto & items = args.args[0];
if (items.is_null()) if (items.is_null()) {
return Value::array(); return Value::array();
if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); }
if (!items.is_array()) {
throw std::runtime_error("object is not iterable: " + items.dump());
}
auto filter_fn = context->get(args.args[1]); auto filter_fn = context->get(args.args[1]);
if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); if (filter_fn.is_null()) {
throw std::runtime_error("Undefined filter: " + args.args[1].dump());
}
auto filter_args = Value::array(); auto filter_args = Value::array();
for (size_t i = 2, n = args.args.size(); i < n; i++) { for (size_t i = 2, n = args.args.size(); i < n; i++) {
@ -2878,10 +2895,15 @@ inline std::shared_ptr<Context> Context::builtins() {
} }
for (auto & [name, value] : args.kwargs) { for (auto & [name, value] : args.kwargs) {
size_t i; size_t i;
if (name == "start") i = 0; if (name == "start") {
else if (name == "end") i = 1; i = 0;
else if (name == "step") i = 2; } else if (name == "end") {
else throw std::runtime_error("Unknown argument " + name + " for function range"); i = 1;
} else if (name == "step") {
i = 2;
} else {
throw std::runtime_error("Unknown argument " + name + " for function range");
}
if (param_set[i]) { if (param_set[i]) {
throw std::runtime_error("Duplicate argument " + name + " for function range"); throw std::runtime_error("Duplicate argument " + name + " for function range");

View file

@ -409,8 +409,6 @@ static void gguf_merge(const split_params & split_params) {
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
std::ofstream fout(split_params.output.c_str(), std::ios::binary);
fout.exceptions(std::ofstream::failbit); // fail fast on write errors
auto * ctx_out = gguf_init_empty(); auto * ctx_out = gguf_init_empty();
@ -454,7 +452,6 @@ static void gguf_merge(const split_params & split_params) {
gguf_free(ctx_gguf); gguf_free(ctx_gguf);
ggml_free(ctx_meta); ggml_free(ctx_meta);
gguf_free(ctx_out); gguf_free(ctx_out);
fout.close();
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
@ -467,7 +464,6 @@ static void gguf_merge(const split_params & split_params) {
gguf_free(ctx_gguf); gguf_free(ctx_gguf);
ggml_free(ctx_meta); ggml_free(ctx_meta);
gguf_free(ctx_out); gguf_free(ctx_out);
fout.close();
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
@ -480,7 +476,6 @@ static void gguf_merge(const split_params & split_params) {
gguf_free(ctx_gguf); gguf_free(ctx_gguf);
ggml_free(ctx_meta); ggml_free(ctx_meta);
gguf_free(ctx_out); gguf_free(ctx_out);
fout.close();
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
@ -501,9 +496,11 @@ static void gguf_merge(const split_params & split_params) {
fprintf(stderr, "\033[3Ddone\n"); fprintf(stderr, "\033[3Ddone\n");
} }
std::ofstream fout;
if (!split_params.dry_run) {
fout.open(split_params.output.c_str(), std::ios::binary);
fout.exceptions(std::ofstream::failbit); // fail fast on write errors
// placeholder for the meta data // placeholder for the meta data
{
auto meta_size = gguf_get_meta_size(ctx_out); auto meta_size = gguf_get_meta_size(ctx_out);
::zeros(fout, meta_size); ::zeros(fout, meta_size);
} }
@ -519,7 +516,9 @@ static void gguf_merge(const split_params & split_params) {
ggml_free(ctx_metas[i]); ggml_free(ctx_metas[i]);
} }
gguf_free(ctx_out); gguf_free(ctx_out);
if (!split_params.dry_run) {
fout.close(); fout.close();
}
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
fprintf(stderr, "%s: writing tensors %s ...", __func__, split_path); fprintf(stderr, "%s: writing tensors %s ...", __func__, split_path);
@ -541,11 +540,12 @@ static void gguf_merge(const split_params & split_params) {
auto offset = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i_tensor); auto offset = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i_tensor);
f_input.seekg(offset); f_input.seekg(offset);
f_input.read((char *)read_data.data(), n_bytes); f_input.read((char *)read_data.data(), n_bytes);
if (!split_params.dry_run) {
// write tensor data + padding // write tensor data + padding
fout.write((const char *)read_data.data(), n_bytes); fout.write((const char *)read_data.data(), n_bytes);
zeros(fout, GGML_PAD(n_bytes, GGUF_DEFAULT_ALIGNMENT) - n_bytes); zeros(fout, GGML_PAD(n_bytes, GGUF_DEFAULT_ALIGNMENT) - n_bytes);
} }
}
gguf_free(ctx_gguf); gguf_free(ctx_gguf);
ggml_free(ctx_meta); ggml_free(ctx_meta);
@ -553,16 +553,15 @@ static void gguf_merge(const split_params & split_params) {
fprintf(stderr, "\033[3Ddone\n"); fprintf(stderr, "\033[3Ddone\n");
} }
{ if (!split_params.dry_run) {
// go back to beginning of file and write the updated metadata // go back to beginning of file and write the updated metadata
fout.seekp(0); fout.seekp(0);
std::vector<uint8_t> data(gguf_get_meta_size(ctx_out)); std::vector<uint8_t> data(gguf_get_meta_size(ctx_out));
gguf_get_meta_data(ctx_out, data.data()); gguf_get_meta_data(ctx_out, data.data());
fout.write((const char *)data.data(), data.size()); fout.write((const char *)data.data(), data.size());
fout.close(); fout.close();
gguf_free(ctx_out);
} }
gguf_free(ctx_out);
fprintf(stderr, "%s: %s merged from %d split with %d tensors.\n", fprintf(stderr, "%s: %s merged from %d split with %d tensors.\n",
__func__, split_params.output.c_str(), n_split, total_tensors); __func__, split_params.output.c_str(), n_split, total_tensors);

273
examples/llava/clip-impl.h Normal file
View file

@ -0,0 +1,273 @@
#include "ggml.h"
#include "gguf.h"
#include <climits>
#include <cstdarg>
#include <string>
#include <map>
#include <sstream>
#include <vector>
// Internal header for clip.cpp
#define KEY_FTYPE "general.file_type"
#define KEY_NAME "general.name"
#define KEY_DESCRIPTION "general.description"
#define KEY_HAS_TEXT_ENC "clip.has_text_encoder"
#define KEY_HAS_VIS_ENC "clip.has_vision_encoder"
#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector"
#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector"
#define KEY_HAS_GLM_PROJ "clip.has_glm_projector"
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger"
#define KEY_USE_GELU "clip.use_gelu"
#define KEY_USE_SILU "clip.use_silu"
#define KEY_N_EMBD "clip.%s.embedding_length"
#define KEY_N_FF "clip.%s.feed_forward_length"
#define KEY_N_BLOCK "clip.%s.block_count"
#define KEY_N_HEAD "clip.%s.attention.head_count"
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
#define KEY_PROJ_DIM "clip.%s.projection_dim"
#define KEY_TOKENS "tokenizer.ggml.tokens"
#define KEY_N_POSITIONS "clip.text.context_length"
#define KEY_IMAGE_SIZE "clip.vision.image_size"
#define KEY_PATCH_SIZE "clip.vision.patch_size"
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
#define KEY_IMAGE_STD "clip.vision.image_std"
#define KEY_PROJ_TYPE "clip.projector_type"
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
//
// tensor name constants
//
#define TN_TOKEN_EMBD "%s.token_embd.weight"
#define TN_POS_EMBD "%s.position_embd.weight"
#define TN_CLASS_EMBD "v.class_embd"
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1"
#define TN_PATCH_BIAS "v.patch_embd.bias"
#define TN_ATTN_K "%s.blk.%d.attn_k.%s"
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
#define TN_ATTN_V "%s.blk.%d.attn_v.%s"
#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s"
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
#define TN_LN_1 "%s.blk.%d.ln1.%s"
#define TN_LN_2 "%s.blk.%d.ln2.%s"
#define TN_LN_PRE "%s.pre_ln.%s"
#define TN_LN_POST "%s.post_ln.%s"
#define TN_TEXT_PROJ "text_projection.weight"
#define TN_VIS_PROJ "visual_projection.weight"
#define TN_LLAVA_PROJ "mm.%d.%s"
#define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s"
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
#define TN_IMAGE_NEWLINE "model.image_newline"
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
// mimicpmv
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
#define TN_MINICPMV_QUERY "resampler.query"
#define TN_MINICPMV_PROJ "resampler.proj.weight"
#define TN_MINICPMV_KV_PROJ "resampler.kv.weight"
#define TN_MINICPMV_ATTN "resampler.attn.%s.%s"
#define TN_MINICPMV_LN "resampler.ln_%s.%s"
#define TN_GLM_ADAPER_CONV "adapter.conv.%s"
#define TN_GLM_ADAPTER_LINEAR "adapter.linear.linear.%s"
#define TN_GLM_ADAPTER_NORM_1 "adapter.linear.norm1.%s"
#define TN_GLM_ADAPTER_D_H_2_4H "adapter.linear.dense_h_to_4h.%s"
#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s"
#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s"
#define TN_GLM_BOI_W "adapter.boi"
#define TN_GLM_EOI_W "adapter.eoi"
enum projector_type {
PROJECTOR_TYPE_MLP,
PROJECTOR_TYPE_MLP_NORM,
PROJECTOR_TYPE_LDP,
PROJECTOR_TYPE_LDPV2,
PROJECTOR_TYPE_RESAMPLER,
PROJECTOR_TYPE_GLM_EDGE,
PROJECTOR_TYPE_MERGER,
PROJECTOR_TYPE_GEMMA3,
PROJECTOR_TYPE_UNKNOWN,
};
static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_MLP, "mlp" },
{ PROJECTOR_TYPE_LDP, "ldp" },
{ PROJECTOR_TYPE_LDPV2, "ldpv2"},
{ PROJECTOR_TYPE_RESAMPLER, "resampler"},
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
};
static projector_type clip_projector_type_from_string(const std::string & str) {
for (const auto & pair : PROJECTOR_TYPE_NAMES) {
if (pair.second == str) {
return pair.first;
}
}
return PROJECTOR_TYPE_UNKNOWN;
}
//
// logging
//
static void clip_log_callback_default(enum ggml_log_level level, const char * text, void * user_data) {
(void) level;
(void) user_data;
fputs(text, stderr);
fflush(stderr);
}
struct clip_logger_state {
ggml_log_level verbosity_thold;
ggml_log_callback log_callback;
void * log_callback_user_data;
};
extern struct clip_logger_state g_logger_state;
static void clip_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {
if (format == NULL) {
return;
}
va_list args_copy;
va_copy(args_copy, args);
char buffer[128];
int len = vsnprintf(buffer, 128, format, args);
if (len < 128) {
g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data);
} else {
char * buffer2 = (char *) calloc(len + 1, sizeof(char));
vsnprintf(buffer2, len + 1, format, args_copy);
buffer2[len] = 0;
g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data);
free(buffer2);
}
va_end(args_copy);
}
static void clip_log_internal(enum ggml_log_level level, const char * format, ...) {
va_list args;
va_start(args, format);
clip_log_internal_v(level, format, args);
va_end(args);
}
#define LOG_TMPL(level, ...) \
do { \
if ((level) >= g_logger_state.verbosity_thold) { \
clip_log_internal((level), __VA_ARGS__); \
} \
} while (0)
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, __VA_ARGS__)
//
// common utils
//
static std::string string_format(const char * fmt, ...) {
va_list ap;
va_list ap2;
va_start(ap, fmt);
va_copy(ap2, ap);
int size = vsnprintf(NULL, 0, fmt, ap);
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
std::vector<char> buf(size + 1);
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
GGML_ASSERT(size2 == size);
va_end(ap2);
va_end(ap);
return std::string(buf.data(), buf.size());
}
static void string_replace_all(std::string & s, const std::string & search, const std::string & replace) {
if (search.empty()) {
return;
}
std::string builder;
builder.reserve(s.length());
size_t pos = 0;
size_t last_pos = 0;
while ((pos = s.find(search, last_pos)) != std::string::npos) {
builder.append(s, last_pos, pos - last_pos);
builder.append(replace);
last_pos = pos + search.length();
}
builder.append(s, last_pos, std::string::npos);
s = std::move(builder);
}
//
// gguf utils
//
static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) {
switch (type) {
case GGUF_TYPE_UINT8: return std::to_string(((const uint8_t *)data)[i]);
case GGUF_TYPE_INT8: return std::to_string(((const int8_t *)data)[i]);
case GGUF_TYPE_UINT16: return std::to_string(((const uint16_t *)data)[i]);
case GGUF_TYPE_INT16: return std::to_string(((const int16_t *)data)[i]);
case GGUF_TYPE_UINT32: return std::to_string(((const uint32_t *)data)[i]);
case GGUF_TYPE_INT32: return std::to_string(((const int32_t *)data)[i]);
case GGUF_TYPE_UINT64: return std::to_string(((const uint64_t *)data)[i]);
case GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]);
case GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]);
case GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]);
case GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false";
default: return string_format("unknown type %d", type);
}
}
static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
switch (type) {
case GGUF_TYPE_STRING:
return gguf_get_val_str(ctx_gguf, i);
case GGUF_TYPE_ARRAY:
{
const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i);
int arr_n = gguf_get_arr_n(ctx_gguf, i);
const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i);
std::stringstream ss;
ss << "[";
for (int j = 0; j < arr_n; j++) {
if (arr_type == GGUF_TYPE_STRING) {
std::string val = gguf_get_arr_str(ctx_gguf, i, j);
// escape quotes
string_replace_all(val, "\\", "\\\\");
string_replace_all(val, "\"", "\\\"");
ss << '"' << val << '"';
} else if (arr_type == GGUF_TYPE_ARRAY) {
ss << "???";
} else {
ss << gguf_data_to_str(arr_type, data, j);
}
if (j < arr_n - 1) {
ss << ", ";
}
}
ss << "]";
return ss.str();
}
default:
return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0);
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,7 @@
#ifndef CLIP_H #ifndef CLIP_H
#define CLIP_H #define CLIP_H
#include "ggml.h"
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
@ -41,7 +42,7 @@ struct clip_image_f32_batch {
struct clip_context_params { struct clip_context_params {
bool use_gpu; bool use_gpu;
int verbosity; ggml_log_level verbosity;
}; };
// deprecated, use clip_init // deprecated, use clip_init

View file

@ -10,7 +10,7 @@
#include <vector> #include <vector>
#include <limits.h> #include <limits.h>
#include <inttypes.h> #include <cinttypes>
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h> #include <signal.h>
@ -79,7 +79,11 @@ struct gemma3_context {
void init_clip_model(common_params & params) { void init_clip_model(common_params & params) {
const char * clip_path = params.mmproj.path.c_str(); const char * clip_path = params.mmproj.path.c_str();
ctx_clip = clip_model_load(clip_path, params.verbosity > 1); ctx_clip = clip_model_load(clip_path, GGML_LOG_LEVEL_INFO);
if (!ctx_clip) {
LOG_ERR("Failed to load CLIP model from %s\n", clip_path);
exit(1);
}
} }
~gemma3_context() { ~gemma3_context() {

View file

@ -241,7 +241,7 @@ static struct llava_context * llava_init_context(common_params * params, llama_m
prompt = "describe the image in detail."; prompt = "describe the image in detail.";
} }
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); auto ctx_clip = clip_model_load(clip_path, GGML_LOG_LEVEL_INFO);
llama_context_params ctx_params = common_context_params_to_llama(*params); llama_context_params ctx_params = common_context_params_to_llama(*params);
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings

View file

@ -88,7 +88,7 @@ static struct clip_ctx * clip_init_context(common_params * params) {
} }
struct clip_context_params clip_params = { struct clip_context_params clip_params = {
/* use_gpu */ params->n_gpu_layers != 0, /* use_gpu */ params->n_gpu_layers != 0,
/* verbosity */ params->verbosity, /* verbosity */ GGML_LOG_LEVEL_INFO, // TODO: make this configurable
}; };
auto * ctx_clip = clip_init(clip_path, clip_params); auto * ctx_clip = clip_init(clip_path, clip_params);
return ctx_clip; return ctx_clip;

View file

@ -330,7 +330,7 @@ static struct llava_context * llava_init_context(common_params * params, llama_m
prompt = "describe the image in detail."; prompt = "describe the image in detail.";
} }
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); auto ctx_clip = clip_model_load(clip_path, GGML_LOG_LEVEL_INFO);
llama_context_params ctx_params = common_context_params_to_llama(*params); llama_context_params ctx_params = common_context_params_to_llama(*params);
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings

BIN
examples/llava/test-1.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 121 KiB

81
examples/llava/tests.sh Executable file
View file

@ -0,0 +1,81 @@
#!/bin/bash
# make sure we are in the right directory
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
cd $SCRIPT_DIR
#export LLAMA_CACHE="$SCRIPT_DIR/tmp"
set -eux
mkdir -p $SCRIPT_DIR/output
PROJ_ROOT="$SCRIPT_DIR/../.."
cd $PROJ_ROOT
###############
arr_bin=()
arr_hf=()
add_test() {
local bin=$1
local hf=$2
arr_bin+=("$bin")
arr_hf+=("$hf")
}
add_test "llama-gemma3-cli" "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M"
add_test "llama-llava-cli" "cmp-nct/Yi-VL-6B-GGUF:Q5_K"
add_test "llama-llava-cli" "guinmoon/MobileVLM-3B-GGUF:Q4_K_M"
add_test "llama-llava-cli" "THUDM/glm-edge-v-5b-gguf:Q4_K_M"
add_test "llama-llava-cli" "second-state/Llava-v1.5-7B-GGUF:Q2_K"
add_test "llama-llava-cli" "cjpais/llava-1.6-mistral-7b-gguf:Q3_K"
add_test "llama-llava-cli" "ibm-research/granite-vision-3.2-2b-GGUF:Q4_K_M"
add_test "llama-minicpmv-cli" "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K" # model from openbmb is corrupted
add_test "llama-minicpmv-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
add_test "llama-minicpmv-cli" "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
add_test "llama-qwen2vl-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
###############
cmake --build build -j --target "${arr_bin[@]}"
arr_res=()
for i in "${!arr_bin[@]}"; do
bin="${arr_bin[$i]}"
hf="${arr_hf[$i]}"
echo "Running test with binary: $bin and HF model: $hf"
echo ""
echo ""
output=$("$PROJ_ROOT/build/bin/$bin" -hf "$hf" --image $SCRIPT_DIR/test-1.jpeg -p "what is the publisher name of the newspaper?" --temp 0 2>&1 | tee /dev/tty)
echo "$output" > $SCRIPT_DIR/output/$bin-$(echo "$hf" | tr '/' '-').log
if echo "$output" | grep -iq "new york"; then
result="\033[32mOK\033[0m: $bin $hf"
else
result="\033[31mFAIL\033[0m: $bin $hf"
fi
echo -e "$result"
arr_res+=("$result")
echo ""
echo ""
echo ""
echo "#################################################"
echo "#################################################"
echo ""
echo ""
done
set +x
for i in "${!arr_res[@]}"; do
echo -e "${arr_res[$i]}"
done
echo ""
echo "Output logs are saved in $SCRIPT_DIR/output"

Binary file not shown.

File diff suppressed because it is too large Load diff

View file

@ -13,9 +13,11 @@
"dependencies": { "dependencies": {
"@heroicons/react": "^2.2.0", "@heroicons/react": "^2.2.0",
"@sec-ant/readable-stream": "^0.6.0", "@sec-ant/readable-stream": "^0.6.0",
"@tailwindcss/postcss": "^4.1.1",
"@tailwindcss/vite": "^4.1.1",
"@vscode/markdown-it-katex": "^1.1.1", "@vscode/markdown-it-katex": "^1.1.1",
"autoprefixer": "^10.4.20", "autoprefixer": "^10.4.20",
"daisyui": "^4.12.14", "daisyui": "^5.0.12",
"dexie": "^4.0.11", "dexie": "^4.0.11",
"highlight.js": "^11.10.0", "highlight.js": "^11.10.0",
"katex": "^0.16.15", "katex": "^0.16.15",
@ -29,7 +31,7 @@
"remark-breaks": "^4.0.0", "remark-breaks": "^4.0.0",
"remark-gfm": "^4.0.0", "remark-gfm": "^4.0.0",
"remark-math": "^6.0.0", "remark-math": "^6.0.0",
"tailwindcss": "^3.4.15", "tailwindcss": "^4.1.1",
"textlinestream": "^1.1.1", "textlinestream": "^1.1.1",
"vite-plugin-singlefile": "^2.0.3" "vite-plugin-singlefile": "^2.0.3"
}, },

View file

@ -1,6 +1,5 @@
export default { export default {
plugins: { plugins: {
tailwindcss: {}, "@tailwindcss/postcss": {},
autoprefixer: {},
}, },
} }

View file

@ -28,7 +28,7 @@ function AppLayout() {
<> <>
<Sidebar /> <Sidebar />
<div <div
className="drawer-content grow flex flex-col h-screen w-screen mx-auto px-4 overflow-auto" className="drawer-content grow flex flex-col h-screen w-screen mx-auto px-4 overflow-auto bg-base-100"
id="main-scroll" id="main-scroll"
> >
<Header /> <Header />

View file

@ -1,4 +1,4 @@
import daisyuiThemes from 'daisyui/src/theming/themes'; import daisyuiThemes from 'daisyui/theme/object';
import { isNumeric } from './utils/misc'; import { isNumeric } from './utils/misc';
export const isDev = import.meta.env.MODE === 'development'; export const isDev = import.meta.env.MODE === 'development';

View file

@ -2,7 +2,7 @@ import { useEffect, useState } from 'react';
import StorageUtils from '../utils/storage'; import StorageUtils from '../utils/storage';
import { useAppContext } from '../utils/app.context'; import { useAppContext } from '../utils/app.context';
import { classNames } from '../utils/misc'; import { classNames } from '../utils/misc';
import daisyuiThemes from 'daisyui/src/theming/themes'; import daisyuiThemes from 'daisyui/theme/object';
import { THEMES } from '../Config'; import { THEMES } from '../Config';
import { useNavigate } from 'react-router'; import { useNavigate } from 'react-router';
@ -20,7 +20,6 @@ export default function Header() {
document.body.setAttribute('data-theme', selectedTheme); document.body.setAttribute('data-theme', selectedTheme);
document.body.setAttribute( document.body.setAttribute(
'data-color-scheme', 'data-color-scheme',
// @ts-expect-error daisyuiThemes complains about index type, but it should work
daisyuiThemes[selectedTheme]?.['color-scheme'] ?? 'auto' daisyuiThemes[selectedTheme]?.['color-scheme'] ?? 'auto'
); );
}, [selectedTheme]); }, [selectedTheme]);

View file

@ -1,8 +1,13 @@
@use 'sass:meta'; @use 'sass:meta';
@use 'tailwindcss';
@tailwind base; @plugin 'daisyui' {
@tailwind components; themes: all;
@tailwind utilities; }
html {
scrollbar-gutter: auto;
}
.markdown { .markdown {
h1, h1,

View file

@ -2475,7 +2475,12 @@ int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
} }
int32_t llama_kv_self_n_tokens(const llama_context * ctx) { int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
return llama_kv_cache_n_tokens(ctx->get_kv_self()); const auto * kv = ctx->get_kv_self();
if (!kv) {
return 0;
}
return kv->get_n_tokens();
} }
// deprecated // deprecated
@ -2484,7 +2489,12 @@ int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
} }
int32_t llama_kv_self_used_cells(const llama_context * ctx) { int32_t llama_kv_self_used_cells(const llama_context * ctx) {
return llama_kv_cache_used_cells(ctx->get_kv_self()); const auto * kv = ctx->get_kv_self();
if (!kv) {
return 0;
}
return kv->get_used_cells();
} }
// deprecated // deprecated
@ -2493,7 +2503,12 @@ void llama_kv_cache_clear(llama_context * ctx) {
} }
void llama_kv_self_clear(llama_context * ctx) { void llama_kv_self_clear(llama_context * ctx) {
llama_kv_cache_clear(ctx->get_kv_self()); auto * kv = ctx->get_kv_self();
if (!kv) {
return;
}
kv->clear();
} }
// deprecated // deprecated
@ -2510,7 +2525,12 @@ bool llama_kv_self_seq_rm(
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1) { llama_pos p1) {
return llama_kv_cache_seq_rm(ctx->get_kv_self(), seq_id, p0, p1); auto * kv = ctx->get_kv_self();
if (!kv) {
return true;
}
return kv->seq_rm(seq_id, p0, p1);
} }
// deprecated // deprecated
@ -2529,7 +2549,12 @@ void llama_kv_self_seq_cp(
llama_seq_id seq_id_dst, llama_seq_id seq_id_dst,
llama_pos p0, llama_pos p0,
llama_pos p1) { llama_pos p1) {
return llama_kv_cache_seq_cp(ctx->get_kv_self(), seq_id_src, seq_id_dst, p0, p1); auto * kv = ctx->get_kv_self();
if (!kv) {
return;
}
return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
} }
// deprecated // deprecated
@ -2540,7 +2565,12 @@ void llama_kv_cache_seq_keep(
} }
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_cache_seq_keep(ctx->get_kv_self(), seq_id); auto * kv = ctx->get_kv_self();
if (!kv) {
return;
}
return kv->seq_keep(seq_id);
} }
// deprecated // deprecated
@ -2559,7 +2589,12 @@ void llama_kv_self_seq_add(
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
llama_pos delta) { llama_pos delta) {
return llama_kv_cache_seq_add(ctx->get_kv_self(), seq_id, p0, p1, delta); auto * kv = ctx->get_kv_self();
if (!kv) {
return;
}
return kv->seq_add(seq_id, p0, p1, delta);
} }
// deprecated // deprecated
@ -2578,7 +2613,12 @@ void llama_kv_self_seq_div(
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
int d) { int d) {
return llama_kv_cache_seq_div(ctx->get_kv_self(), seq_id, p0, p1, d); auto * kv = ctx->get_kv_self();
if (!kv) {
return;
}
return kv->seq_div(seq_id, p0, p1, d);
} }
// deprecated // deprecated
@ -2587,7 +2627,12 @@ llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
} }
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_cache_seq_pos_max(ctx->get_kv_self(), seq_id); const auto * kv = ctx->get_kv_self();
if (!kv) {
return 0;
}
return kv->seq_pos_max(seq_id);
} }
// deprecated // deprecated
@ -2596,7 +2641,12 @@ void llama_kv_cache_defrag(llama_context * ctx) {
} }
void llama_kv_self_defrag(llama_context * ctx) { void llama_kv_self_defrag(llama_context * ctx) {
llama_kv_cache_defrag(ctx->get_kv_self()); auto * kv = ctx->get_kv_self();
if (!kv) {
return;
}
return kv->defrag();
} }
// deprecated // deprecated
@ -2605,7 +2655,12 @@ bool llama_kv_cache_can_shift(const llama_context * ctx) {
} }
bool llama_kv_self_can_shift(const llama_context * ctx) { bool llama_kv_self_can_shift(const llama_context * ctx) {
return llama_kv_cache_can_shift(ctx->get_kv_self()); const auto * kv = ctx->get_kv_self();
if (!kv) {
return false;
}
return kv->get_can_shift();
} }
// deprecated // deprecated

View file

@ -131,7 +131,7 @@ int32_t llama_kv_cache_unified::get_n_tokens() const {
return result; return result;
} }
uint32_t llama_kv_cache_unified::get_used_cells() const { int32_t llama_kv_cache_unified::get_used_cells() const {
return used; return used;
} }
@ -428,7 +428,7 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
} }
} }
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) { llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
llama_pos result = 0; llama_pos result = 0;
for (uint32_t i = 0; i < size; ++i) { for (uint32_t i = 0; i < size; ++i) {
@ -481,6 +481,11 @@ void llama_kv_cache_unified::restore() {
} }
void llama_kv_cache_unified::commit() { void llama_kv_cache_unified::commit() {
// TODO: tmp - move to llama_kv_cache_recurrent
if (recurrent) {
return;
}
if (pending.ranges.empty()) { if (pending.ranges.empty()) {
LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n", LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
__func__, "https://github.com/ggml-org/llama.cpp/pull/12695"); __func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
@ -1273,117 +1278,6 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
return true; return true;
} }
//
// interface implementation
//
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) {
if (!kv) {
return 0;
}
return kv->get_n_tokens();
}
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
if (!kv) {
return 0;
}
return kv->get_used_cells();
}
void llama_kv_cache_clear(llama_kv_cache * kv) {
if (!kv) {
return;
}
kv->clear();
}
bool llama_kv_cache_seq_rm(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
if (!kv) {
return true;
}
return kv->seq_rm(seq_id, p0, p1);
}
void llama_kv_cache_seq_cp(
llama_kv_cache * kv,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
if (!kv) {
return;
}
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}
void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id) {
if (!kv) {
return;
}
kv->seq_keep(seq_id);
}
void llama_kv_cache_seq_add(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
if (!kv) {
return;
}
kv->seq_add(seq_id, p0, p1, delta);
}
void llama_kv_cache_seq_div(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d) {
if (!kv) {
return;
}
kv->seq_div(seq_id, p0, p1, d);
}
llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id) {
if (!kv) {
return 0;
}
return kv->seq_pos_max(seq_id);
}
void llama_kv_cache_defrag(llama_kv_cache * kv) {
if (!kv) {
return;
}
kv->defrag();
}
bool llama_kv_cache_can_shift(const llama_kv_cache * kv) {
if (!kv) {
return false;
}
return kv->get_can_shift();
}
// //
// kv cache view // kv cache view
// //
@ -1393,7 +1287,7 @@ llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t
/*.n_cells = */ 0, /*.n_cells = */ 0,
/*.n_seq_max = */ n_seq_max, /*.n_seq_max = */ n_seq_max,
/*.token_count = */ 0, /*.token_count = */ 0,
/*.used_cells = */ llama_kv_cache_used_cells(&kv), /*.used_cells = */ kv.get_used_cells(),
/*.max_contiguous = */ 0, /*.max_contiguous = */ 0,
/*.max_contiguous_idx = */ -1, /*.max_contiguous_idx = */ -1,
/*.cells = */ nullptr, /*.cells = */ nullptr,

View file

@ -21,7 +21,7 @@ struct llama_kv_cache : public llama_memory_i {
virtual void commit() = 0; // call after successful batch processing - clears any pending state virtual void commit() = 0; // call after successful batch processing - clears any pending state
virtual int32_t get_n_tokens() const = 0; virtual int32_t get_n_tokens() const = 0;
virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
virtual bool get_can_shift() const = 0; virtual bool get_can_shift() const = 0;
@ -90,7 +90,7 @@ public:
bool offload); bool offload);
int32_t get_n_tokens() const override; int32_t get_n_tokens() const override;
uint32_t get_used_cells() const override; int32_t get_used_cells() const override;
size_t total_size() const; size_t total_size() const;
@ -109,7 +109,7 @@ public:
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_max(llama_seq_id seq_id) override; llama_pos seq_pos_max(llama_seq_id seq_id) const override;
bool get_can_shift() const override; bool get_can_shift() const override;
@ -204,48 +204,6 @@ private:
// using llama_kv_cache_unified::llama_kv_cache_unified; // using llama_kv_cache_unified::llama_kv_cache_unified;
//}; //};
// TODO: maybe become part of the public llama_kv_cache in the future
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv);
void llama_kv_cache_clear(llama_kv_cache * kv);
bool llama_kv_cache_seq_rm(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1);
void llama_kv_cache_seq_cp(
llama_kv_cache * kv,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1);
void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id);
void llama_kv_cache_seq_add(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta);
void llama_kv_cache_seq_div(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d);
llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id);
void llama_kv_cache_defrag(llama_kv_cache * kv);
bool llama_kv_cache_can_shift(const llama_kv_cache * kv);
// //
// kv cache view // kv cache view
// //

View file

@ -15,7 +15,7 @@ public:
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0; virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0; virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0; virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
virtual bool get_can_edit() const = 0; virtual bool get_can_edit() const = 0;
}; };