From 7fc1c4ef78d72bf2753fe2cff04ae03846bfff62 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 21 Apr 2026 17:24:55 +0300 Subject: [PATCH 01/20] metal : workaround macOS GPU interactivity watchdog (#22216) --- ggml/src/ggml-metal/ggml-metal.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 4dbf8e6fe..6a836e459 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -918,6 +918,10 @@ ggml_backend_reg_t ggml_backend_metal_reg(void) { static std::vector devs; if (!initialized) { + // workaround macOS limitation (kIOGPUCommandBufferCallbackErrorImpactingInteractivity) until proper fix becomes possible + // ref: https://github.com/ggml-org/llama.cpp/issues/20141#issuecomment-4272947703 + setenv("AGX_RELAX_CDM_CTXSTORE_TIMEOUT", "1", true); + static ggml_backend_metal_reg_ptr reg_ctx(ggml_backend_metal_reg_init()); for (int i = 0; i < g_devices; ++i) { From 606fa42f5de58ef74e0eb9a0a55bb6c0a689e2eb Mon Sep 17 00:00:00 2001 From: "Alessandro de Oliveira Faria (A.K.A.CABELO)" Date: Tue, 21 Apr 2026 11:45:48 -0300 Subject: [PATCH 02/20] vendor : update cpp-httplib to 0.43.1 (#22143) * vendor : update cpp-httplib to 0.43.0 * vendor : update cpp-httplib to 0.43.0 --- scripts/sync_vendor.py | 2 +- vendor/cpp-httplib/httplib.cpp | 595 ++++++++------------------------- vendor/cpp-httplib/httplib.h | 139 ++------ 3 files changed, 179 insertions(+), 557 deletions(-) diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index 7ce093323..ff1dd0753 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -5,7 +5,7 @@ import os import sys import subprocess -HTTPLIB_VERSION = "refs/tags/v0.42.0" +HTTPLIB_VERSION = "refs/tags/v0.43.1" vendor = { "https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp", diff --git a/vendor/cpp-httplib/httplib.cpp b/vendor/cpp-httplib/httplib.cpp index 1cd71c2ec..95bf0eb1b 100644 --- a/vendor/cpp-httplib/httplib.cpp +++ b/vendor/cpp-httplib/httplib.cpp @@ -872,7 +872,8 @@ bool write_websocket_frame(Stream &strm, ws::Opcode opcode, if (strm.write(reinterpret_cast(header), 2) < 0) { return false; } uint8_t ext[8]; for (int i = 7; i >= 0; i--) { - ext[7 - i] = static_cast((len >> (i * 8)) & 0xFF); + ext[7 - i] = + static_cast((static_cast(len) >> (i * 8)) & 0xFF); } if (strm.write(reinterpret_cast(ext), 8) < 0) { return false; } } @@ -1034,10 +1035,15 @@ bool canonicalize_path(const char *path, std::string &resolved) { char buf[_MAX_PATH]; if (_fullpath(buf, path, _MAX_PATH) == nullptr) { return false; } resolved = buf; -#else +#elif defined(PATH_MAX) char buf[PATH_MAX]; if (realpath(path, buf) == nullptr) { return false; } resolved = buf; +#else + auto buf = realpath(path, nullptr); + auto guard = scope_exit([&]() { std::free(buf); }); + if (buf == nullptr) { return false; } + resolved = buf; #endif return true; } @@ -2765,6 +2771,35 @@ EncodingType encoding_type(const Request &req, const Response &res) { return best; } +std::unique_ptr make_compressor(EncodingType type) { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (type == EncodingType::Gzip) { + return detail::make_unique(); + } +#endif +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + if (type == EncodingType::Brotli) { + return detail::make_unique(); + } +#endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + if (type == EncodingType::Zstd) { + return detail::make_unique(); + } +#endif + (void)type; + return nullptr; +} + +const char *encoding_name(EncodingType type) { + switch (type) { + case EncodingType::Gzip: return "gzip"; + case EncodingType::Brotli: return "br"; + case EncodingType::Zstd: return "zstd"; + default: return ""; + } +} + bool nocompressor::compress(const char *data, size_t data_length, bool /*last*/, Callback callback) { if (!data_length) { return true; } @@ -3097,6 +3132,29 @@ const char *get_header_value(const Headers &headers, return def; } +size_t get_header_value_count(const Headers &headers, + const std::string &key) { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +template +typename Map::mapped_type +get_multimap_value(const Map &m, const std::string &key, size_t id) { + auto rng = m.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return it->second; } + return typename Map::mapped_type(); +} + +void set_header(Headers &headers, const std::string &key, + const std::string &val) { + if (fields::is_field_name(key) && fields::is_field_value(val)) { + headers.emplace(key, val); + } +} + bool read_headers(Stream &strm, Headers &headers) { const auto bufsiz = 2048; char buf[bufsiz]; @@ -5791,16 +5849,12 @@ std::string Request::get_header_value(const std::string &key, } size_t Request::get_header_value_count(const std::string &key) const { - auto r = headers.equal_range(key); - return static_cast(std::distance(r.first, r.second)); + return detail::get_header_value_count(headers, key); } void Request::set_header(const std::string &key, const std::string &val) { - if (detail::fields::is_field_name(key) && - detail::fields::is_field_value(val)) { - headers.emplace(key, val); - } + detail::set_header(headers, key, val); } bool Request::has_trailer(const std::string &key) const { @@ -5809,11 +5863,7 @@ bool Request::has_trailer(const std::string &key) const { std::string Request::get_trailer_value(const std::string &key, size_t id) const { - auto rng = trailers.equal_range(key); - auto it = rng.first; - std::advance(it, static_cast(id)); - if (it != rng.second) { return it->second; } - return std::string(); + return detail::get_multimap_value(trailers, key, id); } size_t Request::get_trailer_value_count(const std::string &key) const { @@ -5827,11 +5877,7 @@ bool Request::has_param(const std::string &key) const { std::string Request::get_param_value(const std::string &key, size_t id) const { - auto rng = params.equal_range(key); - auto it = rng.first; - std::advance(it, static_cast(id)); - if (it != rng.second) { return it->second; } - return std::string(); + return detail::get_multimap_value(params, key, id); } std::vector @@ -5886,11 +5932,7 @@ size_t MultipartFormData::get_field_count(const std::string &key) const { FormData MultipartFormData::get_file(const std::string &key, size_t id) const { - auto rng = files.equal_range(key); - auto it = rng.first; - std::advance(it, static_cast(id)); - if (it != rng.second) { return it->second; } - return FormData(); + return detail::get_multimap_value(files, key, id); } std::vector @@ -5929,16 +5971,12 @@ std::string Response::get_header_value(const std::string &key, } size_t Response::get_header_value_count(const std::string &key) const { - auto r = headers.equal_range(key); - return static_cast(std::distance(r.first, r.second)); + return detail::get_header_value_count(headers, key); } void Response::set_header(const std::string &key, const std::string &val) { - if (detail::fields::is_field_name(key) && - detail::fields::is_field_value(val)) { - headers.emplace(key, val); - } + detail::set_header(headers, key, val); } bool Response::has_trailer(const std::string &key) const { return trailers.find(key) != trailers.end(); @@ -5946,11 +5984,7 @@ bool Response::has_trailer(const std::string &key) const { std::string Response::get_trailer_value(const std::string &key, size_t id) const { - auto rng = trailers.equal_range(key); - auto it = rng.first; - std::advance(it, static_cast(id)); - if (it != rng.second) { return it->second; } - return std::string(); + return detail::get_multimap_value(trailers, key, id); } size_t Response::get_trailer_value_count(const std::string &key) const { @@ -6253,15 +6287,6 @@ void ThreadPool::worker(bool is_dynamic) { assert(true == static_cast(fn)); fn(); - - // Dynamic thread: exit if queue is empty after task completion - if (is_dynamic) { - std::unique_lock lock(mutex_); - if (jobs_.empty()) { - move_to_finished(std::this_thread::get_id()); - break; - } - } } #if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && \ @@ -6791,61 +6816,51 @@ Server::make_matcher(const std::string &pattern) { } Server &Server::Get(const std::string &pattern, Handler handler) { - get_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); - return *this; + return add_handler(get_handlers_, pattern, std::move(handler)); } Server &Server::Post(const std::string &pattern, Handler handler) { - post_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); - return *this; + return add_handler(post_handlers_, pattern, std::move(handler)); } Server &Server::Post(const std::string &pattern, HandlerWithContentReader handler) { - post_handlers_for_content_reader_.emplace_back(make_matcher(pattern), - std::move(handler)); - return *this; + return add_handler(post_handlers_for_content_reader_, pattern, + std::move(handler)); } Server &Server::Put(const std::string &pattern, Handler handler) { - put_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); - return *this; + return add_handler(put_handlers_, pattern, std::move(handler)); } Server &Server::Put(const std::string &pattern, HandlerWithContentReader handler) { - put_handlers_for_content_reader_.emplace_back(make_matcher(pattern), - std::move(handler)); - return *this; + return add_handler(put_handlers_for_content_reader_, pattern, + std::move(handler)); } Server &Server::Patch(const std::string &pattern, Handler handler) { - patch_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); - return *this; + return add_handler(patch_handlers_, pattern, std::move(handler)); } Server &Server::Patch(const std::string &pattern, HandlerWithContentReader handler) { - patch_handlers_for_content_reader_.emplace_back(make_matcher(pattern), - std::move(handler)); - return *this; + return add_handler(patch_handlers_for_content_reader_, pattern, + std::move(handler)); } Server &Server::Delete(const std::string &pattern, Handler handler) { - delete_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); - return *this; + return add_handler(delete_handlers_, pattern, std::move(handler)); } Server &Server::Delete(const std::string &pattern, HandlerWithContentReader handler) { - delete_handlers_for_content_reader_.emplace_back(make_matcher(pattern), - std::move(handler)); - return *this; + return add_handler(delete_handlers_for_content_reader_, pattern, + std::move(handler)); } Server &Server::Options(const std::string &pattern, Handler handler) { - options_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); - return *this; + return add_handler(options_handlers_, pattern, std::move(handler)); } Server &Server::WebSocket(const std::string &pattern, @@ -7054,6 +7069,11 @@ Server &Server::set_payload_max_length(size_t length) { return *this; } +Server &Server::set_websocket_max_missed_pongs(int count) { + websocket_max_missed_pongs_ = count; + return *this; +} + Server &Server::set_websocket_ping_interval(time_t sec) { websocket_ping_interval_sec_ = sec; return *this; @@ -7279,23 +7299,10 @@ Server::write_content_with_provider(Stream &strm, const Request &req, if (res.is_chunked_content_provider_) { auto type = detail::encoding_type(req, res); - std::unique_ptr compressor; - if (type == detail::EncodingType::Gzip) { -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - compressor = detail::make_unique(); -#endif - } else if (type == detail::EncodingType::Brotli) { -#ifdef CPPHTTPLIB_BROTLI_SUPPORT - compressor = detail::make_unique(); -#endif - } else if (type == detail::EncodingType::Zstd) { -#ifdef CPPHTTPLIB_ZSTD_SUPPORT - compressor = detail::make_unique(); -#endif - } else { + auto compressor = detail::make_compressor(type); + if (!compressor) { compressor = detail::make_unique(); } - assert(compressor != nullptr); return detail::write_content_chunked(strm, res.content_provider_, is_shutting_down, *compressor); @@ -7917,14 +7924,8 @@ void Server::apply_ranges(const Request &req, Response &res, if (res.content_provider_) { if (res.is_chunked_content_provider_) { res.set_header("Transfer-Encoding", "chunked"); - if (type == detail::EncodingType::Gzip) { - res.set_header("Content-Encoding", "gzip"); - res.set_header("Vary", "Accept-Encoding"); - } else if (type == detail::EncodingType::Brotli) { - res.set_header("Content-Encoding", "br"); - res.set_header("Vary", "Accept-Encoding"); - } else if (type == detail::EncodingType::Zstd) { - res.set_header("Content-Encoding", "zstd"); + if (type != detail::EncodingType::None) { + res.set_header("Content-Encoding", detail::encoding_name(type)); res.set_header("Vary", "Accept-Encoding"); } } @@ -7955,27 +7956,7 @@ void Server::apply_ranges(const Request &req, Response &res, if (type != detail::EncodingType::None) { output_pre_compression_log(req, res); - std::unique_ptr compressor; - std::string content_encoding; - - if (type == detail::EncodingType::Gzip) { -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - compressor = detail::make_unique(); - content_encoding = "gzip"; -#endif - } else if (type == detail::EncodingType::Brotli) { -#ifdef CPPHTTPLIB_BROTLI_SUPPORT - compressor = detail::make_unique(); - content_encoding = "br"; -#endif - } else if (type == detail::EncodingType::Zstd) { -#ifdef CPPHTTPLIB_ZSTD_SUPPORT - compressor = detail::make_unique(); - content_encoding = "zstd"; -#endif - } - - if (compressor) { + if (auto compressor = detail::make_compressor(type)) { std::string compressed; if (compressor->compress(res.body.data(), res.body.size(), true, [&](const char *data, size_t data_len) { @@ -7983,7 +7964,7 @@ void Server::apply_ranges(const Request &req, Response &res, return true; })) { res.body.swap(compressed); - res.set_header("Content-Encoding", content_encoding); + res.set_header("Content-Encoding", detail::encoding_name(type)); res.set_header("Vary", "Accept-Encoding"); } } @@ -8231,7 +8212,8 @@ Server::process_request(Stream &strm, const std::string &remote_addr, { // Use WebSocket-specific read timeout instead of HTTP timeout strm.set_read_timeout(CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND, 0); - ws::WebSocket ws(strm, req, true, websocket_ping_interval_sec_); + ws::WebSocket ws(strm, req, true, websocket_ping_interval_sec_, + websocket_max_missed_pongs_); entry.handler(req, ws); } return true; @@ -10822,38 +10804,6 @@ void ClientImpl::enable_server_hostname_verification(bool enabled) { } #endif -// ClientImpl::set_ca_cert_store is defined after TLS namespace (uses helpers) -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert, - std::size_t size) const { - auto mem = BIO_new_mem_buf(ca_cert, static_cast(size)); - auto se = detail::scope_exit([&] { BIO_free_all(mem); }); - if (!mem) { return nullptr; } - - auto inf = PEM_X509_INFO_read_bio(mem, nullptr, nullptr, nullptr); - if (!inf) { return nullptr; } - - auto cts = X509_STORE_new(); - if (cts) { - for (auto i = 0; i < static_cast(sk_X509_INFO_num(inf)); i++) { - auto itmp = sk_X509_INFO_value(inf, i); - if (!itmp) { continue; } - - if (itmp->x509) { X509_STORE_add_cert(cts, itmp->x509); } - if (itmp->crl) { X509_STORE_add_crl(cts, itmp->crl); } - } - } - - sk_X509_INFO_pop_free(inf, X509_INFO_free); - return cts; -} - -void ClientImpl::set_server_certificate_verifier( - std::function /*verifier*/) { - // Base implementation does nothing - SSLClient overrides this -} -#endif - void ClientImpl::set_logger(Logger logger) { logger_ = std::move(logger); } @@ -10927,10 +10877,10 @@ Client::Client(const std::string &scheme_host_port, cli_ = detail::make_unique(scheme_host_port, 80, client_cert_path, client_key_path); } -} // namespace detail +} Client::Client(const std::string &host, int port) - : cli_(detail::make_unique(host, port)) {} + : Client(host, port, std::string(), std::string()) {} Client::Client(const std::string &host, int port, const std::string &client_cert_path, @@ -11505,12 +11455,6 @@ void Client::set_follow_location(bool on) { void Client::set_path_encode(bool on) { cli_->set_path_encode(on); } -[[deprecated("Use set_path_encode() instead. " - "This function will be removed by v1.0.0.")]] -void Client::set_url_encode(bool on) { - cli_->set_path_encode(on); -} - void Client::set_compress(bool on) { cli_->set_compress(on); } void Client::set_decompress(bool on) { cli_->set_decompress(on); } @@ -11893,24 +11837,31 @@ SSLClient::SSLClient(const std::string &host) SSLClient::SSLClient(const std::string &host, int port) : SSLClient(host, port, std::string(), std::string()) {} +void SSLClient::init_ctx() { + ctx_ = tls::create_client_context(); + if (ctx_) { tls::set_min_version(ctx_, tls::Version::TLS1_2); } +} + +void SSLClient::reset_ctx_on_error() { + last_backend_error_ = tls::get_error(); + tls::free_context(ctx_); + ctx_ = nullptr; +} + SSLClient::SSLClient(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path, const std::string &private_key_password) : ClientImpl(host, port, client_cert_path, client_key_path) { - ctx_ = tls::create_client_context(); + init_ctx(); if (!ctx_) { return; } - tls::set_min_version(ctx_, tls::Version::TLS1_2); - if (!client_cert_path.empty() && !client_key_path.empty()) { const char *password = private_key_password.empty() ? nullptr : private_key_password.c_str(); if (!tls::set_client_cert_file(ctx_, client_cert_path.c_str(), client_key_path.c_str(), password)) { - last_backend_error_ = tls::get_error(); - tls::free_context(ctx_); - ctx_ = nullptr; + reset_ctx_on_error(); } } } @@ -11918,17 +11869,13 @@ SSLClient::SSLClient(const std::string &host, int port, SSLClient::SSLClient(const std::string &host, int port, const PemMemory &pem) : ClientImpl(host, port) { - ctx_ = tls::create_client_context(); + init_ctx(); if (!ctx_) { return; } - tls::set_min_version(ctx_, tls::Version::TLS1_2); - if (pem.cert_pem && pem.key_pem) { if (!tls::set_client_cert_pem(ctx_, pem.cert_pem, pem.key_pem, pem.private_key_password)) { - last_backend_error_ = tls::get_error(); - tls::free_context(ctx_); - ctx_ = nullptr; + reset_ctx_on_error(); } } } @@ -12479,41 +12426,6 @@ std::string Request::sni() const { * Group 8: TLS abstraction layer - OpenSSL backend */ -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -// These wrappers forward to deprecated APIs that will be removed by v1.0.0. -// Suppress C4996 / -Wdeprecated-declarations so that MSVC /sdl builds (which -// promote C4996 to an error) compile cleanly even though the wrappers -// themselves are also marked [[deprecated]]. -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#elif defined(__GNUC__) || defined(__clang__) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wdeprecated-declarations" -#endif - -SSL_CTX *Client::ssl_context() const { - if (is_ssl_) { return static_cast(*cli_).ssl_context(); } - return nullptr; -} - -void Client::set_server_certificate_verifier( - std::function verifier) { - cli_->set_server_certificate_verifier(verifier); -} - -long Client::get_verify_result() const { - if (is_ssl_) { return static_cast(*cli_).get_verify_result(); } - return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? -} - -#if defined(_MSC_VER) -#pragma warning(pop) -#elif defined(__GNUC__) || defined(__clang__) -#pragma GCC diagnostic pop -#endif -#endif // CPPHTTPLIB_OPENSSL_SUPPORT - /* * OpenSSL Backend Implementation */ @@ -12523,54 +12435,6 @@ namespace tls { namespace impl { -// OpenSSL-specific helpers for converting native types to PEM -std::string x509_to_pem(X509 *cert) { - if (!cert) return {}; - BIO *bio = BIO_new(BIO_s_mem()); - if (!bio) return {}; - if (PEM_write_bio_X509(bio, cert) != 1) { - BIO_free(bio); - return {}; - } - char *data = nullptr; - long len = BIO_get_mem_data(bio, &data); - std::string pem(data, static_cast(len)); - BIO_free(bio); - return pem; -} - -std::string evp_pkey_to_pem(EVP_PKEY *key) { - if (!key) return {}; - BIO *bio = BIO_new(BIO_s_mem()); - if (!bio) return {}; - if (PEM_write_bio_PrivateKey(bio, key, nullptr, nullptr, 0, nullptr, - nullptr) != 1) { - BIO_free(bio); - return {}; - } - char *data = nullptr; - long len = BIO_get_mem_data(bio, &data); - std::string pem(data, static_cast(len)); - BIO_free(bio); - return pem; -} - -std::string x509_store_to_pem(X509_STORE *store) { - if (!store) return {}; - std::string pem; - auto objs = X509_STORE_get0_objects(store); - if (!objs) return {}; - auto count = sk_X509_OBJECT_num(objs); - for (decltype(count) i = 0; i < count; i++) { - auto obj = sk_X509_OBJECT_value(objs, i); - if (X509_OBJECT_get_type(obj) == X509_LU_X509) { - auto cert = X509_OBJECT_get0_X509(obj); - if (cert) { pem += x509_to_pem(cert); } - } - } - return pem; -} - // Helper to map OpenSSL SSL_get_error to ErrorCode ErrorCode map_ssl_error(int ssl_error, int &out_errno) { switch (ssl_error) { @@ -12603,8 +12467,10 @@ STACK_OF(X509_NAME) * X509 *cert = nullptr; while ((cert = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr)) != nullptr) { - X509_NAME *name = X509_get_subject_name(cert); - if (name) { sk_X509_NAME_push(ca_list, X509_NAME_dup(name)); } + const X509_NAME *name = X509_get_subject_name(cert); + if (name) { + sk_X509_NAME_push(ca_list, X509_NAME_dup(const_cast(name))); + } X509_free(cert); } BIO_free(bio); @@ -12612,45 +12478,6 @@ STACK_OF(X509_NAME) * return ca_list; } -// Helper: Extract CA names from X509_STORE -// Returns a new STACK_OF(X509_NAME)* or nullptr on failure -// Caller takes ownership of returned list -STACK_OF(X509_NAME) * - extract_client_ca_list_from_store(X509_STORE *store) { - if (!store) { return nullptr; } - - auto ca_list = sk_X509_NAME_new_null(); - if (!ca_list) { return nullptr; } - - auto objs = X509_STORE_get0_objects(store); - if (!objs) { - sk_X509_NAME_free(ca_list); - return nullptr; - } - - auto count = sk_X509_OBJECT_num(objs); - for (decltype(count) i = 0; i < count; i++) { - auto obj = sk_X509_OBJECT_value(objs, i); - if (X509_OBJECT_get_type(obj) == X509_LU_X509) { - auto cert = X509_OBJECT_get0_X509(obj); - if (cert) { - auto subject = X509_get_subject_name(cert); - if (subject) { - auto name_dup = X509_NAME_dup(subject); - if (name_dup) { sk_X509_NAME_push(ca_list, name_dup); } - } - } - } - } - - if (sk_X509_NAME_num(ca_list) == 0) { - sk_X509_NAME_free(ca_list); - return nullptr; - } - - return ca_list; -} - // OpenSSL verify callback wrapper int openssl_verify_callback(int preverify_ok, X509_STORE_CTX *ctx) { auto &callback = get_verify_callback(); @@ -13086,6 +12913,9 @@ ssize_t read(session_t session, void *buf, size_t len, TlsError &err) { auto ssl_err = SSL_get_error(ssl, ret); err.code = impl::map_ssl_error(ssl_err, err.sys_errno); + if (err.code == ErrorCode::PeerClosed) { + return 0; + } // Gracefully handle the peer closed state. if (err.code == ErrorCode::Fatal) { err.backend_code = ERR_get_error(); } return -1; } @@ -13523,164 +13353,8 @@ std::string verify_error_string(long error_code) { return str ? str : "unknown error"; } -namespace impl { - -// OpenSSL-specific helpers for public API wrappers -ctx_t create_server_context_from_x509(X509 *cert, EVP_PKEY *key, - X509_STORE *client_ca_store, - int &out_error) { - out_error = 0; - auto cert_pem = x509_to_pem(cert); - auto key_pem = evp_pkey_to_pem(key); - if (cert_pem.empty() || key_pem.empty()) { - out_error = static_cast(ERR_get_error()); - return nullptr; - } - - auto ctx = create_server_context(); - if (!ctx) { - out_error = static_cast(get_error()); - return nullptr; - } - - if (!set_server_cert_pem(ctx, cert_pem.c_str(), key_pem.c_str(), nullptr)) { - out_error = static_cast(get_error()); - free_context(ctx); - return nullptr; - } - - if (client_ca_store) { - // Set cert store for verification (SSL_CTX_set_cert_store takes ownership) - SSL_CTX_set_cert_store(static_cast(ctx), client_ca_store); - - // Extract and set client CA list directly from store (more efficient than - // PEM conversion) - auto ca_list = extract_client_ca_list_from_store(client_ca_store); - if (ca_list) { - SSL_CTX_set_client_CA_list(static_cast(ctx), ca_list); - } - - set_verify_client(ctx, true); - } - - return ctx; -} - -void update_server_certs_from_x509(ctx_t ctx, X509 *cert, EVP_PKEY *key, - X509_STORE *client_ca_store) { - auto cert_pem = x509_to_pem(cert); - auto key_pem = evp_pkey_to_pem(key); - - if (!cert_pem.empty() && !key_pem.empty()) { - update_server_cert(ctx, cert_pem.c_str(), key_pem.c_str(), nullptr); - } - - if (client_ca_store) { - auto ca_pem = x509_store_to_pem(client_ca_store); - if (!ca_pem.empty()) { update_server_client_ca(ctx, ca_pem.c_str()); } - X509_STORE_free(client_ca_store); - } -} - -ctx_t create_client_context_from_x509(X509 *cert, EVP_PKEY *key, - const char *password, - uint64_t &out_error) { - out_error = 0; - auto ctx = create_client_context(); - if (!ctx) { - out_error = get_error(); - return nullptr; - } - - if (cert && key) { - auto cert_pem = x509_to_pem(cert); - auto key_pem = evp_pkey_to_pem(key); - if (cert_pem.empty() || key_pem.empty()) { - out_error = ERR_get_error(); - free_context(ctx); - return nullptr; - } - if (!set_client_cert_pem(ctx, cert_pem.c_str(), key_pem.c_str(), - password)) { - out_error = get_error(); - free_context(ctx); - return nullptr; - } - } - - return ctx; -} - -} // namespace impl - } // namespace tls -// ClientImpl::set_ca_cert_store - defined here to use -// tls::impl::x509_store_to_pem Deprecated: converts X509_STORE to PEM and -// stores for redirect transfer -void ClientImpl::set_ca_cert_store(X509_STORE *ca_cert_store) { - if (ca_cert_store) { - ca_cert_pem_ = tls::impl::x509_store_to_pem(ca_cert_store); - } -} - -SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store) { - ctx_ = tls::impl::create_server_context_from_x509( - cert, private_key, client_ca_cert_store, last_ssl_error_); -} - -SSLServer::SSLServer( - const std::function &setup_ssl_ctx_callback) { - // Use abstract API to create context - ctx_ = tls::create_server_context(); - if (ctx_) { - // Pass to OpenSSL-specific callback (ctx_ is SSL_CTX* internally) - auto ssl_ctx = static_cast(ctx_); - if (!setup_ssl_ctx_callback(*ssl_ctx)) { - tls::free_context(ctx_); - ctx_ = nullptr; - } - } -} - -SSL_CTX *SSLServer::ssl_context() const { - return static_cast(ctx_); -} - -void SSLServer::update_certs(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store) { - std::lock_guard guard(ctx_mutex_); - tls::impl::update_server_certs_from_x509(ctx_, cert, private_key, - client_ca_cert_store); -} - -SSLClient::SSLClient(const std::string &host, int port, - X509 *client_cert, EVP_PKEY *client_key, - const std::string &private_key_password) - : ClientImpl(host, port) { - const char *password = - private_key_password.empty() ? nullptr : private_key_password.c_str(); - ctx_ = tls::impl::create_client_context_from_x509( - client_cert, client_key, password, last_backend_error_); -} - -long SSLClient::get_verify_result() const { return verify_result_; } - -void SSLClient::set_server_certificate_verifier( - std::function verifier) { - // Wrap SSL* callback into backend-independent session_verifier_ - auto v = std::make_shared>( - std::move(verifier)); - session_verifier_ = [v](tls::session_t session) { - return (*v)(static_cast(session)); - }; -} - -SSL_CTX *SSLClient::ssl_context() const { - return static_cast(ctx_); -} - bool SSLClient::verify_host(X509 *server_cert) const { /* Quote from RFC2818 section 3.1 "Server Identity" @@ -16194,7 +15868,11 @@ ReadResult WebSocket::read(std::string &msg) { payload.size(), true, !is_server_); continue; } - case Opcode::Pong: continue; + case Opcode::Pong: { + std::lock_guard lock(ping_mutex_); + unacked_pings_ = 0; + continue; + } case Opcode::Close: { if (!closed_.exchange(true)) { // Echo close frame back @@ -16228,7 +15906,11 @@ ReadResult WebSocket::read(std::string &msg) { true, !is_server_); continue; } - if (cont_opcode == Opcode::Pong) { continue; } + if (cont_opcode == Opcode::Pong) { + std::lock_guard lock(ping_mutex_); + unacked_pings_ = 0; + continue; + } if (cont_opcode == Opcode::Close) { if (!closed_.exchange(true)) { std::lock_guard lock(write_mutex_); @@ -16316,12 +15998,22 @@ void WebSocket::start_heartbeat() { while (!closed_) { ping_cv_.wait_for(lock, std::chrono::seconds(ping_interval_sec_)); if (closed_) { break; } + // If the peer has failed to respond to the previous pings, give up. + // RFC 6455 does not define a pong-timeout mechanism; this is an + // opt-in liveness check controlled by max_missed_pongs_. + if (max_missed_pongs_ > 0 && unacked_pings_ >= max_missed_pongs_) { + lock.unlock(); + close(CloseStatus::GoingAway, "pong timeout"); + return; + } lock.unlock(); if (!send_frame(Opcode::Ping, nullptr, 0)) { + lock.lock(); closed_ = true; break; } lock.lock(); + unacked_pings_++; } }); } @@ -16449,8 +16141,9 @@ bool WebSocketClient::connect() { Request req; req.method = "GET"; req.path = path_; - ws_ = std::unique_ptr( - new WebSocket(std::move(strm), req, false, websocket_ping_interval_sec_)); + ws_ = std::unique_ptr(new WebSocket(std::move(strm), req, false, + websocket_ping_interval_sec_, + websocket_max_missed_pongs_)); return true; } @@ -16494,6 +16187,10 @@ void WebSocketClient::set_websocket_ping_interval(time_t sec) { websocket_ping_interval_sec_ = sec; } +void WebSocketClient::set_websocket_max_missed_pongs(int count) { + websocket_max_missed_pongs_ = count; +} + void WebSocketClient::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } void WebSocketClient::set_address_family(int family) { diff --git a/vendor/cpp-httplib/httplib.h b/vendor/cpp-httplib/httplib.h index 1d12f2d2b..8581d1695 100644 --- a/vendor/cpp-httplib/httplib.h +++ b/vendor/cpp-httplib/httplib.h @@ -8,8 +8,8 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.42.0" -#define CPPHTTPLIB_VERSION_NUM "0x002a00" +#define CPPHTTPLIB_VERSION "0.43.1" +#define CPPHTTPLIB_VERSION_NUM "0x002b01" #ifdef _WIN32 #if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0A00 @@ -205,6 +205,10 @@ #define CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND 30 #endif +#ifndef CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS +#define CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS 0 +#endif + /* * Headers */ @@ -1720,6 +1724,8 @@ public: Server &set_websocket_ping_interval( const std::chrono::duration &duration); + Server &set_websocket_max_missed_pongs(int count); + bool bind_to_port(const std::string &host, int port, int socket_flags = 0); int bind_to_any_port(const std::string &host, int socket_flags = 0); bool listen_after_bind(); @@ -1756,6 +1762,7 @@ protected: size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; time_t websocket_ping_interval_sec_ = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND; + int websocket_max_missed_pongs_ = CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS; private: using Handlers = @@ -1767,6 +1774,14 @@ private: static std::unique_ptr make_matcher(const std::string &pattern); + template + Server &add_handler( + std::vector, H>> &handlers, + const std::string &pattern, H handler) { + handlers.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; + } + Server &set_error_handler_core(HandlerWithResponse handler, std::true_type); Server &set_error_handler_core(Handler handler, std::false_type); @@ -1928,15 +1943,6 @@ private: int ssl_error_ = 0; uint64_t ssl_backend_error_ = 0; #endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -public: - [[deprecated("Use ssl_backend_error() instead. " - "This function will be removed by v1.0.0.")]] - uint64_t ssl_openssl_error() const { - return ssl_backend_error_; - } -#endif }; struct ClientConnection { @@ -2409,22 +2415,6 @@ protected: int last_ssl_error_ = 0; uint64_t last_backend_error_ = 0; #endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -public: - [[deprecated("Use load_ca_cert_store() instead. " - "This function will be removed by v1.0.0.")]] - void set_ca_cert_store(X509_STORE *ca_cert_store); - - [[deprecated("Use tls::create_ca_store() instead. " - "This function will be removed by v1.0.0.")]] - X509_STORE *create_ca_cert_store(const char *ca_cert, std::size_t size) const; - - [[deprecated("Use set_server_certificate_verifier(VerifyCallback) instead. " - "This function will be removed by v1.0.0.")]] - virtual void set_server_certificate_verifier( - std::function verifier); -#endif }; class Client { @@ -2599,7 +2589,6 @@ public: void set_follow_location(bool on); void set_path_encode(bool on); - void set_url_encode(bool on); void set_compress(bool on); @@ -2647,22 +2636,6 @@ public: private: bool is_ssl_ = false; #endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -public: - [[deprecated("Use tls_context() instead. " - "This function will be removed by v1.0.0.")]] - SSL_CTX *ssl_context() const; - - [[deprecated("Use set_session_verifier(session_t) instead. " - "This function will be removed by v1.0.0.")]] - void set_server_certificate_verifier( - std::function verifier); - - [[deprecated("Use Result::ssl_backend_error() instead. " - "This function will be removed by v1.0.0.")]] - long get_verify_result() const; -#endif }; #ifdef CPPHTTPLIB_SSL_ENABLED @@ -2708,29 +2681,6 @@ private: std::mutex ctx_mutex_; int last_ssl_error_ = 0; - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -public: - [[deprecated("Use SSLServer(PemMemory) or " - "SSLServer(ContextSetupCallback) instead. " - "This constructor will be removed by v1.0.0.")]] - SSLServer(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store = nullptr); - - [[deprecated("Use SSLServer(ContextSetupCallback) instead. " - "This constructor will be removed by v1.0.0.")]] - SSLServer( - const std::function &setup_ssl_ctx_callback); - - [[deprecated("Use tls_context() instead. " - "This function will be removed by v1.0.0.")]] - SSL_CTX *ssl_context() const; - - [[deprecated("Use update_certs_pem() instead. " - "This function will be removed by v1.0.0.")]] - void update_certs(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store = nullptr); -#endif }; class SSLClient final : public ClientImpl { @@ -2794,6 +2744,9 @@ private: Response &res, bool &success, Error &error); bool initialize_ssl(Socket &socket, Error &error); + void init_ctx(); + void reset_ctx_on_error(); + bool load_certs(); tls::ctx_t ctx_ = nullptr; @@ -2811,42 +2764,6 @@ private: friend class ClientImpl; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT -public: - [[deprecated("Use SSLClient(host, port, PemMemory) instead. " - "This constructor will be removed by v1.0.0.")]] - explicit SSLClient(const std::string &host, int port, X509 *client_cert, - EVP_PKEY *client_key, - const std::string &private_key_password = std::string()); - - [[deprecated("Use Result::ssl_backend_error() instead. " - "This function will be removed by v1.0.0.")]] - long get_verify_result() const; - - [[deprecated("Use tls_context() instead. " - "This function will be removed by v1.0.0.")]] - SSL_CTX *ssl_context() const; - - // Override of a deprecated virtual in ClientImpl. Suppress C4996 / - // -Wdeprecated-declarations on the override declaration itself so that - // MSVC /sdl builds compile cleanly. Will be removed together with the - // base virtual by v1.0.0. -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#elif defined(__GNUC__) || defined(__clang__) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wdeprecated-declarations" -#endif - [[deprecated("Use set_session_verifier(session_t) instead. " - "This function will be removed by v1.0.0.")]] - void set_server_certificate_verifier( - std::function verifier) override; -#if defined(_MSC_VER) -#pragma warning(pop) -#elif defined(__GNUC__) || defined(__clang__) -#pragma GCC diagnostic pop -#endif - private: bool verify_host(X509 *server_cert) const; bool verify_host_with_subject_alt_name(X509 *server_cert) const; @@ -3818,17 +3735,21 @@ private: WebSocket( Stream &strm, const Request &req, bool is_server, - time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND) + time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND, + int max_missed_pongs = CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS) : strm_(strm), req_(req), is_server_(is_server), - ping_interval_sec_(ping_interval_sec) { + ping_interval_sec_(ping_interval_sec), + max_missed_pongs_(max_missed_pongs) { start_heartbeat(); } WebSocket( std::unique_ptr &&owned_strm, const Request &req, bool is_server, - time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND) + time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND, + int max_missed_pongs = CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS) : strm_(*owned_strm), owned_strm_(std::move(owned_strm)), req_(req), - is_server_(is_server), ping_interval_sec_(ping_interval_sec) { + is_server_(is_server), ping_interval_sec_(ping_interval_sec), + max_missed_pongs_(max_missed_pongs) { start_heartbeat(); } @@ -3840,6 +3761,8 @@ private: Request req_; bool is_server_; time_t ping_interval_sec_; + int max_missed_pongs_; + int unacked_pings_ = 0; std::atomic closed_{false}; std::mutex write_mutex_; std::thread ping_thread_; @@ -3869,6 +3792,7 @@ public: void set_read_timeout(time_t sec, time_t usec = 0); void set_write_timeout(time_t sec, time_t usec = 0); void set_websocket_ping_interval(time_t sec); + void set_websocket_max_missed_pongs(int count); void set_tcp_nodelay(bool on); void set_address_family(int family); void set_ipv6_v6only(bool on); @@ -3900,6 +3824,7 @@ private: time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND; time_t websocket_ping_interval_sec_ = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND; + int websocket_max_missed_pongs_ = CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS; int address_family_ = AF_UNSPEC; bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY; From 52f1096f21b0e12e4532f34224991e06cf21a69a Mon Sep 17 00:00:00 2001 From: Zijun Yu Date: Tue, 21 Apr 2026 23:58:34 +0800 Subject: [PATCH 03/20] openvino: driver setup, CI split, thread safety, and NPU optimizations (#21944) * Thread safety per request only * Fix ROPE yarn case * Fix sticky stateful config * Use i4/i8 directly for symmetric quant * Use weightless caching * Add WeightlessCacheAttribute to reduce NPU memory usage * Gelu tanh support (#125) * Imrope support (#126) * fix(openvino): explicit ov::Tensor frees in ggml_backend_openvino_free * add GPU,NPU support in OV Dockerfile * add build-openvino.yml ci * Fix sticky stateful config * add concurrency to ov-gpu ci runs. Move OV CI to build-openvino.yml * fix thread-safety of shared runtime context * rope type abstraction for frontend translations * fix editorconfig --------- Co-authored-by: Mustafa Cavus Co-authored-by: Dan Hoffman Co-authored-by: Ravi Panchumarthy --- .devops/openvino.Dockerfile | 50 +- .github/workflows/build-openvino.yml | 120 +++++ .github/workflows/build-self-hosted.yml | 4 + .github/workflows/build.yml | 80 --- docs/backend/OPENVINO.md | 3 - ggml/src/ggml-openvino/ggml-decoder.cpp | 20 +- .../src/ggml-openvino/ggml-openvino-extra.cpp | 29 +- ggml/src/ggml-openvino/ggml-openvino.cpp | 42 +- ggml/src/ggml-openvino/ggml-quants.cpp | 458 ++++++++++-------- ggml/src/ggml-openvino/openvino/op/rope.cpp | 40 +- .../ggml-openvino/openvino/op/unary_gelu.cpp | 25 + ggml/src/ggml-openvino/openvino/op_table.cpp | 1 + ggml/src/ggml-openvino/openvino/op_table.h | 1 + .../openvino/pass/eliminate_zp.cpp | 123 ----- .../openvino/pass/eliminate_zp.h | 17 - .../rt_info/weightless_caching_attributes.hpp | 41 ++ .../openvino/translate_session.cpp | 30 +- ggml/src/ggml-openvino/openvino/utils.cpp | 109 +++-- ggml/src/ggml-openvino/openvino/utils.h | 1 + ggml/src/ggml-openvino/utils.cpp | 147 ++++-- ggml/src/ggml-openvino/utils.h | 26 +- 21 files changed, 823 insertions(+), 544 deletions(-) create mode 100644 .github/workflows/build-openvino.yml create mode 100644 ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp delete mode 100644 ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp delete mode 100644 ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h create mode 100644 ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp diff --git a/.devops/openvino.Dockerfile b/.devops/openvino.Dockerfile index 3ee4dd201..31b58736d 100644 --- a/.devops/openvino.Dockerfile +++ b/.devops/openvino.Dockerfile @@ -2,7 +2,19 @@ ARG OPENVINO_VERSION_MAJOR=2026.0 ARG OPENVINO_VERSION_FULL=2026.0.0.20965.c6d6a13a886 ARG UBUNTU_VERSION=24.04 -# Optional proxy build arguments - empty by default +# Intel GPU driver versions. https://github.com/intel/compute-runtime/releases +ARG IGC_VERSION=v2.30.1 +ARG IGC_VERSION_FULL=2_2.30.1+20950 +ARG COMPUTE_RUNTIME_VERSION=26.09.37435.1 +ARG COMPUTE_RUNTIME_VERSION_FULL=26.09.37435.1-0 +ARG IGDGMM_VERSION=22.9.0 + +# Intel NPU driver versions. https://github.com/intel/linux-npu-driver/releases +ARG NPU_DRIVER_VERSION=v1.32.0 +ARG NPU_DRIVER_FULL=v1.32.0.20260402-23905121947 +ARG LIBZE1_VERSION=1.27.0-1~24.04~ppa2 + +# Optional proxy build arguments ARG http_proxy= ARG https_proxy= @@ -78,13 +90,47 @@ ARG http_proxy ARG https_proxy RUN apt-get update \ - && apt-get install -y libgomp1 libtbb12 curl \ + && apt-get install -y libgomp1 libtbb12 curl wget ocl-icd-libopencl1 \ && apt autoremove -y \ && apt clean -y \ && rm -rf /tmp/* /var/tmp/* \ && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ && find /var/cache -type f -delete +# Install GPU drivers +ARG IGC_VERSION +ARG IGC_VERSION_FULL +ARG COMPUTE_RUNTIME_VERSION +ARG COMPUTE_RUNTIME_VERSION_FULL +ARG IGDGMM_VERSION +RUN mkdir /tmp/neo/ && cd /tmp/neo/ \ + && wget https://github.com/intel/intel-graphics-compiler/releases/download/${IGC_VERSION}/intel-igc-core-${IGC_VERSION_FULL}_amd64.deb \ + && wget https://github.com/intel/intel-graphics-compiler/releases/download/${IGC_VERSION}/intel-igc-opencl-${IGC_VERSION_FULL}_amd64.deb \ + && wget https://github.com/intel/compute-runtime/releases/download/${COMPUTE_RUNTIME_VERSION}/intel-ocloc-dbgsym_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.ddeb \ + && wget https://github.com/intel/compute-runtime/releases/download/${COMPUTE_RUNTIME_VERSION}/intel-ocloc_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.deb \ + && wget https://github.com/intel/compute-runtime/releases/download/${COMPUTE_RUNTIME_VERSION}/intel-opencl-icd-dbgsym_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.ddeb \ + && wget https://github.com/intel/compute-runtime/releases/download/${COMPUTE_RUNTIME_VERSION}/intel-opencl-icd_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.deb \ + && wget https://github.com/intel/compute-runtime/releases/download/${COMPUTE_RUNTIME_VERSION}/libigdgmm12_${IGDGMM_VERSION}_amd64.deb \ + && wget https://github.com/intel/compute-runtime/releases/download/${COMPUTE_RUNTIME_VERSION}/libze-intel-gpu1-dbgsym_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.ddeb \ + && wget https://github.com/intel/compute-runtime/releases/download/${COMPUTE_RUNTIME_VERSION}/libze-intel-gpu1_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.deb \ + && dpkg --install *.deb \ + && rm -rf /tmp/neo/ + +# Install NPU drivers +ARG NPU_DRIVER_VERSION +ARG NPU_DRIVER_FULL +ARG LIBZE1_VERSION +RUN mkdir /tmp/npu/ && cd /tmp/npu/ \ + && wget https://github.com/intel/linux-npu-driver/releases/download/${NPU_DRIVER_VERSION}/linux-npu-driver-${NPU_DRIVER_FULL}-ubuntu2404.tar.gz \ + && tar -xf linux-npu-driver-${NPU_DRIVER_FULL}-ubuntu2404.tar.gz \ + && dpkg --install *.deb \ + && rm -rf /tmp/npu/ + +RUN cd /tmp \ + && wget https://snapshot.ppa.launchpadcontent.net/kobuk-team/intel-graphics/ubuntu/20260324T100000Z/pool/main/l/level-zero-loader/libze1_${LIBZE1_VERSION}_amd64.deb \ + && dpkg --install libze1_${LIBZE1_VERSION}_amd64.deb \ + && rm libze1_${LIBZE1_VERSION}_amd64.deb + COPY --from=build /app/lib/ /app/ ### Full (all binaries) diff --git a/.github/workflows/build-openvino.yml b/.github/workflows/build-openvino.yml new file mode 100644 index 000000000..f7177f6be --- /dev/null +++ b/.github/workflows/build-openvino.yml @@ -0,0 +1,120 @@ +name: CI (openvino) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: [ + '.github/workflows/build-openvino.yml', + '**/CMakeLists.txt', + '**/.cmake', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + ] + + pull_request: + types: [opened, synchronize, reopened] + paths: [ + '.github/workflows/build-openvino.yml', + 'ggml/src/ggml-openvino/**' + ] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +env: + GGML_NLOOP: 3 + GGML_N_THREADS: 1 + LLAMA_LOG_COLORS: 1 + LLAMA_LOG_PREFIX: 1 + LLAMA_LOG_TIMESTAMPS: 1 + +jobs: + ubuntu-24-openvino: + name: ubuntu-24-openvino-${{ matrix.openvino_device }} + + concurrency: + group: openvino-${{ matrix.variant }}-${{ github.head_ref || github.ref }} + cancel-in-progress: false + + strategy: + matrix: + include: + - variant: cpu + runner: '"ubuntu-24.04"' + openvino_device: "CPU" + - variant: gpu + runner: '["self-hosted","Linux","Intel","OpenVINO"]' + openvino_device: "GPU" + + runs-on: ${{ fromJSON(matrix.runner) }} + + env: + # Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile + OPENVINO_VERSION_MAJOR: "2026.0" + OPENVINO_VERSION_FULL: "2026.0.0.20965.c6d6a13a886" + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: ccache + if: runner.environment == 'github-hosted' + uses: ggml-org/ccache-action@v1.2.21 + with: + key: ubuntu-24-openvino-${{ matrix.variant }}-no-preset-v1 + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install -y build-essential libssl-dev libtbb12 cmake ninja-build python3-pip + sudo apt-get install -y ocl-icd-opencl-dev opencl-headers opencl-clhpp-headers intel-opencl-icd + + - name: Use OpenVINO Toolkit Cache + if: runner.environment == 'github-hosted' + uses: actions/cache@v5 + id: cache-openvino + with: + path: ./openvino_toolkit + key: openvino-toolkit-v${{ env.OPENVINO_VERSION_FULL }}-${{ runner.os }} + + - name: Setup OpenVINO Toolkit + if: steps.cache-openvino.outputs.cache-hit != 'true' + uses: ./.github/actions/linux-setup-openvino + with: + path: ./openvino_toolkit + version_major: ${{ env.OPENVINO_VERSION_MAJOR }} + version_full: ${{ env.OPENVINO_VERSION_FULL }} + + - name: Install OpenVINO dependencies + run: | + cd ./openvino_toolkit + chmod +x ./install_dependencies/install_openvino_dependencies.sh + echo "Y" | sudo -E ./install_dependencies/install_openvino_dependencies.sh + + - name: Build + id: cmake_build + run: | + source ./openvino_toolkit/setupvars.sh + cmake -B build/ReleaseOV -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENVINO=ON + time cmake --build build/ReleaseOV --config Release -j $(nproc) + + - name: Test + id: cmake_test + # TODO: fix and re-enable the `test-llama-archs` test below + run: | + cd ${{ github.workspace }} + if [ "${{ matrix.openvino_device }}" = "GPU" ]; then + export GGML_OPENVINO_DEVICE=GPU + fi + ctest --test-dir build/ReleaseOV -L main -E "test-llama-archs" --verbose --timeout 2000 diff --git a/.github/workflows/build-self-hosted.yml b/.github/workflows/build-self-hosted.yml index 52624a46d..e9148dd73 100644 --- a/.github/workflows/build-self-hosted.yml +++ b/.github/workflows/build-self-hosted.yml @@ -265,6 +265,10 @@ jobs: ggml-ci-intel-openvino-gpu-low-perf: runs-on: [self-hosted, Linux, Intel, OpenVINO] + concurrency: + group: openvino-gpu-${{ github.head_ref || github.ref }} + cancel-in-progress: false + env: # Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile OPENVINO_VERSION_MAJOR: "2026.0" diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 28c8665bd..c7f00e359 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -656,86 +656,6 @@ jobs: -DGGML_SYCL_F16=ON time cmake --build build --config Release -j $(nproc) - ubuntu-24-openvino: - name: ubuntu-24-openvino-${{ matrix.openvino_device }} - strategy: - matrix: - include: - - variant: cpu - runner: '"ubuntu-24.04"' - openvino_device: "CPU" - - variant: gpu - runner: '["self-hosted","Linux","X64","Intel"]' - openvino_device: "GPU" - - runs-on: ${{ fromJSON(matrix.runner) }} - - env: - # Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile - OPENVINO_VERSION_MAJOR: "2026.0" - OPENVINO_VERSION_FULL: "2026.0.0.20965.c6d6a13a886" - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: ccache - if: runner.environment == 'github-hosted' - uses: ggml-org/ccache-action@v1.2.21 - with: - key: ubuntu-24-openvino-${{ matrix.variant }}-no-preset-v1 - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install -y build-essential libssl-dev libtbb12 cmake ninja-build python3-pip - sudo apt-get install -y ocl-icd-opencl-dev opencl-headers opencl-clhpp-headers intel-opencl-icd - - - name: Use OpenVINO Toolkit Cache - if: runner.environment == 'github-hosted' - uses: actions/cache@v5 - id: cache-openvino - with: - path: ./openvino_toolkit - key: openvino-toolkit-v${{ env.OPENVINO_VERSION_FULL }}-${{ runner.os }} - - - name: Setup OpenVINO Toolkit - if: steps.cache-openvino.outputs.cache-hit != 'true' - uses: ./.github/actions/linux-setup-openvino - with: - path: ./openvino_toolkit - version_major: ${{ env.OPENVINO_VERSION_MAJOR }} - version_full: ${{ env.OPENVINO_VERSION_FULL }} - - - name: Install OpenVINO dependencies - run: | - cd ./openvino_toolkit - chmod +x ./install_dependencies/install_openvino_dependencies.sh - echo "Y" | sudo -E ./install_dependencies/install_openvino_dependencies.sh - - - name: Build - id: cmake_build - run: | - source ./openvino_toolkit/setupvars.sh - cmake -B build/ReleaseOV -G Ninja \ - -DCMAKE_BUILD_TYPE=Release \ - -DGGML_OPENVINO=ON - time cmake --build build/ReleaseOV --config Release -j $(nproc) - - - name: Test - id: cmake_test - # TODO: fix and re-enable the `test-llama-archs` test below - run: | - cd ${{ github.workspace }} - if [ "${{ matrix.openvino_device }}" = "GPU" ]; then - export GGML_OPENVINO_DEVICE=GPU - fi - ctest --test-dir build/ReleaseOV -L main -E "test-llama-archs" --verbose --timeout 2000 - windows-latest: runs-on: windows-2025 diff --git a/docs/backend/OPENVINO.md b/docs/backend/OPENVINO.md index 96d0f672e..c9c005a99 100644 --- a/docs/backend/OPENVINO.md +++ b/docs/backend/OPENVINO.md @@ -244,7 +244,6 @@ build\ReleaseOV\bin\llama-cli.exe -m "C:\models\Llama-3.2-1B-Instruct-Q4_0.gguf" - `-fa 1` is required when running llama-bench with the OpenVINO backend. - `GGML_OPENVINO_STATEFUL_EXECUTION=1 GGML_OPENVINO_DEVICE=GPU ./llama-bench -fa 1` - `llama-server` with OpenVINO backend supports only one chat session/thread, when `GGML_OPENVINO_STATEFUL_EXECUTION=1` is enabled. -- For Intel GPU, NPU detection in containers, GPU, NPU user-space drivers/libraries must be present inside the image. We will include in a future PR. Until then, you can use this reference Dockerfile: [openvino.Dockerfile](https://github.com/ravi9/llama.cpp/blob/ov-docker-update/.devops/openvino.Dockerfile) > [!NOTE] > The OpenVINO backend is actively under development. Fixes are underway, and this document will continue to be updated as issues are resolved. @@ -274,8 +273,6 @@ docker build --build-arg http_proxy=$http_proxy --build-arg https_proxy=$https_p Run llama.cpp with OpenVINO backend Docker container. Save sample models in `~/models` as [shown above](#3-download-sample-model). It will be mounted to the container in the examples below. -> [!NOTE] -> Intel GPU, NPU detection in containers will be included in a future PR. Until then, you can use this reference Dockerfile: [openvino.Dockerfile](https://github.com/ravi9/llama.cpp/blob/ov-docker-update/.devops/openvino.Dockerfile). ```bash # Run Docker container diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 0938d2273..5095e7998 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -207,8 +206,22 @@ int GgmlOvDecoder::compute_op_case(const ggml_tensor * node) const { break; } case GGML_OP_ROPE: { + const int mode = node->op_params[2]; + switch (mode) { + case GGML_ROPE_TYPE_NEOX: { + op_case = 0x00010000; + break; + } + case GGML_ROPE_TYPE_IMROPE: { + op_case = 0x00020000; + break; + } + default: + op_case = 0x00000000; + break; + } if (node->src[0]->op == GGML_OP_VIEW) { - op_case = 2; + op_case = (op_case | 0x00000002); } break; } @@ -573,9 +586,6 @@ std::map GgmlOvDecoder::get_kv_param_res_names() const } std::map> GgmlOvDecoder::create_weight_nodes(ggml_cgraph * cgraph, bool naive) { - static std::mutex weights_mutex; - std::lock_guard lock(weights_mutex); - std::map> model_weights; auto * nodes = cgraph->nodes; auto n_nodes = cgraph->n_nodes; diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp index cc3cb4583..4140136ac 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include ov::Core & ov_singleton_core() { @@ -42,11 +43,13 @@ void ggml_openvino_device_config::init() { {"NPUW_DQ", "YES" }, {"NPUW_DQ_FULL", "NO" }, }; - if (cache_dir) { + if (cache_dir && strlen(cache_dir) > 0) { compile_config["NPUW_CACHE_DIR"] = cache_dir; + compile_config.insert(ov::cache_mode(ov::CacheMode::OPTIMIZE_SIZE)); } - } else if (cache_dir) { - ov_singleton_core().set_property(ov::cache_dir(cache_dir)); + } else if (cache_dir && strlen(cache_dir) > 0) { + compile_config.insert(ov::cache_dir(cache_dir)); + compile_config.insert(ov::cache_mode(ov::CacheMode::OPTIMIZE_SIZE)); } // Initialize remote context with queue sharing for GPU @@ -259,10 +262,12 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements; int64_t n_blocks = n_elements / layout.weights_per_block; layout.scales_size = n_blocks * sizeof(uint16_t); - // For symmetric quantization, we only need one zp value (not one per block) - // Zero points are stored in U4 or U8 format matching the weight type - size_t n_zp_elements = layout.is_symmetric ? 1 : n_blocks; - layout.zp_size = layout.is_u4 ? ((n_zp_elements + 1) / 2) : n_zp_elements; + // For symmetric quantization, no zp needed (weights stored as signed) + if (layout.is_symmetric) { + layout.zp_size = 0; + } else { + layout.zp_size = layout.is_u4 ? ((n_blocks + 1) / 2) : n_blocks; + } layout.weights_offset = 0; layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment; @@ -313,10 +318,12 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten // Scales: F16 per block int64_t n_blocks = n_elements / layout.weights_per_block; layout.scales_size = n_blocks * sizeof(uint16_t); // F16 = 2 bytes - // Zero points: U4 or U8 matching weight type - // For symmetric quantization, we only need one zp value (not one per block) - size_t n_zp_elements = layout.is_symmetric ? 1 : n_blocks; - layout.zp_size = layout.is_u4 ? ((n_zp_elements + 1) / 2) : n_zp_elements; + // For symmetric quantization, no zp needed (weights stored as signed) + if (layout.is_symmetric) { + layout.zp_size = 0; + } else { + layout.zp_size = layout.is_u4 ? ((n_blocks + 1) / 2) : n_blocks; + } // Layout in buffer: [weights | scales | zp] with alignment layout.weights_offset = 0; diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 0c8d3508e..4f3ebf253 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -145,13 +145,18 @@ static void * ggml_backend_openvino_buffer_get_base(ggml_backend_buffer_t buffer return ctx->data; } +static bool is_stateful_enabled() { + static const auto * stateful = getenv("GGML_OPENVINO_STATEFUL_EXECUTION"); + return stateful && *stateful != '\0' && strcmp(stateful, "0") != 0; +} + static enum ggml_status ggml_backend_openvino_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name); ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; // Put kvcache on device memory for GPU (NPU memory is too small even for kvcache) if (strncmp(tensor->name, "cache_", 6) == 0 && !ctx->is_remote && ggml_openvino_get_device_name() == "GPU" && - !getenv("GGML_OPENVINO_STATEFUL_EXECUTION")) { + !is_stateful_enabled()) { GGML_ASSERT(ctx->tensor_extras.empty()); auto device = ctx->device; auto size = ctx->size; @@ -600,6 +605,14 @@ bool ggml_backend_buft_is_openvino_host(ggml_backend_buffer_type_t buft) { static void ggml_backend_openvino_free(ggml_backend_t backend) { ggml_backend_openvino_context * ctx = (ggml_backend_openvino_context *) backend->context; + + if (ctx->runtime_context) { + auto r_ctx = std::static_pointer_cast(ctx->runtime_context); + if (--r_ctx->backend_count == 0) { + r_ctx->clear_caches(); + } + } + delete ctx; delete backend; } @@ -644,7 +657,12 @@ static ggml_guid_t ggml_backend_openvino_guid(void) { } static std::shared_ptr get_ov_runtime_context_ptr() { - static std::shared_ptr r_ctx = std::make_shared(); + static std::shared_ptr r_ctx = [] { + auto ctx = std::make_shared(); + ctx->device = ggml_openvino_get_device_name(); + ctx->stateful = is_stateful_enabled() && !ggml_openvino_is_npu(); + return ctx; + }(); return r_ctx; } @@ -669,8 +687,7 @@ GGML_BACKEND_API ggml_backend_t ggml_backend_openvino_init(int device) { } std::shared_ptr r_ctx = std::static_pointer_cast(ctx->runtime_context); - r_ctx->device = ggml_openvino_get_device_name(); - r_ctx->stateful = getenv("GGML_OPENVINO_STATEFUL_EXECUTION") && !ggml_openvino_is_npu(); + r_ctx->backend_count++; ggml_backend_t openvino_backend = new ggml_backend{ /* .guid = */ ggml_backend_openvino_guid(), @@ -883,7 +900,7 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { const int32_t * op_params = op->op_params; const int n_dims = op_params[1]; const int mode = op_params[2]; - if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) { + if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_IMROPE) { // GGML_LOG_WARN("OpenVINO backend does not support ROPE with mode %d\n", mode); return true; } @@ -896,14 +913,6 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { // GGML_LOG_WARN("OpenVINO backend does not support ROPE with type %s\n", ggml_type_name(op->type)); return true; } - float freq_scale; - float ext_factor; - memcpy(&freq_scale, op_params + 6, sizeof(float)); - memcpy(&ext_factor, op_params + 7, sizeof(float)); - if (ext_factor != 0.0f) { - // GGML_LOG_WARN("OpenVINO backend does not support ROPE with ext_factor %f != 0.0f\n", ext_factor); - return true; - } if (op->src[0]->op == GGML_OP_VIEW) { if (op->src[0]->view_src->ne[1] != op->src[0]->ne[2]) { // GGML_LOG_WARN( @@ -913,6 +922,12 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { return true; } } + if (mode == GGML_ROPE_TYPE_IMROPE && + (op->src[2] != 0 || ((const float *) op_params)[6] != 1 || ((const float *) op_params)[7] != 0 || + ((const float *) op_params)[8] != 1)) { + // GGML_LOG_WARN("OpenVINO backend does not support IMROPE with freq_factors, freq_scale, ext_factor, and attn_factor\n"); + return true; + } break; } default: @@ -942,6 +957,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con // GGML_OP_SOFT_MAX, GGML_OP_SET_ROWS, GGML_OP_FLASH_ATTN_EXT, GGML_OP_CPY}; static const std::set supported_unary_ops{ + GGML_UNARY_OP_GELU, GGML_UNARY_OP_SILU, }; static const std::set supported_glu_ops{ diff --git a/ggml/src/ggml-openvino/ggml-quants.cpp b/ggml/src/ggml-openvino/ggml-quants.cpp index dbf38646d..57d66df4f 100644 --- a/ggml/src/ggml-openvino/ggml-quants.cpp +++ b/ggml/src/ggml-openvino/ggml-quants.cpp @@ -46,6 +46,7 @@ void unpack_32_4(const uint8_t * data, uint8_t * dst) { // Extracts (weight, scales, zp) from Q4_0 tensors. // Data layout is: |16 bit scale|32 x 4bit weights|. +// When zp_arr is empty (symmetric), weights are stored as signed i4 (value - 8). void extract_q4_0_data(const ggml_tensor * tensor, ov::Tensor & weights_arr, ov::Tensor & scales_arr, @@ -55,28 +56,32 @@ void extract_q4_0_data(const ggml_tensor * tensor, auto * data = static_cast(tensor->data); auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i4); // Signed i4 path - // For Q4_0, zero point is always 8 - if (is_scalar_zp) { - zp[0] = 8 | (8 << 4); // Pack two 4-bit values - } - - ov::parallel_for(scales_arr.get_size(), [&](size_t i) { - scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))); - // For asymmetric quantization, compute per-block zero points - if (!is_scalar_zp) { + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))); // Pack two 4-bit zero points per byte if (i % 2 == 0) { zp[i / 2] = 8; // Lower nibble } else { zp[i / 2] |= (8 << 4); // Upper nibble } - } - unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); - }); + unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); + }); + } else { + // Symmetric: unpack as u4 then convert to i4 by subtracting 8 (XOR each nibble) + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))); + unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); + // Convert u4 to i4: subtract 8 from each nibble. XOR 0x88 flips each nibble by 8. + for (int j = 0; j < 16; ++j) { + weights[i * 16 + j] ^= 0x88; + } + }); + } } // Extracts (weight, scales, zp) from Q4_1 tensors. @@ -123,6 +128,7 @@ void extract_q4_1_data(const ggml_tensor * tensor, // Extracts (weight, scales, zp) from Q8_0 tensors. // Data layout is: |16 bit scale|32 x 8bit weights|. +// When zp_arr is empty (symmetric), weights are stored as signed i8 directly. void extract_q8_0_data(const ggml_tensor * tensor, ov::Tensor & weights_arr, ov::Tensor & scales_arr, @@ -133,29 +139,30 @@ void extract_q8_0_data(const ggml_tensor * tensor, auto * data = static_cast(tensor->data); auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i8); // Signed i8 path - // For Q8_0, zero point is always 128 - if (is_scalar_zp) { - zp[0] = 128; - } - - ov::parallel_for(scales_arr.get_size(), [&](size_t i) { - uint8_t * block_data = data + i * bytes_per_block; - scales[i] = ov::float16::from_bits(*(uint16_t *) block_data); - // For asymmetric quantization, store per-block zero points - if (!is_scalar_zp) { + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + scales[i] = ov::float16::from_bits(*(uint16_t *) block_data); zp[i] = 128; - } - for (size_t j = 0; j < weights_per_block; ++j) { - uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes. - // Original data is in int8_t, so we add a bias of -128 and invert the first bit. - x ^= 1 << 7; - weights[i * weights_per_block + j] = x; - } - }); + for (size_t j = 0; j < weights_per_block; ++j) { + uint8_t x = block_data[j + 2]; + x ^= 1 << 7; // Convert int8 to uint8 by flipping sign bit + weights[i * weights_per_block + j] = x; + } + }); + } else { + // Symmetric: store original int8 values directly (no unsigned bias) + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + scales[i] = ov::float16::from_bits(*(uint16_t *) block_data); + // Copy int8 weights as-is (the tensor element type is i8) + memcpy(weights + i * weights_per_block, block_data + 2, weights_per_block); + }); + } } void unpack_256_4(const uint8_t * data, uint8_t * dst) { @@ -256,44 +263,62 @@ void extract_q6_k_data(const ggml_tensor * tensor, auto * data = static_cast(tensor->data); auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i8); // Signed i8 path - // For Q6_K, zero point is always 32 - if (is_scalar_zp) { - zp[0] = 32; - } - - ov::parallel_for(n_super_block, [&](size_t i) { - uint8_t * block_data = data + i * bytes_per_block; - - float scale_factor = - static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 104))); // (128+64+16)/2 - - for (size_t j = 0; j < 16; j++) { - scales[j + i * 16] = - ov::float16(scale_factor * static_cast(*((int8_t *) (block_data + 128 + 64 + j)))); - // For asymmetric quantization, store per-block zero points - if (!is_scalar_zp) { + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + float scale_factor = static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 104))); + for (size_t j = 0; j < 16; j++) { + scales[j + i * 16] = + ov::float16(scale_factor * static_cast(*((int8_t *) (block_data + 128 + 64 + j)))); zp[j + i * 16] = 32; } - } - - uint8_t * ql = block_data; - uint8_t * qh = block_data + 128; - - for (int64_t j = 0; j < 32; ++j) { - weights[i * 256 + j] = (ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4); - weights[i * 256 + j + 32] = (ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4); - weights[i * 256 + j + 64] = (ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4); - weights[i * 256 + j + 96] = (ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4); - weights[i * 256 + j + 128] = (ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4); - weights[i * 256 + j + 160] = (ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4); - weights[i * 256 + j + 192] = (ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4); - weights[i * 256 + j + 224] = (ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4); - } - }); + uint8_t * ql = block_data; + uint8_t * qh = block_data + 128; + for (int64_t j = 0; j < 32; ++j) { + weights[i * 256 + j] = (ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4); + weights[i * 256 + j + 32] = (ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4); + weights[i * 256 + j + 64] = (ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4); + weights[i * 256 + j + 96] = (ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4); + weights[i * 256 + j + 128] = (ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4); + weights[i * 256 + j + 160] = (ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4); + weights[i * 256 + j + 192] = (ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4); + weights[i * 256 + j + 224] = (ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4); + } + }); + } else { + // Symmetric: subtract 32 from each weight to store as signed i8 + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + float scale_factor = static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 104))); + for (size_t j = 0; j < 16; j++) { + scales[j + i * 16] = + ov::float16(scale_factor * static_cast(*((int8_t *) (block_data + 128 + 64 + j)))); + } + uint8_t * ql = block_data; + uint8_t * qh = block_data + 128; + auto * signed_weights = reinterpret_cast(weights); + for (int64_t j = 0; j < 32; ++j) { + signed_weights[i * 256 + j] = static_cast((ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 32] = + static_cast((ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 64] = static_cast((ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 96] = + static_cast((ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 128] = + static_cast((ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 160] = + static_cast((ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 192] = + static_cast((ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 224] = + static_cast((ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4)) - 32; + } + }); + } } static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t * d, uint8_t * m) { @@ -389,11 +414,10 @@ ov::Output make_int8_weights(ov::Tensor & weight, size_t group_size, bool use_bias) { ov::Shape orig_shape = weight.get_shape(); + bool is_signed = (weight.get_element_type() == ov::element::i8); // Symmetric: signed weights, no ZP // Expand dimensions for scales and zp/bias auto scale_shape = scales.get_shape(); - auto zp_shape = zp.get_shape(); - bool is_scalar_zp = zp_shape.empty(); // Symmetric quantization ov::Shape packed_shape = {orig_shape[0], orig_shape[1] / group_size, group_size}; @@ -403,37 +427,48 @@ ov::Output make_int8_weights(ov::Tensor & weight, } else { scale_shape.push_back(1); scales.set_shape(scale_shape); - // For symmetric quantization, zp remains scalar (don't resize) - if (!is_scalar_zp) { + if (!is_signed && zp.get_size() > 0) { + auto zp_shape = zp.get_shape(); zp_shape.push_back(1); zp.set_shape(zp_shape); } } - // Create graph nodes - auto weights_node = std::make_shared(ov::element::u8, packed_shape, - static_cast(weight.data()), nullptr); - weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; auto scales_f16 = std::make_shared(scales); - auto weights_f16 = std::make_shared(weights_node, ov::element::f16); ov::Output result; - if (use_bias && !is_scalar_zp) { - // Bias path: w * s + b (zp tensor holds f16 bias values) - auto bias_f16 = std::make_shared(zp); - auto w_s = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); - result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + if (is_signed) { + // Signed path: q * s (no zero point subtraction needed) + auto weights_node = std::make_shared(ov::element::i8, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + result = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); } else { - // Zero point path: (w - zp) * s - auto zero_point = std::make_shared(zp); - float zp_value; - if (ov::op::util::get_single_value(zero_point, zp_value)) { - zero_point = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value}); + // Unsigned path + auto weights_node = std::make_shared(ov::element::u8, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + + if (use_bias && zp.get_size() > 0) { + // Bias path: w * s + b (zp tensor holds f16 bias values) + auto bias_f16 = std::make_shared(zp); + auto w_s = + std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + } else { + // Zero point path: (w - zp) * s + auto zero_point = std::make_shared(zp); + float zp_value; + if (ov::op::util::get_single_value(zero_point, zp_value)) { + zero_point = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value}); + } + auto zero_point_f16 = std::make_shared(zero_point, ov::element::f16); + auto w_zp = + std::make_shared(weights_f16, zero_point_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); } - auto zero_point_f16 = std::make_shared(zero_point, ov::element::f16); - auto w_zp = - std::make_shared(weights_f16, zero_point_f16, ov::op::AutoBroadcastType::NUMPY); - result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); } if (packed_shape.size() != 2) { @@ -452,11 +487,10 @@ ov::Output make_int4_weights(ov::Tensor & weight, size_t group_size, bool use_bias) { ov::Shape orig_weight_shape = weight.get_shape(); + bool is_signed = (weight.get_element_type() == ov::element::i4); // Symmetric: signed weights, no ZP // Expand dimensions for scales and zp/bias ov::Shape scale_shape = scales.get_shape(); - auto zp_shape = zp.get_shape(); - bool is_scalar_zp = zp_shape.empty(); // Symmetric quantization // Create INT4 weight tensor ov::Shape packed_shape = {orig_weight_shape[0], orig_weight_shape[1] / group_size, group_size}; @@ -467,36 +501,48 @@ ov::Output make_int4_weights(ov::Tensor & weight, } else { scale_shape.push_back(1); scales.set_shape(scale_shape); - // For symmetric quantization, zp remains scalar (don't resize) - if (!is_scalar_zp) { + if (!is_signed && zp.get_size() > 0) { + auto zp_shape = zp.get_shape(); zp_shape.push_back(1); zp.set_shape(zp_shape); } } - auto weights_node = std::make_shared(ov::element::u4, packed_shape, - static_cast(weight.data()), nullptr); - weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; - auto weights_f16 = std::make_shared(weights_node, ov::element::f16); auto scales_f16 = std::make_shared(scales); ov::Output result; - if (use_bias && !is_scalar_zp) { - // Bias path: w * s + b (zp tensor holds f16 bias values) - auto bias_f16 = std::make_shared(zp); - auto w_s = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); - result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + if (is_signed) { + // Signed path: q * s (no zero point subtraction needed) + auto weights_node = std::make_shared(ov::element::i4, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + result = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); } else { - // Zero point path: (w - zp) * s - auto zero_points_node = std::make_shared(zp); - float zp_value; - if (ov::op::util::get_single_value(zero_points_node, zp_value)) { - zero_points_node = ov::op::v0::Constant::create(zero_points_node->get_element_type(), {}, {zp_value}); + // Unsigned path + auto weights_node = std::make_shared(ov::element::u4, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + + if (use_bias && zp.get_size() > 0) { + // Bias path: w * s + b (zp tensor holds f16 bias values) + auto bias_f16 = std::make_shared(zp); + auto w_s = + std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + } else { + // Zero point path: (w - zp) * s + auto zero_points_node = std::make_shared(zp); + float zp_value; + if (ov::op::util::get_single_value(zero_points_node, zp_value)) { + zero_points_node = ov::op::v0::Constant::create(zero_points_node->get_element_type(), {}, {zp_value}); + } + auto zero_points_f16 = std::make_shared(zero_points_node, ov::element::f16); + auto w_zp = + std::make_shared(weights_f16, zero_points_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); } - auto zero_points_f16 = std::make_shared(zero_points_node, ov::element::f16); - auto w_zp = - std::make_shared(weights_f16, zero_points_f16, ov::op::AutoBroadcastType::NUMPY); - result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); } if (packed_shape.size() != 2) { @@ -699,24 +745,32 @@ OvWeight process_weight_tensor(const ggml_tensor * tensor, const void * data, vo // Quantized path (normal extraction or quantized requant) // Create weight/scale/zp tensors - shared between both paths - ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + // For symmetric quantization, use signed types (i4/i8) and no ZP tensor + ov::element::Type weight_type = layout.is_symmetric ? (layout.is_u4 ? ov::element::i4 : ov::element::i8) : + (layout.is_u4 ? ov::element::u4 : ov::element::u8); ov::Shape scale_shape = {node_shape[0], node_shape[1] / layout.weights_per_block}; - ov::Shape zp_shape = layout.is_symmetric ? ov::Shape{} : scale_shape; if (output_base_ptr) { uint8_t * buf_base = static_cast(output_base_ptr); result.weights = ov::Tensor(weight_type, node_shape, buf_base + layout.weights_offset); result.scales = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.scales_offset); - result.zp = ov::Tensor(weight_type, zp_shape, buf_base + layout.zp_offset); + if (!layout.is_symmetric) { + ov::element::Type zp_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + result.zp = ov::Tensor(zp_type, scale_shape, buf_base + layout.zp_offset); + } + // else: result.zp remains default-constructed (empty) for symmetric } else { result.weights = ov::Tensor(weight_type, node_shape); result.scales = ov::Tensor(ov::element::f16, scale_shape); - if (use_bias && !layout.is_symmetric) { - // bias only has effect for asymmetric quant - result.zp = ov::Tensor(ov::element::f16, zp_shape); - } else { - result.zp = ov::Tensor(weight_type, zp_shape); + if (!layout.is_symmetric) { + if (use_bias) { + result.zp = ov::Tensor(ov::element::f16, scale_shape); + } else { + ov::element::Type zp_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + result.zp = ov::Tensor(zp_type, scale_shape); + } } + // else: result.zp remains default-constructed (empty) for symmetric } if (layout.is_requant && layout.requant_type.has_value()) { @@ -741,59 +795,75 @@ void quantize_q4_0(const float * x, auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i4); // Signed i4 path - // For Q4_0, zero point is always 8 - if (is_scalar_zp) { - zp[0] = 8 | (8 << 4); // Pack two 4-bit values - } - - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < qk; j++) { - const float v = x[i * qk + j]; - if (amax < fabsf(v)) { - amax = fabsf(v); - max = v; + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + float max = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } } - } - - const float d = max / -8; - - if (d == 0) { - scales[i] = ov::float16(1.0f); - // zp is already set to 8 for symmetric, or set per-block for asymmetric - if (!is_scalar_zp) { + const float d = max / -8; + if (d == 0) { + scales[i] = ov::float16(1.0f); if (i % 2 == 0) { zp[i / 2] = 8; } else { zp[i / 2] |= (8 << 4); } + memset(weights + i * qk / 2, 8 | (8 << 4), qk / 2); + continue; } - memset(weights + i * qk / 2, 8 | (8 << 4), qk / 2); - continue; - } - - const float id = 1.0f / d; - scales[i] = ov::float16(d); - // For asymmetric quantization, store per-block zero points - if (!is_scalar_zp) { + const float id = 1.0f / d; + scales[i] = ov::float16(d); if (i % 2 == 0) { zp[i / 2] = 8; } else { zp[i / 2] |= (8 << 4); } + for (int j = 0; j < qk / 2; ++j) { + const float x0 = x[i * qk + 2 * j] * id; + const float x1 = x[i * qk + 2 * j + 1] * id; + const uint8_t xi0 = MIN(15, (int8_t) (x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t) (x1 + 8.5f)); + weights[i * qk / 2 + j] = xi0 | (xi1 << 4); + } } - - for (int j = 0; j < qk / 2; ++j) { - const float x0 = x[i * qk + 2 * j] * id; - const float x1 = x[i * qk + 2 * j + 1] * id; - const uint8_t xi0 = MIN(15, (int8_t) (x0 + 8.5f)); - const uint8_t xi1 = MIN(15, (int8_t) (x1 + 8.5f)); - weights[i * qk / 2 + j] = xi0 | (xi1 << 4); + } else { + // Symmetric: produce signed i4 values in [-8, 7] + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + float max = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + const float d = max / -8; + if (d == 0) { + scales[i] = ov::float16(1.0f); + // i4 value 0 packed: 0x00 + memset(weights + i * qk / 2, 0, qk / 2); + continue; + } + const float id = 1.0f / d; + scales[i] = ov::float16(d); + for (int j = 0; j < qk / 2; ++j) { + const float x0 = x[i * qk + 2 * j] * id; + const float x1 = x[i * qk + 2 * j + 1] * id; + // Signed i4: range [-8, 7]. Quantize as round(x*id), then pack as 4-bit two's complement. + int8_t si0 = (int8_t) std::max(-8, std::min(7, (int) roundf(x0))); + int8_t si1 = (int8_t) std::max(-8, std::min(7, (int) roundf(x1))); + weights[i * qk / 2 + j] = (si0 & 0x0F) | ((si1 & 0x0F) << 4); + } } } } @@ -809,36 +879,42 @@ void quantize_q8_0(const float * x, auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i8); // Signed i8 path - // For Q8_0, zero point is always 128 - if (is_scalar_zp) { - zp[0] = 128; - } - - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - - for (int j = 0; j < qk; j++) { - const float v = x[i * qk + j]; - if (amax < fabsf(v)) { - amax = fabsf(v); + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + amax = std::max(amax, fabsf(v)); + } + const float d = amax / 127.0f; + const float id = d ? 1.0f / d : 0.0f; + scales[i] = ov::float16(d); + zp[i] = 128; + for (int j = 0; j < qk; ++j) { + const float x0 = x[i * qk + j] * id; + const int8_t xi0 = roundf(x0); + weights[i * qk + j] = (uint8_t) (xi0 + 128); } } - - const float d = amax / 127.0f; - const float id = d ? 1.0f / d : 0.0f; - scales[i] = ov::float16(d); - // For asymmetric quantization, store per-block zero points - if (!is_scalar_zp) { - zp[i] = 128; - } - - for (int j = 0; j < qk; ++j) { - const float x0 = x[i * qk + j] * id; - const int8_t xi0 = roundf(x0); - weights[i * qk + j] = (uint8_t) (xi0 + 128); + } else { + // Symmetric: store signed int8 values directly + auto * signed_weights = reinterpret_cast(weights); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + amax = std::max(amax, fabsf(v)); + } + const float d = amax / 127.0f; + const float id = d ? 1.0f / d : 0.0f; + scales[i] = ov::float16(d); + for (int j = 0; j < qk; ++j) { + const float x0 = x[i * qk + j] * id; + signed_weights[i * qk + j] = (int8_t) roundf(x0); + } } } } @@ -861,12 +937,8 @@ void quantize_q8_1(const float * x, for (int j = 0; j < qk; j++) { const float v = x[i * qk + j]; - if (v < min) { - min = v; - } - if (v > max) { - max = v; - } + min = std::min(v, min); + max = std::max(v, max); } const float d = (max - min) / ((1 << 8) - 1); diff --git a/ggml/src/ggml-openvino/openvino/op/rope.cpp b/ggml/src/ggml-openvino/openvino/op/rope.cpp index 26dc2d24f..a8db9b389 100644 --- a/ggml/src/ggml-openvino/openvino/op/rope.cpp +++ b/ggml/src/ggml-openvino/openvino/op/rope.cpp @@ -9,12 +9,17 @@ #include #include #include +#include +#include +#include #include #include #include +#include #include #include #include +#include #include #include @@ -33,6 +38,12 @@ OutputVector translate_rope(const NodeContext & context) { auto data_node = context.get_input(0).get_node_shared_ptr(); auto output_shape = context.get_output_shape().to_shape(); int32_t * op_params = context.get_output_op_params(); + const int mode = (op_case & 0xFFFF0000) >> 16; + op_case = (op_case & 0x0000FFFF); + + constexpr int TYPE_NORMAL = 0; + constexpr int TYPE_NEOX = 1; + constexpr int TYPE_IMROPE = 2; Output cos_theta_node; Output sin_theta_node; @@ -45,7 +56,7 @@ OutputVector translate_rope(const NodeContext & context) { if (context.get_input_size() == 3) { rope_freqs_weight = context.get_input(2).get_node_shared_ptr(); } - auto sin_cos = make_sin_cos(op_params, inp_pos, rope_freqs_weight); + auto sin_cos = make_sin_cos(op_params, inp_pos, rope_freqs_weight, mode == TYPE_IMROPE); sin_theta_node = sin_cos.first; cos_theta_node = sin_cos.second; } @@ -65,11 +76,7 @@ OutputVector translate_rope(const NodeContext & context) { } } - const int mode = op_params[2]; - constexpr int ROPE_TYPE_NORMAL = 0; - constexpr int ROPE_TYPE_NEOX = 2; - - if (mode == ROPE_TYPE_NORMAL) { + if (mode == TYPE_NORMAL) { auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1}); auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); @@ -97,7 +104,7 @@ OutputVector translate_rope(const NodeContext & context) { auto data_shape = ov::op::v0::Constant::create( ov::element::i64, {4}, std::vector{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); res = std::make_shared(stack, data_shape, false); - } else if (mode == ROPE_TYPE_NEOX) { + } else if (mode == TYPE_NEOX) { auto data_split = std::make_shared( data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1}), 2); Output slice_data_node_0 = data_split->outputs()[0]; @@ -112,6 +119,25 @@ OutputVector translate_rope(const NodeContext & context) { std::make_shared(slice_data_node_1, cos_theta_node)); res = std::make_shared(ov::OutputVector{first_half_node, second_half_node}, -1); + } else if (mode == TYPE_IMROPE) { + int64_t n_dims = data_node->get_shape()[3]; + auto cos_sin_shape = std::make_shared(ov::element::i64, ov::Shape{4}, std::vector{1,-1,1,(n_dims >> 1)}); + auto cos_reshaped = std::make_shared(cos_theta_node, cos_sin_shape, true); + auto sin_reshaped = std::make_shared(sin_theta_node, cos_sin_shape, true); + + auto split_axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {3}); + auto split_a = std::make_shared(data_node, split_axis, 2); + auto x0 = split_a->output(0); + auto x1 = split_a->output(1); + auto mul_a = std::make_shared(x0, cos_reshaped); + auto mul_b = std::make_shared(x1, sin_reshaped); + auto sub = std::make_shared(mul_a, mul_b); + + auto mul_c = std::make_shared(x0, sin_reshaped); + auto mul_d = std::make_shared(x1, cos_reshaped); + auto add = std::make_shared(mul_c, mul_d); + + res = std::make_shared(ov::OutputVector{sub, add}, 3); } return rename_outputs_with_suffix({res}, context.get_name()); diff --git a/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp b/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp new file mode 100644 index 000000000..d1e9efc33 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp @@ -0,0 +1,25 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_unary_gelu(const NodeContext & context) { + num_inputs_check(context, 1, 1); + + auto input = context.get_input(0); + auto res = std::make_shared(input); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op_table.cpp b/ggml/src/ggml-openvino/openvino/op_table.cpp index beadafe81..138553927 100644 --- a/ggml/src/ggml-openvino/openvino/op_table.cpp +++ b/ggml/src/ggml-openvino/openvino/op_table.cpp @@ -31,6 +31,7 @@ std::unordered_map get_supported_ops() { {"GGML_OP_SOFT_MAX", op::translate_soft_max }, {"GGML_OP_SUB", op::translate_1to1_match_2_inputs}, {"GGML_OP_TRANSPOSE", op::translate_transpose }, + {"GGML_UNARY_OP_GELU", op::translate_unary_gelu }, {"GGML_UNARY_OP_SILU", op::translate_unary_silu }, {"GGML_OP_VIEW", op::translate_view }, {"GGML_GLU_OP_SWIGLU", op::translate_glu_swiglu }, diff --git a/ggml/src/ggml-openvino/openvino/op_table.h b/ggml/src/ggml-openvino/openvino/op_table.h index 37f763117..f546796d2 100644 --- a/ggml/src/ggml-openvino/openvino/op_table.h +++ b/ggml/src/ggml-openvino/openvino/op_table.h @@ -21,6 +21,7 @@ GGML_OP_CONVERTER(translate_rms_norm); GGML_OP_CONVERTER(translate_rope); GGML_OP_CONVERTER(translate_scale); GGML_OP_CONVERTER(translate_unary_silu); +GGML_OP_CONVERTER(translate_unary_gelu); GGML_OP_CONVERTER(translate_soft_max); GGML_OP_CONVERTER(translate_transpose); GGML_OP_CONVERTER(translate_view); diff --git a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp deleted file mode 100644 index ed2a3ab6d..000000000 --- a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +++ /dev/null @@ -1,123 +0,0 @@ -#include "eliminate_zp.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace ov { -namespace frontend { -namespace ggml { -namespace pass { - -EliminateZeroPoints::EliminateZeroPoints() { - // Find pattern: - // (Multiply Any(scale) - // (Subtract (Convert Constant(data))) - // (Convert Constant(zero_point))) - // where zero_point is a scalar - // If data is u4 and zp value is 8 (q4_0), Replace the Subtract with an i4 Constant whose value is data - zp_val - // If data is u8 and zp value is 128 (q8_0) or 32 (q6_k), Replace the Subtract with an i8 Constant - - auto m_data_constant = ov::pass::pattern::wrap_type(); - auto m_data_convert = ov::pass::pattern::wrap_type({m_data_constant}); - - auto m_zp_constant = ov::pass::pattern::wrap_type(); - auto m_zp_convert = ov::pass::pattern::wrap_type({m_zp_constant}); - - auto m_subtract = ov::pass::pattern::wrap_type({m_data_convert, m_zp_convert}); - auto m_scale = ov::pass::pattern::any_input(); - auto m_multiply = ov::pass::pattern::wrap_type({m_scale, m_subtract}); - - const auto callback = [=](ov::pass::pattern::Matcher & m) { - const auto & pattern_map = m.get_pattern_value_map(); - - auto multiply_node = - std::dynamic_pointer_cast(pattern_map.at(m_multiply).get_node_shared_ptr()); - auto subtract_node = - std::dynamic_pointer_cast(pattern_map.at(m_subtract).get_node_shared_ptr()); - auto data_constant = - std::dynamic_pointer_cast(pattern_map.at(m_data_constant).get_node_shared_ptr()); - auto zp_constant = - std::dynamic_pointer_cast(pattern_map.at(m_zp_constant).get_node_shared_ptr()); - - if (!multiply_node || !subtract_node || !data_constant || !zp_constant) { - return false; - } - - if (ov::shape_size(zp_constant->get_shape()) != 1) { - return false; - } - - auto data_type = data_constant->get_element_type(); - auto zp_data = zp_constant->cast_vector(); - - if (zp_data.empty()) { - return false; - } - - int zp_value = zp_data[0]; - - bool should_eliminate = false; - ov::element::Type target_type; - - if (data_type == ov::element::u4 && zp_value == 8) { - should_eliminate = true; - target_type = ov::element::i4; - } else if (data_type == ov::element::u8 && (zp_value == 128 || zp_value == 32)) { - should_eliminate = true; - target_type = ov::element::i8; - } - - if (!should_eliminate) { - return false; - } - - auto data_shape = data_constant->get_shape(); - size_t total_elements = ov::shape_size(data_shape); - - std::shared_ptr new_constant; - - // TODO improve performance - if (data_type == ov::element::u4) { - auto data_values = data_constant->cast_vector(); - std::vector adjusted_values(total_elements); - - ov::parallel_for(total_elements, [&](size_t i) { - adjusted_values[i] = static_cast(static_cast(data_values[i]) - 8); - }); - - new_constant = std::make_shared(target_type, data_shape, adjusted_values); - } else if (data_type == ov::element::u8) { - auto data_values = data_constant->cast_vector(); - std::vector adjusted_values(total_elements); - - ov::parallel_for(total_elements, [&, zp_value](size_t i) { - adjusted_values[i] = static_cast(static_cast(data_values[i]) - zp_value); - }); - - new_constant = std::make_shared(target_type, data_shape, adjusted_values); - } - - auto new_convert = - std::make_shared(new_constant, subtract_node->get_output_element_type(0)); - ov::replace_node(subtract_node, new_convert); - - return true; - }; - - register_matcher( - std::make_shared(m_multiply, "ov::frontend::ggml::pass::EliminateZeroPoints"), - callback); -} - -} // namespace pass -} // namespace ggml -} // namespace frontend -} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h deleted file mode 100644 index edd3cd718..000000000 --- a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +++ /dev/null @@ -1,17 +0,0 @@ -#include "openvino/pass/matcher_pass.hpp" - -namespace ov { -namespace frontend { -namespace ggml { -namespace pass { - -class EliminateZeroPoints : public ov::pass::MatcherPass { -public: - OPENVINO_MATCHER_PASS_RTTI("ov::frontend::ggml::pass::EliminateZeroPoints") - EliminateZeroPoints(); -}; - -} // namespace pass -} // namespace ggml -} // namespace frontend -} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp b/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp new file mode 100644 index 000000000..f051891c4 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp @@ -0,0 +1,41 @@ +// Copyright (C) 2018-2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace ov { + +/** + * @brief Holds weightless caching attributes of a single constant. + * + * WeightlessCacheAttribute class represents runtime info attribute that holds + * the values of original size of the constant in bytes and the binary offset of the + * constant's data in the weights file used by the weightless caching mechanism. It's + * not copyable in case the data was changed (the original node was replaced by a new + * one produced during the tranformation pipeline) - in that case weightless caching + * can't be used for that constant. + */ +class OPENVINO_API WeightlessCacheAttribute : public RuntimeAttribute { +public: + OPENVINO_RTTI("WeightlessCacheAttribute", "0", RuntimeAttribute) + + WeightlessCacheAttribute() = delete; + + WeightlessCacheAttribute(size_t original_size, size_t bin_offset, ov::element::Type original_dtype) + : original_size(original_size), + bin_offset(bin_offset), + original_dtype(original_dtype) {} + + bool is_copyable() const override; + + size_t original_size; + size_t bin_offset; + ov::element::Type original_dtype; +}; + +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp index 23a1dea24..0f68a1f50 100644 --- a/ggml/src/ggml-openvino/openvino/translate_session.cpp +++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp @@ -3,15 +3,16 @@ #include "ggml-openvino/openvino/node_context.h" #include "ggml-openvino/openvino/utils.h" #include "input_model.h" -#include "pass/eliminate_zp.h" #include "pass/mark_decompression_convert_constant_folding.h" #include "pass/squeeze_matmul.h" +#include "rt_info/weightless_caching_attributes.hpp" #include #include #include #include #include +#include #include #include #include @@ -33,7 +34,6 @@ #include #include #include -#include namespace ov { namespace frontend { @@ -240,6 +240,31 @@ std::shared_ptr TranslateSession::translate_graph(const frontend::InputMo resulting_model = std::make_shared(results, used_params); apply_transformations(resulting_model); + + // Set WeightlessCacheAttribute on large constants to avoid unnecessary memory copies + // in the NPUW plugin. Without this attribute, NPUW's LazyTensor constructor + // (lazy_tensor.cpp, op::Const::Const) will memcpy every constant "in case export + // occurs", doubling memory usage per compile_model call. + // + // The bin_offset field serves as a unique key (not a real file offset) — this is + // the same convention the GPU plugin uses for non-IR models (see + // Plugin::set_weightless_cache_attributes in intel_gpu/src/plugin/plugin.cpp). + // Each constant must have a distinct bin_offset, otherwise GPU's weightless cache + // import will map multiple constants to the same data. + // + // Small constants (< 16 elements) are excluded since they may be introduced by + // optimization patterns and the overhead is negligible. + size_t offset = 0; + for (auto & node : resulting_model->get_ordered_ops()) { + if (auto cnst = ov::as_type_ptr(node); + cnst && cnst->get_byte_size() / cnst->get_element_type().size() >= 16) { + auto & rt_info = cnst->get_rt_info(); + if (rt_info.find(ov::WeightlessCacheAttribute::get_type_info_static()) == rt_info.end()) { + rt_info[ov::WeightlessCacheAttribute::get_type_info_static()] = + ov::WeightlessCacheAttribute(cnst->get_byte_size(), offset++, cnst->get_element_type()); + } + } + } return resulting_model; } @@ -257,7 +282,6 @@ std::shared_ptr TranslateSession::apply_transformations(std::shared_ptris_static()) { - manager.register_pass(); manager.register_pass(); } manager.run_passes(model); diff --git a/ggml/src/ggml-openvino/openvino/utils.cpp b/ggml/src/ggml-openvino/openvino/utils.cpp index 65356a51b..0baaf88e1 100644 --- a/ggml/src/ggml-openvino/openvino/utils.cpp +++ b/ggml/src/ggml-openvino/openvino/utils.cpp @@ -2,6 +2,7 @@ #include "ggml-impl.h" +#include #include #include #include @@ -13,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -87,8 +89,11 @@ ov::Output rope_yarn_ramp_mix(int n_dims, const float corr_dims[2], fl auto ramp_y = std::make_shared(std::make_shared(dim_ids, corr_low), denom); auto ramp_clamped = std::make_shared(ramp_y, 0.0f, 1.0f); + // rope_yarn_ramp returns (1 - clamp(y)), so invert before scaling + auto one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f}); + auto ramp_inverted = std::make_shared(one, ramp_clamped); auto ext_factor_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {ext_factor}); - auto ramp_mix = std::make_shared(ramp_clamped, ext_factor_node); + auto ramp_mix = std::make_shared(ramp_inverted, ext_factor_node); return ramp_mix; } @@ -115,6 +120,7 @@ void ggml_rope_yarn_corr_dims(int n_dims, std::pair, ov::Output> make_sin_cos(int32_t * rope_params, std::shared_ptr inp_pos, std::shared_ptr rope_freqs_weight, + bool imrope, bool stateful) { if (stateful) { inp_pos = std::make_shared(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); @@ -122,6 +128,13 @@ std::pair, ov::Output> make_sin_cos(int32_t * rope_params auto pos_perm = std::make_shared(ov::element::i64, ov::Shape{3}, std::vector{2, 1, 0}); inp_pos = std::make_shared(inp_pos, pos_perm); + } else if (imrope) { + inp_pos = std::make_shared(inp_pos, ov::element::f32); + auto pos_shape = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{5}, {0, 0, 0, 4, -1}); + inp_pos = std::make_shared(inp_pos, pos_shape, true); + auto pos_transpose_shape = + std::make_shared(ov::element::i64, ov::Shape{5}, std::vector{0, 1, 2, 4, 3}); + inp_pos = std::make_shared(inp_pos, pos_transpose_shape); } else { inp_pos = std::make_shared(inp_pos, ov::element::f32); auto pos_perm = @@ -136,6 +149,7 @@ std::pair, ov::Output> make_sin_cos(int32_t * rope_params float beta_fast; float beta_slow; const int n_dims = rope_params[1]; + const size_t n_dims_half = n_dims >> 1; const int n_ctx_orig = rope_params[4]; memcpy(&freq_base, rope_params + 5, sizeof(float)); memcpy(&freq_scale, rope_params + 6, sizeof(float)); @@ -146,57 +160,74 @@ std::pair, ov::Output> make_sin_cos(int32_t * rope_params const float theta_scale = powf(freq_base, -2.0f / n_dims); - float corr_dims[2]; - ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - - std::vector factor(n_dims / 2); - factor[0] = 1.0f; - for (size_t i = 1; i < factor.size(); i++) { - factor[i] = theta_scale * factor[i - 1]; - } + std::vector factor(n_dims_half); Output freq_factors; - if (stateful) { - freq_factors = - std::make_shared(ov::element::f32, ov::Shape{1, 1, factor.size()}, factor); - } else { - freq_factors = - std::make_shared(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor); - } - if (rope_freqs_weight) { - freq_factors = std::make_shared(freq_factors, rope_freqs_weight); - } - - auto theta_extrap = std::make_shared(freq_factors, inp_pos); - auto theta_interp = std::make_shared( - theta_extrap, ov::op::v0::Constant::create(ov::element::f32, {1}, {freq_scale})); Output theta; float mscale = attn_factor; - if (ext_factor == 0.0f) { - theta = theta_interp; - } else { - auto ramp_mix = rope_yarn_ramp_mix(n_dims, corr_dims, ext_factor); - Output one; - if (stateful) { - one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {1.0f}); - } else { - one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f}); + if (imrope) { + std::vector gather_indices(n_dims_half); + for (size_t j = 0; j < n_dims_half; j++) { + gather_indices[j] = j % 3; + factor[j] = std::pow(theta_scale, j); + } + auto gather_indices_const = + std::make_shared(ov::element::i64, ov::Shape{n_dims_half}, gather_indices); + auto gather_axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {4}); + inp_pos = std::make_shared(inp_pos, gather_indices_const, gather_axis); + auto factor_const = std::make_shared(ov::element::f32, ov::Shape{n_dims_half}, factor); + theta = std::make_shared(inp_pos, factor_const); + } else { + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + factor[0] = 1.0f; + for (size_t i = 1; i < factor.size(); i++) { + factor[i] = theta_scale * factor[i - 1]; + } + if (stateful) { + freq_factors = + std::make_shared(ov::element::f32, ov::Shape{1, 1, factor.size()}, factor); + } else { + freq_factors = + std::make_shared(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor); + } + if (rope_freqs_weight) { + freq_factors = std::make_shared(freq_factors, rope_freqs_weight); } - auto one_minus_ramp = std::make_shared(one, ramp_mix); - theta = std::make_shared(std::make_shared(theta_interp, one_minus_ramp), - std::make_shared(theta_extrap, ramp_mix)); - mscale *= (1.0f + 0.1f * std::log(1.0f / freq_scale)); + auto theta_extrap = std::make_shared(freq_factors, inp_pos); + auto theta_interp = std::make_shared( + theta_extrap, ov::op::v0::Constant::create(ov::element::f32, {1}, {freq_scale})); + + if (ext_factor == 0.0f) { + theta = theta_interp; + } else { + auto ramp_mix = rope_yarn_ramp_mix(n_dims, corr_dims, ext_factor); + Output one; + if (stateful) { + one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {1.0f}); + } else { + one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f}); + } + auto one_minus_ramp = std::make_shared(one, ramp_mix); + + theta = std::make_shared(std::make_shared(theta_interp, one_minus_ramp), + std::make_shared(theta_extrap, ramp_mix)); + mscale *= (1.0f + 0.1f * std::log(1.0f / freq_scale)); + } } Output cos_theta = std::make_shared(theta); Output sin_theta = std::make_shared(theta); - auto mscale_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {mscale}); + if (!imrope) { + auto mscale_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {mscale}); + + cos_theta = std::make_shared(cos_theta, mscale_node); + sin_theta = std::make_shared(sin_theta, mscale_node); + } - cos_theta = std::make_shared(cos_theta, mscale_node); - sin_theta = std::make_shared(sin_theta, mscale_node); return std::make_pair(sin_theta, cos_theta); } diff --git a/ggml/src/ggml-openvino/openvino/utils.h b/ggml/src/ggml-openvino/openvino/utils.h index 88dcad4c9..767dd4c53 100644 --- a/ggml/src/ggml-openvino/openvino/utils.h +++ b/ggml/src/ggml-openvino/openvino/utils.h @@ -67,6 +67,7 @@ OutputVector rename_outputs_with_suffix(const OutputVector& outputs, const std:: std::pair, ov::Output> make_sin_cos(int32_t* rope_params, std::shared_ptr inp_pos, std::shared_ptr rope_freqs_weight = nullptr, + bool imrope = false, bool stateful = false); ov::Output process_view_input(const NodeContext& context, int input_index, int slice_len = 0); diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 1b553a0de..998ef7c9e 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -81,8 +81,8 @@ ov::Tensor create_ov_output_tensor(std::shared_ptr ggml_decoder, enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr r_ctx) { auto & core = ov_singleton_core(); const auto & config = ggml_openvino_get_compile_config(); - auto device = r_ctx->device; - bool stateful = r_ctx->stateful; + const auto & device = r_ctx->device; + const auto & stateful = r_ctx->stateful; static auto is_static = false; if (is_naive(cgraph)) { @@ -106,14 +106,26 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< int64_t infer_end_time; { - std::lock_guard lock(r_ctx->ov_compute_mutex); - - auto it = r_ctx->decoder_cache.find(key); - - cache_hit = it != r_ctx->decoder_cache.end(); + std::shared_ptr entry; ModelParams old_m_params; + + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + auto it = r_ctx->decoder_cache.find(key); + cache_hit = it != r_ctx->decoder_cache.end(); + if (cache_hit) { + entry = it->second; + } else { + auto mutex = std::make_shared(); + entry = std::make_shared(mutex); + r_ctx->decoder_cache[key] = entry; + } + } + + std::lock_guard lock(*(entry->mutex)); + if (cache_hit) { - ggml_decoder = it->second; + ggml_decoder = entry->ptr; old_m_params = ggml_decoder->get_model_params(); cache_hit = old_m_params.can_reuse_dynamically(m_params); } @@ -126,7 +138,10 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< ggml_decoder->update_io(cgraph); } ggml_decoder->add_extra_inputs(); - infer_request = r_ctx->infer_request_cache.at(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + infer_request = r_ctx->infer_request_cache.at(key); + } if (stateful) { const auto * inp_pos = get_inp_pos_tensor(cgraph); @@ -170,7 +185,10 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< conversion_end_time = decoder_end_time; compile_end_time = decoder_end_time; } else { - r_ctx->infer_request_cache.erase(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache.erase(key); + } std::shared_ptr model; auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); @@ -199,8 +217,7 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< } compile_end_time = ggml_time_us(); infer_request = std::make_shared(compiled_model.create_infer_request()); - r_ctx->infer_request_cache[key] = infer_request; - r_ctx->decoder_cache[key] = ggml_decoder; + entry->ptr = ggml_decoder; std::vector ov_input_names; std::vector ov_output_names; @@ -210,8 +227,13 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< for (const auto & ov_output : model->get_results()) { ov_output_names.push_back(ov_output->get_friendly_name()); } - r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); - r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache[key] = infer_request; + r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); + r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + } if (stateful) { const auto * inp_pos = get_inp_pos_tensor(cgraph); @@ -224,8 +246,13 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< } } - auto ov_input_names = r_ctx->ov_input_names_cache[key]; - auto ov_output_names = r_ctx->ov_output_names_cache[key]; + std::vector ov_input_names; + std::vector ov_output_names; + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + ov_input_names = r_ctx->ov_input_names_cache[key]; + ov_output_names = r_ctx->ov_output_names_cache[key]; + } for (size_t i = 0; i < ov_input_names.size(); i++) { auto param_name = ov_input_names[i]; @@ -306,12 +333,26 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrdecoder_cache.find(key); - - cache_hit = it != r_ctx->decoder_cache.end(); + std::shared_ptr entry; ModelParams old_m_params; + + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + auto it = r_ctx->decoder_cache.find(key); + cache_hit = it != r_ctx->decoder_cache.end(); + if (cache_hit) { + entry = it->second; + } else { + auto mutex = std::make_shared(); + entry = std::make_shared(mutex); + r_ctx->decoder_cache[key] = entry; + } + } + + std::lock_guard lock(*(entry->mutex)); + if (cache_hit) { - ggml_decoder = it->second; + ggml_decoder = entry->ptr; old_m_params = ggml_decoder->get_model_params(); cache_hit = old_m_params.can_reuse_statically(m_params); } @@ -325,14 +366,21 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrupdate_io(cgraph); } ggml_decoder->add_extra_inputs(); - infer_request = is_prefill ? r_ctx->infer_request_cache_prefill.at(key) : r_ctx->infer_request_cache.at(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + infer_request = + is_prefill ? r_ctx->infer_request_cache_prefill.at(key) : r_ctx->infer_request_cache.at(key); + } decoder_end_time = ggml_time_us(); conversion_end_time = decoder_end_time; compile_end_time = decoder_end_time; } else { - r_ctx->infer_request_cache.erase(key); - r_ctx->infer_request_cache_prefill.erase(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache.erase(key); + r_ctx->infer_request_cache_prefill.erase(key); + } std::shared_ptr model; auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); @@ -372,16 +420,14 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrinfer_request_cache_prefill[key] = - std::make_shared(compiled_model_prefill.create_infer_request()); - r_ctx->infer_request_cache[key] = - std::make_shared(compiled_model_decode.create_infer_request()); + auto infer_request_prefill = std::make_shared(compiled_model_prefill.create_infer_request()); + auto infer_request_decode = std::make_shared(compiled_model_decode.create_infer_request()); compile_end_time = ggml_time_us(); model = is_prefill ? model_prefill : model_decode; ggml_decoder = is_prefill ? ggml_decoder_prefill : ggml_decoder_decode; - infer_request = is_prefill ? r_ctx->infer_request_cache_prefill[key] : r_ctx->infer_request_cache[key]; - r_ctx->decoder_cache[key] = ggml_decoder; + infer_request = is_prefill ? infer_request_prefill : infer_request_decode; + entry->ptr = ggml_decoder; std::vector ov_input_names; std::vector ov_output_names; @@ -391,18 +437,29 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrget_results()) { ov_output_names.push_back(ov_output->get_friendly_name()); } - r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); - r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache_prefill[key] = infer_request_prefill; + r_ctx->infer_request_cache[key] = infer_request_decode; + r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); + r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + } } - auto ov_input_names = r_ctx->ov_input_names_cache[key]; - auto ov_output_names = r_ctx->ov_output_names_cache[key]; + std::vector ov_input_names_local; + std::vector ov_output_names_local; + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + ov_input_names_local = r_ctx->ov_input_names_cache[key]; + ov_output_names_local = r_ctx->ov_output_names_cache[key]; + } if (is_prefill) { auto inp_len = inp_pos->ne[0]; for (int chunk_index = 0; chunk_index * prefill_chunk_size < inp_len; chunk_index++) { - for (size_t i = 0; i < ov_input_names.size(); i++) { - auto param_name = ov_input_names[i]; + for (size_t i = 0; i < ov_input_names_local.size(); i++) { + auto param_name = ov_input_names_local[i]; auto input_tensor = get_ov_input_tensor_static_prefill(ggml_decoder, param_name, chunk_index); infer_request->set_input_tensor(i, input_tensor); @@ -412,8 +469,8 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrget_model_outputs().at(ov_output_names[i]); + for (size_t i = 0; i < ov_output_names_local.size(); i++) { + auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names_local[i]); auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor); infer_request->set_output_tensor(i, output_tensor); } @@ -421,16 +478,16 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrinfer(); if (getenv("GGML_OPENVINO_DEBUG_OUTPUT")) { - for (size_t i = 0; i < ov_output_names.size(); i++) { + for (size_t i = 0; i < ov_output_names_local.size(); i++) { const auto output_tensor = infer_request->get_output_tensor(i); - print_output_tensor_info(ov_output_names[i], output_tensor, output_tensor.data()); + print_output_tensor_info(ov_output_names_local[i], output_tensor, output_tensor.data()); } } } infer_end_time = ggml_time_us(); } else { - for (size_t i = 0; i < ov_input_names.size(); i++) { - auto param_name = ov_input_names[i]; + for (size_t i = 0; i < ov_input_names_local.size(); i++) { + auto param_name = ov_input_names_local[i]; auto input_tensor = get_ov_input_tensor_static_decode(ggml_decoder, param_name); infer_request->set_input_tensor(i, input_tensor); @@ -440,8 +497,8 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrget_model_outputs().at(ov_output_names[i]); + for (size_t i = 0; i < ov_output_names_local.size(); i++) { + auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names_local[i]); auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor); infer_request->set_output_tensor(i, output_tensor); } @@ -450,9 +507,9 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrget_output_tensor(i); - print_output_tensor_info(ov_output_names[i], output_tensor, output_tensor.data()); + print_output_tensor_info(ov_output_names_local[i], output_tensor, output_tensor.data()); } } } diff --git a/ggml/src/ggml-openvino/utils.h b/ggml/src/ggml-openvino/utils.h index 656573d13..2c72e33c3 100644 --- a/ggml/src/ggml-openvino/utils.h +++ b/ggml/src/ggml-openvino/utils.h @@ -3,12 +3,15 @@ #include "ggml-impl.h" #include +#include #include #include +#include #include #include #include #include +#include #include struct graph_key { @@ -40,11 +43,17 @@ struct graph_key_hash { } }; +struct decoder_runtime_ctx { + decoder_runtime_ctx(std::shared_ptr mutex) : mutex(std::move(mutex)) {} + std::shared_ptr mutex; + std::shared_ptr ptr; +}; + struct ov_runtime_context { - std::mutex ov_compute_mutex; + mutable std::mutex ctx_mutex; std::string device; bool stateful; - std::unordered_map, graph_key_hash> decoder_cache; + std::unordered_map, graph_key_hash> decoder_cache; std::unordered_map, graph_key_hash> infer_request_cache; std::unordered_map, graph_key_hash> infer_request_cache_prefill; std::unordered_map, graph_key_hash> ov_input_names_cache; @@ -53,11 +62,22 @@ struct ov_runtime_context { // Simultanous stateful inference request support to be added. size_t stateful_kv_size; std::map kv_state_input_name_map; + std::atomic backend_count; ov_runtime_context() : device("CPU"), stateful(false), - stateful_kv_size(0) {} + stateful_kv_size(0), + backend_count(0) {} + + void clear_caches() { + std::lock_guard lock(ctx_mutex); + decoder_cache.clear(); + infer_request_cache.clear(); + infer_request_cache_prefill.clear(); + ov_input_names_cache.clear(); + ov_output_names_cache.clear(); + } }; enum ggml_status ov_graph_compute(struct ggml_cgraph * cgraph, ggml_backend_t backend); From 84652b80cfd165c1b6f8c0541803e1967e5c4c34 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 21 Apr 2026 19:52:02 +0300 Subject: [PATCH 04/20] arg : add --spec-default (#22223) --- common/arg.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/common/arg.cpp b/common/arg.cpp index 6751a55ab..5c0bbe38e 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3902,6 +3902,17 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); + add_opt(common_arg( + {"--spec-default"}, + string_format("enable default speculative decoding config"), + [](common_params & params) { + params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD; + params.speculative.ngram_size_n = 24; + params.speculative.n_min = 48; + params.speculative.n_max = 64; + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); + return ctx_arg; } From 98d2d2884e28ca8718774caad691ddd525f2f7b2 Mon Sep 17 00:00:00 2001 From: Kwa Jie Hao <31984694+kwajiehao@users.noreply.github.com> Date: Wed, 22 Apr 2026 02:02:49 +0800 Subject: [PATCH 05/20] mtmd: Add support for Reka Edge 2603 (#21616) * feat: (vocab) fix stray text appended in llama_decode_text Remove accidental concatenation of the full `text` string when formatting UNK_BYTE hex escapes. Only the closing "]" should be appended. * feat(mtmd): add Yasa2 vision encoder support Add a Yasa2 (ConvNeXtV2-based) vision encoder for reka-edge: - Register PROJECTOR_TYPE_YASA2 and tensor name definitions - Add yasa2_block/yasa2_stage model structs - Implement graph builder with ConvNeXt stages, GRN, adaptive pooling - Wire into clip.cpp switch statements and mtmd.cpp init_vision - Use mtmd_image_preprocessor_fixed_size for image preprocessing * feat(chat): add reka-edge template handler (tools, thinking) - Add chat-reka.cpp/h implementing PEG-based parser for reka-edge format - Add Reka-Edge.jinja chat template - Detect reka-edge template in try_specialized_template() - Add LLAMA_EXAMPLE_MTMD to chat-template-file arg * feat: add reka vlm to gguf conversion script Converts Reka Yasa2 hf checkpoints to GGUF format: - Text decoder: Llama-arch with tiktoken/BPE vocab - Mmproj (--mmproj): ConvNeXt vision backbone + language_projection - Generates 2D sincos positional embeddings for vision encoder * test: add Reka Edge chat template and parser tests - test-chat-template: oracle tests comparing Jinja engine output vs common_chat_templates_apply for text, tools, thinking, images, video - test-chat: PEG parser tests for Reka Edge format, round-trip tests for image/video content parts, common path integration tests * scripts: add Reka Edge mixed quantization helper Q4_0 base quantization with Q8_0 override for the last 8 transformer blocks (layers 24-31) via --tensor-type regex. * fix: adapt chat-reka and tests to upstream API - Use autoparser::generation_params (not templates_params) - Add p.prefix(generation_prompt) to PEG parser - Simplify reasoning parser to match LFM2 pattern - Remove image/video oracle tests (unsupported by oaicompat parser; no other multimodal models test this path) * fix: avoid duplicate tensor loading in yasa2 vision encoder TN_YASA_PATCH_W and TN_PATCH_EMBD both resolve to "v.patch_embd.weight", causing the same tensor to be loaded twice into ctx_data and overflowing the memory pool. Reuse the tensors already loaded by the common section. * chore: update image pre-processing settings The reka-edge model depends on the following settings in an older fork of llama.cpp: 1. Fixed square resize 2. BICUBIC 3. add_padding=false In current llama.cpp, this means setting: - image_resize_algo = RESIZE_ALGO_BICUBIC - image_resize_pad = false * chore: remove reka gguf conversion script * chore: remove reka quantization script * chore: remove unnecessary changes from PR scope This commit removes a couple of unnecessary changes for the PR scope: 1. BPE decoder bug fix - this affects reka edge because there's a bug in our tokenization that doesn't represent tokens as special tokens. However this isn't meant to be a thinking model so when run with --reasoning off the edge case does not affect us 2. --chat-template-file support from llama-mtmd-cli - the focus is on llama-server and the reka edge gguf contains the necessary metadata to detect the chat template 3. reka edge oracle test cases - no other model has similar test cases, so I removed it for standardization * chore: remove unnecessary ggml_cast This commit removes unnecessary ggml_cast after updating the reka vlm -> gguf conversion script on hugging face. * chore: remove redundant code * chore: remove unnecessary ggml_cont calls This commit removes all ggml_cont calls except the four that precede ggml_reshape_3d/ggml_reshape_4d. Those are necessary because ggml_reshape recomputes strides assuming contiguous layout and asserts ggml_is_contiguous. Other operations (ggml_mean, ggml_add, ggml_mul etc.) use stride-based indexing and handle non-contiguous inputs correctly and so we are ok to remove ggml_cont for those. * chore: remove unnecessary ggml_repeat calls This commit removes unnecessary ggml_repeat calls because the underlying ops already broadcast automatically. Every ggml_repeat in yasa2.cpp was expanding a smaller tensor to match a larger one's shape before passing both to an elementwise op (ggml_add, ggml_sub, ggml_mul, or ggml_div). This is unnecessary because all four of these ops already support broadcasting internally. * chore: restore ggml_cont needed for cpu operations * refactor: locate reka chat template handler in chat.cpp * chore: remove unnecessary warmup tokens * chore: add code comments on image_resize_pad * chore: remove custom reka parsing code * chore: revert common/chat.cpp * Uncomment debug logging for PEG input parsing --------- Co-authored-by: Piotr Wilkin (ilintar) --- tests/test-chat.cpp | 95 ++++++++++++++++++ tools/mtmd/CMakeLists.txt | 1 + tools/mtmd/clip-impl.h | 11 +++ tools/mtmd/clip-model.h | 30 ++++++ tools/mtmd/clip.cpp | 69 +++++++++++++ tools/mtmd/models/models.h | 8 ++ tools/mtmd/models/yasa2.cpp | 191 ++++++++++++++++++++++++++++++++++++ tools/mtmd/mtmd.cpp | 13 +++ 8 files changed, 418 insertions(+) create mode 100644 tools/mtmd/models/yasa2.cpp diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index eee28271b..79e1891b8 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -3595,6 +3595,51 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .run(); } + // Reka Edge + { + auto tst = peg_tester("models/templates/Reka-Edge.jinja", detailed_debug); + tst.test("Hello, world!\nWhat's up?") + .enable_thinking(false) + .expect(message_assist) + .run(); + tst.test("I'm\nthinking\n\nHello, world!\nWhat's up?") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .expect(message_assist_thoughts) + .run(); + tst.test("\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n") + .enable_thinking(false) + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + tst.test("Hello, world!\nWhat's up?\n\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n") + .enable_thinking(false) + .tools({ special_function_tool }) + .expect(message_assist_call_content) + .run(); + tst.test("I'm\nthinking\n\n\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .tools({ special_function_tool }) + .expect(message_assist_call_thoughts) + .run(); + tst.test("\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n\n\n{\"name\": \"special_function_with_opt\", \"arguments\": {\"arg1\": 1, \"arg2\": 2}}\n") + .enable_thinking(false) + .parallel_tool_calls(true) + .tools({ special_function_tool, special_function_tool_with_optional_param }) + .expect_tool_calls({ + { "special_function", R"({"arg1": 1})", {} }, + { "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} }, + }) + .run(); + tst.test("\n{\"name\": \"special_function\", \"arguments\": {\"arg") + .enable_thinking(false) + .tools({ special_function_tool }) + .is_partial(true) + .expect(message_assist_call_cutoff_args) + .run(); + } + // Apriel 1.5 { auto tst = peg_tester("models/templates/unsloth-Apriel-1.5.jinja", detailed_debug); @@ -4077,6 +4122,55 @@ static void test_template_output_peg_parsers(bool detailed_debug) { } } +static void test_reka_edge_common_path() { + auto tmpls = read_templates("models/templates/Reka-Edge.jinja"); + + { + common_chat_templates_inputs inputs; + common_chat_msg system_msg; + system_msg.role = "system"; + system_msg.content = "Use tools when needed."; + + common_chat_msg tool_call_msg = simple_assist_msg("", "", "special_function", "{\"arg1\": 1}"); + + common_chat_msg tool_msg; + tool_msg.role = "tool"; + tool_msg.tool_name = "special_function"; + tool_msg.tool_call_id = "call0"; + tool_msg.content = "Sunny"; + + inputs.messages = { system_msg, message_user, tool_call_msg, tool_msg, message_user }; + inputs.tools = { special_function_tool }; + inputs.enable_thinking = true; + inputs.add_generation_prompt = true; + + auto params = common_chat_templates_apply(tmpls.get(), inputs); + + if (params.prompt.find("\nSunny\n") == std::string::npos) { + throw std::runtime_error("Reka Edge prompt did not render tool response history"); + } + if (params.prompt.rfind("assistant: \n") == std::string::npos) { + throw std::runtime_error("Reka Edge prompt did not render thinking generation prompt"); + } + } + + { + common_chat_templates_inputs inputs; + inputs.messages = { + message_user, + simple_assist_msg("The first point is") + }; + inputs.add_generation_prompt = false; + inputs.enable_thinking = false; + inputs.chat_template_kwargs["continue_final_message"] = "true"; + + auto params = common_chat_templates_apply(tmpls.get(), inputs); + if (string_ends_with(params.prompt, "")) { + throw std::runtime_error("Reka Edge continue_final_message unexpectedly closed the assistant turn"); + } + } +} + // Test the developer role to system workaround with a simple mock template static void test_developer_role_to_system_workaround() { LOG_DBG("%s\n", __func__); @@ -4256,6 +4350,7 @@ int main(int argc, char ** argv) { test_msgs_oaicompat_json_conversion(); test_tools_oaicompat_json_conversion(); test_developer_role_to_system_workaround(); + test_reka_edge_common_path(); test_template_output_peg_parsers(detailed_debug); std::cout << "\n[chat] All tests passed!" << '\n'; } diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 399876128..35d721d5a 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -40,6 +40,7 @@ add_library(mtmd models/deepseekocr.cpp models/mobilenetv5.cpp models/youtuvl.cpp + models/yasa2.cpp ) set_target_properties(mtmd PROPERTIES diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 17cb703f7..61fe82439 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -242,6 +242,15 @@ #define TN_STD_BIAS "v.std_bias" #define TN_STD_SCALE "v.std_scale" +// yasa2 +#define TN_YASA_PATCH_LN_W "v.patch_ln.weight" +#define TN_YASA_PATCH_LN_B "v.patch_ln.bias" +#define TN_YASA_BACKBONE_LN_W "v.backbone_ln.weight" +#define TN_YASA_BACKBONE_LN_B "v.backbone_ln.bias" +#define TN_YASA_POS_EMBD "v.vision_pos_embed" +#define TN_YASA_STAGE_DOWN_LN "v.stage.%d.down.ln.%s" +#define TN_YASA_STAGE_DOWN_CONV "v.stage.%d.down.conv.%s" +#define TN_YASA_STAGE_BLK "v.stage.%d.blk.%d.%s.%s" // align x to upper multiple of n #define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) @@ -290,6 +299,7 @@ enum projector_type { PROJECTOR_TYPE_LFM2A, PROJECTOR_TYPE_GLM4V, PROJECTOR_TYPE_YOUTUVL, + PROJECTOR_TYPE_YASA2, PROJECTOR_TYPE_KIMIK25, PROJECTOR_TYPE_NEMOTRON_V2_VL, PROJECTOR_TYPE_HUNYUANOCR, @@ -335,6 +345,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LFM2A, "lfm2a"}, { PROJECTOR_TYPE_GLM4V, "glm4v"}, { PROJECTOR_TYPE_YOUTUVL, "youtuvl"}, + { PROJECTOR_TYPE_YASA2, "yasa2"}, { PROJECTOR_TYPE_KIMIK25, "kimik25"}, { PROJECTOR_TYPE_NEMOTRON_V2_VL, "nemotron_v2_vl"}, { PROJECTOR_TYPE_HUNYUANOCR, "hunyuanocr"}, diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index 9a93584d9..bf8031b55 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -268,6 +268,27 @@ struct mobilenetv5_block { ggml_tensor * attn_norm_w = nullptr; }; +struct yasa2_block { + ggml_tensor * dw_w = nullptr; + ggml_tensor * dw_b = nullptr; + ggml_tensor * ln_w = nullptr; + ggml_tensor * ln_b = nullptr; + ggml_tensor * pw1_w = nullptr; + ggml_tensor * pw1_b = nullptr; + ggml_tensor * grn_w = nullptr; + ggml_tensor * grn_b = nullptr; + ggml_tensor * pw2_w = nullptr; + ggml_tensor * pw2_b = nullptr; +}; + +struct yasa2_stage { + ggml_tensor * down_ln_w = nullptr; + ggml_tensor * down_ln_b = nullptr; + ggml_tensor * down_conv_w = nullptr; + ggml_tensor * down_conv_b = nullptr; + std::vector blocks; +}; + struct clip_model { clip_modality modality = CLIP_MODALITY_VISION; projector_type proj_type = PROJECTOR_TYPE_MLP; @@ -402,6 +423,15 @@ struct clip_model { ggml_tensor * msfa_ffn_expand_bn = nullptr; ggml_tensor * msfa_ffn_project_bn = nullptr; + // yasa2 + ggml_tensor * yasa_patch_w = nullptr; + ggml_tensor * yasa_patch_b = nullptr; + ggml_tensor * yasa_patch_ln_w = nullptr; + ggml_tensor * yasa_patch_ln_b = nullptr; + ggml_tensor * yasa_backbone_ln_w = nullptr; + ggml_tensor * yasa_backbone_ln_b = nullptr; + ggml_tensor * yasa_vision_pos_embed = nullptr; + std::vector yasa_stages; // pixtral, glm4v ggml_tensor * token_embd_img_break = nullptr; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index f0e8786b6..540b0ea41 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -947,6 +947,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_YASA2: + { + builder = std::make_unique(ctx, img); + } break; default: GGML_ABORT("missing cgraph builder"); } @@ -1389,6 +1393,16 @@ struct clip_model_loader { hparams.set_limit_image_tokens(1, 62500); hparams.set_warmup_n_tokens(16*16); // avoid OOM on warmup } break; + case PROJECTOR_TYPE_YASA2: + { + hparams.ffn_op = FFN_GELU_ERF; + log_ffn_op = "gelu_erf"; + hparams.image_resize_algo = RESIZE_ALGO_BICUBIC; + + // reka model performs better when using resize_bicubic, which stretches + // the image to fit fixed square size + hparams.image_resize_pad = false; + } break; case PROJECTOR_TYPE_GLM4V: { hparams.rope_theta = 10000.0f; @@ -1839,6 +1853,55 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); // merger.mlp.2 model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; + case PROJECTOR_TYPE_YASA2: + { + // reuse tensors already loaded by the common section + // (TN_PATCH_EMBD and TN_PATCH_BIAS have the same tensor names) + GGML_ASSERT(model.patch_embeddings_0 && "yasa2 requires v.patch_embd.weight"); + model.yasa_patch_w = model.patch_embeddings_0; + model.yasa_patch_b = model.patch_bias; + model.yasa_patch_ln_w = get_tensor(TN_YASA_PATCH_LN_W, false); + model.yasa_patch_ln_b = get_tensor(TN_YASA_PATCH_LN_B, false); + model.yasa_backbone_ln_w = get_tensor(TN_YASA_BACKBONE_LN_W, false); + model.yasa_backbone_ln_b = get_tensor(TN_YASA_BACKBONE_LN_B, false); + model.yasa_vision_pos_embed = get_tensor(TN_YASA_POS_EMBD, false); + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false); + model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); + + model.yasa_stages.clear(); + for (int s = 0; ; ++s) { + yasa2_stage stage; + stage.down_ln_w = get_tensor(string_format(TN_YASA_STAGE_DOWN_LN, s, "weight"), false); + stage.down_ln_b = get_tensor(string_format(TN_YASA_STAGE_DOWN_LN, s, "bias"), false); + stage.down_conv_w = get_tensor(string_format(TN_YASA_STAGE_DOWN_CONV, s, "weight"), false); + stage.down_conv_b = get_tensor(string_format(TN_YASA_STAGE_DOWN_CONV, s, "bias"), false); + + for (int bi = 0; ; ++bi) { + yasa2_block blk; + blk.dw_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "dw", "weight"), false); + if (!blk.dw_w) { + break; + } + blk.dw_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "dw", "bias"), false); + blk.ln_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "ln", "weight"), false); + blk.ln_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "ln", "bias"), false); + blk.pw1_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "pw1", "weight"), false); + blk.pw1_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "pw1", "bias"), false); + blk.grn_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "grn", "weight"), false); + blk.grn_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "grn", "bias"), false); + blk.pw2_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "pw2", "weight"), false); + blk.pw2_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "pw2", "bias"), false); + stage.blocks.push_back(blk); + } + + if (!stage.down_conv_w && stage.blocks.empty()) { + break; + } + model.yasa_stages.push_back(std::move(stage)); + } + } break; case PROJECTOR_TYPE_GLM4V: { model.mm_fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight")); @@ -2843,6 +2906,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im { // do nothing } break; + case PROJECTOR_TYPE_YASA2: + { + n_patches = 64; // adaptive average pooling to 8x8 tokens + } break; case PROJECTOR_TYPE_LDP: case PROJECTOR_TYPE_LDPV2: case PROJECTOR_TYPE_GLM_EDGE: @@ -3463,6 +3530,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima case PROJECTOR_TYPE_PHI4: case PROJECTOR_TYPE_COGVLM: case PROJECTOR_TYPE_HUNYUANOCR: + case PROJECTOR_TYPE_YASA2: { // do nothing } break; @@ -3689,6 +3757,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_KIMIVL: case PROJECTOR_TYPE_PADDLEOCR: case PROJECTOR_TYPE_KIMIK25: + case PROJECTOR_TYPE_YASA2: return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_HUNYUANOCR: return ctx->model.mm_model_proj->ne[1]; diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index 03d99e15b..c30d79133 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -43,6 +43,14 @@ struct clip_graph_youtuvl : clip_graph { ggml_cgraph * build() override; }; +struct clip_graph_yasa2 : clip_graph { + clip_graph_yasa2(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; + + ggml_tensor * layer_norm_channels(ggml_tensor * inp, ggml_tensor * w, ggml_tensor * b, float eps = 1e-6f); + ggml_tensor * convnext_grn(ggml_tensor * inp, ggml_tensor * w, ggml_tensor * b); +}; + struct clip_graph_minicpmv : clip_graph { clip_graph_minicpmv(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; diff --git a/tools/mtmd/models/yasa2.cpp b/tools/mtmd/models/yasa2.cpp new file mode 100644 index 000000000..e8cd3dacb --- /dev/null +++ b/tools/mtmd/models/yasa2.cpp @@ -0,0 +1,191 @@ +// ABOUTME: Yasa2 vision encoder graph builder for ConvNeXt-based architecture. +// ABOUTME: Implements patch embedding, ConvNeXt stages with GRN, and adaptive pooling. + +#include "models.h" + +static ggml_tensor * add_channel_bias( + ggml_context * ctx0, + ggml_tensor * x_whcb, + ggml_tensor * b_c) { + if (!b_c) { + return x_whcb; + } + ggml_tensor * b4 = ggml_reshape_4d(ctx0, b_c, 1, 1, b_c->ne[0], 1); + return ggml_add(ctx0, x_whcb, b4); +} + +static ggml_tensor * mul_channel_weight( + ggml_context * ctx0, + ggml_tensor * x_whcb, + ggml_tensor * w_c) { + if (!w_c) { + return x_whcb; + } + ggml_tensor * w4 = ggml_reshape_4d(ctx0, w_c, 1, 1, w_c->ne[0], 1); + return ggml_mul(ctx0, x_whcb, w4); +} + +ggml_tensor * clip_graph_yasa2::layer_norm_channels(ggml_tensor * inp, ggml_tensor * w, ggml_tensor * b, float eps) { + // Match HF ConvNextLayerNorm(channels_first): + // u = mean_c(x), s = mean_c((x-u)^2), x = (x-u)/sqrt(s+eps) + // cast back to input dtype before affine. + ggml_tensor * cur = ggml_permute(ctx0, inp, 2, 1, 0, 3); // [W,H,C,B] -> [C,H,W,B] + cur = ggml_cont(ctx0, cur); + + ggml_tensor * u = ggml_mean(ctx0, cur); // [1,H,W,B] + ggml_tensor * xm = ggml_sub(ctx0, cur, u); // [C,H,W,B] + + ggml_tensor * s = ggml_mul(ctx0, xm, xm); // [C,H,W,B] + s = ggml_mean(ctx0, s); // [1,H,W,B] + s = ggml_clamp(ctx0, s, eps, 1e30f); // avoid div-by-zero in no-alloc warmup + s = ggml_sqrt(ctx0, s); // [1,H,W,B] + + ggml_tensor * xhat = ggml_div(ctx0, xm, s); // [C,H,W,B] + xhat = ggml_permute(ctx0, xhat, 2, 1, 0, 3); // [W,H,C,B] + xhat = ggml_cont(ctx0, xhat); + xhat = mul_channel_weight(ctx0, xhat, w); + xhat = add_channel_bias(ctx0, xhat, b); + return xhat; +} + +ggml_tensor * clip_graph_yasa2::convnext_grn(ggml_tensor * inp, ggml_tensor * w, ggml_tensor * b) { + // Exact ConvNeXtV2 GRN: + // Gx = ||x||_2 over spatial dims (W,H), Nx = Gx / (mean_c(Gx) + eps) + // y = w * (x * Nx) + b + x + const int64_t wdim = inp->ne[0]; + const int64_t hdim = inp->ne[1]; + const int64_t cdim = inp->ne[2]; + const int64_t bdim = inp->ne[3]; + + // Keep GRN math in fp32 for stability; fp16/bf16 accumulation can drift. + ggml_tensor * sq = ggml_mul(ctx0, inp, inp); + ggml_tensor * sq_flat = ggml_reshape_4d(ctx0, sq, wdim * hdim, cdim, 1, bdim); // [WH,C,1,B] + ggml_tensor * gx = ggml_sum_rows(ctx0, sq_flat); // [1,C,1,B] + gx = ggml_sqrt(ctx0, gx); // [1,C,1,B] + + ggml_tensor * gx_ch_first = ggml_permute(ctx0, gx, 1, 0, 2, 3); // [C,1,1,B] + gx_ch_first = ggml_cont(ctx0, gx_ch_first); + ggml_tensor * gx_mean = ggml_mean(ctx0, gx_ch_first); // [1,1,1,B] + + gx_mean = ggml_clamp(ctx0, gx_mean, 1e-6f, 1e30f); // approx +eps, warmup-safe + ggml_tensor * nx = ggml_div(ctx0, gx, gx_mean); // [1,C,1,B] + nx = ggml_permute(ctx0, nx, 0, 2, 1, 3); // [1,1,C,B] + nx = ggml_cont(ctx0, nx); + + ggml_tensor * xnx = ggml_mul(ctx0, inp, nx); + xnx = mul_channel_weight(ctx0, xnx, w); + xnx = add_channel_bias(ctx0, xnx, b); + return ggml_add(ctx0, inp, xnx); +} + +ggml_cgraph * clip_graph_yasa2::build() { + ggml_tensor * cur = build_inp_raw(); + + // Patch embedding Conv2d(kernel=4, stride=4) + cur = ggml_conv_2d(ctx0, model.yasa_patch_w, cur, patch_size, patch_size, 0, 0, 1, 1); + cur = add_channel_bias(ctx0, cur, model.yasa_patch_b); + ggml_set_name(cur, "yasa2_patch_conv_out"); + cb(cur, "yasa2_patch_conv_out", -1); + cur = layer_norm_channels(cur, model.yasa_patch_ln_w, model.yasa_patch_ln_b, eps); + ggml_set_name(cur, "yasa2_patch_ln_out"); + cb(cur, "yasa2_patch_ln_out", -1); + + // ConvNeXt stages + for (size_t s = 0; s < model.yasa_stages.size(); ++s) { + const auto & stage = model.yasa_stages[s]; + + if (stage.down_conv_w) { + cur = layer_norm_channels(cur, stage.down_ln_w, stage.down_ln_b, eps); + cur = ggml_conv_2d(ctx0, stage.down_conv_w, cur, 2, 2, 0, 0, 1, 1); + cur = add_channel_bias(ctx0, cur, stage.down_conv_b); + ggml_format_name(cur, "yasa2_stage%zu_down_out", s); + } + + for (size_t bi = 0; bi < stage.blocks.size(); ++bi) { + const auto & blk = stage.blocks[bi]; + ggml_tensor * res = cur; + + ggml_tensor * x = ggml_conv_2d_dw(ctx0, blk.dw_w, cur, 1, 1, 3, 3, 1, 1); + x = add_channel_bias(ctx0, x, blk.dw_b); + x = layer_norm_channels(x, blk.ln_w, blk.ln_b, eps); + + // pwconv1/pwconv2 are HF Linear layers over channels; implement via matmul on tokens. + const int64_t w = x->ne[0]; + const int64_t h = x->ne[1]; + const int64_t b = x->ne[3]; + + ggml_tensor * tok = ggml_reshape_3d(ctx0, x, w * h, x->ne[2], b); // [T,C,B] + tok = ggml_permute(ctx0, tok, 1, 0, 2, 3); // [C,T,B] + tok = ggml_cont(ctx0, tok); + + tok = ggml_mul_mat(ctx0, blk.pw1_w, tok); // [4C,T,B] + if (blk.pw1_b) { + ggml_tensor * b1 = ggml_reshape_3d(ctx0, blk.pw1_b, blk.pw1_b->ne[0], 1, 1); // [4C,1,1] + tok = ggml_add(ctx0, tok, b1); + } + x = ggml_permute(ctx0, tok, 1, 0, 2, 3); // [T,4C,B] + x = ggml_cont(ctx0, x); + x = ggml_reshape_4d(ctx0, x, w, h, tok->ne[0], b); // [W,H,4C,B] + x = ggml_gelu_erf(ctx0, x); + x = convnext_grn(x, blk.grn_w, blk.grn_b); + + tok = ggml_reshape_3d(ctx0, x, w * h, x->ne[2], b); // [T,4C,B] + tok = ggml_permute(ctx0, tok, 1, 0, 2, 3); // [4C,T,B] + tok = ggml_cont(ctx0, tok); + + tok = ggml_mul_mat(ctx0, blk.pw2_w, tok); // [C,T,B] + if (blk.pw2_b) { + ggml_tensor * b2 = ggml_reshape_3d(ctx0, blk.pw2_b, blk.pw2_b->ne[0], 1, 1); // [C,1,1] + tok = ggml_add(ctx0, tok, b2); + } + x = ggml_permute(ctx0, tok, 1, 0, 2, 3); // [T,C,B] + x = ggml_cont(ctx0, x); + x = ggml_reshape_4d(ctx0, x, w, h, tok->ne[0], b); // [W,H,C,B] + + cur = ggml_add(ctx0, res, x); + ggml_format_name(cur, "yasa2_stage%zu_blk%zu_out", s, bi); + } + } + + // HF path adds vision position embeddings BEFORE adaptive pooling. + const int64_t pre_w = cur->ne[0]; + const int64_t pre_h = cur->ne[1]; + ggml_tensor * tokens_pre = ggml_reshape_3d(ctx0, cur, pre_w * pre_h, cur->ne[2], cur->ne[3]); // [T,C,B] + tokens_pre = ggml_permute(ctx0, tokens_pre, 1, 0, 2, 3); // [C,T,B] + tokens_pre = ggml_cont(ctx0, tokens_pre); + if (model.yasa_vision_pos_embed && tokens_pre->ne[1] == model.yasa_vision_pos_embed->ne[1]) { + const int64_t n_ch = model.yasa_vision_pos_embed->ne[0]; + const int64_t n_tokens = model.yasa_vision_pos_embed->ne[1]; + ggml_tensor * pos = ggml_reshape_3d(ctx0, model.yasa_vision_pos_embed, (int) n_ch, (int) n_tokens, 1); + tokens_pre = ggml_add(ctx0, tokens_pre, pos); + } + cur = ggml_permute(ctx0, tokens_pre, 1, 0, 2, 3); // [T,C,B] + cur = ggml_cont(ctx0, cur); + cur = ggml_reshape_4d(ctx0, cur, pre_w, pre_h, cur->ne[1], cur->ne[2]); // [W,H,C,B] + + // AdaptiveAvgPool2d target is 8x8 for real inputs, but warmup can use tiny images. + const int pooled_w = std::min(8, (int) cur->ne[0]); + const int pooled_h = std::min(8, (int) cur->ne[1]); + const int kw = std::max(1, (int) cur->ne[0] / pooled_w); + const int kh = std::max(1, (int) cur->ne[1] / pooled_h); + cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, kw, kh, kw, kh, 0, 0); + + // [W,H,C,B] -> [C,T,B] + ggml_tensor * tokens = ggml_reshape_3d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2], cur->ne[3]); + tokens = ggml_permute(ctx0, tokens, 1, 0, 2, 3); + tokens = ggml_cont(ctx0, tokens); + cb(tokens, "yasa2_tokens", -1); + + GGML_ASSERT(model.mm_0_w && model.mm_2_w); + ggml_tensor * embeddings = build_ffn( + tokens, + model.mm_0_w, model.mm_0_b, + nullptr, nullptr, + model.mm_2_w, model.mm_2_b, + FFN_GELU_ERF, + -1); + cb(embeddings, "yasa2_emb", -1); + + ggml_build_forward_expand(gf, embeddings); + return gf; +} diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 854ac81e0..cc3de6a85 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -316,6 +316,19 @@ struct mtmd_context { img_end = "<|vision_end|>"; image_preproc = std::make_unique(ctx_v); } break; + case PROJECTOR_TYPE_YASA2: + { + img_beg = ""; + img_end = ""; + // Currently only supprots single-tile preprocessing: any input is downscaled + // to one image_size x image_size tile (64 output tokens via 8x8 adaptive avg + // pool). + // However, the model itself supports llava-uhd multi-tile tiling for high-res + // images. This will be implemented in a future PR (dispatch on has_pinpoints + // - see LDP/COGVLM branch above) and emit image_grid_pinpoints in the conversion + // script. + image_preproc = std::make_unique(ctx_v); + } break; case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_GEMMA3NV: { From 72d693e4fbf0a188172951fe667d5e2f354ee438 Mon Sep 17 00:00:00 2001 From: Paul Dubs Date: Tue, 21 Apr 2026 20:29:07 +0200 Subject: [PATCH 06/20] spec : reset i_last when low acceptance streak occurs (#22168) By resetting i_last to zero, we will include the current context when rebuilding the speculative map. --- common/speculative.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/common/speculative.cpp b/common/speculative.cpp index daa2b5a8a..20e38bec4 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -749,6 +749,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { mod.reset(); n_low = 0; + i_last = 0; } } else { n_low = 0; From 2248799a5810c4f09f31d361ea57fb050554f377 Mon Sep 17 00:00:00 2001 From: Mengsheng Wu Date: Wed, 22 Apr 2026 04:53:44 +0800 Subject: [PATCH 07/20] hexagon: fix missing v79 entry in libggml-htp.inf (#22194) --- ggml/src/ggml-hexagon/libggml-htp.inf | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-hexagon/libggml-htp.inf b/ggml/src/ggml-hexagon/libggml-htp.inf index 656d2d9ab..360d8b122 100644 --- a/ggml/src/ggml-hexagon/libggml-htp.inf +++ b/ggml/src/ggml-hexagon/libggml-htp.inf @@ -18,6 +18,7 @@ libggml-htp-v68.so = 1 libggml-htp-v69.so = 1 libggml-htp-v73.so = 1 libggml-htp-v75.so = 1 +libggml-htp-v79.so = 1 libggml-htp-v81.so = 1 [ControlFlags] @@ -31,6 +32,7 @@ libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE +libggml-htp-v79.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v81.so,,,0x10 ;COPYFLG_NO_OVERWRITE [Strings] From 5a4cd6741fc33227cdacb329f355ab21f8481de2 Mon Sep 17 00:00:00 2001 From: Shreya Jain Date: Tue, 21 Apr 2026 14:16:04 -0700 Subject: [PATCH 08/20] Hexagon: DAIG op (#22195) * hexagon: Add DIAG op * hexagon: add HVX support and DMA double buffering * hexagon: fix fatal error * hexagon: remove as many pragma(s) as possible --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 28 +++ ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/diag-ops.c | 216 +++++++++++++++++++++++ ggml/src/ggml-hexagon/htp/htp-ctx.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 3 + 6 files changed, 250 insertions(+) create mode 100644 ggml/src/ggml-hexagon/htp/diag-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 3d68b8004..5e206c5e9 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2596,6 +2596,29 @@ static bool ggml_hexagon_supported_cumsum(const struct ggml_hexagon_session * se return true; } +static bool ggml_hexagon_supported_diag(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + // diag only supports F32 currently + if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + // Input must have ne[1] == 1 (vector input) + if (src0->ne[1] != 1) { + return false; + } + + // Output must be square in first two dimensions + if (dst->ne[0] != dst->ne[1] || dst->ne[0] != src0->ne[0]) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast(backend->context); return sess->c_name(); @@ -2632,6 +2655,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_ROPE: return HTP_OP_ROPE; case GGML_OP_REPEAT: return HTP_OP_REPEAT; case GGML_OP_CUMSUM: return HTP_OP_CUMSUM; + case GGML_OP_DIAG: return HTP_OP_DIAG; case GGML_OP_UNARY: switch (ggml_get_unary_op(t)) { @@ -3159,6 +3183,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_cumsum(sess, op); break; + case GGML_OP_DIAG: + supp = ggml_hexagon_supported_diag(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 9ca759459..82c10b57b 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -34,6 +34,7 @@ add_library(${HTP_LIB} SHARED argsort-ops.c ssm-conv.c cumsum-ops.c + diag-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/diag-ops.c b/ggml/src/ggml-hexagon/htp/diag-ops.c new file mode 100644 index 000000000..9b3194d90 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/diag-ops.c @@ -0,0 +1,216 @@ +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-types.h" +#include "hex-utils.h" +#include "hvx-copy.h" +#include "hex-dma.h" + +#define htp_diag_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict dst = octx->dst; \ + \ + const uint32_t ne02 = src0->ne[2]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +struct htp_diag_context { + struct htp_ops_context * octx; + size_t src_batch_size; + size_t dst_row_size; + size_t src_batch_size_aligned; + size_t dst_row_size_aligned; + uint32_t batches_per_thread; + uint32_t total_batches; +}; + +#define htp_diag_preamble \ + struct htp_diag_context * dctx = (struct htp_diag_context *) data; \ + struct htp_ops_context * octx = dctx->octx; \ + htp_diag_tensors_preamble; + +static inline void hvx_diag_row_f32(const float * restrict src, float * restrict dst, + uint32_t row_idx, uint32_t n) { + hvx_splat_f32_a((uint8_t *) dst, 0.0f, n); + dst[row_idx] = src[row_idx]; +} + +// --------------------------------------------------------------------------- +// Per thread worker: DMA src fetch, compute in VTCM, DMA dst writeback +// --------------------------------------------------------------------------- + +static void diag_thread_f32_dma(unsigned int nth, unsigned int ith, void * data) { + htp_diag_preamble; + dma_queue * dma_queue = octx->ctx->dma[ith]; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ib0 = dctx->batches_per_thread * ith; + const uint32_t ib1 = MIN(ib0 + dctx->batches_per_thread, dctx->total_batches); + + if (ib0 >= ib1) { + return; + } + + const size_t src_batch_size = dctx->src_batch_size; + const size_t dst_row_size = dctx->dst_row_size; + const size_t src_batch_size_aligned = dctx->src_batch_size_aligned; + const size_t dst_row_size_aligned = dctx->dst_row_size_aligned; + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + // 1 src buffer + 1 dst row buffer per thread in VTCM + uint8_t * src_spad = octx->src0_spad.data + (ith * src_batch_size_aligned); + uint8_t * dst_spad = octx->dst_spad.data + (ith * dst_row_size_aligned); + + for (uint32_t ib = ib0; ib < ib1; ib++) { + const uint32_t i3 = ib / ne02; + const uint32_t i2 = ib % ne02; + + const uint8_t * src_batch = src_data + i3 * nb03 + i2 * nb02; + + // Fetch source vector into VTCM + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src_spad, src_batch), + src_batch_size_aligned, src_batch_size, 1); + dma_queue_flush(dma_queue); + + const float * src_spad_f32 = (const float *) src_spad; + float * dst_spad_f32 = (float *) dst_spad; + + for (uint32_t i1 = 0; i1 < ne1; i1++) { + // Compute row in VTCM + hvx_diag_row_f32(src_spad_f32, dst_spad_f32, i1, ne0); + + // Write completed row back to DDR + uint8_t * dst_row = dst_data + i3 * nb3 + i2 * nb2 + i1 * nb1; + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(dst_row, dst_spad), + dst_row_size, dst_row_size_aligned, 1); + dma_queue_flush(dma_queue); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "diag-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ib0, ib1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// Per thread worker: Direct HVX (no DMA) +// --------------------------------------------------------------------------- + +static void diag_thread_f32(unsigned int nth, unsigned int ith, void * data) { + htp_diag_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + const uint32_t ib0 = dctx->batches_per_thread * ith; + const uint32_t ib1 = MIN(ib0 + dctx->batches_per_thread, dctx->total_batches); + + for (uint32_t ib = ib0; ib < ib1; ib++) { + const uint32_t i3 = ib / ne02; + const uint32_t i2 = ib % ne02; + + const float * restrict src_batch = (const float *)(src_data + i3 * nb03 + i2 * nb02); + + for (uint32_t i1 = 0; i1 < ne1; i1++) { + float * restrict dst_row = (float *)(dst_data + i3 * nb3 + i2 * nb2 + i1 * nb1); + hvx_diag_row_f32(src_batch, dst_row, i1, ne0); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "diag-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ib0, ib1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_diag_f32(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const uint32_t total_batches = src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_batches); + + const size_t src_batch_size = src0->ne[0] * sizeof(float); + const size_t dst_row_size = dst->ne[0] * sizeof(float); + const size_t src_batch_size_aligned = hex_round_up(src_batch_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + // 1 src buffer + 1 dst row buffer per thread + const size_t spad_per_thread = src_batch_size_aligned + dst_row_size_aligned; + + octx->src0_spad.size_per_thread = src_batch_size_aligned; + octx->dst_spad.size_per_thread = dst_row_size_aligned; + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.src = NULL; + + struct htp_diag_context dctx = { + .octx = octx, + .src_batch_size = src_batch_size, + .dst_row_size = dst_row_size, + .src_batch_size_aligned = src_batch_size_aligned, + .dst_row_size_aligned = dst_row_size_aligned, + .batches_per_thread = (total_batches + n_threads - 1) / n_threads, + .total_batches = total_batches, + }; + + if (octx->ctx->vtcm_size < spad_per_thread * n_threads) { + worker_pool_run_func(octx->ctx->worker_pool, diag_thread_f32, &dctx, n_threads); + } else { + worker_pool_run_func(octx->ctx->worker_pool, diag_thread_f32_dma, &dctx, n_threads); + } + + return HTP_STATUS_OK; +} + +int op_diag(struct htp_ops_context * octx) { + const struct htp_tensor * dst = octx->dst; + + int err = HTP_STATUS_OK; + + switch (dst->type) { + case HTP_TYPE_F32: + err = op_diag_f32(octx); + break; + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 8b5e47ade..038941af0 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -98,5 +98,6 @@ int op_repeat(struct htp_ops_context * octx); int op_argsort(struct htp_ops_context * octx); int op_ssm_conv(struct htp_ops_context * octx); int op_cumsum(struct htp_ops_context * octx); +int op_diag(struct htp_ops_context * octx); #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 79b5ecd22..002dd1c12 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -80,6 +80,7 @@ enum htp_op_code { HTP_OP_SSM_CONV, HTP_OP_REPEAT, HTP_OP_CUMSUM, + HTP_OP_DIAG, HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 5091623a6..d633145c9 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -514,6 +514,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_CUMSUM: return op_cumsum(octx); + case HTP_OP_DIAG: + return op_diag(octx); + case HTP_OP_INVALID: break; From 04fe84b69dd20084ea5b4ed4da578fb05c95eb10 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Wed, 22 Apr 2026 00:26:09 +0200 Subject: [PATCH 09/20] server: allow cancel loading model (#21814) --- tools/server/server-models.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 6066611f5..15c11c3c9 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -712,6 +712,11 @@ void server_models::unload(const std::string & name) { if (it->second.meta.is_running()) { SRV_INF("stopping model instance name=%s\n", name.c_str()); stopping_models.insert(name); + if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) { + // special case: if model is in loading state, unloading means force-killing it + SRV_WRN("model name=%s is still loading, force-killing\n", name.c_str()); + subprocess_terminate(it->second.subproc.get()); + } cv_stop.notify_all(); // status change will be handled by the managing thread } else { From 2799d933b5b5a0a54b151ea213f4c12c73b15866 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Wed, 22 Apr 2026 08:05:21 +0900 Subject: [PATCH 10/20] ggml-webgpu: reset CPU/GPU profiling time when freeing context (#22050) * Reset the CPU/GPU profiling time when freeing context. * move GPU profiling time from global context to webgpu_context. --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index aa20a745e..a29231452 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -211,6 +211,7 @@ struct webgpu_global_context_struct { wgpu::Buffer memset_params_buf; webgpu_pipeline memset_pipeline; + // TODO: We should rework the CPU profiling time handling to make it more useful. ref: https://github.com/ggml-org/llama.cpp/pull/22050 #ifdef GGML_WEBGPU_CPU_PROFILE // Profiling: labeled CPU time in ms (total) std::unordered_map cpu_time_ms; @@ -218,11 +219,6 @@ struct webgpu_global_context_struct { std::unordered_map cpu_detail_ms; #endif -#ifdef GGML_WEBGPU_GPU_PROFILE - // Profiling: per-shader GPU time in ms - std::unordered_map shader_gpu_time_ms; -#endif - #ifdef GGML_WEBGPU_DEBUG wgpu::Buffer debug_host_buf; wgpu::Buffer debug_dev_buf; @@ -268,10 +264,12 @@ struct webgpu_context_struct { size_t memset_bytes_per_thread; #ifdef GGML_WEBGPU_GPU_PROFILE - wgpu::Buffer profile_timestamp_dev_buf; - wgpu::Buffer profile_timestamp_host_buf; - wgpu::QuerySet profile_timestamp_query_set; - uint32_t profile_timestamp_query_count = 0; + // Profiling: per-shader GPU time in ms + std::unordered_map shader_gpu_time_ms; + wgpu::Buffer profile_timestamp_dev_buf; + wgpu::Buffer profile_timestamp_host_buf; + wgpu::QuerySet profile_timestamp_query_set; + uint32_t profile_timestamp_query_count = 0; #endif ~webgpu_context_struct() { @@ -713,12 +711,12 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { #ifdef GGML_WEBGPU_GPU_PROFILE std::cout << "\n[ggml_webgpu gpu profiling summary]\n"; double total_gpu = 0.0; - for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) { + for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) { total_gpu += kv.second; } std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n"; std::cout << "\nggml_webgpu: gpu breakdown:\n"; - for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) { + for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) { double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0; std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2) << pct << "%)\n"; @@ -2511,7 +2509,7 @@ static void ggml_backend_webgpu_collect_profile_results(webgpu_context & for (size_t i = 0; i < pipeline_names.size(); ++i) { // WebGPU timestamps are in ns; convert to ms. const double elapsed_ms = double(ts_data[2 * i + 1] - ts_data[2 * i]) * 1e-6; - ctx->global_ctx->shader_gpu_time_ms[pipeline_names[i]] += elapsed_ms; + ctx->shader_gpu_time_ms[pipeline_names[i]] += elapsed_ms; } ctx->profile_timestamp_host_buf.Unmap(); From 0dedb9ef7a71fcebfa6fb17e0d6e6abd6e893376 Mon Sep 17 00:00:00 2001 From: Aparna M P Date: Wed, 22 Apr 2026 04:54:20 +0530 Subject: [PATCH 11/20] hexagon: add support for FILL op (#22198) Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 16 +++ ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/fill-ops.c | 123 +++++++++++++++++++++++ ggml/src/ggml-hexagon/htp/htp-ctx.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 3 + 6 files changed, 145 insertions(+) create mode 100644 ggml/src/ggml-hexagon/htp/fill-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 5e206c5e9..cdd9fcf59 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2655,6 +2655,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_ROPE: return HTP_OP_ROPE; case GGML_OP_REPEAT: return HTP_OP_REPEAT; case GGML_OP_CUMSUM: return HTP_OP_CUMSUM; + case GGML_OP_FILL: return HTP_OP_FILL; case GGML_OP_DIAG: return HTP_OP_DIAG; case GGML_OP_UNARY: @@ -3053,6 +3054,17 @@ static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * se return true; } +static bool ggml_hexagon_supported_fill(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * dst = op; + + if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { auto sess = static_cast(dev->context); @@ -3183,6 +3195,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_cumsum(sess, op); break; + case GGML_OP_FILL: + supp = ggml_hexagon_supported_fill(sess, op); + break; + case GGML_OP_DIAG: supp = ggml_hexagon_supported_diag(sess, op); break; diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 82c10b57b..b1ae60a9c 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -34,6 +34,7 @@ add_library(${HTP_LIB} SHARED argsort-ops.c ssm-conv.c cumsum-ops.c + fill-ops.c diag-ops.c ) diff --git a/ggml/src/ggml-hexagon/htp/fill-ops.c b/ggml/src/ggml-hexagon/htp/fill-ops.c new file mode 100644 index 000000000..3ccfbe74e --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/fill-ops.c @@ -0,0 +1,123 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#include + +#include "hvx-copy.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" + +// ggml op_params layout for FILL: +// op_params[0] (as float) - the scalar fill value + +#define fill_preamble \ + const struct htp_tensor * dst = octx->dst; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + \ + const uint32_t nr = ne1 * ne2 * ne3; + +struct htp_fill_context { + struct htp_ops_context * octx; + uint32_t nrows_per_thread; + uint32_t total_rows; // ne1 * ne2 * ne3 + bool opt_path; + HVX_Vector splat_vec; + uint32_t elem_size; +}; + +static void fill_thread(unsigned int nth, unsigned int ith, void * data) { + const struct htp_fill_context * fctx = (const struct htp_fill_context *) data; + struct htp_ops_context * octx = fctx->octx; + fill_preamble; + + // Parallelise over the flat row index spanning ne1*ne2*ne3 + const uint32_t ir0 = fctx->nrows_per_thread * ith; + const uint32_t ir1 = MIN(ir0 + fctx->nrows_per_thread, fctx->total_rows); + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + if (fctx->opt_path) { + // Opt path: tensor is fully contiguous, treat as flat array + const uint32_t elem_start = ir0 * ne0; + const uint32_t elem_end = ir1 * ne0; + uint8_t * dst_ptr = (uint8_t *) dst->data + elem_start * fctx->elem_size; + hvx_splat_u(dst_ptr, fctx->splat_vec, elem_end - elem_start, fctx->elem_size); + } else { + // Non-contiguous path: must respect strides + for (uint32_t ir = ir0; ir < ir1; ++ir) { + const uint32_t i1 = ir % ne1; + const uint32_t i2 = (ir / ne1) % ne2; + const uint32_t i3 = ir / (ne1 * ne2); + uint8_t * dst_ptr = (uint8_t *) dst->data + i1*nb1 + i2*nb2 + i3*nb3; + hvx_splat_u(dst_ptr, fctx->splat_vec, ne0, fctx->elem_size); + } + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + FARF(HIGH, "fill %u/%u: rows %u:%u usec %u\n", + ith, nth, ir0, ir1, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_fill(struct htp_ops_context * octx) { + fill_preamble; + + if (dst->type != HTP_TYPE_F32 && dst->type != HTP_TYPE_F16) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + // nr = ne1*ne2*ne3 (flat row count across all outer dims); parallelise over it. + const uint32_t n_threads = MIN(nr, octx->n_threads); + + // Optimize if fully contiguous: skip stride arithmetic, treat as flat array + const bool opt_path = (nb2 == nb1 * ne1) && (nb3 == nb2 * ne2); + + FARF(HIGH, "fill: (%ux%ux%ux%u) type=%u opt=%d\n", + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->type, (int) opt_path); + + float val_f32 = 0.f; + memcpy(&val_f32, &octx->op_params[0], sizeof(float)); + + struct htp_fill_context fctx = { + .octx = octx, + .nrows_per_thread = (nr + n_threads - 1) / n_threads, + .total_rows = nr, + .opt_path = opt_path, + }; + + switch (dst->type) { + case HTP_TYPE_F32: + fctx.splat_vec = hvx_vec_splat_f32(val_f32); + fctx.elem_size = sizeof(float); + break; + case HTP_TYPE_F16: + fctx.splat_vec = hvx_vec_splat_f16((_Float16) val_f32); + fctx.elem_size = sizeof(_Float16); + break; + default: + return HTP_STATUS_NO_SUPPORT; + } + + worker_pool_run_func(octx->ctx->worker_pool, fill_thread, &fctx, n_threads); + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 038941af0..78455e6b0 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -98,6 +98,7 @@ int op_repeat(struct htp_ops_context * octx); int op_argsort(struct htp_ops_context * octx); int op_ssm_conv(struct htp_ops_context * octx); int op_cumsum(struct htp_ops_context * octx); +int op_fill(struct htp_ops_context * octx); int op_diag(struct htp_ops_context * octx); #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 002dd1c12..62d6ec022 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -80,6 +80,7 @@ enum htp_op_code { HTP_OP_SSM_CONV, HTP_OP_REPEAT, HTP_OP_CUMSUM, + HTP_OP_FILL, HTP_OP_DIAG, HTP_OP_INVALID diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index d633145c9..9185c9ffe 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -514,6 +514,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_CUMSUM: return op_cumsum(octx); + case HTP_OP_FILL: + return op_fill(octx); + case HTP_OP_DIAG: return op_diag(octx); From ca7f7b7b947842384cd8dda4a17a1868f1493a3e Mon Sep 17 00:00:00 2001 From: Chen Yuan Date: Tue, 21 Apr 2026 23:18:57 -0400 Subject: [PATCH 12/20] ggml-webgpu(shader): support conv2d kernels. (#21964) * ggml(webgpu): fix the busy-polls in Emscripten in the waitAny after #20618, and remove the busy webgpu log * Merge with upstream * Fix GET_ROWS packed integer NaN when using f16 as memory buffer in shader quants * Update Unary wgsl EXP and EXPM1 for f16 stability * Fix GET_ROWS IQ4_XS strcut for NaN f16 canonicalization * Fix numerical percision for unary sqrt when working with f16 * Fix NaN canonicalization for packed integers using f16 * Update err threshold for binary div ops when using f16 * backend: Keep one Dawn/WebGPU instance alive for the lifetime of the static backend * clean: uncomment existing code logs * clean: clean the unncessary debug info * Refactor and generalize dequant helpers * Remove deprecated quant structs * Refactor shader defines to reduce repetition * Remove error override for F16 type * fix: fix the accidential removal of the proper initialization of ctx * clean: clean legacy and format code * fix: did not modify tests ops * shader(conv2d): add conv2d shader kernels and pass f32 and f16 tests * shader(conv2d): fix the out of bounds memory access in the weight indexing * shader(conv2d): clean unused variables and optimize the computation * merge: use the new entries function * clean: address the formatting issues * clean: address the warning issues * clear: clean the shader editorconfig-checker issues * clear: clean the shader editorconfig-checker with utf-8 --------- Co-authored-by: Jeremy J. Hartmann --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 63 +++++++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 89 ++++++++++ ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl | 165 ++++++++++++++++++ 3 files changed, 317 insertions(+) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 9d88f9805..f84dfee9d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -240,6 +240,27 @@ struct ggml_webgpu_ssm_conv_pipeline_key { } }; +/** CONV 2D */ +struct ggml_webgpu_conv2d_pipeline_key { + ggml_type weight_type; + ggml_type input_type; + ggml_type output_type; + + bool operator==(const ggml_webgpu_conv2d_pipeline_key & other) const { + return weight_type == other.weight_type && input_type == other.input_type && output_type == other.output_type; + } +}; + +struct ggml_webgpu_conv2d_pipeline_key_hash { + size_t operator()(const ggml_webgpu_conv2d_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.weight_type); + ggml_webgpu_hash_combine(seed, key.input_type); + ggml_webgpu_hash_combine(seed, key.output_type); + return seed; + } +}; + /** Gated Delta Net **/ struct ggml_webgpu_gated_delta_net_pipeline_key { int type; @@ -789,6 +810,8 @@ class ggml_webgpu_shader_lib { rope_pipelines; std::unordered_map soft_max_pipelines; + std::unordered_map + conv2d_pipelines; public: ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } @@ -2382,6 +2405,46 @@ class ggml_webgpu_shader_lib { return soft_max_pipelines[key]; } + webgpu_pipeline get_conv2d_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_conv2d_pipeline_key key = {}; + key.weight_type = context.src0->type; + key.input_type = context.src1->type; + key.output_type = context.dst->type; + + auto it = conv2d_pipelines.find(key); + if (it != conv2d_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "conv_2d"; + + auto push_type_defines = [&](const char * prefix, ggml_type type) { + std::string s_prefix = prefix; + if (type == GGML_TYPE_F32) { + defines.push_back(s_prefix + "_F32"); + } else if (type == GGML_TYPE_F16) { + defines.push_back(s_prefix + "_F16"); + } else { + GGML_ABORT("Unsupported type for CONV_2D shader"); + } + }; + + push_type_defines("WEIGHT", key.weight_type); + push_type_defines("INPUT", key.input_type); + push_type_defines("OUTPUT", key.output_type); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_conv2d, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + conv2d_pipelines[key] = pipeline; + return conv2d_pipelines[key]; + } + private: static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, std::string shader_code, diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index a29231452..551586751 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -8,6 +8,7 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" #include "ggml-webgpu-shader-lib.hpp" +#include "ggml.h" #ifdef __EMSCRIPTEN__ # include @@ -921,6 +922,87 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } +static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + const int32_t s0 = ggml_get_op_params_i32(dst, 0); + const int32_t s1 = ggml_get_op_params_i32(dst, 1); + const int32_t p0 = ggml_get_op_params_i32(dst, 2); + const int32_t p1 = ggml_get_op_params_i32(dst, 3); + const int32_t d0 = ggml_get_op_params_i32(dst, 4); + const int32_t d1 = ggml_get_op_params_i32(dst, 5); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], + + (uint32_t) src1->ne[0], + (uint32_t) src1->ne[1], + + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + + (uint32_t) s0, + (uint32_t) s1, + (uint32_t) p0, + (uint32_t) p1, + (uint32_t) d0, + (uint32_t) d1, + }; + + std::vector entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), + }; + + uint32_t max_wg_size = + std::min((uint32_t) WEBGPU_MAX_WG_SIZE, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupSizeX); + uint32_t wg_size = + std::min((uint32_t) ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, max_wg_size); + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = wg_size; + + webgpu_pipeline pipeline = ctx->shader_lib->get_conv2d_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + uint32_t n_out = ggml_nelements(dst); + uint32_t total_wg = CEIL_DIV(n_out, decisions->wg_size); + uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + uint32_t wg_x = std::min(total_wg, max_wg); + uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, @@ -2477,6 +2559,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context c case GGML_OP_SUM: case GGML_OP_SUM_ROWS: return ggml_webgpu_sum_rows(ctx, src0, node); + case GGML_OP_CONV_2D: + return ggml_webgpu_conv_2d(ctx, src0, src1, node); default: return std::nullopt; } @@ -3495,6 +3579,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SOLVE_TRI: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; break; + case GGML_OP_CONV_2D: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) && + (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); + break; case GGML_OP_SSM_CONV: supports_op = op->type == GGML_TYPE_F32; break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl new file mode 100644 index 000000000..9eb131dc2 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl @@ -0,0 +1,165 @@ +#include "common_decls.tmpl" +enable f16; + +@group(0) @binding(0) +#if defined(WEIGHT_F32) +var weights: array; +#elif defined(WEIGHT_F16) +var weights: array; +#endif + +@group(0) @binding(1) +#if defined(INPUT_F32) +var input: array; +#elif defined(INPUT_F16) +var input: array; +#endif + +@group(0) @binding(2) +#if defined(OUTPUT_F32) +var output: array; +#elif defined(OUTPUT_F16) +var output: array; +#endif + +struct Params { + offset_w: u32, + offset_i: u32, + offset_o: u32, + + // element strides + sw0: u32, sw1: u32, sw2: u32, sw3: u32, + si0: u32, si1: u32, si2: u32, si3: u32, + so0: u32, so1: u32, so2: u32, so3: u32, + + // kernel dimensions + KW: u32, KH: u32, IC: u32, + // input dimensions + IW: u32, IH: u32, + // output dimensions + OW: u32, OH: u32, OC_out: u32, N_out: u32, + + // stride + s0: u32, s1: u32, + // padding + p0: u32, p1: u32, + // dilation + d0: u32, d1: u32, +}; + +@group(0) @binding(3) +var params: Params; + +fn load_weight(idx: u32) -> f32 { + #if defined(WEIGHT_F32) + return weights[idx]; + #elif defined(WEIGHT_F16) + return f32(weights[idx]); + #endif +} + +fn load_input(idx: u32) -> f32 { + #if defined(INPUT_F32) + return input[idx]; + #elif defined(INPUT_F16) + return f32(input[idx]); + #endif +} + +fn store_output(idx: u32, val: f32) { + #if defined(OUTPUT_F32) + output[idx] = val; + #elif defined(OUTPUT_F16) + output[idx] = f16(val); + #endif +} + +fn ceil_div_u32(x: u32, y: u32) -> u32 { + return (x + y - 1) / y; +} + +// returns the first valid kernel index k such that base + k * step >= 0 +fn first_valid_k(base: i32, step: u32) -> u32 { + if (base >= 0) { + return 0; + } + + return ceil_div_u32(u32(-base), step); +} + +// returns the first invalid kernel index k such that base + k * step >= limit so valid k are in [0, end_valid_k) +fn end_valid_k(base: i32, step: u32, limit: u32, k_max: u32) -> u32 { + let remaining = i32(limit) - base; + if (remaining <= 0) { + return 0; + } + + return min(k_max, ceil_div_u32(u32(remaining), step)); +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_wg: vec3 +) { + + let threads_per_group = u32(WG_SIZE); + let i_out = gid.x + (num_wg.x * threads_per_group) * gid.y; + let n_out = params.OW * params.OH * params.OC_out * params.N_out; + + var sum: f32 = 0.0; + if (i_out >= n_out) { + return; + } + + // Kernel layout: [KW, KH, IC, ..] + // Input layout: [IW, IH, .., ..] + // Output layout: [OW, OH, OC, N] + + var i = i_out; + let n = i / (params.OC_out * params.OH * params.OW); + i = i % (params.OC_out * params.OH * params.OW); + let oc = i / (params.OH * params.OW); + i = i % (params.OH * params.OW); + let oh = i / params.OW; + let ow = i % params.OW; + + let ow_base = i32(ow * params.s0) - i32(params.p0); + let oh_base = i32(oh * params.s1) - i32(params.p1); + + // clip the valid kernel window once + let kw_begin = first_valid_k(ow_base, params.d0); + let kw_end = end_valid_k(ow_base, params.d0, params.IW, params.KW); + let kh_begin = first_valid_k(oh_base, params.d1); + let kh_end = end_valid_k(oh_base, params.d1, params.IH, params.KH); + + // entire receptive field is out of bounds + if (kw_begin >= kw_end || kh_begin >= kh_end) { + let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3; + store_output(out_idx, 0.0); + return; + } + + let weight_oc_base = params.offset_w + oc * params.sw3; + let input_n_base = params.offset_i + n * params.si3; + + for (var ic: u32 = 0; ic < params.IC; ic += 1) { + let w_base_ic = ic * params.sw2 + weight_oc_base; + let in_base = ic * params.si2 + input_n_base; + + for (var kh: u32 = kh_begin; kh < kh_end; kh += 1) { + let ih = u32(oh_base + i32(kh * params.d1)); + let w_row_base = w_base_ic + kh * params.sw1; + let in_row_base = in_base + ih * params.si1; + for (var kw: u32 = kw_begin; kw < kw_end; kw += 1) { + let iw = u32(ow_base + i32(kw * params.d0)); + let w_idx = w_row_base + kw * params.sw0; + let in_idx = in_row_base + iw * params.si0; + sum += load_weight(w_idx) * load_input(in_idx); + } + } + } + + let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3; + store_output(out_idx, sum); +} From 134d6e54d4cd1311360bf6beeb6c779d90e09b87 Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Wed, 22 Apr 2026 10:28:45 +0200 Subject: [PATCH 13/20] common/chat, server: refactor, move all conversion functions to common, add tests (#20690) * Refactor conversion functions --- common/chat.cpp | 69 ++-- common/chat.h | 5 +- tests/CMakeLists.txt | 2 + tests/test-chat.cpp | 130 ++++++- tools/server/CMakeLists.txt | 2 + tools/server/server-chat.cpp | 588 ++++++++++++++++++++++++++++++++ tools/server/server-chat.h | 24 ++ tools/server/server-common.cpp | 567 ------------------------------ tools/server/server-common.h | 12 - tools/server/server-context.cpp | 7 +- tools/server/server-task.cpp | 5 +- 11 files changed, 772 insertions(+), 639 deletions(-) create mode 100644 tools/server/server-chat.cpp create mode 100644 tools/server/server-chat.h diff --git a/common/chat.cpp b/common/chat.cpp index e424206af..7c071560f 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -397,6 +397,25 @@ json common_chat_msgs_to_json_oaicompat(const std::vector & msg return render_message_to_json(msgs, c); } +json common_chat_tools_to_json_oaicompat(const std::vector & tools) { + if (tools.empty()) { + return json(); + } + + auto result = json::array(); + for (const auto & tool : tools) { + result.push_back({ + { "type", "function" }, + { "function", { + { "name", tool.name }, + { "description", tool.description }, + { "parameters", json::parse(tool.parameters) }, + }}, + }); + } + return result; +} + std::vector common_chat_tools_parse_oaicompat(const json & tools) { std::vector result; @@ -432,56 +451,6 @@ std::vector common_chat_tools_parse_oaicompat(const json & too return result; } -json common_chat_tools_to_json_oaicompat(const std::vector & tools) { - if (tools.empty()) { - return json(); - } - - auto result = json::array(); - for (const auto & tool : tools) { - result.push_back({ - { "type", "function" }, - { "function", - { - { "name", tool.name }, - { "description", tool.description }, - { "parameters", json::parse(tool.parameters) }, - } }, - }); - } - return result; -} - -json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) { - json delta = json::object(); - if (!diff.reasoning_content_delta.empty()) { - delta["reasoning_content"] = diff.reasoning_content_delta; - } - if (!diff.content_delta.empty()) { - delta["content"] = diff.content_delta; - } - if (diff.tool_call_index != std::string::npos) { - json tool_call; - tool_call["index"] = diff.tool_call_index; - if (!diff.tool_call_delta.id.empty()) { - tool_call["id"] = diff.tool_call_delta.id; - tool_call["type"] = "function"; - } - if (!diff.tool_call_delta.name.empty() || !diff.tool_call_delta.arguments.empty()) { - json function = json::object(); - if (!diff.tool_call_delta.name.empty()) { - function["name"] = diff.tool_call_delta.name; - } - if (!diff.tool_call_delta.arguments.empty()) { - function["arguments"] = diff.tool_call_delta.arguments; - } - tool_call["function"] = function; - } - delta["tool_calls"] = json::array({ tool_call }); - } - return delta; -} - bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { if (use_jinja) { try { diff --git a/common/chat.h b/common/chat.h index b06ca37fd..9122f2967 100644 --- a/common/chat.h +++ b/common/chat.h @@ -256,14 +256,13 @@ bool common_chat_templates_support_enable_thinking(const common_chat_templates * // Parses a JSON array of messages in OpenAI's chat completion API format. std::vector common_chat_msgs_parse_oaicompat(const nlohmann::ordered_json & messages); +std::vector common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools); + // DEPRECATED: only used in tests nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text = false); -std::vector common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools); nlohmann::ordered_json common_chat_tools_to_json_oaicompat(const std::vector & tools); -nlohmann::ordered_json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff); - // get template caps, useful for reporting to server /props endpoint std::map common_chat_templates_get_caps(const common_chat_templates * chat_templates); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b282c3239..edb585b9f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -155,6 +155,8 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS) llama_build_and_test(test-grammar-integration.cpp) llama_build_and_test(test-llama-grammar.cpp) llama_build_and_test(test-chat.cpp WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}) + target_include_directories(test-chat PRIVATE ${PROJECT_SOURCE_DIR}/tools/server) + target_link_libraries(test-chat PRIVATE server-context) # TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8 if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") llama_build_and_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}) diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 79e1891b8..52b480c24 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -7,6 +7,7 @@ // #include "../src/llama-grammar.h" #include "../src/unicode.h" +#include "../tools/server/server-chat.h" #include "chat-auto-parser.h" #include "chat.h" #include "common.h" @@ -1514,6 +1515,117 @@ static void test_tools_oaicompat_json_conversion() { common_chat_tools_to_json_oaicompat({ special_function_tool }).dump(2)); } +static void test_convert_responses_to_chatcmpl() { + LOG_DBG("%s\n", __func__); + + // Test basic conversion with input messages (user/assistant alternating) + { + json input = json::parse(R"({ + "input": [ + { + "type": "message", + "role": "user", + "content": "hi wassup" + }, + { + "type": "message", + "role": "assistant", + "content": "Hey! 👋 Not much, just here ready to chat. What's up with you? Anything I can help you with today?" + }, + { + "type": "message", + "role": "user", + "content": "hi" + } + ], + "model": "gpt-5-mini", + "stream": false, + "text": {}, + "reasoning": { + "effort": "medium" + } + })"); + + json result = server_chat_convert_responses_to_chatcmpl(input); + + // Verify messages were converted correctly + assert_equals(true, result.contains("messages")); + assert_equals(true, result.at("messages").is_array()); + assert_equals((size_t)3, result.at("messages").size()); + + // Check first message (user) + const auto & msg0 = result.at("messages")[0]; + assert_equals(std::string("user"), msg0.at("role").get()); + assert_equals(true, msg0.at("content").is_array()); + assert_equals(std::string("text"), msg0.at("content")[0].at("type").get()); + assert_equals(std::string("hi wassup"), msg0.at("content")[0].at("text").get()); + + // Check second message (assistant) + const auto & msg1 = result.at("messages")[1]; + assert_equals(std::string("assistant"), msg1.at("role").get()); + assert_equals(true, msg1.at("content").is_array()); + assert_equals(std::string("text"), msg1.at("content")[0].at("type").get()); + assert_equals(std::string("Hey! 👋 Not much, just here ready to chat. What's up with you? Anything I can help you with today?"), msg1.at("content")[0].at("text").get()); + + // Check third message (user) + const auto & msg2 = result.at("messages")[2]; + assert_equals(std::string("user"), msg2.at("role").get()); + assert_equals(true, msg2.at("content").is_array()); + assert_equals(std::string("text"), msg2.at("content")[0].at("type").get()); + assert_equals(std::string("hi"), msg2.at("content")[0].at("text").get()); + + // Verify other fields preserved + assert_equals(std::string("gpt-5-mini"), result.at("model").get()); + assert_equals(false, result.at("stream").get()); + } + + // Test string input + { + json input = json::parse(R"({ + "input": "Hello, world!", + "model": "test-model" + })"); + + json result = server_chat_convert_responses_to_chatcmpl(input); + + assert_equals((size_t)1, result.at("messages").size()); + const auto & msg = result.at("messages")[0]; + assert_equals(std::string("user"), msg.at("role").get()); + assert_equals(std::string("Hello, world!"), msg.at("content").get()); + } + + // Test with instructions (system message) + { + json input = json::parse(R"({ + "input": "Hello", + "instructions": "You are a helpful assistant.", + "model": "test-model" + })"); + + json result = server_chat_convert_responses_to_chatcmpl(input); + + assert_equals((size_t)2, result.at("messages").size()); + const auto & sys_msg = result.at("messages")[0]; + assert_equals(std::string("system"), sys_msg.at("role").get()); + assert_equals(std::string("You are a helpful assistant."), sys_msg.at("content").get()); + } + + // Test with max_output_tokens conversion + { + json input = json::parse(R"({ + "input": "Hello", + "model": "test-model", + "max_output_tokens": 100 + })"); + + json result = server_chat_convert_responses_to_chatcmpl(input); + + assert_equals(true, result.contains("max_tokens")); + assert_equals(false, result.contains("max_output_tokens")); + assert_equals(100, result.at("max_tokens").get()); + } +} + static void test_template_output_peg_parsers(bool detailed_debug) { LOG_DBG("%s\n", __func__); @@ -4291,7 +4403,7 @@ int main(int argc, char ** argv) { bool detailed_debug = false; bool only_run_filtered = false; - // Check for --template flag + // Check for --template and --detailed flags for (int i = 1; i < argc; i++) { std::string arg = argv[i]; if (arg == "--template" && i + 1 < argc) { @@ -4316,7 +4428,20 @@ int main(int argc, char ** argv) { } #ifndef _WIN32 - if (argc > 1) { + // Check if any argument is a .jinja file (for template format detection mode) + bool has_jinja_files = false; + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + if (arg == "--detailed") { + continue; + } + if (arg.size() >= 6 && arg.rfind(".jinja") == arg.size() - 6) { + has_jinja_files = true; + break; + } + } + + if (has_jinja_files) { common_chat_templates_inputs inputs; common_chat_msg msg; msg.role = "user"; @@ -4349,6 +4474,7 @@ int main(int argc, char ** argv) { test_msg_diffs_compute(); test_msgs_oaicompat_json_conversion(); test_tools_oaicompat_json_conversion(); + test_convert_responses_to_chatcmpl(); test_developer_role_to_system_workaround(); test_reka_edge_common_path(); test_template_output_peg_parsers(detailed_debug); diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index 0cce99f59..71cc0e7a8 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -5,6 +5,8 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}) set(TARGET server-context) add_library(${TARGET} STATIC + server-chat.cpp + server-chat.h server-task.cpp server-task.h server-queue.cpp diff --git a/tools/server/server-chat.cpp b/tools/server/server-chat.cpp new file mode 100644 index 000000000..4fe81553c --- /dev/null +++ b/tools/server/server-chat.cpp @@ -0,0 +1,588 @@ +#include "server-chat.h" +#include "server-common.h" + +#include + +json server_chat_convert_responses_to_chatcmpl(const json & response_body) { + if (!response_body.contains("input")) { + throw std::invalid_argument("'input' is required"); + } + if (!json_value(response_body, "previous_response_id", std::string{}).empty()) { + throw std::invalid_argument("llama.cpp does not support 'previous_response_id'."); + } + + const json input_value = response_body.at("input"); + json chatcmpl_body = response_body; + chatcmpl_body.erase("input"); + std::vector chatcmpl_messages; + + if (response_body.contains("instructions")) { + chatcmpl_messages.push_back({ + {"role", "system"}, + {"content", json_value(response_body, "instructions", std::string())}, + }); + chatcmpl_body.erase("instructions"); + } + + if (input_value.is_string()) { + // #responses_create-input-text_input + chatcmpl_messages.push_back({ + {"role", "user"}, + {"content", input_value}, + }); + } else if (input_value.is_array()) { + // #responses_create-input-input_item_list + + static auto exists_and_is_array = [](const json & j, const char * key) -> bool { + return j.contains(key) && j.at(key).is_array(); + }; + static auto exists_and_is_string = [](const json & j, const char * key) -> bool { + return j.contains(key) && j.at(key).is_string(); + }; + + for (json item : input_value) { + bool merge_prev = !chatcmpl_messages.empty() && chatcmpl_messages.back().value("role", "") == "assistant"; + + if (exists_and_is_string(item, "content")) { + // #responses_create-input-input_item_list-input_message-content-text_input + // Only "Input message" contains item["content"]::string + // After converting item["content"]::string to item["content"]::array, + // we can treat "Input message" as sum of "Item-Input message" and "Item-Output message" + item["content"] = json::array({ + json { + {"text", item.at("content")}, + {"type", "input_text"} + } + }); + } + + if (exists_and_is_array(item, "content") && + exists_and_is_string(item, "role") && + (item.at("role") == "user" || + item.at("role") == "system" || + item.at("role") == "developer") + ) { + // #responses_create-input-input_item_list-item-input_message + std::vector chatcmpl_content; + + for (const json & input_item : item.at("content")) { + const std::string type = json_value(input_item, "type", std::string()); + + if (type == "input_text") { + if (!input_item.contains("text")) { + throw std::invalid_argument("'Input text' requires 'text'"); + } + chatcmpl_content.push_back({ + {"text", input_item.at("text")}, + {"type", "text"}, + }); + } else if (type == "input_image") { + // While `detail` is marked as required, + // it has default value("auto") and can be omitted. + + if (!input_item.contains("image_url")) { + throw std::invalid_argument("'image_url' is required"); + } + chatcmpl_content.push_back({ + {"image_url", json { + {"url", input_item.at("image_url")} + }}, + {"type", "image_url"}, + }); + } else if (type == "input_file") { + throw std::invalid_argument("'input_file' is not supported by llamacpp at this moment"); + } else { + throw std::invalid_argument("'type' must be one of 'input_text', 'input_image', or 'input_file'"); + } + } + + if (item.contains("type")) { + item.erase("type"); + } + if (item.contains("status")) { + item.erase("status"); + } + item["content"] = chatcmpl_content; + + chatcmpl_messages.push_back(item); + } else if (exists_and_is_string(item, "role") && + item.at("role") == "assistant" && + exists_and_is_string(item, "type") && + item.at("type") == "message" + ) { + // #responses_create-input-input_item_list-item-output_message + auto chatcmpl_content = json::array(); + + // Handle both string content and array content + if (item.contains("content") && item.at("content").is_string()) { + // String content - convert to text content part + chatcmpl_content.push_back({ + {"text", item.at("content")}, + {"type", "text"}, + }); + } else if (exists_and_is_array(item, "content")) { + // Array content - process each item + for (const auto & output_text : item.at("content")) { + const std::string type = json_value(output_text, "type", std::string()); + if (type == "output_text" || type == "input_text") { + // Accept both output_text and input_text (string content gets converted to input_text) + if (!exists_and_is_string(output_text, "text")) { + throw std::invalid_argument("'Output text' requires 'text'"); + } + chatcmpl_content.push_back({ + {"text", output_text.at("text")}, + {"type", "text"}, + }); + } else if (type == "refusal") { + if (!exists_and_is_string(output_text, "refusal")) { + throw std::invalid_argument("'Refusal' requires 'refusal'"); + } + chatcmpl_content.push_back({ + {"refusal", output_text.at("refusal")}, + {"type", "refusal"}, + }); + } else { + throw std::invalid_argument("'type' must be one of 'output_text' or 'refusal'"); + } + } + } + + if (merge_prev) { + auto & prev_msg = chatcmpl_messages.back(); + if (!exists_and_is_array(prev_msg, "content")) { + prev_msg["content"] = json::array(); + } + auto & prev_content = prev_msg["content"]; + prev_content.insert(prev_content.end(), chatcmpl_content.begin(), chatcmpl_content.end()); + } else { + item.erase("status"); + item.erase("type"); + item["content"] = chatcmpl_content; + chatcmpl_messages.push_back(item); + } + } else if (exists_and_is_string(item, "arguments") && + exists_and_is_string(item, "call_id") && + exists_and_is_string(item, "name") && + exists_and_is_string(item, "type") && + item.at("type") == "function_call" + ) { + // #responses_create-input-input_item_list-item-function_tool_call + json tool_call = { + {"function", json { + {"arguments", item.at("arguments")}, + {"name", item.at("name")}, + }}, + {"id", item.at("call_id")}, + {"type", "function"}, + }; + + if (merge_prev) { + auto & prev_msg = chatcmpl_messages.back(); + if (!exists_and_is_array(prev_msg, "tool_calls")) { + prev_msg["tool_calls"] = json::array(); + } + prev_msg["tool_calls"].push_back(tool_call); + } else { + chatcmpl_messages.push_back(json { + {"role", "assistant"}, + {"tool_calls", json::array({tool_call})} + }); + } + } else if (exists_and_is_string(item, "call_id") && + (exists_and_is_string(item, "output") || exists_and_is_array(item, "output")) && + exists_and_is_string(item, "type") && + item.at("type") == "function_call_output" + ) { + // #responses_create-input-input_item_list-item-function_tool_call_output + if (item.at("output").is_string()) { + chatcmpl_messages.push_back(json { + {"content", item.at("output")}, + {"role", "tool"}, + {"tool_call_id", item.at("call_id")}, + }); + } else { + json chatcmpl_outputs = item.at("output"); + for (json & chatcmpl_output : chatcmpl_outputs) { + if (!chatcmpl_output.contains("type") || chatcmpl_output.at("type") != "input_text") { + throw std::invalid_argument("Output of tool call should be 'Input text'"); + } + chatcmpl_output["type"] = "text"; + } + chatcmpl_messages.push_back(json { + {"content", chatcmpl_outputs}, + {"role", "tool"}, + {"tool_call_id", item.at("call_id")}, + }); + } + } else if (exists_and_is_array(item, "summary") && + exists_and_is_string(item, "type") && + item.at("type") == "reasoning") { + // #responses_create-input-input_item_list-item-reasoning + + if (!exists_and_is_array(item, "content")) { + throw std::invalid_argument("item['content'] is not an array"); + } + if (item.at("content").empty()) { + throw std::invalid_argument("item['content'] is empty"); + } + if (!exists_and_is_string(item.at("content")[0], "text")) { + throw std::invalid_argument("item['content']['text'] is not a string"); + } + + if (merge_prev) { + auto & prev_msg = chatcmpl_messages.back(); + prev_msg["reasoning_content"] = item.at("content")[0].at("text"); + } else { + chatcmpl_messages.push_back(json { + {"role", "assistant"}, + {"content", json::array()}, + {"reasoning_content", item.at("content")[0].at("text")}, + }); + } + } else { + throw std::invalid_argument("Cannot determine type of 'item'"); + } + } + } else { + throw std::invalid_argument("'input' must be a string or array of objects"); + } + + chatcmpl_body["messages"] = chatcmpl_messages; + + if (response_body.contains("tools")) { + if (!response_body.at("tools").is_array()) { + throw std::invalid_argument("'tools' must be an array of objects"); + } + std::vector chatcmpl_tools; + for (json resp_tool : response_body.at("tools")) { + json chatcmpl_tool; + + if (json_value(resp_tool, "type", std::string()) != "function") { + throw std::invalid_argument("'type' of tool must be 'function'"); + } + resp_tool.erase("type"); + chatcmpl_tool["type"] = "function"; + + if (!resp_tool.contains("strict")) { + resp_tool["strict"] = true; + } + chatcmpl_tool["function"] = resp_tool; + chatcmpl_tools.push_back(chatcmpl_tool); + } + chatcmpl_body.erase("tools"); + chatcmpl_body["tools"] = chatcmpl_tools; + } + + if (response_body.contains("max_output_tokens")) { + chatcmpl_body.erase("max_output_tokens"); + chatcmpl_body["max_tokens"] = response_body["max_output_tokens"]; + } + + return chatcmpl_body; +} + +json server_chat_convert_anthropic_to_oai(const json & body) { + json oai_body; + + // Convert system prompt + json oai_messages = json::array(); + auto system_param = json_value(body, "system", json()); + if (!system_param.is_null()) { + std::string system_content; + + if (system_param.is_string()) { + system_content = system_param.get(); + } else if (system_param.is_array()) { + for (const auto & block : system_param) { + if (json_value(block, "type", std::string()) == "text") { + system_content += json_value(block, "text", std::string()); + } + } + } + + oai_messages.push_back({ + {"role", "system"}, + {"content", system_content} + }); + } + + // Convert messages + if (!body.contains("messages")) { + throw std::runtime_error("'messages' is required"); + } + const json & messages = body.at("messages"); + if (messages.is_array()) { + for (const auto & msg : messages) { + std::string role = json_value(msg, "role", std::string()); + + if (!msg.contains("content")) { + if (role == "assistant") { + continue; + } + oai_messages.push_back(msg); + continue; + } + + const json & content = msg.at("content"); + + if (content.is_string()) { + oai_messages.push_back(msg); + continue; + } + + if (!content.is_array()) { + oai_messages.push_back(msg); + continue; + } + + json tool_calls = json::array(); + json converted_content = json::array(); + json tool_results = json::array(); + std::string reasoning_content; + bool has_tool_calls = false; + + for (const auto & block : content) { + std::string type = json_value(block, "type", std::string()); + + if (type == "text") { + converted_content.push_back(block); + } else if (type == "thinking") { + reasoning_content += json_value(block, "thinking", std::string()); + } else if (type == "image") { + json source = json_value(block, "source", json::object()); + std::string source_type = json_value(source, "type", std::string()); + + if (source_type == "base64") { + std::string media_type = json_value(source, "media_type", std::string("image/jpeg")); + std::string data = json_value(source, "data", std::string()); + std::ostringstream ss; + ss << "data:" << media_type << ";base64," << data; + + converted_content.push_back({ + {"type", "image_url"}, + {"image_url", { + {"url", ss.str()} + }} + }); + } else if (source_type == "url") { + std::string url = json_value(source, "url", std::string()); + converted_content.push_back({ + {"type", "image_url"}, + {"image_url", { + {"url", url} + }} + }); + } + } else if (type == "tool_use") { + tool_calls.push_back({ + {"id", json_value(block, "id", std::string())}, + {"type", "function"}, + {"function", { + {"name", json_value(block, "name", std::string())}, + {"arguments", json_value(block, "input", json::object()).dump()} + }} + }); + has_tool_calls = true; + } else if (type == "tool_result") { + std::string tool_use_id = json_value(block, "tool_use_id", std::string()); + + auto result_content = json_value(block, "content", json()); + std::string result_text; + if (result_content.is_string()) { + result_text = result_content.get(); + } else if (result_content.is_array()) { + for (const auto & c : result_content) { + if (json_value(c, "type", std::string()) == "text") { + result_text += json_value(c, "text", std::string()); + } + } + } + + tool_results.push_back({ + {"role", "tool"}, + {"tool_call_id", tool_use_id}, + {"content", result_text} + }); + } + } + + if (!converted_content.empty() || has_tool_calls || !reasoning_content.empty()) { + json new_msg = {{"role", role}}; + if (!converted_content.empty()) { + new_msg["content"] = converted_content; + } else if (has_tool_calls || !reasoning_content.empty()) { + new_msg["content"] = ""; + } + if (!tool_calls.empty()) { + new_msg["tool_calls"] = tool_calls; + } + if (!reasoning_content.empty()) { + new_msg["reasoning_content"] = reasoning_content; + } + oai_messages.push_back(new_msg); + } + + for (const auto & tool_msg : tool_results) { + oai_messages.push_back(tool_msg); + } + } + } + + oai_body["messages"] = oai_messages; + + // Convert tools + if (body.contains("tools")) { + const json & tools = body.at("tools"); + if (tools.is_array()) { + json oai_tools = json::array(); + for (const auto & tool : tools) { + oai_tools.push_back({ + {"type", "function"}, + {"function", { + {"name", json_value(tool, "name", std::string())}, + {"description", json_value(tool, "description", std::string())}, + {"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()} + }} + }); + } + oai_body["tools"] = oai_tools; + } + } + + // Convert tool_choice + if (body.contains("tool_choice")) { + const json & tc = body.at("tool_choice"); + if (tc.is_object()) { + std::string type = json_value(tc, "type", std::string()); + if (type == "auto") { + oai_body["tool_choice"] = "auto"; + } else if (type == "any" || type == "tool") { + oai_body["tool_choice"] = "required"; + } + } + } + + // Convert stop_sequences to stop + if (body.contains("stop_sequences")) { + oai_body["stop"] = body.at("stop_sequences"); + } + + // Handle max_tokens (required in Anthropic, but we're permissive) + if (body.contains("max_tokens")) { + oai_body["max_tokens"] = body.at("max_tokens"); + } else { + oai_body["max_tokens"] = 4096; + } + + // Pass through common params + for (const auto & key : {"temperature", "top_p", "top_k", "stream"}) { + if (body.contains(key)) { + oai_body[key] = body.at(key); + } + } + + // Handle Anthropic-specific thinking param + if (body.contains("thinking")) { + json thinking = json_value(body, "thinking", json::object()); + std::string thinking_type = json_value(thinking, "type", std::string()); + if (thinking_type == "enabled") { + int budget_tokens = json_value(thinking, "budget_tokens", 10000); + oai_body["thinking_budget_tokens"] = budget_tokens; + } + } + + // Handle Anthropic-specific metadata param + if (body.contains("metadata")) { + json metadata = json_value(body, "metadata", json::object()); + std::string user_id = json_value(metadata, "user_id", std::string()); + if (!user_id.empty()) { + oai_body["__metadata_user_id"] = user_id; + } + } + + return oai_body; +} + +json server_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) { + json delta = json::object(); + if (!diff.reasoning_content_delta.empty()) { + delta["reasoning_content"] = diff.reasoning_content_delta; + } + if (!diff.content_delta.empty()) { + delta["content"] = diff.content_delta; + } + if (diff.tool_call_index != std::string::npos) { + json tool_call; + tool_call["index"] = diff.tool_call_index; + if (!diff.tool_call_delta.id.empty()) { + tool_call["id"] = diff.tool_call_delta.id; + tool_call["type"] = "function"; + } + if (!diff.tool_call_delta.name.empty() || !diff.tool_call_delta.arguments.empty()) { + json function = json::object(); + if (!diff.tool_call_delta.name.empty()) { + function["name"] = diff.tool_call_delta.name; + } + if (!diff.tool_call_delta.arguments.empty()) { + function["arguments"] = diff.tool_call_delta.arguments; + } + tool_call["function"] = function; + } + delta["tool_calls"] = json::array({ tool_call }); + } + return delta; +} + +json convert_transcriptions_to_chatcmpl( + const json & inp_body, + const std::map & in_files, + std::vector & out_files) { + // TODO @ngxson : this function may need to be improved in the future + // handle input files + out_files.clear(); + auto it = in_files.find("file"); + if (it != in_files.end()) { + out_files.push_back(it->second); + } else { + throw std::invalid_argument("No input file found for transcription"); + } + + // handle input data + std::string prompt = json_value(inp_body, "prompt", std::string()); + std::string language = json_value(inp_body, "language", std::string()); + std::string response_format = json_value(inp_body, "response_format", std::string("json")); + if (response_format != "json") { + throw std::invalid_argument("Only 'json' response_format is supported for transcription"); + } + if (prompt.empty()) { + prompt = "Transcribe audio to text"; + } + if (!language.empty()) { + prompt += string_format(" (language: %s)", language.c_str()); + } + prompt += get_media_marker(); + + json chatcmpl_body = inp_body; // copy all fields + chatcmpl_body["messages"] = json::array({ + { + {"role", "user"}, + {"content", prompt}, + }, + }); + + // because input from form-data, everything is string, we need to correct the types here + std::string stream = json_value(inp_body, "stream", std::string("false")); + chatcmpl_body["stream"] = stream == "true"; + + if (inp_body.contains("max_tokens")) { + std::string inp = inp_body["max_tokens"].get(); + chatcmpl_body["max_tokens"] = std::stoul(inp); + } + + if (inp_body.contains("temperature")) { + std::string inp = inp_body["temperature"].get(); + chatcmpl_body["temperature"] = std::stof(inp); + } + + return chatcmpl_body; +} diff --git a/tools/server/server-chat.h b/tools/server/server-chat.h new file mode 100644 index 000000000..ecb8907c4 --- /dev/null +++ b/tools/server/server-chat.h @@ -0,0 +1,24 @@ +// Chat conversion functions for server (Responses API, Anthropic API, OAI streaming diffs) + +#pragma once + +#include "chat.h" +#include "server-common.h" + +#include + +using json = nlohmann::ordered_json; + +// Convert OpenAI Responses API format to OpenAI Chat Completions API format +json server_chat_convert_responses_to_chatcmpl(const json & body); + +// Convert Anthropic Messages API format to OpenAI Chat Completions API format +json server_chat_convert_anthropic_to_oai(const json & body); + +// convert OpenAI transcriptions API format to OpenAI Chat Completions API format +json convert_transcriptions_to_chatcmpl( + const json & body, + const std::map & in_files, + std::vector & out_files); + +json server_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff); diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index cae64884b..18a317e1d 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -1164,573 +1164,6 @@ json oaicompat_chat_params_parse( return llama_params; } -json convert_responses_to_chatcmpl(const json & response_body) { - if (!response_body.contains("input")) { - throw std::invalid_argument("'input' is required"); - } - if (!json_value(response_body, "previous_response_id", std::string{}).empty()) { - throw std::invalid_argument("llama.cpp does not support 'previous_response_id'."); - } - - const json input_value = response_body.at("input"); - json chatcmpl_body = response_body; - chatcmpl_body.erase("input"); - std::vector chatcmpl_messages; - - if (response_body.contains("instructions")) { - chatcmpl_messages.push_back({ - {"role", "system"}, - {"content", json_value(response_body, "instructions", std::string())}, - }); - chatcmpl_body.erase("instructions"); - } - - if (input_value.is_string()) { - // #responses_create-input-text_input - chatcmpl_messages.push_back({ - {"role", "user"}, - {"content", input_value}, - }); - } else if (input_value.is_array()) { - // #responses_create-input-input_item_list - - static auto exists_and_is_array = [](const json & j, const char * key) -> bool { - return j.contains(key) && j.at(key).is_array(); - }; - static auto exists_and_is_string = [](const json & j, const char * key) -> bool { - return j.contains(key) && j.at(key).is_string(); - }; - - for (json item : input_value) { - bool merge_prev = !chatcmpl_messages.empty() && chatcmpl_messages.back().value("role", "") == "assistant"; - - if (exists_and_is_string(item, "content")) { - // #responses_create-input-input_item_list-input_message-content-text_input - // Only "Input message" contains item["content"]::string - // After converting item["content"]::string to item["content"]::array, - // we can treat "Input message" as sum of "Item-Input message" and "Item-Output message" - item["content"] = json::array({ - json { - {"text", item.at("content")}, - {"type", "input_text"} - } - }); - } - - if (exists_and_is_array(item, "content") && - exists_and_is_string(item, "role") && - (item.at("role") == "user" || - item.at("role") == "system" || - item.at("role") == "developer") - ) { - // #responses_create-input-input_item_list-item-input_message - std::vector chatcmpl_content; - - for (const json & input_item : item.at("content")) { - const std::string type = json_value(input_item, "type", std::string()); - - if (type == "input_text") { - if (!input_item.contains("text")) { - throw std::invalid_argument("'Input text' requires 'text'"); - } - chatcmpl_content.push_back({ - {"text", input_item.at("text")}, - {"type", "text"}, - }); - } else if (type == "input_image") { - // While `detail` is marked as required, - // it has default value("auto") and can be omitted. - - if (!input_item.contains("image_url")) { - throw std::invalid_argument("'image_url' is required"); - } - chatcmpl_content.push_back({ - {"image_url", json { - {"url", input_item.at("image_url")} - }}, - {"type", "image_url"}, - }); - } else if (type == "input_file") { - throw std::invalid_argument("'input_file' is not supported by llamacpp at this moment"); - // if (input_item.contains("file_url")) { - // // chat completion API does not support file_url - // throw std::invalid_argument("'file_url' is not supported"); - // } - // if (!input_item.contains("file_data") || !input_item.contains("filename")) { - // throw std::invalid_argument("Both 'file_data' and 'filename' are required"); - // } - // chatcmpl_content.push_back({ - // {"file", json { - // {"file_data", input_item.at("file_data")}, - // {"filename", input_item.at("filename")}, - // }}, - // {"type", "file"}, - // }); - } else { - throw std::invalid_argument("'type' must be one of 'input_text', 'input_image', or 'input_file'"); - } - } - - if (item.contains("type")) { - item.erase("type"); - } - if (item.contains("status")) { - item.erase("status"); - } - item["content"] = chatcmpl_content; - - chatcmpl_messages.push_back(item); - } else if (exists_and_is_array(item, "content") && - exists_and_is_string(item, "role") && - item.at("role") == "assistant" && - // exists_and_is_string(item, "status") && - // (item.at("status") == "in_progress" || - // item.at("status") == "completed" || - // item.at("status") == "incomplete") && - // item["status"] not sent by codex-cli - exists_and_is_string(item, "type") && - item.at("type") == "message" - ) { - // #responses_create-input-input_item_list-item-output_message - auto chatcmpl_content = json::array(); - - for (const auto & output_text : item.at("content")) { - const std::string type = json_value(output_text, "type", std::string()); - if (type == "output_text") { - if (!exists_and_is_string(output_text, "text")) { - throw std::invalid_argument("'Output text' requires 'text'"); - // Ignore annotations and logprobs for now - chatcmpl_content.push_back({ - {"text", output_text.at("text")}, - {"type", "text"}, - }); - } - } else if (type == "refusal") { - if (!exists_and_is_string(output_text, "refusal")) { - throw std::invalid_argument("'Refusal' requires 'refusal'"); - // Ignore annotations and logprobs for now - chatcmpl_content.push_back({ - {"refusal", output_text.at("refusal")}, - {"type", "refusal"}, - }); - } - } else { - throw std::invalid_argument("'type' must be one of 'output_text' or 'refusal'"); - } - } - - if (merge_prev) { - auto & prev_msg = chatcmpl_messages.back(); - if (!exists_and_is_array(prev_msg, "content")) { - prev_msg["content"] = json::array(); - } - auto & prev_content = prev_msg["content"]; - prev_content.insert(prev_content.end(), chatcmpl_content.begin(), chatcmpl_content.end()); - } else { - item.erase("status"); - item.erase("type"); - item["content"] = chatcmpl_content; - chatcmpl_messages.push_back(item); - } - } else if (exists_and_is_string(item, "arguments") && - exists_and_is_string(item, "call_id") && - exists_and_is_string(item, "name") && - exists_and_is_string(item, "type") && - item.at("type") == "function_call" - ) { - // #responses_create-input-input_item_list-item-function_tool_call - json tool_call = { - {"function", json { - {"arguments", item.at("arguments")}, - {"name", item.at("name")}, - }}, - {"id", item.at("call_id")}, - {"type", "function"}, - }; - - if (merge_prev) { - auto & prev_msg = chatcmpl_messages.back(); - if (!exists_and_is_array(prev_msg, "tool_calls")) { - prev_msg["tool_calls"] = json::array(); - } - prev_msg["tool_calls"].push_back(tool_call); - } else { - chatcmpl_messages.push_back(json { - {"role", "assistant"}, - {"tool_calls", json::array({tool_call})} - }); - } - } else if (exists_and_is_string(item, "call_id") && - (exists_and_is_string(item, "output") || exists_and_is_array(item, "output")) && - exists_and_is_string(item, "type") && - item.at("type") == "function_call_output" - ) { - // #responses_create-input-input_item_list-item-function_tool_call_output - if (item.at("output").is_string()) { - chatcmpl_messages.push_back(json { - {"content", item.at("output")}, - {"role", "tool"}, - {"tool_call_id", item.at("call_id")}, - }); - } else { - json chatcmpl_outputs = item.at("output"); - for (json & chatcmpl_output : chatcmpl_outputs) { - if (!chatcmpl_output.contains("type") || chatcmpl_output.at("type") != "input_text") { - throw std::invalid_argument("Output of tool call should be 'Input text'"); - } - chatcmpl_output["type"] = "text"; - } - chatcmpl_messages.push_back(json { - {"content", chatcmpl_outputs}, - {"role", "tool"}, - {"tool_call_id", item.at("call_id")}, - }); - } - } else if (// exists_and_is_string(item, "id") && - // item["id"] not sent by codex-cli - exists_and_is_array(item, "summary") && - exists_and_is_string(item, "type") && - item.at("type") == "reasoning") { - // #responses_create-input-input_item_list-item-reasoning - - if (!exists_and_is_array(item, "content")) { - throw std::invalid_argument("item['content'] is not an array"); - } - if (item.at("content").empty()) { - throw std::invalid_argument("item['content'] is empty"); - } - if (!exists_and_is_string(item.at("content")[0], "text")) { - throw std::invalid_argument("item['content']['text'] is not a string"); - } - - if (merge_prev) { - auto & prev_msg = chatcmpl_messages.back(); - prev_msg["reasoning_content"] = item.at("content")[0].at("text"); - } else { - chatcmpl_messages.push_back(json { - {"role", "assistant"}, - {"content", json::array()}, - {"reasoning_content", item.at("content")[0].at("text")}, - }); - } - } else { - throw std::invalid_argument("Cannot determine type of 'item'"); - } - } - } else { - throw std::invalid_argument("'input' must be a string or array of objects"); - } - - chatcmpl_body["messages"] = chatcmpl_messages; - - if (response_body.contains("tools")) { - if (!response_body.at("tools").is_array()) { - throw std::invalid_argument("'tools' must be an array of objects"); - } - std::vector chatcmpl_tools; - for (json resp_tool : response_body.at("tools")) { - json chatcmpl_tool; - - if (json_value(resp_tool, "type", std::string()) != "function") { - throw std::invalid_argument("'type' of tool must be 'function'"); - } - resp_tool.erase("type"); - chatcmpl_tool["type"] = "function"; - - if (!resp_tool.contains("strict")) { - resp_tool["strict"] = true; - } - chatcmpl_tool["function"] = resp_tool; - chatcmpl_tools.push_back(chatcmpl_tool); - } - chatcmpl_body.erase("tools"); - chatcmpl_body["tools"] = chatcmpl_tools; - } - - if (response_body.contains("max_output_tokens")) { - chatcmpl_body.erase("max_output_tokens"); - chatcmpl_body["max_tokens"] = response_body["max_output_tokens"]; - } - - return chatcmpl_body; -} - -json convert_transcriptions_to_chatcmpl( - const json & inp_body, - const std::map & in_files, - std::vector & out_files) { - // TODO @ngxson : this function may need to be improved in the future - // handle input files - out_files.clear(); - auto it = in_files.find("file"); - if (it != in_files.end()) { - out_files.push_back(it->second); - } else { - throw std::invalid_argument("No input file found for transcription"); - } - - // handle input data - std::string prompt = json_value(inp_body, "prompt", std::string()); - std::string language = json_value(inp_body, "language", std::string()); - std::string response_format = json_value(inp_body, "response_format", std::string("json")); - if (response_format != "json") { - throw std::invalid_argument("Only 'json' response_format is supported for transcription"); - } - if (prompt.empty()) { - prompt = "Transcribe audio to text"; - } - if (!language.empty()) { - prompt += string_format(" (language: %s)", language.c_str()); - } - prompt += get_media_marker(); - - json chatcmpl_body = inp_body; // copy all fields - chatcmpl_body["messages"] = json::array({ - { - {"role", "user"}, - {"content", prompt}, - }, - }); - - // because input from form-data, everything is string, we need to correct the types here - std::string stream = json_value(inp_body, "stream", std::string("false")); - chatcmpl_body["stream"] = stream == "true"; - - if (inp_body.contains("max_tokens")) { - std::string inp = inp_body["max_tokens"].get(); - chatcmpl_body["max_tokens"] = std::stoul(inp); - } - - if (inp_body.contains("temperature")) { - std::string inp = inp_body["temperature"].get(); - chatcmpl_body["temperature"] = std::stof(inp); - } - - return chatcmpl_body; -} - -json convert_anthropic_to_oai(const json & body) { - json oai_body; - - // Convert system prompt - json oai_messages = json::array(); - auto system_param = json_value(body, "system", json()); - if (!system_param.is_null()) { - std::string system_content; - - if (system_param.is_string()) { - system_content = system_param.get(); - } else if (system_param.is_array()) { - for (const auto & block : system_param) { - if (json_value(block, "type", std::string()) == "text") { - system_content += json_value(block, "text", std::string()); - } - } - } - - oai_messages.push_back({ - {"role", "system"}, - {"content", system_content} - }); - } - - // Convert messages - if (!body.contains("messages")) { - throw std::runtime_error("'messages' is required"); - } - const json & messages = body.at("messages"); - if (messages.is_array()) { - for (const auto & msg : messages) { - std::string role = json_value(msg, "role", std::string()); - - if (!msg.contains("content")) { - if (role == "assistant") { - continue; - } - oai_messages.push_back(msg); - continue; - } - - const json & content = msg.at("content"); - - if (content.is_string()) { - oai_messages.push_back(msg); - continue; - } - - if (!content.is_array()) { - oai_messages.push_back(msg); - continue; - } - - json tool_calls = json::array(); - json converted_content = json::array(); - json tool_results = json::array(); - std::string reasoning_content; - bool has_tool_calls = false; - - for (const auto & block : content) { - std::string type = json_value(block, "type", std::string()); - - if (type == "text") { - converted_content.push_back(block); - } else if (type == "thinking") { - reasoning_content += json_value(block, "thinking", std::string()); - } else if (type == "image") { - json source = json_value(block, "source", json::object()); - std::string source_type = json_value(source, "type", std::string()); - - if (source_type == "base64") { - std::string media_type = json_value(source, "media_type", std::string("image/jpeg")); - std::string data = json_value(source, "data", std::string()); - std::ostringstream ss; - ss << "data:" << media_type << ";base64," << data; - - converted_content.push_back({ - {"type", "image_url"}, - {"image_url", { - {"url", ss.str()} - }} - }); - } else if (source_type == "url") { - std::string url = json_value(source, "url", std::string()); - converted_content.push_back({ - {"type", "image_url"}, - {"image_url", { - {"url", url} - }} - }); - } - } else if (type == "tool_use") { - tool_calls.push_back({ - {"id", json_value(block, "id", std::string())}, - {"type", "function"}, - {"function", { - {"name", json_value(block, "name", std::string())}, - {"arguments", json_value(block, "input", json::object()).dump()} - }} - }); - has_tool_calls = true; - } else if (type == "tool_result") { - std::string tool_use_id = json_value(block, "tool_use_id", std::string()); - - auto result_content = json_value(block, "content", json()); - std::string result_text; - if (result_content.is_string()) { - result_text = result_content.get(); - } else if (result_content.is_array()) { - for (const auto & c : result_content) { - if (json_value(c, "type", std::string()) == "text") { - result_text += json_value(c, "text", std::string()); - } - } - } - - tool_results.push_back({ - {"role", "tool"}, - {"tool_call_id", tool_use_id}, - {"content", result_text} - }); - } - } - - if (!converted_content.empty() || has_tool_calls || !reasoning_content.empty()) { - json new_msg = {{"role", role}}; - if (!converted_content.empty()) { - new_msg["content"] = converted_content; - } else if (has_tool_calls || !reasoning_content.empty()) { - new_msg["content"] = ""; - } - if (!tool_calls.empty()) { - new_msg["tool_calls"] = tool_calls; - } - if (!reasoning_content.empty()) { - new_msg["reasoning_content"] = reasoning_content; - } - oai_messages.push_back(new_msg); - } - - for (const auto & tool_msg : tool_results) { - oai_messages.push_back(tool_msg); - } - } - } - - oai_body["messages"] = oai_messages; - - // Convert tools - if (body.contains("tools")) { - const json & tools = body.at("tools"); - if (tools.is_array()) { - json oai_tools = json::array(); - for (const auto & tool : tools) { - oai_tools.push_back({ - {"type", "function"}, - {"function", { - {"name", json_value(tool, "name", std::string())}, - {"description", json_value(tool, "description", std::string())}, - {"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()} - }} - }); - } - oai_body["tools"] = oai_tools; - } - } - - // Convert tool_choice - if (body.contains("tool_choice")) { - const json & tc = body.at("tool_choice"); - if (tc.is_object()) { - std::string type = json_value(tc, "type", std::string()); - if (type == "auto") { - oai_body["tool_choice"] = "auto"; - } else if (type == "any" || type == "tool") { - oai_body["tool_choice"] = "required"; - } - } - } - - // Convert stop_sequences to stop - if (body.contains("stop_sequences")) { - oai_body["stop"] = body.at("stop_sequences"); - } - - // Handle max_tokens (required in Anthropic, but we're permissive) - if (body.contains("max_tokens")) { - oai_body["max_tokens"] = body.at("max_tokens"); - } else { - oai_body["max_tokens"] = 4096; - } - - // Pass through common params - for (const auto & key : {"temperature", "top_p", "top_k", "stream"}) { - if (body.contains(key)) { - oai_body[key] = body.at(key); - } - } - - // Handle Anthropic-specific thinking param - if (body.contains("thinking")) { - json thinking = json_value(body, "thinking", json::object()); - std::string thinking_type = json_value(thinking, "type", std::string()); - if (thinking_type == "enabled") { - int budget_tokens = json_value(thinking, "budget_tokens", 10000); - oai_body["thinking_budget_tokens"] = budget_tokens; - } - } - - // Handle Anthropic-specific metadata param - if (body.contains("metadata")) { - json metadata = json_value(body, "metadata", json::object()); - std::string user_id = json_value(metadata, "user_id", std::string()); - if (!user_id.empty()) { - oai_body["__metadata_user_id"] = user_id; - } - } - - return oai_body; -} - json format_embeddings_response_oaicompat( const json & request, const std::string & model_name, diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 093a43453..4681f9c51 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -307,18 +307,6 @@ json oaicompat_chat_params_parse( const server_chat_params & opt, std::vector & out_files); -// convert OpenAI Responses API format to OpenAI Chat Completions API format -json convert_responses_to_chatcmpl(const json & body); - -// convert OpenAI transcriptions API format to OpenAI Chat Completions API format -json convert_transcriptions_to_chatcmpl( - const json & body, - const std::map & in_files, - std::vector & out_files); - -// convert Anthropic Messages API format to OpenAI Chat Completions API format -json convert_anthropic_to_oai(const json & body); - // TODO: move it to server-task.cpp json format_embeddings_response_oaicompat( const json & request, diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index a5372572f..a55c6356f 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1,5 +1,6 @@ #include "server-context.h" +#include "server-chat.h" #include "server-common.h" #include "server-http.h" #include "server-task.h" @@ -3774,7 +3775,7 @@ void server_routes::init_routes() { this->post_responses_oai = [this](const server_http_req & req) { auto res = create_response(); std::vector files; - json body = convert_responses_to_chatcmpl(json::parse(req.body)); + json body = server_chat_convert_responses_to_chatcmpl(json::parse(req.body)); SRV_DBG("%s\n", "Request converted: OpenAI Responses -> OpenAI Chat Completions"); SRV_DBG("converted request: %s\n", body.dump().c_str()); json body_parsed = oaicompat_chat_params_parse( @@ -3819,7 +3820,7 @@ void server_routes::init_routes() { this->post_anthropic_messages = [this](const server_http_req & req) { auto res = create_response(); std::vector files; - json body = convert_anthropic_to_oai(json::parse(req.body)); + json body = server_chat_convert_anthropic_to_oai(json::parse(req.body)); SRV_DBG("%s\n", "Request converted: Anthropic -> OpenAI Chat Completions"); SRV_DBG("converted request: %s\n", body.dump().c_str()); json body_parsed = oaicompat_chat_params_parse( @@ -3837,7 +3838,7 @@ void server_routes::init_routes() { this->post_anthropic_count_tokens = [this](const server_http_req & req) { auto res = create_response(); std::vector files; - json body = convert_anthropic_to_oai(json::parse(req.body)); + json body = server_chat_convert_anthropic_to_oai(json::parse(req.body)); SRV_DBG("%s\n", "Request converted: Anthropic -> OpenAI Chat Completions"); SRV_DBG("converted request: %s\n", body.dump().c_str()); json body_parsed = oaicompat_chat_params_parse( diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 2187b8d21..ac1e77615 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -1,6 +1,7 @@ #include "server-task.h" #include "build-info.h" +#include "server-chat.h" #include "chat.h" #include "common.h" #include "json-schema-to-grammar.h" @@ -873,7 +874,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() { json { {"finish_reason", nullptr}, {"index", index}, - {"delta", common_chat_msg_diff_to_json_oaicompat(diff)}, + {"delta", server_chat_msg_diff_to_json_oaicompat(diff)}, }, })}, {"created", t}, @@ -1522,7 +1523,7 @@ json server_task_result_cmpl_partial::to_json_oaicompat_chat() { } for (const auto & diff : oaicompat_msg_diffs) { - add_delta(common_chat_msg_diff_to_json_oaicompat(diff)); + add_delta(server_chat_msg_diff_to_json_oaicompat(diff)); } if (!deltas.empty()) { From 750579ff14198fe964ab7fc5565b1d77600deab4 Mon Sep 17 00:00:00 2001 From: Ethan Turner Date: Wed, 22 Apr 2026 01:40:19 -0700 Subject: [PATCH 14/20] common: Refactoring sampler parameters (#20429) (#22233) This change refactors the reasoning_budget_message parameter from the common params into the sampling parameters specifically. It also removes the reasoning_budget common parameter and standardizes on the existing reasoning_budget_tokens parameter in the sampling configuration. Issue: https://github.com/ggml-org/llama.cpp/issues/20429 Original PR: https://github.com/ggml-org/llama.cpp/pull/20297 --- common/arg.cpp | 4 ++-- common/common.h | 3 +-- tools/cli/cli.cpp | 4 ++-- tools/server/server-context.cpp | 4 ++-- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 5c0bbe38e..85d84e5cc 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3122,14 +3122,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "token budget for thinking: -1 for unrestricted, 0 for immediate end, N>0 for token budget (default: -1)", [](common_params & params, int value) { if (value < -1) { throw std::invalid_argument("invalid value"); } - params.reasoning_budget = value; + params.sampling.reasoning_budget_tokens = value; } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET")); add_opt(common_arg( {"--reasoning-budget-message"}, "MESSAGE", "message injected before the end-of-thinking tag when reasoning budget is exhausted (default: none)", [](common_params & params, const std::string & value) { - params.reasoning_budget_message = value; + params.sampling.reasoning_budget_message = value; } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE")); add_opt(common_arg( diff --git a/common/common.h b/common/common.h index 4137a87f1..9a8218433 100644 --- a/common/common.h +++ b/common/common.h @@ -274,6 +274,7 @@ struct common_params_sampling { std::vector reasoning_budget_start; // start tag token sequence std::vector reasoning_budget_end; // end tag token sequence std::vector reasoning_budget_forced; // forced sequence (message + end tag) + std::string reasoning_budget_message; // message injected before end tag when budget exhausted bool backend_sampling = false; @@ -581,8 +582,6 @@ struct common_params { bool force_pure_content_parser = false; common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; int enable_reasoning = -1; // -1 = auto, 0 = disable, 1 = enable - int reasoning_budget = -1; - std::string reasoning_budget_message; // message injected before end tag when budget exhausted bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time diff --git a/tools/cli/cli.cpp b/tools/cli/cli.cpp index 6ee00012d..5bdb1e78f 100644 --- a/tools/cli/cli.cpp +++ b/tools/cli/cli.cpp @@ -77,8 +77,8 @@ struct cli_context { // defaults.return_progress = true; // TODO: show progress verbose_prompt = params.verbose_prompt; - reasoning_budget = params.reasoning_budget; - reasoning_budget_message = params.reasoning_budget_message; + reasoning_budget = params.sampling.reasoning_budget_tokens; + reasoning_budget_message = params.sampling.reasoning_budget_message; } std::string generate_completion(result_timings & out_timings) { diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index a55c6356f..53f61b5a9 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1045,8 +1045,8 @@ private: /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, /* enable_thinking */ enable_thinking, - /* reasoning_budget */ params_base.reasoning_budget, - /* reasoning_budget_msg */ params_base.reasoning_budget_message, + /* reasoning_budget */ params_base.sampling.reasoning_budget_tokens, + /* reasoning_budget_msg */ params_base.sampling.reasoning_budget_message, /* media_path */ params_base.media_path, /* force_pure_content */ params_base.force_pure_content_parser }; From 7bfe60fdf929ae569b81bbbce7ff7be5a1f8e354 Mon Sep 17 00:00:00 2001 From: manayang Date: Wed, 22 Apr 2026 17:58:43 +0800 Subject: [PATCH 15/20] mtmd, llama : Update HunyuanVL vision-language model support (#22037) * mtmd, llama : add HunyuanVL vision-language model support - add LLM_ARCH_HUNYUAN_VL with M-RoPE (XD-RoPE) support - add PROJECTOR_TYPE_HUNYUANVL with PatchMerger vision encoder - add HunyuanVL-specific M-RoPE position encoding for image tokens - add GGUF conversion for HunyuanVL vision and text models - add smoke test in tools/mtmd/tests.sh * fix: fix HunyuanVL XD-RoPE h/w section order * fix: Remove redundant code * convert : fix HunyuanOCR / HunyuanVL conversion - Tested locally: both HunyuanOCR and HunyuanVL-4B convert to GGUF - successfully and produce correct inference output on Metal (F16 / Q8_0). * clip : fix -Werror=misleading-indentation in bilinear resize * fix CI: convert_hf_to_gguf type check error - convert_hf_to_gguf.py: give HunyuanVLTextModel.__init__ an explicit `dir_model: Path` parameter so ty can infer the type for load_hparams instead of reporting `Unknown | None`. --------- Co-authored-by: wendadawen --- convert_hf_to_gguf.py | 107 +++++++++++++++++++++++++++---- gguf-py/gguf/constants.py | 20 ++++++ gguf-py/gguf/gguf_writer.py | 3 + src/llama-arch.cpp | 2 + src/llama-arch.h | 2 + src/llama-hparams.h | 1 + src/llama-model.cpp | 22 +++++++ src/models/hunyuan-dense.cpp | 41 ++++++++---- tools/mtmd/clip-impl.h | 4 +- tools/mtmd/clip.cpp | 80 +++++++++++++++++++++++ tools/mtmd/models/hunyuanocr.cpp | 16 ++++- tools/mtmd/mtmd.cpp | 64 +++++++++++++++++- tools/mtmd/tests.sh | 1 + 13 files changed, 336 insertions(+), 27 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 5b4fb79fc..090686b15 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -11855,7 +11855,7 @@ class LLaDAMoEModel(TextModel): raise ValueError(f"Unprocessed experts: {experts}") -@ModelBase.register("HunYuanDenseV1ForCausalLM", "HunYuanVLForConditionalGeneration") +@ModelBase.register("HunYuanDenseV1ForCausalLM") class HunYuanModel(TextModel): model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE @@ -11994,28 +11994,58 @@ class HunYuanModel(TextModel): @ModelBase.register("HunYuanVLForConditionalGeneration") -class HunyuanOCRVisionModel(MmprojModel): +class HunyuanVLVisionModel(MmprojModel): + # Handles both HunyuanOCR and HunyuanVL, which share the HF architecture name + # "HunYuanVLForConditionalGeneration" and the `vit.perceive.*` vision layout. + # Each variant maps to a different projector type in clip.cpp so image + # preprocessing follows the correct code path. + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) assert self.hparams_vision is not None - # HunyuanOCR uses max_image_size instead of image_size + # HunyuanOCR / HunyuanVL uses max_image_size instead of image_size if "image_size" not in self.hparams_vision: self.hparams_vision["image_size"] = self.hparams_vision.get("max_image_size", 2048) + @staticmethod + def is_ocr_variant(hparams: dict) -> bool: + """Return True for HunyuanOCR, False for HunyuanVL. + + The projector's output dim must equal the text model's hidden_size by + construction (that's what "projector" means). HunyuanOCR pairs a 1B text + backbone (hidden=1024); HunyuanVL pairs a 4B one (hidden=3072). So the + ViT -> LLM projection dim is a hard architectural signature, not a + magic number. + """ + vision_out = int((hparams.get("vision_config") or {}).get("out_hidden_size", 0)) + return vision_out == 1024 + def set_gguf_parameters(self): super().set_gguf_parameters() assert self.hparams_vision is not None - hparams = self.hparams_vision - self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.HUNYUANOCR) - self.gguf_writer.add_vision_use_gelu(True) - self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("rms_norm_eps", 1e-5)) - self.gguf_writer.add_vision_spatial_merge_size(hparams.get("spatial_merge_size", 2)) - self.gguf_writer.add_vision_min_pixels(self.preprocessor_config["min_pixels"]) - self.gguf_writer.add_vision_max_pixels(self.preprocessor_config["max_pixels"]) + vcfg = self.hparams_vision + + if self.is_ocr_variant(self.global_config): + # --- HunyuanOCR --- + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.HUNYUANOCR) + self.gguf_writer.add_vision_use_gelu(True) + self.gguf_writer.add_vision_attention_layernorm_eps(vcfg.get("rms_norm_eps", 1e-5)) + self.gguf_writer.add_vision_spatial_merge_size(vcfg.get("spatial_merge_size", 2)) + self.gguf_writer.add_vision_min_pixels(self.preprocessor_config["min_pixels"]) + self.gguf_writer.add_vision_max_pixels(self.preprocessor_config["max_pixels"]) + return + + # --- HunyuanVL --- + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.HUNYUANVL) + self.gguf_writer.add_vision_use_gelu(str(vcfg["hidden_act"]).lower() == "gelu") + self.gguf_writer.add_vision_attention_layernorm_eps(float(vcfg["rms_norm_eps"])) + self.gguf_writer.add_vision_spatial_merge_size(int(vcfg["spatial_merge_size"])) + self.gguf_writer.add_vision_min_pixels(int(self.preprocessor_config["min_pixels"])) + self.gguf_writer.add_vision_max_pixels(int(self.preprocessor_config["max_pixels"])) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: if not name.startswith("vit."): - return # skip text tensors + return # strip CLS token (row 0) from position embeddings so resize_position_embeddings works if "position_embedding" in name: data_torch = data_torch[1:] # [n_patches+1, n_embd] -> [n_patches, n_embd] @@ -12023,11 +12053,66 @@ class HunyuanOCRVisionModel(MmprojModel): def tensor_force_quant(self, name, new_name, bid, n_dims): # force conv weights to F32 or F16 to avoid BF16 IM2COL issues on Metal + # Both HunyuanOCR and HunyuanVL emit the ViT -> LLM projection as mm.0/mm.2. if ("mm.0." in new_name or "mm.2." in new_name) and new_name.endswith(".weight"): return gguf.GGMLQuantizationType.F16 if self.ftype == gguf.LlamaFileType.MOSTLY_F16 else gguf.GGMLQuantizationType.F32 return super().tensor_force_quant(name, new_name, bid, n_dims) +@ModelBase.register("HunYuanVLForConditionalGeneration") +class HunyuanVLTextModel(HunYuanModel): + # The "HunYuanVLForConditionalGeneration" HF architecture covers both HunyuanOCR + # and HunyuanVL. HunyuanOCR reuses the HunYuan-Dense text backbone (standard RoPE), + # while HunyuanVL introduces a new LLM arch with XD-RoPE. Detect the variant from + # the config and pick the matching GGUF architecture. + model_arch = gguf.MODEL_ARCH.HUNYUAN_VL + + @staticmethod + def _is_ocr_config(hparams: dict) -> bool: + # OCR pairs a 1B text backbone (hidden=1024) with a ViT projector that + # outputs 1024-d; HunyuanVL uses 3072-d. Keep in sync with + # HunyuanVLVisionModel.is_ocr_variant. + return int((hparams.get("vision_config") or {}).get("out_hidden_size", 0)) == 1024 + + def __init__(self, dir_model: Path, *args, **kwargs): + raw_hparams = kwargs.get("hparams") or ModelBase.load_hparams(dir_model, is_mistral_format=False) + if self._is_ocr_config(raw_hparams): + self.model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE + else: + self.model_arch = gguf.MODEL_ARCH.HUNYUAN_VL + super().__init__(dir_model, *args, **kwargs) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + # Only emit XD-RoPE metadata for the HunyuanVL backbone; HunyuanOCR uses + # the HunYuan-Dense arch which already handles standard rope in super(). + if self.model_arch != gguf.MODEL_ARCH.HUNYUAN_VL: + return + + if self.rope_parameters.get("rope_type") != "xdrope": + return + + # defaults for HunyuanVL. The C++ side later computes: + # freq_base = rope_theta * alpha ** (head_dim / (head_dim - 2)) + self.gguf_writer.add_rope_freq_base(float(self.rope_parameters["rope_theta"])) + self.gguf_writer.add_rope_scaling_alpha(float(self.rope_parameters["alpha"])) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_rope_scaling_factor(float(self.rope_parameters.get("factor", 1))) + + ctx_len = int(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(ctx_len) + self.gguf_writer.add_context_length(ctx_len) + + self.gguf_writer.add_rope_dimension_sections(list(self.rope_parameters["xdrope_section"])) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Skip vision tensors — they are written by HunyuanVLVisionModel + if name.startswith("vit."): + return + yield from super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("SmolLM3ForCausalLM") class SmolLM3Model(LlamaModel): model_arch = gguf.MODEL_ARCH.SMOLLM3 diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c5297a2f4..83ae51ce9 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -197,6 +197,7 @@ class Keys: FREQ_BASE_SWA = "{arch}.rope.freq_base_swa" SCALING_TYPE = "{arch}.rope.scaling.type" SCALING_FACTOR = "{arch}.rope.scaling.factor" + SCALING_ALPHA = "{arch}.rope.scaling.alpha" SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" @@ -471,6 +472,7 @@ class MODEL_ARCH(IntEnum): ERNIE4_5_MOE = auto() HUNYUAN_MOE = auto() HUNYUAN_DENSE = auto() + HUNYUAN_VL = auto() SMOLLM3 = auto() GPT_OSS = auto() LFM2 = auto() @@ -957,6 +959,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.FALCON_H1: "falcon-h1", MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe", MODEL_ARCH.HUNYUAN_DENSE: "hunyuan-dense", + MODEL_ARCH.HUNYUAN_VL: "hunyuan_vl", MODEL_ARCH.SMOLLM3: "smollm3", MODEL_ARCH.GPT_OSS: "gpt-oss", MODEL_ARCH.LFM2: "lfm2", @@ -3489,6 +3492,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.HUNYUAN_VL: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.SMOLLM3: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -4138,6 +4157,7 @@ class VisionProjectorType: YOUTUVL = "youtuvl" NEMOTRON_V2_VL = "nemotron_v2_vl" HUNYUANOCR = "hunyuanocr" + HUNYUANVL = "hunyuanvl" # Items here are (block size, type size) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 90d500dc7..6a81ca37d 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -973,6 +973,9 @@ class GGUFWriter: def add_rope_scaling_factor(self, value: float) -> None: self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value) + def add_rope_scaling_alpha(self, value: float) -> None: + self.add_float32(Keys.Rope.SCALING_ALPHA.format(arch=self.arch), value) + def add_rope_scaling_attn_factors(self, value: float) -> None: self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 6904b9c1a..633a66fc6 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -109,6 +109,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" }, { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, { LLM_ARCH_HUNYUAN_DENSE, "hunyuan-dense" }, + { LLM_ARCH_HUNYUAN_VL, "hunyuan_vl" }, { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_OPENAI_MOE, "gpt-oss" }, { LLM_ARCH_LFM2, "lfm2" }, @@ -250,6 +251,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ALPHA, "%s.rope.scaling.alpha" }, { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index c4aabab7e..8f335f5c7 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -113,6 +113,7 @@ enum llm_arch { LLM_ARCH_ERNIE4_5_MOE, LLM_ARCH_HUNYUAN_MOE, LLM_ARCH_HUNYUAN_DENSE, + LLM_ARCH_HUNYUAN_VL, LLM_ARCH_SMOLLM3, LLM_ARCH_OPENAI_MOE, LLM_ARCH_LFM2, @@ -254,6 +255,7 @@ enum llm_kv { LLM_KV_ROPE_SCALE_LINEAR, LLM_KV_ROPE_SCALING_TYPE, LLM_KV_ROPE_SCALING_FACTOR, + LLM_KV_ROPE_SCALING_ALPHA, LLM_KV_ROPE_SCALING_ATTN_FACTOR, LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_FINETUNED, diff --git a/src/llama-hparams.h b/src/llama-hparams.h index c2000c77c..ac7f9ee86 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -116,6 +116,7 @@ struct llama_hparams { float rope_freq_base_train_swa = 10000.0f; float rope_freq_scale_train; float rope_freq_scale_train_swa = 1.0f; + float rope_scaling_alpha = 0.0f; // NTK-aware alpha for XDRoPE uint32_t n_ctx_orig_yarn; float rope_yarn_log_mul = 0.0f; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f77b2e921..9e2a13cbd 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -737,6 +737,13 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false); ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false); + if (arch == LLM_ARCH_HUNYUAN_VL || arch == LLM_ARCH_HUNYUAN_DENSE) { + if (hparams.n_expert <= 1) { + hparams.n_expert = 0; + hparams.n_expert_used = 0; + } + } + if (arch == LLM_ARCH_WAVTOKENIZER_DEC) { ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd); ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd_out_impl); @@ -815,6 +822,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); + ml.get_key(LLM_KV_ROPE_SCALING_ALPHA, hparams.rope_scaling_alpha, false); // non-transformer models do not have attention heads if (hparams.n_head() > 0) { @@ -2592,9 +2600,18 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_HUNYUAN_VL: case LLM_ARCH_HUNYUAN_DENSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // XDRoPE / NTK-aware scaling: base = rope_theta * alpha^(dim / (dim - 2)) + if (hparams.rope_scaling_alpha > 0.0f) { + const int dim = hparams.n_embd_head_k(); + hparams.rope_freq_base_train = hparams.rope_freq_base_train + * powf(hparams.rope_scaling_alpha, (float)dim / (float)(dim - 2)); + } switch (hparams.n_embd) { case 1024: type = LLM_TYPE_0_5B; break; @@ -6947,6 +6964,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); } } break; + case LLM_ARCH_HUNYUAN_VL: case LLM_ARCH_HUNYUAN_DENSE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -8967,6 +8985,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_HUNYUAN_VL: case LLM_ARCH_HUNYUAN_DENSE: { llm = std::make_unique(*this, params); @@ -9316,6 +9335,9 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GLM4_MOE: return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX; + case LLM_ARCH_HUNYUAN_VL: + return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX; + // all model arches should be listed explicitly here case LLM_ARCH_UNKNOWN: GGML_ABORT("unknown architecture"); diff --git a/src/models/hunyuan-dense.cpp b/src/models/hunyuan-dense.cpp index e4e837eb4..1cd85d6d9 100644 --- a/src/models/hunyuan-dense.cpp +++ b/src/models/hunyuan-dense.cpp @@ -6,6 +6,11 @@ llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, cons GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); GGML_ASSERT(n_embd_head == n_rot); + const bool use_mrope = hparams.use_mrope(); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + ggml_tensor * cur; ggml_tensor * inpL; @@ -37,22 +42,36 @@ llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, cons auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, n_embd_head, n_head, n_head_kv, il); - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + if (use_mrope) { + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } else { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 61fe82439..7d6484eea 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -150,7 +150,7 @@ #define TN_TOK_BOI "v.boi" #define TN_TOK_EOI "v.eoi" -// hunyuanocr +// hunyuanocr / hunyuanvl (shared GGUF tensor names) #define TN_MM_PRE_NORM "mm.pre_norm.%s" #define TN_TOK_IMG_BEGIN "mm.image_begin" #define TN_TOK_IMG_END "mm.image_end" @@ -303,6 +303,7 @@ enum projector_type { PROJECTOR_TYPE_KIMIK25, PROJECTOR_TYPE_NEMOTRON_V2_VL, PROJECTOR_TYPE_HUNYUANOCR, + PROJECTOR_TYPE_HUNYUANVL, PROJECTOR_TYPE_UNKNOWN, }; @@ -349,6 +350,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_KIMIK25, "kimik25"}, { PROJECTOR_TYPE_NEMOTRON_V2_VL, "nemotron_v2_vl"}, { PROJECTOR_TYPE_HUNYUANOCR, "hunyuanocr"}, + { PROJECTOR_TYPE_HUNYUANVL, "hunyuanvl"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 540b0ea41..45e39898d 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -912,6 +912,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 builder = std::make_unique(ctx, img); } break; case PROJECTOR_TYPE_HUNYUANOCR: + case PROJECTOR_TYPE_HUNYUANVL: { builder = std::make_unique(ctx, img); } break; @@ -1473,6 +1474,16 @@ struct clip_model_loader { get_u32(KEY_IMAGE_MAX_PIXELS, hparams.image_max_pixels); hparams.set_warmup_n_tokens(28*28); } break; + case PROJECTOR_TYPE_HUNYUANVL: + { + hparams.n_merge = 2; + hparams.image_resize_algo = RESIZE_ALGO_BICUBIC_PILLOW; + hparams.image_resize_pad = false; + hparams.ffn_op = FFN_GELU; + get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false); + hparams.set_limit_image_tokens(256, 16384); + hparams.set_warmup_n_tokens(32*32); + } break; case PROJECTOR_TYPE_LFM2A: { // audio preprocessing params @@ -2222,6 +2233,7 @@ struct clip_model_loader { model.mm_eoi = get_tensor(TN_TOK_EOI); } break; case PROJECTOR_TYPE_HUNYUANOCR: + case PROJECTOR_TYPE_HUNYUANVL: { // proj.0 -> mm.0 (conv1), proj.2 -> mm.2 (conv2), mlp -> mm.model.fc (linear) model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); @@ -2860,6 +2872,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_PADDLEOCR: case PROJECTOR_TYPE_HUNYUANOCR: + case PROJECTOR_TYPE_HUNYUANVL: case PROJECTOR_TYPE_YOUTUVL: return (img->nx / params.patch_size) / 2; case PROJECTOR_TYPE_STEP3VL: @@ -2879,6 +2892,7 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_PADDLEOCR: + case PROJECTOR_TYPE_HUNYUANVL: case PROJECTOR_TYPE_YOUTUVL: return (img->ny / params.patch_size) / 2; case PROJECTOR_TYPE_STEP3VL: @@ -3070,6 +3084,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im n_patches = h * (h + 1) + 1; } break; case PROJECTOR_TYPE_HUNYUANOCR: + case PROJECTOR_TYPE_HUNYUANVL: { int merge = ctx->model.hparams.n_merge; int ow = (img->nx / patch_size) / merge; @@ -3534,6 +3549,70 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima { // do nothing } break; + case PROJECTOR_TYPE_HUNYUANVL: + { + // Compute the HunyuanVL 2D position embedding on CPU (with the + // custom sf=(target+0.1)/n_grid bilinear sampling that the + // reference implementation uses) and upload it to the graph + // input declared in clip_graph_hunyuanocr::build(). + GGML_ASSERT(model.position_embeddings != nullptr); + ggml_tensor * src_t = model.position_embeddings; + const int64_t n_embd = src_t->ne[0]; + const int64_t n_pos = src_t->ne[1]; // = n_grid * n_grid + const int n_grid = (int)std::lround(std::sqrt((double)n_pos)); + GGML_ASSERT((int64_t)n_grid * n_grid == n_pos); + const int out_w = pos_w; // pw + const int out_h = pos_h; // ph + + // Pull weight to host. + std::vector src(n_embd * n_pos); + ggml_backend_tensor_get(src_t, src.data(), 0, ggml_nbytes(src_t)); + + // Output layout matches ggml_new_tensor_2d(F32, n_embd, out_h*out_w): + // ne[0] = n_embd (fastest), ne[1] = out_h*out_w + // dst[(y*out_w + x) * n_embd + c] + std::vector dst((size_t)n_embd * out_h * out_w); + + const float sx = (float)(out_w + 0.1f) / (float)n_grid; + const float sy = (float)(out_h + 0.1f) / (float)n_grid; + + for (int y = 0; y < out_h; ++y) { + // Match ggml_compute_forward_upscale_f32 pixel-center + // convention (align_corners=False): src_y = (y+0.5)/sy - 0.5. + const float fy = ((float)y + 0.5f) / sy - 0.5f; + int y0 = (int)std::floor(fy); + int y1 = y0 + 1; + y0 = std::clamp(y0, 0, n_grid - 1); + y1 = std::clamp(y1, 0, n_grid - 1); + float wy1 = std::clamp(fy - (float)y0, 0.0f, 1.0f); + const float wy0 = 1.0f - wy1; + for (int x = 0; x < out_w; ++x) { + const float fx = ((float)x + 0.5f) / sx - 0.5f; + int x0 = (int)std::floor(fx); + int x1 = x0 + 1; + x0 = std::clamp(x0, 0, n_grid - 1); + x1 = std::clamp(x1, 0, n_grid - 1); + float wx1 = std::clamp(fx - (float)x0, 0.0f, 1.0f); + const float wx0 = 1.0f - wx1; + + const float w00 = wy0 * wx0; + const float w01 = wy0 * wx1; + const float w10 = wy1 * wx0; + const float w11 = wy1 * wx1; + + const float * s00 = &src[((size_t)y0 * n_grid + x0) * n_embd]; + const float * s01 = &src[((size_t)y0 * n_grid + x1) * n_embd]; + const float * s10 = &src[((size_t)y1 * n_grid + x0) * n_embd]; + const float * s11 = &src[((size_t)y1 * n_grid + x1) * n_embd]; + float * d = &dst[((size_t)y * out_w + x) * n_embd]; + for (int c = 0; c < n_embd; ++c) { + d[c] = w00 * s00[c] + w01 * s01[c] + w10 * s10[c] + w11 * s11[c]; + } + } + } + + set_input_f32("hunyuanvl_pos_embd", dst); + } break; case PROJECTOR_TYPE_LLAMA4: { // set the 2D positions @@ -3760,6 +3839,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_YASA2: return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_HUNYUANOCR: + case PROJECTOR_TYPE_HUNYUANVL: return ctx->model.mm_model_proj->ne[1]; case PROJECTOR_TYPE_COGVLM: return ctx->model.mm_4h_to_h_w->ne[1]; diff --git a/tools/mtmd/models/hunyuanocr.cpp b/tools/mtmd/models/hunyuanocr.cpp index 37d1e2b86..45ed684f7 100644 --- a/tools/mtmd/models/hunyuanocr.cpp +++ b/tools/mtmd/models/hunyuanocr.cpp @@ -5,7 +5,21 @@ ggml_cgraph * clip_graph_hunyuanocr::build() { const int pw = n_patches_x; const int ph = n_patches_y; - ggml_tensor * pos_embd = resize_position_embeddings(GGML_SCALE_MODE_BILINEAR); + // Position embedding interpolation. + // HunyuanVL needs scale factors sf=(target+0.1)/n_grid, which the standard + // ggml_interpolate cannot express. To avoid adding a new ggml op, the + // resize is computed on CPU in clip_image_batch_encode and uploaded here + // as a graph input (named "hunyuanvl_pos_embd"). + // HunyuanOCR uses the same square layout and the standard ratio-based + // interpolation provided by resize_position_embeddings(). + ggml_tensor * pos_embd = nullptr; + if (proj_type == PROJECTOR_TYPE_HUNYUANVL && model.position_embeddings) { + pos_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ph * pw); + ggml_set_name(pos_embd, "hunyuanvl_pos_embd"); + ggml_set_input(pos_embd); + } else { + pos_embd = resize_position_embeddings(GGML_SCALE_MODE_BILINEAR); + } ggml_tensor * inp = build_inp(); ggml_tensor * cur = build_vit(inp, n_patches, NORM_TYPE_NORMAL, hparams.ffn_op, pos_embd, nullptr); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index cc3de6a85..626361b92 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -35,15 +35,23 @@ struct mtmd_bitmap { // position indexing for decoder model enum mtmd_pos_type { - MTMD_POS_TYPE_NORMAL, // number of positions equals to number of tokens - MTMD_POS_TYPE_MROPE, // qwen-vl mrope style, each image takes max(t,h,w) position indexes + MTMD_POS_TYPE_NORMAL, // number of positions equals to number of tokens + MTMD_POS_TYPE_MROPE, // qwen-vl mrope style, each image takes max(t,h,w) position indexes + MTMD_POS_TYPE_HUNYUANVL, // HunyuanVL mrope + BOI/EOI/newline layout with XD-RoPE dim-3 }; struct mtmd_image_tokens { uint32_t nx; // number of tokens in x direction uint32_t ny; // number of tokens in y direction mtmd_pos_type pos = MTMD_POS_TYPE_NORMAL; - uint32_t n_tokens() const { return nx * ny; } + uint32_t image_idx = 0; // 0-based position of this image among image chunks in the prompt(used by pos == MTMD_POS_TYPE_HUNYUANVL) + uint32_t n_tokens() const { + if (pos == MTMD_POS_TYPE_HUNYUANVL) { + // [BOI] [row0 tokens + newline] ... [row(ny-1) tokens + newline] [EOI] + return (nx + 1) * ny + 2; + } + return nx * ny; + } clip_image_f32_batch batch_f32; // preprocessed image patches std::string id; // optional user-defined ID, useful for KV cache tracking @@ -52,6 +60,7 @@ struct mtmd_image_tokens { nx, ny, pos, + image_idx, batch_f32.clone(), id }; @@ -466,6 +475,7 @@ struct mtmd_context { image_preproc = std::make_unique(ctx_v); } break; case PROJECTOR_TYPE_HUNYUANOCR: + case PROJECTOR_TYPE_HUNYUANVL: { // note: these use fullwidth | (U+FF5C) and ▁ (U+2581) to match the tokenizer vocabulary img_beg = "<|hy_place▁holder▁no▁100|>"; @@ -611,6 +621,7 @@ struct mtmd_tokenizer { const llama_vocab * vocab; mtmd_input_chunks cur; + uint32_t n_images_added = 0; // 0-based index assigned to the next image chunk mtmd_tokenizer(mtmd_context * ctx, const mtmd_input_text * text, @@ -819,6 +830,14 @@ struct mtmd_tokenizer { image_tokens->ny = 1; } image_tokens->pos = ctx->pos_type; + // HunyuanVL wraps the image grid with BOI/EOI and adds one newline per row, + // and uses XD-RoPE (dim-3 = image index). Override the position type so that + // n_tokens() and mtmd_image_tokens_get_decoder_pos pick the HunyuanVL layout. + if (ctx->proj_type_v() == PROJECTOR_TYPE_HUNYUANVL) { + image_tokens->pos = MTMD_POS_TYPE_HUNYUANVL; + image_tokens->image_idx = n_images_added; + GGML_ASSERT(n_tokens == (size_t)image_tokens->n_tokens()); + } image_tokens->batch_f32 = std::move(batch_f32); image_tokens->id = bitmap->id; // optional @@ -839,6 +858,9 @@ struct mtmd_tokenizer { add_text(ctx->img_end, true); // add image end token } + // advance image-chunk counter so the next image gets the next XD-RoPE dim-3 slot + n_images_added++; + } else { // handle audio @@ -1286,6 +1308,38 @@ mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * ima pos.y = pos_0 + i; pos.z = pos_0 + i; } break; + case MTMD_POS_TYPE_HUNYUANVL: + { + // HunyuanVL layout: [BOI] [row0 tokens + newline] ... [row(ny-1) tokens + newline] [EOI] + // Total = 1 + ny*(nx+1) + 1. BOI and EOI use sequential positions in every dim; + // content and row-newline tokens use (row, col) with XD-RoPE dim-3 = image_idx. + const uint32_t nx = image_tokens->nx; + const uint32_t n_total = image_tokens->n_tokens(); + if (i == 0) { + // BOI + pos.t = pos_0 + i; + pos.x = pos_0 + i; + pos.y = pos_0 + i; + pos.z = pos_0 + i; + } else if (i == n_total - 1) { + // EOI + pos.t = pos_0 + i; + pos.x = pos_0 + i; + pos.y = pos_0 + i; + pos.z = pos_0 + i; + } else { + // content token at (row, col), or the trailing newline of a row (col == nx) + // section 0 = sequential, section 1 = w(col), section 2 = h(row), section 3 = image_count. + // set_position_mrope_2d writes .y -> section 1 and .x -> section 2 + const uint32_t offset = (uint32_t)i - 1; + const uint32_t row = offset / (nx + 1); + const uint32_t col = offset % (nx + 1); + pos.t = pos_0 + i; + pos.x = row; + pos.y = col; + pos.z = image_tokens->image_idx; + } + } break; default: GGML_ABORT("invalid position type"); } @@ -1302,6 +1356,10 @@ llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) { return std::max(image_tokens->nx, image_tokens->ny); case MTMD_POS_TYPE_NORMAL: return image_tokens->n_tokens(); + case MTMD_POS_TYPE_HUNYUANVL: + // HunyuanVL: the sequential (dim-0) position advances by the full token count + // (includes BOI/EOI and row newline tokens), not by max(nx, ny) + return image_tokens->n_tokens(); default: GGML_ABORT("invalid position type"); } diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index 5da48d61b..83416fb27 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -91,6 +91,7 @@ add_test_vision "ggml-org/LightOnOCR-1B-1025-GGUF:Q8_0" add_test_vision "ggml-org/DeepSeek-OCR-GGUF:Q8_0" -p "Free OCR." --chat-template deepseek-ocr add_test_vision "ggml-org/dots.ocr-GGUF:Q8_0" -p "OCR" add_test_vision "ggml-org/HunyuanOCR-GGUF:Q8_0" -p "OCR" +add_test_vision "ggml-org/HunyuanVL-4B-GGUF:Q8_0" add_test_vision "ggml-org/gemma-4-E2B-it-GGUF:Q8_0" --jinja add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0" From 17f624516858b1a95b59076b0367b1f26f37ecd5 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Wed, 22 Apr 2026 12:10:50 +0200 Subject: [PATCH 16/20] server: ignore reasoning content from transcription api (#21905) --- tools/server/server-task.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index ac1e77615..9380792c0 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -1111,7 +1111,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() { json server_task_result_cmpl_final::to_json_oaicompat_asr() { json event = json { {"type", "transcript.text.done"}, - {"text", content}, + {"text", oaicompat_msg.content}, {"usage", json { {"type", "tokens"}, {"input_tokens", n_prompt_tokens}, From 82d3f4d3b2fcff84d61f8a4f660f8aee71a4ea39 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Wed, 22 Apr 2026 12:16:29 +0200 Subject: [PATCH 17/20] mtmd: also support LLAMA_ROPE_TYPE_NONE (#22242) --- tools/mtmd/mtmd.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 626361b92..599077867 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -195,6 +195,7 @@ struct mtmd_context { auto decoder_rope_type = llama_model_rope_type(text_model); switch (decoder_rope_type) { + case LLAMA_ROPE_TYPE_NONE: case LLAMA_ROPE_TYPE_NORM: case LLAMA_ROPE_TYPE_NEOX: { From 225088ea76687c005e282d3a48d73e9c0c8c5091 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Wed, 22 Apr 2026 18:02:56 +0530 Subject: [PATCH 18/20] sycl: Improve mul_mat_id memory efficiency and add BF16 fast path (#22119) * sycl: size mul_mat_id staging buffers by routed rows Previously src1_contiguous/dst_contiguous in ggml_sycl_mul_mat_id were sized to ggml_nelements(src1/dst), which over-allocates when ne12 > 1 and can fail with UR_RESULT_ERROR_OUT_OF_HOST_MEMORY on Level Zero for MoE models (notably with --cpu-moe). Size them by the actual number of routed rows (ids->ne[1] * n_ids) instead. * sycl: add bf16 mul_mat fast path via DNNL When src0 is BF16 (commonly the case for lm_head / output.weight), the existing f16 path is skipped because bf16 isn't covered, and the f32 fallback dequantizes the entire src0 slab to f32 in a single pool alloc (row_diff*ne00 floats). For large-vocab models this can reach several GB and fail with UR_RESULT_ERROR_OUT_OF_HOST_MEMORY on Level Zero. Add a bf16xbf16 -> f32 DNNL matmul fast path that uses the bf16 storage in place and only materializes a small src1 bf16 conversion buffer. bf16 matmul accumulates in f32, so it's correct even when the op requests GGML_PREC_F32 (as lm_head does). - gemm.hpp: map bfloat16 to dnnl::memory::data_type::bf16. - convert.{hpp,cpp}: expose ggml_get_to_bf16_sycl for f32/f16/bf16 -> bf16. - ggml-sycl.cpp: take the bf16 path early in ggml_sycl_op_mul_mat_sycl when DNNL and GGML_SYCL_HAS_BF16 are both available. --- ggml/src/ggml-sycl/common.hpp | 7 +++++++ ggml/src/ggml-sycl/convert.cpp | 23 ++++++++++++++++------- ggml/src/ggml-sycl/convert.hpp | 9 +++++++++ ggml/src/ggml-sycl/gemm.hpp | 3 +++ ggml/src/ggml-sycl/ggml-sycl.cpp | 30 ++++++++++++++++++++++++++++-- ggml/src/ggml-sycl/set_rows.cpp | 8 +++++++- 6 files changed, 70 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index fd84c9178..0101b2764 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -28,6 +28,13 @@ namespace syclexp = sycl::ext::oneapi::experimental; +#if defined(__INTEL_LLVM_COMPILER) && __has_include() + #include + #ifndef GGML_SYCL_HAS_BF16 + #define GGML_SYCL_HAS_BF16 + #endif +#endif + #if GGML_SYCL_DNNL #include "dnnl.hpp" #include "dnnl_sycl.hpp" diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index f3c521b45..67b9c06f3 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -2,13 +2,6 @@ #include "dequantize.hpp" #include "presets.hpp" -#if defined(__INTEL_LLVM_COMPILER) - #if __has_include() - #include - #define GGML_SYCL_HAS_BF16 - #endif -#endif - template static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, const sycl::nd_item<3> &item_ct1) { @@ -767,6 +760,22 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { } +#ifdef GGML_SYCL_HAS_BF16 +to_bf16_sycl_t ggml_get_to_bf16_sycl(ggml_type type, ggml_tensor * /*dst*/) { + switch (type) { + case GGML_TYPE_F32: + return convert_unary_sycl; + case GGML_TYPE_F16: + return convert_unary_sycl; + case GGML_TYPE_BF16: + return convert_unary_sycl; + default: + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type)); + return nullptr; + } +} +#endif + to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type) { switch (type) { case GGML_TYPE_F32: diff --git a/ggml/src/ggml-sycl/convert.hpp b/ggml/src/ggml-sycl/convert.hpp index 6e621f215..8de79d10f 100644 --- a/ggml/src/ggml-sycl/convert.hpp +++ b/ggml/src/ggml-sycl/convert.hpp @@ -23,6 +23,11 @@ typedef to_t_sycl_t to_fp16_sycl_t; to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst); to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor * dst); +#ifdef GGML_SYCL_HAS_BF16 +typedef to_t_sycl_t to_bf16_sycl_t; +to_bf16_sycl_t ggml_get_to_bf16_sycl(ggml_type type, ggml_tensor * dst); +#endif + // Nc = Non-contiguous template using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, @@ -35,15 +40,19 @@ template inline dst_t ggml_sycl_cast(src_t x) { if constexpr (std::is_same_v) { return x; +#ifdef GGML_SYCL_HAS_BF16 } else if constexpr (std::is_same_v) { return sycl::ext::oneapi::bfloat16(float(x)); } else if constexpr (std::is_same_v) { return static_cast(x); +#endif } else if constexpr (std::is_same_v && std::is_same_v) { return x.template convert(); +#ifdef GGML_SYCL_HAS_BF16 } else if constexpr (std::is_same_v && std::is_same_v>) { return {x.x, x.y}; +#endif } else if constexpr(std::is_same_v) { return int32_t(x); } else { diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp index dcf6c7aee..c202da110 100644 --- a/ggml/src/ggml-sycl/gemm.hpp +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -29,6 +29,9 @@ public: static constexpr dt to_dt() { if constexpr (std::is_same_v) return dt::f32; else if constexpr (std::is_same_v) return dt::f16; +#ifdef GGML_SYCL_HAS_BF16 + else if constexpr (std::is_same_v) return dt::bf16; +#endif else static_assert(0); } diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index c02a41ad8..3829da879 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2176,6 +2176,31 @@ inline void ggml_sycl_op_mul_mat_sycl( #else bool use_fp16 = false; #endif + +#if GGML_SYCL_DNNL && defined(GGML_SYCL_HAS_BF16) + // Fast path for bf16 src0 + if (src0->type == GGML_TYPE_BF16 && !g_ggml_sycl_disable_dnn && ggml_is_contiguous(src0) && + row_diff == src0->ne[1]) { + using bf16_t = sycl::ext::oneapi::bfloat16; + ggml_sycl_pool_alloc src1_as_bf16(ctx.pool(), src1_ncols*ne10); + if (src1->type != GGML_TYPE_BF16) { + const to_bf16_sycl_t to_bf16_sycl = ggml_get_to_bf16_sycl(src1->type, dst); + GGML_ASSERT(to_bf16_sycl != nullptr); + to_bf16_sycl(src1_ddf_i, src1_as_bf16.get(), src1_ncols*ne10, stream); + } else { + stream->memcpy(src1_as_bf16.get(), src1_ddf_i, src1_ncols*ne10*sizeof(bf16_t)); + } + DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, + src0_dd_i, DnnlGemmWrapper::to_dt(), + src1_as_bf16.get(), DnnlGemmWrapper::to_dt(), + dst_dd_i, DnnlGemmWrapper::to_dt(), stream); + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddq_i); + GGML_UNUSED(src1_padded_row_size); + return; + } +#endif + if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { ggml_sycl_pool_alloc src0_as_f16(ctx.pool()); @@ -3848,8 +3873,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, } } } else { - ggml_sycl_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); - ggml_sycl_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + const int64_t n_routed_rows = ids->ne[1] * n_ids; + ggml_sycl_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*n_routed_rows*ne10); + ggml_sycl_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*n_routed_rows*ne0); src1_row.data = src1_contiguous.get(); dst_row.data = dst_contiguous.get(); diff --git a/ggml/src/ggml-sycl/set_rows.cpp b/ggml/src/ggml-sycl/set_rows.cpp index a641c1009..8fb419435 100644 --- a/ggml/src/ggml-sycl/set_rows.cpp +++ b/ggml/src/ggml-sycl/set_rows.cpp @@ -4,7 +4,11 @@ namespace utils { template static constexpr bool is_arithmetic_v() { - return std::is_arithmetic_v || std::is_same_v || std::is_same_v; + return std::is_arithmetic_v || std::is_same_v +#ifdef GGML_SYCL_HAS_BF16 + || std::is_same_v +#endif + ; } } @@ -181,6 +185,7 @@ static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * s stream ); break; +#ifdef GGML_SYCL_HAS_BF16 case GGML_TYPE_BF16: set_rows_sycl( src0_d, src1_d, (char *)dst->data, @@ -193,6 +198,7 @@ static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * s stream ); break; +#endif case GGML_TYPE_Q8_0: set_rows_sycl_q(src0_d, src1_d, (block_q8_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream); break; From bcb5eeb64529806b8d1cb80eccbc22c7d0897cb2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 22 Apr 2026 15:44:45 +0300 Subject: [PATCH 19/20] speculative-simple : add checkpoint support (#22227) * speculative-simple : add checkpoint support * cont : fix build --- .../speculative-simple/speculative-simple.cpp | 112 ++++++++++++++++-- tools/server/server-context.cpp | 10 +- 2 files changed, 112 insertions(+), 10 deletions(-) diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index a03dbce88..73394b74e 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -8,8 +8,24 @@ #include #include #include +#include #include #include +#include + +struct spec_checkpoint { + int64_t n_tokens = 0; + + std::vector data; + + size_t size() const { + return data.size(); + } + + bool empty() const { + return data.empty(); + } +}; int main(int argc, char ** argv) { std::setlocale(LC_NUMERIC, "C"); @@ -46,6 +62,14 @@ int main(int argc, char ** argv) { model_tgt = llama_init_tgt->model(); ctx_tgt = llama_init_tgt->context(); + // check if the context supports partial sequence removal + const auto ctx_seq_rm = common_context_can_seq_rm(ctx_tgt); + const bool use_ckpt = (ctx_seq_rm == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); + + if (use_ckpt) { + LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n"); + } + const llama_vocab * vocab = llama_model_get_vocab(model_tgt); // load the draft model @@ -119,7 +143,7 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // target model sampling context - struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); + common_sampler_ptr smpl(common_sampler_init(model_tgt, params.sampling)); // eval the prompt llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); @@ -142,21 +166,61 @@ int main(int argc, char ** argv) { llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); + size_t n_draft = 0; + + llama_tokens draft; + spec_checkpoint spec_ckpt; + const auto t_enc_end = ggml_time_us(); const auto t_dec_start = ggml_time_us(); while (true) { - // optionally, generate draft tokens that can be appended to the target batch + // generate or reuse draft tokens // // this is the most important part of the speculation. the more probable tokens that are provided here // the better the performance will be. in theory, this computation can be performed asynchronously and even // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens // from a cache or lookup tables. // - llama_tokens draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last); + if (draft.empty()) { + // generate a new draft + draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last); - //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str()); + if ((int) draft.size() > params_spec.n_max) { + LOG_WRN("draft size %zu exceeds max %d, truncating\n", draft.size(), params_spec.n_max); + draft.resize(params_spec.n_max); + } + + if ((int) draft.size() < params_spec.n_min) { + LOG_DBG("ignoring small draft: %zu < %d\n", draft.size(), params_spec.n_min); + draft.clear(); + } + + // save the original draft size + n_draft = draft.size(); + + // save a checkpoint of the target context before evaluating the draft + // this allows us to restore the state if partial draft acceptance occurs + if (!draft.empty() && use_ckpt) { + const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_tgt, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + spec_ckpt.data.resize(ckpt_size); + + const size_t n = llama_state_seq_get_data_ext(ctx_tgt, spec_ckpt.data.data(), ckpt_size, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + GGML_ASSERT(n == ckpt_size); + + spec_ckpt.n_tokens = (int64_t) prompt_tgt.size(); + LOG_DBG("created speculative checkpoint (n_tokens = %" PRId64 ", size = %.3f MiB)\n", + spec_ckpt.n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024); + } + } else { + // we have a previous (partial) draft to reuse from checkpoint restoration + if (use_ckpt) { + GGML_ASSERT(!spec_ckpt.empty()); + } + } + + GGML_ASSERT(n_draft > 0); // always have a token to evaluate from before - id_last common_batch_clear(batch_tgt); @@ -178,6 +242,12 @@ int main(int argc, char ** argv) { llama_decode(ctx_tgt, batch_tgt); } + // only save the sampler sampler state if we use checkpoints + common_sampler_ptr smpl_save; + if (use_ckpt) { + smpl_save.reset(common_sampler_clone(smpl.get())); + } + // sample from the full target batch and return the accepted tokens based on the target sampler // // for each token to be accepted, the sampler would have to sample that same token @@ -185,14 +255,38 @@ int main(int argc, char ** argv) { // available logits from the batch and sample the next token until we run out of logits or the sampler // disagrees with the draft // - const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft); + auto ids = common_sampler_sample_and_accept_n(smpl.get(), ctx_tgt, draft); //LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str()); GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token + // check for partial draft acceptance: + // if the context doesn't support partial sequence removal, restore the checkpoint + // and make the accepted tokens the new partial draft for the next iteration + if (use_ckpt && ids.size() - 1 < draft.size()) { + LOG_DBG("partial acceptance: %zu < %zu, restoring checkpoint\n", ids.size() - 1, draft.size()); + + draft = std::move(ids); + + const size_t n = llama_state_seq_set_data_ext(ctx_tgt, spec_ckpt.data.data(), spec_ckpt.size(), 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + GGML_ASSERT(n == spec_ckpt.size()); + + llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, spec_ckpt.n_tokens, -1); + + prompt_tgt.resize(spec_ckpt.n_tokens); + smpl = std::move(smpl_save); + + n_past = (int) prompt_tgt.size(); + + continue; + } + + common_speculative_accept(spec, ids.size() - 1); + + // full acceptance: consume the draft and commit accepted tokens n_past += ids.size() - 1; - n_drafted += draft.size(); // note: we ignore the discarded small drafts + n_drafted += n_draft; // note: we ignore the discarded small drafts n_accept += ids.size() - 1; n_predict += ids.size(); @@ -222,6 +316,9 @@ int main(int argc, char ** argv) { LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last); + // clear the draft since it has been consumed + draft.clear(); + { LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); @@ -254,11 +351,10 @@ int main(int argc, char ** argv) { LOG_INF("\n"); LOG_INF("target:\n\n"); - common_perf_print(ctx_tgt, smpl); + common_perf_print(ctx_tgt, smpl.get()); llama_batch_free(batch_tgt); - common_sampler_free(smpl); common_speculative_free(spec); llama_backend_free(); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 53f61b5a9..b8c05cd80 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2961,7 +2961,13 @@ private: // verify and try to accept the draft { - common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get())); + const bool use_ckpt = slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + + // only save the sampler sampler state if we use checkpoints + common_sampler_ptr smpl_save; + if (use_ckpt) { + smpl_save.reset(common_sampler_clone(slot.smpl.get())); + } GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft); @@ -2973,7 +2979,7 @@ private: // check for partial draft acceptance if (accepted.size() < slot.spec_draft.size() + 1) { - if (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { + if (use_ckpt) { // partial acceptance is not supported by the context -> truncate the draft and restore the state slot.spec_draft = std::move(accepted); From 8bccdbbff9d0d91d54838471f6eea182b9ab1b79 Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Wed, 22 Apr 2026 18:10:56 +0200 Subject: [PATCH 20/20] chat: fix parallel_tool_calls default setting based on model capabilities, add tests for parallel tool calls and structured outputs (#22217) * chat: fix parallel_tool_calls default setting based on model capabilities, add tests for parallel tool calls and structured outputs * Fix ty errors. * Fix flake8 err --- scripts/server-test-parallel-tc.py | 991 +++++++++++++++++++++++++++++ scripts/server-test-structured.py | 980 ++++++++++++++++++++++++++++ tools/cli/cli.cpp | 4 +- tools/server/server-common.cpp | 4 +- 4 files changed, 1977 insertions(+), 2 deletions(-) create mode 100755 scripts/server-test-parallel-tc.py create mode 100755 scripts/server-test-structured.py diff --git a/scripts/server-test-parallel-tc.py b/scripts/server-test-parallel-tc.py new file mode 100755 index 000000000..a166c6d72 --- /dev/null +++ b/scripts/server-test-parallel-tc.py @@ -0,0 +1,991 @@ +#!/usr/bin/env python3 +""" +Test parallel tool-calling capability via chat completions endpoint. + +Only run this against models that actually support parallel tool calls — this +script does not attempt to toggle that setting on the server. Each scenario is +explicitly worded so that a capable model SHOULD emit multiple tool calls in a +single assistant turn (either the same tool N times, or several different +tools at once). + +Each test case contains: + - tools: list of tool definitions (OpenAI-compatible) + - messages: initial conversation messages + - mock_tool_responses: dict mapping tool_name -> callable(arguments) -> str (JSON) + - expected_parallel: dict describing what constitutes a successful parallel turn + {"min_parallel": int, # minimum tool_calls in one turn + "require_same_tool": Optional[str], # all parallel calls must be this tool + "require_distinct_tools": Optional[int], # >= N distinct tool names in one turn + "min_distinct_args_key": Optional[str]} # parallel calls must span this + # many distinct values of this arg key + - validate: callable(turns, all_tool_calls, final_content) -> (passed, reason) +""" + +import argparse +import json +import requests +import sys + +# --------------------------------------------------------------------------- +# Color / formatting helpers +# --------------------------------------------------------------------------- + +RESET = "\x1b[0m" +BOLD = "\x1b[1m" +DIM = "\x1b[2m" +CYAN = "\x1b[36m" +YELLOW = "\x1b[33m" +GREEN = "\x1b[32m" +RED = "\x1b[31m" +BLUE = "\x1b[34m" +WHITE = "\x1b[97m" +MAGENTA = "\x1b[35m" + + +def _print(text="", end="\n"): + sys.stdout.write(text + end) + sys.stdout.flush() + + +def print_header(title): + bar = "─" * 60 + _print(f"\n{BOLD}{CYAN}┌{bar}┐{RESET}") + _print( + f"{BOLD}{CYAN}│ {WHITE}{title}{CYAN}{' ' * max(0, 58 - len(title))}│{RESET}" + ) + _print(f"{BOLD}{CYAN}└{bar}┘{RESET}") + + +def print_turn_banner(turn_idx, n_calls): + color = MAGENTA if n_calls >= 2 else DIM + _print(f"\n {BOLD}{color}▶ turn {turn_idx} — {n_calls} tool call(s){RESET}") + + +def print_tool_call(name, args): + args_str = json.dumps(args) + _print( + f" {BOLD}{YELLOW}⚙ {name}{RESET}{DIM}({args_str}){RESET}" + ) + + +def print_tool_result(result): + preview = result[:140] + ("…" if len(result) > 140 else "") + _print(f" {DIM}{BLUE}↳ {preview}{RESET}") + + +def print_model_output(text): + sys.stdout.write(text) + sys.stdout.flush() + + +def print_pass(reason): + _print(f"\n{BOLD}{GREEN}✔ PASS{RESET} {reason}") + + +def print_fail(reason): + _print(f"\n{BOLD}{RED}✘ FAIL{RESET} {reason}") + + +def print_info(msg): + _print(f"{DIM}{msg}{RESET}") + + +def print_warn(msg): + _print(f"{BOLD}{YELLOW}⚠ {msg}{RESET}") + + +# --------------------------------------------------------------------------- +# HTTP helpers +# --------------------------------------------------------------------------- + + +def chat_completion(url, messages, tools=None, stream=False): + payload = { + "messages": messages, + "stream": stream, + "max_tokens": 4096, + } + if tools: + payload["tools"] = tools + payload["tool_choice"] = "auto" + + try: + response = requests.post(url, json=payload, stream=stream) + response.raise_for_status() + except requests.exceptions.RequestException as e: + body = e.response.content if (e.response is not None) else b"" + print_fail(f"Request error: {e} | body: {body}") + return None + + full_content = "" + reasoning_content = "" + tool_calls: list[dict] = [] + + if stream: + for line in response.iter_lines(): + if not line: + continue + decoded = line.decode("utf-8") + if not decoded.startswith("data: "): + continue + data_str = decoded[6:] + if data_str == "[DONE]": + break + try: + data = json.loads(data_str) + except json.JSONDecodeError: + continue + choices = data.get("choices", []) + if not choices: + continue + delta = choices[0].get("delta", {}) + if delta.get("reasoning_content"): + reasoning_content += delta["reasoning_content"] + if delta.get("content"): + full_content += delta["content"] + print_model_output(delta["content"]) + for tc in delta.get("tool_calls", []): + idx = tc.get("index", 0) + while len(tool_calls) <= idx: + tool_calls.append( + { + "id": "", + "type": "function", + "function": {"name": "", "arguments": ""}, + } + ) + if "id" in tc: + tool_calls[idx]["id"] += tc["id"] + if "function" in tc: + if "name" in tc["function"]: + tool_calls[idx]["function"]["name"] += tc["function"]["name"] + if "arguments" in tc["function"]: + tool_calls[idx]["function"]["arguments"] += tc["function"][ + "arguments" + ] + else: + data = response.json() + choices = data.get("choices", []) + if choices: + msg = choices[0].get("message", {}) + full_content = msg.get("content") or "" + reasoning_content = msg.get("reasoning_content") or "" + tool_calls = msg.get("tool_calls") or [] + if full_content: + print_model_output(full_content) + + result = {"content": full_content, "tool_calls": tool_calls} + if reasoning_content: + result["reasoning_content"] = reasoning_content + return result + + +def run_agentic_loop(url, messages, tools, mock_tool_responses, stream, max_turns=6): + """ + Drive the multi-turn tool-call loop, but record each turn's tool calls + separately so parallelism can be validated. + + Returns (turns, all_tool_calls, final_content) where `turns` is a list + of dicts: {"index": int, "tool_calls": [...], "content": str}. + """ + msgs = list(messages) + turns: list[dict] = [] + all_tool_calls: list[dict] = [] + + for turn_idx in range(max_turns): + result = chat_completion(url, msgs, tools=tools, stream=stream) + if result is None: + return turns, all_tool_calls, None + + tcs = result.get("tool_calls") or [] + content = result.get("content") or "" + + turns.append( + {"index": turn_idx, "tool_calls": list(tcs), "content": content} + ) + + if not tcs: + if content: + _print(f"\n{DIM}{'·' * 60}{RESET}") + _print(f"{DIM} model response:{RESET}\n") + return turns, all_tool_calls, content + + print_turn_banner(turn_idx, len(tcs)) + all_tool_calls.extend(tcs) + + assistant_msg: dict = { + "role": "assistant", + "content": content, + "tool_calls": tcs, + } + reasoning = result.get("reasoning_content") + if reasoning: + assistant_msg["reasoning_content"] = reasoning + msgs.append(assistant_msg) + + for tc in tcs: + tool_name = tc["function"]["name"] + try: + args = json.loads(tc["function"]["arguments"]) + except json.JSONDecodeError: + args = {} + + print_tool_call(tool_name, args) + + mock_fn = mock_tool_responses.get(tool_name) + if mock_fn: + tool_result = mock_fn(args) + else: + tool_result = json.dumps({"error": f"Unknown tool: {tool_name}"}) + + print_tool_result(tool_result) + + msgs.append( + { + "role": "tool", + "tool_call_id": tc.get("id", ""), + "content": tool_result, + } + ) + + return turns, all_tool_calls, None + + +# --------------------------------------------------------------------------- +# Parallelism helpers +# --------------------------------------------------------------------------- + + +def _best_parallel_turn(turns): + """Return the turn (dict) with the most tool calls, or None if no tools.""" + tool_turns = [t for t in turns if t["tool_calls"]] + if not tool_turns: + return None + return max(tool_turns, key=lambda t: len(t["tool_calls"])) + + +def _distinct_tool_names(turn): + return {tc["function"]["name"] for tc in turn["tool_calls"]} + + +def _distinct_arg_values(turn, key): + values = set() + for tc in turn["tool_calls"]: + try: + args = json.loads(tc["function"]["arguments"]) + except json.JSONDecodeError: + continue + v = args.get(key) + if v is not None: + if isinstance(v, str): + values.add(v.strip().lower()) + else: + values.add(v) + return values + + +def _check_parallel(turns, expected): + """ + Check that at least one turn satisfies the parallel-call expectations. + Returns (ok, reason). + """ + best = _best_parallel_turn(turns) + if best is None: + return False, "No tool calls were made at all" + + min_parallel = expected.get("min_parallel", 2) + if len(best["tool_calls"]) < min_parallel: + by_turn = [len(t["tool_calls"]) for t in turns] + return False, ( + f"No turn had >= {min_parallel} parallel tool calls " + f"(per-turn counts: {by_turn})" + ) + + require_same = expected.get("require_same_tool") + if require_same is not None: + names = [tc["function"]["name"] for tc in best["tool_calls"]] + if any(n != require_same for n in names): + return False, ( + f"Parallel turn mixed tools; expected all {require_same!r}, got {names}" + ) + + require_distinct = expected.get("require_distinct_tools") + if require_distinct is not None: + distinct = _distinct_tool_names(best) + if len(distinct) < require_distinct: + return False, ( + f"Parallel turn had only {len(distinct)} distinct tool names " + f"({distinct}); need >= {require_distinct}" + ) + + distinct_key = expected.get("min_distinct_args_key") + distinct_count = expected.get("min_distinct_args_count", min_parallel) + if distinct_key is not None: + values = _distinct_arg_values(best, distinct_key) + if len(values) < distinct_count: + return False, ( + f"Parallel turn had only {len(values)} distinct {distinct_key!r} " + f"values ({values}); need >= {distinct_count}" + ) + + return True, ( + f"Parallel turn had {len(best['tool_calls'])} calls across " + f"{len(_distinct_tool_names(best))} distinct tool(s)" + ) + + +# --------------------------------------------------------------------------- +# Test case runner +# --------------------------------------------------------------------------- + + +def run_test(url, test_case, stream): + name = test_case["name"] + mode = f"{'stream' if stream else 'non-stream'}" + print_header(f"{name} [{mode}]") + + turns, all_tool_calls, final_content = run_agentic_loop( + url, + messages=test_case["messages"], + tools=test_case["tools"], + mock_tool_responses=test_case["mock_tool_responses"], + stream=stream, + ) + + if not turns: + print_fail("No response from server.") + return False + + parallel_ok, parallel_reason = _check_parallel(turns, test_case["expected_parallel"]) + if not parallel_ok: + print_fail(parallel_reason) + return False + + passed, reason = test_case["validate"](turns, all_tool_calls, final_content) + if passed: + print_pass(f"{parallel_reason}; {reason}") + else: + print_fail(reason) + return passed + + +# --------------------------------------------------------------------------- +# Test case definitions +# --------------------------------------------------------------------------- + +# ---- Test 1: Multi-file read (same tool, multiple distinct paths) ---- + +_FILE_TOOLS = [ + { + "type": "function", + "function": { + "name": "read_file", + "description": ( + "Read the full contents of a file from the local filesystem. " + "Call this tool in parallel when asked to read several files — " + "each path needs its own call." + ), + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Absolute or repo-relative path to a file", + }, + }, + "required": ["path"], + }, + }, + }, +] + +_FILE_CONTENTS = { + "config/database.yml": "host: db.internal\nport: 5432\nuser: svc_app\n", + "config/redis.yml": "host: cache.internal\nport: 6379\ndb: 0\n", + "config/queue.yml": "broker: rabbitmq.internal\nport: 5672\nvhost: prod\n", + "config/auth.yml": "provider: oidc\nissuer: https://auth.internal\n", +} + + +def _read_file_mock(args): + path = args.get("path", "") + norm = path.lstrip("./").lstrip("/") + content = _FILE_CONTENTS.get(norm) + if content is None: + for k, v in _FILE_CONTENTS.items(): + if path.endswith(k): + content = v + break + if content is None: + return json.dumps({"path": path, "error": "not found"}) + return json.dumps({"path": path, "content": content}) + + +MULTIFILE_READ_TEST = { + "name": "Parallel multi-file read (same tool, 4 distinct paths)", + "tools": _FILE_TOOLS, + "messages": [ + { + "role": "user", + "content": ( + "Please read all four of these config files so I can review them " + "together: config/database.yml, config/redis.yml, config/queue.yml, " + "and config/auth.yml. Call read_file for every path in parallel in " + "a single batch — do NOT read them one by one sequentially across " + "turns. After you have all four, give me a one-line summary of each." + ), + } + ], + "mock_tool_responses": {"read_file": _read_file_mock}, + "expected_parallel": { + "min_parallel": 4, + "require_same_tool": "read_file", + "min_distinct_args_key": "path", + "min_distinct_args_count": 4, + }, + "validate": lambda turns, tcs, content: _validate_multifile(turns, tcs, content), +} + + +def _validate_multifile(turns, tcs, content): + del turns + if not content: + return False, "No final summary produced" + return True, f"{len(tcs)} total read_file calls; content length={len(content)}" + + +# ---- Test 2: Batch TODO marking (same tool, N calls in one turn) ---- + +_TODO_TOOLS = [ + { + "type": "function", + "function": { + "name": "mark_todo_complete", + "description": ( + "Mark a single TODO item as complete by ID. When the user wants " + "several items marked at once, call this tool in parallel — " + "one call per item — rather than sequentially across turns." + ), + "parameters": { + "type": "object", + "properties": { + "todo_id": { + "type": "string", + "description": "Identifier of the TODO item", + }, + "note": { + "type": "string", + "description": "Optional completion note", + }, + }, + "required": ["todo_id"], + }, + }, + }, +] + +_TODO_DB = { + "T-101": "Draft onboarding doc", + "T-102": "Update dependency lockfile", + "T-103": "Fix flaky login test", + "T-104": "Rotate service credentials", + "T-105": "Archive Q4 reports", +} + + +def _mark_todo_mock(args): + tid = args.get("todo_id", "") + if tid in _TODO_DB: + return json.dumps({"todo_id": tid, "title": _TODO_DB[tid], "status": "done"}) + return json.dumps({"todo_id": tid, "error": "unknown id"}) + + +TODO_BATCH_TEST = { + "name": "Batch TODO completion (same tool, 5 IDs in one turn)", + "tools": _TODO_TOOLS, + "messages": [ + { + "role": "user", + "content": ( + "I finished every item on today's list. Please mark all of the " + "following TODOs as complete, in one parallel batch: T-101, T-102, " + "T-103, T-104, T-105. Don't mark them one at a time across separate " + "turns — issue all five mark_todo_complete calls at once. Afterwards " + "confirm which ones succeeded." + ), + } + ], + "mock_tool_responses": {"mark_todo_complete": _mark_todo_mock}, + "expected_parallel": { + "min_parallel": 5, + "require_same_tool": "mark_todo_complete", + "min_distinct_args_key": "todo_id", + "min_distinct_args_count": 5, + }, + "validate": lambda turns, tcs, content: _validate_todo(turns, tcs, content), +} + + +def _validate_todo(turns, tcs, content): + del turns + if not content: + return False, "No confirmation summary produced" + return True, f"{len(tcs)} total mark_todo_complete calls" + + +# ---- Test 3: Multi-city weather (same tool, N parallel locations) ---- + +_WEATHER_TOOLS = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": ( + "Fetch current weather for ONE city. When the user asks about " + "several cities, call this tool in parallel — one call per city — " + "instead of sequentially." + ), + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + "units": { + "type": "string", + "enum": ["metric", "imperial"], + "default": "metric", + }, + }, + "required": ["city"], + }, + }, + }, +] + +_WEATHER_DB = { + "tokyo": {"city": "Tokyo", "temp_c": 18.4, "condition": "partly cloudy", "humidity": 64}, + "london": {"city": "London", "temp_c": 9.1, "condition": "overcast", "humidity": 81}, + "new york": {"city": "New York", "temp_c": 12.7, "condition": "clear", "humidity": 55}, + "paris": {"city": "Paris", "temp_c": 11.3, "condition": "light rain", "humidity": 78}, +} + + +def _weather_mock(args): + city = args.get("city", "").strip().lower() + if city.startswith("new york"): + city = "new york" + if city in _WEATHER_DB: + return json.dumps(_WEATHER_DB[city]) + return json.dumps({"city": args.get("city", ""), "error": "unknown city"}) + + +MULTI_WEATHER_TEST = { + "name": "Parallel multi-city weather (same tool, 4 cities)", + "tools": _WEATHER_TOOLS, + "messages": [ + { + "role": "user", + "content": ( + "I'm comparing today's weather across four cities for a travel " + "decision: Tokyo, London, New York, and Paris. Please call " + "get_weather for all four in parallel in a single turn — don't " + "fetch them one at a time. Then rank them from warmest to coolest." + ), + } + ], + "mock_tool_responses": {"get_weather": _weather_mock}, + "expected_parallel": { + "min_parallel": 4, + "require_same_tool": "get_weather", + "min_distinct_args_key": "city", + "min_distinct_args_count": 4, + }, + "validate": lambda turns, tcs, content: _validate_weather(turns, tcs, content), +} + + +def _validate_weather(turns, tcs, content): + del turns + if not content or not any( + kw in content.lower() for kw in ("warmest", "rank", "hot", "cool") + ): + return False, f"Final content missing a ranking: {content!r}" + return True, f"{len(tcs)} total get_weather calls; ranking produced" + + +# ---- Test 4: Trip planning (different tools, parallel in one turn) ---- + +_TRIP_TOOLS = [ + { + "type": "function", + "function": { + "name": "search_flights", + "description": "Search one-way flights between two airports on a given date.", + "parameters": { + "type": "object", + "properties": { + "from_airport": {"type": "string", "description": "IATA code, e.g. SFO"}, + "to_airport": {"type": "string", "description": "IATA code, e.g. JFK"}, + "date": {"type": "string", "description": "YYYY-MM-DD"}, + }, + "required": ["from_airport", "to_airport", "date"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "search_hotels", + "description": "Search hotels in a city for a date range.", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "check_in": {"type": "string", "description": "YYYY-MM-DD"}, + "check_out": {"type": "string", "description": "YYYY-MM-DD"}, + "max_price": {"type": "integer"}, + }, + "required": ["city", "check_in", "check_out"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "search_restaurants", + "description": "Search restaurants in a city by cuisine.", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "cuisine": {"type": "string"}, + }, + "required": ["city"], + }, + }, + }, +] + +_FLIGHTS_RESULT = { + "results": [ + {"flight": "UA 1552", "depart": "08:15", "arrive": "16:45", "price": 389}, + {"flight": "AA 20", "depart": "10:00", "arrive": "18:35", "price": 412}, + ] +} +_HOTELS_RESULT = { + "results": [ + {"name": "Midtown Grand", "nightly_rate": 245, "rating": 4.3}, + {"name": "Harbour Boutique", "nightly_rate": 312, "rating": 4.6}, + ] +} +_RESTAURANTS_RESULT = { + "results": [ + {"name": "Trattoria Nona", "cuisine": "italian", "rating": 4.5}, + {"name": "Osteria Blu", "cuisine": "italian", "rating": 4.4}, + ] +} + +TRIP_PLAN_TEST = { + "name": "Trip planning (3 different tools in parallel)", + "tools": _TRIP_TOOLS, + "messages": [ + { + "role": "user", + "content": ( + "I'm flying from SFO to JFK on 2026-06-12 and staying four nights " + "(check out 2026-06-16). I'd also like some Italian restaurant " + "suggestions in New York. Please call search_flights, search_hotels, " + "and search_restaurants in parallel — all three in a single turn, " + "since they don't depend on each other. Then give me a concise " + "travel summary." + ), + } + ], + "mock_tool_responses": { + "search_flights": lambda _: json.dumps(_FLIGHTS_RESULT), + "search_hotels": lambda _: json.dumps(_HOTELS_RESULT), + "search_restaurants": lambda _: json.dumps(_RESTAURANTS_RESULT), + }, + "expected_parallel": { + "min_parallel": 3, + "require_distinct_tools": 3, + }, + "validate": lambda turns, tcs, content: _validate_trip(turns, tcs, content), +} + + +def _validate_trip(turns, tcs, content): + del turns + names = {tc["function"]["name"] for tc in tcs} + required = {"search_flights", "search_hotels", "search_restaurants"} + missing = required - names + if missing: + return False, f"Missing tool calls: {missing}" + if not content: + return False, "No travel summary produced" + return True, f"All three tools called; summary length={len(content)}" + + +# ---- Test 5: Portfolio check (same tool, parallel tickers) ---- + +_STOCK_TOOLS = [ + { + "type": "function", + "function": { + "name": "get_stock_quote", + "description": ( + "Get the latest quote for ONE ticker. When the user asks about " + "multiple tickers, call this tool in parallel — one per symbol — " + "rather than sequentially." + ), + "parameters": { + "type": "object", + "properties": { + "symbol": {"type": "string", "description": "Ticker symbol"}, + }, + "required": ["symbol"], + }, + }, + }, +] + +_STOCK_DB = { + "AAPL": {"symbol": "AAPL", "price": 218.45, "change_pct": "+0.8%"}, + "MSFT": {"symbol": "MSFT", "price": 421.10, "change_pct": "+1.2%"}, + "GOOGL":{"symbol": "GOOGL","price": 175.22, "change_pct": "-0.3%"}, + "AMZN": {"symbol": "AMZN", "price": 189.76, "change_pct": "+0.5%"}, + "NVDA": {"symbol": "NVDA", "price": 140.88, "change_pct": "+2.4%"}, +} + + +def _stock_mock(args): + sym = args.get("symbol", "").strip().upper() + if sym in _STOCK_DB: + return json.dumps(_STOCK_DB[sym]) + return json.dumps({"symbol": sym, "error": "unknown ticker"}) + + +PORTFOLIO_TEST = { + "name": "Portfolio check (same tool, 5 tickers in parallel)", + "tools": _STOCK_TOOLS, + "messages": [ + { + "role": "user", + "content": ( + "Pull the latest quote for every ticker in my portfolio — AAPL, " + "MSFT, GOOGL, AMZN, and NVDA — in a single parallel batch. These " + "lookups are independent, so please don't chain them across turns. " + "Once you have all five, tell me which ticker had the biggest " + "percentage change today." + ), + } + ], + "mock_tool_responses": {"get_stock_quote": _stock_mock}, + "expected_parallel": { + "min_parallel": 5, + "require_same_tool": "get_stock_quote", + "min_distinct_args_key": "symbol", + "min_distinct_args_count": 5, + }, + "validate": lambda turns, tcs, content: _validate_portfolio(turns, tcs, content), +} + + +def _validate_portfolio(turns, tcs, content): + del turns + if not content or ("nvda" not in content.lower() and "NVDA" not in content): + return False, f"Expected NVDA to be identified as the biggest mover: {content!r}" + return True, f"{len(tcs)} total quotes pulled" + + +# ---- Test 6: Mixed — translate + dictionary in parallel for the same word ---- + +_LANG_TOOLS = [ + { + "type": "function", + "function": { + "name": "translate_text", + "description": "Translate a short text into a target language.", + "parameters": { + "type": "object", + "properties": { + "text": {"type": "string"}, + "target_language": {"type": "string", + "description": "ISO 639-1 language code, e.g. 'es'"}, + }, + "required": ["text", "target_language"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_definition", + "description": "Get the English dictionary definition of a word.", + "parameters": { + "type": "object", + "properties": { + "word": {"type": "string"}, + }, + "required": ["word"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_synonyms", + "description": "Get English synonyms for a word.", + "parameters": { + "type": "object", + "properties": { + "word": {"type": "string"}, + }, + "required": ["word"], + }, + }, + }, +] + + +def _translate_mock(args): + t = args.get("text", "") + lang = args.get("target_language", "") + return json.dumps({"source": t, "target_language": lang, "translation": f"[{lang}] {t}"}) + + +def _definition_mock(args): + w = args.get("word", "") + return json.dumps({ + "word": w, + "definition": f"A standard dictionary definition of {w!r}.", + }) + + +def _synonyms_mock(args): + w = args.get("word", "") + return json.dumps({ + "word": w, + "synonyms": ["synonym_a", "synonym_b", "synonym_c"], + }) + + +LANG_TOOLKIT_TEST = { + "name": "Language toolkit (translate + definition + synonyms in parallel)", + "tools": _LANG_TOOLS, + "messages": [ + { + "role": "user", + "content": ( + "For the English word 'resilient', I need three independent " + "look-ups at once: (a) translate it into Spanish, (b) fetch its " + "dictionary definition, and (c) list its synonyms. These three " + "calls don't depend on each other — please issue them in parallel " + "in a single turn. Then present the combined results as a short " + "language note." + ), + } + ], + "mock_tool_responses": { + "translate_text": _translate_mock, + "get_definition": _definition_mock, + "get_synonyms": _synonyms_mock, + }, + "expected_parallel": { + "min_parallel": 3, + "require_distinct_tools": 3, + }, + "validate": lambda turns, tcs, content: _validate_lang(turns, tcs, content), +} + + +def _validate_lang(turns, tcs, content): + del turns + names = {tc["function"]["name"] for tc in tcs} + required = {"translate_text", "get_definition", "get_synonyms"} + missing = required - names + if missing: + return False, f"Missing tool calls: {missing}" + if not content: + return False, "No language note produced" + return True, f"All three lookup tools called; note length={len(content)}" + + +# --------------------------------------------------------------------------- +# All test cases +# --------------------------------------------------------------------------- + +ALL_TEST_CASES = [ + MULTIFILE_READ_TEST, + TODO_BATCH_TEST, + MULTI_WEATHER_TEST, + TRIP_PLAN_TEST, + PORTFOLIO_TEST, + LANG_TOOLKIT_TEST, +] + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser( + description=( + "Test llama-server parallel tool-calling capability. Run this only " + "against models configured for parallel tool calls — this script " + "does not configure that itself." + ) + ) + parser.add_argument("--host", default="localhost") + parser.add_argument("--port", default=8080, type=int) + parser.add_argument( + "--no-stream", action="store_true", help="Disable streaming mode tests" + ) + parser.add_argument( + "--stream-only", action="store_true", help="Only run streaming mode tests" + ) + parser.add_argument( + "--test", + help="Run only the test whose name contains this substring (case-insensitive)", + ) + args = parser.parse_args() + + url = f"http://{args.host}:{args.port}/v1/chat/completions" + print_info(f"Testing server at {url}") + print_warn( + "This script expects the target model to emit multiple tool calls in a " + "single assistant turn. Run it only against parallel-tool-capable models." + ) + + modes: list[bool] = [] + if not args.stream_only: + modes.append(False) + if not args.no_stream: + modes.append(True) + + cases: list[dict] = ALL_TEST_CASES + if args.test: + name_filter = args.test.lower() + cases = [c for c in cases if name_filter in str(c["name"]).lower()] + if not cases: + print_fail(f"No test cases matched '{args.test}'") + sys.exit(1) + + total = 0 + passed = 0 + for stream in modes: + for case in cases: + total += 1 + if run_test(url, case, stream=stream): + passed += 1 + + color = GREEN if passed == total else RED + _print(f"\n{BOLD}{color}{'─' * 60}{RESET}") + _print(f"{BOLD}{color} Results: {passed}/{total} passed{RESET}") + _print(f"{BOLD}{color}{'─' * 60}{RESET}\n") + sys.exit(0 if passed == total else 1) + + +if __name__ == "__main__": + main() diff --git a/scripts/server-test-structured.py b/scripts/server-test-structured.py new file mode 100755 index 000000000..98ff473b9 --- /dev/null +++ b/scripts/server-test-structured.py @@ -0,0 +1,980 @@ +#!/usr/bin/env python3 +""" +Test structured output capability via chat completions endpoint. + +Each test case contains: + - response_format: OpenAI-compatible response_format specification + (json_schema only — llama.cpp does not support json_object) + - messages: initial conversation messages + - tools (optional): tool definitions (for mixed tool + structured tests) + - mock_tool_responses (optional): dict mapping tool_name -> callable(arguments) -> str (JSON) + - apply_stage: "always" to apply response_format to every request, + "after_tools" to run the tool loop plain, then request a + structured summary in a follow-up user turn. + - followup (optional, for after_tools): user message appended before the + final structured call. + - validate: callable(parsed_json, tool_calls_history, raw_content) -> (passed: bool, reason: str) +""" + +import argparse +import json +import requests +import sys +from typing import Any, cast + +# --------------------------------------------------------------------------- +# Color / formatting helpers +# --------------------------------------------------------------------------- + +RESET = "\x1b[0m" +BOLD = "\x1b[1m" +DIM = "\x1b[2m" +CYAN = "\x1b[36m" +YELLOW = "\x1b[33m" +GREEN = "\x1b[32m" +RED = "\x1b[31m" +BLUE = "\x1b[34m" +WHITE = "\x1b[97m" +MAGENTA = "\x1b[35m" + + +def _print(text="", end="\n"): + sys.stdout.write(text + end) + sys.stdout.flush() + + +def print_header(title): + bar = "─" * 60 + _print(f"\n{BOLD}{CYAN}┌{bar}┐{RESET}") + _print( + f"{BOLD}{CYAN}│ {WHITE}{title}{CYAN}{' ' * max(0, 58 - len(title))}│{RESET}" + ) + _print(f"{BOLD}{CYAN}└{bar}┘{RESET}") + + +def print_tool_call(name, args): + args_str = json.dumps(args) + _print( + f"\n {BOLD}{YELLOW}⚙ tool call{RESET} {CYAN}{name}{RESET}{DIM}({args_str}){RESET}" + ) + + +def print_tool_result(result): + preview = result[:160] + ("…" if len(result) > 160 else "") + _print(f" {DIM}{BLUE}↳ result{RESET} {DIM}{preview}{RESET}") + + +def print_model_output(text): + sys.stdout.write(text) + sys.stdout.flush() + + +def print_pass(reason): + _print(f"\n{BOLD}{GREEN}✔ PASS{RESET} {reason}") + + +def print_fail(reason): + _print(f"\n{BOLD}{RED}✘ FAIL{RESET} {reason}") + + +def print_info(msg): + _print(f"{DIM}{msg}{RESET}") + + +def print_schema_note(label, rf): + kind = rf.get("type", "?") + name = "" + if kind == "json_schema": + name = rf.get("json_schema", {}).get("name", "") + _print(f"{DIM}{MAGENTA} ⟐ response_format [{label}]: {kind}" + f"{(' / ' + name) if name else ''}{RESET}") + + +# --------------------------------------------------------------------------- +# HTTP helpers +# --------------------------------------------------------------------------- + + +def chat_completion(url, messages, tools=None, response_format=None, stream=False): + payload = { + "messages": messages, + "stream": stream, + "max_tokens": 4096, + } + if tools: + payload["tools"] = tools + payload["tool_choice"] = "auto" + if response_format is not None: + payload["response_format"] = response_format + + try: + response = requests.post(url, json=payload, stream=stream) + response.raise_for_status() + except requests.exceptions.RequestException as e: + body = e.response.content if (e.response is not None) else b"" + print_fail(f"Request error: {e} | body: {body}") + return None + + full_content = "" + reasoning_content = "" + tool_calls: list[dict] = [] + + if stream: + for line in response.iter_lines(): + if not line: + continue + decoded = line.decode("utf-8") + if not decoded.startswith("data: "): + continue + data_str = decoded[6:] + if data_str == "[DONE]": + break + try: + data = json.loads(data_str) + except json.JSONDecodeError: + continue + choices = data.get("choices", []) + if not choices: + continue + delta = choices[0].get("delta", {}) + if delta.get("reasoning_content"): + reasoning_content += delta["reasoning_content"] + if delta.get("content"): + full_content += delta["content"] + print_model_output(delta["content"]) + for tc in delta.get("tool_calls", []): + idx = tc.get("index", 0) + while len(tool_calls) <= idx: + tool_calls.append( + { + "id": "", + "type": "function", + "function": {"name": "", "arguments": ""}, + } + ) + if "id" in tc: + tool_calls[idx]["id"] += tc["id"] + if "function" in tc: + if "name" in tc["function"]: + tool_calls[idx]["function"]["name"] += tc["function"]["name"] + if "arguments" in tc["function"]: + tool_calls[idx]["function"]["arguments"] += tc["function"][ + "arguments" + ] + else: + data = response.json() + choices = data.get("choices", []) + if choices: + msg = choices[0].get("message", {}) + full_content = msg.get("content") or "" + reasoning_content = msg.get("reasoning_content") or "" + tool_calls = msg.get("tool_calls") or [] + if full_content: + print_model_output(full_content) + + result = {"content": full_content, "tool_calls": tool_calls} + if reasoning_content: + result["reasoning_content"] = reasoning_content + return result + + +def run_tool_loop( + url, messages, tools, mock_tool_responses, stream, response_format=None, + max_turns=6, +): + """ + Drive the tool-call loop. If response_format is provided it is applied to + every request. Returns (all_tool_calls, final_messages, final_content). + """ + msgs = list(messages) + all_tool_calls: list[dict] = [] + + for _ in range(max_turns): + result = chat_completion( + url, msgs, tools=tools, response_format=response_format, stream=stream + ) + if result is None: + return all_tool_calls, msgs, None + + tcs = result.get("tool_calls") or [] + content = result.get("content") or "" + + if not tcs: + if content: + _print(f"\n{DIM}{'·' * 60}{RESET}") + return all_tool_calls, msgs, content + + all_tool_calls.extend(tcs) + + assistant_msg: dict = { + "role": "assistant", + "content": content, + "tool_calls": tcs, + } + reasoning = result.get("reasoning_content") + if reasoning: + assistant_msg["reasoning_content"] = reasoning + msgs.append(assistant_msg) + + for tc in tcs: + tool_name = tc["function"]["name"] + try: + args = json.loads(tc["function"]["arguments"]) + except json.JSONDecodeError: + args = {} + + print_tool_call(tool_name, args) + + mock_fn = mock_tool_responses.get(tool_name) if mock_tool_responses else None + if mock_fn: + tool_result = mock_fn(args) + else: + tool_result = json.dumps({"error": f"Unknown tool: {tool_name}"}) + + print_tool_result(tool_result) + + msgs.append( + { + "role": "tool", + "tool_call_id": tc.get("id", ""), + "content": tool_result, + } + ) + + return all_tool_calls, msgs, None + + +# --------------------------------------------------------------------------- +# Test case runner +# --------------------------------------------------------------------------- + + +def _try_parse_json(text): + """Attempt to parse text as JSON, trimming common markdown fences.""" + if text is None: + return None + stripped = text.strip() + if stripped.startswith("```"): + lines = stripped.splitlines() + if lines and lines[0].startswith("```"): + lines = lines[1:] + if lines and lines[-1].strip().startswith("```"): + lines = lines[:-1] + stripped = "\n".join(lines).strip() + try: + return json.loads(stripped) + except json.JSONDecodeError: + return None + + +def run_test(url, test_case, stream): + name = test_case["name"] + mode = f"{'stream' if stream else 'non-stream'}" + apply_stage = test_case.get("apply_stage", "always") + print_header(f"{name} [{mode}] ({apply_stage})") + + response_format = test_case["response_format"] + print_schema_note(apply_stage, response_format) + + tools = test_case.get("tools") + mocks = test_case.get("mock_tool_responses") or {} + + all_tcs: list[dict] = [] + final_content = None + + if apply_stage == "always": + all_tcs, _msgs, final_content = run_tool_loop( + url, + messages=list(test_case["messages"]), + tools=tools, + mock_tool_responses=mocks, + stream=stream, + response_format=response_format, + ) + elif apply_stage == "after_tools": + # Phase 1: plain tool loop, no response_format applied yet. + all_tcs, msgs, interim_content = run_tool_loop( + url, + messages=list(test_case["messages"]), + tools=tools, + mock_tool_responses=mocks, + stream=stream, + response_format=None, + ) + if interim_content: + msgs.append({"role": "assistant", "content": interim_content}) + followup = test_case.get( + "followup", + "Now output the answer strictly as JSON matching the provided schema. " + "Do not include commentary.", + ) + msgs.append({"role": "user", "content": followup}) + + # Phase 2: request final structured output. Tools are not passed so the + # model focuses on producing the schema-constrained answer. + _print(f"\n{DIM}{MAGENTA} ⟐ follow-up turn with response_format applied{RESET}") + result = chat_completion( + url, msgs, tools=None, response_format=response_format, stream=stream + ) + final_content = result["content"] if result else None + else: + print_fail(f"Unknown apply_stage: {apply_stage}") + return False + + if final_content is None: + print_fail("No final content from server.") + return False + + parsed = _try_parse_json(final_content) + if parsed is None: + print_fail(f"Final content is not valid JSON: {final_content[:200]!r}") + return False + + passed, reason = test_case["validate"](parsed, all_tcs, final_content) + if passed: + print_pass(reason) + else: + print_fail(reason) + return passed + + +# --------------------------------------------------------------------------- +# Test case definitions +# --------------------------------------------------------------------------- + +# ---- Test 1: Book metadata extraction (always / json_schema) ---- + +_BOOK_SCHEMA = { + "type": "json_schema", + "json_schema": { + "name": "book_metadata", + "strict": True, + "schema": { + "type": "object", + "additionalProperties": False, + "properties": { + "title": {"type": "string"}, + "author": {"type": "string"}, + "year": {"type": "integer"}, + "genre": { + "type": "string", + "enum": [ + "fiction", + "non-fiction", + "fantasy", + "sci-fi", + "mystery", + "biography", + "history", + "other", + ], + }, + "page_count": {"type": "integer"}, + }, + "required": ["title", "author", "year", "genre", "page_count"], + }, + }, +} + +BOOK_TEST_CASE = { + "name": "Book metadata extraction (json_schema, always)", + "response_format": _BOOK_SCHEMA, + "apply_stage": "always", + "messages": [ + { + "role": "user", + "content": ( + "Extract book metadata from this description: " + "'Dune is a 1965 science fiction epic by Frank Herbert, spanning roughly " + "688 pages in its first edition, set on the desert planet Arrakis.' " + "Return the data as JSON." + ), + } + ], + "validate": lambda parsed, tcs, raw: _validate_book(parsed), +} + + +def _validate_book(parsed): + required = {"title", "author", "year", "genre", "page_count"} + missing = required - parsed.keys() + if missing: + return False, f"Missing fields: {missing}" + if not isinstance(parsed["title"], str) or not parsed["title"]: + return False, "title must be a non-empty string" + if not isinstance(parsed["author"], str) or "herbert" not in parsed["author"].lower(): + return False, f"author unexpected: {parsed['author']!r}" + if not isinstance(parsed["year"], int) or parsed["year"] != 1965: + return False, f"year should be 1965, got {parsed['year']!r}" + if parsed["genre"] not in { + "fiction", "non-fiction", "fantasy", "sci-fi", "mystery", + "biography", "history", "other", + }: + return False, f"genre not in enum: {parsed['genre']!r}" + if not isinstance(parsed["page_count"], int) or parsed["page_count"] <= 0: + return False, f"page_count should be positive int: {parsed['page_count']!r}" + return True, f"Book: {parsed['title']} ({parsed['year']}) / {parsed['genre']}" + + +# ---- Test 2: Sentiment classification (always / enum-constrained) ---- + +_SENTIMENT_SCHEMA = { + "type": "json_schema", + "json_schema": { + "name": "sentiment_analysis", + "strict": True, + "schema": { + "type": "object", + "additionalProperties": False, + "properties": { + "sentiment": { + "type": "string", + "enum": ["positive", "negative", "neutral"], + }, + "confidence": {"type": "number"}, + "keywords": { + "type": "array", + "items": {"type": "string"}, + "minItems": 1, + "maxItems": 5, + }, + }, + "required": ["sentiment", "confidence", "keywords"], + }, + }, +} + +SENTIMENT_TEST_CASE = { + "name": "Sentiment analysis with enum and array", + "response_format": _SENTIMENT_SCHEMA, + "apply_stage": "always", + "messages": [ + { + "role": "user", + "content": ( + "Analyse the sentiment of this review and return JSON with the " + "detected sentiment label, a confidence score between 0 and 1, " + "and up to five keyword strings that drove the classification:\n\n" + "'This product completely exceeded my expectations. The build " + "quality is phenomenal, it arrived a day early, and customer " + "support was delightful when I had a setup question.'" + ), + } + ], + "validate": lambda parsed, tcs, raw: _validate_sentiment(parsed), +} + + +def _validate_sentiment(parsed): + if parsed.get("sentiment") not in {"positive", "negative", "neutral"}: + return False, f"sentiment not in enum: {parsed.get('sentiment')!r}" + if parsed["sentiment"] != "positive": + return False, f"expected positive sentiment, got {parsed['sentiment']}" + conf = parsed.get("confidence") + if not isinstance(conf, (int, float)) or not (0.0 <= conf <= 1.0): + return False, f"confidence not in [0,1]: {conf!r}" + kws = parsed.get("keywords") + if not isinstance(kws, list) or not (1 <= len(kws) <= 5): + return False, f"keywords length out of range: {kws!r}" + if not all(isinstance(k, str) and k for k in kws): + return False, f"keywords must be non-empty strings: {kws!r}" + return True, f"sentiment={parsed['sentiment']} conf={conf} kws={kws}" + + +# ---- Test 3: Nested recipe schema (always) ---- + +_RECIPE_SCHEMA = { + "type": "json_schema", + "json_schema": { + "name": "recipe", + "strict": True, + "schema": { + "type": "object", + "additionalProperties": False, + "properties": { + "name": {"type": "string"}, + "servings": {"type": "integer"}, + "ingredients": { + "type": "array", + "minItems": 2, + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "item": {"type": "string"}, + "quantity": {"type": "string"}, + }, + "required": ["item", "quantity"], + }, + }, + "steps": { + "type": "array", + "minItems": 2, + "items": {"type": "string"}, + }, + "prep_time_minutes": {"type": "integer"}, + }, + "required": ["name", "servings", "ingredients", "steps", "prep_time_minutes"], + }, + }, +} + +RECIPE_TEST_CASE = { + "name": "Nested recipe with arrays of objects", + "response_format": _RECIPE_SCHEMA, + "apply_stage": "always", + "messages": [ + { + "role": "user", + "content": ( + "Give me a simple 4-serving scrambled eggs recipe as structured JSON. " + "Include the recipe name, servings, ingredients (each with item and " + "quantity), preparation steps, and total prep time in minutes." + ), + } + ], + "validate": lambda parsed, tcs, raw: _validate_recipe(parsed), +} + + +def _validate_recipe(parsed): + required = {"name", "servings", "ingredients", "steps", "prep_time_minutes"} + missing = required - parsed.keys() + if missing: + return False, f"Missing fields: {missing}" + if not isinstance(parsed["name"], str) or not parsed["name"]: + return False, "name must be a non-empty string" + if not isinstance(parsed["servings"], int) or parsed["servings"] <= 0: + return False, f"servings must be positive int: {parsed['servings']!r}" + ings = parsed["ingredients"] + if not isinstance(ings, list) or len(ings) < 2: + return False, f"ingredients must be array of >=2: got {ings!r}" + for i, ing in enumerate(ings): + if not isinstance(ing, dict): + return False, f"ingredient[{i}] is not an object: {ing!r}" + ing_d = cast(dict[str, Any], ing) + item_val = ing_d.get("item") + qty_val = ing_d.get("quantity") + if item_val is None or qty_val is None: + return False, f"ingredient[{i}] missing item/quantity: {ing!r}" + if not isinstance(item_val, str) or not isinstance(qty_val, str): + return False, f"ingredient[{i}] fields must be strings: {ing!r}" + steps = parsed["steps"] + if not isinstance(steps, list) or len(steps) < 2: + return False, f"steps must be array of >=2 strings: got {steps!r}" + if not all(isinstance(s, str) and s for s in steps): + return False, "all steps must be non-empty strings" + pt = parsed["prep_time_minutes"] + if not isinstance(pt, int) or pt <= 0: + return False, f"prep_time_minutes must be positive int: {pt!r}" + return True, f"recipe '{parsed['name']}' with {len(ings)} ingredients, {len(steps)} steps" + + +# ---- Test 4: Tool call -> structured product comparison (after_tools) ---- + +_SHOP_TOOLS = [ + { + "type": "function", + "function": { + "name": "search_products", + "description": "Search a product catalogue by keyword.", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string"}, + }, + "required": ["query"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_product_details", + "description": "Get detailed specs for a product by ID.", + "parameters": { + "type": "object", + "properties": { + "product_id": {"type": "string"}, + }, + "required": ["product_id"], + }, + }, + }, +] + +_SHOP_SEARCH_RESULT = { + "results": [ + {"product_id": "LAP-001", "title": "AeroBook 13 Pro", "price": 1399.0, "rating": 4.7}, + {"product_id": "LAP-002", "title": "QuantumSlim 14", "price": 1199.0, "rating": 4.4}, + {"product_id": "LAP-003", "title": "NimbusWork Ultra 15", "price": 999.0, "rating": 4.2}, + ], +} +_SHOP_PRODUCT_DETAILS = { + "LAP-001": { + "product_id": "LAP-001", + "title": "AeroBook 13 Pro", + "cpu": "M-series 10-core", + "ram_gb": 16, + "storage_gb": 512, + "battery_hours": 18, + "weight_kg": 1.24, + "price": 1399.0, + }, + "LAP-002": { + "product_id": "LAP-002", + "title": "QuantumSlim 14", + "cpu": "Core i7 12-core", + "ram_gb": 16, + "storage_gb": 512, + "battery_hours": 12, + "weight_kg": 1.35, + "price": 1199.0, + }, + "LAP-003": { + "product_id": "LAP-003", + "title": "NimbusWork Ultra 15", + "cpu": "Ryzen 7 8-core", + "ram_gb": 16, + "storage_gb": 1024, + "battery_hours": 10, + "weight_kg": 1.70, + "price": 999.0, + }, +} + + +def _shop_details_mock(args): + pid = args.get("product_id", "") + if pid in _SHOP_PRODUCT_DETAILS: + return json.dumps(_SHOP_PRODUCT_DETAILS[pid]) + return json.dumps({"error": f"unknown product_id: {pid}"}) + + +_SHOP_COMPARISON_SCHEMA = { + "type": "json_schema", + "json_schema": { + "name": "laptop_comparison", + "strict": True, + "schema": { + "type": "object", + "additionalProperties": False, + "properties": { + "recommendation": {"type": "string"}, + "ranked_candidates": { + "type": "array", + "minItems": 2, + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "product_id": {"type": "string"}, + "title": {"type": "string"}, + "score": {"type": "number"}, + "reason": {"type": "string"}, + }, + "required": ["product_id", "title", "score", "reason"], + }, + }, + }, + "required": ["recommendation", "ranked_candidates"], + }, + }, +} + +SHOP_COMPARISON_TEST_CASE = { + "name": "Tool calls then structured laptop comparison (after_tools)", + "response_format": _SHOP_COMPARISON_SCHEMA, + "apply_stage": "after_tools", + "tools": _SHOP_TOOLS, + "mock_tool_responses": { + "search_products": lambda _: json.dumps(_SHOP_SEARCH_RESULT), + "get_product_details": _shop_details_mock, + }, + "messages": [ + { + "role": "user", + "content": ( + "I need a lightweight laptop for travel. Please search the catalogue " + "for 'ultraportable laptop', then fetch detailed specs for at least two " + "of the top candidates. Once you've gathered the data I'll ask you to " + "produce a structured comparison." + ), + } + ], + "followup": ( + "Thanks. Now produce the final comparison strictly as JSON matching the " + "laptop_comparison schema: your single best recommendation (the product_id), " + "and a ranked_candidates array of at least two laptops, each with " + "product_id, title, a numeric score, and a short reason." + ), + "validate": lambda parsed, tcs, raw: _validate_shop_comparison(parsed, tcs), +} + + +def _validate_shop_comparison(parsed, tcs): + names = [tc["function"]["name"] for tc in tcs] + if "search_products" not in names: + return False, f"expected search_products tool call, got {names}" + if "get_product_details" not in names: + return False, f"expected get_product_details tool call, got {names}" + if "recommendation" not in parsed or not isinstance(parsed["recommendation"], str): + return False, f"recommendation missing or not a string: {parsed!r}" + cands = parsed.get("ranked_candidates") + if not isinstance(cands, list) or len(cands) < 2: + return False, f"ranked_candidates must be >=2: {cands!r}" + valid_ids = set(_SHOP_PRODUCT_DETAILS.keys()) + candidate_pids: list = [] + for i, c in enumerate(cands): + if not isinstance(c, dict): + return False, f"candidate[{i}] not an object: {c!r}" + c_d = cast(dict[str, Any], c) + pid = c_d.get("product_id") + title = c_d.get("title") + score = c_d.get("score") + reason = c_d.get("reason") + for k, v in (("product_id", pid), ("title", title), + ("score", score), ("reason", reason)): + if v is None: + return False, f"candidate[{i}] missing {k}: {c!r}" + if pid not in valid_ids: + return False, f"candidate[{i}].product_id not in catalogue: {pid!r}" + if not isinstance(score, (int, float)): + return False, f"candidate[{i}].score not numeric: {score!r}" + candidate_pids.append(pid) + recommendation = parsed["recommendation"] + if recommendation not in valid_ids and recommendation not in candidate_pids: + return False, f"recommendation {recommendation!r} not in candidates" + return True, ( + f"tools={names}; recommended={parsed['recommendation']}; " + f"{len(cands)} ranked candidates" + ) + + +# ---- Test 5: Multi-step research then structured report (after_tools) ---- + +_RESEARCH_TOOLS = [ + { + "type": "function", + "function": { + "name": "get_country_stats", + "description": "Fetch basic statistics for a country (population, GDP, capital).", + "parameters": { + "type": "object", + "properties": { + "country": {"type": "string"}, + }, + "required": ["country"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_climate_info", + "description": "Fetch climate information for a country.", + "parameters": { + "type": "object", + "properties": { + "country": {"type": "string"}, + }, + "required": ["country"], + }, + }, + }, +] + +_COUNTRY_STATS = { + "norway": { + "country": "Norway", + "capital": "Oslo", + "population": 5_480_000, + "gdp_usd_trillion": 0.48, + "currency": "NOK", + } +} +_CLIMATE_INFO = { + "norway": { + "country": "Norway", + "climate_zone": "subarctic / temperate coastal", + "avg_winter_temp_c": -4.5, + "avg_summer_temp_c": 16.0, + "annual_precipitation_mm": 1400, + } +} + + +def _country_stats_mock(args): + c = args.get("country", "").strip().lower() + if c in _COUNTRY_STATS: + return json.dumps(_COUNTRY_STATS[c]) + return json.dumps({"error": f"unknown country: {c}"}) + + +def _climate_info_mock(args): + c = args.get("country", "").strip().lower() + if c in _CLIMATE_INFO: + return json.dumps(_CLIMATE_INFO[c]) + return json.dumps({"error": f"unknown country: {c}"}) + + +_RESEARCH_REPORT_SCHEMA = { + "type": "json_schema", + "json_schema": { + "name": "country_report", + "strict": True, + "schema": { + "type": "object", + "additionalProperties": False, + "properties": { + "country": {"type": "string"}, + "capital": {"type": "string"}, + "population": {"type": "integer"}, + "climate_summary": {"type": "string"}, + "highlights": { + "type": "array", + "minItems": 2, + "maxItems": 5, + "items": {"type": "string"}, + }, + "suitable_for_tourism": {"type": "boolean"}, + }, + "required": [ + "country", "capital", "population", + "climate_summary", "highlights", "suitable_for_tourism", + ], + }, + }, +} + +COUNTRY_REPORT_TEST_CASE = { + "name": "Research pipeline then structured country report (after_tools)", + "response_format": _RESEARCH_REPORT_SCHEMA, + "apply_stage": "after_tools", + "tools": _RESEARCH_TOOLS, + "mock_tool_responses": { + "get_country_stats": _country_stats_mock, + "get_climate_info": _climate_info_mock, + }, + "messages": [ + { + "role": "user", + "content": ( + "I'm preparing a short briefing on Norway. Please call the " + "get_country_stats and get_climate_info tools to gather data " + "first. Afterwards I'll ask for a structured summary." + ), + } + ], + "followup": ( + "Based on the tool results, produce the briefing as JSON matching the " + "country_report schema. Populate every required field and provide between " + "two and five highlights." + ), + "validate": lambda parsed, tcs, raw: _validate_country_report(parsed, tcs), +} + + +def _validate_country_report(parsed, tcs): + names = [tc["function"]["name"] for tc in tcs] + for required_tool in ("get_country_stats", "get_climate_info"): + if required_tool not in names: + return False, f"missing tool call {required_tool!r}: got {names}" + required = { + "country", "capital", "population", + "climate_summary", "highlights", "suitable_for_tourism", + } + missing = required - parsed.keys() + if missing: + return False, f"missing report fields: {missing}" + if "norway" not in parsed["country"].lower(): + return False, f"country should reference Norway: {parsed['country']!r}" + if "oslo" not in parsed["capital"].lower(): + return False, f"capital should be Oslo: {parsed['capital']!r}" + if not isinstance(parsed["population"], int) or parsed["population"] < 1_000_000: + return False, f"population implausible: {parsed['population']!r}" + if not isinstance(parsed["climate_summary"], str) or not parsed["climate_summary"]: + return False, "climate_summary must be a non-empty string" + hls = parsed["highlights"] + if not isinstance(hls, list) or not (2 <= len(hls) <= 5): + return False, f"highlights length out of range: {hls!r}" + if not all(isinstance(h, str) and h for h in hls): + return False, "each highlight must be a non-empty string" + if not isinstance(parsed["suitable_for_tourism"], bool): + return False, f"suitable_for_tourism must be bool: {parsed['suitable_for_tourism']!r}" + return True, ( + f"tools={names}; report for {parsed['country']} " + f"(pop {parsed['population']}, {len(hls)} highlights)" + ) + + +# --------------------------------------------------------------------------- +# All test cases +# --------------------------------------------------------------------------- + +ALL_TEST_CASES = [ + BOOK_TEST_CASE, + SENTIMENT_TEST_CASE, + RECIPE_TEST_CASE, + SHOP_COMPARISON_TEST_CASE, + COUNTRY_REPORT_TEST_CASE, +] + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser( + description="Test llama-server structured-output capability." + ) + parser.add_argument("--host", default="localhost") + parser.add_argument("--port", default=8080, type=int) + parser.add_argument( + "--no-stream", action="store_true", help="Disable streaming mode tests" + ) + parser.add_argument( + "--stream-only", action="store_true", help="Only run streaming mode tests" + ) + parser.add_argument( + "--test", + help="Run only the test whose name contains this substring (case-insensitive)", + ) + args = parser.parse_args() + + url = f"http://{args.host}:{args.port}/v1/chat/completions" + print_info(f"Testing server at {url}") + + modes: list[bool] = [] + if not args.stream_only: + modes.append(False) + if not args.no_stream: + modes.append(True) + + cases: list[dict] = ALL_TEST_CASES + if args.test: + name_filter = args.test.lower() + cases = [c for c in cases if name_filter in str(c["name"]).lower()] + if not cases: + print_fail(f"No test cases matched '{args.test}'") + sys.exit(1) + + total = 0 + passed = 0 + for stream in modes: + for case in cases: + total += 1 + if run_test(url, case, stream=stream): + passed += 1 + + color = GREEN if passed == total else RED + _print(f"\n{BOLD}{color}{'─' * 60}{RESET}") + _print(f"{BOLD}{color} Results: {passed}/{total} passed{RESET}") + _print(f"{BOLD}{color}{'─' * 60}{RESET}\n") + sys.exit(0 if passed == total else 1) + + +if __name__ == "__main__": + main() diff --git a/tools/cli/cli.cpp b/tools/cli/cli.cpp index 5bdb1e78f..5136e52a7 100644 --- a/tools/cli/cli.cpp +++ b/tools/cli/cli.cpp @@ -207,6 +207,8 @@ struct cli_context { auto meta = ctx_server.get_meta(); auto & chat_params = meta.chat_params; + auto caps = common_chat_templates_get_caps(chat_params.tmpls.get()); + common_chat_templates_inputs inputs; inputs.messages = common_chat_msgs_parse_oaicompat(messages); inputs.tools = {}; // TODO @@ -214,7 +216,7 @@ struct cli_context { inputs.json_schema = ""; // TODO inputs.grammar = ""; // TODO inputs.use_jinja = chat_params.use_jinja; - inputs.parallel_tool_calls = false; + inputs.parallel_tool_calls = caps["supports_parallel_tool_calls"]; inputs.add_generation_prompt = true; inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; inputs.force_pure_content = chat_params.force_pure_content; diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 18a317e1d..ad8834e31 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -1027,6 +1027,8 @@ json oaicompat_chat_params_parse( } } + auto caps = common_chat_templates_get_caps(opt.tmpls.get()); + common_chat_templates_inputs inputs; inputs.messages = common_chat_msgs_parse_oaicompat(messages); inputs.tools = common_chat_tools_parse_oaicompat(tools); @@ -1034,7 +1036,7 @@ json oaicompat_chat_params_parse( inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); inputs.grammar = grammar; inputs.use_jinja = opt.use_jinja; - inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); + inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", caps["supports_parallel_tool_calls"]); inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); inputs.reasoning_format = opt.reasoning_format; if (body.contains("reasoning_format")) {