Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.devops/openvino.Dockerfile
#	.github/workflows/build-self-hosted.yml
#	.github/workflows/build.yml
#	common/chat.cpp
#	docs/backend/OPENVINO.md
#	examples/speculative-simple/speculative-simple.cpp
#	ggml/src/ggml-hexagon/ggml-hexagon.cpp
#	ggml/src/ggml-hexagon/htp/CMakeLists.txt
#	ggml/src/ggml-hexagon/htp/htp-ctx.h
#	ggml/src/ggml-hexagon/htp/htp-ops.h
#	ggml/src/ggml-hexagon/htp/main.c
#	ggml/src/ggml-hexagon/libggml-htp.inf
#	ggml/src/ggml-openvino/ggml-decoder.cpp
#	ggml/src/ggml-openvino/ggml-openvino-extra.cpp
#	ggml/src/ggml-openvino/ggml-openvino.cpp
#	ggml/src/ggml-openvino/ggml-quants.cpp
#	ggml/src/ggml-openvino/openvino/op/rope.cpp
#	ggml/src/ggml-openvino/openvino/op_table.cpp
#	ggml/src/ggml-openvino/openvino/op_table.h
#	ggml/src/ggml-openvino/openvino/translate_session.cpp
#	ggml/src/ggml-openvino/openvino/utils.cpp
#	ggml/src/ggml-openvino/openvino/utils.h
#	ggml/src/ggml-openvino/utils.cpp
#	ggml/src/ggml-openvino/utils.h
#	ggml/src/ggml-sycl/common.hpp
#	ggml/src/ggml-sycl/convert.cpp
#	ggml/src/ggml-sycl/convert.hpp
#	ggml/src/ggml-sycl/gemm.hpp
#	ggml/src/ggml-sycl/ggml-sycl.cpp
#	ggml/src/ggml-sycl/set_rows.cpp
#	ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
#	ggml/src/ggml-webgpu/ggml-webgpu.cpp
#	scripts/sync_vendor.py
#	tests/CMakeLists.txt
#	tests/test-chat.cpp
#	tools/cli/cli.cpp
#	tools/mtmd/CMakeLists.txt
#	tools/server/CMakeLists.txt
This commit is contained in:
Concedo 2026-04-23 00:55:05 +08:00
commit 0755f27372
42 changed files with 1531 additions and 3199 deletions

View file

@ -3124,14 +3124,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"token budget for thinking: -1 for unrestricted, 0 for immediate end, N>0 for token budget (default: -1)",
[](common_params & params, int value) {
if (value < -1) { throw std::invalid_argument("invalid value"); }
params.reasoning_budget = value;
params.sampling.reasoning_budget_tokens = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET"));
add_opt(common_arg(
{"--reasoning-budget-message"}, "MESSAGE",
"message injected before the end-of-thinking tag when reasoning budget is exhausted (default: none)",
[](common_params & params, const std::string & value) {
params.reasoning_budget_message = value;
params.sampling.reasoning_budget_message = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE"));
add_opt(common_arg(
@ -3904,6 +3904,17 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--spec-default"},
string_format("enable default speculative decoding config"),
[](common_params & params) {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
params.speculative.ngram_size_n = 24;
params.speculative.n_min = 48;
params.speculative.n_max = 64;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
return ctx_arg;
}

View file

@ -408,6 +408,25 @@ json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msg
return render_message_to_json(msgs, c);
}
json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
if (tools.empty()) {
return json();
}
auto result = json::array();
for (const auto & tool : tools) {
result.push_back({
{ "type", "function" },
{ "function", {
{ "name", tool.name },
{ "description", tool.description },
{ "parameters", json::parse(tool.parameters) },
}},
});
}
return result;
}
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
std::vector<common_chat_tool> result;
@ -443,60 +462,9 @@ std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & too
return result;
}
json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
if (tools.empty()) {
return json();
}
auto result = json::array();
for (const auto & tool : tools) {
result.push_back({
{ "type", "function" },
{ "function",
{
{ "name", tool.name },
{ "description", tool.description },
{ "parameters", json::parse(tool.parameters) },
} },
});
}
return result;
}
json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
json delta = json::object();
if (!diff.reasoning_content_delta.empty()) {
delta["reasoning_content"] = diff.reasoning_content_delta;
}
if (!diff.content_delta.empty()) {
delta["content"] = diff.content_delta;
}
if (diff.tool_call_index != std::string::npos) {
json tool_call;
tool_call["index"] = diff.tool_call_index;
if (!diff.tool_call_delta.id.empty()) {
tool_call["id"] = diff.tool_call_delta.id;
tool_call["type"] = "function";
}
if (!diff.tool_call_delta.name.empty() || !diff.tool_call_delta.arguments.empty()) {
json function = json::object();
if (!diff.tool_call_delta.name.empty()) {
function["name"] = diff.tool_call_delta.name;
}
if (!diff.tool_call_delta.arguments.empty()) {
function["arguments"] = diff.tool_call_delta.arguments;
}
tool_call["function"] = function;
}
delta["tool_calls"] = json::array({ tool_call });
}
return delta;
}
#include "common/unicode.h"
#include "peg-parser.cpp"
#include "chat-peg-parser.cpp"
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
if (use_jinja) {
try {

View file

@ -256,14 +256,13 @@ bool common_chat_templates_support_enable_thinking(const common_chat_templates *
// Parses a JSON array of messages in OpenAI's chat completion API format.
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const nlohmann::ordered_json & messages);
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools);
// DEPRECATED: only used in tests
nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools);
nlohmann::ordered_json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
nlohmann::ordered_json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
// get template caps, useful for reporting to server /props endpoint
std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_templates * chat_templates);

View file

@ -275,6 +275,7 @@ struct common_params_sampling {
std::vector<llama_token> reasoning_budget_start; // start tag token sequence
std::vector<llama_token> reasoning_budget_end; // end tag token sequence
std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag)
std::string reasoning_budget_message; // message injected before end tag when budget exhausted
bool backend_sampling = false;
@ -582,8 +583,6 @@ struct common_params {
bool force_pure_content_parser = false;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
int enable_reasoning = -1; // -1 = auto, 0 = disable, 1 = enable
int reasoning_budget = -1;
std::string reasoning_budget_message; // message injected before end tag when budget exhausted
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time

View file

@ -749,6 +749,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
mod.reset();
n_low = 0;
i_last = 0;
}
} else {
n_low = 0;

View file

@ -11855,7 +11855,7 @@ class LLaDAMoEModel(TextModel):
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("HunYuanDenseV1ForCausalLM", "HunYuanVLForConditionalGeneration")
@ModelBase.register("HunYuanDenseV1ForCausalLM")
class HunYuanModel(TextModel):
model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE
@ -11994,28 +11994,58 @@ class HunYuanModel(TextModel):
@ModelBase.register("HunYuanVLForConditionalGeneration")
class HunyuanOCRVisionModel(MmprojModel):
class HunyuanVLVisionModel(MmprojModel):
# Handles both HunyuanOCR and HunyuanVL, which share the HF architecture name
# "HunYuanVLForConditionalGeneration" and the `vit.perceive.*` vision layout.
# Each variant maps to a different projector type in clip.cpp so image
# preprocessing follows the correct code path.
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
# HunyuanOCR uses max_image_size instead of image_size
# HunyuanOCR / HunyuanVL uses max_image_size instead of image_size
if "image_size" not in self.hparams_vision:
self.hparams_vision["image_size"] = self.hparams_vision.get("max_image_size", 2048)
@staticmethod
def is_ocr_variant(hparams: dict) -> bool:
"""Return True for HunyuanOCR, False for HunyuanVL.
The projector's output dim must equal the text model's hidden_size by
construction (that's what "projector" means). HunyuanOCR pairs a 1B text
backbone (hidden=1024); HunyuanVL pairs a 4B one (hidden=3072). So the
ViT -> LLM projection dim is a hard architectural signature, not a
magic number.
"""
vision_out = int((hparams.get("vision_config") or {}).get("out_hidden_size", 0))
return vision_out == 1024
def set_gguf_parameters(self):
super().set_gguf_parameters()
assert self.hparams_vision is not None
hparams = self.hparams_vision
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.HUNYUANOCR)
self.gguf_writer.add_vision_use_gelu(True)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("rms_norm_eps", 1e-5))
self.gguf_writer.add_vision_spatial_merge_size(hparams.get("spatial_merge_size", 2))
self.gguf_writer.add_vision_min_pixels(self.preprocessor_config["min_pixels"])
self.gguf_writer.add_vision_max_pixels(self.preprocessor_config["max_pixels"])
vcfg = self.hparams_vision
if self.is_ocr_variant(self.global_config):
# --- HunyuanOCR ---
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.HUNYUANOCR)
self.gguf_writer.add_vision_use_gelu(True)
self.gguf_writer.add_vision_attention_layernorm_eps(vcfg.get("rms_norm_eps", 1e-5))
self.gguf_writer.add_vision_spatial_merge_size(vcfg.get("spatial_merge_size", 2))
self.gguf_writer.add_vision_min_pixels(self.preprocessor_config["min_pixels"])
self.gguf_writer.add_vision_max_pixels(self.preprocessor_config["max_pixels"])
return
# --- HunyuanVL ---
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.HUNYUANVL)
self.gguf_writer.add_vision_use_gelu(str(vcfg["hidden_act"]).lower() == "gelu")
self.gguf_writer.add_vision_attention_layernorm_eps(float(vcfg["rms_norm_eps"]))
self.gguf_writer.add_vision_spatial_merge_size(int(vcfg["spatial_merge_size"]))
self.gguf_writer.add_vision_min_pixels(int(self.preprocessor_config["min_pixels"]))
self.gguf_writer.add_vision_max_pixels(int(self.preprocessor_config["max_pixels"]))
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if not name.startswith("vit."):
return # skip text tensors
return
# strip CLS token (row 0) from position embeddings so resize_position_embeddings works
if "position_embedding" in name:
data_torch = data_torch[1:] # [n_patches+1, n_embd] -> [n_patches, n_embd]
@ -12023,11 +12053,66 @@ class HunyuanOCRVisionModel(MmprojModel):
def tensor_force_quant(self, name, new_name, bid, n_dims):
# force conv weights to F32 or F16 to avoid BF16 IM2COL issues on Metal
# Both HunyuanOCR and HunyuanVL emit the ViT -> LLM projection as mm.0/mm.2.
if ("mm.0." in new_name or "mm.2." in new_name) and new_name.endswith(".weight"):
return gguf.GGMLQuantizationType.F16 if self.ftype == gguf.LlamaFileType.MOSTLY_F16 else gguf.GGMLQuantizationType.F32
return super().tensor_force_quant(name, new_name, bid, n_dims)
@ModelBase.register("HunYuanVLForConditionalGeneration")
class HunyuanVLTextModel(HunYuanModel):
# The "HunYuanVLForConditionalGeneration" HF architecture covers both HunyuanOCR
# and HunyuanVL. HunyuanOCR reuses the HunYuan-Dense text backbone (standard RoPE),
# while HunyuanVL introduces a new LLM arch with XD-RoPE. Detect the variant from
# the config and pick the matching GGUF architecture.
model_arch = gguf.MODEL_ARCH.HUNYUAN_VL
@staticmethod
def _is_ocr_config(hparams: dict) -> bool:
# OCR pairs a 1B text backbone (hidden=1024) with a ViT projector that
# outputs 1024-d; HunyuanVL uses 3072-d. Keep in sync with
# HunyuanVLVisionModel.is_ocr_variant.
return int((hparams.get("vision_config") or {}).get("out_hidden_size", 0)) == 1024
def __init__(self, dir_model: Path, *args, **kwargs):
raw_hparams = kwargs.get("hparams") or ModelBase.load_hparams(dir_model, is_mistral_format=False)
if self._is_ocr_config(raw_hparams):
self.model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE
else:
self.model_arch = gguf.MODEL_ARCH.HUNYUAN_VL
super().__init__(dir_model, *args, **kwargs)
def set_gguf_parameters(self):
super().set_gguf_parameters()
# Only emit XD-RoPE metadata for the HunyuanVL backbone; HunyuanOCR uses
# the HunYuan-Dense arch which already handles standard rope in super().
if self.model_arch != gguf.MODEL_ARCH.HUNYUAN_VL:
return
if self.rope_parameters.get("rope_type") != "xdrope":
return
# defaults for HunyuanVL. The C++ side later computes:
# freq_base = rope_theta * alpha ** (head_dim / (head_dim - 2))
self.gguf_writer.add_rope_freq_base(float(self.rope_parameters["rope_theta"]))
self.gguf_writer.add_rope_scaling_alpha(float(self.rope_parameters["alpha"]))
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_rope_scaling_factor(float(self.rope_parameters.get("factor", 1)))
ctx_len = int(self.hparams["max_position_embeddings"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(ctx_len)
self.gguf_writer.add_context_length(ctx_len)
self.gguf_writer.add_rope_dimension_sections(list(self.rope_parameters["xdrope_section"]))
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Skip vision tensors — they are written by HunyuanVLVisionModel
if name.startswith("vit."):
return
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("SmolLM3ForCausalLM")
class SmolLM3Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.SMOLLM3

View file

@ -1,58 +0,0 @@
# Snapdragon-based Linux devices
## Docker Setup
The easiest way to build llama.cpp for a Snapdragon-based Linux device is using the toolchain Docker image (see [github.com/snapdragon-toolchain](https://github.com/snapdragon-toolchain)).
This image includes OpenCL SDK, Hexagon SDK, CMake, and the ARM64 Linux cross-compilation toolchain.
Cross-compilation is supported on **Linux X86** hosts. The resulting binaries are deployed to and run on the target **Qualcomm Snapdragon ARM64 Linux** device.
```
~/src/llama.cpp$ docker run -it -u $(id -u):$(id -g) --volume $(pwd):/workspace --platform linux/amd64 ghcr.io/snapdragon-toolchain/arm64-linux:v0.1
[d]/> cd /workspace
```
Note: The rest of the **Linux** build process assumes that you're running inside the toolchain container.
## How to Build
Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets:
```
[d]/workspace> cp docs/backend/snapdragon/CMakeUserPresets.json .
[d]/workspace> cmake --preset arm64-linux-snapdragon-release -B build-snapdragon
[d]/workspace> cmake --build build-snapdragon -j $(nproc)
```
To generate an installable "package" simply use cmake --install, then zip it:
```
[d]/workspace> cmake --install build-snapdragon --prefix pkg-snapdragon
[d]/workspace> zip -r pkg-snapdragon.zip pkg-snapdragon
```
## How to Install
For this step, you will deploy the built binaries and libraries to the target Linux device. Transfer `pkg-snapdragon.zip` to the target device, then unzip it and set up the environment variables:
```
$ unzip pkg-snapdragon.zip
$ cd pkg-snapdragon
$ export LD_LIBRARY_PATH=./lib
$ export ADSP_LIBRARY_PATH=./lib
```
At this point, you should also download some models onto the device:
```
$ wget https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q4_0.gguf
```
## How to Run
Next, since we have setup the environment variables, we can run the llama-cli with the Hexagon backends:
```
$ ./bin/llama-cli -m Llama-3.2-3B-Instruct-Q4_0.gguf --device HTP0 -ngl 99 -p "what is the most popular cookie in the world?"
```

View file

@ -1,158 +0,0 @@
#pragma clang diagnostic ignored "-Wunused-function"
#include <stdbool.h>
#include <stdlib.h>
#include <string.h>
#include <qurt_thread.h>
#include <qurt_futex.h>
#include <HAP_compute_res.h>
#include "hmx-queue.h"
#define QURT_LOWEST_PRIO (254)
static inline void hmx_lock(struct hmx_queue *q)
{
if (!q->hmx_locked) {
HAP_compute_res_hmx_lock(q->hap_rctx);
q->hmx_locked = true;
}
}
static inline void hmx_unlock(struct hmx_queue *q)
{
if (q->hmx_locked) {
HAP_compute_res_hmx_unlock(q->hap_rctx);
q->hmx_locked = false;
}
}
static inline void hmx_queue_process(struct hmx_queue *q, bool* killed) {
unsigned int ir = atomic_load(&q->idx_read);
while (ir != atomic_load(&q->idx_write)) {
struct hmx_queue_desc *d = &q->desc[ir];
if (!d->done) {
FARF(HIGH, "hmx-queue-process: ir %u func %p data %p", ir, d->func, d->data);
enum hmx_queue_signal sig = (enum hmx_queue_signal) (unsigned int) d->func;
switch (sig) {
case HMX_QUEUE_NOOP: /* noop */; break;
case HMX_QUEUE_KILL: *killed = true; break;
case HMX_QUEUE_SUSPEND: hmx_unlock(q); break;
default:
hmx_lock(q);
d->func(d->data);
break;
}
atomic_fetch_add(&d->done, 1);
}
ir = (ir + 1) & q->idx_mask;
atomic_store(&q->idx_read, ir);
}
}
static void hmx_queue_thread(void * arg) {
struct hmx_queue * q = (struct hmx_queue *) arg;
FARF(HIGH, "hmx-queue-thread: started");
bool killed = false;
unsigned int poll_cnt = HMX_QUEUE_POLL_COUNT;
unsigned int prev_seqn = 0;
while (!killed) {
unsigned int seqn = atomic_load(&q->seqn);
if (seqn == prev_seqn) {
if (--poll_cnt) { hex_pause(); continue; }
FARF(HIGH, "hmx-queue-thread: sleeping");
qurt_futex_wait(&q->seqn, prev_seqn);
continue;
}
prev_seqn = seqn;
poll_cnt = HMX_QUEUE_POLL_COUNT;
FARF(HIGH, "hmx-queue-thread: new work");
hmx_queue_process(q, &killed);
}
FARF(HIGH, "hmx-queue-thread: stopped");
}
struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx) {
capacity = hex_ceil_pow2(capacity);
struct hmx_queue * q = (struct hmx_queue *) memalign(32, sizeof(struct hmx_queue));
if (q == NULL) {
FARF(ERROR, "%s: failed to allocate DMA queue\n", __FUNCTION__);
return NULL;
}
memset(q, 0, sizeof(struct hmx_queue));
q->capacity = capacity;
q->idx_mask = capacity - 1;
q->hap_rctx = hap_rctx;
q->desc = (struct hmx_queue_desc *) memalign(64, capacity * sizeof(struct hmx_queue_desc));
if (!q->desc) {
FARF(ERROR, "hmx-queue: failed to allocate HMX queue descriptors\n");
return NULL;
}
memset(q->desc, 0, capacity * sizeof(struct hmx_queue_desc));
const size_t stack_size = HMX_QUEUE_THREAD_STACK_SIZE;
q->stack = (unsigned char *) memalign(64, stack_size);
if (!q->stack) {
FARF(ERROR, "hmx-queue: thread stack allocation failed (%zu bytes)", stack_size);
return NULL;
}
memset(q->stack, 0, stack_size);
// Match caller thread priority (same pattern as worker-pool.c).
int prio = qurt_thread_get_priority(qurt_thread_get_id());
if (prio < 1) {
prio = 1;
}
if (prio > QURT_LOWEST_PRIO) {
prio = QURT_LOWEST_PRIO;
}
qurt_thread_attr_t attr;
qurt_thread_attr_init(&attr);
qurt_thread_attr_set_stack_addr(&attr, q->stack);
qurt_thread_attr_set_stack_size(&attr, stack_size);
qurt_thread_attr_set_priority(&attr, prio);
qurt_thread_attr_set_name(&attr, "hmx-queue");
int err = qurt_thread_create(&q->thread, &attr, hmx_queue_thread, q);
if (err) {
FARF(ERROR, "hmx-worker: thread create failed (%d)", err);
return NULL;
}
FARF(HIGH, "hmx-queue: capacity %u\n", capacity);
return q;
}
void hmx_queue_delete(struct hmx_queue * q) {
if (!q) {
return;
}
// Tell the worker to exit.
hmx_queue_flush(q);
hmx_queue_signal(q, HMX_QUEUE_KILL);
hmx_queue_flush(q);
int status;
qurt_thread_join(q->thread, &status);
free(q->desc);
free(q->stack);
free(q);
}

View file

@ -1,134 +0,0 @@
#ifndef HMX_QUEUE_H
#define HMX_QUEUE_H
#include <stdbool.h>
#include <stdint.h>
#include <stdatomic.h>
#include <hexagon_types.h>
#include <qurt_thread.h>
#include <qurt_futex.h>
#include <HAP_farf.h>
#include "hex-utils.h"
#ifdef __cplusplus
extern "C" {
#endif
#define HMX_QUEUE_THREAD_STACK_SIZE (16 * 1024)
#define HMX_QUEUE_POLL_COUNT 2000
typedef void (*hmx_queue_func)(void *);
// Dummy funcs used as signals
enum hmx_queue_signal {
HMX_QUEUE_NOOP = 0, // aka NULL
HMX_QUEUE_SUSPEND,
HMX_QUEUE_KILL
};
struct hmx_queue_desc {
hmx_queue_func func;
void * data;
atomic_uint done;
};
struct hmx_queue {
struct hmx_queue_desc * desc;
atomic_uint idx_write; // updated by producer (push)
atomic_uint idx_read; // updated by consumer (process)
unsigned int idx_pop; // updated by producer (pop)
uint32_t idx_mask;
uint32_t capacity;
atomic_uint seqn; // incremented for all pushes, used with futex
qurt_thread_t thread;
void * stack;
uint32_t hap_rctx;
bool hmx_locked;
};
struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx);
void hmx_queue_delete(struct hmx_queue * q);
static inline struct hmx_queue_desc hmx_queue_make_desc(hmx_queue_func func, void * data) {
struct hmx_queue_desc d = { func, data };
return d;
}
static inline bool hmx_queue_push(struct hmx_queue * q, struct hmx_queue_desc d) {
unsigned int ir = atomic_load(&q->idx_read);
unsigned int iw = q->idx_write;
if (((iw + 1) & q->idx_mask) == ir) {
FARF(HIGH, "hmx-queue-push: queue is full\n");
return false;
}
atomic_store(&d.done, 0);
FARF(HIGH, "hmx-queue-push: iw %u func %p data %p\n", iw, d.func, d.data);
q->desc[iw] = d;
atomic_store(&q->idx_write, (iw + 1) & q->idx_mask);
// wake up our thread
atomic_fetch_add(&q->seqn, 1);
qurt_futex_wake(&q->seqn, 1);
return true;
}
static inline bool hmx_queue_signal(struct hmx_queue *q, enum hmx_queue_signal sig) {
return hmx_queue_push(q, hmx_queue_make_desc((hmx_queue_func) sig, NULL));
}
static inline bool hmx_queue_empty(struct hmx_queue * q) {
return q->idx_pop == q->idx_write;
}
static inline uint32_t hmx_queue_depth(struct hmx_queue * q) {
return (q->idx_read - q->idx_read) & q->idx_mask;
}
static inline uint32_t hmx_queue_capacity(struct hmx_queue * q) {
return q->capacity;
}
static inline struct hmx_queue_desc hmx_queue_pop(struct hmx_queue * q) {
unsigned int ip = q->idx_pop;
unsigned int iw = q->idx_write;
struct hmx_queue_desc rd = { NULL, NULL };
if (ip == iw) {
return rd;
}
// Wait for desc to complete
struct hmx_queue_desc * d = &q->desc[ip];
while (!atomic_load(&d->done)) {
FARF(HIGH, "hmx-queue-pop: waiting for HMX queue : %u\n", ip);
hex_pause();
}
rd = *d;
q->idx_pop = (ip + 1) & q->idx_mask;
FARF(HIGH, "hmx-queue-pop: ip %u func %p data %p\n", ip, rd.func, rd.data);
return rd;
}
static inline void hmx_queue_flush(struct hmx_queue * q) {
while (hmx_queue_pop(q).func != NULL) ;
}
static inline void hmx_queue_suspend(struct hmx_queue *q) {
hmx_queue_signal(q, HMX_QUEUE_SUSPEND);
hmx_queue_flush(q);
}
#ifdef __cplusplus
} // extern "C"
#endif
#endif /* HMX_QUEUE_H */

View file

@ -918,6 +918,10 @@ ggml_backend_reg_t ggml_backend_metal_reg(void) {
static std::vector<ggml_backend_device_ptr> devs;
if (!initialized) {
// workaround macOS limitation (kIOGPUCommandBufferCallbackErrorImpactingInteractivity) until proper fix becomes possible
// ref: https://github.com/ggml-org/llama.cpp/issues/20141#issuecomment-4272947703
setenv("AGX_RELAX_CDM_CTXSTORE_TIMEOUT", "1", true);
static ggml_backend_metal_reg_ptr reg_ctx(ggml_backend_metal_reg_init());
for (int i = 0; i < g_devices; ++i) {

View file

@ -1,176 +0,0 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_qcom_reqd_sub_group_size
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK_K 256
#define K_SCALE_SIZE 12
inline void get_scale_min_k4(
int j,
global const uchar * q,
uchar * d,
uchar * m,
uchar mask_d6,
uchar mask_d4,
uchar mask_hi2
) {
if (j < 4) {
*d = q[j] & mask_d6;
*m = q[j+4] & mask_d6;
} else {
*d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2);
*m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2);
}
}
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_128
#endif
kernel void kernel_gemm_noshuffle_q5_k_f32(
global const ushort * src0_q,
global const uchar * src0_qh,
global const uchar * src0_s,
global const half * src0_d,
global const half * src0_dm,
read_only image1d_buffer_t src1,
global float * dst,
ulong offsetd,
int m,
int n,
int k,
int n_no_padding,
uchar mask_d6,
uchar mask_d4,
uchar mask_hi2
) {
dst = (global float *)((global char *)dst + offsetd);
int n_4 = n >> 2;
int gy = get_global_id(0);
int gx = get_global_id(1);
int gx_2 = gx << 2;
half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0;
half8 B;
half4 dequantized_weights;
int num_blocks_K = k / QK_K;
global const ushort * weight_ptr = src0_q + gx_2;
global const uchar * qh_ptr = src0_qh + gx_2;
global const half * d_ptr = src0_d + gx_2;
global const half * dm_ptr = src0_dm + gx_2;
for (int i = 0; i < k; i += 32) {
int sb_idx = i / QK_K;
int sub_idx = (i / 32) % 8;
half4 d = vload4(0, d_ptr + sb_idx * m);
half4 dm = vload4(0, dm_ptr + sb_idx * m);
global const uchar * sc0 = src0_s + (gx_2+0) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE;
global const uchar * sc1 = src0_s + (gx_2+1) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE;
global const uchar * sc2 = src0_s + (gx_2+2) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE;
global const uchar * sc3 = src0_s + (gx_2+3) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE;
uchar sv0, mn0, sv1, mn1, sv2, mn2, sv3, mn3;
get_scale_min_k4(sub_idx, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2);
get_scale_min_k4(sub_idx, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2);
get_scale_min_k4(sub_idx, sc2, &sv2, &mn2, mask_d6, mask_d4, mask_hi2);
get_scale_min_k4(sub_idx, sc3, &sv3, &mn3, mask_d6, mask_d4, mask_hi2);
half4 scale = convert_half4(convert_float4(d) * convert_float4((uchar4)(sv0, sv1, sv2, sv3)));
half4 mval = convert_half4(convert_float4(dm) * convert_float4((uchar4)(mn0, mn1, mn2, mn3)));
for (int l = 0; l < 32; l += 4) {
int ki = i + l;
ushort4 bits4 = vload4(0, weight_ptr + (ki/4) * m);
uchar4 qh_bits = vload4(0, qh_ptr + (ki/8) * m);
int qh_shift = ki % 8;
// j=0
B.s0123 = read_imageh(src1, gy*2 + (ki+0) * n_4);
B.s4567 = read_imageh(src1, gy*2+1 + (ki+0) * n_4);
dequantized_weights.s0 = ((bits4.s0 & 0x000F) | (((qh_bits.s0 >> (qh_shift+0)) & 1) << 4)) * scale.s0 - mval.s0;
dequantized_weights.s1 = ((bits4.s1 & 0x000F) | (((qh_bits.s1 >> (qh_shift+0)) & 1) << 4)) * scale.s1 - mval.s1;
dequantized_weights.s2 = ((bits4.s2 & 0x000F) | (((qh_bits.s2 >> (qh_shift+0)) & 1) << 4)) * scale.s2 - mval.s2;
dequantized_weights.s3 = ((bits4.s3 & 0x000F) | (((qh_bits.s3 >> (qh_shift+0)) & 1) << 4)) * scale.s3 - mval.s3;
c0 += B * dequantized_weights.s0;
c1 += B * dequantized_weights.s1;
c2 += B * dequantized_weights.s2;
c3 += B * dequantized_weights.s3;
// j=1
B.s0123 = read_imageh(src1, gy*2 + (ki+1) * n_4);
B.s4567 = read_imageh(src1, gy*2+1 + (ki+1) * n_4);
dequantized_weights.s0 = (((bits4.s0 & 0x00F0) >> 4) | (((qh_bits.s0 >> (qh_shift+1)) & 1) << 4)) * scale.s0 - mval.s0;
dequantized_weights.s1 = (((bits4.s1 & 0x00F0) >> 4) | (((qh_bits.s1 >> (qh_shift+1)) & 1) << 4)) * scale.s1 - mval.s1;
dequantized_weights.s2 = (((bits4.s2 & 0x00F0) >> 4) | (((qh_bits.s2 >> (qh_shift+1)) & 1) << 4)) * scale.s2 - mval.s2;
dequantized_weights.s3 = (((bits4.s3 & 0x00F0) >> 4) | (((qh_bits.s3 >> (qh_shift+1)) & 1) << 4)) * scale.s3 - mval.s3;
c0 += B * dequantized_weights.s0;
c1 += B * dequantized_weights.s1;
c2 += B * dequantized_weights.s2;
c3 += B * dequantized_weights.s3;
// j=2
B.s0123 = read_imageh(src1, gy*2 + (ki+2) * n_4);
B.s4567 = read_imageh(src1, gy*2+1 + (ki+2) * n_4);
dequantized_weights.s0 = (((bits4.s0 & 0x0F00) >> 8) | (((qh_bits.s0 >> (qh_shift+2)) & 1) << 4)) * scale.s0 - mval.s0;
dequantized_weights.s1 = (((bits4.s1 & 0x0F00) >> 8) | (((qh_bits.s1 >> (qh_shift+2)) & 1) << 4)) * scale.s1 - mval.s1;
dequantized_weights.s2 = (((bits4.s2 & 0x0F00) >> 8) | (((qh_bits.s2 >> (qh_shift+2)) & 1) << 4)) * scale.s2 - mval.s2;
dequantized_weights.s3 = (((bits4.s3 & 0x0F00) >> 8) | (((qh_bits.s3 >> (qh_shift+2)) & 1) << 4)) * scale.s3 - mval.s3;
c0 += B * dequantized_weights.s0;
c1 += B * dequantized_weights.s1;
c2 += B * dequantized_weights.s2;
c3 += B * dequantized_weights.s3;
// j=3
B.s0123 = read_imageh(src1, gy*2 + (ki+3) * n_4);
B.s4567 = read_imageh(src1, gy*2+1 + (ki+3) * n_4);
dequantized_weights.s0 = (((bits4.s0 & 0xF000) >> 12) | (((qh_bits.s0 >> (qh_shift+3)) & 1) << 4)) * scale.s0 - mval.s0;
dequantized_weights.s1 = (((bits4.s1 & 0xF000) >> 12) | (((qh_bits.s1 >> (qh_shift+3)) & 1) << 4)) * scale.s1 - mval.s1;
dequantized_weights.s2 = (((bits4.s2 & 0xF000) >> 12) | (((qh_bits.s2 >> (qh_shift+3)) & 1) << 4)) * scale.s2 - mval.s2;
dequantized_weights.s3 = (((bits4.s3 & 0xF000) >> 12) | (((qh_bits.s3 >> (qh_shift+3)) & 1) << 4)) * scale.s3 - mval.s3;
c0 += B * dequantized_weights.s0;
c1 += B * dequantized_weights.s1;
c2 += B * dequantized_weights.s2;
c3 += B * dequantized_weights.s3;
}
}
int idx = (gy<<3)*m + (gx<<2);
if (idx+3 < m*n_no_padding) {
vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
idx += m;
}
if (idx+3 < m*n_no_padding) {
vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
idx += m;
}
if (idx+3 < m*n_no_padding) {
vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
idx += m;
}
if (idx+3 < m*n_no_padding) {
vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
idx += m;
}
if (idx+3 < m*n_no_padding) {
vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
idx += m;
}
if (idx+3 < m*n_no_padding) {
vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
idx += m;
}
if (idx+3 < m*n_no_padding) {
vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
idx += m;
}
if (idx+3 < m*n_no_padding) {
vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
}
}

View file

@ -1,326 +0,0 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#ifdef cl_qcom_reqd_sub_group_size
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#endif
#define QK_K 256
#define NSUBGROUPS 4
#define SUBGROUP_SIZE 64
inline void get_scale_min_k4(
int j,
global const uchar * q,
uchar * d,
uchar * m,
uchar mask_d6,
uchar mask_d4,
uchar mask_hi2
) {
if (j < 4) {
*d = q[j] & mask_d6;
*m = q[j+4] & mask_d6;
} else {
*d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2);
*m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2);
}
}
#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, bits1, scale, minv, y) \
float shared_y; \
shared_y = sub_group_broadcast(y.s0, 0); \
total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s0 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s1 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s1, 0); \
total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s2, 0); \
total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s3, 0); \
total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s4, 0); \
total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s5, 0); \
total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s6, 0); \
total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s7, 0); \
total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s0, 1); \
total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s2 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s3 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s1, 1); \
total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s2, 1); \
total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s3, 1); \
total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s4, 1); \
total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s5, 1); \
total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s6, 1); \
total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s7, 1); \
total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, bits1, scale, minv, y) \
shared_y = sub_group_broadcast(y.s0, 2); \
total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s4 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s5 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s1, 2); \
total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s2, 2); \
total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s3, 2); \
total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s4, 2); \
total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s5, 2); \
total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s6, 2); \
total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s7, 2); \
total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s0, 3); \
total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s6 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s7 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s1, 3); \
total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s2, 3); \
total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s3, 3); \
total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s4, 3); \
total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s5, 3); \
total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s6, 3); \
total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
shared_y = sub_group_broadcast(y.s7, 3); \
total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, bits1, scale, minv, y) \
float8 shared_y; \
shared_y = sub_group_broadcast(y, 0); \
total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s0 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \
total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \
total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \
total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \
total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \
total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \
total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \
total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \
total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s1 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \
total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \
total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \
total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \
total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \
total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \
total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \
total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \
shared_y = sub_group_broadcast(y, 1); \
total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s2 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \
total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \
total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \
total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \
total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \
total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \
total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \
total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \
total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s3 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \
total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \
total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \
total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \
total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \
total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \
total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \
total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \
#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, bits1, scale, minv, y) \
shared_y = sub_group_broadcast(y, 2); \
total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s4 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \
total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \
total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \
total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \
total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \
total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \
total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \
total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \
total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s5 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \
total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \
total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \
total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \
total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \
total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \
total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \
total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \
shared_y = sub_group_broadcast(y, 3); \
total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s6 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \
total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \
total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \
total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \
total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \
total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \
total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \
total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \
total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s7 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \
total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \
total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \
total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \
total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \
total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \
total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \
total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_gemv_noshuffle_q5_k_f32(
read_only image1d_buffer_t src0_q,
read_only image1d_buffer_t src0_qh,
global half2 * src0_d,
global half2 * src0_m,
global uchar * src0_s,
read_only image1d_buffer_t src1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
uchar mask_d6,
uchar mask_d4,
uchar mask_hi2)
{
uint groupId = get_local_id(1);
uint gid = get_global_id(0);
ushort slid = get_sub_group_local_id();
uint K = ne00;
uint M = ne01;
uint LINE_STRIDE_A = M / 2;
uint BLOCK_STRIDE_A = NSUBGROUPS * M;
uint LINE_STRIDE_A_QH = M / 2;
uint BLOCK_STRIDE_A_QH = NSUBGROUPS * M / 2;
uint scales_per_row = (K / QK_K) * 12;
private uint4 regA;
private ushort4 regH;
private half2 regS;
private half2 regM;
private float8 regB;
private float2 totalSum = (float2)(0.0f);
for (uint k = groupId; k < (K / 32); k += NSUBGROUPS) {
uint sb = k / 8;
uint j = k % 8;
half2 d = src0_d[gid + sb * LINE_STRIDE_A];
half2 dm = src0_m[gid + sb * LINE_STRIDE_A];
global const uchar * sc0 = src0_s + 2 * gid * scales_per_row + sb * 12;
global const uchar * sc1 = src0_s + (2 * gid + 1) * scales_per_row + sb * 12;
uchar sv0, mn0, sv1, mn1;
get_scale_min_k4(j, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2);
get_scale_min_k4(j, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2);
regS = convert_half2(convert_float2(d) * convert_float2((uchar2)(sv0, sv1)));
regM = convert_half2(convert_float2(dm) * convert_float2((uchar2)(mn0, mn1)));
if (slid < 4) {
regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));
regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));
}
regH.s0 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 0)).x);
regH.s1 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 1)).x);
regH.s2 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 2)).x);
regH.s3 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 3)).x);
regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
#ifdef VECTOR_SUB_GROUP_BROADCAST
dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB);
#else
dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB);
#endif // VECTOR_SUB_GROUP_BROADCAST
regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
#ifdef VECTOR_SUB_GROUP_BROADCAST
dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB);
#else
dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB);
#endif // VECTOR_SUB_GROUP_BROADCAST
}
// reduction in local memory, assumes #wave=4
local float2 reduceLM[SUBGROUP_SIZE * 3];
if (groupId == 1) {
reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum;
}
if (groupId == 2) {
reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum;
}
if (groupId == 3) {
reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum;
}
barrier(CLK_LOCAL_MEM_FENCE);
if (groupId == 0) {
totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid];
}
if (groupId == 0) {
totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid];
}
if (groupId == 0) {
totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid];
}
// 2 outputs per fiber in wave 0
if (groupId == 0) {
dst = (global float*)((global char*)dst + offsetd);
vstore2(totalSum, 0, &(dst[gid * 2]));
}
}

View file

@ -1,101 +0,0 @@
diagnostic(off, subgroup_uniformity);
enable f16;
#define KV_TILE 32
#define WG_SIZE 32
struct Params {
offset_mask: u32,
seq_len_q: u32,
seq_len_kv: u32,
stride_mask3: u32,
// Number of KV blocks and Q blocks per batch.
// nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = seq_len_q.
nblk0: u32,
nblk1: u32,
};
@group(0) @binding(0) var<storage, read> mask: array<f16>;
@group(0) @binding(1) var<storage, read_write> blk: array<u32>;
@group(0) @binding(2) var<uniform> params: Params;
const MASK_MIN: f32 = -65504.0;
const MASK_MAX: f32 = 65504.0;
var<workgroup> wg_min: array<f32, WG_SIZE>;
var<workgroup> wg_max: array<f32, WG_SIZE>;
var<workgroup> wg_any: array<u32, WG_SIZE>;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>) {
// Dispatch mapping:
// - x indexes KV blocks
// - y flattens (batch_idx, q_blk) as y = batch_idx * nblk1 + q_blk
let kv_blk = wg_id.x;
let y = wg_id.y;
let q_blk = y % params.nblk1;
let batch_idx = y / params.nblk1;
if (kv_blk >= params.nblk0) {
return;
}
let q_start = q_blk;
let k_start = kv_blk * KV_TILE;
let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u);
let mask_batch_base = params.offset_mask + mask_batch * params.stride_mask3;
// We keep min/max to classify:
// - fully masked (max <= MASK_MIN)
// - all-zero mask (min == 0 && max == 0)
// - mixed/general mask
var local_min = MASK_MAX;
var local_max = -MASK_MAX;
var local_any = 0u;
let q_row = q_start;
if (q_row < params.seq_len_q) {
let row_base = mask_batch_base + q_row * params.seq_len_kv;
for (var k_rel = local_id.x; k_rel < KV_TILE; k_rel += WG_SIZE) {
let k_col = k_start + k_rel;
if (k_col >= params.seq_len_kv) {
continue;
}
let mv = f32(mask[row_base + k_col]);
local_min = min(local_min, mv);
local_max = max(local_max, mv);
local_any = 1u;
}
}
wg_min[local_id.x] = local_min;
wg_max[local_id.x] = local_max;
wg_any[local_id.x] = local_any;
workgroupBarrier();
// Thread 0 writes one state per block.
if (local_id.x == 0u) {
var mmin = wg_min[0];
var mmax = wg_max[0];
var many = wg_any[0];
for (var i = 1u; i < WG_SIZE; i += 1u) {
mmin = min(mmin, wg_min[i]);
mmax = max(mmax, wg_max[i]);
many = max(many, wg_any[i]);
}
var state = 0u;
if (many != 0u) {
if (mmax <= MASK_MIN) {
state = 0u;
} else if (mmin == 0.0 && mmax == 0.0) {
state = 2u;
} else {
state = 1u;
}
}
let blk_idx = (batch_idx * params.nblk1 + q_blk) * params.nblk0 + kv_blk;
blk[blk_idx] = state;
}
}

View file

@ -1,78 +0,0 @@
diagnostic(off, subgroup_uniformity);
enable f16;
enable subgroups;
// Default values
#define HEAD_DIM_V 64
#define WG_SIZE 128
struct Params {
nrows: u32,
seq_len_q: u32,
n_heads: u32,
offset_dst: u32,
nwg: u32,
tmp_data_base: u32,
tmp_stats_base: u32,
};
@group(0) @binding(0) var<storage, read_write> tmp: array<f32>;
@group(0) @binding(1) var<storage, read_write> dst: array<vec4<f32>>;
@group(0) @binding(2) var<uniform> params: Params;
const FLOAT_MIN: f32 = -1.0e9;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(subgroup_id) subgroup_id: u32,
@builtin(num_subgroups) num_subgroups: u32,
@builtin(subgroup_size) subgroup_size: u32,
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
let rid = wg_id.x;
if (rid >= params.nrows) {
return;
}
let rows_per_batch = params.n_heads * params.seq_len_q;
let batch_idx = rid / rows_per_batch;
let rem = rid % rows_per_batch;
let head_idx = rem / params.seq_len_q;
let q_row = rem % params.seq_len_q;
let dst2_stride = HEAD_DIM_V * params.n_heads;
let dst3_stride = dst2_stride * params.seq_len_q;
let row_base = params.offset_dst + batch_idx * dst3_stride + q_row * dst2_stride + head_idx * HEAD_DIM_V;
let thread = sg_inv_id;
if (params.nwg > subgroup_size) {
return;
}
let stats_base = params.tmp_stats_base + rid * (2u * params.nwg);
let active_thread = thread < params.nwg;
let si = select(0.0, tmp[stats_base + 2u * thread + 0u], active_thread);
let mi = select(FLOAT_MIN, tmp[stats_base + 2u * thread + 1u], active_thread);
let m = subgroupMax(mi);
let ms = select(0.0, exp(mi - m), active_thread);
let s = subgroupAdd(si * ms);
let inv_s = select(0.0, 1.0 / s, s != 0.0);
let row_tmp_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg);
for (var elem_base = subgroup_id * 4u; elem_base < HEAD_DIM_V; elem_base += num_subgroups * 4u) {
var weighted = vec4<f32>(0.0, 0.0, 0.0, 0.0);
if (active_thread) {
let src = row_tmp_base + thread * HEAD_DIM_V + elem_base;
weighted = vec4<f32>(tmp[src + 0u], tmp[src + 1u], tmp[src + 2u], tmp[src + 3u]) * ms;
}
let sum_x = subgroupAdd(weighted.x);
let sum_y = subgroupAdd(weighted.y);
let sum_z = subgroupAdd(weighted.z);
let sum_w = subgroupAdd(weighted.w);
if (thread == 0u) {
let dst_vec_index = (row_base + elem_base) >> 2u;
dst[dst_vec_index] = vec4<f32>(sum_x, sum_y, sum_z, sum_w) * inv_s;
}
}
}

View file

@ -1,729 +0,0 @@
diagnostic(off, chromium.subgroup_matrix_uniformity);
diagnostic(off, subgroup_uniformity);
enable f16;
enable subgroups;
enable chromium_experimental_subgroup_matrix;
#ifdef KV_F32
#define KV_TYPE f32
#else
#define KV_TYPE f16
#endif
#define HEAD_DIM_QK 64
#define HEAD_DIM_V 64
#define SG_MAT_M 8
#define SG_MAT_N 8
#define SG_MAT_K 8
#define Q_TILE SG_MAT_M
#define KV_TILE 16
#define WG_SIZE 64
#ifndef VEC_NE
#define VEC_NE 4u
#endif
#define KV_BLOCKS (KV_TILE / SG_MAT_N)
#define BLOCK_SIZE 32
#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE)
#if defined(KV_Q4_0)
#define NQ 16
#define F16_PER_BLOCK 9
#define WEIGHTS_PER_F16 4
#elif defined(KV_Q8_0)
#define NQ 8
#define F16_PER_BLOCK 17
#define WEIGHTS_PER_F16 2
#endif
#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
fn get_byte(value: u32, index: u32) -> u32 {
return (value >> (index * 8)) & 0xFF;
}
fn get_byte_i32(value: u32, index: u32) -> i32 {
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
}
struct Params {
offset_q: u32,
offset_k: u32,
offset_v: u32,
offset_mask: u32,
offset_sinks: u32,
offset_dst: u32,
// shapes of Q/K/V
n_heads: u32,
seq_len_q: u32,
seq_len_kv: u32,
// strides (in elements)
stride_q1: u32,
stride_q2: u32,
stride_q3: u32,
stride_k1: u32,
stride_k2: u32,
stride_k3: u32,
stride_v1: u32,
stride_v2: u32,
stride_v3: u32,
stride_mask3: u32,
// repeat factors for K/V, e.g., MHA vs. MQA vs. GQA
q_per_kv: u32,
// softmax params
scale: f32,
max_bias: f32,
logit_softcap: f32,
n_head_log2: f32,
m0: f32,
m1: f32,
#ifdef BLK
blk_base: u32,
blk_nblk0: u32,
blk_nblk1: u32,
#endif
tmp_data_base: u32,
tmp_stats_base: u32,
nwg: u32,
};
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
#if defined(KV_Q4_0) || defined(KV_Q8_0)
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
#else
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
#endif
#if defined(KV_Q4_0) || defined(KV_Q8_0)
@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
#else
@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>;
#endif
#if defined(MASK) && defined(SINKS)
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
#ifdef BLK
#define BLK_BINDING 5
#define TMP_BINDING 6
#define DST_BINDING 7
#define PARAMS_BINDING 8
#else
#define TMP_BINDING 5
#define DST_BINDING 6
#define PARAMS_BINDING 7
#endif
#elif defined(MASK)
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
#ifdef BLK
#define BLK_BINDING 4
#define TMP_BINDING 5
#define DST_BINDING 6
#define PARAMS_BINDING 7
#else
#define TMP_BINDING 4
#define DST_BINDING 5
#define PARAMS_BINDING 6
#endif
#elif defined(SINKS)
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
#define TMP_BINDING 4
#define DST_BINDING 5
#define PARAMS_BINDING 6
#else
#define TMP_BINDING 3
#define DST_BINDING 4
#define PARAMS_BINDING 5
#endif
#ifdef BLK
@group(0) @binding(BLK_BINDING) var<storage, read_write> blk: array<u32>;
#endif
@group(0) @binding(TMP_BINDING) var<storage, read_write> tmp: array<f32>;
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
// Just a very small float value.
const FLOAT_MIN: f32 = -1.0e9;
var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
#ifndef KV_DIRECT
const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
// we can reuse the same shmem for K and V since we only need one at a time
var<workgroup> kv_shmem: array<f16, kv_shmem_size>;
#endif
var<workgroup> o_shmem: array<f16, Q_TILE * HEAD_DIM_V>;
#ifdef MASK
// storage for mask values
var<workgroup> mask_shmem: array<f16, Q_TILE * KV_TILE>;
#endif
// note that we reuse the same storage for both since we only need one at a time
var<workgroup> inter_shmem: array<f16, Q_TILE * KV_TILE>;
// Storage for row max and exp sum during online softmax
var<workgroup> row_max_shmem: array<f32, Q_TILE>;
var<workgroup> exp_sum_shmem: array<f32, Q_TILE>;
var<workgroup> blk_state_wg: u32;
fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 {
var v = select(FLOAT_MIN,
f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale,
kv_idx < KV_TILE);
#ifdef LOGIT_SOFTCAP
v = params.logit_softcap * tanh(v);
#endif
#ifdef MASK
if (apply_mask) {
var mask_val = select(0.0,f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE);
v += select(mask_val, slope * mask_val, has_bias);
}
#endif
return v;
}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(subgroup_id) subgroup_id: u32,
@builtin(subgroup_size) subgroup_size: u32,
@builtin(num_subgroups) num_subgroups: u32,
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
// initialize row max for online softmax
for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {
row_max_shmem[i] = FLOAT_MIN;
exp_sum_shmem[i] = 0.0;
}
for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) {
o_shmem[i] = 0.0;
}
// workgroups per head/batch
let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;
let wg_per_batch = wg_per_head * params.n_heads;
let dst2_stride = HEAD_DIM_V * params.n_heads;
let dst3_stride = dst2_stride * params.seq_len_q;
let iwg = wg_id.x % params.nwg;
let base_wg_id = wg_id.x / params.nwg;
// batch index
let batch_idx = base_wg_id / wg_per_batch;
let q_batch_offset = params.offset_q + batch_idx * params.stride_q3;
let k_batch_offset = params.offset_k + batch_idx * params.stride_k3;
let v_batch_offset = params.offset_v + batch_idx * params.stride_v3;
let wg_in_batch = base_wg_id % wg_per_batch;
// head index
let head_idx = wg_in_batch / wg_per_head;
let q_head_offset = q_batch_offset + head_idx * params.stride_q2;
let k_head_idx = head_idx / params.q_per_kv;
let v_head_idx = k_head_idx;
let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2;
// starting Q row for this workgroup
let wg_in_head = wg_in_batch % wg_per_head;
let q_row_start = wg_in_head * Q_TILE;
#ifdef MASK
// mask offset
let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv;
#endif
let head = f32(head_idx);
let has_bias = params.max_bias > 0.0;
let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias);
// load q tile into shared memory
for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
let q_row = elem_idx / HEAD_DIM_QK;
let q_col = elem_idx % HEAD_DIM_QK;
let head_q_row = q_row_start + q_row;
let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
q_shmem[elem_idx] = f16(select(
0.0,
Q[global_q_row_offset + q_col],
head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK));
}
for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) {
#ifdef BLK
let q_blk = q_row_start / Q_TILE;
let kv_blk = kv_tile / KV_TILE;
let blk_batch = select(0u, batch_idx, params.stride_mask3 > 0u);
let blk_idx = params.blk_base + (blk_batch * params.blk_nblk1 + q_blk) * params.blk_nblk0 + kv_blk;
let blk_state_local = blk[blk_idx];
#else
let blk_state_local = 1u;
#endif
if (local_id.x == 0u) {
blk_state_wg = blk_state_local;
}
workgroupBarrier();
let blk_state = blk_state_wg;
let skip_tile = blk_state == 0u;
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
inter_shmem[elem_idx] = f16(0.0);
}
// load k tile into shared memory
#if defined(KV_Q4_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let k_row = blck_idx / BLOCKS_K;
let global_k_row = kv_tile + k_row;
let block_k = blck_idx % BLOCKS_K;
let row_offset = k_row * HEAD_DIM_QK;
if (global_k_row < params.seq_len_kv) {
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = K[base_idx];
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = K[base_idx + 1u + block_offset + j];
let q_1 = K[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_lo;
kv_shmem[row_offset + idx + 16u] = q_hi;
}
}
}
}
#elif defined(KV_Q8_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let k_row = blck_idx / BLOCKS_K;
let global_k_row = kv_tile + k_row;
let block_k = blck_idx % BLOCKS_K;
let row_offset = k_row * HEAD_DIM_QK;
if (global_k_row < params.seq_len_kv) {
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = K[base_idx];
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = K[base_idx + 1u + block_offset + j];
let q_1 = K[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_val;
}
}
}
}
#elif defined(KV_DIRECT)
// Direct global loads for KV
#else
for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) {
let k_row = elem_idx / HEAD_DIM_QK;
let k_col = elem_idx % HEAD_DIM_QK;
let global_k_row = kv_tile + k_row;
let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;
let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK;
let vec_idx = (global_k_row_offset + k_col) >> 2u;
let k4 = select(vec4<KV_TYPE>(0.0), K[vec_idx], in_bounds);
kv_shmem[elem_idx + 0u] = f16(k4.x);
kv_shmem[elem_idx + 1u] = f16(k4.y);
kv_shmem[elem_idx + 2u] = f16(k4.z);
kv_shmem[elem_idx + 3u] = f16(k4.w);
}
#endif
workgroupBarrier();
// accumulate q block * k block into registers across the entire KV tile
if (!skip_tile) {
let num_of_threads = subgroup_size / VEC_NE;
let tx = sg_inv_id % num_of_threads;
let ty = sg_inv_id / num_of_threads;
for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
let global_q_row = q_row_start + q_tile_row;
if (global_q_row >= params.seq_len_q) {
continue;
}
let local_q_row_offset = q_tile_row * HEAD_DIM_QK;
for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += VEC_NE) {
let kv_idx = kv_base + ty;
var partial_sum: f32 = 0.0;
let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv;
if (kv_valid) {
for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) {
let q_off = local_q_row_offset + i * 4u;
let qv = vec4<f32>(
f32(q_shmem[q_off + 0u]),
f32(q_shmem[q_off + 1u]),
f32(q_shmem[q_off + 2u]),
f32(q_shmem[q_off + 3u]));
#ifdef KV_DIRECT
let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u);
let kv = vec4<f32>(K[idx >> 2u]);
#else
let idx = kv_idx * HEAD_DIM_QK + (i * 4u);
let kv = vec4<f32>(
f32(kv_shmem[idx + 0u]),
f32(kv_shmem[idx + 1u]),
f32(kv_shmem[idx + 2u]),
f32(kv_shmem[idx + 3u]));
#endif
partial_sum += dot(qv, kv);
}
}
var sum = partial_sum;
// Reduce over tx threads (NL) for this ty stripe.
var tx_delta = num_of_threads >> 1u;
loop {
if (tx_delta == 0u) {
break;
}
let sh = subgroupShuffleDown(sum, tx_delta);
if (tx < tx_delta) {
sum += sh;
}
tx_delta >>= 1u;
}
let sum_bcast = subgroupShuffle(sum, num_of_threads * ty);
if (tx == 0u && kv_valid) {
let dst_idx = q_tile_row * KV_TILE + kv_idx;
inter_shmem[dst_idx] = f16(sum_bcast);
}
}
}
}
#ifdef MASK
let apply_mask = !skip_tile && (blk_state != 2u);
if (apply_mask) {
// load mask tile into shared memory for this KV block
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
let mask_row = elem_idx / KV_TILE;
let mask_col = elem_idx % KV_TILE;
let global_q_row = q_row_start + mask_row;
let global_k_col = kv_tile + mask_col;
let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv;
let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col;
mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds);
}
}
#else
let apply_mask = false;
#endif
workgroupBarrier();
// online softmax
if (!skip_tile) {
for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
let global_q_row = q_row_start + q_tile_row;
if (global_q_row >= params.seq_len_q) {
break;
}
var prev_max = row_max_shmem[q_tile_row];
var final_max = prev_max;
// pass 1: compute final max across the full KV tile in chunks
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
let kv_idx = kv_offset + sg_inv_id;
let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE;
let softmax_term = select(FLOAT_MIN,
calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask),
kv_valid);
final_max = subgroupMax(max(final_max, softmax_term));
}
var total_exp_term: f32 = 0.0;
// pass 2: compute exp sum and write P using final_max
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
let kv_idx = kv_offset + sg_inv_id;
let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask);
let cur_p = select(0.0,
exp(softmax_term - final_max),
kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);
total_exp_term += subgroupAdd(cur_p);
if (kv_idx < KV_TILE) {
inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p);
}
}
let cur_exp = exp(prev_max - final_max);
if (sg_inv_id == 0) {
row_max_shmem[q_tile_row] = final_max;
exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term;
}
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
let idx = q_tile_row * HEAD_DIM_V + elem_idx;
o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp);
}
}
}
// load v tile into shared memory
#if defined(KV_Q4_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let v_row = blck_idx / BLOCKS_V;
let global_v_row = kv_tile + v_row;
let block_k = blck_idx % BLOCKS_V;
let row_offset = v_row * HEAD_DIM_V;
if (global_v_row < params.seq_len_kv) {
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = V[base_idx];
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = V[base_idx + 1u + block_offset + j];
let q_1 = V[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_lo;
kv_shmem[row_offset + idx + 16u] = q_hi;
}
}
}
}
#elif defined(KV_Q8_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let v_row = blck_idx / BLOCKS_V;
let global_v_row = kv_tile + v_row;
let block_k = blck_idx % BLOCKS_V;
let row_offset = v_row * HEAD_DIM_V;
if (global_v_row < params.seq_len_kv) {
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = V[base_idx];
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = V[base_idx + 1u + block_offset + j];
let q_1 = V[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_val;
}
}
}
}
#elif defined(KV_DIRECT)
// Direct global loads for KV
#else
for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) {
let v_row = elem_idx / HEAD_DIM_V;
let v_col = elem_idx % HEAD_DIM_V;
let global_v_row = kv_tile + v_row;
let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;
let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V;
let vec_idx = (global_v_row_offset + v_col) >> 2u;
let v4 = select(vec4<KV_TYPE>(0.0), V[vec_idx], in_bounds);
kv_shmem[elem_idx + 0u] = f16(v4.x);
kv_shmem[elem_idx + 1u] = f16(v4.y);
kv_shmem[elem_idx + 2u] = f16(v4.z);
kv_shmem[elem_idx + 3u] = f16(v4.w);
}
#endif
workgroupBarrier();
if (!skip_tile) {
// we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem
// we want to compute O += P * V across the full KV tile
let ne_threads : u32 = VEC_NE;
let nl_threads = max(1u, subgroup_size / ne_threads);
let tx_pv = sg_inv_id % nl_threads;
let ty_pv = sg_inv_id / nl_threads;
for (var q_tile_row = subgroup_id;
q_tile_row < Q_TILE;
q_tile_row += num_subgroups) {
for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_threads) {
var lo = vec4<f32>(0.0, 0.0, 0.0, 0.0);
for (var cc = 0u; cc < KV_TILE / ne_threads; cc += 1u) {
let kv_idx = cc * ne_threads + ty_pv;
let v_row = kv_tile + kv_idx;
if (v_row >= params.seq_len_kv) {
continue;
}
let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]);
#ifdef KV_DIRECT
let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u;
let v4 = vec4<f32>(V[v_idx >> 2u]);
#else
let v_idx = kv_idx * HEAD_DIM_V + vec_col * 4u;
let v4 = vec4<f32>(
f32(kv_shmem[v_idx + 0u]),
f32(kv_shmem[v_idx + 1u]),
f32(kv_shmem[v_idx + 2u]),
f32(kv_shmem[v_idx + 3u]));
#endif
lo += p * v4;
}
var lo_x = lo.x;
var lo_y = lo.y;
var lo_z = lo.z;
var lo_w = lo.w;
// Reduce over ty threads (NE) for this tx thread.
var ty_delta = ne_threads >> 1u;
loop {
if (ty_delta == 0u) {
break;
}
let thread_delta = ty_delta * nl_threads;
let shx = subgroupShuffleDown(lo_x, thread_delta);
let shy = subgroupShuffleDown(lo_y, thread_delta);
let shz = subgroupShuffleDown(lo_z, thread_delta);
let shw = subgroupShuffleDown(lo_w, thread_delta);
if (ty_pv < ty_delta) {
lo_x += shx;
lo_y += shy;
lo_z += shz;
lo_w += shw;
}
ty_delta >>= 1u;
}
if (ty_pv == 0u) {
let elem_base = vec_col * 4u;
let o_base_idx = q_tile_row * HEAD_DIM_V + elem_base;
o_shmem[o_base_idx + 0u] = f16(f32(o_shmem[o_base_idx + 0u]) + lo_x);
o_shmem[o_base_idx + 1u] = f16(f32(o_shmem[o_base_idx + 1u]) + lo_y);
o_shmem[o_base_idx + 2u] = f16(f32(o_shmem[o_base_idx + 2u]) + lo_z);
o_shmem[o_base_idx + 3u] = f16(f32(o_shmem[o_base_idx + 3u]) + lo_w);
}
}
}
}
workgroupBarrier();
}
#ifdef SINKS
// Sinks are global terms and must be applied exactly once across split workgroups.
if (iwg == 0u) {
for (var q_tile_row = subgroup_id;
q_tile_row < Q_TILE;
q_tile_row += num_subgroups) {
let global_q_row = q_row_start + q_tile_row;
if (global_q_row >= params.seq_len_q) {
break;
}
var prev_max = row_max_shmem[q_tile_row];
// for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum
let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0);
let new_max = subgroupMax(max(prev_max, sink_val));
let max_exp = exp(prev_max - new_max);
let sink_exp = exp(sink_val - new_max);
let sink_exp_sum = subgroupAdd(sink_exp);
if (sg_inv_id == 0) {
row_max_shmem[q_tile_row] = new_max;
exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum;
}
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
let idx = q_tile_row * HEAD_DIM_V + elem_idx;
o_shmem[idx] = f16(f32(o_shmem[idx]) * max_exp);
}
}
workgroupBarrier();
}
#endif
let rows_per_batch = params.n_heads * params.seq_len_q;
for (var q_tile_row = subgroup_id;
q_tile_row < Q_TILE;
q_tile_row += num_subgroups) {
let global_q_row = q_row_start + q_tile_row;
if (global_q_row >= params.seq_len_q) { break; }
if (params.nwg == 1u) {
let exp_sum = exp_sum_shmem[q_tile_row];
let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
let row_base: u32 =
params.offset_dst + batch_idx * dst3_stride + global_q_row * dst2_stride + head_idx * HEAD_DIM_V;
for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) {
let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
let v = vec4<f32>(
f32(o_shmem[i0]) * scale,
f32(o_shmem[i1]) * scale,
f32(o_shmem[i2]) * scale,
f32(o_shmem[i3]) * scale
);
let dst_vec_index: u32 = (row_base + elem_base) >> 2u;
dst[dst_vec_index] = v;
}
} else {
let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + global_q_row;
let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V;
let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg;
for (var elem_base = sg_inv_id * 4u;
elem_base < HEAD_DIM_V;
elem_base += subgroup_size * 4u) {
let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
let tbase = tmp_row_data_base + elem_base;
tmp[tbase + 0u] = f32(o_shmem[i0]);
tmp[tbase + 1u] = f32(o_shmem[i1]);
tmp[tbase + 2u] = f32(o_shmem[i2]);
tmp[tbase + 3u] = f32(o_shmem[i3]);
}
if (sg_inv_id == 0u) {
tmp[tmp_row_stats_base + 0u] = exp_sum_shmem[q_tile_row];
tmp[tmp_row_stats_base + 1u] = row_max_shmem[q_tile_row];
}
}
}
}

View file

@ -1,155 +0,0 @@
enable f16;
#ifdef TYPE_F32
#define DataType f32
#endif
#ifdef TYPE_F16
#define DataType f16
#endif
#ifdef OP_REGLU
fn op(a: DataType, b: DataType) -> DataType {
return max(a, 0) * b;
}
#endif
#ifdef OP_GEGLU
const SQRT_2_OVER_PI: DataType = 0.79788456080286535587989211986876;
const GELU_COEF_A: DataType = 0.044715;
fn op(a: DataType, b: DataType) -> DataType {
let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a);
return 0.5 * a * (2.0 - 2.0/ (exp(2* val) + 1)) * b;
}
#endif
#ifdef OP_SWIGLU
fn op(a: DataType, b: DataType) -> DataType {
return a / (1.0 + exp(-a)) * b;
}
#endif
#ifdef OP_SWIGLU_OAI
fn op(a: f32, b: f32) -> f32 {
let xi = min(a, params.limit);
let gi = max(min(b, params.limit), -params.limit);
var out_glu = xi / (1.0 + exp(-xi * params.alpha));
out_glu = out_glu * (1.0 + gi);
return out_glu;
}
#endif
#ifdef OP_GEGLU_ERF
const p_erf: DataType = 0.3275911;
const a1_erf: DataType = 0.254829592;
const a2_erf: DataType = -0.284496736;
const a3_erf: DataType = 1.421413741;
const a4_erf: DataType = -1.453152027;
const a5_erf: DataType = 1.061405429;
const SQRT_2_INV: DataType = 0.7071067811865476;
fn op(a: DataType, b: DataType) -> DataType {
let a_div_sqr2 = a * SQRT_2_INV;
let sign_x = sign(a_div_sqr2);
let x = abs(a_div_sqr2);
let t = 1.0 / (1.0 + p_erf * x);
let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x));
let erf_approx = sign_x * y;
return 0.5 * a * (1.0 + erf_approx) * b;
}
#endif
#ifdef OP_GEGLU_QUICK
const GELU_QUICK_COEF: DataType = -1.702;
fn op(a: DataType, b: DataType) -> DataType {
return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b;
}
#endif
struct Params {
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
// Strides (in elements)
stride_src01: u32,
stride_src02: u32,
stride_src03: u32,
stride_src11: u32,
stride_src12: u32,
stride_src13: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// shape of dst
ne: u32,
ne0: u32,
ne1: u32,
ne2: u32,
swapped: u32,
alpha: f32,
limit: f32,
}
@group(0) @binding(0)
var<storage, read_write> src0: array<DataType>;
#ifdef NO_SPLIT
@group(0) @binding(1)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(2)
var<uniform> params: Params;
fn a_value(base: u32) -> DataType {
let offset: u32 = select(0, params.ne0, params.swapped != 0);
return src0[base + offset];
}
fn b_value(base: u32) -> DataType {
let offset: u32 = select(params.ne0, 0, params.swapped != 0);
return src0[base + offset];
}
#else
@group(0) @binding(1)
var<storage, read_write> src1: array<DataType>;
@group(0) @binding(2)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(3)
var<uniform> params: Params;
fn a_value(base: u32) -> DataType {
return src0[base];
}
fn b_value(base: u32) -> DataType {
return src1[base];
}
#endif
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
i = i % (params.ne1 * params.ne0);
let i1 = i / params.ne0;
let i0 = i % params.ne0;
let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0;
let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
dst[i_dst] = op(a_value(i_a), b_value(i_b));
}

View file

@ -197,6 +197,7 @@ class Keys:
FREQ_BASE_SWA = "{arch}.rope.freq_base_swa"
SCALING_TYPE = "{arch}.rope.scaling.type"
SCALING_FACTOR = "{arch}.rope.scaling.factor"
SCALING_ALPHA = "{arch}.rope.scaling.alpha"
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
@ -471,6 +472,7 @@ class MODEL_ARCH(IntEnum):
ERNIE4_5_MOE = auto()
HUNYUAN_MOE = auto()
HUNYUAN_DENSE = auto()
HUNYUAN_VL = auto()
SMOLLM3 = auto()
GPT_OSS = auto()
LFM2 = auto()
@ -957,6 +959,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.FALCON_H1: "falcon-h1",
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
MODEL_ARCH.HUNYUAN_DENSE: "hunyuan-dense",
MODEL_ARCH.HUNYUAN_VL: "hunyuan_vl",
MODEL_ARCH.SMOLLM3: "smollm3",
MODEL_ARCH.GPT_OSS: "gpt-oss",
MODEL_ARCH.LFM2: "lfm2",
@ -3489,6 +3492,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.HUNYUAN_VL: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.SMOLLM3: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@ -4138,6 +4157,7 @@ class VisionProjectorType:
YOUTUVL = "youtuvl"
NEMOTRON_V2_VL = "nemotron_v2_vl"
HUNYUANOCR = "hunyuanocr"
HUNYUANVL = "hunyuanvl"
# Items here are (block size, type size)

View file

@ -973,6 +973,9 @@ class GGUFWriter:
def add_rope_scaling_factor(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value)
def add_rope_scaling_alpha(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_ALPHA.format(arch=self.arch), value)
def add_rope_scaling_attn_factors(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value)

View file

@ -3400,6 +3400,7 @@ static void PrepareMediaEmbds(const int nctx, const std::vector<int> & media_int
//added after https://github.com/ggml-org/llama.cpp/pull/22161, replacing clip_is_mrope function
auto decoder_rope_type = llama_model_rope_type(llama_get_model(llama_ctx_v4));
switch (decoder_rope_type) {
case LLAMA_ROPE_TYPE_NONE:
case LLAMA_ROPE_TYPE_NORM:
case LLAMA_ROPE_TYPE_NEOX:
{

View file

@ -1,53 +0,0 @@
#!/usr/bin/env pwsh
# Basedir on device
$basedir=".\pkg-snapdragon"
$cli_opts=$args
$model="Llama-3.2-3B-Instruct-Q4_0.gguf"
if ($null -ne $env:M) {
$model=$env:M
}
$device="HTP0"
if ($null -ne $env:D) {
$device=$env:D
}
if ($null -ne $env:V) {
$env:GGML_HEXAGON_VERBOSE=$env:V
}
if ($null -ne $env:SCHED) {
$env:GGML_SCHED_DEBUG=$env:SCHED; $cli_opts="$cli_opts -v"
}
if ($null -ne $env:PROF) {
$env:GGML_HEXAGON_PROFILE=$env:PROF; $env:GGML_HEXAGON_OPSYNC=1
}
if ($null -ne $env:OPMASK) {
$env:GGML_HEXAGON_OPMASK=$env:OPMASK
}
if ($null -ne $env:NHVX) {
$env:GGML_HEXAGON_NHVX=$env:NHVX
}
if ($null -ne $env:NDEV) {
$env:GGML_HEXAGON_NDEV=$env:NDEV
}
if ($null -ne $env:HB) {
$env:GGML_HEXAGON_HOSTBUF=$env:HB
}
$env:ADSP_LIBRARY_PATH="$basedir\lib"
& "$basedir\bin\llama-completion.exe" `
--no-mmap -m $basedir\..\..\gguf\$model `
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 `
--ctx-size 8192 --batch-size 256 -fa on `
-ngl 99 -no-cnv --device $device $cli_opts

View file

@ -109,6 +109,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
{ LLM_ARCH_HUNYUAN_DENSE, "hunyuan-dense" },
{ LLM_ARCH_HUNYUAN_VL, "hunyuan_vl" },
{ LLM_ARCH_SMOLLM3, "smollm3" },
{ LLM_ARCH_OPENAI_MOE, "gpt-oss" },
{ LLM_ARCH_LFM2, "lfm2" },
@ -250,6 +251,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
{ LLM_KV_ROPE_SCALING_ALPHA, "%s.rope.scaling.alpha" },
{ LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" },
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },

View file

@ -113,6 +113,7 @@ enum llm_arch {
LLM_ARCH_ERNIE4_5_MOE,
LLM_ARCH_HUNYUAN_MOE,
LLM_ARCH_HUNYUAN_DENSE,
LLM_ARCH_HUNYUAN_VL,
LLM_ARCH_SMOLLM3,
LLM_ARCH_OPENAI_MOE,
LLM_ARCH_LFM2,
@ -254,6 +255,7 @@ enum llm_kv {
LLM_KV_ROPE_SCALE_LINEAR,
LLM_KV_ROPE_SCALING_TYPE,
LLM_KV_ROPE_SCALING_FACTOR,
LLM_KV_ROPE_SCALING_ALPHA,
LLM_KV_ROPE_SCALING_ATTN_FACTOR,
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
LLM_KV_ROPE_SCALING_FINETUNED,

View file

@ -116,6 +116,7 @@ struct llama_hparams {
float rope_freq_base_train_swa = 10000.0f;
float rope_freq_scale_train;
float rope_freq_scale_train_swa = 1.0f;
float rope_scaling_alpha = 0.0f; // NTK-aware alpha for XDRoPE
uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul = 0.0f;

View file

@ -852,6 +852,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false);
ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false);
if (arch == LLM_ARCH_HUNYUAN_VL || arch == LLM_ARCH_HUNYUAN_DENSE) {
if (hparams.n_expert <= 1) {
hparams.n_expert = 0;
hparams.n_expert_used = 0;
}
}
if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd);
ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd_out_impl);
@ -930,6 +937,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false);
ml.get_key(LLM_KV_ROPE_SCALING_ALPHA, hparams.rope_scaling_alpha, false);
// non-transformer models do not have attention heads
if (hparams.n_head() > 0) {
@ -2707,9 +2715,18 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_HUNYUAN_VL:
case LLM_ARCH_HUNYUAN_DENSE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false);
// XDRoPE / NTK-aware scaling: base = rope_theta * alpha^(dim / (dim - 2))
if (hparams.rope_scaling_alpha > 0.0f) {
const int dim = hparams.n_embd_head_k();
hparams.rope_freq_base_train = hparams.rope_freq_base_train
* powf(hparams.rope_scaling_alpha, (float)dim / (float)(dim - 2));
}
switch (hparams.n_embd) {
case 1024: type = LLM_TYPE_0_5B; break;
@ -7106,6 +7123,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0);
}
} break;
case LLM_ARCH_HUNYUAN_VL:
case LLM_ARCH_HUNYUAN_DENSE:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -9126,6 +9144,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_hunyuan_moe>(*this, params);
} break;
case LLM_ARCH_HUNYUAN_VL:
case LLM_ARCH_HUNYUAN_DENSE:
{
llm = std::make_unique<llm_build_hunyuan_dense>(*this, params);
@ -9475,6 +9494,9 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_GLM4_MOE:
return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_HUNYUAN_VL:
return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX;
// all model arches should be listed explicitly here
case LLM_ARCH_UNKNOWN:
GGML_ABORT("unknown architecture");

View file

@ -6,6 +6,11 @@ llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, cons
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
const bool use_mrope = hparams.use_mrope();
int sections[4];
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
ggml_tensor * cur;
ggml_tensor * inpL;
@ -37,22 +42,36 @@ llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, cons
auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur,
n_embd_head, n_head, n_head_kv, il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
if (use_mrope) {
Qcur = ggml_rope_multi(
ctx0, Qcur, inp_pos, rope_factors,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_multi(
ctx0, Kcur, inp_pos, rope_factors,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
} else {
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = build_norm(Kcur,
model.layers[il].attn_k_norm, nullptr,
LLM_NORM_RMS, il);

View file

@ -150,7 +150,7 @@
#define TN_TOK_BOI "v.boi"
#define TN_TOK_EOI "v.eoi"
// hunyuanocr
// hunyuanocr / hunyuanvl (shared GGUF tensor names)
#define TN_MM_PRE_NORM "mm.pre_norm.%s"
#define TN_TOK_IMG_BEGIN "mm.image_begin"
#define TN_TOK_IMG_END "mm.image_end"
@ -242,6 +242,15 @@
#define TN_STD_BIAS "v.std_bias"
#define TN_STD_SCALE "v.std_scale"
// yasa2
#define TN_YASA_PATCH_LN_W "v.patch_ln.weight"
#define TN_YASA_PATCH_LN_B "v.patch_ln.bias"
#define TN_YASA_BACKBONE_LN_W "v.backbone_ln.weight"
#define TN_YASA_BACKBONE_LN_B "v.backbone_ln.bias"
#define TN_YASA_POS_EMBD "v.vision_pos_embed"
#define TN_YASA_STAGE_DOWN_LN "v.stage.%d.down.ln.%s"
#define TN_YASA_STAGE_DOWN_CONV "v.stage.%d.down.conv.%s"
#define TN_YASA_STAGE_BLK "v.stage.%d.blk.%d.%s.%s"
// align x to upper multiple of n
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
@ -290,9 +299,11 @@ enum projector_type {
PROJECTOR_TYPE_LFM2A,
PROJECTOR_TYPE_GLM4V,
PROJECTOR_TYPE_YOUTUVL,
PROJECTOR_TYPE_YASA2,
PROJECTOR_TYPE_KIMIK25,
PROJECTOR_TYPE_NEMOTRON_V2_VL,
PROJECTOR_TYPE_HUNYUANOCR,
PROJECTOR_TYPE_HUNYUANVL,
PROJECTOR_TYPE_UNKNOWN,
};
@ -335,9 +346,11 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_LFM2A, "lfm2a"},
{ PROJECTOR_TYPE_GLM4V, "glm4v"},
{ PROJECTOR_TYPE_YOUTUVL, "youtuvl"},
{ PROJECTOR_TYPE_YASA2, "yasa2"},
{ PROJECTOR_TYPE_KIMIK25, "kimik25"},
{ PROJECTOR_TYPE_NEMOTRON_V2_VL, "nemotron_v2_vl"},
{ PROJECTOR_TYPE_HUNYUANOCR, "hunyuanocr"},
{ PROJECTOR_TYPE_HUNYUANVL, "hunyuanvl"},
};
static projector_type clip_projector_type_from_string(const std::string & str) {

View file

@ -268,6 +268,27 @@ struct mobilenetv5_block {
ggml_tensor * attn_norm_w = nullptr;
};
struct yasa2_block {
ggml_tensor * dw_w = nullptr;
ggml_tensor * dw_b = nullptr;
ggml_tensor * ln_w = nullptr;
ggml_tensor * ln_b = nullptr;
ggml_tensor * pw1_w = nullptr;
ggml_tensor * pw1_b = nullptr;
ggml_tensor * grn_w = nullptr;
ggml_tensor * grn_b = nullptr;
ggml_tensor * pw2_w = nullptr;
ggml_tensor * pw2_b = nullptr;
};
struct yasa2_stage {
ggml_tensor * down_ln_w = nullptr;
ggml_tensor * down_ln_b = nullptr;
ggml_tensor * down_conv_w = nullptr;
ggml_tensor * down_conv_b = nullptr;
std::vector<yasa2_block> blocks;
};
struct clip_model {
clip_modality modality = CLIP_MODALITY_VISION;
projector_type proj_type = PROJECTOR_TYPE_MLP;
@ -402,6 +423,15 @@ struct clip_model {
ggml_tensor * msfa_ffn_expand_bn = nullptr;
ggml_tensor * msfa_ffn_project_bn = nullptr;
// yasa2
ggml_tensor * yasa_patch_w = nullptr;
ggml_tensor * yasa_patch_b = nullptr;
ggml_tensor * yasa_patch_ln_w = nullptr;
ggml_tensor * yasa_patch_ln_b = nullptr;
ggml_tensor * yasa_backbone_ln_w = nullptr;
ggml_tensor * yasa_backbone_ln_b = nullptr;
ggml_tensor * yasa_vision_pos_embed = nullptr;
std::vector<yasa2_stage> yasa_stages;
// pixtral, glm4v
ggml_tensor * token_embd_img_break = nullptr;

View file

@ -76,6 +76,7 @@
#include "models/deepseekocr.cpp"
#include "models/mobilenetv5.cpp"
#include "models/youtuvl.cpp"
#include "models/yasa2.cpp"
struct clip_logger_state g_logger_state = {clip_log_callback_default, NULL};
@ -969,6 +970,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
builder = std::make_unique<clip_graph_cogvlm>(ctx, img);
} break;
case PROJECTOR_TYPE_HUNYUANOCR:
case PROJECTOR_TYPE_HUNYUANVL:
{
builder = std::make_unique<clip_graph_hunyuanocr>(ctx, img);
} break;
@ -1004,6 +1006,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
builder = std::make_unique<clip_graph_youtuvl>(ctx, img);
} break;
case PROJECTOR_TYPE_YASA2:
{
builder = std::make_unique<clip_graph_yasa2>(ctx, img);
} break;
default:
GGML_ABORT("missing cgraph builder");
}
@ -1474,6 +1480,16 @@ struct clip_model_loader {
hparams.set_limit_image_tokens(1, 62500);
hparams.set_warmup_n_tokens(16*16); // avoid OOM on warmup
} break;
case PROJECTOR_TYPE_YASA2:
{
hparams.ffn_op = FFN_GELU_ERF;
log_ffn_op = "gelu_erf";
hparams.image_resize_algo = RESIZE_ALGO_BICUBIC;
// reka model performs better when using resize_bicubic, which stretches
// the image to fit fixed square size
hparams.image_resize_pad = false;
} break;
case PROJECTOR_TYPE_GLM4V:
{
hparams.rope_theta = 10000.0f;
@ -1544,6 +1560,16 @@ struct clip_model_loader {
get_u32(KEY_IMAGE_MAX_PIXELS, hparams.image_max_pixels);
hparams.set_warmup_n_tokens(28*28);
} break;
case PROJECTOR_TYPE_HUNYUANVL:
{
hparams.n_merge = 2;
hparams.image_resize_algo = RESIZE_ALGO_BICUBIC_PILLOW;
hparams.image_resize_pad = false;
hparams.ffn_op = FFN_GELU;
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false);
hparams.set_limit_image_tokens(256, 16384);
hparams.set_warmup_n_tokens(32*32);
} break;
case PROJECTOR_TYPE_LFM2A:
{
// audio preprocessing params
@ -1929,6 +1955,55 @@ struct clip_model_loader {
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); // merger.mlp.2
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
} break;
case PROJECTOR_TYPE_YASA2:
{
// reuse tensors already loaded by the common section
// (TN_PATCH_EMBD and TN_PATCH_BIAS have the same tensor names)
GGML_ASSERT(model.patch_embeddings_0 && "yasa2 requires v.patch_embd.weight");
model.yasa_patch_w = model.patch_embeddings_0;
model.yasa_patch_b = model.patch_bias;
model.yasa_patch_ln_w = get_tensor(TN_YASA_PATCH_LN_W, false);
model.yasa_patch_ln_b = get_tensor(TN_YASA_PATCH_LN_B, false);
model.yasa_backbone_ln_w = get_tensor(TN_YASA_BACKBONE_LN_W, false);
model.yasa_backbone_ln_b = get_tensor(TN_YASA_BACKBONE_LN_B, false);
model.yasa_vision_pos_embed = get_tensor(TN_YASA_POS_EMBD, false);
model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false);
model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
model.yasa_stages.clear();
for (int s = 0; ; ++s) {
yasa2_stage stage;
stage.down_ln_w = get_tensor(string_format(TN_YASA_STAGE_DOWN_LN, s, "weight"), false);
stage.down_ln_b = get_tensor(string_format(TN_YASA_STAGE_DOWN_LN, s, "bias"), false);
stage.down_conv_w = get_tensor(string_format(TN_YASA_STAGE_DOWN_CONV, s, "weight"), false);
stage.down_conv_b = get_tensor(string_format(TN_YASA_STAGE_DOWN_CONV, s, "bias"), false);
for (int bi = 0; ; ++bi) {
yasa2_block blk;
blk.dw_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "dw", "weight"), false);
if (!blk.dw_w) {
break;
}
blk.dw_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "dw", "bias"), false);
blk.ln_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "ln", "weight"), false);
blk.ln_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "ln", "bias"), false);
blk.pw1_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "pw1", "weight"), false);
blk.pw1_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "pw1", "bias"), false);
blk.grn_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "grn", "weight"), false);
blk.grn_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "grn", "bias"), false);
blk.pw2_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "pw2", "weight"), false);
blk.pw2_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "pw2", "bias"), false);
stage.blocks.push_back(blk);
}
if (!stage.down_conv_w && stage.blocks.empty()) {
break;
}
model.yasa_stages.push_back(std::move(stage));
}
} break;
case PROJECTOR_TYPE_GLM4V:
{
model.mm_fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight"));
@ -2249,6 +2324,7 @@ struct clip_model_loader {
model.mm_eoi = get_tensor(TN_TOK_EOI);
} break;
case PROJECTOR_TYPE_HUNYUANOCR:
case PROJECTOR_TYPE_HUNYUANVL:
{
// proj.0 -> mm.0 (conv1), proj.2 -> mm.2 (conv2), mlp -> mm.model.fc (linear)
model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
@ -3062,6 +3138,19 @@ void setup_init_vision_shim_kcpp(struct clip_ctx * ctx_v) {
img_end = "<|vision_end|>";
image_preproc = std::make_unique<mtmd_image_preprocessor_youtuvl>(ctx_v);
} break;
case PROJECTOR_TYPE_YASA2:
{
img_beg = "<image>";
img_end = "</image>";
// Currently only supprots single-tile preprocessing: any input is downscaled
// to one image_size x image_size tile (64 output tokens via 8x8 adaptive avg
// pool).
// However, the model itself supports llava-uhd multi-tile tiling for high-res
// images. This will be implemented in a future PR (dispatch on has_pinpoints
// - see LDP/COGVLM branch above) and emit image_grid_pinpoints in the conversion
// script.
image_preproc = std::make_unique<mtmd_image_preprocessor_fixed_size>(ctx_v);
} break;
case PROJECTOR_TYPE_GEMMA3:
case PROJECTOR_TYPE_GEMMA3NV:
{
@ -3199,6 +3288,7 @@ void setup_init_vision_shim_kcpp(struct clip_ctx * ctx_v) {
image_preproc = std::make_unique<mtmd_image_preprocessor_deepseekocr>(ctx_v);
} break;
case PROJECTOR_TYPE_HUNYUANOCR:
case PROJECTOR_TYPE_HUNYUANVL:
{
// note: these use fullwidth (U+FF5C) and ▁ (U+2581) to match the tokenizer vocabulary
img_beg = "<hy_place▁holder▁no▁100>";
@ -3287,6 +3377,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 *
case PROJECTOR_TYPE_GLM4V:
case PROJECTOR_TYPE_PADDLEOCR:
case PROJECTOR_TYPE_HUNYUANOCR:
case PROJECTOR_TYPE_HUNYUANVL:
case PROJECTOR_TYPE_YOUTUVL:
return (img->nx / params.patch_size) / 2;
case PROJECTOR_TYPE_STEP3VL:
@ -3306,6 +3397,7 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 *
case PROJECTOR_TYPE_QWEN3VL:
case PROJECTOR_TYPE_GLM4V:
case PROJECTOR_TYPE_PADDLEOCR:
case PROJECTOR_TYPE_HUNYUANVL:
case PROJECTOR_TYPE_YOUTUVL:
return (img->ny / params.patch_size) / 2;
case PROJECTOR_TYPE_STEP3VL:
@ -3333,6 +3425,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
{
// do nothing
} break;
case PROJECTOR_TYPE_YASA2:
{
n_patches = 64; // adaptive average pooling to 8x8 tokens
} break;
case PROJECTOR_TYPE_LDP:
case PROJECTOR_TYPE_LDPV2:
case PROJECTOR_TYPE_GLM_EDGE:
@ -3493,6 +3589,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
n_patches = h * (h + 1) + 1;
} break;
case PROJECTOR_TYPE_HUNYUANOCR:
case PROJECTOR_TYPE_HUNYUANVL:
{
int merge = ctx->model.hparams.n_merge;
int ow = (img->nx / patch_size) / merge;
@ -3953,9 +4050,74 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
case PROJECTOR_TYPE_PHI4:
case PROJECTOR_TYPE_COGVLM:
case PROJECTOR_TYPE_HUNYUANOCR:
case PROJECTOR_TYPE_YASA2:
{
// do nothing
} break;
case PROJECTOR_TYPE_HUNYUANVL:
{
// Compute the HunyuanVL 2D position embedding on CPU (with the
// custom sf=(target+0.1)/n_grid bilinear sampling that the
// reference implementation uses) and upload it to the graph
// input declared in clip_graph_hunyuanocr::build().
GGML_ASSERT(model.position_embeddings != nullptr);
ggml_tensor * src_t = model.position_embeddings;
const int64_t n_embd = src_t->ne[0];
const int64_t n_pos = src_t->ne[1]; // = n_grid * n_grid
const int n_grid = (int)std::lround(std::sqrt((double)n_pos));
GGML_ASSERT((int64_t)n_grid * n_grid == n_pos);
const int out_w = pos_w; // pw
const int out_h = pos_h; // ph
// Pull weight to host.
std::vector<float> src(n_embd * n_pos);
ggml_backend_tensor_get(src_t, src.data(), 0, ggml_nbytes(src_t));
// Output layout matches ggml_new_tensor_2d(F32, n_embd, out_h*out_w):
// ne[0] = n_embd (fastest), ne[1] = out_h*out_w
// dst[(y*out_w + x) * n_embd + c]
std::vector<float> dst((size_t)n_embd * out_h * out_w);
const float sx = (float)(out_w + 0.1f) / (float)n_grid;
const float sy = (float)(out_h + 0.1f) / (float)n_grid;
for (int y = 0; y < out_h; ++y) {
// Match ggml_compute_forward_upscale_f32 pixel-center
// convention (align_corners=False): src_y = (y+0.5)/sy - 0.5.
const float fy = ((float)y + 0.5f) / sy - 0.5f;
int y0 = (int)std::floor(fy);
int y1 = y0 + 1;
y0 = std::clamp(y0, 0, n_grid - 1);
y1 = std::clamp(y1, 0, n_grid - 1);
float wy1 = std::clamp(fy - (float)y0, 0.0f, 1.0f);
const float wy0 = 1.0f - wy1;
for (int x = 0; x < out_w; ++x) {
const float fx = ((float)x + 0.5f) / sx - 0.5f;
int x0 = (int)std::floor(fx);
int x1 = x0 + 1;
x0 = std::clamp(x0, 0, n_grid - 1);
x1 = std::clamp(x1, 0, n_grid - 1);
float wx1 = std::clamp(fx - (float)x0, 0.0f, 1.0f);
const float wx0 = 1.0f - wx1;
const float w00 = wy0 * wx0;
const float w01 = wy0 * wx1;
const float w10 = wy1 * wx0;
const float w11 = wy1 * wx1;
const float * s00 = &src[((size_t)y0 * n_grid + x0) * n_embd];
const float * s01 = &src[((size_t)y0 * n_grid + x1) * n_embd];
const float * s10 = &src[((size_t)y1 * n_grid + x0) * n_embd];
const float * s11 = &src[((size_t)y1 * n_grid + x1) * n_embd];
float * d = &dst[((size_t)y * out_w + x) * n_embd];
for (int c = 0; c < n_embd; ++c) {
d[c] = w00 * s00[c] + w01 * s01[c] + w10 * s10[c] + w11 * s11[c];
}
}
}
set_input_f32("hunyuanvl_pos_embd", dst);
} break;
case PROJECTOR_TYPE_LLAMA4:
{
// set the 2D positions
@ -4376,8 +4538,10 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_PADDLEOCR:
case PROJECTOR_TYPE_KIMIK25:
case PROJECTOR_TYPE_YASA2:
return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_HUNYUANOCR:
case PROJECTOR_TYPE_HUNYUANVL:
return ctx->model.mm_model_proj->ne[1];
case PROJECTOR_TYPE_COGVLM:
return ctx->model.mm_4h_to_h_w->ne[1];

View file

@ -5,7 +5,21 @@ ggml_cgraph * clip_graph_hunyuanocr::build() {
const int pw = n_patches_x;
const int ph = n_patches_y;
ggml_tensor * pos_embd = resize_position_embeddings(GGML_SCALE_MODE_BILINEAR);
// Position embedding interpolation.
// HunyuanVL needs scale factors sf=(target+0.1)/n_grid, which the standard
// ggml_interpolate cannot express. To avoid adding a new ggml op, the
// resize is computed on CPU in clip_image_batch_encode and uploaded here
// as a graph input (named "hunyuanvl_pos_embd").
// HunyuanOCR uses the same square layout and the standard ratio-based
// interpolation provided by resize_position_embeddings().
ggml_tensor * pos_embd = nullptr;
if (proj_type == PROJECTOR_TYPE_HUNYUANVL && model.position_embeddings) {
pos_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ph * pw);
ggml_set_name(pos_embd, "hunyuanvl_pos_embd");
ggml_set_input(pos_embd);
} else {
pos_embd = resize_position_embeddings(GGML_SCALE_MODE_BILINEAR);
}
ggml_tensor * inp = build_inp();
ggml_tensor * cur = build_vit(inp, n_patches, NORM_TYPE_NORMAL, hparams.ffn_op, pos_embd, nullptr);

View file

@ -43,6 +43,14 @@ struct clip_graph_youtuvl : clip_graph {
ggml_cgraph * build() override;
};
struct clip_graph_yasa2 : clip_graph {
clip_graph_yasa2(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;
ggml_tensor * layer_norm_channels(ggml_tensor * inp, ggml_tensor * w, ggml_tensor * b, float eps = 1e-6f);
ggml_tensor * convnext_grn(ggml_tensor * inp, ggml_tensor * w, ggml_tensor * b);
};
struct clip_graph_minicpmv : clip_graph {
clip_graph_minicpmv(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;

191
tools/mtmd/models/yasa2.cpp Normal file
View file

@ -0,0 +1,191 @@
// ABOUTME: Yasa2 vision encoder graph builder for ConvNeXt-based architecture.
// ABOUTME: Implements patch embedding, ConvNeXt stages with GRN, and adaptive pooling.
#include "models.h"
static ggml_tensor * add_channel_bias(
ggml_context * ctx0,
ggml_tensor * x_whcb,
ggml_tensor * b_c) {
if (!b_c) {
return x_whcb;
}
ggml_tensor * b4 = ggml_reshape_4d(ctx0, b_c, 1, 1, b_c->ne[0], 1);
return ggml_add(ctx0, x_whcb, b4);
}
static ggml_tensor * mul_channel_weight(
ggml_context * ctx0,
ggml_tensor * x_whcb,
ggml_tensor * w_c) {
if (!w_c) {
return x_whcb;
}
ggml_tensor * w4 = ggml_reshape_4d(ctx0, w_c, 1, 1, w_c->ne[0], 1);
return ggml_mul(ctx0, x_whcb, w4);
}
ggml_tensor * clip_graph_yasa2::layer_norm_channels(ggml_tensor * inp, ggml_tensor * w, ggml_tensor * b, float eps) {
// Match HF ConvNextLayerNorm(channels_first):
// u = mean_c(x), s = mean_c((x-u)^2), x = (x-u)/sqrt(s+eps)
// cast back to input dtype before affine.
ggml_tensor * cur = ggml_permute(ctx0, inp, 2, 1, 0, 3); // [W,H,C,B] -> [C,H,W,B]
cur = ggml_cont(ctx0, cur);
ggml_tensor * u = ggml_mean(ctx0, cur); // [1,H,W,B]
ggml_tensor * xm = ggml_sub(ctx0, cur, u); // [C,H,W,B]
ggml_tensor * s = ggml_mul(ctx0, xm, xm); // [C,H,W,B]
s = ggml_mean(ctx0, s); // [1,H,W,B]
s = ggml_clamp(ctx0, s, eps, 1e30f); // avoid div-by-zero in no-alloc warmup
s = ggml_sqrt(ctx0, s); // [1,H,W,B]
ggml_tensor * xhat = ggml_div(ctx0, xm, s); // [C,H,W,B]
xhat = ggml_permute(ctx0, xhat, 2, 1, 0, 3); // [W,H,C,B]
xhat = ggml_cont(ctx0, xhat);
xhat = mul_channel_weight(ctx0, xhat, w);
xhat = add_channel_bias(ctx0, xhat, b);
return xhat;
}
ggml_tensor * clip_graph_yasa2::convnext_grn(ggml_tensor * inp, ggml_tensor * w, ggml_tensor * b) {
// Exact ConvNeXtV2 GRN:
// Gx = ||x||_2 over spatial dims (W,H), Nx = Gx / (mean_c(Gx) + eps)
// y = w * (x * Nx) + b + x
const int64_t wdim = inp->ne[0];
const int64_t hdim = inp->ne[1];
const int64_t cdim = inp->ne[2];
const int64_t bdim = inp->ne[3];
// Keep GRN math in fp32 for stability; fp16/bf16 accumulation can drift.
ggml_tensor * sq = ggml_mul(ctx0, inp, inp);
ggml_tensor * sq_flat = ggml_reshape_4d(ctx0, sq, wdim * hdim, cdim, 1, bdim); // [WH,C,1,B]
ggml_tensor * gx = ggml_sum_rows(ctx0, sq_flat); // [1,C,1,B]
gx = ggml_sqrt(ctx0, gx); // [1,C,1,B]
ggml_tensor * gx_ch_first = ggml_permute(ctx0, gx, 1, 0, 2, 3); // [C,1,1,B]
gx_ch_first = ggml_cont(ctx0, gx_ch_first);
ggml_tensor * gx_mean = ggml_mean(ctx0, gx_ch_first); // [1,1,1,B]
gx_mean = ggml_clamp(ctx0, gx_mean, 1e-6f, 1e30f); // approx +eps, warmup-safe
ggml_tensor * nx = ggml_div(ctx0, gx, gx_mean); // [1,C,1,B]
nx = ggml_permute(ctx0, nx, 0, 2, 1, 3); // [1,1,C,B]
nx = ggml_cont(ctx0, nx);
ggml_tensor * xnx = ggml_mul(ctx0, inp, nx);
xnx = mul_channel_weight(ctx0, xnx, w);
xnx = add_channel_bias(ctx0, xnx, b);
return ggml_add(ctx0, inp, xnx);
}
ggml_cgraph * clip_graph_yasa2::build() {
ggml_tensor * cur = build_inp_raw();
// Patch embedding Conv2d(kernel=4, stride=4)
cur = ggml_conv_2d(ctx0, model.yasa_patch_w, cur, patch_size, patch_size, 0, 0, 1, 1);
cur = add_channel_bias(ctx0, cur, model.yasa_patch_b);
ggml_set_name(cur, "yasa2_patch_conv_out");
cb(cur, "yasa2_patch_conv_out", -1);
cur = layer_norm_channels(cur, model.yasa_patch_ln_w, model.yasa_patch_ln_b, eps);
ggml_set_name(cur, "yasa2_patch_ln_out");
cb(cur, "yasa2_patch_ln_out", -1);
// ConvNeXt stages
for (size_t s = 0; s < model.yasa_stages.size(); ++s) {
const auto & stage = model.yasa_stages[s];
if (stage.down_conv_w) {
cur = layer_norm_channels(cur, stage.down_ln_w, stage.down_ln_b, eps);
cur = ggml_conv_2d(ctx0, stage.down_conv_w, cur, 2, 2, 0, 0, 1, 1);
cur = add_channel_bias(ctx0, cur, stage.down_conv_b);
ggml_format_name(cur, "yasa2_stage%zu_down_out", s);
}
for (size_t bi = 0; bi < stage.blocks.size(); ++bi) {
const auto & blk = stage.blocks[bi];
ggml_tensor * res = cur;
ggml_tensor * x = ggml_conv_2d_dw(ctx0, blk.dw_w, cur, 1, 1, 3, 3, 1, 1);
x = add_channel_bias(ctx0, x, blk.dw_b);
x = layer_norm_channels(x, blk.ln_w, blk.ln_b, eps);
// pwconv1/pwconv2 are HF Linear layers over channels; implement via matmul on tokens.
const int64_t w = x->ne[0];
const int64_t h = x->ne[1];
const int64_t b = x->ne[3];
ggml_tensor * tok = ggml_reshape_3d(ctx0, x, w * h, x->ne[2], b); // [T,C,B]
tok = ggml_permute(ctx0, tok, 1, 0, 2, 3); // [C,T,B]
tok = ggml_cont(ctx0, tok);
tok = ggml_mul_mat(ctx0, blk.pw1_w, tok); // [4C,T,B]
if (blk.pw1_b) {
ggml_tensor * b1 = ggml_reshape_3d(ctx0, blk.pw1_b, blk.pw1_b->ne[0], 1, 1); // [4C,1,1]
tok = ggml_add(ctx0, tok, b1);
}
x = ggml_permute(ctx0, tok, 1, 0, 2, 3); // [T,4C,B]
x = ggml_cont(ctx0, x);
x = ggml_reshape_4d(ctx0, x, w, h, tok->ne[0], b); // [W,H,4C,B]
x = ggml_gelu_erf(ctx0, x);
x = convnext_grn(x, blk.grn_w, blk.grn_b);
tok = ggml_reshape_3d(ctx0, x, w * h, x->ne[2], b); // [T,4C,B]
tok = ggml_permute(ctx0, tok, 1, 0, 2, 3); // [4C,T,B]
tok = ggml_cont(ctx0, tok);
tok = ggml_mul_mat(ctx0, blk.pw2_w, tok); // [C,T,B]
if (blk.pw2_b) {
ggml_tensor * b2 = ggml_reshape_3d(ctx0, blk.pw2_b, blk.pw2_b->ne[0], 1, 1); // [C,1,1]
tok = ggml_add(ctx0, tok, b2);
}
x = ggml_permute(ctx0, tok, 1, 0, 2, 3); // [T,C,B]
x = ggml_cont(ctx0, x);
x = ggml_reshape_4d(ctx0, x, w, h, tok->ne[0], b); // [W,H,C,B]
cur = ggml_add(ctx0, res, x);
ggml_format_name(cur, "yasa2_stage%zu_blk%zu_out", s, bi);
}
}
// HF path adds vision position embeddings BEFORE adaptive pooling.
const int64_t pre_w = cur->ne[0];
const int64_t pre_h = cur->ne[1];
ggml_tensor * tokens_pre = ggml_reshape_3d(ctx0, cur, pre_w * pre_h, cur->ne[2], cur->ne[3]); // [T,C,B]
tokens_pre = ggml_permute(ctx0, tokens_pre, 1, 0, 2, 3); // [C,T,B]
tokens_pre = ggml_cont(ctx0, tokens_pre);
if (model.yasa_vision_pos_embed && tokens_pre->ne[1] == model.yasa_vision_pos_embed->ne[1]) {
const int64_t n_ch = model.yasa_vision_pos_embed->ne[0];
const int64_t n_tokens = model.yasa_vision_pos_embed->ne[1];
ggml_tensor * pos = ggml_reshape_3d(ctx0, model.yasa_vision_pos_embed, (int) n_ch, (int) n_tokens, 1);
tokens_pre = ggml_add(ctx0, tokens_pre, pos);
}
cur = ggml_permute(ctx0, tokens_pre, 1, 0, 2, 3); // [T,C,B]
cur = ggml_cont(ctx0, cur);
cur = ggml_reshape_4d(ctx0, cur, pre_w, pre_h, cur->ne[1], cur->ne[2]); // [W,H,C,B]
// AdaptiveAvgPool2d target is 8x8 for real inputs, but warmup can use tiny images.
const int pooled_w = std::min(8, (int) cur->ne[0]);
const int pooled_h = std::min(8, (int) cur->ne[1]);
const int kw = std::max(1, (int) cur->ne[0] / pooled_w);
const int kh = std::max(1, (int) cur->ne[1] / pooled_h);
cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, kw, kh, kw, kh, 0, 0);
// [W,H,C,B] -> [C,T,B]
ggml_tensor * tokens = ggml_reshape_3d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2], cur->ne[3]);
tokens = ggml_permute(ctx0, tokens, 1, 0, 2, 3);
tokens = ggml_cont(ctx0, tokens);
cb(tokens, "yasa2_tokens", -1);
GGML_ASSERT(model.mm_0_w && model.mm_2_w);
ggml_tensor * embeddings = build_ffn(
tokens,
model.mm_0_w, model.mm_0_b,
nullptr, nullptr,
model.mm_2_w, model.mm_2_b,
FFN_GELU_ERF,
-1);
cb(embeddings, "yasa2_emb", -1);
ggml_build_forward_expand(gf, embeddings);
return gf;
}

View file

@ -35,15 +35,23 @@ struct mtmd_bitmap {
// position indexing for decoder model
enum mtmd_pos_type {
MTMD_POS_TYPE_NORMAL, // number of positions equals to number of tokens
MTMD_POS_TYPE_MROPE, // qwen-vl mrope style, each image takes max(t,h,w) position indexes
MTMD_POS_TYPE_NORMAL, // number of positions equals to number of tokens
MTMD_POS_TYPE_MROPE, // qwen-vl mrope style, each image takes max(t,h,w) position indexes
MTMD_POS_TYPE_HUNYUANVL, // HunyuanVL mrope + BOI/EOI/newline layout with XD-RoPE dim-3
};
struct mtmd_image_tokens {
uint32_t nx; // number of tokens in x direction
uint32_t ny; // number of tokens in y direction
mtmd_pos_type pos = MTMD_POS_TYPE_NORMAL;
uint32_t n_tokens() const { return nx * ny; }
uint32_t image_idx = 0; // 0-based position of this image among image chunks in the prompt(used by pos == MTMD_POS_TYPE_HUNYUANVL)
uint32_t n_tokens() const {
if (pos == MTMD_POS_TYPE_HUNYUANVL) {
// [BOI] [row0 tokens + newline] ... [row(ny-1) tokens + newline] [EOI]
return (nx + 1) * ny + 2;
}
return nx * ny;
}
clip_image_f32_batch batch_f32; // preprocessed image patches
std::string id; // optional user-defined ID, useful for KV cache tracking
@ -52,6 +60,7 @@ struct mtmd_image_tokens {
nx,
ny,
pos,
image_idx,
batch_f32.clone(),
id
};
@ -186,6 +195,7 @@ struct mtmd_context {
auto decoder_rope_type = llama_model_rope_type(text_model);
switch (decoder_rope_type) {
case LLAMA_ROPE_TYPE_NONE:
case LLAMA_ROPE_TYPE_NORM:
case LLAMA_ROPE_TYPE_NEOX:
{
@ -316,6 +326,19 @@ struct mtmd_context {
img_end = "<|vision_end|>";
image_preproc = std::make_unique<mtmd_image_preprocessor_youtuvl>(ctx_v);
} break;
case PROJECTOR_TYPE_YASA2:
{
img_beg = "<image>";
img_end = "</image>";
// Currently only supprots single-tile preprocessing: any input is downscaled
// to one image_size x image_size tile (64 output tokens via 8x8 adaptive avg
// pool).
// However, the model itself supports llava-uhd multi-tile tiling for high-res
// images. This will be implemented in a future PR (dispatch on has_pinpoints
// - see LDP/COGVLM branch above) and emit image_grid_pinpoints in the conversion
// script.
image_preproc = std::make_unique<mtmd_image_preprocessor_fixed_size>(ctx_v);
} break;
case PROJECTOR_TYPE_GEMMA3:
case PROJECTOR_TYPE_GEMMA3NV:
{
@ -453,6 +476,7 @@ struct mtmd_context {
image_preproc = std::make_unique<mtmd_image_preprocessor_deepseekocr>(ctx_v);
} break;
case PROJECTOR_TYPE_HUNYUANOCR:
case PROJECTOR_TYPE_HUNYUANVL:
{
// note: these use fullwidth (U+FF5C) and ▁ (U+2581) to match the tokenizer vocabulary
img_beg = "<hy_place▁holder▁no▁100>";
@ -598,6 +622,7 @@ struct mtmd_tokenizer {
const llama_vocab * vocab;
mtmd_input_chunks cur;
uint32_t n_images_added = 0; // 0-based index assigned to the next image chunk
mtmd_tokenizer(mtmd_context * ctx,
const mtmd_input_text * text,
@ -806,6 +831,14 @@ struct mtmd_tokenizer {
image_tokens->ny = 1;
}
image_tokens->pos = ctx->pos_type;
// HunyuanVL wraps the image grid with BOI/EOI and adds one newline per row,
// and uses XD-RoPE (dim-3 = image index). Override the position type so that
// n_tokens() and mtmd_image_tokens_get_decoder_pos pick the HunyuanVL layout.
if (ctx->proj_type_v() == PROJECTOR_TYPE_HUNYUANVL) {
image_tokens->pos = MTMD_POS_TYPE_HUNYUANVL;
image_tokens->image_idx = n_images_added;
GGML_ASSERT(n_tokens == (size_t)image_tokens->n_tokens());
}
image_tokens->batch_f32 = std::move(batch_f32);
image_tokens->id = bitmap->id; // optional
@ -826,6 +859,9 @@ struct mtmd_tokenizer {
add_text(ctx->img_end, true); // add image end token
}
// advance image-chunk counter so the next image gets the next XD-RoPE dim-3 slot
n_images_added++;
} else {
// handle audio
@ -1273,6 +1309,38 @@ mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * ima
pos.y = pos_0 + i;
pos.z = pos_0 + i;
} break;
case MTMD_POS_TYPE_HUNYUANVL:
{
// HunyuanVL layout: [BOI] [row0 tokens + newline] ... [row(ny-1) tokens + newline] [EOI]
// Total = 1 + ny*(nx+1) + 1. BOI and EOI use sequential positions in every dim;
// content and row-newline tokens use (row, col) with XD-RoPE dim-3 = image_idx.
const uint32_t nx = image_tokens->nx;
const uint32_t n_total = image_tokens->n_tokens();
if (i == 0) {
// BOI
pos.t = pos_0 + i;
pos.x = pos_0 + i;
pos.y = pos_0 + i;
pos.z = pos_0 + i;
} else if (i == n_total - 1) {
// EOI
pos.t = pos_0 + i;
pos.x = pos_0 + i;
pos.y = pos_0 + i;
pos.z = pos_0 + i;
} else {
// content token at (row, col), or the trailing newline of a row (col == nx)
// section 0 = sequential, section 1 = w(col), section 2 = h(row), section 3 = image_count.
// set_position_mrope_2d writes .y -> section 1 and .x -> section 2
const uint32_t offset = (uint32_t)i - 1;
const uint32_t row = offset / (nx + 1);
const uint32_t col = offset % (nx + 1);
pos.t = pos_0 + i;
pos.x = row;
pos.y = col;
pos.z = image_tokens->image_idx;
}
} break;
default:
GGML_ABORT("invalid position type");
}
@ -1289,6 +1357,10 @@ llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
return std::max(image_tokens->nx, image_tokens->ny);
case MTMD_POS_TYPE_NORMAL:
return image_tokens->n_tokens();
case MTMD_POS_TYPE_HUNYUANVL:
// HunyuanVL: the sequential (dim-0) position advances by the full token count
// (includes BOI/EOI and row newline tokens), not by max(nx, ny)
return image_tokens->n_tokens();
default:
GGML_ABORT("invalid position type");
}

View file

@ -91,6 +91,7 @@ add_test_vision "ggml-org/LightOnOCR-1B-1025-GGUF:Q8_0"
add_test_vision "ggml-org/DeepSeek-OCR-GGUF:Q8_0" -p "Free OCR." --chat-template deepseek-ocr
add_test_vision "ggml-org/dots.ocr-GGUF:Q8_0" -p "OCR"
add_test_vision "ggml-org/HunyuanOCR-GGUF:Q8_0" -p "OCR"
add_test_vision "ggml-org/HunyuanVL-4B-GGUF:Q8_0"
add_test_vision "ggml-org/gemma-4-E2B-it-GGUF:Q8_0" --jinja
add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0"

View file

@ -0,0 +1,588 @@
#include "server-chat.h"
#include "server-common.h"
#include <sstream>
json server_chat_convert_responses_to_chatcmpl(const json & response_body) {
if (!response_body.contains("input")) {
throw std::invalid_argument("'input' is required");
}
if (!json_value(response_body, "previous_response_id", std::string{}).empty()) {
throw std::invalid_argument("llama.cpp does not support 'previous_response_id'.");
}
const json input_value = response_body.at("input");
json chatcmpl_body = response_body;
chatcmpl_body.erase("input");
std::vector<json> chatcmpl_messages;
if (response_body.contains("instructions")) {
chatcmpl_messages.push_back({
{"role", "system"},
{"content", json_value(response_body, "instructions", std::string())},
});
chatcmpl_body.erase("instructions");
}
if (input_value.is_string()) {
// #responses_create-input-text_input
chatcmpl_messages.push_back({
{"role", "user"},
{"content", input_value},
});
} else if (input_value.is_array()) {
// #responses_create-input-input_item_list
static auto exists_and_is_array = [](const json & j, const char * key) -> bool {
return j.contains(key) && j.at(key).is_array();
};
static auto exists_and_is_string = [](const json & j, const char * key) -> bool {
return j.contains(key) && j.at(key).is_string();
};
for (json item : input_value) {
bool merge_prev = !chatcmpl_messages.empty() && chatcmpl_messages.back().value("role", "") == "assistant";
if (exists_and_is_string(item, "content")) {
// #responses_create-input-input_item_list-input_message-content-text_input
// Only "Input message" contains item["content"]::string
// After converting item["content"]::string to item["content"]::array,
// we can treat "Input message" as sum of "Item-Input message" and "Item-Output message"
item["content"] = json::array({
json {
{"text", item.at("content")},
{"type", "input_text"}
}
});
}
if (exists_and_is_array(item, "content") &&
exists_and_is_string(item, "role") &&
(item.at("role") == "user" ||
item.at("role") == "system" ||
item.at("role") == "developer")
) {
// #responses_create-input-input_item_list-item-input_message
std::vector<json> chatcmpl_content;
for (const json & input_item : item.at("content")) {
const std::string type = json_value(input_item, "type", std::string());
if (type == "input_text") {
if (!input_item.contains("text")) {
throw std::invalid_argument("'Input text' requires 'text'");
}
chatcmpl_content.push_back({
{"text", input_item.at("text")},
{"type", "text"},
});
} else if (type == "input_image") {
// While `detail` is marked as required,
// it has default value("auto") and can be omitted.
if (!input_item.contains("image_url")) {
throw std::invalid_argument("'image_url' is required");
}
chatcmpl_content.push_back({
{"image_url", json {
{"url", input_item.at("image_url")}
}},
{"type", "image_url"},
});
} else if (type == "input_file") {
throw std::invalid_argument("'input_file' is not supported by llamacpp at this moment");
} else {
throw std::invalid_argument("'type' must be one of 'input_text', 'input_image', or 'input_file'");
}
}
if (item.contains("type")) {
item.erase("type");
}
if (item.contains("status")) {
item.erase("status");
}
item["content"] = chatcmpl_content;
chatcmpl_messages.push_back(item);
} else if (exists_and_is_string(item, "role") &&
item.at("role") == "assistant" &&
exists_and_is_string(item, "type") &&
item.at("type") == "message"
) {
// #responses_create-input-input_item_list-item-output_message
auto chatcmpl_content = json::array();
// Handle both string content and array content
if (item.contains("content") && item.at("content").is_string()) {
// String content - convert to text content part
chatcmpl_content.push_back({
{"text", item.at("content")},
{"type", "text"},
});
} else if (exists_and_is_array(item, "content")) {
// Array content - process each item
for (const auto & output_text : item.at("content")) {
const std::string type = json_value(output_text, "type", std::string());
if (type == "output_text" || type == "input_text") {
// Accept both output_text and input_text (string content gets converted to input_text)
if (!exists_and_is_string(output_text, "text")) {
throw std::invalid_argument("'Output text' requires 'text'");
}
chatcmpl_content.push_back({
{"text", output_text.at("text")},
{"type", "text"},
});
} else if (type == "refusal") {
if (!exists_and_is_string(output_text, "refusal")) {
throw std::invalid_argument("'Refusal' requires 'refusal'");
}
chatcmpl_content.push_back({
{"refusal", output_text.at("refusal")},
{"type", "refusal"},
});
} else {
throw std::invalid_argument("'type' must be one of 'output_text' or 'refusal'");
}
}
}
if (merge_prev) {
auto & prev_msg = chatcmpl_messages.back();
if (!exists_and_is_array(prev_msg, "content")) {
prev_msg["content"] = json::array();
}
auto & prev_content = prev_msg["content"];
prev_content.insert(prev_content.end(), chatcmpl_content.begin(), chatcmpl_content.end());
} else {
item.erase("status");
item.erase("type");
item["content"] = chatcmpl_content;
chatcmpl_messages.push_back(item);
}
} else if (exists_and_is_string(item, "arguments") &&
exists_and_is_string(item, "call_id") &&
exists_and_is_string(item, "name") &&
exists_and_is_string(item, "type") &&
item.at("type") == "function_call"
) {
// #responses_create-input-input_item_list-item-function_tool_call
json tool_call = {
{"function", json {
{"arguments", item.at("arguments")},
{"name", item.at("name")},
}},
{"id", item.at("call_id")},
{"type", "function"},
};
if (merge_prev) {
auto & prev_msg = chatcmpl_messages.back();
if (!exists_and_is_array(prev_msg, "tool_calls")) {
prev_msg["tool_calls"] = json::array();
}
prev_msg["tool_calls"].push_back(tool_call);
} else {
chatcmpl_messages.push_back(json {
{"role", "assistant"},
{"tool_calls", json::array({tool_call})}
});
}
} else if (exists_and_is_string(item, "call_id") &&
(exists_and_is_string(item, "output") || exists_and_is_array(item, "output")) &&
exists_and_is_string(item, "type") &&
item.at("type") == "function_call_output"
) {
// #responses_create-input-input_item_list-item-function_tool_call_output
if (item.at("output").is_string()) {
chatcmpl_messages.push_back(json {
{"content", item.at("output")},
{"role", "tool"},
{"tool_call_id", item.at("call_id")},
});
} else {
json chatcmpl_outputs = item.at("output");
for (json & chatcmpl_output : chatcmpl_outputs) {
if (!chatcmpl_output.contains("type") || chatcmpl_output.at("type") != "input_text") {
throw std::invalid_argument("Output of tool call should be 'Input text'");
}
chatcmpl_output["type"] = "text";
}
chatcmpl_messages.push_back(json {
{"content", chatcmpl_outputs},
{"role", "tool"},
{"tool_call_id", item.at("call_id")},
});
}
} else if (exists_and_is_array(item, "summary") &&
exists_and_is_string(item, "type") &&
item.at("type") == "reasoning") {
// #responses_create-input-input_item_list-item-reasoning
if (!exists_and_is_array(item, "content")) {
throw std::invalid_argument("item['content'] is not an array");
}
if (item.at("content").empty()) {
throw std::invalid_argument("item['content'] is empty");
}
if (!exists_and_is_string(item.at("content")[0], "text")) {
throw std::invalid_argument("item['content']['text'] is not a string");
}
if (merge_prev) {
auto & prev_msg = chatcmpl_messages.back();
prev_msg["reasoning_content"] = item.at("content")[0].at("text");
} else {
chatcmpl_messages.push_back(json {
{"role", "assistant"},
{"content", json::array()},
{"reasoning_content", item.at("content")[0].at("text")},
});
}
} else {
throw std::invalid_argument("Cannot determine type of 'item'");
}
}
} else {
throw std::invalid_argument("'input' must be a string or array of objects");
}
chatcmpl_body["messages"] = chatcmpl_messages;
if (response_body.contains("tools")) {
if (!response_body.at("tools").is_array()) {
throw std::invalid_argument("'tools' must be an array of objects");
}
std::vector<json> chatcmpl_tools;
for (json resp_tool : response_body.at("tools")) {
json chatcmpl_tool;
if (json_value(resp_tool, "type", std::string()) != "function") {
throw std::invalid_argument("'type' of tool must be 'function'");
}
resp_tool.erase("type");
chatcmpl_tool["type"] = "function";
if (!resp_tool.contains("strict")) {
resp_tool["strict"] = true;
}
chatcmpl_tool["function"] = resp_tool;
chatcmpl_tools.push_back(chatcmpl_tool);
}
chatcmpl_body.erase("tools");
chatcmpl_body["tools"] = chatcmpl_tools;
}
if (response_body.contains("max_output_tokens")) {
chatcmpl_body.erase("max_output_tokens");
chatcmpl_body["max_tokens"] = response_body["max_output_tokens"];
}
return chatcmpl_body;
}
json server_chat_convert_anthropic_to_oai(const json & body) {
json oai_body;
// Convert system prompt
json oai_messages = json::array();
auto system_param = json_value(body, "system", json());
if (!system_param.is_null()) {
std::string system_content;
if (system_param.is_string()) {
system_content = system_param.get<std::string>();
} else if (system_param.is_array()) {
for (const auto & block : system_param) {
if (json_value(block, "type", std::string()) == "text") {
system_content += json_value(block, "text", std::string());
}
}
}
oai_messages.push_back({
{"role", "system"},
{"content", system_content}
});
}
// Convert messages
if (!body.contains("messages")) {
throw std::runtime_error("'messages' is required");
}
const json & messages = body.at("messages");
if (messages.is_array()) {
for (const auto & msg : messages) {
std::string role = json_value(msg, "role", std::string());
if (!msg.contains("content")) {
if (role == "assistant") {
continue;
}
oai_messages.push_back(msg);
continue;
}
const json & content = msg.at("content");
if (content.is_string()) {
oai_messages.push_back(msg);
continue;
}
if (!content.is_array()) {
oai_messages.push_back(msg);
continue;
}
json tool_calls = json::array();
json converted_content = json::array();
json tool_results = json::array();
std::string reasoning_content;
bool has_tool_calls = false;
for (const auto & block : content) {
std::string type = json_value(block, "type", std::string());
if (type == "text") {
converted_content.push_back(block);
} else if (type == "thinking") {
reasoning_content += json_value(block, "thinking", std::string());
} else if (type == "image") {
json source = json_value(block, "source", json::object());
std::string source_type = json_value(source, "type", std::string());
if (source_type == "base64") {
std::string media_type = json_value(source, "media_type", std::string("image/jpeg"));
std::string data = json_value(source, "data", std::string());
std::ostringstream ss;
ss << "data:" << media_type << ";base64," << data;
converted_content.push_back({
{"type", "image_url"},
{"image_url", {
{"url", ss.str()}
}}
});
} else if (source_type == "url") {
std::string url = json_value(source, "url", std::string());
converted_content.push_back({
{"type", "image_url"},
{"image_url", {
{"url", url}
}}
});
}
} else if (type == "tool_use") {
tool_calls.push_back({
{"id", json_value(block, "id", std::string())},
{"type", "function"},
{"function", {
{"name", json_value(block, "name", std::string())},
{"arguments", json_value(block, "input", json::object()).dump()}
}}
});
has_tool_calls = true;
} else if (type == "tool_result") {
std::string tool_use_id = json_value(block, "tool_use_id", std::string());
auto result_content = json_value(block, "content", json());
std::string result_text;
if (result_content.is_string()) {
result_text = result_content.get<std::string>();
} else if (result_content.is_array()) {
for (const auto & c : result_content) {
if (json_value(c, "type", std::string()) == "text") {
result_text += json_value(c, "text", std::string());
}
}
}
tool_results.push_back({
{"role", "tool"},
{"tool_call_id", tool_use_id},
{"content", result_text}
});
}
}
if (!converted_content.empty() || has_tool_calls || !reasoning_content.empty()) {
json new_msg = {{"role", role}};
if (!converted_content.empty()) {
new_msg["content"] = converted_content;
} else if (has_tool_calls || !reasoning_content.empty()) {
new_msg["content"] = "";
}
if (!tool_calls.empty()) {
new_msg["tool_calls"] = tool_calls;
}
if (!reasoning_content.empty()) {
new_msg["reasoning_content"] = reasoning_content;
}
oai_messages.push_back(new_msg);
}
for (const auto & tool_msg : tool_results) {
oai_messages.push_back(tool_msg);
}
}
}
oai_body["messages"] = oai_messages;
// Convert tools
if (body.contains("tools")) {
const json & tools = body.at("tools");
if (tools.is_array()) {
json oai_tools = json::array();
for (const auto & tool : tools) {
oai_tools.push_back({
{"type", "function"},
{"function", {
{"name", json_value(tool, "name", std::string())},
{"description", json_value(tool, "description", std::string())},
{"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()}
}}
});
}
oai_body["tools"] = oai_tools;
}
}
// Convert tool_choice
if (body.contains("tool_choice")) {
const json & tc = body.at("tool_choice");
if (tc.is_object()) {
std::string type = json_value(tc, "type", std::string());
if (type == "auto") {
oai_body["tool_choice"] = "auto";
} else if (type == "any" || type == "tool") {
oai_body["tool_choice"] = "required";
}
}
}
// Convert stop_sequences to stop
if (body.contains("stop_sequences")) {
oai_body["stop"] = body.at("stop_sequences");
}
// Handle max_tokens (required in Anthropic, but we're permissive)
if (body.contains("max_tokens")) {
oai_body["max_tokens"] = body.at("max_tokens");
} else {
oai_body["max_tokens"] = 4096;
}
// Pass through common params
for (const auto & key : {"temperature", "top_p", "top_k", "stream"}) {
if (body.contains(key)) {
oai_body[key] = body.at(key);
}
}
// Handle Anthropic-specific thinking param
if (body.contains("thinking")) {
json thinking = json_value(body, "thinking", json::object());
std::string thinking_type = json_value(thinking, "type", std::string());
if (thinking_type == "enabled") {
int budget_tokens = json_value(thinking, "budget_tokens", 10000);
oai_body["thinking_budget_tokens"] = budget_tokens;
}
}
// Handle Anthropic-specific metadata param
if (body.contains("metadata")) {
json metadata = json_value(body, "metadata", json::object());
std::string user_id = json_value(metadata, "user_id", std::string());
if (!user_id.empty()) {
oai_body["__metadata_user_id"] = user_id;
}
}
return oai_body;
}
json server_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
json delta = json::object();
if (!diff.reasoning_content_delta.empty()) {
delta["reasoning_content"] = diff.reasoning_content_delta;
}
if (!diff.content_delta.empty()) {
delta["content"] = diff.content_delta;
}
if (diff.tool_call_index != std::string::npos) {
json tool_call;
tool_call["index"] = diff.tool_call_index;
if (!diff.tool_call_delta.id.empty()) {
tool_call["id"] = diff.tool_call_delta.id;
tool_call["type"] = "function";
}
if (!diff.tool_call_delta.name.empty() || !diff.tool_call_delta.arguments.empty()) {
json function = json::object();
if (!diff.tool_call_delta.name.empty()) {
function["name"] = diff.tool_call_delta.name;
}
if (!diff.tool_call_delta.arguments.empty()) {
function["arguments"] = diff.tool_call_delta.arguments;
}
tool_call["function"] = function;
}
delta["tool_calls"] = json::array({ tool_call });
}
return delta;
}
json convert_transcriptions_to_chatcmpl(
const json & inp_body,
const std::map<std::string, raw_buffer> & in_files,
std::vector<raw_buffer> & out_files) {
// TODO @ngxson : this function may need to be improved in the future
// handle input files
out_files.clear();
auto it = in_files.find("file");
if (it != in_files.end()) {
out_files.push_back(it->second);
} else {
throw std::invalid_argument("No input file found for transcription");
}
// handle input data
std::string prompt = json_value(inp_body, "prompt", std::string());
std::string language = json_value(inp_body, "language", std::string());
std::string response_format = json_value(inp_body, "response_format", std::string("json"));
if (response_format != "json") {
throw std::invalid_argument("Only 'json' response_format is supported for transcription");
}
if (prompt.empty()) {
prompt = "Transcribe audio to text";
}
if (!language.empty()) {
prompt += string_format(" (language: %s)", language.c_str());
}
prompt += get_media_marker();
json chatcmpl_body = inp_body; // copy all fields
chatcmpl_body["messages"] = json::array({
{
{"role", "user"},
{"content", prompt},
},
});
// because input from form-data, everything is string, we need to correct the types here
std::string stream = json_value(inp_body, "stream", std::string("false"));
chatcmpl_body["stream"] = stream == "true";
if (inp_body.contains("max_tokens")) {
std::string inp = inp_body["max_tokens"].get<std::string>();
chatcmpl_body["max_tokens"] = std::stoul(inp);
}
if (inp_body.contains("temperature")) {
std::string inp = inp_body["temperature"].get<std::string>();
chatcmpl_body["temperature"] = std::stof(inp);
}
return chatcmpl_body;
}

View file

@ -0,0 +1,24 @@
// Chat conversion functions for server (Responses API, Anthropic API, OAI streaming diffs)
#pragma once
#include "chat.h"
#include "server-common.h"
#include <nlohmann/json_fwd.hpp>
using json = nlohmann::ordered_json;
// Convert OpenAI Responses API format to OpenAI Chat Completions API format
json server_chat_convert_responses_to_chatcmpl(const json & body);
// Convert Anthropic Messages API format to OpenAI Chat Completions API format
json server_chat_convert_anthropic_to_oai(const json & body);
// convert OpenAI transcriptions API format to OpenAI Chat Completions API format
json convert_transcriptions_to_chatcmpl(
const json & body,
const std::map<std::string, raw_buffer> & in_files,
std::vector<raw_buffer> & out_files);
json server_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);

View file

@ -1027,6 +1027,8 @@ json oaicompat_chat_params_parse(
}
}
auto caps = common_chat_templates_get_caps(opt.tmpls.get());
common_chat_templates_inputs inputs;
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
inputs.tools = common_chat_tools_parse_oaicompat(tools);
@ -1034,7 +1036,7 @@ json oaicompat_chat_params_parse(
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
inputs.grammar = grammar;
inputs.use_jinja = opt.use_jinja;
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", caps["supports_parallel_tool_calls"]);
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
inputs.reasoning_format = opt.reasoning_format;
if (body.contains("reasoning_format")) {
@ -1164,573 +1166,6 @@ json oaicompat_chat_params_parse(
return llama_params;
}
json convert_responses_to_chatcmpl(const json & response_body) {
if (!response_body.contains("input")) {
throw std::invalid_argument("'input' is required");
}
if (!json_value(response_body, "previous_response_id", std::string{}).empty()) {
throw std::invalid_argument("llama.cpp does not support 'previous_response_id'.");
}
const json input_value = response_body.at("input");
json chatcmpl_body = response_body;
chatcmpl_body.erase("input");
std::vector<json> chatcmpl_messages;
if (response_body.contains("instructions")) {
chatcmpl_messages.push_back({
{"role", "system"},
{"content", json_value(response_body, "instructions", std::string())},
});
chatcmpl_body.erase("instructions");
}
if (input_value.is_string()) {
// #responses_create-input-text_input
chatcmpl_messages.push_back({
{"role", "user"},
{"content", input_value},
});
} else if (input_value.is_array()) {
// #responses_create-input-input_item_list
static auto exists_and_is_array = [](const json & j, const char * key) -> bool {
return j.contains(key) && j.at(key).is_array();
};
static auto exists_and_is_string = [](const json & j, const char * key) -> bool {
return j.contains(key) && j.at(key).is_string();
};
for (json item : input_value) {
bool merge_prev = !chatcmpl_messages.empty() && chatcmpl_messages.back().value("role", "") == "assistant";
if (exists_and_is_string(item, "content")) {
// #responses_create-input-input_item_list-input_message-content-text_input
// Only "Input message" contains item["content"]::string
// After converting item["content"]::string to item["content"]::array,
// we can treat "Input message" as sum of "Item-Input message" and "Item-Output message"
item["content"] = json::array({
json {
{"text", item.at("content")},
{"type", "input_text"}
}
});
}
if (exists_and_is_array(item, "content") &&
exists_and_is_string(item, "role") &&
(item.at("role") == "user" ||
item.at("role") == "system" ||
item.at("role") == "developer")
) {
// #responses_create-input-input_item_list-item-input_message
std::vector<json> chatcmpl_content;
for (const json & input_item : item.at("content")) {
const std::string type = json_value(input_item, "type", std::string());
if (type == "input_text") {
if (!input_item.contains("text")) {
throw std::invalid_argument("'Input text' requires 'text'");
}
chatcmpl_content.push_back({
{"text", input_item.at("text")},
{"type", "text"},
});
} else if (type == "input_image") {
// While `detail` is marked as required,
// it has default value("auto") and can be omitted.
if (!input_item.contains("image_url")) {
throw std::invalid_argument("'image_url' is required");
}
chatcmpl_content.push_back({
{"image_url", json {
{"url", input_item.at("image_url")}
}},
{"type", "image_url"},
});
} else if (type == "input_file") {
throw std::invalid_argument("'input_file' is not supported by llamacpp at this moment");
// if (input_item.contains("file_url")) {
// // chat completion API does not support file_url
// throw std::invalid_argument("'file_url' is not supported");
// }
// if (!input_item.contains("file_data") || !input_item.contains("filename")) {
// throw std::invalid_argument("Both 'file_data' and 'filename' are required");
// }
// chatcmpl_content.push_back({
// {"file", json {
// {"file_data", input_item.at("file_data")},
// {"filename", input_item.at("filename")},
// }},
// {"type", "file"},
// });
} else {
throw std::invalid_argument("'type' must be one of 'input_text', 'input_image', or 'input_file'");
}
}
if (item.contains("type")) {
item.erase("type");
}
if (item.contains("status")) {
item.erase("status");
}
item["content"] = chatcmpl_content;
chatcmpl_messages.push_back(item);
} else if (exists_and_is_array(item, "content") &&
exists_and_is_string(item, "role") &&
item.at("role") == "assistant" &&
// exists_and_is_string(item, "status") &&
// (item.at("status") == "in_progress" ||
// item.at("status") == "completed" ||
// item.at("status") == "incomplete") &&
// item["status"] not sent by codex-cli
exists_and_is_string(item, "type") &&
item.at("type") == "message"
) {
// #responses_create-input-input_item_list-item-output_message
auto chatcmpl_content = json::array();
for (const auto & output_text : item.at("content")) {
const std::string type = json_value(output_text, "type", std::string());
if (type == "output_text") {
if (!exists_and_is_string(output_text, "text")) {
throw std::invalid_argument("'Output text' requires 'text'");
// Ignore annotations and logprobs for now
chatcmpl_content.push_back({
{"text", output_text.at("text")},
{"type", "text"},
});
}
} else if (type == "refusal") {
if (!exists_and_is_string(output_text, "refusal")) {
throw std::invalid_argument("'Refusal' requires 'refusal'");
// Ignore annotations and logprobs for now
chatcmpl_content.push_back({
{"refusal", output_text.at("refusal")},
{"type", "refusal"},
});
}
} else {
throw std::invalid_argument("'type' must be one of 'output_text' or 'refusal'");
}
}
if (merge_prev) {
auto & prev_msg = chatcmpl_messages.back();
if (!exists_and_is_array(prev_msg, "content")) {
prev_msg["content"] = json::array();
}
auto & prev_content = prev_msg["content"];
prev_content.insert(prev_content.end(), chatcmpl_content.begin(), chatcmpl_content.end());
} else {
item.erase("status");
item.erase("type");
item["content"] = chatcmpl_content;
chatcmpl_messages.push_back(item);
}
} else if (exists_and_is_string(item, "arguments") &&
exists_and_is_string(item, "call_id") &&
exists_and_is_string(item, "name") &&
exists_and_is_string(item, "type") &&
item.at("type") == "function_call"
) {
// #responses_create-input-input_item_list-item-function_tool_call
json tool_call = {
{"function", json {
{"arguments", item.at("arguments")},
{"name", item.at("name")},
}},
{"id", item.at("call_id")},
{"type", "function"},
};
if (merge_prev) {
auto & prev_msg = chatcmpl_messages.back();
if (!exists_and_is_array(prev_msg, "tool_calls")) {
prev_msg["tool_calls"] = json::array();
}
prev_msg["tool_calls"].push_back(tool_call);
} else {
chatcmpl_messages.push_back(json {
{"role", "assistant"},
{"tool_calls", json::array({tool_call})}
});
}
} else if (exists_and_is_string(item, "call_id") &&
(exists_and_is_string(item, "output") || exists_and_is_array(item, "output")) &&
exists_and_is_string(item, "type") &&
item.at("type") == "function_call_output"
) {
// #responses_create-input-input_item_list-item-function_tool_call_output
if (item.at("output").is_string()) {
chatcmpl_messages.push_back(json {
{"content", item.at("output")},
{"role", "tool"},
{"tool_call_id", item.at("call_id")},
});
} else {
json chatcmpl_outputs = item.at("output");
for (json & chatcmpl_output : chatcmpl_outputs) {
if (!chatcmpl_output.contains("type") || chatcmpl_output.at("type") != "input_text") {
throw std::invalid_argument("Output of tool call should be 'Input text'");
}
chatcmpl_output["type"] = "text";
}
chatcmpl_messages.push_back(json {
{"content", chatcmpl_outputs},
{"role", "tool"},
{"tool_call_id", item.at("call_id")},
});
}
} else if (// exists_and_is_string(item, "id") &&
// item["id"] not sent by codex-cli
exists_and_is_array(item, "summary") &&
exists_and_is_string(item, "type") &&
item.at("type") == "reasoning") {
// #responses_create-input-input_item_list-item-reasoning
if (!exists_and_is_array(item, "content")) {
throw std::invalid_argument("item['content'] is not an array");
}
if (item.at("content").empty()) {
throw std::invalid_argument("item['content'] is empty");
}
if (!exists_and_is_string(item.at("content")[0], "text")) {
throw std::invalid_argument("item['content']['text'] is not a string");
}
if (merge_prev) {
auto & prev_msg = chatcmpl_messages.back();
prev_msg["reasoning_content"] = item.at("content")[0].at("text");
} else {
chatcmpl_messages.push_back(json {
{"role", "assistant"},
{"content", json::array()},
{"reasoning_content", item.at("content")[0].at("text")},
});
}
} else {
throw std::invalid_argument("Cannot determine type of 'item'");
}
}
} else {
throw std::invalid_argument("'input' must be a string or array of objects");
}
chatcmpl_body["messages"] = chatcmpl_messages;
if (response_body.contains("tools")) {
if (!response_body.at("tools").is_array()) {
throw std::invalid_argument("'tools' must be an array of objects");
}
std::vector<json> chatcmpl_tools;
for (json resp_tool : response_body.at("tools")) {
json chatcmpl_tool;
if (json_value(resp_tool, "type", std::string()) != "function") {
throw std::invalid_argument("'type' of tool must be 'function'");
}
resp_tool.erase("type");
chatcmpl_tool["type"] = "function";
if (!resp_tool.contains("strict")) {
resp_tool["strict"] = true;
}
chatcmpl_tool["function"] = resp_tool;
chatcmpl_tools.push_back(chatcmpl_tool);
}
chatcmpl_body.erase("tools");
chatcmpl_body["tools"] = chatcmpl_tools;
}
if (response_body.contains("max_output_tokens")) {
chatcmpl_body.erase("max_output_tokens");
chatcmpl_body["max_tokens"] = response_body["max_output_tokens"];
}
return chatcmpl_body;
}
json convert_transcriptions_to_chatcmpl(
const json & inp_body,
const std::map<std::string, raw_buffer> & in_files,
std::vector<raw_buffer> & out_files) {
// TODO @ngxson : this function may need to be improved in the future
// handle input files
out_files.clear();
auto it = in_files.find("file");
if (it != in_files.end()) {
out_files.push_back(it->second);
} else {
throw std::invalid_argument("No input file found for transcription");
}
// handle input data
std::string prompt = json_value(inp_body, "prompt", std::string());
std::string language = json_value(inp_body, "language", std::string());
std::string response_format = json_value(inp_body, "response_format", std::string("json"));
if (response_format != "json") {
throw std::invalid_argument("Only 'json' response_format is supported for transcription");
}
if (prompt.empty()) {
prompt = "Transcribe audio to text";
}
if (!language.empty()) {
prompt += string_format(" (language: %s)", language.c_str());
}
prompt += get_media_marker();
json chatcmpl_body = inp_body; // copy all fields
chatcmpl_body["messages"] = json::array({
{
{"role", "user"},
{"content", prompt},
},
});
// because input from form-data, everything is string, we need to correct the types here
std::string stream = json_value(inp_body, "stream", std::string("false"));
chatcmpl_body["stream"] = stream == "true";
if (inp_body.contains("max_tokens")) {
std::string inp = inp_body["max_tokens"].get<std::string>();
chatcmpl_body["max_tokens"] = std::stoul(inp);
}
if (inp_body.contains("temperature")) {
std::string inp = inp_body["temperature"].get<std::string>();
chatcmpl_body["temperature"] = std::stof(inp);
}
return chatcmpl_body;
}
json convert_anthropic_to_oai(const json & body) {
json oai_body;
// Convert system prompt
json oai_messages = json::array();
auto system_param = json_value(body, "system", json());
if (!system_param.is_null()) {
std::string system_content;
if (system_param.is_string()) {
system_content = system_param.get<std::string>();
} else if (system_param.is_array()) {
for (const auto & block : system_param) {
if (json_value(block, "type", std::string()) == "text") {
system_content += json_value(block, "text", std::string());
}
}
}
oai_messages.push_back({
{"role", "system"},
{"content", system_content}
});
}
// Convert messages
if (!body.contains("messages")) {
throw std::runtime_error("'messages' is required");
}
const json & messages = body.at("messages");
if (messages.is_array()) {
for (const auto & msg : messages) {
std::string role = json_value(msg, "role", std::string());
if (!msg.contains("content")) {
if (role == "assistant") {
continue;
}
oai_messages.push_back(msg);
continue;
}
const json & content = msg.at("content");
if (content.is_string()) {
oai_messages.push_back(msg);
continue;
}
if (!content.is_array()) {
oai_messages.push_back(msg);
continue;
}
json tool_calls = json::array();
json converted_content = json::array();
json tool_results = json::array();
std::string reasoning_content;
bool has_tool_calls = false;
for (const auto & block : content) {
std::string type = json_value(block, "type", std::string());
if (type == "text") {
converted_content.push_back(block);
} else if (type == "thinking") {
reasoning_content += json_value(block, "thinking", std::string());
} else if (type == "image") {
json source = json_value(block, "source", json::object());
std::string source_type = json_value(source, "type", std::string());
if (source_type == "base64") {
std::string media_type = json_value(source, "media_type", std::string("image/jpeg"));
std::string data = json_value(source, "data", std::string());
std::ostringstream ss;
ss << "data:" << media_type << ";base64," << data;
converted_content.push_back({
{"type", "image_url"},
{"image_url", {
{"url", ss.str()}
}}
});
} else if (source_type == "url") {
std::string url = json_value(source, "url", std::string());
converted_content.push_back({
{"type", "image_url"},
{"image_url", {
{"url", url}
}}
});
}
} else if (type == "tool_use") {
tool_calls.push_back({
{"id", json_value(block, "id", std::string())},
{"type", "function"},
{"function", {
{"name", json_value(block, "name", std::string())},
{"arguments", json_value(block, "input", json::object()).dump()}
}}
});
has_tool_calls = true;
} else if (type == "tool_result") {
std::string tool_use_id = json_value(block, "tool_use_id", std::string());
auto result_content = json_value(block, "content", json());
std::string result_text;
if (result_content.is_string()) {
result_text = result_content.get<std::string>();
} else if (result_content.is_array()) {
for (const auto & c : result_content) {
if (json_value(c, "type", std::string()) == "text") {
result_text += json_value(c, "text", std::string());
}
}
}
tool_results.push_back({
{"role", "tool"},
{"tool_call_id", tool_use_id},
{"content", result_text}
});
}
}
if (!converted_content.empty() || has_tool_calls || !reasoning_content.empty()) {
json new_msg = {{"role", role}};
if (!converted_content.empty()) {
new_msg["content"] = converted_content;
} else if (has_tool_calls || !reasoning_content.empty()) {
new_msg["content"] = "";
}
if (!tool_calls.empty()) {
new_msg["tool_calls"] = tool_calls;
}
if (!reasoning_content.empty()) {
new_msg["reasoning_content"] = reasoning_content;
}
oai_messages.push_back(new_msg);
}
for (const auto & tool_msg : tool_results) {
oai_messages.push_back(tool_msg);
}
}
}
oai_body["messages"] = oai_messages;
// Convert tools
if (body.contains("tools")) {
const json & tools = body.at("tools");
if (tools.is_array()) {
json oai_tools = json::array();
for (const auto & tool : tools) {
oai_tools.push_back({
{"type", "function"},
{"function", {
{"name", json_value(tool, "name", std::string())},
{"description", json_value(tool, "description", std::string())},
{"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()}
}}
});
}
oai_body["tools"] = oai_tools;
}
}
// Convert tool_choice
if (body.contains("tool_choice")) {
const json & tc = body.at("tool_choice");
if (tc.is_object()) {
std::string type = json_value(tc, "type", std::string());
if (type == "auto") {
oai_body["tool_choice"] = "auto";
} else if (type == "any" || type == "tool") {
oai_body["tool_choice"] = "required";
}
}
}
// Convert stop_sequences to stop
if (body.contains("stop_sequences")) {
oai_body["stop"] = body.at("stop_sequences");
}
// Handle max_tokens (required in Anthropic, but we're permissive)
if (body.contains("max_tokens")) {
oai_body["max_tokens"] = body.at("max_tokens");
} else {
oai_body["max_tokens"] = 4096;
}
// Pass through common params
for (const auto & key : {"temperature", "top_p", "top_k", "stream"}) {
if (body.contains(key)) {
oai_body[key] = body.at(key);
}
}
// Handle Anthropic-specific thinking param
if (body.contains("thinking")) {
json thinking = json_value(body, "thinking", json::object());
std::string thinking_type = json_value(thinking, "type", std::string());
if (thinking_type == "enabled") {
int budget_tokens = json_value(thinking, "budget_tokens", 10000);
oai_body["thinking_budget_tokens"] = budget_tokens;
}
}
// Handle Anthropic-specific metadata param
if (body.contains("metadata")) {
json metadata = json_value(body, "metadata", json::object());
std::string user_id = json_value(metadata, "user_id", std::string());
if (!user_id.empty()) {
oai_body["__metadata_user_id"] = user_id;
}
}
return oai_body;
}
json format_embeddings_response_oaicompat(
const json & request,
const std::string & model_name,

View file

@ -307,18 +307,6 @@ json oaicompat_chat_params_parse(
const server_chat_params & opt,
std::vector<raw_buffer> & out_files);
// convert OpenAI Responses API format to OpenAI Chat Completions API format
json convert_responses_to_chatcmpl(const json & body);
// convert OpenAI transcriptions API format to OpenAI Chat Completions API format
json convert_transcriptions_to_chatcmpl(
const json & body,
const std::map<std::string, raw_buffer> & in_files,
std::vector<raw_buffer> & out_files);
// convert Anthropic Messages API format to OpenAI Chat Completions API format
json convert_anthropic_to_oai(const json & body);
// TODO: move it to server-task.cpp
json format_embeddings_response_oaicompat(
const json & request,

View file

@ -1,5 +1,6 @@
#include "server-context.h"
#include "server-chat.h"
#include "server-common.h"
#include "server-http.h"
#include "server-task.h"
@ -1044,8 +1045,8 @@ private:
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
/* enable_thinking */ enable_thinking,
/* reasoning_budget */ params_base.reasoning_budget,
/* reasoning_budget_msg */ params_base.reasoning_budget_message,
/* reasoning_budget */ params_base.sampling.reasoning_budget_tokens,
/* reasoning_budget_msg */ params_base.sampling.reasoning_budget_message,
/* media_path */ params_base.media_path,
/* force_pure_content */ params_base.force_pure_content_parser
};
@ -2960,7 +2961,13 @@ private:
// verify and try to accept the draft
{
common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get()));
const bool use_ckpt = slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
// only save the sampler sampler state if we use checkpoints
common_sampler_ptr smpl_save;
if (use_ckpt) {
smpl_save.reset(common_sampler_clone(slot.smpl.get()));
}
GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1);
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft);
@ -2972,7 +2979,7 @@ private:
// check for partial draft acceptance
if (accepted.size() < slot.spec_draft.size() + 1) {
if (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
if (use_ckpt) {
// partial acceptance is not supported by the context -> truncate the draft and restore the state
slot.spec_draft = std::move(accepted);
@ -3774,7 +3781,7 @@ void server_routes::init_routes() {
this->post_responses_oai = [this](const server_http_req & req) {
auto res = create_response();
std::vector<raw_buffer> files;
json body = convert_responses_to_chatcmpl(json::parse(req.body));
json body = server_chat_convert_responses_to_chatcmpl(json::parse(req.body));
SRV_DBG("%s\n", "Request converted: OpenAI Responses -> OpenAI Chat Completions");
SRV_DBG("converted request: %s\n", body.dump().c_str());
json body_parsed = oaicompat_chat_params_parse(
@ -3819,7 +3826,7 @@ void server_routes::init_routes() {
this->post_anthropic_messages = [this](const server_http_req & req) {
auto res = create_response();
std::vector<raw_buffer> files;
json body = convert_anthropic_to_oai(json::parse(req.body));
json body = server_chat_convert_anthropic_to_oai(json::parse(req.body));
SRV_DBG("%s\n", "Request converted: Anthropic -> OpenAI Chat Completions");
SRV_DBG("converted request: %s\n", body.dump().c_str());
json body_parsed = oaicompat_chat_params_parse(
@ -3837,7 +3844,7 @@ void server_routes::init_routes() {
this->post_anthropic_count_tokens = [this](const server_http_req & req) {
auto res = create_response();
std::vector<raw_buffer> files;
json body = convert_anthropic_to_oai(json::parse(req.body));
json body = server_chat_convert_anthropic_to_oai(json::parse(req.body));
SRV_DBG("%s\n", "Request converted: Anthropic -> OpenAI Chat Completions");
SRV_DBG("converted request: %s\n", body.dump().c_str());
json body_parsed = oaicompat_chat_params_parse(

View file

@ -712,6 +712,11 @@ void server_models::unload(const std::string & name) {
if (it->second.meta.is_running()) {
SRV_INF("stopping model instance name=%s\n", name.c_str());
stopping_models.insert(name);
if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) {
// special case: if model is in loading state, unloading means force-killing it
SRV_WRN("model name=%s is still loading, force-killing\n", name.c_str());
subprocess_terminate(it->second.subproc.get());
}
cv_stop.notify_all();
// status change will be handled by the managing thread
} else {

View file

@ -1,6 +1,7 @@
#include "server-task.h"
#include "build-info.h"
#include "server-chat.h"
#include "chat.h"
#include "common.h"
#include "json-schema-to-grammar.h"
@ -873,7 +874,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
json {
{"finish_reason", nullptr},
{"index", index},
{"delta", common_chat_msg_diff_to_json_oaicompat(diff)},
{"delta", server_chat_msg_diff_to_json_oaicompat(diff)},
},
})},
{"created", t},
@ -1110,7 +1111,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() {
json server_task_result_cmpl_final::to_json_oaicompat_asr() {
json event = json {
{"type", "transcript.text.done"},
{"text", content},
{"text", oaicompat_msg.content},
{"usage", json {
{"type", "tokens"},
{"input_tokens", n_prompt_tokens},
@ -1522,7 +1523,7 @@ json server_task_result_cmpl_partial::to_json_oaicompat_chat() {
}
for (const auto & diff : oaicompat_msg_diffs) {
add_delta(common_chat_msg_diff_to_json_oaicompat(diff));
add_delta(server_chat_msg_diff_to_json_oaicompat(diff));
}
if (!deltas.empty()) {

View file

@ -872,7 +872,8 @@ bool write_websocket_frame(Stream &strm, ws::Opcode opcode,
if (strm.write(reinterpret_cast<char *>(header), 2) < 0) { return false; }
uint8_t ext[8];
for (int i = 7; i >= 0; i--) {
ext[7 - i] = static_cast<uint8_t>((len >> (i * 8)) & 0xFF);
ext[7 - i] =
static_cast<uint8_t>((static_cast<uint64_t>(len) >> (i * 8)) & 0xFF);
}
if (strm.write(reinterpret_cast<char *>(ext), 8) < 0) { return false; }
}
@ -1034,10 +1035,15 @@ bool canonicalize_path(const char *path, std::string &resolved) {
char buf[_MAX_PATH];
if (_fullpath(buf, path, _MAX_PATH) == nullptr) { return false; }
resolved = buf;
#else
#elif defined(PATH_MAX)
char buf[PATH_MAX];
if (realpath(path, buf) == nullptr) { return false; }
resolved = buf;
#else
auto buf = realpath(path, nullptr);
auto guard = scope_exit([&]() { std::free(buf); });
if (buf == nullptr) { return false; }
resolved = buf;
#endif
return true;
}
@ -2765,6 +2771,35 @@ EncodingType encoding_type(const Request &req, const Response &res) {
return best;
}
std::unique_ptr<compressor> make_compressor(EncodingType type) {
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
if (type == EncodingType::Gzip) {
return detail::make_unique<gzip_compressor>();
}
#endif
#ifdef CPPHTTPLIB_BROTLI_SUPPORT
if (type == EncodingType::Brotli) {
return detail::make_unique<brotli_compressor>();
}
#endif
#ifdef CPPHTTPLIB_ZSTD_SUPPORT
if (type == EncodingType::Zstd) {
return detail::make_unique<zstd_compressor>();
}
#endif
(void)type;
return nullptr;
}
const char *encoding_name(EncodingType type) {
switch (type) {
case EncodingType::Gzip: return "gzip";
case EncodingType::Brotli: return "br";
case EncodingType::Zstd: return "zstd";
default: return "";
}
}
bool nocompressor::compress(const char *data, size_t data_length,
bool /*last*/, Callback callback) {
if (!data_length) { return true; }
@ -3097,6 +3132,29 @@ const char *get_header_value(const Headers &headers,
return def;
}
size_t get_header_value_count(const Headers &headers,
const std::string &key) {
auto r = headers.equal_range(key);
return static_cast<size_t>(std::distance(r.first, r.second));
}
template <typename Map>
typename Map::mapped_type
get_multimap_value(const Map &m, const std::string &key, size_t id) {
auto rng = m.equal_range(key);
auto it = rng.first;
std::advance(it, static_cast<ssize_t>(id));
if (it != rng.second) { return it->second; }
return typename Map::mapped_type();
}
void set_header(Headers &headers, const std::string &key,
const std::string &val) {
if (fields::is_field_name(key) && fields::is_field_value(val)) {
headers.emplace(key, val);
}
}
bool read_headers(Stream &strm, Headers &headers) {
const auto bufsiz = 2048;
char buf[bufsiz];
@ -5791,16 +5849,12 @@ std::string Request::get_header_value(const std::string &key,
}
size_t Request::get_header_value_count(const std::string &key) const {
auto r = headers.equal_range(key);
return static_cast<size_t>(std::distance(r.first, r.second));
return detail::get_header_value_count(headers, key);
}
void Request::set_header(const std::string &key,
const std::string &val) {
if (detail::fields::is_field_name(key) &&
detail::fields::is_field_value(val)) {
headers.emplace(key, val);
}
detail::set_header(headers, key, val);
}
bool Request::has_trailer(const std::string &key) const {
@ -5809,11 +5863,7 @@ bool Request::has_trailer(const std::string &key) const {
std::string Request::get_trailer_value(const std::string &key,
size_t id) const {
auto rng = trailers.equal_range(key);
auto it = rng.first;
std::advance(it, static_cast<ssize_t>(id));
if (it != rng.second) { return it->second; }
return std::string();
return detail::get_multimap_value(trailers, key, id);
}
size_t Request::get_trailer_value_count(const std::string &key) const {
@ -5827,11 +5877,7 @@ bool Request::has_param(const std::string &key) const {
std::string Request::get_param_value(const std::string &key,
size_t id) const {
auto rng = params.equal_range(key);
auto it = rng.first;
std::advance(it, static_cast<ssize_t>(id));
if (it != rng.second) { return it->second; }
return std::string();
return detail::get_multimap_value(params, key, id);
}
std::vector<std::string>
@ -5886,11 +5932,7 @@ size_t MultipartFormData::get_field_count(const std::string &key) const {
FormData MultipartFormData::get_file(const std::string &key,
size_t id) const {
auto rng = files.equal_range(key);
auto it = rng.first;
std::advance(it, static_cast<ssize_t>(id));
if (it != rng.second) { return it->second; }
return FormData();
return detail::get_multimap_value(files, key, id);
}
std::vector<FormData>
@ -5929,16 +5971,12 @@ std::string Response::get_header_value(const std::string &key,
}
size_t Response::get_header_value_count(const std::string &key) const {
auto r = headers.equal_range(key);
return static_cast<size_t>(std::distance(r.first, r.second));
return detail::get_header_value_count(headers, key);
}
void Response::set_header(const std::string &key,
const std::string &val) {
if (detail::fields::is_field_name(key) &&
detail::fields::is_field_value(val)) {
headers.emplace(key, val);
}
detail::set_header(headers, key, val);
}
bool Response::has_trailer(const std::string &key) const {
return trailers.find(key) != trailers.end();
@ -5946,11 +5984,7 @@ bool Response::has_trailer(const std::string &key) const {
std::string Response::get_trailer_value(const std::string &key,
size_t id) const {
auto rng = trailers.equal_range(key);
auto it = rng.first;
std::advance(it, static_cast<ssize_t>(id));
if (it != rng.second) { return it->second; }
return std::string();
return detail::get_multimap_value(trailers, key, id);
}
size_t Response::get_trailer_value_count(const std::string &key) const {
@ -6253,15 +6287,6 @@ void ThreadPool::worker(bool is_dynamic) {
assert(true == static_cast<bool>(fn));
fn();
// Dynamic thread: exit if queue is empty after task completion
if (is_dynamic) {
std::unique_lock<std::mutex> lock(mutex_);
if (jobs_.empty()) {
move_to_finished(std::this_thread::get_id());
break;
}
}
}
#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && \
@ -6791,61 +6816,51 @@ Server::make_matcher(const std::string &pattern) {
}
Server &Server::Get(const std::string &pattern, Handler handler) {
get_handlers_.emplace_back(make_matcher(pattern), std::move(handler));
return *this;
return add_handler(get_handlers_, pattern, std::move(handler));
}
Server &Server::Post(const std::string &pattern, Handler handler) {
post_handlers_.emplace_back(make_matcher(pattern), std::move(handler));
return *this;
return add_handler(post_handlers_, pattern, std::move(handler));
}
Server &Server::Post(const std::string &pattern,
HandlerWithContentReader handler) {
post_handlers_for_content_reader_.emplace_back(make_matcher(pattern),
std::move(handler));
return *this;
return add_handler(post_handlers_for_content_reader_, pattern,
std::move(handler));
}
Server &Server::Put(const std::string &pattern, Handler handler) {
put_handlers_.emplace_back(make_matcher(pattern), std::move(handler));
return *this;
return add_handler(put_handlers_, pattern, std::move(handler));
}
Server &Server::Put(const std::string &pattern,
HandlerWithContentReader handler) {
put_handlers_for_content_reader_.emplace_back(make_matcher(pattern),
std::move(handler));
return *this;
return add_handler(put_handlers_for_content_reader_, pattern,
std::move(handler));
}
Server &Server::Patch(const std::string &pattern, Handler handler) {
patch_handlers_.emplace_back(make_matcher(pattern), std::move(handler));
return *this;
return add_handler(patch_handlers_, pattern, std::move(handler));
}
Server &Server::Patch(const std::string &pattern,
HandlerWithContentReader handler) {
patch_handlers_for_content_reader_.emplace_back(make_matcher(pattern),
std::move(handler));
return *this;
return add_handler(patch_handlers_for_content_reader_, pattern,
std::move(handler));
}
Server &Server::Delete(const std::string &pattern, Handler handler) {
delete_handlers_.emplace_back(make_matcher(pattern), std::move(handler));
return *this;
return add_handler(delete_handlers_, pattern, std::move(handler));
}
Server &Server::Delete(const std::string &pattern,
HandlerWithContentReader handler) {
delete_handlers_for_content_reader_.emplace_back(make_matcher(pattern),
std::move(handler));
return *this;
return add_handler(delete_handlers_for_content_reader_, pattern,
std::move(handler));
}
Server &Server::Options(const std::string &pattern, Handler handler) {
options_handlers_.emplace_back(make_matcher(pattern), std::move(handler));
return *this;
return add_handler(options_handlers_, pattern, std::move(handler));
}
Server &Server::WebSocket(const std::string &pattern,
@ -7054,6 +7069,11 @@ Server &Server::set_payload_max_length(size_t length) {
return *this;
}
Server &Server::set_websocket_max_missed_pongs(int count) {
websocket_max_missed_pongs_ = count;
return *this;
}
Server &Server::set_websocket_ping_interval(time_t sec) {
websocket_ping_interval_sec_ = sec;
return *this;
@ -7279,23 +7299,10 @@ Server::write_content_with_provider(Stream &strm, const Request &req,
if (res.is_chunked_content_provider_) {
auto type = detail::encoding_type(req, res);
std::unique_ptr<detail::compressor> compressor;
if (type == detail::EncodingType::Gzip) {
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
compressor = detail::make_unique<detail::gzip_compressor>();
#endif
} else if (type == detail::EncodingType::Brotli) {
#ifdef CPPHTTPLIB_BROTLI_SUPPORT
compressor = detail::make_unique<detail::brotli_compressor>();
#endif
} else if (type == detail::EncodingType::Zstd) {
#ifdef CPPHTTPLIB_ZSTD_SUPPORT
compressor = detail::make_unique<detail::zstd_compressor>();
#endif
} else {
auto compressor = detail::make_compressor(type);
if (!compressor) {
compressor = detail::make_unique<detail::nocompressor>();
}
assert(compressor != nullptr);
return detail::write_content_chunked(strm, res.content_provider_,
is_shutting_down, *compressor);
@ -7917,14 +7924,8 @@ void Server::apply_ranges(const Request &req, Response &res,
if (res.content_provider_) {
if (res.is_chunked_content_provider_) {
res.set_header("Transfer-Encoding", "chunked");
if (type == detail::EncodingType::Gzip) {
res.set_header("Content-Encoding", "gzip");
res.set_header("Vary", "Accept-Encoding");
} else if (type == detail::EncodingType::Brotli) {
res.set_header("Content-Encoding", "br");
res.set_header("Vary", "Accept-Encoding");
} else if (type == detail::EncodingType::Zstd) {
res.set_header("Content-Encoding", "zstd");
if (type != detail::EncodingType::None) {
res.set_header("Content-Encoding", detail::encoding_name(type));
res.set_header("Vary", "Accept-Encoding");
}
}
@ -7955,27 +7956,7 @@ void Server::apply_ranges(const Request &req, Response &res,
if (type != detail::EncodingType::None) {
output_pre_compression_log(req, res);
std::unique_ptr<detail::compressor> compressor;
std::string content_encoding;
if (type == detail::EncodingType::Gzip) {
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
compressor = detail::make_unique<detail::gzip_compressor>();
content_encoding = "gzip";
#endif
} else if (type == detail::EncodingType::Brotli) {
#ifdef CPPHTTPLIB_BROTLI_SUPPORT
compressor = detail::make_unique<detail::brotli_compressor>();
content_encoding = "br";
#endif
} else if (type == detail::EncodingType::Zstd) {
#ifdef CPPHTTPLIB_ZSTD_SUPPORT
compressor = detail::make_unique<detail::zstd_compressor>();
content_encoding = "zstd";
#endif
}
if (compressor) {
if (auto compressor = detail::make_compressor(type)) {
std::string compressed;
if (compressor->compress(res.body.data(), res.body.size(), true,
[&](const char *data, size_t data_len) {
@ -7983,7 +7964,7 @@ void Server::apply_ranges(const Request &req, Response &res,
return true;
})) {
res.body.swap(compressed);
res.set_header("Content-Encoding", content_encoding);
res.set_header("Content-Encoding", detail::encoding_name(type));
res.set_header("Vary", "Accept-Encoding");
}
}
@ -8231,7 +8212,8 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
{
// Use WebSocket-specific read timeout instead of HTTP timeout
strm.set_read_timeout(CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND, 0);
ws::WebSocket ws(strm, req, true, websocket_ping_interval_sec_);
ws::WebSocket ws(strm, req, true, websocket_ping_interval_sec_,
websocket_max_missed_pongs_);
entry.handler(req, ws);
}
return true;
@ -10822,38 +10804,6 @@ void ClientImpl::enable_server_hostname_verification(bool enabled) {
}
#endif
// ClientImpl::set_ca_cert_store is defined after TLS namespace (uses helpers)
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert,
std::size_t size) const {
auto mem = BIO_new_mem_buf(ca_cert, static_cast<int>(size));
auto se = detail::scope_exit([&] { BIO_free_all(mem); });
if (!mem) { return nullptr; }
auto inf = PEM_X509_INFO_read_bio(mem, nullptr, nullptr, nullptr);
if (!inf) { return nullptr; }
auto cts = X509_STORE_new();
if (cts) {
for (auto i = 0; i < static_cast<int>(sk_X509_INFO_num(inf)); i++) {
auto itmp = sk_X509_INFO_value(inf, i);
if (!itmp) { continue; }
if (itmp->x509) { X509_STORE_add_cert(cts, itmp->x509); }
if (itmp->crl) { X509_STORE_add_crl(cts, itmp->crl); }
}
}
sk_X509_INFO_pop_free(inf, X509_INFO_free);
return cts;
}
void ClientImpl::set_server_certificate_verifier(
std::function<SSLVerifierResponse(SSL *ssl)> /*verifier*/) {
// Base implementation does nothing - SSLClient overrides this
}
#endif
void ClientImpl::set_logger(Logger logger) {
logger_ = std::move(logger);
}
@ -10927,10 +10877,10 @@ Client::Client(const std::string &scheme_host_port,
cli_ = detail::make_unique<ClientImpl>(scheme_host_port, 80,
client_cert_path, client_key_path);
}
} // namespace detail
}
Client::Client(const std::string &host, int port)
: cli_(detail::make_unique<ClientImpl>(host, port)) {}
: Client(host, port, std::string(), std::string()) {}
Client::Client(const std::string &host, int port,
const std::string &client_cert_path,
@ -11505,12 +11455,6 @@ void Client::set_follow_location(bool on) {
void Client::set_path_encode(bool on) { cli_->set_path_encode(on); }
[[deprecated("Use set_path_encode() instead. "
"This function will be removed by v1.0.0.")]]
void Client::set_url_encode(bool on) {
cli_->set_path_encode(on);
}
void Client::set_compress(bool on) { cli_->set_compress(on); }
void Client::set_decompress(bool on) { cli_->set_decompress(on); }
@ -11893,24 +11837,31 @@ SSLClient::SSLClient(const std::string &host)
SSLClient::SSLClient(const std::string &host, int port)
: SSLClient(host, port, std::string(), std::string()) {}
void SSLClient::init_ctx() {
ctx_ = tls::create_client_context();
if (ctx_) { tls::set_min_version(ctx_, tls::Version::TLS1_2); }
}
void SSLClient::reset_ctx_on_error() {
last_backend_error_ = tls::get_error();
tls::free_context(ctx_);
ctx_ = nullptr;
}
SSLClient::SSLClient(const std::string &host, int port,
const std::string &client_cert_path,
const std::string &client_key_path,
const std::string &private_key_password)
: ClientImpl(host, port, client_cert_path, client_key_path) {
ctx_ = tls::create_client_context();
init_ctx();
if (!ctx_) { return; }
tls::set_min_version(ctx_, tls::Version::TLS1_2);
if (!client_cert_path.empty() && !client_key_path.empty()) {
const char *password =
private_key_password.empty() ? nullptr : private_key_password.c_str();
if (!tls::set_client_cert_file(ctx_, client_cert_path.c_str(),
client_key_path.c_str(), password)) {
last_backend_error_ = tls::get_error();
tls::free_context(ctx_);
ctx_ = nullptr;
reset_ctx_on_error();
}
}
}
@ -11918,17 +11869,13 @@ SSLClient::SSLClient(const std::string &host, int port,
SSLClient::SSLClient(const std::string &host, int port,
const PemMemory &pem)
: ClientImpl(host, port) {
ctx_ = tls::create_client_context();
init_ctx();
if (!ctx_) { return; }
tls::set_min_version(ctx_, tls::Version::TLS1_2);
if (pem.cert_pem && pem.key_pem) {
if (!tls::set_client_cert_pem(ctx_, pem.cert_pem, pem.key_pem,
pem.private_key_password)) {
last_backend_error_ = tls::get_error();
tls::free_context(ctx_);
ctx_ = nullptr;
reset_ctx_on_error();
}
}
}
@ -12479,41 +12426,6 @@ std::string Request::sni() const {
* Group 8: TLS abstraction layer - OpenSSL backend
*/
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
// These wrappers forward to deprecated APIs that will be removed by v1.0.0.
// Suppress C4996 / -Wdeprecated-declarations so that MSVC /sdl builds (which
// promote C4996 to an error) compile cleanly even though the wrappers
// themselves are also marked [[deprecated]].
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4996)
#elif defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif
SSL_CTX *Client::ssl_context() const {
if (is_ssl_) { return static_cast<SSLClient &>(*cli_).ssl_context(); }
return nullptr;
}
void Client::set_server_certificate_verifier(
std::function<SSLVerifierResponse(SSL *ssl)> verifier) {
cli_->set_server_certificate_verifier(verifier);
}
long Client::get_verify_result() const {
if (is_ssl_) { return static_cast<SSLClient &>(*cli_).get_verify_result(); }
return -1; // NOTE: -1 doesn't match any of X509_V_ERR_???
}
#if defined(_MSC_VER)
#pragma warning(pop)
#elif defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
#endif
#endif // CPPHTTPLIB_OPENSSL_SUPPORT
/*
* OpenSSL Backend Implementation
*/
@ -12523,54 +12435,6 @@ namespace tls {
namespace impl {
// OpenSSL-specific helpers for converting native types to PEM
std::string x509_to_pem(X509 *cert) {
if (!cert) return {};
BIO *bio = BIO_new(BIO_s_mem());
if (!bio) return {};
if (PEM_write_bio_X509(bio, cert) != 1) {
BIO_free(bio);
return {};
}
char *data = nullptr;
long len = BIO_get_mem_data(bio, &data);
std::string pem(data, static_cast<size_t>(len));
BIO_free(bio);
return pem;
}
std::string evp_pkey_to_pem(EVP_PKEY *key) {
if (!key) return {};
BIO *bio = BIO_new(BIO_s_mem());
if (!bio) return {};
if (PEM_write_bio_PrivateKey(bio, key, nullptr, nullptr, 0, nullptr,
nullptr) != 1) {
BIO_free(bio);
return {};
}
char *data = nullptr;
long len = BIO_get_mem_data(bio, &data);
std::string pem(data, static_cast<size_t>(len));
BIO_free(bio);
return pem;
}
std::string x509_store_to_pem(X509_STORE *store) {
if (!store) return {};
std::string pem;
auto objs = X509_STORE_get0_objects(store);
if (!objs) return {};
auto count = sk_X509_OBJECT_num(objs);
for (decltype(count) i = 0; i < count; i++) {
auto obj = sk_X509_OBJECT_value(objs, i);
if (X509_OBJECT_get_type(obj) == X509_LU_X509) {
auto cert = X509_OBJECT_get0_X509(obj);
if (cert) { pem += x509_to_pem(cert); }
}
}
return pem;
}
// Helper to map OpenSSL SSL_get_error to ErrorCode
ErrorCode map_ssl_error(int ssl_error, int &out_errno) {
switch (ssl_error) {
@ -12603,8 +12467,10 @@ STACK_OF(X509_NAME) *
X509 *cert = nullptr;
while ((cert = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr)) !=
nullptr) {
X509_NAME *name = X509_get_subject_name(cert);
if (name) { sk_X509_NAME_push(ca_list, X509_NAME_dup(name)); }
const X509_NAME *name = X509_get_subject_name(cert);
if (name) {
sk_X509_NAME_push(ca_list, X509_NAME_dup(const_cast<X509_NAME *>(name)));
}
X509_free(cert);
}
BIO_free(bio);
@ -12612,45 +12478,6 @@ STACK_OF(X509_NAME) *
return ca_list;
}
// Helper: Extract CA names from X509_STORE
// Returns a new STACK_OF(X509_NAME)* or nullptr on failure
// Caller takes ownership of returned list
STACK_OF(X509_NAME) *
extract_client_ca_list_from_store(X509_STORE *store) {
if (!store) { return nullptr; }
auto ca_list = sk_X509_NAME_new_null();
if (!ca_list) { return nullptr; }
auto objs = X509_STORE_get0_objects(store);
if (!objs) {
sk_X509_NAME_free(ca_list);
return nullptr;
}
auto count = sk_X509_OBJECT_num(objs);
for (decltype(count) i = 0; i < count; i++) {
auto obj = sk_X509_OBJECT_value(objs, i);
if (X509_OBJECT_get_type(obj) == X509_LU_X509) {
auto cert = X509_OBJECT_get0_X509(obj);
if (cert) {
auto subject = X509_get_subject_name(cert);
if (subject) {
auto name_dup = X509_NAME_dup(subject);
if (name_dup) { sk_X509_NAME_push(ca_list, name_dup); }
}
}
}
}
if (sk_X509_NAME_num(ca_list) == 0) {
sk_X509_NAME_free(ca_list);
return nullptr;
}
return ca_list;
}
// OpenSSL verify callback wrapper
int openssl_verify_callback(int preverify_ok, X509_STORE_CTX *ctx) {
auto &callback = get_verify_callback();
@ -13086,6 +12913,9 @@ ssize_t read(session_t session, void *buf, size_t len, TlsError &err) {
auto ssl_err = SSL_get_error(ssl, ret);
err.code = impl::map_ssl_error(ssl_err, err.sys_errno);
if (err.code == ErrorCode::PeerClosed) {
return 0;
} // Gracefully handle the peer closed state.
if (err.code == ErrorCode::Fatal) { err.backend_code = ERR_get_error(); }
return -1;
}
@ -13523,164 +13353,8 @@ std::string verify_error_string(long error_code) {
return str ? str : "unknown error";
}
namespace impl {
// OpenSSL-specific helpers for public API wrappers
ctx_t create_server_context_from_x509(X509 *cert, EVP_PKEY *key,
X509_STORE *client_ca_store,
int &out_error) {
out_error = 0;
auto cert_pem = x509_to_pem(cert);
auto key_pem = evp_pkey_to_pem(key);
if (cert_pem.empty() || key_pem.empty()) {
out_error = static_cast<int>(ERR_get_error());
return nullptr;
}
auto ctx = create_server_context();
if (!ctx) {
out_error = static_cast<int>(get_error());
return nullptr;
}
if (!set_server_cert_pem(ctx, cert_pem.c_str(), key_pem.c_str(), nullptr)) {
out_error = static_cast<int>(get_error());
free_context(ctx);
return nullptr;
}
if (client_ca_store) {
// Set cert store for verification (SSL_CTX_set_cert_store takes ownership)
SSL_CTX_set_cert_store(static_cast<SSL_CTX *>(ctx), client_ca_store);
// Extract and set client CA list directly from store (more efficient than
// PEM conversion)
auto ca_list = extract_client_ca_list_from_store(client_ca_store);
if (ca_list) {
SSL_CTX_set_client_CA_list(static_cast<SSL_CTX *>(ctx), ca_list);
}
set_verify_client(ctx, true);
}
return ctx;
}
void update_server_certs_from_x509(ctx_t ctx, X509 *cert, EVP_PKEY *key,
X509_STORE *client_ca_store) {
auto cert_pem = x509_to_pem(cert);
auto key_pem = evp_pkey_to_pem(key);
if (!cert_pem.empty() && !key_pem.empty()) {
update_server_cert(ctx, cert_pem.c_str(), key_pem.c_str(), nullptr);
}
if (client_ca_store) {
auto ca_pem = x509_store_to_pem(client_ca_store);
if (!ca_pem.empty()) { update_server_client_ca(ctx, ca_pem.c_str()); }
X509_STORE_free(client_ca_store);
}
}
ctx_t create_client_context_from_x509(X509 *cert, EVP_PKEY *key,
const char *password,
uint64_t &out_error) {
out_error = 0;
auto ctx = create_client_context();
if (!ctx) {
out_error = get_error();
return nullptr;
}
if (cert && key) {
auto cert_pem = x509_to_pem(cert);
auto key_pem = evp_pkey_to_pem(key);
if (cert_pem.empty() || key_pem.empty()) {
out_error = ERR_get_error();
free_context(ctx);
return nullptr;
}
if (!set_client_cert_pem(ctx, cert_pem.c_str(), key_pem.c_str(),
password)) {
out_error = get_error();
free_context(ctx);
return nullptr;
}
}
return ctx;
}
} // namespace impl
} // namespace tls
// ClientImpl::set_ca_cert_store - defined here to use
// tls::impl::x509_store_to_pem Deprecated: converts X509_STORE to PEM and
// stores for redirect transfer
void ClientImpl::set_ca_cert_store(X509_STORE *ca_cert_store) {
if (ca_cert_store) {
ca_cert_pem_ = tls::impl::x509_store_to_pem(ca_cert_store);
}
}
SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key,
X509_STORE *client_ca_cert_store) {
ctx_ = tls::impl::create_server_context_from_x509(
cert, private_key, client_ca_cert_store, last_ssl_error_);
}
SSLServer::SSLServer(
const std::function<bool(SSL_CTX &ssl_ctx)> &setup_ssl_ctx_callback) {
// Use abstract API to create context
ctx_ = tls::create_server_context();
if (ctx_) {
// Pass to OpenSSL-specific callback (ctx_ is SSL_CTX* internally)
auto ssl_ctx = static_cast<SSL_CTX *>(ctx_);
if (!setup_ssl_ctx_callback(*ssl_ctx)) {
tls::free_context(ctx_);
ctx_ = nullptr;
}
}
}
SSL_CTX *SSLServer::ssl_context() const {
return static_cast<SSL_CTX *>(ctx_);
}
void SSLServer::update_certs(X509 *cert, EVP_PKEY *private_key,
X509_STORE *client_ca_cert_store) {
std::lock_guard<std::mutex> guard(ctx_mutex_);
tls::impl::update_server_certs_from_x509(ctx_, cert, private_key,
client_ca_cert_store);
}
SSLClient::SSLClient(const std::string &host, int port,
X509 *client_cert, EVP_PKEY *client_key,
const std::string &private_key_password)
: ClientImpl(host, port) {
const char *password =
private_key_password.empty() ? nullptr : private_key_password.c_str();
ctx_ = tls::impl::create_client_context_from_x509(
client_cert, client_key, password, last_backend_error_);
}
long SSLClient::get_verify_result() const { return verify_result_; }
void SSLClient::set_server_certificate_verifier(
std::function<SSLVerifierResponse(SSL *ssl)> verifier) {
// Wrap SSL* callback into backend-independent session_verifier_
auto v = std::make_shared<std::function<SSLVerifierResponse(SSL *)>>(
std::move(verifier));
session_verifier_ = [v](tls::session_t session) {
return (*v)(static_cast<SSL *>(session));
};
}
SSL_CTX *SSLClient::ssl_context() const {
return static_cast<SSL_CTX *>(ctx_);
}
bool SSLClient::verify_host(X509 *server_cert) const {
/* Quote from RFC2818 section 3.1 "Server Identity"
@ -16194,7 +15868,11 @@ ReadResult WebSocket::read(std::string &msg) {
payload.size(), true, !is_server_);
continue;
}
case Opcode::Pong: continue;
case Opcode::Pong: {
std::lock_guard<std::mutex> lock(ping_mutex_);
unacked_pings_ = 0;
continue;
}
case Opcode::Close: {
if (!closed_.exchange(true)) {
// Echo close frame back
@ -16228,7 +15906,11 @@ ReadResult WebSocket::read(std::string &msg) {
true, !is_server_);
continue;
}
if (cont_opcode == Opcode::Pong) { continue; }
if (cont_opcode == Opcode::Pong) {
std::lock_guard<std::mutex> lock(ping_mutex_);
unacked_pings_ = 0;
continue;
}
if (cont_opcode == Opcode::Close) {
if (!closed_.exchange(true)) {
std::lock_guard<std::mutex> lock(write_mutex_);
@ -16316,12 +15998,22 @@ void WebSocket::start_heartbeat() {
while (!closed_) {
ping_cv_.wait_for(lock, std::chrono::seconds(ping_interval_sec_));
if (closed_) { break; }
// If the peer has failed to respond to the previous pings, give up.
// RFC 6455 does not define a pong-timeout mechanism; this is an
// opt-in liveness check controlled by max_missed_pongs_.
if (max_missed_pongs_ > 0 && unacked_pings_ >= max_missed_pongs_) {
lock.unlock();
close(CloseStatus::GoingAway, "pong timeout");
return;
}
lock.unlock();
if (!send_frame(Opcode::Ping, nullptr, 0)) {
lock.lock();
closed_ = true;
break;
}
lock.lock();
unacked_pings_++;
}
});
}
@ -16449,8 +16141,9 @@ bool WebSocketClient::connect() {
Request req;
req.method = "GET";
req.path = path_;
ws_ = std::unique_ptr<WebSocket>(
new WebSocket(std::move(strm), req, false, websocket_ping_interval_sec_));
ws_ = std::unique_ptr<WebSocket>(new WebSocket(std::move(strm), req, false,
websocket_ping_interval_sec_,
websocket_max_missed_pongs_));
return true;
}
@ -16494,6 +16187,10 @@ void WebSocketClient::set_websocket_ping_interval(time_t sec) {
websocket_ping_interval_sec_ = sec;
}
void WebSocketClient::set_websocket_max_missed_pongs(int count) {
websocket_max_missed_pongs_ = count;
}
void WebSocketClient::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; }
void WebSocketClient::set_address_family(int family) {

View file

@ -8,8 +8,8 @@
#ifndef CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_VERSION "0.42.0"
#define CPPHTTPLIB_VERSION_NUM "0x002a00"
#define CPPHTTPLIB_VERSION "0.43.1"
#define CPPHTTPLIB_VERSION_NUM "0x002b01"
#ifdef _WIN32
#if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0A00
@ -205,6 +205,10 @@
#define CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND 30
#endif
#ifndef CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS
#define CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS 0
#endif
/*
* Headers
*/
@ -1720,6 +1724,8 @@ public:
Server &set_websocket_ping_interval(
const std::chrono::duration<Rep, Period> &duration);
Server &set_websocket_max_missed_pongs(int count);
bool bind_to_port(const std::string &host, int port, int socket_flags = 0);
int bind_to_any_port(const std::string &host, int socket_flags = 0);
bool listen_after_bind();
@ -1756,6 +1762,7 @@ protected:
size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH;
time_t websocket_ping_interval_sec_ =
CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND;
int websocket_max_missed_pongs_ = CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS;
private:
using Handlers =
@ -1767,6 +1774,14 @@ private:
static std::unique_ptr<detail::MatcherBase>
make_matcher(const std::string &pattern);
template <typename H>
Server &add_handler(
std::vector<std::pair<std::unique_ptr<detail::MatcherBase>, H>> &handlers,
const std::string &pattern, H handler) {
handlers.emplace_back(make_matcher(pattern), std::move(handler));
return *this;
}
Server &set_error_handler_core(HandlerWithResponse handler, std::true_type);
Server &set_error_handler_core(Handler handler, std::false_type);
@ -1928,15 +1943,6 @@ private:
int ssl_error_ = 0;
uint64_t ssl_backend_error_ = 0;
#endif
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
public:
[[deprecated("Use ssl_backend_error() instead. "
"This function will be removed by v1.0.0.")]]
uint64_t ssl_openssl_error() const {
return ssl_backend_error_;
}
#endif
};
struct ClientConnection {
@ -2409,22 +2415,6 @@ protected:
int last_ssl_error_ = 0;
uint64_t last_backend_error_ = 0;
#endif
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
public:
[[deprecated("Use load_ca_cert_store() instead. "
"This function will be removed by v1.0.0.")]]
void set_ca_cert_store(X509_STORE *ca_cert_store);
[[deprecated("Use tls::create_ca_store() instead. "
"This function will be removed by v1.0.0.")]]
X509_STORE *create_ca_cert_store(const char *ca_cert, std::size_t size) const;
[[deprecated("Use set_server_certificate_verifier(VerifyCallback) instead. "
"This function will be removed by v1.0.0.")]]
virtual void set_server_certificate_verifier(
std::function<SSLVerifierResponse(SSL *ssl)> verifier);
#endif
};
class Client {
@ -2599,7 +2589,6 @@ public:
void set_follow_location(bool on);
void set_path_encode(bool on);
void set_url_encode(bool on);
void set_compress(bool on);
@ -2647,22 +2636,6 @@ public:
private:
bool is_ssl_ = false;
#endif
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
public:
[[deprecated("Use tls_context() instead. "
"This function will be removed by v1.0.0.")]]
SSL_CTX *ssl_context() const;
[[deprecated("Use set_session_verifier(session_t) instead. "
"This function will be removed by v1.0.0.")]]
void set_server_certificate_verifier(
std::function<SSLVerifierResponse(SSL *ssl)> verifier);
[[deprecated("Use Result::ssl_backend_error() instead. "
"This function will be removed by v1.0.0.")]]
long get_verify_result() const;
#endif
};
#ifdef CPPHTTPLIB_SSL_ENABLED
@ -2708,29 +2681,6 @@ private:
std::mutex ctx_mutex_;
int last_ssl_error_ = 0;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
public:
[[deprecated("Use SSLServer(PemMemory) or "
"SSLServer(ContextSetupCallback) instead. "
"This constructor will be removed by v1.0.0.")]]
SSLServer(X509 *cert, EVP_PKEY *private_key,
X509_STORE *client_ca_cert_store = nullptr);
[[deprecated("Use SSLServer(ContextSetupCallback) instead. "
"This constructor will be removed by v1.0.0.")]]
SSLServer(
const std::function<bool(SSL_CTX &ssl_ctx)> &setup_ssl_ctx_callback);
[[deprecated("Use tls_context() instead. "
"This function will be removed by v1.0.0.")]]
SSL_CTX *ssl_context() const;
[[deprecated("Use update_certs_pem() instead. "
"This function will be removed by v1.0.0.")]]
void update_certs(X509 *cert, EVP_PKEY *private_key,
X509_STORE *client_ca_cert_store = nullptr);
#endif
};
class SSLClient final : public ClientImpl {
@ -2794,6 +2744,9 @@ private:
Response &res, bool &success, Error &error);
bool initialize_ssl(Socket &socket, Error &error);
void init_ctx();
void reset_ctx_on_error();
bool load_certs();
tls::ctx_t ctx_ = nullptr;
@ -2811,42 +2764,6 @@ private:
friend class ClientImpl;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
public:
[[deprecated("Use SSLClient(host, port, PemMemory) instead. "
"This constructor will be removed by v1.0.0.")]]
explicit SSLClient(const std::string &host, int port, X509 *client_cert,
EVP_PKEY *client_key,
const std::string &private_key_password = std::string());
[[deprecated("Use Result::ssl_backend_error() instead. "
"This function will be removed by v1.0.0.")]]
long get_verify_result() const;
[[deprecated("Use tls_context() instead. "
"This function will be removed by v1.0.0.")]]
SSL_CTX *ssl_context() const;
// Override of a deprecated virtual in ClientImpl. Suppress C4996 /
// -Wdeprecated-declarations on the override declaration itself so that
// MSVC /sdl builds compile cleanly. Will be removed together with the
// base virtual by v1.0.0.
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4996)
#elif defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif
[[deprecated("Use set_session_verifier(session_t) instead. "
"This function will be removed by v1.0.0.")]]
void set_server_certificate_verifier(
std::function<SSLVerifierResponse(SSL *ssl)> verifier) override;
#if defined(_MSC_VER)
#pragma warning(pop)
#elif defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
#endif
private:
bool verify_host(X509 *server_cert) const;
bool verify_host_with_subject_alt_name(X509 *server_cert) const;
@ -3818,17 +3735,21 @@ private:
WebSocket(
Stream &strm, const Request &req, bool is_server,
time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND)
time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND,
int max_missed_pongs = CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS)
: strm_(strm), req_(req), is_server_(is_server),
ping_interval_sec_(ping_interval_sec) {
ping_interval_sec_(ping_interval_sec),
max_missed_pongs_(max_missed_pongs) {
start_heartbeat();
}
WebSocket(
std::unique_ptr<Stream> &&owned_strm, const Request &req, bool is_server,
time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND)
time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND,
int max_missed_pongs = CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS)
: strm_(*owned_strm), owned_strm_(std::move(owned_strm)), req_(req),
is_server_(is_server), ping_interval_sec_(ping_interval_sec) {
is_server_(is_server), ping_interval_sec_(ping_interval_sec),
max_missed_pongs_(max_missed_pongs) {
start_heartbeat();
}
@ -3840,6 +3761,8 @@ private:
Request req_;
bool is_server_;
time_t ping_interval_sec_;
int max_missed_pongs_;
int unacked_pings_ = 0;
std::atomic<bool> closed_{false};
std::mutex write_mutex_;
std::thread ping_thread_;
@ -3869,6 +3792,7 @@ public:
void set_read_timeout(time_t sec, time_t usec = 0);
void set_write_timeout(time_t sec, time_t usec = 0);
void set_websocket_ping_interval(time_t sec);
void set_websocket_max_missed_pongs(int count);
void set_tcp_nodelay(bool on);
void set_address_family(int family);
void set_ipv6_v6only(bool on);
@ -3900,6 +3824,7 @@ private:
time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND;
time_t websocket_ping_interval_sec_ =
CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND;
int websocket_max_missed_pongs_ = CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS;
int address_family_ = AF_UNSPEC;
bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY;
bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY;