mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-04-26 10:41:25 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .devops/openvino.Dockerfile # .github/workflows/build-self-hosted.yml # .github/workflows/build.yml # common/chat.cpp # docs/backend/OPENVINO.md # examples/speculative-simple/speculative-simple.cpp # ggml/src/ggml-hexagon/ggml-hexagon.cpp # ggml/src/ggml-hexagon/htp/CMakeLists.txt # ggml/src/ggml-hexagon/htp/htp-ctx.h # ggml/src/ggml-hexagon/htp/htp-ops.h # ggml/src/ggml-hexagon/htp/main.c # ggml/src/ggml-hexagon/libggml-htp.inf # ggml/src/ggml-openvino/ggml-decoder.cpp # ggml/src/ggml-openvino/ggml-openvino-extra.cpp # ggml/src/ggml-openvino/ggml-openvino.cpp # ggml/src/ggml-openvino/ggml-quants.cpp # ggml/src/ggml-openvino/openvino/op/rope.cpp # ggml/src/ggml-openvino/openvino/op_table.cpp # ggml/src/ggml-openvino/openvino/op_table.h # ggml/src/ggml-openvino/openvino/translate_session.cpp # ggml/src/ggml-openvino/openvino/utils.cpp # ggml/src/ggml-openvino/openvino/utils.h # ggml/src/ggml-openvino/utils.cpp # ggml/src/ggml-openvino/utils.h # ggml/src/ggml-sycl/common.hpp # ggml/src/ggml-sycl/convert.cpp # ggml/src/ggml-sycl/convert.hpp # ggml/src/ggml-sycl/gemm.hpp # ggml/src/ggml-sycl/ggml-sycl.cpp # ggml/src/ggml-sycl/set_rows.cpp # ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp # ggml/src/ggml-webgpu/ggml-webgpu.cpp # scripts/sync_vendor.py # tests/CMakeLists.txt # tests/test-chat.cpp # tools/cli/cli.cpp # tools/mtmd/CMakeLists.txt # tools/server/CMakeLists.txt
This commit is contained in:
commit
0755f27372
42 changed files with 1531 additions and 3199 deletions
|
|
@ -3124,14 +3124,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
"token budget for thinking: -1 for unrestricted, 0 for immediate end, N>0 for token budget (default: -1)",
|
||||
[](common_params & params, int value) {
|
||||
if (value < -1) { throw std::invalid_argument("invalid value"); }
|
||||
params.reasoning_budget = value;
|
||||
params.sampling.reasoning_budget_tokens = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET"));
|
||||
add_opt(common_arg(
|
||||
{"--reasoning-budget-message"}, "MESSAGE",
|
||||
"message injected before the end-of-thinking tag when reasoning budget is exhausted (default: none)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.reasoning_budget_message = value;
|
||||
params.sampling.reasoning_budget_message = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE"));
|
||||
add_opt(common_arg(
|
||||
|
|
@ -3904,6 +3904,17 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--spec-default"},
|
||||
string_format("enable default speculative decoding config"),
|
||||
[](common_params & params) {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
|
||||
params.speculative.ngram_size_n = 24;
|
||||
params.speculative.n_min = 48;
|
||||
params.speculative.n_max = 64;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
|
||||
return ctx_arg;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -408,6 +408,25 @@ json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msg
|
|||
return render_message_to_json(msgs, c);
|
||||
}
|
||||
|
||||
json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
|
||||
if (tools.empty()) {
|
||||
return json();
|
||||
}
|
||||
|
||||
auto result = json::array();
|
||||
for (const auto & tool : tools) {
|
||||
result.push_back({
|
||||
{ "type", "function" },
|
||||
{ "function", {
|
||||
{ "name", tool.name },
|
||||
{ "description", tool.description },
|
||||
{ "parameters", json::parse(tool.parameters) },
|
||||
}},
|
||||
});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
|
||||
std::vector<common_chat_tool> result;
|
||||
|
||||
|
|
@ -443,60 +462,9 @@ std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & too
|
|||
return result;
|
||||
}
|
||||
|
||||
json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
|
||||
if (tools.empty()) {
|
||||
return json();
|
||||
}
|
||||
|
||||
auto result = json::array();
|
||||
for (const auto & tool : tools) {
|
||||
result.push_back({
|
||||
{ "type", "function" },
|
||||
{ "function",
|
||||
{
|
||||
{ "name", tool.name },
|
||||
{ "description", tool.description },
|
||||
{ "parameters", json::parse(tool.parameters) },
|
||||
} },
|
||||
});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
|
||||
json delta = json::object();
|
||||
if (!diff.reasoning_content_delta.empty()) {
|
||||
delta["reasoning_content"] = diff.reasoning_content_delta;
|
||||
}
|
||||
if (!diff.content_delta.empty()) {
|
||||
delta["content"] = diff.content_delta;
|
||||
}
|
||||
if (diff.tool_call_index != std::string::npos) {
|
||||
json tool_call;
|
||||
tool_call["index"] = diff.tool_call_index;
|
||||
if (!diff.tool_call_delta.id.empty()) {
|
||||
tool_call["id"] = diff.tool_call_delta.id;
|
||||
tool_call["type"] = "function";
|
||||
}
|
||||
if (!diff.tool_call_delta.name.empty() || !diff.tool_call_delta.arguments.empty()) {
|
||||
json function = json::object();
|
||||
if (!diff.tool_call_delta.name.empty()) {
|
||||
function["name"] = diff.tool_call_delta.name;
|
||||
}
|
||||
if (!diff.tool_call_delta.arguments.empty()) {
|
||||
function["arguments"] = diff.tool_call_delta.arguments;
|
||||
}
|
||||
tool_call["function"] = function;
|
||||
}
|
||||
delta["tool_calls"] = json::array({ tool_call });
|
||||
}
|
||||
return delta;
|
||||
}
|
||||
|
||||
#include "common/unicode.h"
|
||||
#include "peg-parser.cpp"
|
||||
#include "chat-peg-parser.cpp"
|
||||
|
||||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
try {
|
||||
|
|
|
|||
|
|
@ -256,14 +256,13 @@ bool common_chat_templates_support_enable_thinking(const common_chat_templates *
|
|||
// Parses a JSON array of messages in OpenAI's chat completion API format.
|
||||
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const nlohmann::ordered_json & messages);
|
||||
|
||||
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools);
|
||||
|
||||
// DEPRECATED: only used in tests
|
||||
nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
|
||||
|
||||
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools);
|
||||
nlohmann::ordered_json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
|
||||
|
||||
nlohmann::ordered_json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|
||||
|
||||
// get template caps, useful for reporting to server /props endpoint
|
||||
std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_templates * chat_templates);
|
||||
|
||||
|
|
|
|||
|
|
@ -275,6 +275,7 @@ struct common_params_sampling {
|
|||
std::vector<llama_token> reasoning_budget_start; // start tag token sequence
|
||||
std::vector<llama_token> reasoning_budget_end; // end tag token sequence
|
||||
std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag)
|
||||
std::string reasoning_budget_message; // message injected before end tag when budget exhausted
|
||||
|
||||
bool backend_sampling = false;
|
||||
|
||||
|
|
@ -582,8 +583,6 @@ struct common_params {
|
|||
bool force_pure_content_parser = false;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
int enable_reasoning = -1; // -1 = auto, 0 = disable, 1 = enable
|
||||
int reasoning_budget = -1;
|
||||
std::string reasoning_budget_message; // message injected before end tag when budget exhausted
|
||||
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
|
||||
int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time
|
||||
|
||||
|
|
|
|||
|
|
@ -749,6 +749,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
|
|||
|
||||
mod.reset();
|
||||
n_low = 0;
|
||||
i_last = 0;
|
||||
}
|
||||
} else {
|
||||
n_low = 0;
|
||||
|
|
|
|||
|
|
@ -11855,7 +11855,7 @@ class LLaDAMoEModel(TextModel):
|
|||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("HunYuanDenseV1ForCausalLM", "HunYuanVLForConditionalGeneration")
|
||||
@ModelBase.register("HunYuanDenseV1ForCausalLM")
|
||||
class HunYuanModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE
|
||||
|
||||
|
|
@ -11994,28 +11994,58 @@ class HunYuanModel(TextModel):
|
|||
|
||||
|
||||
@ModelBase.register("HunYuanVLForConditionalGeneration")
|
||||
class HunyuanOCRVisionModel(MmprojModel):
|
||||
class HunyuanVLVisionModel(MmprojModel):
|
||||
# Handles both HunyuanOCR and HunyuanVL, which share the HF architecture name
|
||||
# "HunYuanVLForConditionalGeneration" and the `vit.perceive.*` vision layout.
|
||||
# Each variant maps to a different projector type in clip.cpp so image
|
||||
# preprocessing follows the correct code path.
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.hparams_vision is not None
|
||||
# HunyuanOCR uses max_image_size instead of image_size
|
||||
# HunyuanOCR / HunyuanVL uses max_image_size instead of image_size
|
||||
if "image_size" not in self.hparams_vision:
|
||||
self.hparams_vision["image_size"] = self.hparams_vision.get("max_image_size", 2048)
|
||||
|
||||
@staticmethod
|
||||
def is_ocr_variant(hparams: dict) -> bool:
|
||||
"""Return True for HunyuanOCR, False for HunyuanVL.
|
||||
|
||||
The projector's output dim must equal the text model's hidden_size by
|
||||
construction (that's what "projector" means). HunyuanOCR pairs a 1B text
|
||||
backbone (hidden=1024); HunyuanVL pairs a 4B one (hidden=3072). So the
|
||||
ViT -> LLM projection dim is a hard architectural signature, not a
|
||||
magic number.
|
||||
"""
|
||||
vision_out = int((hparams.get("vision_config") or {}).get("out_hidden_size", 0))
|
||||
return vision_out == 1024
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
assert self.hparams_vision is not None
|
||||
hparams = self.hparams_vision
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.HUNYUANOCR)
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("rms_norm_eps", 1e-5))
|
||||
self.gguf_writer.add_vision_spatial_merge_size(hparams.get("spatial_merge_size", 2))
|
||||
self.gguf_writer.add_vision_min_pixels(self.preprocessor_config["min_pixels"])
|
||||
self.gguf_writer.add_vision_max_pixels(self.preprocessor_config["max_pixels"])
|
||||
vcfg = self.hparams_vision
|
||||
|
||||
if self.is_ocr_variant(self.global_config):
|
||||
# --- HunyuanOCR ---
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.HUNYUANOCR)
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(vcfg.get("rms_norm_eps", 1e-5))
|
||||
self.gguf_writer.add_vision_spatial_merge_size(vcfg.get("spatial_merge_size", 2))
|
||||
self.gguf_writer.add_vision_min_pixels(self.preprocessor_config["min_pixels"])
|
||||
self.gguf_writer.add_vision_max_pixels(self.preprocessor_config["max_pixels"])
|
||||
return
|
||||
|
||||
# --- HunyuanVL ---
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.HUNYUANVL)
|
||||
self.gguf_writer.add_vision_use_gelu(str(vcfg["hidden_act"]).lower() == "gelu")
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(float(vcfg["rms_norm_eps"]))
|
||||
self.gguf_writer.add_vision_spatial_merge_size(int(vcfg["spatial_merge_size"]))
|
||||
self.gguf_writer.add_vision_min_pixels(int(self.preprocessor_config["min_pixels"]))
|
||||
self.gguf_writer.add_vision_max_pixels(int(self.preprocessor_config["max_pixels"]))
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if not name.startswith("vit."):
|
||||
return # skip text tensors
|
||||
return
|
||||
# strip CLS token (row 0) from position embeddings so resize_position_embeddings works
|
||||
if "position_embedding" in name:
|
||||
data_torch = data_torch[1:] # [n_patches+1, n_embd] -> [n_patches, n_embd]
|
||||
|
|
@ -12023,11 +12053,66 @@ class HunyuanOCRVisionModel(MmprojModel):
|
|||
|
||||
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||
# force conv weights to F32 or F16 to avoid BF16 IM2COL issues on Metal
|
||||
# Both HunyuanOCR and HunyuanVL emit the ViT -> LLM projection as mm.0/mm.2.
|
||||
if ("mm.0." in new_name or "mm.2." in new_name) and new_name.endswith(".weight"):
|
||||
return gguf.GGMLQuantizationType.F16 if self.ftype == gguf.LlamaFileType.MOSTLY_F16 else gguf.GGMLQuantizationType.F32
|
||||
return super().tensor_force_quant(name, new_name, bid, n_dims)
|
||||
|
||||
|
||||
@ModelBase.register("HunYuanVLForConditionalGeneration")
|
||||
class HunyuanVLTextModel(HunYuanModel):
|
||||
# The "HunYuanVLForConditionalGeneration" HF architecture covers both HunyuanOCR
|
||||
# and HunyuanVL. HunyuanOCR reuses the HunYuan-Dense text backbone (standard RoPE),
|
||||
# while HunyuanVL introduces a new LLM arch with XD-RoPE. Detect the variant from
|
||||
# the config and pick the matching GGUF architecture.
|
||||
model_arch = gguf.MODEL_ARCH.HUNYUAN_VL
|
||||
|
||||
@staticmethod
|
||||
def _is_ocr_config(hparams: dict) -> bool:
|
||||
# OCR pairs a 1B text backbone (hidden=1024) with a ViT projector that
|
||||
# outputs 1024-d; HunyuanVL uses 3072-d. Keep in sync with
|
||||
# HunyuanVLVisionModel.is_ocr_variant.
|
||||
return int((hparams.get("vision_config") or {}).get("out_hidden_size", 0)) == 1024
|
||||
|
||||
def __init__(self, dir_model: Path, *args, **kwargs):
|
||||
raw_hparams = kwargs.get("hparams") or ModelBase.load_hparams(dir_model, is_mistral_format=False)
|
||||
if self._is_ocr_config(raw_hparams):
|
||||
self.model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE
|
||||
else:
|
||||
self.model_arch = gguf.MODEL_ARCH.HUNYUAN_VL
|
||||
super().__init__(dir_model, *args, **kwargs)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# Only emit XD-RoPE metadata for the HunyuanVL backbone; HunyuanOCR uses
|
||||
# the HunYuan-Dense arch which already handles standard rope in super().
|
||||
if self.model_arch != gguf.MODEL_ARCH.HUNYUAN_VL:
|
||||
return
|
||||
|
||||
if self.rope_parameters.get("rope_type") != "xdrope":
|
||||
return
|
||||
|
||||
# defaults for HunyuanVL. The C++ side later computes:
|
||||
# freq_base = rope_theta * alpha ** (head_dim / (head_dim - 2))
|
||||
self.gguf_writer.add_rope_freq_base(float(self.rope_parameters["rope_theta"]))
|
||||
self.gguf_writer.add_rope_scaling_alpha(float(self.rope_parameters["alpha"]))
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
self.gguf_writer.add_rope_scaling_factor(float(self.rope_parameters.get("factor", 1)))
|
||||
|
||||
ctx_len = int(self.hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(ctx_len)
|
||||
self.gguf_writer.add_context_length(ctx_len)
|
||||
|
||||
self.gguf_writer.add_rope_dimension_sections(list(self.rope_parameters["xdrope_section"]))
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Skip vision tensors — they are written by HunyuanVLVisionModel
|
||||
if name.startswith("vit."):
|
||||
return
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("SmolLM3ForCausalLM")
|
||||
class SmolLM3Model(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.SMOLLM3
|
||||
|
|
|
|||
|
|
@ -1,58 +0,0 @@
|
|||
# Snapdragon-based Linux devices
|
||||
|
||||
## Docker Setup
|
||||
|
||||
The easiest way to build llama.cpp for a Snapdragon-based Linux device is using the toolchain Docker image (see [github.com/snapdragon-toolchain](https://github.com/snapdragon-toolchain)).
|
||||
This image includes OpenCL SDK, Hexagon SDK, CMake, and the ARM64 Linux cross-compilation toolchain.
|
||||
|
||||
Cross-compilation is supported on **Linux X86** hosts. The resulting binaries are deployed to and run on the target **Qualcomm Snapdragon ARM64 Linux** device.
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ docker run -it -u $(id -u):$(id -g) --volume $(pwd):/workspace --platform linux/amd64 ghcr.io/snapdragon-toolchain/arm64-linux:v0.1
|
||||
[d]/> cd /workspace
|
||||
```
|
||||
|
||||
Note: The rest of the **Linux** build process assumes that you're running inside the toolchain container.
|
||||
|
||||
|
||||
## How to Build
|
||||
|
||||
Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets:
|
||||
|
||||
```
|
||||
[d]/workspace> cp docs/backend/snapdragon/CMakeUserPresets.json .
|
||||
|
||||
[d]/workspace> cmake --preset arm64-linux-snapdragon-release -B build-snapdragon
|
||||
|
||||
[d]/workspace> cmake --build build-snapdragon -j $(nproc)
|
||||
```
|
||||
|
||||
To generate an installable "package" simply use cmake --install, then zip it:
|
||||
|
||||
```
|
||||
[d]/workspace> cmake --install build-snapdragon --prefix pkg-snapdragon
|
||||
[d]/workspace> zip -r pkg-snapdragon.zip pkg-snapdragon
|
||||
```
|
||||
|
||||
## How to Install
|
||||
|
||||
For this step, you will deploy the built binaries and libraries to the target Linux device. Transfer `pkg-snapdragon.zip` to the target device, then unzip it and set up the environment variables:
|
||||
|
||||
```
|
||||
$ unzip pkg-snapdragon.zip
|
||||
$ cd pkg-snapdragon
|
||||
$ export LD_LIBRARY_PATH=./lib
|
||||
$ export ADSP_LIBRARY_PATH=./lib
|
||||
```
|
||||
|
||||
At this point, you should also download some models onto the device:
|
||||
|
||||
```
|
||||
$ wget https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q4_0.gguf
|
||||
```
|
||||
|
||||
## How to Run
|
||||
Next, since we have setup the environment variables, we can run the llama-cli with the Hexagon backends:
|
||||
```
|
||||
$ ./bin/llama-cli -m Llama-3.2-3B-Instruct-Q4_0.gguf --device HTP0 -ngl 99 -p "what is the most popular cookie in the world?"
|
||||
```
|
||||
|
|
@ -1,158 +0,0 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <qurt_thread.h>
|
||||
#include <qurt_futex.h>
|
||||
|
||||
#include <HAP_compute_res.h>
|
||||
|
||||
#include "hmx-queue.h"
|
||||
|
||||
#define QURT_LOWEST_PRIO (254)
|
||||
|
||||
static inline void hmx_lock(struct hmx_queue *q)
|
||||
{
|
||||
if (!q->hmx_locked) {
|
||||
HAP_compute_res_hmx_lock(q->hap_rctx);
|
||||
q->hmx_locked = true;
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hmx_unlock(struct hmx_queue *q)
|
||||
{
|
||||
if (q->hmx_locked) {
|
||||
HAP_compute_res_hmx_unlock(q->hap_rctx);
|
||||
q->hmx_locked = false;
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hmx_queue_process(struct hmx_queue *q, bool* killed) {
|
||||
unsigned int ir = atomic_load(&q->idx_read);
|
||||
|
||||
while (ir != atomic_load(&q->idx_write)) {
|
||||
struct hmx_queue_desc *d = &q->desc[ir];
|
||||
if (!d->done) {
|
||||
FARF(HIGH, "hmx-queue-process: ir %u func %p data %p", ir, d->func, d->data);
|
||||
|
||||
enum hmx_queue_signal sig = (enum hmx_queue_signal) (unsigned int) d->func;
|
||||
switch (sig) {
|
||||
case HMX_QUEUE_NOOP: /* noop */; break;
|
||||
case HMX_QUEUE_KILL: *killed = true; break;
|
||||
case HMX_QUEUE_SUSPEND: hmx_unlock(q); break;
|
||||
default:
|
||||
hmx_lock(q);
|
||||
d->func(d->data);
|
||||
break;
|
||||
}
|
||||
|
||||
atomic_fetch_add(&d->done, 1);
|
||||
}
|
||||
|
||||
ir = (ir + 1) & q->idx_mask;
|
||||
atomic_store(&q->idx_read, ir);
|
||||
}
|
||||
}
|
||||
|
||||
static void hmx_queue_thread(void * arg) {
|
||||
struct hmx_queue * q = (struct hmx_queue *) arg;
|
||||
|
||||
FARF(HIGH, "hmx-queue-thread: started");
|
||||
|
||||
bool killed = false;
|
||||
|
||||
unsigned int poll_cnt = HMX_QUEUE_POLL_COUNT;
|
||||
unsigned int prev_seqn = 0;
|
||||
while (!killed) {
|
||||
unsigned int seqn = atomic_load(&q->seqn);
|
||||
if (seqn == prev_seqn) {
|
||||
if (--poll_cnt) { hex_pause(); continue; }
|
||||
FARF(HIGH, "hmx-queue-thread: sleeping");
|
||||
qurt_futex_wait(&q->seqn, prev_seqn);
|
||||
continue;
|
||||
}
|
||||
prev_seqn = seqn;
|
||||
poll_cnt = HMX_QUEUE_POLL_COUNT;
|
||||
|
||||
FARF(HIGH, "hmx-queue-thread: new work");
|
||||
|
||||
hmx_queue_process(q, &killed);
|
||||
}
|
||||
|
||||
FARF(HIGH, "hmx-queue-thread: stopped");
|
||||
}
|
||||
|
||||
struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx) {
|
||||
capacity = hex_ceil_pow2(capacity);
|
||||
|
||||
struct hmx_queue * q = (struct hmx_queue *) memalign(32, sizeof(struct hmx_queue));
|
||||
if (q == NULL) {
|
||||
FARF(ERROR, "%s: failed to allocate DMA queue\n", __FUNCTION__);
|
||||
return NULL;
|
||||
}
|
||||
memset(q, 0, sizeof(struct hmx_queue));
|
||||
q->capacity = capacity;
|
||||
q->idx_mask = capacity - 1;
|
||||
q->hap_rctx = hap_rctx;
|
||||
|
||||
q->desc = (struct hmx_queue_desc *) memalign(64, capacity * sizeof(struct hmx_queue_desc));
|
||||
if (!q->desc) {
|
||||
FARF(ERROR, "hmx-queue: failed to allocate HMX queue descriptors\n");
|
||||
return NULL;
|
||||
}
|
||||
memset(q->desc, 0, capacity * sizeof(struct hmx_queue_desc));
|
||||
|
||||
const size_t stack_size = HMX_QUEUE_THREAD_STACK_SIZE;
|
||||
q->stack = (unsigned char *) memalign(64, stack_size);
|
||||
if (!q->stack) {
|
||||
FARF(ERROR, "hmx-queue: thread stack allocation failed (%zu bytes)", stack_size);
|
||||
return NULL;
|
||||
}
|
||||
memset(q->stack, 0, stack_size);
|
||||
|
||||
// Match caller thread priority (same pattern as worker-pool.c).
|
||||
int prio = qurt_thread_get_priority(qurt_thread_get_id());
|
||||
if (prio < 1) {
|
||||
prio = 1;
|
||||
}
|
||||
if (prio > QURT_LOWEST_PRIO) {
|
||||
prio = QURT_LOWEST_PRIO;
|
||||
}
|
||||
|
||||
qurt_thread_attr_t attr;
|
||||
qurt_thread_attr_init(&attr);
|
||||
qurt_thread_attr_set_stack_addr(&attr, q->stack);
|
||||
qurt_thread_attr_set_stack_size(&attr, stack_size);
|
||||
qurt_thread_attr_set_priority(&attr, prio);
|
||||
qurt_thread_attr_set_name(&attr, "hmx-queue");
|
||||
|
||||
int err = qurt_thread_create(&q->thread, &attr, hmx_queue_thread, q);
|
||||
if (err) {
|
||||
FARF(ERROR, "hmx-worker: thread create failed (%d)", err);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
FARF(HIGH, "hmx-queue: capacity %u\n", capacity);
|
||||
|
||||
return q;
|
||||
}
|
||||
|
||||
void hmx_queue_delete(struct hmx_queue * q) {
|
||||
if (!q) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Tell the worker to exit.
|
||||
hmx_queue_flush(q);
|
||||
hmx_queue_signal(q, HMX_QUEUE_KILL);
|
||||
hmx_queue_flush(q);
|
||||
|
||||
int status;
|
||||
qurt_thread_join(q->thread, &status);
|
||||
|
||||
free(q->desc);
|
||||
free(q->stack);
|
||||
free(q);
|
||||
}
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
#ifndef HMX_QUEUE_H
|
||||
#define HMX_QUEUE_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdatomic.h>
|
||||
|
||||
#include <hexagon_types.h>
|
||||
#include <qurt_thread.h>
|
||||
#include <qurt_futex.h>
|
||||
#include <HAP_farf.h>
|
||||
|
||||
#include "hex-utils.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define HMX_QUEUE_THREAD_STACK_SIZE (16 * 1024)
|
||||
#define HMX_QUEUE_POLL_COUNT 2000
|
||||
|
||||
typedef void (*hmx_queue_func)(void *);
|
||||
|
||||
// Dummy funcs used as signals
|
||||
enum hmx_queue_signal {
|
||||
HMX_QUEUE_NOOP = 0, // aka NULL
|
||||
HMX_QUEUE_SUSPEND,
|
||||
HMX_QUEUE_KILL
|
||||
};
|
||||
|
||||
struct hmx_queue_desc {
|
||||
hmx_queue_func func;
|
||||
void * data;
|
||||
atomic_uint done;
|
||||
};
|
||||
|
||||
struct hmx_queue {
|
||||
struct hmx_queue_desc * desc;
|
||||
atomic_uint idx_write; // updated by producer (push)
|
||||
atomic_uint idx_read; // updated by consumer (process)
|
||||
unsigned int idx_pop; // updated by producer (pop)
|
||||
uint32_t idx_mask;
|
||||
uint32_t capacity;
|
||||
|
||||
atomic_uint seqn; // incremented for all pushes, used with futex
|
||||
qurt_thread_t thread;
|
||||
void * stack;
|
||||
uint32_t hap_rctx;
|
||||
bool hmx_locked;
|
||||
};
|
||||
|
||||
struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx);
|
||||
void hmx_queue_delete(struct hmx_queue * q);
|
||||
|
||||
static inline struct hmx_queue_desc hmx_queue_make_desc(hmx_queue_func func, void * data) {
|
||||
struct hmx_queue_desc d = { func, data };
|
||||
return d;
|
||||
}
|
||||
|
||||
static inline bool hmx_queue_push(struct hmx_queue * q, struct hmx_queue_desc d) {
|
||||
unsigned int ir = atomic_load(&q->idx_read);
|
||||
unsigned int iw = q->idx_write;
|
||||
|
||||
if (((iw + 1) & q->idx_mask) == ir) {
|
||||
FARF(HIGH, "hmx-queue-push: queue is full\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
atomic_store(&d.done, 0);
|
||||
|
||||
FARF(HIGH, "hmx-queue-push: iw %u func %p data %p\n", iw, d.func, d.data);
|
||||
|
||||
q->desc[iw] = d;
|
||||
atomic_store(&q->idx_write, (iw + 1) & q->idx_mask);
|
||||
// wake up our thread
|
||||
atomic_fetch_add(&q->seqn, 1);
|
||||
qurt_futex_wake(&q->seqn, 1);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static inline bool hmx_queue_signal(struct hmx_queue *q, enum hmx_queue_signal sig) {
|
||||
return hmx_queue_push(q, hmx_queue_make_desc((hmx_queue_func) sig, NULL));
|
||||
}
|
||||
|
||||
static inline bool hmx_queue_empty(struct hmx_queue * q) {
|
||||
return q->idx_pop == q->idx_write;
|
||||
}
|
||||
|
||||
static inline uint32_t hmx_queue_depth(struct hmx_queue * q) {
|
||||
return (q->idx_read - q->idx_read) & q->idx_mask;
|
||||
}
|
||||
|
||||
static inline uint32_t hmx_queue_capacity(struct hmx_queue * q) {
|
||||
return q->capacity;
|
||||
}
|
||||
|
||||
static inline struct hmx_queue_desc hmx_queue_pop(struct hmx_queue * q) {
|
||||
unsigned int ip = q->idx_pop;
|
||||
unsigned int iw = q->idx_write;
|
||||
|
||||
struct hmx_queue_desc rd = { NULL, NULL };
|
||||
if (ip == iw) {
|
||||
return rd;
|
||||
}
|
||||
|
||||
// Wait for desc to complete
|
||||
struct hmx_queue_desc * d = &q->desc[ip];
|
||||
while (!atomic_load(&d->done)) {
|
||||
FARF(HIGH, "hmx-queue-pop: waiting for HMX queue : %u\n", ip);
|
||||
hex_pause();
|
||||
}
|
||||
|
||||
rd = *d;
|
||||
q->idx_pop = (ip + 1) & q->idx_mask;
|
||||
|
||||
FARF(HIGH, "hmx-queue-pop: ip %u func %p data %p\n", ip, rd.func, rd.data);
|
||||
return rd;
|
||||
}
|
||||
|
||||
static inline void hmx_queue_flush(struct hmx_queue * q) {
|
||||
while (hmx_queue_pop(q).func != NULL) ;
|
||||
}
|
||||
|
||||
static inline void hmx_queue_suspend(struct hmx_queue *q) {
|
||||
hmx_queue_signal(q, HMX_QUEUE_SUSPEND);
|
||||
hmx_queue_flush(q);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
||||
#endif /* HMX_QUEUE_H */
|
||||
|
|
@ -918,6 +918,10 @@ ggml_backend_reg_t ggml_backend_metal_reg(void) {
|
|||
static std::vector<ggml_backend_device_ptr> devs;
|
||||
|
||||
if (!initialized) {
|
||||
// workaround macOS limitation (kIOGPUCommandBufferCallbackErrorImpactingInteractivity) until proper fix becomes possible
|
||||
// ref: https://github.com/ggml-org/llama.cpp/issues/20141#issuecomment-4272947703
|
||||
setenv("AGX_RELAX_CDM_CTXSTORE_TIMEOUT", "1", true);
|
||||
|
||||
static ggml_backend_metal_reg_ptr reg_ctx(ggml_backend_metal_reg_init());
|
||||
|
||||
for (int i = 0; i < g_devices; ++i) {
|
||||
|
|
|
|||
|
|
@ -1,176 +0,0 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_qcom_reqd_sub_group_size
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
#define QK_K 256
|
||||
#define K_SCALE_SIZE 12
|
||||
|
||||
inline void get_scale_min_k4(
|
||||
int j,
|
||||
global const uchar * q,
|
||||
uchar * d,
|
||||
uchar * m,
|
||||
uchar mask_d6,
|
||||
uchar mask_d4,
|
||||
uchar mask_hi2
|
||||
) {
|
||||
if (j < 4) {
|
||||
*d = q[j] & mask_d6;
|
||||
*m = q[j+4] & mask_d6;
|
||||
} else {
|
||||
*d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2);
|
||||
*m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_128
|
||||
#endif
|
||||
kernel void kernel_gemm_noshuffle_q5_k_f32(
|
||||
global const ushort * src0_q,
|
||||
global const uchar * src0_qh,
|
||||
global const uchar * src0_s,
|
||||
global const half * src0_d,
|
||||
global const half * src0_dm,
|
||||
read_only image1d_buffer_t src1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
int n_no_padding,
|
||||
uchar mask_d6,
|
||||
uchar mask_d4,
|
||||
uchar mask_hi2
|
||||
) {
|
||||
dst = (global float *)((global char *)dst + offsetd);
|
||||
int n_4 = n >> 2;
|
||||
int gy = get_global_id(0);
|
||||
int gx = get_global_id(1);
|
||||
int gx_2 = gx << 2;
|
||||
|
||||
half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0;
|
||||
half8 B;
|
||||
half4 dequantized_weights;
|
||||
|
||||
int num_blocks_K = k / QK_K;
|
||||
|
||||
global const ushort * weight_ptr = src0_q + gx_2;
|
||||
global const uchar * qh_ptr = src0_qh + gx_2;
|
||||
global const half * d_ptr = src0_d + gx_2;
|
||||
global const half * dm_ptr = src0_dm + gx_2;
|
||||
|
||||
for (int i = 0; i < k; i += 32) {
|
||||
int sb_idx = i / QK_K;
|
||||
int sub_idx = (i / 32) % 8;
|
||||
|
||||
half4 d = vload4(0, d_ptr + sb_idx * m);
|
||||
half4 dm = vload4(0, dm_ptr + sb_idx * m);
|
||||
|
||||
global const uchar * sc0 = src0_s + (gx_2+0) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE;
|
||||
global const uchar * sc1 = src0_s + (gx_2+1) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE;
|
||||
global const uchar * sc2 = src0_s + (gx_2+2) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE;
|
||||
global const uchar * sc3 = src0_s + (gx_2+3) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE;
|
||||
|
||||
uchar sv0, mn0, sv1, mn1, sv2, mn2, sv3, mn3;
|
||||
get_scale_min_k4(sub_idx, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2);
|
||||
get_scale_min_k4(sub_idx, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2);
|
||||
get_scale_min_k4(sub_idx, sc2, &sv2, &mn2, mask_d6, mask_d4, mask_hi2);
|
||||
get_scale_min_k4(sub_idx, sc3, &sv3, &mn3, mask_d6, mask_d4, mask_hi2);
|
||||
|
||||
half4 scale = convert_half4(convert_float4(d) * convert_float4((uchar4)(sv0, sv1, sv2, sv3)));
|
||||
half4 mval = convert_half4(convert_float4(dm) * convert_float4((uchar4)(mn0, mn1, mn2, mn3)));
|
||||
|
||||
for (int l = 0; l < 32; l += 4) {
|
||||
int ki = i + l;
|
||||
ushort4 bits4 = vload4(0, weight_ptr + (ki/4) * m);
|
||||
uchar4 qh_bits = vload4(0, qh_ptr + (ki/8) * m);
|
||||
int qh_shift = ki % 8;
|
||||
|
||||
// j=0
|
||||
B.s0123 = read_imageh(src1, gy*2 + (ki+0) * n_4);
|
||||
B.s4567 = read_imageh(src1, gy*2+1 + (ki+0) * n_4);
|
||||
dequantized_weights.s0 = ((bits4.s0 & 0x000F) | (((qh_bits.s0 >> (qh_shift+0)) & 1) << 4)) * scale.s0 - mval.s0;
|
||||
dequantized_weights.s1 = ((bits4.s1 & 0x000F) | (((qh_bits.s1 >> (qh_shift+0)) & 1) << 4)) * scale.s1 - mval.s1;
|
||||
dequantized_weights.s2 = ((bits4.s2 & 0x000F) | (((qh_bits.s2 >> (qh_shift+0)) & 1) << 4)) * scale.s2 - mval.s2;
|
||||
dequantized_weights.s3 = ((bits4.s3 & 0x000F) | (((qh_bits.s3 >> (qh_shift+0)) & 1) << 4)) * scale.s3 - mval.s3;
|
||||
c0 += B * dequantized_weights.s0;
|
||||
c1 += B * dequantized_weights.s1;
|
||||
c2 += B * dequantized_weights.s2;
|
||||
c3 += B * dequantized_weights.s3;
|
||||
|
||||
// j=1
|
||||
B.s0123 = read_imageh(src1, gy*2 + (ki+1) * n_4);
|
||||
B.s4567 = read_imageh(src1, gy*2+1 + (ki+1) * n_4);
|
||||
dequantized_weights.s0 = (((bits4.s0 & 0x00F0) >> 4) | (((qh_bits.s0 >> (qh_shift+1)) & 1) << 4)) * scale.s0 - mval.s0;
|
||||
dequantized_weights.s1 = (((bits4.s1 & 0x00F0) >> 4) | (((qh_bits.s1 >> (qh_shift+1)) & 1) << 4)) * scale.s1 - mval.s1;
|
||||
dequantized_weights.s2 = (((bits4.s2 & 0x00F0) >> 4) | (((qh_bits.s2 >> (qh_shift+1)) & 1) << 4)) * scale.s2 - mval.s2;
|
||||
dequantized_weights.s3 = (((bits4.s3 & 0x00F0) >> 4) | (((qh_bits.s3 >> (qh_shift+1)) & 1) << 4)) * scale.s3 - mval.s3;
|
||||
c0 += B * dequantized_weights.s0;
|
||||
c1 += B * dequantized_weights.s1;
|
||||
c2 += B * dequantized_weights.s2;
|
||||
c3 += B * dequantized_weights.s3;
|
||||
|
||||
// j=2
|
||||
B.s0123 = read_imageh(src1, gy*2 + (ki+2) * n_4);
|
||||
B.s4567 = read_imageh(src1, gy*2+1 + (ki+2) * n_4);
|
||||
dequantized_weights.s0 = (((bits4.s0 & 0x0F00) >> 8) | (((qh_bits.s0 >> (qh_shift+2)) & 1) << 4)) * scale.s0 - mval.s0;
|
||||
dequantized_weights.s1 = (((bits4.s1 & 0x0F00) >> 8) | (((qh_bits.s1 >> (qh_shift+2)) & 1) << 4)) * scale.s1 - mval.s1;
|
||||
dequantized_weights.s2 = (((bits4.s2 & 0x0F00) >> 8) | (((qh_bits.s2 >> (qh_shift+2)) & 1) << 4)) * scale.s2 - mval.s2;
|
||||
dequantized_weights.s3 = (((bits4.s3 & 0x0F00) >> 8) | (((qh_bits.s3 >> (qh_shift+2)) & 1) << 4)) * scale.s3 - mval.s3;
|
||||
c0 += B * dequantized_weights.s0;
|
||||
c1 += B * dequantized_weights.s1;
|
||||
c2 += B * dequantized_weights.s2;
|
||||
c3 += B * dequantized_weights.s3;
|
||||
|
||||
// j=3
|
||||
B.s0123 = read_imageh(src1, gy*2 + (ki+3) * n_4);
|
||||
B.s4567 = read_imageh(src1, gy*2+1 + (ki+3) * n_4);
|
||||
dequantized_weights.s0 = (((bits4.s0 & 0xF000) >> 12) | (((qh_bits.s0 >> (qh_shift+3)) & 1) << 4)) * scale.s0 - mval.s0;
|
||||
dequantized_weights.s1 = (((bits4.s1 & 0xF000) >> 12) | (((qh_bits.s1 >> (qh_shift+3)) & 1) << 4)) * scale.s1 - mval.s1;
|
||||
dequantized_weights.s2 = (((bits4.s2 & 0xF000) >> 12) | (((qh_bits.s2 >> (qh_shift+3)) & 1) << 4)) * scale.s2 - mval.s2;
|
||||
dequantized_weights.s3 = (((bits4.s3 & 0xF000) >> 12) | (((qh_bits.s3 >> (qh_shift+3)) & 1) << 4)) * scale.s3 - mval.s3;
|
||||
c0 += B * dequantized_weights.s0;
|
||||
c1 += B * dequantized_weights.s1;
|
||||
c2 += B * dequantized_weights.s2;
|
||||
c3 += B * dequantized_weights.s3;
|
||||
}
|
||||
}
|
||||
|
||||
int idx = (gy<<3)*m + (gx<<2);
|
||||
|
||||
if (idx+3 < m*n_no_padding) {
|
||||
vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if (idx+3 < m*n_no_padding) {
|
||||
vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if (idx+3 < m*n_no_padding) {
|
||||
vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if (idx+3 < m*n_no_padding) {
|
||||
vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if (idx+3 < m*n_no_padding) {
|
||||
vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if (idx+3 < m*n_no_padding) {
|
||||
vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if (idx+3 < m*n_no_padding) {
|
||||
vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if (idx+3 < m*n_no_padding) {
|
||||
vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,326 +0,0 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
|
||||
#ifdef cl_qcom_reqd_sub_group_size
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#endif
|
||||
|
||||
#define QK_K 256
|
||||
#define NSUBGROUPS 4
|
||||
#define SUBGROUP_SIZE 64
|
||||
|
||||
inline void get_scale_min_k4(
|
||||
int j,
|
||||
global const uchar * q,
|
||||
uchar * d,
|
||||
uchar * m,
|
||||
uchar mask_d6,
|
||||
uchar mask_d4,
|
||||
uchar mask_hi2
|
||||
) {
|
||||
if (j < 4) {
|
||||
*d = q[j] & mask_d6;
|
||||
*m = q[j+4] & mask_d6;
|
||||
} else {
|
||||
*d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2);
|
||||
*m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2);
|
||||
}
|
||||
}
|
||||
|
||||
#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, bits1, scale, minv, y) \
|
||||
float shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s0, 0); \
|
||||
total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s0 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s1 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 0); \
|
||||
total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 0); \
|
||||
total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 0); \
|
||||
total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s4, 0); \
|
||||
total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 0); \
|
||||
total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 0); \
|
||||
total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 0); \
|
||||
total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s0, 1); \
|
||||
total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s2 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s3 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 1); \
|
||||
total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 1); \
|
||||
total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 1); \
|
||||
total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s4, 1); \
|
||||
total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 1); \
|
||||
total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 1); \
|
||||
total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 1); \
|
||||
total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
|
||||
|
||||
#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, bits1, scale, minv, y) \
|
||||
shared_y = sub_group_broadcast(y.s0, 2); \
|
||||
total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s4 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s5 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 2); \
|
||||
total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 2); \
|
||||
total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 2); \
|
||||
total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s4, 2); \
|
||||
total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 2); \
|
||||
total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 2); \
|
||||
total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 2); \
|
||||
total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s0, 3); \
|
||||
total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s6 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s7 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 3); \
|
||||
total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 3); \
|
||||
total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 3); \
|
||||
total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s4, 3); \
|
||||
total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 3); \
|
||||
total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 3); \
|
||||
total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 3); \
|
||||
total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \
|
||||
|
||||
|
||||
#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, bits1, scale, minv, y) \
|
||||
float8 shared_y; \
|
||||
shared_y = sub_group_broadcast(y, 0); \
|
||||
total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s0 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \
|
||||
total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \
|
||||
total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \
|
||||
total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \
|
||||
total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \
|
||||
total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \
|
||||
total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \
|
||||
total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s1 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \
|
||||
total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \
|
||||
total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \
|
||||
total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \
|
||||
total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \
|
||||
total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \
|
||||
total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \
|
||||
shared_y = sub_group_broadcast(y, 1); \
|
||||
total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s2 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \
|
||||
total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \
|
||||
total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \
|
||||
total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \
|
||||
total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \
|
||||
total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \
|
||||
total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \
|
||||
total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s3 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \
|
||||
total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \
|
||||
total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \
|
||||
total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \
|
||||
total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \
|
||||
total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \
|
||||
total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \
|
||||
|
||||
|
||||
#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, bits1, scale, minv, y) \
|
||||
shared_y = sub_group_broadcast(y, 2); \
|
||||
total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s4 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \
|
||||
total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \
|
||||
total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \
|
||||
total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \
|
||||
total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \
|
||||
total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \
|
||||
total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \
|
||||
total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s5 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \
|
||||
total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \
|
||||
total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \
|
||||
total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \
|
||||
total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \
|
||||
total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \
|
||||
total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \
|
||||
shared_y = sub_group_broadcast(y, 3); \
|
||||
total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s6 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \
|
||||
total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \
|
||||
total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \
|
||||
total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \
|
||||
total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \
|
||||
total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \
|
||||
total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \
|
||||
total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s7 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \
|
||||
total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \
|
||||
total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \
|
||||
total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \
|
||||
total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \
|
||||
total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \
|
||||
total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \
|
||||
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_gemv_noshuffle_q5_k_f32(
|
||||
read_only image1d_buffer_t src0_q,
|
||||
read_only image1d_buffer_t src0_qh,
|
||||
global half2 * src0_d,
|
||||
global half2 * src0_m,
|
||||
global uchar * src0_s,
|
||||
read_only image1d_buffer_t src1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
uchar mask_d6,
|
||||
uchar mask_d4,
|
||||
uchar mask_hi2)
|
||||
{
|
||||
uint groupId = get_local_id(1);
|
||||
uint gid = get_global_id(0);
|
||||
ushort slid = get_sub_group_local_id();
|
||||
|
||||
uint K = ne00;
|
||||
uint M = ne01;
|
||||
|
||||
uint LINE_STRIDE_A = M / 2;
|
||||
uint BLOCK_STRIDE_A = NSUBGROUPS * M;
|
||||
|
||||
uint LINE_STRIDE_A_QH = M / 2;
|
||||
uint BLOCK_STRIDE_A_QH = NSUBGROUPS * M / 2;
|
||||
uint scales_per_row = (K / QK_K) * 12;
|
||||
|
||||
private uint4 regA;
|
||||
private ushort4 regH;
|
||||
private half2 regS;
|
||||
private half2 regM;
|
||||
private float8 regB;
|
||||
|
||||
private float2 totalSum = (float2)(0.0f);
|
||||
|
||||
for (uint k = groupId; k < (K / 32); k += NSUBGROUPS) {
|
||||
uint sb = k / 8;
|
||||
uint j = k % 8;
|
||||
|
||||
half2 d = src0_d[gid + sb * LINE_STRIDE_A];
|
||||
half2 dm = src0_m[gid + sb * LINE_STRIDE_A];
|
||||
|
||||
global const uchar * sc0 = src0_s + 2 * gid * scales_per_row + sb * 12;
|
||||
global const uchar * sc1 = src0_s + (2 * gid + 1) * scales_per_row + sb * 12;
|
||||
|
||||
uchar sv0, mn0, sv1, mn1;
|
||||
get_scale_min_k4(j, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2);
|
||||
get_scale_min_k4(j, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2);
|
||||
|
||||
regS = convert_half2(convert_float2(d) * convert_float2((uchar2)(sv0, sv1)));
|
||||
regM = convert_half2(convert_float2(dm) * convert_float2((uchar2)(mn0, mn1)));
|
||||
|
||||
if (slid < 4) {
|
||||
regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));
|
||||
regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));
|
||||
}
|
||||
|
||||
regH.s0 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 0)).x);
|
||||
regH.s1 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 1)).x);
|
||||
regH.s2 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 2)).x);
|
||||
regH.s3 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 3)).x);
|
||||
|
||||
regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
|
||||
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
|
||||
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
|
||||
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
|
||||
#ifdef VECTOR_SUB_GROUP_BROADCAST
|
||||
dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB);
|
||||
#else
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB);
|
||||
#endif // VECTOR_SUB_GROUP_BROADCAST
|
||||
|
||||
regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
|
||||
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
|
||||
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
|
||||
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
|
||||
#ifdef VECTOR_SUB_GROUP_BROADCAST
|
||||
dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB);
|
||||
#else
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB);
|
||||
#endif // VECTOR_SUB_GROUP_BROADCAST
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #wave=4
|
||||
local float2 reduceLM[SUBGROUP_SIZE * 3];
|
||||
if (groupId == 1) {
|
||||
reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum;
|
||||
}
|
||||
if (groupId == 2) {
|
||||
reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum;
|
||||
}
|
||||
if (groupId == 3) {
|
||||
reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if (groupId == 0) {
|
||||
totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid];
|
||||
}
|
||||
if (groupId == 0) {
|
||||
totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid];
|
||||
}
|
||||
if (groupId == 0) {
|
||||
totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid];
|
||||
}
|
||||
|
||||
// 2 outputs per fiber in wave 0
|
||||
if (groupId == 0) {
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
vstore2(totalSum, 0, &(dst[gid * 2]));
|
||||
}
|
||||
}
|
||||
|
|
@ -1,101 +0,0 @@
|
|||
diagnostic(off, subgroup_uniformity);
|
||||
enable f16;
|
||||
|
||||
#define KV_TILE 32
|
||||
#define WG_SIZE 32
|
||||
|
||||
struct Params {
|
||||
offset_mask: u32,
|
||||
seq_len_q: u32,
|
||||
seq_len_kv: u32,
|
||||
stride_mask3: u32,
|
||||
// Number of KV blocks and Q blocks per batch.
|
||||
// nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = seq_len_q.
|
||||
nblk0: u32,
|
||||
nblk1: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read> mask: array<f16>;
|
||||
@group(0) @binding(1) var<storage, read_write> blk: array<u32>;
|
||||
@group(0) @binding(2) var<uniform> params: Params;
|
||||
|
||||
const MASK_MIN: f32 = -65504.0;
|
||||
const MASK_MAX: f32 = 65504.0;
|
||||
var<workgroup> wg_min: array<f32, WG_SIZE>;
|
||||
var<workgroup> wg_max: array<f32, WG_SIZE>;
|
||||
var<workgroup> wg_any: array<u32, WG_SIZE>;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>) {
|
||||
// Dispatch mapping:
|
||||
// - x indexes KV blocks
|
||||
// - y flattens (batch_idx, q_blk) as y = batch_idx * nblk1 + q_blk
|
||||
let kv_blk = wg_id.x;
|
||||
let y = wg_id.y;
|
||||
let q_blk = y % params.nblk1;
|
||||
let batch_idx = y / params.nblk1;
|
||||
if (kv_blk >= params.nblk0) {
|
||||
return;
|
||||
}
|
||||
|
||||
let q_start = q_blk;
|
||||
let k_start = kv_blk * KV_TILE;
|
||||
|
||||
let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u);
|
||||
let mask_batch_base = params.offset_mask + mask_batch * params.stride_mask3;
|
||||
|
||||
// We keep min/max to classify:
|
||||
// - fully masked (max <= MASK_MIN)
|
||||
// - all-zero mask (min == 0 && max == 0)
|
||||
// - mixed/general mask
|
||||
var local_min = MASK_MAX;
|
||||
var local_max = -MASK_MAX;
|
||||
var local_any = 0u;
|
||||
|
||||
let q_row = q_start;
|
||||
if (q_row < params.seq_len_q) {
|
||||
let row_base = mask_batch_base + q_row * params.seq_len_kv;
|
||||
for (var k_rel = local_id.x; k_rel < KV_TILE; k_rel += WG_SIZE) {
|
||||
let k_col = k_start + k_rel;
|
||||
if (k_col >= params.seq_len_kv) {
|
||||
continue;
|
||||
}
|
||||
let mv = f32(mask[row_base + k_col]);
|
||||
local_min = min(local_min, mv);
|
||||
local_max = max(local_max, mv);
|
||||
local_any = 1u;
|
||||
}
|
||||
}
|
||||
|
||||
wg_min[local_id.x] = local_min;
|
||||
wg_max[local_id.x] = local_max;
|
||||
wg_any[local_id.x] = local_any;
|
||||
workgroupBarrier();
|
||||
|
||||
// Thread 0 writes one state per block.
|
||||
if (local_id.x == 0u) {
|
||||
var mmin = wg_min[0];
|
||||
var mmax = wg_max[0];
|
||||
var many = wg_any[0];
|
||||
for (var i = 1u; i < WG_SIZE; i += 1u) {
|
||||
mmin = min(mmin, wg_min[i]);
|
||||
mmax = max(mmax, wg_max[i]);
|
||||
many = max(many, wg_any[i]);
|
||||
}
|
||||
|
||||
var state = 0u;
|
||||
if (many != 0u) {
|
||||
if (mmax <= MASK_MIN) {
|
||||
state = 0u;
|
||||
} else if (mmin == 0.0 && mmax == 0.0) {
|
||||
state = 2u;
|
||||
} else {
|
||||
state = 1u;
|
||||
}
|
||||
}
|
||||
|
||||
let blk_idx = (batch_idx * params.nblk1 + q_blk) * params.nblk0 + kv_blk;
|
||||
blk[blk_idx] = state;
|
||||
}
|
||||
}
|
||||
|
|
@ -1,78 +0,0 @@
|
|||
diagnostic(off, subgroup_uniformity);
|
||||
enable f16;
|
||||
enable subgroups;
|
||||
|
||||
// Default values
|
||||
#define HEAD_DIM_V 64
|
||||
#define WG_SIZE 128
|
||||
|
||||
struct Params {
|
||||
nrows: u32,
|
||||
seq_len_q: u32,
|
||||
n_heads: u32,
|
||||
offset_dst: u32,
|
||||
nwg: u32,
|
||||
tmp_data_base: u32,
|
||||
tmp_stats_base: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> tmp: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read_write> dst: array<vec4<f32>>;
|
||||
@group(0) @binding(2) var<uniform> params: Params;
|
||||
|
||||
const FLOAT_MIN: f32 = -1.0e9;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(subgroup_id) subgroup_id: u32,
|
||||
@builtin(num_subgroups) num_subgroups: u32,
|
||||
@builtin(subgroup_size) subgroup_size: u32,
|
||||
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
|
||||
let rid = wg_id.x;
|
||||
if (rid >= params.nrows) {
|
||||
return;
|
||||
}
|
||||
|
||||
let rows_per_batch = params.n_heads * params.seq_len_q;
|
||||
let batch_idx = rid / rows_per_batch;
|
||||
let rem = rid % rows_per_batch;
|
||||
let head_idx = rem / params.seq_len_q;
|
||||
let q_row = rem % params.seq_len_q;
|
||||
|
||||
let dst2_stride = HEAD_DIM_V * params.n_heads;
|
||||
let dst3_stride = dst2_stride * params.seq_len_q;
|
||||
let row_base = params.offset_dst + batch_idx * dst3_stride + q_row * dst2_stride + head_idx * HEAD_DIM_V;
|
||||
|
||||
let thread = sg_inv_id;
|
||||
if (params.nwg > subgroup_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
let stats_base = params.tmp_stats_base + rid * (2u * params.nwg);
|
||||
let active_thread = thread < params.nwg;
|
||||
let si = select(0.0, tmp[stats_base + 2u * thread + 0u], active_thread);
|
||||
let mi = select(FLOAT_MIN, tmp[stats_base + 2u * thread + 1u], active_thread);
|
||||
let m = subgroupMax(mi);
|
||||
let ms = select(0.0, exp(mi - m), active_thread);
|
||||
let s = subgroupAdd(si * ms);
|
||||
let inv_s = select(0.0, 1.0 / s, s != 0.0);
|
||||
|
||||
let row_tmp_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg);
|
||||
for (var elem_base = subgroup_id * 4u; elem_base < HEAD_DIM_V; elem_base += num_subgroups * 4u) {
|
||||
var weighted = vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
||||
if (active_thread) {
|
||||
let src = row_tmp_base + thread * HEAD_DIM_V + elem_base;
|
||||
weighted = vec4<f32>(tmp[src + 0u], tmp[src + 1u], tmp[src + 2u], tmp[src + 3u]) * ms;
|
||||
}
|
||||
|
||||
let sum_x = subgroupAdd(weighted.x);
|
||||
let sum_y = subgroupAdd(weighted.y);
|
||||
let sum_z = subgroupAdd(weighted.z);
|
||||
let sum_w = subgroupAdd(weighted.w);
|
||||
|
||||
if (thread == 0u) {
|
||||
let dst_vec_index = (row_base + elem_base) >> 2u;
|
||||
dst[dst_vec_index] = vec4<f32>(sum_x, sum_y, sum_z, sum_w) * inv_s;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,729 +0,0 @@
|
|||
diagnostic(off, chromium.subgroup_matrix_uniformity);
|
||||
diagnostic(off, subgroup_uniformity);
|
||||
enable f16;
|
||||
enable subgroups;
|
||||
enable chromium_experimental_subgroup_matrix;
|
||||
|
||||
#ifdef KV_F32
|
||||
#define KV_TYPE f32
|
||||
#else
|
||||
#define KV_TYPE f16
|
||||
#endif
|
||||
|
||||
#define HEAD_DIM_QK 64
|
||||
#define HEAD_DIM_V 64
|
||||
|
||||
|
||||
#define SG_MAT_M 8
|
||||
#define SG_MAT_N 8
|
||||
#define SG_MAT_K 8
|
||||
|
||||
#define Q_TILE SG_MAT_M
|
||||
#define KV_TILE 16
|
||||
#define WG_SIZE 64
|
||||
#ifndef VEC_NE
|
||||
#define VEC_NE 4u
|
||||
#endif
|
||||
|
||||
#define KV_BLOCKS (KV_TILE / SG_MAT_N)
|
||||
|
||||
#define BLOCK_SIZE 32
|
||||
#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
|
||||
#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE)
|
||||
#if defined(KV_Q4_0)
|
||||
#define NQ 16
|
||||
#define F16_PER_BLOCK 9
|
||||
#define WEIGHTS_PER_F16 4
|
||||
#elif defined(KV_Q8_0)
|
||||
#define NQ 8
|
||||
#define F16_PER_BLOCK 17
|
||||
#define WEIGHTS_PER_F16 2
|
||||
#endif
|
||||
#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
|
||||
|
||||
fn get_byte(value: u32, index: u32) -> u32 {
|
||||
return (value >> (index * 8)) & 0xFF;
|
||||
}
|
||||
|
||||
fn get_byte_i32(value: u32, index: u32) -> i32 {
|
||||
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
|
||||
}
|
||||
|
||||
struct Params {
|
||||
offset_q: u32,
|
||||
offset_k: u32,
|
||||
offset_v: u32,
|
||||
offset_mask: u32,
|
||||
offset_sinks: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
// shapes of Q/K/V
|
||||
n_heads: u32,
|
||||
seq_len_q: u32,
|
||||
seq_len_kv: u32,
|
||||
|
||||
// strides (in elements)
|
||||
stride_q1: u32,
|
||||
stride_q2: u32,
|
||||
stride_q3: u32,
|
||||
stride_k1: u32,
|
||||
stride_k2: u32,
|
||||
stride_k3: u32,
|
||||
stride_v1: u32,
|
||||
stride_v2: u32,
|
||||
stride_v3: u32,
|
||||
stride_mask3: u32,
|
||||
|
||||
// repeat factors for K/V, e.g., MHA vs. MQA vs. GQA
|
||||
q_per_kv: u32,
|
||||
|
||||
// softmax params
|
||||
scale: f32,
|
||||
max_bias: f32,
|
||||
logit_softcap: f32,
|
||||
n_head_log2: f32,
|
||||
m0: f32,
|
||||
m1: f32,
|
||||
|
||||
#ifdef BLK
|
||||
blk_base: u32,
|
||||
blk_nblk0: u32,
|
||||
blk_nblk1: u32,
|
||||
#endif
|
||||
|
||||
tmp_data_base: u32,
|
||||
tmp_stats_base: u32,
|
||||
nwg: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
|
||||
#if defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
||||
#else
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
|
||||
#endif
|
||||
#if defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
|
||||
#else
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>;
|
||||
#endif
|
||||
#if defined(MASK) && defined(SINKS)
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
|
||||
#ifdef BLK
|
||||
#define BLK_BINDING 5
|
||||
#define TMP_BINDING 6
|
||||
#define DST_BINDING 7
|
||||
#define PARAMS_BINDING 8
|
||||
#else
|
||||
#define TMP_BINDING 5
|
||||
#define DST_BINDING 6
|
||||
#define PARAMS_BINDING 7
|
||||
#endif
|
||||
#elif defined(MASK)
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
#ifdef BLK
|
||||
#define BLK_BINDING 4
|
||||
#define TMP_BINDING 5
|
||||
#define DST_BINDING 6
|
||||
#define PARAMS_BINDING 7
|
||||
#else
|
||||
#define TMP_BINDING 4
|
||||
#define DST_BINDING 5
|
||||
#define PARAMS_BINDING 6
|
||||
#endif
|
||||
#elif defined(SINKS)
|
||||
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||
#define TMP_BINDING 4
|
||||
#define DST_BINDING 5
|
||||
#define PARAMS_BINDING 6
|
||||
#else
|
||||
#define TMP_BINDING 3
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#endif
|
||||
|
||||
#ifdef BLK
|
||||
@group(0) @binding(BLK_BINDING) var<storage, read_write> blk: array<u32>;
|
||||
#endif
|
||||
@group(0) @binding(TMP_BINDING) var<storage, read_write> tmp: array<f32>;
|
||||
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
|
||||
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
|
||||
|
||||
// Just a very small float value.
|
||||
const FLOAT_MIN: f32 = -1.0e9;
|
||||
|
||||
var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
|
||||
|
||||
#ifndef KV_DIRECT
|
||||
const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
|
||||
// we can reuse the same shmem for K and V since we only need one at a time
|
||||
var<workgroup> kv_shmem: array<f16, kv_shmem_size>;
|
||||
#endif
|
||||
|
||||
var<workgroup> o_shmem: array<f16, Q_TILE * HEAD_DIM_V>;
|
||||
|
||||
#ifdef MASK
|
||||
// storage for mask values
|
||||
var<workgroup> mask_shmem: array<f16, Q_TILE * KV_TILE>;
|
||||
#endif
|
||||
|
||||
// note that we reuse the same storage for both since we only need one at a time
|
||||
var<workgroup> inter_shmem: array<f16, Q_TILE * KV_TILE>;
|
||||
|
||||
// Storage for row max and exp sum during online softmax
|
||||
var<workgroup> row_max_shmem: array<f32, Q_TILE>;
|
||||
var<workgroup> exp_sum_shmem: array<f32, Q_TILE>;
|
||||
var<workgroup> blk_state_wg: u32;
|
||||
|
||||
fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 {
|
||||
var v = select(FLOAT_MIN,
|
||||
f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale,
|
||||
kv_idx < KV_TILE);
|
||||
#ifdef LOGIT_SOFTCAP
|
||||
v = params.logit_softcap * tanh(v);
|
||||
#endif
|
||||
#ifdef MASK
|
||||
if (apply_mask) {
|
||||
var mask_val = select(0.0,f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE);
|
||||
v += select(mask_val, slope * mask_val, has_bias);
|
||||
}
|
||||
#endif
|
||||
return v;
|
||||
}
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(subgroup_id) subgroup_id: u32,
|
||||
@builtin(subgroup_size) subgroup_size: u32,
|
||||
@builtin(num_subgroups) num_subgroups: u32,
|
||||
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
|
||||
|
||||
// initialize row max for online softmax
|
||||
for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {
|
||||
row_max_shmem[i] = FLOAT_MIN;
|
||||
exp_sum_shmem[i] = 0.0;
|
||||
}
|
||||
|
||||
for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) {
|
||||
o_shmem[i] = 0.0;
|
||||
}
|
||||
|
||||
// workgroups per head/batch
|
||||
let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;
|
||||
let wg_per_batch = wg_per_head * params.n_heads;
|
||||
|
||||
let dst2_stride = HEAD_DIM_V * params.n_heads;
|
||||
let dst3_stride = dst2_stride * params.seq_len_q;
|
||||
|
||||
let iwg = wg_id.x % params.nwg;
|
||||
let base_wg_id = wg_id.x / params.nwg;
|
||||
|
||||
// batch index
|
||||
let batch_idx = base_wg_id / wg_per_batch;
|
||||
let q_batch_offset = params.offset_q + batch_idx * params.stride_q3;
|
||||
let k_batch_offset = params.offset_k + batch_idx * params.stride_k3;
|
||||
let v_batch_offset = params.offset_v + batch_idx * params.stride_v3;
|
||||
let wg_in_batch = base_wg_id % wg_per_batch;
|
||||
|
||||
// head index
|
||||
let head_idx = wg_in_batch / wg_per_head;
|
||||
let q_head_offset = q_batch_offset + head_idx * params.stride_q2;
|
||||
let k_head_idx = head_idx / params.q_per_kv;
|
||||
let v_head_idx = k_head_idx;
|
||||
let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
|
||||
let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2;
|
||||
|
||||
// starting Q row for this workgroup
|
||||
let wg_in_head = wg_in_batch % wg_per_head;
|
||||
let q_row_start = wg_in_head * Q_TILE;
|
||||
|
||||
#ifdef MASK
|
||||
// mask offset
|
||||
let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv;
|
||||
#endif
|
||||
|
||||
let head = f32(head_idx);
|
||||
let has_bias = params.max_bias > 0.0;
|
||||
let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias);
|
||||
|
||||
// load q tile into shared memory
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
|
||||
let q_row = elem_idx / HEAD_DIM_QK;
|
||||
let q_col = elem_idx % HEAD_DIM_QK;
|
||||
let head_q_row = q_row_start + q_row;
|
||||
let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
|
||||
q_shmem[elem_idx] = f16(select(
|
||||
0.0,
|
||||
Q[global_q_row_offset + q_col],
|
||||
head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK));
|
||||
}
|
||||
|
||||
for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) {
|
||||
#ifdef BLK
|
||||
let q_blk = q_row_start / Q_TILE;
|
||||
let kv_blk = kv_tile / KV_TILE;
|
||||
let blk_batch = select(0u, batch_idx, params.stride_mask3 > 0u);
|
||||
let blk_idx = params.blk_base + (blk_batch * params.blk_nblk1 + q_blk) * params.blk_nblk0 + kv_blk;
|
||||
let blk_state_local = blk[blk_idx];
|
||||
#else
|
||||
let blk_state_local = 1u;
|
||||
#endif
|
||||
if (local_id.x == 0u) {
|
||||
blk_state_wg = blk_state_local;
|
||||
}
|
||||
workgroupBarrier();
|
||||
let blk_state = blk_state_wg;
|
||||
let skip_tile = blk_state == 0u;
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
|
||||
inter_shmem[elem_idx] = f16(0.0);
|
||||
}
|
||||
|
||||
// load k tile into shared memory
|
||||
#if defined(KV_Q4_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let k_row = blck_idx / BLOCKS_K;
|
||||
let global_k_row = kv_tile + k_row;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let row_offset = k_row * HEAD_DIM_QK;
|
||||
|
||||
if (global_k_row < params.seq_len_kv) {
|
||||
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = K[base_idx];
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = K[base_idx + 1u + block_offset + j];
|
||||
let q_1 = K[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_lo;
|
||||
kv_shmem[row_offset + idx + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_Q8_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let k_row = blck_idx / BLOCKS_K;
|
||||
let global_k_row = kv_tile + k_row;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let row_offset = k_row * HEAD_DIM_QK;
|
||||
|
||||
if (global_k_row < params.seq_len_kv) {
|
||||
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = K[base_idx];
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = K[base_idx + 1u + block_offset + j];
|
||||
let q_1 = K[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f16(q_byte) * d;
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_DIRECT)
|
||||
// Direct global loads for KV
|
||||
#else
|
||||
for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) {
|
||||
let k_row = elem_idx / HEAD_DIM_QK;
|
||||
let k_col = elem_idx % HEAD_DIM_QK;
|
||||
let global_k_row = kv_tile + k_row;
|
||||
let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;
|
||||
let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK;
|
||||
let vec_idx = (global_k_row_offset + k_col) >> 2u;
|
||||
let k4 = select(vec4<KV_TYPE>(0.0), K[vec_idx], in_bounds);
|
||||
kv_shmem[elem_idx + 0u] = f16(k4.x);
|
||||
kv_shmem[elem_idx + 1u] = f16(k4.y);
|
||||
kv_shmem[elem_idx + 2u] = f16(k4.z);
|
||||
kv_shmem[elem_idx + 3u] = f16(k4.w);
|
||||
}
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// accumulate q block * k block into registers across the entire KV tile
|
||||
if (!skip_tile) {
|
||||
let num_of_threads = subgroup_size / VEC_NE;
|
||||
let tx = sg_inv_id % num_of_threads;
|
||||
let ty = sg_inv_id / num_of_threads;
|
||||
for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) {
|
||||
continue;
|
||||
}
|
||||
let local_q_row_offset = q_tile_row * HEAD_DIM_QK;
|
||||
|
||||
for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += VEC_NE) {
|
||||
let kv_idx = kv_base + ty;
|
||||
var partial_sum: f32 = 0.0;
|
||||
let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv;
|
||||
if (kv_valid) {
|
||||
for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) {
|
||||
let q_off = local_q_row_offset + i * 4u;
|
||||
|
||||
let qv = vec4<f32>(
|
||||
f32(q_shmem[q_off + 0u]),
|
||||
f32(q_shmem[q_off + 1u]),
|
||||
f32(q_shmem[q_off + 2u]),
|
||||
f32(q_shmem[q_off + 3u]));
|
||||
#ifdef KV_DIRECT
|
||||
let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u);
|
||||
let kv = vec4<f32>(K[idx >> 2u]);
|
||||
#else
|
||||
let idx = kv_idx * HEAD_DIM_QK + (i * 4u);
|
||||
let kv = vec4<f32>(
|
||||
f32(kv_shmem[idx + 0u]),
|
||||
f32(kv_shmem[idx + 1u]),
|
||||
f32(kv_shmem[idx + 2u]),
|
||||
f32(kv_shmem[idx + 3u]));
|
||||
#endif
|
||||
partial_sum += dot(qv, kv);
|
||||
}
|
||||
}
|
||||
var sum = partial_sum;
|
||||
// Reduce over tx threads (NL) for this ty stripe.
|
||||
var tx_delta = num_of_threads >> 1u;
|
||||
loop {
|
||||
if (tx_delta == 0u) {
|
||||
break;
|
||||
}
|
||||
let sh = subgroupShuffleDown(sum, tx_delta);
|
||||
if (tx < tx_delta) {
|
||||
sum += sh;
|
||||
}
|
||||
tx_delta >>= 1u;
|
||||
}
|
||||
|
||||
let sum_bcast = subgroupShuffle(sum, num_of_threads * ty);
|
||||
if (tx == 0u && kv_valid) {
|
||||
let dst_idx = q_tile_row * KV_TILE + kv_idx;
|
||||
inter_shmem[dst_idx] = f16(sum_bcast);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#ifdef MASK
|
||||
let apply_mask = !skip_tile && (blk_state != 2u);
|
||||
if (apply_mask) {
|
||||
// load mask tile into shared memory for this KV block
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
|
||||
let mask_row = elem_idx / KV_TILE;
|
||||
let mask_col = elem_idx % KV_TILE;
|
||||
let global_q_row = q_row_start + mask_row;
|
||||
let global_k_col = kv_tile + mask_col;
|
||||
let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv;
|
||||
let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col;
|
||||
mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds);
|
||||
}
|
||||
}
|
||||
#else
|
||||
let apply_mask = false;
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// online softmax
|
||||
if (!skip_tile) {
|
||||
for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) {
|
||||
break;
|
||||
}
|
||||
|
||||
var prev_max = row_max_shmem[q_tile_row];
|
||||
var final_max = prev_max;
|
||||
// pass 1: compute final max across the full KV tile in chunks
|
||||
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
|
||||
let kv_idx = kv_offset + sg_inv_id;
|
||||
let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE;
|
||||
let softmax_term = select(FLOAT_MIN,
|
||||
calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask),
|
||||
kv_valid);
|
||||
final_max = subgroupMax(max(final_max, softmax_term));
|
||||
}
|
||||
|
||||
var total_exp_term: f32 = 0.0;
|
||||
// pass 2: compute exp sum and write P using final_max
|
||||
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
|
||||
let kv_idx = kv_offset + sg_inv_id;
|
||||
let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask);
|
||||
let cur_p = select(0.0,
|
||||
exp(softmax_term - final_max),
|
||||
kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);
|
||||
total_exp_term += subgroupAdd(cur_p);
|
||||
if (kv_idx < KV_TILE) {
|
||||
inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p);
|
||||
}
|
||||
}
|
||||
|
||||
let cur_exp = exp(prev_max - final_max);
|
||||
|
||||
if (sg_inv_id == 0) {
|
||||
row_max_shmem[q_tile_row] = final_max;
|
||||
exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term;
|
||||
}
|
||||
|
||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||
let idx = q_tile_row * HEAD_DIM_V + elem_idx;
|
||||
o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// load v tile into shared memory
|
||||
#if defined(KV_Q4_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let v_row = blck_idx / BLOCKS_V;
|
||||
let global_v_row = kv_tile + v_row;
|
||||
let block_k = blck_idx % BLOCKS_V;
|
||||
let row_offset = v_row * HEAD_DIM_V;
|
||||
|
||||
if (global_v_row < params.seq_len_kv) {
|
||||
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = V[base_idx];
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = V[base_idx + 1u + block_offset + j];
|
||||
let q_1 = V[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_lo;
|
||||
kv_shmem[row_offset + idx + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_Q8_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let v_row = blck_idx / BLOCKS_V;
|
||||
let global_v_row = kv_tile + v_row;
|
||||
let block_k = blck_idx % BLOCKS_V;
|
||||
let row_offset = v_row * HEAD_DIM_V;
|
||||
|
||||
if (global_v_row < params.seq_len_kv) {
|
||||
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = V[base_idx];
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = V[base_idx + 1u + block_offset + j];
|
||||
let q_1 = V[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f16(q_byte) * d;
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_DIRECT)
|
||||
// Direct global loads for KV
|
||||
#else
|
||||
for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) {
|
||||
let v_row = elem_idx / HEAD_DIM_V;
|
||||
let v_col = elem_idx % HEAD_DIM_V;
|
||||
let global_v_row = kv_tile + v_row;
|
||||
let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;
|
||||
let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V;
|
||||
let vec_idx = (global_v_row_offset + v_col) >> 2u;
|
||||
let v4 = select(vec4<KV_TYPE>(0.0), V[vec_idx], in_bounds);
|
||||
kv_shmem[elem_idx + 0u] = f16(v4.x);
|
||||
kv_shmem[elem_idx + 1u] = f16(v4.y);
|
||||
kv_shmem[elem_idx + 2u] = f16(v4.z);
|
||||
kv_shmem[elem_idx + 3u] = f16(v4.w);
|
||||
}
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
if (!skip_tile) {
|
||||
// we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem
|
||||
// we want to compute O += P * V across the full KV tile
|
||||
let ne_threads : u32 = VEC_NE;
|
||||
let nl_threads = max(1u, subgroup_size / ne_threads);
|
||||
let tx_pv = sg_inv_id % nl_threads;
|
||||
let ty_pv = sg_inv_id / nl_threads;
|
||||
for (var q_tile_row = subgroup_id;
|
||||
q_tile_row < Q_TILE;
|
||||
q_tile_row += num_subgroups) {
|
||||
for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_threads) {
|
||||
var lo = vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
||||
for (var cc = 0u; cc < KV_TILE / ne_threads; cc += 1u) {
|
||||
let kv_idx = cc * ne_threads + ty_pv;
|
||||
let v_row = kv_tile + kv_idx;
|
||||
if (v_row >= params.seq_len_kv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]);
|
||||
#ifdef KV_DIRECT
|
||||
let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u;
|
||||
let v4 = vec4<f32>(V[v_idx >> 2u]);
|
||||
#else
|
||||
let v_idx = kv_idx * HEAD_DIM_V + vec_col * 4u;
|
||||
let v4 = vec4<f32>(
|
||||
f32(kv_shmem[v_idx + 0u]),
|
||||
f32(kv_shmem[v_idx + 1u]),
|
||||
f32(kv_shmem[v_idx + 2u]),
|
||||
f32(kv_shmem[v_idx + 3u]));
|
||||
#endif
|
||||
lo += p * v4;
|
||||
}
|
||||
|
||||
var lo_x = lo.x;
|
||||
var lo_y = lo.y;
|
||||
var lo_z = lo.z;
|
||||
var lo_w = lo.w;
|
||||
// Reduce over ty threads (NE) for this tx thread.
|
||||
var ty_delta = ne_threads >> 1u;
|
||||
loop {
|
||||
if (ty_delta == 0u) {
|
||||
break;
|
||||
}
|
||||
let thread_delta = ty_delta * nl_threads;
|
||||
let shx = subgroupShuffleDown(lo_x, thread_delta);
|
||||
let shy = subgroupShuffleDown(lo_y, thread_delta);
|
||||
let shz = subgroupShuffleDown(lo_z, thread_delta);
|
||||
let shw = subgroupShuffleDown(lo_w, thread_delta);
|
||||
if (ty_pv < ty_delta) {
|
||||
lo_x += shx;
|
||||
lo_y += shy;
|
||||
lo_z += shz;
|
||||
lo_w += shw;
|
||||
}
|
||||
ty_delta >>= 1u;
|
||||
}
|
||||
|
||||
if (ty_pv == 0u) {
|
||||
let elem_base = vec_col * 4u;
|
||||
let o_base_idx = q_tile_row * HEAD_DIM_V + elem_base;
|
||||
o_shmem[o_base_idx + 0u] = f16(f32(o_shmem[o_base_idx + 0u]) + lo_x);
|
||||
o_shmem[o_base_idx + 1u] = f16(f32(o_shmem[o_base_idx + 1u]) + lo_y);
|
||||
o_shmem[o_base_idx + 2u] = f16(f32(o_shmem[o_base_idx + 2u]) + lo_z);
|
||||
o_shmem[o_base_idx + 3u] = f16(f32(o_shmem[o_base_idx + 3u]) + lo_w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
|
||||
#ifdef SINKS
|
||||
// Sinks are global terms and must be applied exactly once across split workgroups.
|
||||
if (iwg == 0u) {
|
||||
for (var q_tile_row = subgroup_id;
|
||||
q_tile_row < Q_TILE;
|
||||
q_tile_row += num_subgroups) {
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) {
|
||||
break;
|
||||
}
|
||||
|
||||
var prev_max = row_max_shmem[q_tile_row];
|
||||
|
||||
// for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum
|
||||
let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0);
|
||||
let new_max = subgroupMax(max(prev_max, sink_val));
|
||||
let max_exp = exp(prev_max - new_max);
|
||||
let sink_exp = exp(sink_val - new_max);
|
||||
|
||||
let sink_exp_sum = subgroupAdd(sink_exp);
|
||||
|
||||
if (sg_inv_id == 0) {
|
||||
row_max_shmem[q_tile_row] = new_max;
|
||||
exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum;
|
||||
}
|
||||
|
||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||
let idx = q_tile_row * HEAD_DIM_V + elem_idx;
|
||||
o_shmem[idx] = f16(f32(o_shmem[idx]) * max_exp);
|
||||
}
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
#endif
|
||||
let rows_per_batch = params.n_heads * params.seq_len_q;
|
||||
for (var q_tile_row = subgroup_id;
|
||||
q_tile_row < Q_TILE;
|
||||
q_tile_row += num_subgroups) {
|
||||
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) { break; }
|
||||
|
||||
if (params.nwg == 1u) {
|
||||
let exp_sum = exp_sum_shmem[q_tile_row];
|
||||
let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
|
||||
let row_base: u32 =
|
||||
params.offset_dst + batch_idx * dst3_stride + global_q_row * dst2_stride + head_idx * HEAD_DIM_V;
|
||||
|
||||
for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) {
|
||||
let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
|
||||
let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
|
||||
let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
|
||||
let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
|
||||
|
||||
let v = vec4<f32>(
|
||||
f32(o_shmem[i0]) * scale,
|
||||
f32(o_shmem[i1]) * scale,
|
||||
f32(o_shmem[i2]) * scale,
|
||||
f32(o_shmem[i3]) * scale
|
||||
);
|
||||
|
||||
let dst_vec_index: u32 = (row_base + elem_base) >> 2u;
|
||||
dst[dst_vec_index] = v;
|
||||
}
|
||||
} else {
|
||||
let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + global_q_row;
|
||||
let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V;
|
||||
let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg;
|
||||
|
||||
for (var elem_base = sg_inv_id * 4u;
|
||||
elem_base < HEAD_DIM_V;
|
||||
elem_base += subgroup_size * 4u) {
|
||||
|
||||
let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
|
||||
let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
|
||||
let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
|
||||
let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
|
||||
|
||||
let tbase = tmp_row_data_base + elem_base;
|
||||
tmp[tbase + 0u] = f32(o_shmem[i0]);
|
||||
tmp[tbase + 1u] = f32(o_shmem[i1]);
|
||||
tmp[tbase + 2u] = f32(o_shmem[i2]);
|
||||
tmp[tbase + 3u] = f32(o_shmem[i3]);
|
||||
}
|
||||
|
||||
if (sg_inv_id == 0u) {
|
||||
tmp[tmp_row_stats_base + 0u] = exp_sum_shmem[q_tile_row];
|
||||
tmp[tmp_row_stats_base + 1u] = row_max_shmem[q_tile_row];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,155 +0,0 @@
|
|||
enable f16;
|
||||
|
||||
#ifdef TYPE_F32
|
||||
#define DataType f32
|
||||
#endif
|
||||
#ifdef TYPE_F16
|
||||
#define DataType f16
|
||||
#endif
|
||||
|
||||
#ifdef OP_REGLU
|
||||
fn op(a: DataType, b: DataType) -> DataType {
|
||||
return max(a, 0) * b;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef OP_GEGLU
|
||||
const SQRT_2_OVER_PI: DataType = 0.79788456080286535587989211986876;
|
||||
const GELU_COEF_A: DataType = 0.044715;
|
||||
|
||||
fn op(a: DataType, b: DataType) -> DataType {
|
||||
let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a);
|
||||
return 0.5 * a * (2.0 - 2.0/ (exp(2* val) + 1)) * b;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef OP_SWIGLU
|
||||
fn op(a: DataType, b: DataType) -> DataType {
|
||||
return a / (1.0 + exp(-a)) * b;
|
||||
}
|
||||
#endif
|
||||
#ifdef OP_SWIGLU_OAI
|
||||
fn op(a: f32, b: f32) -> f32 {
|
||||
let xi = min(a, params.limit);
|
||||
let gi = max(min(b, params.limit), -params.limit);
|
||||
var out_glu = xi / (1.0 + exp(-xi * params.alpha));
|
||||
out_glu = out_glu * (1.0 + gi);
|
||||
return out_glu;
|
||||
}
|
||||
#endif
|
||||
#ifdef OP_GEGLU_ERF
|
||||
const p_erf: DataType = 0.3275911;
|
||||
const a1_erf: DataType = 0.254829592;
|
||||
const a2_erf: DataType = -0.284496736;
|
||||
const a3_erf: DataType = 1.421413741;
|
||||
const a4_erf: DataType = -1.453152027;
|
||||
const a5_erf: DataType = 1.061405429;
|
||||
const SQRT_2_INV: DataType = 0.7071067811865476;
|
||||
|
||||
fn op(a: DataType, b: DataType) -> DataType {
|
||||
let a_div_sqr2 = a * SQRT_2_INV;
|
||||
let sign_x = sign(a_div_sqr2);
|
||||
let x = abs(a_div_sqr2);
|
||||
let t = 1.0 / (1.0 + p_erf * x);
|
||||
let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x));
|
||||
let erf_approx = sign_x * y;
|
||||
return 0.5 * a * (1.0 + erf_approx) * b;
|
||||
}
|
||||
#endif
|
||||
#ifdef OP_GEGLU_QUICK
|
||||
const GELU_QUICK_COEF: DataType = -1.702;
|
||||
|
||||
fn op(a: DataType, b: DataType) -> DataType {
|
||||
return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b;
|
||||
}
|
||||
#endif
|
||||
|
||||
struct Params {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
// Strides (in elements)
|
||||
stride_src01: u32,
|
||||
stride_src02: u32,
|
||||
stride_src03: u32,
|
||||
|
||||
stride_src11: u32,
|
||||
stride_src12: u32,
|
||||
stride_src13: u32,
|
||||
|
||||
stride_dst1: u32,
|
||||
stride_dst2: u32,
|
||||
stride_dst3: u32,
|
||||
|
||||
// shape of dst
|
||||
ne: u32,
|
||||
ne0: u32,
|
||||
ne1: u32,
|
||||
ne2: u32,
|
||||
|
||||
swapped: u32,
|
||||
alpha: f32,
|
||||
limit: f32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<DataType>;
|
||||
|
||||
#ifdef NO_SPLIT
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<DataType>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
fn a_value(base: u32) -> DataType {
|
||||
let offset: u32 = select(0, params.ne0, params.swapped != 0);
|
||||
return src0[base + offset];
|
||||
}
|
||||
|
||||
fn b_value(base: u32) -> DataType {
|
||||
let offset: u32 = select(params.ne0, 0, params.swapped != 0);
|
||||
return src0[base + offset];
|
||||
}
|
||||
|
||||
#else
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1: array<DataType>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<DataType>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
fn a_value(base: u32) -> DataType {
|
||||
return src0[base];
|
||||
}
|
||||
|
||||
fn b_value(base: u32) -> DataType {
|
||||
return src1[base];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
var i = gid.x;
|
||||
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
||||
i = i % (params.ne2 * params.ne1 * params.ne0);
|
||||
let i2 = i / (params.ne1 * params.ne0);
|
||||
i = i % (params.ne1 * params.ne0);
|
||||
let i1 = i / params.ne0;
|
||||
let i0 = i % params.ne0;
|
||||
|
||||
let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0;
|
||||
let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
|
||||
let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
|
||||
|
||||
dst[i_dst] = op(a_value(i_a), b_value(i_b));
|
||||
}
|
||||
|
|
@ -197,6 +197,7 @@ class Keys:
|
|||
FREQ_BASE_SWA = "{arch}.rope.freq_base_swa"
|
||||
SCALING_TYPE = "{arch}.rope.scaling.type"
|
||||
SCALING_FACTOR = "{arch}.rope.scaling.factor"
|
||||
SCALING_ALPHA = "{arch}.rope.scaling.alpha"
|
||||
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
|
||||
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
|
||||
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
|
||||
|
|
@ -471,6 +472,7 @@ class MODEL_ARCH(IntEnum):
|
|||
ERNIE4_5_MOE = auto()
|
||||
HUNYUAN_MOE = auto()
|
||||
HUNYUAN_DENSE = auto()
|
||||
HUNYUAN_VL = auto()
|
||||
SMOLLM3 = auto()
|
||||
GPT_OSS = auto()
|
||||
LFM2 = auto()
|
||||
|
|
@ -957,6 +959,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.FALCON_H1: "falcon-h1",
|
||||
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
|
||||
MODEL_ARCH.HUNYUAN_DENSE: "hunyuan-dense",
|
||||
MODEL_ARCH.HUNYUAN_VL: "hunyuan_vl",
|
||||
MODEL_ARCH.SMOLLM3: "smollm3",
|
||||
MODEL_ARCH.GPT_OSS: "gpt-oss",
|
||||
MODEL_ARCH.LFM2: "lfm2",
|
||||
|
|
@ -3489,6 +3492,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.HUNYUAN_VL: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.SMOLLM3: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
|
@ -4138,6 +4157,7 @@ class VisionProjectorType:
|
|||
YOUTUVL = "youtuvl"
|
||||
NEMOTRON_V2_VL = "nemotron_v2_vl"
|
||||
HUNYUANOCR = "hunyuanocr"
|
||||
HUNYUANVL = "hunyuanvl"
|
||||
|
||||
|
||||
# Items here are (block size, type size)
|
||||
|
|
|
|||
|
|
@ -973,6 +973,9 @@ class GGUFWriter:
|
|||
def add_rope_scaling_factor(self, value: float) -> None:
|
||||
self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value)
|
||||
|
||||
def add_rope_scaling_alpha(self, value: float) -> None:
|
||||
self.add_float32(Keys.Rope.SCALING_ALPHA.format(arch=self.arch), value)
|
||||
|
||||
def add_rope_scaling_attn_factors(self, value: float) -> None:
|
||||
self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value)
|
||||
|
||||
|
|
|
|||
|
|
@ -3400,6 +3400,7 @@ static void PrepareMediaEmbds(const int nctx, const std::vector<int> & media_int
|
|||
//added after https://github.com/ggml-org/llama.cpp/pull/22161, replacing clip_is_mrope function
|
||||
auto decoder_rope_type = llama_model_rope_type(llama_get_model(llama_ctx_v4));
|
||||
switch (decoder_rope_type) {
|
||||
case LLAMA_ROPE_TYPE_NONE:
|
||||
case LLAMA_ROPE_TYPE_NORM:
|
||||
case LLAMA_ROPE_TYPE_NEOX:
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,53 +0,0 @@
|
|||
|
||||
#!/usr/bin/env pwsh
|
||||
|
||||
# Basedir on device
|
||||
$basedir=".\pkg-snapdragon"
|
||||
|
||||
$cli_opts=$args
|
||||
|
||||
$model="Llama-3.2-3B-Instruct-Q4_0.gguf"
|
||||
if ($null -ne $env:M) {
|
||||
$model=$env:M
|
||||
}
|
||||
|
||||
$device="HTP0"
|
||||
if ($null -ne $env:D) {
|
||||
$device=$env:D
|
||||
}
|
||||
|
||||
if ($null -ne $env:V) {
|
||||
$env:GGML_HEXAGON_VERBOSE=$env:V
|
||||
}
|
||||
|
||||
if ($null -ne $env:SCHED) {
|
||||
$env:GGML_SCHED_DEBUG=$env:SCHED; $cli_opts="$cli_opts -v"
|
||||
}
|
||||
|
||||
if ($null -ne $env:PROF) {
|
||||
$env:GGML_HEXAGON_PROFILE=$env:PROF; $env:GGML_HEXAGON_OPSYNC=1
|
||||
}
|
||||
|
||||
if ($null -ne $env:OPMASK) {
|
||||
$env:GGML_HEXAGON_OPMASK=$env:OPMASK
|
||||
}
|
||||
|
||||
if ($null -ne $env:NHVX) {
|
||||
$env:GGML_HEXAGON_NHVX=$env:NHVX
|
||||
}
|
||||
|
||||
if ($null -ne $env:NDEV) {
|
||||
$env:GGML_HEXAGON_NDEV=$env:NDEV
|
||||
}
|
||||
|
||||
if ($null -ne $env:HB) {
|
||||
$env:GGML_HEXAGON_HOSTBUF=$env:HB
|
||||
}
|
||||
|
||||
$env:ADSP_LIBRARY_PATH="$basedir\lib"
|
||||
|
||||
& "$basedir\bin\llama-completion.exe" `
|
||||
--no-mmap -m $basedir\..\..\gguf\$model `
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 `
|
||||
--ctx-size 8192 --batch-size 256 -fa on `
|
||||
-ngl 99 -no-cnv --device $device $cli_opts
|
||||
|
|
@ -109,6 +109,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
|
||||
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
|
||||
{ LLM_ARCH_HUNYUAN_DENSE, "hunyuan-dense" },
|
||||
{ LLM_ARCH_HUNYUAN_VL, "hunyuan_vl" },
|
||||
{ LLM_ARCH_SMOLLM3, "smollm3" },
|
||||
{ LLM_ARCH_OPENAI_MOE, "gpt-oss" },
|
||||
{ LLM_ARCH_LFM2, "lfm2" },
|
||||
|
|
@ -250,6 +251,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
|
||||
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
|
||||
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
|
||||
{ LLM_KV_ROPE_SCALING_ALPHA, "%s.rope.scaling.alpha" },
|
||||
{ LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" },
|
||||
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
|
||||
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
|
||||
|
|
|
|||
|
|
@ -113,6 +113,7 @@ enum llm_arch {
|
|||
LLM_ARCH_ERNIE4_5_MOE,
|
||||
LLM_ARCH_HUNYUAN_MOE,
|
||||
LLM_ARCH_HUNYUAN_DENSE,
|
||||
LLM_ARCH_HUNYUAN_VL,
|
||||
LLM_ARCH_SMOLLM3,
|
||||
LLM_ARCH_OPENAI_MOE,
|
||||
LLM_ARCH_LFM2,
|
||||
|
|
@ -254,6 +255,7 @@ enum llm_kv {
|
|||
LLM_KV_ROPE_SCALE_LINEAR,
|
||||
LLM_KV_ROPE_SCALING_TYPE,
|
||||
LLM_KV_ROPE_SCALING_FACTOR,
|
||||
LLM_KV_ROPE_SCALING_ALPHA,
|
||||
LLM_KV_ROPE_SCALING_ATTN_FACTOR,
|
||||
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
|
||||
LLM_KV_ROPE_SCALING_FINETUNED,
|
||||
|
|
|
|||
|
|
@ -116,6 +116,7 @@ struct llama_hparams {
|
|||
float rope_freq_base_train_swa = 10000.0f;
|
||||
float rope_freq_scale_train;
|
||||
float rope_freq_scale_train_swa = 1.0f;
|
||||
float rope_scaling_alpha = 0.0f; // NTK-aware alpha for XDRoPE
|
||||
|
||||
uint32_t n_ctx_orig_yarn;
|
||||
float rope_yarn_log_mul = 0.0f;
|
||||
|
|
|
|||
|
|
@ -852,6 +852,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false);
|
||||
ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false);
|
||||
|
||||
if (arch == LLM_ARCH_HUNYUAN_VL || arch == LLM_ARCH_HUNYUAN_DENSE) {
|
||||
if (hparams.n_expert <= 1) {
|
||||
hparams.n_expert = 0;
|
||||
hparams.n_expert_used = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
|
||||
ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd);
|
||||
ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd_out_impl);
|
||||
|
|
@ -930,6 +937,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
|
||||
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false);
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_ALPHA, hparams.rope_scaling_alpha, false);
|
||||
|
||||
// non-transformer models do not have attention heads
|
||||
if (hparams.n_head() > 0) {
|
||||
|
|
@ -2707,9 +2715,18 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_HUNYUAN_VL:
|
||||
case LLM_ARCH_HUNYUAN_DENSE:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false);
|
||||
|
||||
// XDRoPE / NTK-aware scaling: base = rope_theta * alpha^(dim / (dim - 2))
|
||||
if (hparams.rope_scaling_alpha > 0.0f) {
|
||||
const int dim = hparams.n_embd_head_k();
|
||||
hparams.rope_freq_base_train = hparams.rope_freq_base_train
|
||||
* powf(hparams.rope_scaling_alpha, (float)dim / (float)(dim - 2));
|
||||
}
|
||||
|
||||
switch (hparams.n_embd) {
|
||||
case 1024: type = LLM_TYPE_0_5B; break;
|
||||
|
|
@ -7106,6 +7123,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_HUNYUAN_VL:
|
||||
case LLM_ARCH_HUNYUAN_DENSE:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
|
@ -9126,6 +9144,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
{
|
||||
llm = std::make_unique<llm_build_hunyuan_moe>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_HUNYUAN_VL:
|
||||
case LLM_ARCH_HUNYUAN_DENSE:
|
||||
{
|
||||
llm = std::make_unique<llm_build_hunyuan_dense>(*this, params);
|
||||
|
|
@ -9475,6 +9494,9 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_GLM4_MOE:
|
||||
return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_HUNYUAN_VL:
|
||||
return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
// all model arches should be listed explicitly here
|
||||
case LLM_ARCH_UNKNOWN:
|
||||
GGML_ABORT("unknown architecture");
|
||||
|
|
|
|||
|
|
@ -6,6 +6,11 @@ llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, cons
|
|||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
const bool use_mrope = hparams.use_mrope();
|
||||
|
||||
int sections[4];
|
||||
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
|
|
@ -37,22 +42,36 @@ llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, cons
|
|||
auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur,
|
||||
n_embd_head, n_head, n_head_kv, il);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
if (use_mrope) {
|
||||
Qcur = ggml_rope_multi(
|
||||
ctx0, Qcur, inp_pos, rope_factors,
|
||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_multi(
|
||||
ctx0, Kcur, inp_pos, rope_factors,
|
||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
} else {
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
}
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = build_norm(Kcur,
|
||||
model.layers[il].attn_k_norm, nullptr,
|
||||
LLM_NORM_RMS, il);
|
||||
|
|
|
|||
|
|
@ -150,7 +150,7 @@
|
|||
#define TN_TOK_BOI "v.boi"
|
||||
#define TN_TOK_EOI "v.eoi"
|
||||
|
||||
// hunyuanocr
|
||||
// hunyuanocr / hunyuanvl (shared GGUF tensor names)
|
||||
#define TN_MM_PRE_NORM "mm.pre_norm.%s"
|
||||
#define TN_TOK_IMG_BEGIN "mm.image_begin"
|
||||
#define TN_TOK_IMG_END "mm.image_end"
|
||||
|
|
@ -242,6 +242,15 @@
|
|||
#define TN_STD_BIAS "v.std_bias"
|
||||
#define TN_STD_SCALE "v.std_scale"
|
||||
|
||||
// yasa2
|
||||
#define TN_YASA_PATCH_LN_W "v.patch_ln.weight"
|
||||
#define TN_YASA_PATCH_LN_B "v.patch_ln.bias"
|
||||
#define TN_YASA_BACKBONE_LN_W "v.backbone_ln.weight"
|
||||
#define TN_YASA_BACKBONE_LN_B "v.backbone_ln.bias"
|
||||
#define TN_YASA_POS_EMBD "v.vision_pos_embed"
|
||||
#define TN_YASA_STAGE_DOWN_LN "v.stage.%d.down.ln.%s"
|
||||
#define TN_YASA_STAGE_DOWN_CONV "v.stage.%d.down.conv.%s"
|
||||
#define TN_YASA_STAGE_BLK "v.stage.%d.blk.%d.%s.%s"
|
||||
|
||||
// align x to upper multiple of n
|
||||
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
|
||||
|
|
@ -290,9 +299,11 @@ enum projector_type {
|
|||
PROJECTOR_TYPE_LFM2A,
|
||||
PROJECTOR_TYPE_GLM4V,
|
||||
PROJECTOR_TYPE_YOUTUVL,
|
||||
PROJECTOR_TYPE_YASA2,
|
||||
PROJECTOR_TYPE_KIMIK25,
|
||||
PROJECTOR_TYPE_NEMOTRON_V2_VL,
|
||||
PROJECTOR_TYPE_HUNYUANOCR,
|
||||
PROJECTOR_TYPE_HUNYUANVL,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
|
|
@ -335,9 +346,11 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
|||
{ PROJECTOR_TYPE_LFM2A, "lfm2a"},
|
||||
{ PROJECTOR_TYPE_GLM4V, "glm4v"},
|
||||
{ PROJECTOR_TYPE_YOUTUVL, "youtuvl"},
|
||||
{ PROJECTOR_TYPE_YASA2, "yasa2"},
|
||||
{ PROJECTOR_TYPE_KIMIK25, "kimik25"},
|
||||
{ PROJECTOR_TYPE_NEMOTRON_V2_VL, "nemotron_v2_vl"},
|
||||
{ PROJECTOR_TYPE_HUNYUANOCR, "hunyuanocr"},
|
||||
{ PROJECTOR_TYPE_HUNYUANVL, "hunyuanvl"},
|
||||
};
|
||||
|
||||
static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||
|
|
|
|||
|
|
@ -268,6 +268,27 @@ struct mobilenetv5_block {
|
|||
ggml_tensor * attn_norm_w = nullptr;
|
||||
};
|
||||
|
||||
struct yasa2_block {
|
||||
ggml_tensor * dw_w = nullptr;
|
||||
ggml_tensor * dw_b = nullptr;
|
||||
ggml_tensor * ln_w = nullptr;
|
||||
ggml_tensor * ln_b = nullptr;
|
||||
ggml_tensor * pw1_w = nullptr;
|
||||
ggml_tensor * pw1_b = nullptr;
|
||||
ggml_tensor * grn_w = nullptr;
|
||||
ggml_tensor * grn_b = nullptr;
|
||||
ggml_tensor * pw2_w = nullptr;
|
||||
ggml_tensor * pw2_b = nullptr;
|
||||
};
|
||||
|
||||
struct yasa2_stage {
|
||||
ggml_tensor * down_ln_w = nullptr;
|
||||
ggml_tensor * down_ln_b = nullptr;
|
||||
ggml_tensor * down_conv_w = nullptr;
|
||||
ggml_tensor * down_conv_b = nullptr;
|
||||
std::vector<yasa2_block> blocks;
|
||||
};
|
||||
|
||||
struct clip_model {
|
||||
clip_modality modality = CLIP_MODALITY_VISION;
|
||||
projector_type proj_type = PROJECTOR_TYPE_MLP;
|
||||
|
|
@ -402,6 +423,15 @@ struct clip_model {
|
|||
ggml_tensor * msfa_ffn_expand_bn = nullptr;
|
||||
ggml_tensor * msfa_ffn_project_bn = nullptr;
|
||||
|
||||
// yasa2
|
||||
ggml_tensor * yasa_patch_w = nullptr;
|
||||
ggml_tensor * yasa_patch_b = nullptr;
|
||||
ggml_tensor * yasa_patch_ln_w = nullptr;
|
||||
ggml_tensor * yasa_patch_ln_b = nullptr;
|
||||
ggml_tensor * yasa_backbone_ln_w = nullptr;
|
||||
ggml_tensor * yasa_backbone_ln_b = nullptr;
|
||||
ggml_tensor * yasa_vision_pos_embed = nullptr;
|
||||
std::vector<yasa2_stage> yasa_stages;
|
||||
|
||||
// pixtral, glm4v
|
||||
ggml_tensor * token_embd_img_break = nullptr;
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@
|
|||
#include "models/deepseekocr.cpp"
|
||||
#include "models/mobilenetv5.cpp"
|
||||
#include "models/youtuvl.cpp"
|
||||
#include "models/yasa2.cpp"
|
||||
|
||||
struct clip_logger_state g_logger_state = {clip_log_callback_default, NULL};
|
||||
|
||||
|
|
@ -969,6 +970,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
builder = std::make_unique<clip_graph_cogvlm>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_HUNYUANOCR:
|
||||
case PROJECTOR_TYPE_HUNYUANVL:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_hunyuanocr>(ctx, img);
|
||||
} break;
|
||||
|
|
@ -1004,6 +1006,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
{
|
||||
builder = std::make_unique<clip_graph_youtuvl>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_YASA2:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_yasa2>(ctx, img);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("missing cgraph builder");
|
||||
}
|
||||
|
|
@ -1474,6 +1480,16 @@ struct clip_model_loader {
|
|||
hparams.set_limit_image_tokens(1, 62500);
|
||||
hparams.set_warmup_n_tokens(16*16); // avoid OOM on warmup
|
||||
} break;
|
||||
case PROJECTOR_TYPE_YASA2:
|
||||
{
|
||||
hparams.ffn_op = FFN_GELU_ERF;
|
||||
log_ffn_op = "gelu_erf";
|
||||
hparams.image_resize_algo = RESIZE_ALGO_BICUBIC;
|
||||
|
||||
// reka model performs better when using resize_bicubic, which stretches
|
||||
// the image to fit fixed square size
|
||||
hparams.image_resize_pad = false;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GLM4V:
|
||||
{
|
||||
hparams.rope_theta = 10000.0f;
|
||||
|
|
@ -1544,6 +1560,16 @@ struct clip_model_loader {
|
|||
get_u32(KEY_IMAGE_MAX_PIXELS, hparams.image_max_pixels);
|
||||
hparams.set_warmup_n_tokens(28*28);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_HUNYUANVL:
|
||||
{
|
||||
hparams.n_merge = 2;
|
||||
hparams.image_resize_algo = RESIZE_ALGO_BICUBIC_PILLOW;
|
||||
hparams.image_resize_pad = false;
|
||||
hparams.ffn_op = FFN_GELU;
|
||||
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false);
|
||||
hparams.set_limit_image_tokens(256, 16384);
|
||||
hparams.set_warmup_n_tokens(32*32);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LFM2A:
|
||||
{
|
||||
// audio preprocessing params
|
||||
|
|
@ -1929,6 +1955,55 @@ struct clip_model_loader {
|
|||
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); // merger.mlp.2
|
||||
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_YASA2:
|
||||
{
|
||||
// reuse tensors already loaded by the common section
|
||||
// (TN_PATCH_EMBD and TN_PATCH_BIAS have the same tensor names)
|
||||
GGML_ASSERT(model.patch_embeddings_0 && "yasa2 requires v.patch_embd.weight");
|
||||
model.yasa_patch_w = model.patch_embeddings_0;
|
||||
model.yasa_patch_b = model.patch_bias;
|
||||
model.yasa_patch_ln_w = get_tensor(TN_YASA_PATCH_LN_W, false);
|
||||
model.yasa_patch_ln_b = get_tensor(TN_YASA_PATCH_LN_B, false);
|
||||
model.yasa_backbone_ln_w = get_tensor(TN_YASA_BACKBONE_LN_W, false);
|
||||
model.yasa_backbone_ln_b = get_tensor(TN_YASA_BACKBONE_LN_B, false);
|
||||
model.yasa_vision_pos_embed = get_tensor(TN_YASA_POS_EMBD, false);
|
||||
model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
|
||||
model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false);
|
||||
model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
|
||||
model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
|
||||
|
||||
model.yasa_stages.clear();
|
||||
for (int s = 0; ; ++s) {
|
||||
yasa2_stage stage;
|
||||
stage.down_ln_w = get_tensor(string_format(TN_YASA_STAGE_DOWN_LN, s, "weight"), false);
|
||||
stage.down_ln_b = get_tensor(string_format(TN_YASA_STAGE_DOWN_LN, s, "bias"), false);
|
||||
stage.down_conv_w = get_tensor(string_format(TN_YASA_STAGE_DOWN_CONV, s, "weight"), false);
|
||||
stage.down_conv_b = get_tensor(string_format(TN_YASA_STAGE_DOWN_CONV, s, "bias"), false);
|
||||
|
||||
for (int bi = 0; ; ++bi) {
|
||||
yasa2_block blk;
|
||||
blk.dw_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "dw", "weight"), false);
|
||||
if (!blk.dw_w) {
|
||||
break;
|
||||
}
|
||||
blk.dw_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "dw", "bias"), false);
|
||||
blk.ln_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "ln", "weight"), false);
|
||||
blk.ln_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "ln", "bias"), false);
|
||||
blk.pw1_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "pw1", "weight"), false);
|
||||
blk.pw1_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "pw1", "bias"), false);
|
||||
blk.grn_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "grn", "weight"), false);
|
||||
blk.grn_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "grn", "bias"), false);
|
||||
blk.pw2_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "pw2", "weight"), false);
|
||||
blk.pw2_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "pw2", "bias"), false);
|
||||
stage.blocks.push_back(blk);
|
||||
}
|
||||
|
||||
if (!stage.down_conv_w && stage.blocks.empty()) {
|
||||
break;
|
||||
}
|
||||
model.yasa_stages.push_back(std::move(stage));
|
||||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GLM4V:
|
||||
{
|
||||
model.mm_fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight"));
|
||||
|
|
@ -2249,6 +2324,7 @@ struct clip_model_loader {
|
|||
model.mm_eoi = get_tensor(TN_TOK_EOI);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_HUNYUANOCR:
|
||||
case PROJECTOR_TYPE_HUNYUANVL:
|
||||
{
|
||||
// proj.0 -> mm.0 (conv1), proj.2 -> mm.2 (conv2), mlp -> mm.model.fc (linear)
|
||||
model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
|
||||
|
|
@ -3062,6 +3138,19 @@ void setup_init_vision_shim_kcpp(struct clip_ctx * ctx_v) {
|
|||
img_end = "<|vision_end|>";
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_youtuvl>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_YASA2:
|
||||
{
|
||||
img_beg = "<image>";
|
||||
img_end = "</image>";
|
||||
// Currently only supprots single-tile preprocessing: any input is downscaled
|
||||
// to one image_size x image_size tile (64 output tokens via 8x8 adaptive avg
|
||||
// pool).
|
||||
// However, the model itself supports llava-uhd multi-tile tiling for high-res
|
||||
// images. This will be implemented in a future PR (dispatch on has_pinpoints
|
||||
// - see LDP/COGVLM branch above) and emit image_grid_pinpoints in the conversion
|
||||
// script.
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_fixed_size>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
case PROJECTOR_TYPE_GEMMA3NV:
|
||||
{
|
||||
|
|
@ -3199,6 +3288,7 @@ void setup_init_vision_shim_kcpp(struct clip_ctx * ctx_v) {
|
|||
image_preproc = std::make_unique<mtmd_image_preprocessor_deepseekocr>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_HUNYUANOCR:
|
||||
case PROJECTOR_TYPE_HUNYUANVL:
|
||||
{
|
||||
// note: these use fullwidth | (U+FF5C) and ▁ (U+2581) to match the tokenizer vocabulary
|
||||
img_beg = "<|hy_place▁holder▁no▁100|>";
|
||||
|
|
@ -3287,6 +3377,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 *
|
|||
case PROJECTOR_TYPE_GLM4V:
|
||||
case PROJECTOR_TYPE_PADDLEOCR:
|
||||
case PROJECTOR_TYPE_HUNYUANOCR:
|
||||
case PROJECTOR_TYPE_HUNYUANVL:
|
||||
case PROJECTOR_TYPE_YOUTUVL:
|
||||
return (img->nx / params.patch_size) / 2;
|
||||
case PROJECTOR_TYPE_STEP3VL:
|
||||
|
|
@ -3306,6 +3397,7 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 *
|
|||
case PROJECTOR_TYPE_QWEN3VL:
|
||||
case PROJECTOR_TYPE_GLM4V:
|
||||
case PROJECTOR_TYPE_PADDLEOCR:
|
||||
case PROJECTOR_TYPE_HUNYUANVL:
|
||||
case PROJECTOR_TYPE_YOUTUVL:
|
||||
return (img->ny / params.patch_size) / 2;
|
||||
case PROJECTOR_TYPE_STEP3VL:
|
||||
|
|
@ -3333,6 +3425,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
|||
{
|
||||
// do nothing
|
||||
} break;
|
||||
case PROJECTOR_TYPE_YASA2:
|
||||
{
|
||||
n_patches = 64; // adaptive average pooling to 8x8 tokens
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LDP:
|
||||
case PROJECTOR_TYPE_LDPV2:
|
||||
case PROJECTOR_TYPE_GLM_EDGE:
|
||||
|
|
@ -3493,6 +3589,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
|||
n_patches = h * (h + 1) + 1;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_HUNYUANOCR:
|
||||
case PROJECTOR_TYPE_HUNYUANVL:
|
||||
{
|
||||
int merge = ctx->model.hparams.n_merge;
|
||||
int ow = (img->nx / patch_size) / merge;
|
||||
|
|
@ -3953,9 +4050,74 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
case PROJECTOR_TYPE_PHI4:
|
||||
case PROJECTOR_TYPE_COGVLM:
|
||||
case PROJECTOR_TYPE_HUNYUANOCR:
|
||||
case PROJECTOR_TYPE_YASA2:
|
||||
{
|
||||
// do nothing
|
||||
} break;
|
||||
case PROJECTOR_TYPE_HUNYUANVL:
|
||||
{
|
||||
// Compute the HunyuanVL 2D position embedding on CPU (with the
|
||||
// custom sf=(target+0.1)/n_grid bilinear sampling that the
|
||||
// reference implementation uses) and upload it to the graph
|
||||
// input declared in clip_graph_hunyuanocr::build().
|
||||
GGML_ASSERT(model.position_embeddings != nullptr);
|
||||
ggml_tensor * src_t = model.position_embeddings;
|
||||
const int64_t n_embd = src_t->ne[0];
|
||||
const int64_t n_pos = src_t->ne[1]; // = n_grid * n_grid
|
||||
const int n_grid = (int)std::lround(std::sqrt((double)n_pos));
|
||||
GGML_ASSERT((int64_t)n_grid * n_grid == n_pos);
|
||||
const int out_w = pos_w; // pw
|
||||
const int out_h = pos_h; // ph
|
||||
|
||||
// Pull weight to host.
|
||||
std::vector<float> src(n_embd * n_pos);
|
||||
ggml_backend_tensor_get(src_t, src.data(), 0, ggml_nbytes(src_t));
|
||||
|
||||
// Output layout matches ggml_new_tensor_2d(F32, n_embd, out_h*out_w):
|
||||
// ne[0] = n_embd (fastest), ne[1] = out_h*out_w
|
||||
// dst[(y*out_w + x) * n_embd + c]
|
||||
std::vector<float> dst((size_t)n_embd * out_h * out_w);
|
||||
|
||||
const float sx = (float)(out_w + 0.1f) / (float)n_grid;
|
||||
const float sy = (float)(out_h + 0.1f) / (float)n_grid;
|
||||
|
||||
for (int y = 0; y < out_h; ++y) {
|
||||
// Match ggml_compute_forward_upscale_f32 pixel-center
|
||||
// convention (align_corners=False): src_y = (y+0.5)/sy - 0.5.
|
||||
const float fy = ((float)y + 0.5f) / sy - 0.5f;
|
||||
int y0 = (int)std::floor(fy);
|
||||
int y1 = y0 + 1;
|
||||
y0 = std::clamp(y0, 0, n_grid - 1);
|
||||
y1 = std::clamp(y1, 0, n_grid - 1);
|
||||
float wy1 = std::clamp(fy - (float)y0, 0.0f, 1.0f);
|
||||
const float wy0 = 1.0f - wy1;
|
||||
for (int x = 0; x < out_w; ++x) {
|
||||
const float fx = ((float)x + 0.5f) / sx - 0.5f;
|
||||
int x0 = (int)std::floor(fx);
|
||||
int x1 = x0 + 1;
|
||||
x0 = std::clamp(x0, 0, n_grid - 1);
|
||||
x1 = std::clamp(x1, 0, n_grid - 1);
|
||||
float wx1 = std::clamp(fx - (float)x0, 0.0f, 1.0f);
|
||||
const float wx0 = 1.0f - wx1;
|
||||
|
||||
const float w00 = wy0 * wx0;
|
||||
const float w01 = wy0 * wx1;
|
||||
const float w10 = wy1 * wx0;
|
||||
const float w11 = wy1 * wx1;
|
||||
|
||||
const float * s00 = &src[((size_t)y0 * n_grid + x0) * n_embd];
|
||||
const float * s01 = &src[((size_t)y0 * n_grid + x1) * n_embd];
|
||||
const float * s10 = &src[((size_t)y1 * n_grid + x0) * n_embd];
|
||||
const float * s11 = &src[((size_t)y1 * n_grid + x1) * n_embd];
|
||||
float * d = &dst[((size_t)y * out_w + x) * n_embd];
|
||||
for (int c = 0; c < n_embd; ++c) {
|
||||
d[c] = w00 * s00[c] + w01 * s01[c] + w10 * s10[c] + w11 * s11[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
set_input_f32("hunyuanvl_pos_embd", dst);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LLAMA4:
|
||||
{
|
||||
// set the 2D positions
|
||||
|
|
@ -4376,8 +4538,10 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
|||
case PROJECTOR_TYPE_KIMIVL:
|
||||
case PROJECTOR_TYPE_PADDLEOCR:
|
||||
case PROJECTOR_TYPE_KIMIK25:
|
||||
case PROJECTOR_TYPE_YASA2:
|
||||
return ctx->model.mm_2_w->ne[1];
|
||||
case PROJECTOR_TYPE_HUNYUANOCR:
|
||||
case PROJECTOR_TYPE_HUNYUANVL:
|
||||
return ctx->model.mm_model_proj->ne[1];
|
||||
case PROJECTOR_TYPE_COGVLM:
|
||||
return ctx->model.mm_4h_to_h_w->ne[1];
|
||||
|
|
|
|||
|
|
@ -5,7 +5,21 @@ ggml_cgraph * clip_graph_hunyuanocr::build() {
|
|||
const int pw = n_patches_x;
|
||||
const int ph = n_patches_y;
|
||||
|
||||
ggml_tensor * pos_embd = resize_position_embeddings(GGML_SCALE_MODE_BILINEAR);
|
||||
// Position embedding interpolation.
|
||||
// HunyuanVL needs scale factors sf=(target+0.1)/n_grid, which the standard
|
||||
// ggml_interpolate cannot express. To avoid adding a new ggml op, the
|
||||
// resize is computed on CPU in clip_image_batch_encode and uploaded here
|
||||
// as a graph input (named "hunyuanvl_pos_embd").
|
||||
// HunyuanOCR uses the same square layout and the standard ratio-based
|
||||
// interpolation provided by resize_position_embeddings().
|
||||
ggml_tensor * pos_embd = nullptr;
|
||||
if (proj_type == PROJECTOR_TYPE_HUNYUANVL && model.position_embeddings) {
|
||||
pos_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ph * pw);
|
||||
ggml_set_name(pos_embd, "hunyuanvl_pos_embd");
|
||||
ggml_set_input(pos_embd);
|
||||
} else {
|
||||
pos_embd = resize_position_embeddings(GGML_SCALE_MODE_BILINEAR);
|
||||
}
|
||||
|
||||
ggml_tensor * inp = build_inp();
|
||||
ggml_tensor * cur = build_vit(inp, n_patches, NORM_TYPE_NORMAL, hparams.ffn_op, pos_embd, nullptr);
|
||||
|
|
|
|||
|
|
@ -43,6 +43,14 @@ struct clip_graph_youtuvl : clip_graph {
|
|||
ggml_cgraph * build() override;
|
||||
};
|
||||
|
||||
struct clip_graph_yasa2 : clip_graph {
|
||||
clip_graph_yasa2(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||
ggml_cgraph * build() override;
|
||||
|
||||
ggml_tensor * layer_norm_channels(ggml_tensor * inp, ggml_tensor * w, ggml_tensor * b, float eps = 1e-6f);
|
||||
ggml_tensor * convnext_grn(ggml_tensor * inp, ggml_tensor * w, ggml_tensor * b);
|
||||
};
|
||||
|
||||
struct clip_graph_minicpmv : clip_graph {
|
||||
clip_graph_minicpmv(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||
ggml_cgraph * build() override;
|
||||
|
|
|
|||
191
tools/mtmd/models/yasa2.cpp
Normal file
191
tools/mtmd/models/yasa2.cpp
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
// ABOUTME: Yasa2 vision encoder graph builder for ConvNeXt-based architecture.
|
||||
// ABOUTME: Implements patch embedding, ConvNeXt stages with GRN, and adaptive pooling.
|
||||
|
||||
#include "models.h"
|
||||
|
||||
static ggml_tensor * add_channel_bias(
|
||||
ggml_context * ctx0,
|
||||
ggml_tensor * x_whcb,
|
||||
ggml_tensor * b_c) {
|
||||
if (!b_c) {
|
||||
return x_whcb;
|
||||
}
|
||||
ggml_tensor * b4 = ggml_reshape_4d(ctx0, b_c, 1, 1, b_c->ne[0], 1);
|
||||
return ggml_add(ctx0, x_whcb, b4);
|
||||
}
|
||||
|
||||
static ggml_tensor * mul_channel_weight(
|
||||
ggml_context * ctx0,
|
||||
ggml_tensor * x_whcb,
|
||||
ggml_tensor * w_c) {
|
||||
if (!w_c) {
|
||||
return x_whcb;
|
||||
}
|
||||
ggml_tensor * w4 = ggml_reshape_4d(ctx0, w_c, 1, 1, w_c->ne[0], 1);
|
||||
return ggml_mul(ctx0, x_whcb, w4);
|
||||
}
|
||||
|
||||
ggml_tensor * clip_graph_yasa2::layer_norm_channels(ggml_tensor * inp, ggml_tensor * w, ggml_tensor * b, float eps) {
|
||||
// Match HF ConvNextLayerNorm(channels_first):
|
||||
// u = mean_c(x), s = mean_c((x-u)^2), x = (x-u)/sqrt(s+eps)
|
||||
// cast back to input dtype before affine.
|
||||
ggml_tensor * cur = ggml_permute(ctx0, inp, 2, 1, 0, 3); // [W,H,C,B] -> [C,H,W,B]
|
||||
cur = ggml_cont(ctx0, cur);
|
||||
|
||||
ggml_tensor * u = ggml_mean(ctx0, cur); // [1,H,W,B]
|
||||
ggml_tensor * xm = ggml_sub(ctx0, cur, u); // [C,H,W,B]
|
||||
|
||||
ggml_tensor * s = ggml_mul(ctx0, xm, xm); // [C,H,W,B]
|
||||
s = ggml_mean(ctx0, s); // [1,H,W,B]
|
||||
s = ggml_clamp(ctx0, s, eps, 1e30f); // avoid div-by-zero in no-alloc warmup
|
||||
s = ggml_sqrt(ctx0, s); // [1,H,W,B]
|
||||
|
||||
ggml_tensor * xhat = ggml_div(ctx0, xm, s); // [C,H,W,B]
|
||||
xhat = ggml_permute(ctx0, xhat, 2, 1, 0, 3); // [W,H,C,B]
|
||||
xhat = ggml_cont(ctx0, xhat);
|
||||
xhat = mul_channel_weight(ctx0, xhat, w);
|
||||
xhat = add_channel_bias(ctx0, xhat, b);
|
||||
return xhat;
|
||||
}
|
||||
|
||||
ggml_tensor * clip_graph_yasa2::convnext_grn(ggml_tensor * inp, ggml_tensor * w, ggml_tensor * b) {
|
||||
// Exact ConvNeXtV2 GRN:
|
||||
// Gx = ||x||_2 over spatial dims (W,H), Nx = Gx / (mean_c(Gx) + eps)
|
||||
// y = w * (x * Nx) + b + x
|
||||
const int64_t wdim = inp->ne[0];
|
||||
const int64_t hdim = inp->ne[1];
|
||||
const int64_t cdim = inp->ne[2];
|
||||
const int64_t bdim = inp->ne[3];
|
||||
|
||||
// Keep GRN math in fp32 for stability; fp16/bf16 accumulation can drift.
|
||||
ggml_tensor * sq = ggml_mul(ctx0, inp, inp);
|
||||
ggml_tensor * sq_flat = ggml_reshape_4d(ctx0, sq, wdim * hdim, cdim, 1, bdim); // [WH,C,1,B]
|
||||
ggml_tensor * gx = ggml_sum_rows(ctx0, sq_flat); // [1,C,1,B]
|
||||
gx = ggml_sqrt(ctx0, gx); // [1,C,1,B]
|
||||
|
||||
ggml_tensor * gx_ch_first = ggml_permute(ctx0, gx, 1, 0, 2, 3); // [C,1,1,B]
|
||||
gx_ch_first = ggml_cont(ctx0, gx_ch_first);
|
||||
ggml_tensor * gx_mean = ggml_mean(ctx0, gx_ch_first); // [1,1,1,B]
|
||||
|
||||
gx_mean = ggml_clamp(ctx0, gx_mean, 1e-6f, 1e30f); // approx +eps, warmup-safe
|
||||
ggml_tensor * nx = ggml_div(ctx0, gx, gx_mean); // [1,C,1,B]
|
||||
nx = ggml_permute(ctx0, nx, 0, 2, 1, 3); // [1,1,C,B]
|
||||
nx = ggml_cont(ctx0, nx);
|
||||
|
||||
ggml_tensor * xnx = ggml_mul(ctx0, inp, nx);
|
||||
xnx = mul_channel_weight(ctx0, xnx, w);
|
||||
xnx = add_channel_bias(ctx0, xnx, b);
|
||||
return ggml_add(ctx0, inp, xnx);
|
||||
}
|
||||
|
||||
ggml_cgraph * clip_graph_yasa2::build() {
|
||||
ggml_tensor * cur = build_inp_raw();
|
||||
|
||||
// Patch embedding Conv2d(kernel=4, stride=4)
|
||||
cur = ggml_conv_2d(ctx0, model.yasa_patch_w, cur, patch_size, patch_size, 0, 0, 1, 1);
|
||||
cur = add_channel_bias(ctx0, cur, model.yasa_patch_b);
|
||||
ggml_set_name(cur, "yasa2_patch_conv_out");
|
||||
cb(cur, "yasa2_patch_conv_out", -1);
|
||||
cur = layer_norm_channels(cur, model.yasa_patch_ln_w, model.yasa_patch_ln_b, eps);
|
||||
ggml_set_name(cur, "yasa2_patch_ln_out");
|
||||
cb(cur, "yasa2_patch_ln_out", -1);
|
||||
|
||||
// ConvNeXt stages
|
||||
for (size_t s = 0; s < model.yasa_stages.size(); ++s) {
|
||||
const auto & stage = model.yasa_stages[s];
|
||||
|
||||
if (stage.down_conv_w) {
|
||||
cur = layer_norm_channels(cur, stage.down_ln_w, stage.down_ln_b, eps);
|
||||
cur = ggml_conv_2d(ctx0, stage.down_conv_w, cur, 2, 2, 0, 0, 1, 1);
|
||||
cur = add_channel_bias(ctx0, cur, stage.down_conv_b);
|
||||
ggml_format_name(cur, "yasa2_stage%zu_down_out", s);
|
||||
}
|
||||
|
||||
for (size_t bi = 0; bi < stage.blocks.size(); ++bi) {
|
||||
const auto & blk = stage.blocks[bi];
|
||||
ggml_tensor * res = cur;
|
||||
|
||||
ggml_tensor * x = ggml_conv_2d_dw(ctx0, blk.dw_w, cur, 1, 1, 3, 3, 1, 1);
|
||||
x = add_channel_bias(ctx0, x, blk.dw_b);
|
||||
x = layer_norm_channels(x, blk.ln_w, blk.ln_b, eps);
|
||||
|
||||
// pwconv1/pwconv2 are HF Linear layers over channels; implement via matmul on tokens.
|
||||
const int64_t w = x->ne[0];
|
||||
const int64_t h = x->ne[1];
|
||||
const int64_t b = x->ne[3];
|
||||
|
||||
ggml_tensor * tok = ggml_reshape_3d(ctx0, x, w * h, x->ne[2], b); // [T,C,B]
|
||||
tok = ggml_permute(ctx0, tok, 1, 0, 2, 3); // [C,T,B]
|
||||
tok = ggml_cont(ctx0, tok);
|
||||
|
||||
tok = ggml_mul_mat(ctx0, blk.pw1_w, tok); // [4C,T,B]
|
||||
if (blk.pw1_b) {
|
||||
ggml_tensor * b1 = ggml_reshape_3d(ctx0, blk.pw1_b, blk.pw1_b->ne[0], 1, 1); // [4C,1,1]
|
||||
tok = ggml_add(ctx0, tok, b1);
|
||||
}
|
||||
x = ggml_permute(ctx0, tok, 1, 0, 2, 3); // [T,4C,B]
|
||||
x = ggml_cont(ctx0, x);
|
||||
x = ggml_reshape_4d(ctx0, x, w, h, tok->ne[0], b); // [W,H,4C,B]
|
||||
x = ggml_gelu_erf(ctx0, x);
|
||||
x = convnext_grn(x, blk.grn_w, blk.grn_b);
|
||||
|
||||
tok = ggml_reshape_3d(ctx0, x, w * h, x->ne[2], b); // [T,4C,B]
|
||||
tok = ggml_permute(ctx0, tok, 1, 0, 2, 3); // [4C,T,B]
|
||||
tok = ggml_cont(ctx0, tok);
|
||||
|
||||
tok = ggml_mul_mat(ctx0, blk.pw2_w, tok); // [C,T,B]
|
||||
if (blk.pw2_b) {
|
||||
ggml_tensor * b2 = ggml_reshape_3d(ctx0, blk.pw2_b, blk.pw2_b->ne[0], 1, 1); // [C,1,1]
|
||||
tok = ggml_add(ctx0, tok, b2);
|
||||
}
|
||||
x = ggml_permute(ctx0, tok, 1, 0, 2, 3); // [T,C,B]
|
||||
x = ggml_cont(ctx0, x);
|
||||
x = ggml_reshape_4d(ctx0, x, w, h, tok->ne[0], b); // [W,H,C,B]
|
||||
|
||||
cur = ggml_add(ctx0, res, x);
|
||||
ggml_format_name(cur, "yasa2_stage%zu_blk%zu_out", s, bi);
|
||||
}
|
||||
}
|
||||
|
||||
// HF path adds vision position embeddings BEFORE adaptive pooling.
|
||||
const int64_t pre_w = cur->ne[0];
|
||||
const int64_t pre_h = cur->ne[1];
|
||||
ggml_tensor * tokens_pre = ggml_reshape_3d(ctx0, cur, pre_w * pre_h, cur->ne[2], cur->ne[3]); // [T,C,B]
|
||||
tokens_pre = ggml_permute(ctx0, tokens_pre, 1, 0, 2, 3); // [C,T,B]
|
||||
tokens_pre = ggml_cont(ctx0, tokens_pre);
|
||||
if (model.yasa_vision_pos_embed && tokens_pre->ne[1] == model.yasa_vision_pos_embed->ne[1]) {
|
||||
const int64_t n_ch = model.yasa_vision_pos_embed->ne[0];
|
||||
const int64_t n_tokens = model.yasa_vision_pos_embed->ne[1];
|
||||
ggml_tensor * pos = ggml_reshape_3d(ctx0, model.yasa_vision_pos_embed, (int) n_ch, (int) n_tokens, 1);
|
||||
tokens_pre = ggml_add(ctx0, tokens_pre, pos);
|
||||
}
|
||||
cur = ggml_permute(ctx0, tokens_pre, 1, 0, 2, 3); // [T,C,B]
|
||||
cur = ggml_cont(ctx0, cur);
|
||||
cur = ggml_reshape_4d(ctx0, cur, pre_w, pre_h, cur->ne[1], cur->ne[2]); // [W,H,C,B]
|
||||
|
||||
// AdaptiveAvgPool2d target is 8x8 for real inputs, but warmup can use tiny images.
|
||||
const int pooled_w = std::min(8, (int) cur->ne[0]);
|
||||
const int pooled_h = std::min(8, (int) cur->ne[1]);
|
||||
const int kw = std::max(1, (int) cur->ne[0] / pooled_w);
|
||||
const int kh = std::max(1, (int) cur->ne[1] / pooled_h);
|
||||
cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, kw, kh, kw, kh, 0, 0);
|
||||
|
||||
// [W,H,C,B] -> [C,T,B]
|
||||
ggml_tensor * tokens = ggml_reshape_3d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2], cur->ne[3]);
|
||||
tokens = ggml_permute(ctx0, tokens, 1, 0, 2, 3);
|
||||
tokens = ggml_cont(ctx0, tokens);
|
||||
cb(tokens, "yasa2_tokens", -1);
|
||||
|
||||
GGML_ASSERT(model.mm_0_w && model.mm_2_w);
|
||||
ggml_tensor * embeddings = build_ffn(
|
||||
tokens,
|
||||
model.mm_0_w, model.mm_0_b,
|
||||
nullptr, nullptr,
|
||||
model.mm_2_w, model.mm_2_b,
|
||||
FFN_GELU_ERF,
|
||||
-1);
|
||||
cb(embeddings, "yasa2_emb", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, embeddings);
|
||||
return gf;
|
||||
}
|
||||
|
|
@ -35,15 +35,23 @@ struct mtmd_bitmap {
|
|||
|
||||
// position indexing for decoder model
|
||||
enum mtmd_pos_type {
|
||||
MTMD_POS_TYPE_NORMAL, // number of positions equals to number of tokens
|
||||
MTMD_POS_TYPE_MROPE, // qwen-vl mrope style, each image takes max(t,h,w) position indexes
|
||||
MTMD_POS_TYPE_NORMAL, // number of positions equals to number of tokens
|
||||
MTMD_POS_TYPE_MROPE, // qwen-vl mrope style, each image takes max(t,h,w) position indexes
|
||||
MTMD_POS_TYPE_HUNYUANVL, // HunyuanVL mrope + BOI/EOI/newline layout with XD-RoPE dim-3
|
||||
};
|
||||
|
||||
struct mtmd_image_tokens {
|
||||
uint32_t nx; // number of tokens in x direction
|
||||
uint32_t ny; // number of tokens in y direction
|
||||
mtmd_pos_type pos = MTMD_POS_TYPE_NORMAL;
|
||||
uint32_t n_tokens() const { return nx * ny; }
|
||||
uint32_t image_idx = 0; // 0-based position of this image among image chunks in the prompt(used by pos == MTMD_POS_TYPE_HUNYUANVL)
|
||||
uint32_t n_tokens() const {
|
||||
if (pos == MTMD_POS_TYPE_HUNYUANVL) {
|
||||
// [BOI] [row0 tokens + newline] ... [row(ny-1) tokens + newline] [EOI]
|
||||
return (nx + 1) * ny + 2;
|
||||
}
|
||||
return nx * ny;
|
||||
}
|
||||
clip_image_f32_batch batch_f32; // preprocessed image patches
|
||||
std::string id; // optional user-defined ID, useful for KV cache tracking
|
||||
|
||||
|
|
@ -52,6 +60,7 @@ struct mtmd_image_tokens {
|
|||
nx,
|
||||
ny,
|
||||
pos,
|
||||
image_idx,
|
||||
batch_f32.clone(),
|
||||
id
|
||||
};
|
||||
|
|
@ -186,6 +195,7 @@ struct mtmd_context {
|
|||
|
||||
auto decoder_rope_type = llama_model_rope_type(text_model);
|
||||
switch (decoder_rope_type) {
|
||||
case LLAMA_ROPE_TYPE_NONE:
|
||||
case LLAMA_ROPE_TYPE_NORM:
|
||||
case LLAMA_ROPE_TYPE_NEOX:
|
||||
{
|
||||
|
|
@ -316,6 +326,19 @@ struct mtmd_context {
|
|||
img_end = "<|vision_end|>";
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_youtuvl>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_YASA2:
|
||||
{
|
||||
img_beg = "<image>";
|
||||
img_end = "</image>";
|
||||
// Currently only supprots single-tile preprocessing: any input is downscaled
|
||||
// to one image_size x image_size tile (64 output tokens via 8x8 adaptive avg
|
||||
// pool).
|
||||
// However, the model itself supports llava-uhd multi-tile tiling for high-res
|
||||
// images. This will be implemented in a future PR (dispatch on has_pinpoints
|
||||
// - see LDP/COGVLM branch above) and emit image_grid_pinpoints in the conversion
|
||||
// script.
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_fixed_size>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
case PROJECTOR_TYPE_GEMMA3NV:
|
||||
{
|
||||
|
|
@ -453,6 +476,7 @@ struct mtmd_context {
|
|||
image_preproc = std::make_unique<mtmd_image_preprocessor_deepseekocr>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_HUNYUANOCR:
|
||||
case PROJECTOR_TYPE_HUNYUANVL:
|
||||
{
|
||||
// note: these use fullwidth | (U+FF5C) and ▁ (U+2581) to match the tokenizer vocabulary
|
||||
img_beg = "<|hy_place▁holder▁no▁100|>";
|
||||
|
|
@ -598,6 +622,7 @@ struct mtmd_tokenizer {
|
|||
const llama_vocab * vocab;
|
||||
|
||||
mtmd_input_chunks cur;
|
||||
uint32_t n_images_added = 0; // 0-based index assigned to the next image chunk
|
||||
|
||||
mtmd_tokenizer(mtmd_context * ctx,
|
||||
const mtmd_input_text * text,
|
||||
|
|
@ -806,6 +831,14 @@ struct mtmd_tokenizer {
|
|||
image_tokens->ny = 1;
|
||||
}
|
||||
image_tokens->pos = ctx->pos_type;
|
||||
// HunyuanVL wraps the image grid with BOI/EOI and adds one newline per row,
|
||||
// and uses XD-RoPE (dim-3 = image index). Override the position type so that
|
||||
// n_tokens() and mtmd_image_tokens_get_decoder_pos pick the HunyuanVL layout.
|
||||
if (ctx->proj_type_v() == PROJECTOR_TYPE_HUNYUANVL) {
|
||||
image_tokens->pos = MTMD_POS_TYPE_HUNYUANVL;
|
||||
image_tokens->image_idx = n_images_added;
|
||||
GGML_ASSERT(n_tokens == (size_t)image_tokens->n_tokens());
|
||||
}
|
||||
image_tokens->batch_f32 = std::move(batch_f32);
|
||||
image_tokens->id = bitmap->id; // optional
|
||||
|
||||
|
|
@ -826,6 +859,9 @@ struct mtmd_tokenizer {
|
|||
add_text(ctx->img_end, true); // add image end token
|
||||
}
|
||||
|
||||
// advance image-chunk counter so the next image gets the next XD-RoPE dim-3 slot
|
||||
n_images_added++;
|
||||
|
||||
} else {
|
||||
// handle audio
|
||||
|
||||
|
|
@ -1273,6 +1309,38 @@ mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * ima
|
|||
pos.y = pos_0 + i;
|
||||
pos.z = pos_0 + i;
|
||||
} break;
|
||||
case MTMD_POS_TYPE_HUNYUANVL:
|
||||
{
|
||||
// HunyuanVL layout: [BOI] [row0 tokens + newline] ... [row(ny-1) tokens + newline] [EOI]
|
||||
// Total = 1 + ny*(nx+1) + 1. BOI and EOI use sequential positions in every dim;
|
||||
// content and row-newline tokens use (row, col) with XD-RoPE dim-3 = image_idx.
|
||||
const uint32_t nx = image_tokens->nx;
|
||||
const uint32_t n_total = image_tokens->n_tokens();
|
||||
if (i == 0) {
|
||||
// BOI
|
||||
pos.t = pos_0 + i;
|
||||
pos.x = pos_0 + i;
|
||||
pos.y = pos_0 + i;
|
||||
pos.z = pos_0 + i;
|
||||
} else if (i == n_total - 1) {
|
||||
// EOI
|
||||
pos.t = pos_0 + i;
|
||||
pos.x = pos_0 + i;
|
||||
pos.y = pos_0 + i;
|
||||
pos.z = pos_0 + i;
|
||||
} else {
|
||||
// content token at (row, col), or the trailing newline of a row (col == nx)
|
||||
// section 0 = sequential, section 1 = w(col), section 2 = h(row), section 3 = image_count.
|
||||
// set_position_mrope_2d writes .y -> section 1 and .x -> section 2
|
||||
const uint32_t offset = (uint32_t)i - 1;
|
||||
const uint32_t row = offset / (nx + 1);
|
||||
const uint32_t col = offset % (nx + 1);
|
||||
pos.t = pos_0 + i;
|
||||
pos.x = row;
|
||||
pos.y = col;
|
||||
pos.z = image_tokens->image_idx;
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("invalid position type");
|
||||
}
|
||||
|
|
@ -1289,6 +1357,10 @@ llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
|
|||
return std::max(image_tokens->nx, image_tokens->ny);
|
||||
case MTMD_POS_TYPE_NORMAL:
|
||||
return image_tokens->n_tokens();
|
||||
case MTMD_POS_TYPE_HUNYUANVL:
|
||||
// HunyuanVL: the sequential (dim-0) position advances by the full token count
|
||||
// (includes BOI/EOI and row newline tokens), not by max(nx, ny)
|
||||
return image_tokens->n_tokens();
|
||||
default:
|
||||
GGML_ABORT("invalid position type");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -91,6 +91,7 @@ add_test_vision "ggml-org/LightOnOCR-1B-1025-GGUF:Q8_0"
|
|||
add_test_vision "ggml-org/DeepSeek-OCR-GGUF:Q8_0" -p "Free OCR." --chat-template deepseek-ocr
|
||||
add_test_vision "ggml-org/dots.ocr-GGUF:Q8_0" -p "OCR"
|
||||
add_test_vision "ggml-org/HunyuanOCR-GGUF:Q8_0" -p "OCR"
|
||||
add_test_vision "ggml-org/HunyuanVL-4B-GGUF:Q8_0"
|
||||
add_test_vision "ggml-org/gemma-4-E2B-it-GGUF:Q8_0" --jinja
|
||||
|
||||
add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0"
|
||||
|
|
|
|||
588
tools/server/server-chat.cpp
Normal file
588
tools/server/server-chat.cpp
Normal file
|
|
@ -0,0 +1,588 @@
|
|||
#include "server-chat.h"
|
||||
#include "server-common.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
json server_chat_convert_responses_to_chatcmpl(const json & response_body) {
|
||||
if (!response_body.contains("input")) {
|
||||
throw std::invalid_argument("'input' is required");
|
||||
}
|
||||
if (!json_value(response_body, "previous_response_id", std::string{}).empty()) {
|
||||
throw std::invalid_argument("llama.cpp does not support 'previous_response_id'.");
|
||||
}
|
||||
|
||||
const json input_value = response_body.at("input");
|
||||
json chatcmpl_body = response_body;
|
||||
chatcmpl_body.erase("input");
|
||||
std::vector<json> chatcmpl_messages;
|
||||
|
||||
if (response_body.contains("instructions")) {
|
||||
chatcmpl_messages.push_back({
|
||||
{"role", "system"},
|
||||
{"content", json_value(response_body, "instructions", std::string())},
|
||||
});
|
||||
chatcmpl_body.erase("instructions");
|
||||
}
|
||||
|
||||
if (input_value.is_string()) {
|
||||
// #responses_create-input-text_input
|
||||
chatcmpl_messages.push_back({
|
||||
{"role", "user"},
|
||||
{"content", input_value},
|
||||
});
|
||||
} else if (input_value.is_array()) {
|
||||
// #responses_create-input-input_item_list
|
||||
|
||||
static auto exists_and_is_array = [](const json & j, const char * key) -> bool {
|
||||
return j.contains(key) && j.at(key).is_array();
|
||||
};
|
||||
static auto exists_and_is_string = [](const json & j, const char * key) -> bool {
|
||||
return j.contains(key) && j.at(key).is_string();
|
||||
};
|
||||
|
||||
for (json item : input_value) {
|
||||
bool merge_prev = !chatcmpl_messages.empty() && chatcmpl_messages.back().value("role", "") == "assistant";
|
||||
|
||||
if (exists_and_is_string(item, "content")) {
|
||||
// #responses_create-input-input_item_list-input_message-content-text_input
|
||||
// Only "Input message" contains item["content"]::string
|
||||
// After converting item["content"]::string to item["content"]::array,
|
||||
// we can treat "Input message" as sum of "Item-Input message" and "Item-Output message"
|
||||
item["content"] = json::array({
|
||||
json {
|
||||
{"text", item.at("content")},
|
||||
{"type", "input_text"}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (exists_and_is_array(item, "content") &&
|
||||
exists_and_is_string(item, "role") &&
|
||||
(item.at("role") == "user" ||
|
||||
item.at("role") == "system" ||
|
||||
item.at("role") == "developer")
|
||||
) {
|
||||
// #responses_create-input-input_item_list-item-input_message
|
||||
std::vector<json> chatcmpl_content;
|
||||
|
||||
for (const json & input_item : item.at("content")) {
|
||||
const std::string type = json_value(input_item, "type", std::string());
|
||||
|
||||
if (type == "input_text") {
|
||||
if (!input_item.contains("text")) {
|
||||
throw std::invalid_argument("'Input text' requires 'text'");
|
||||
}
|
||||
chatcmpl_content.push_back({
|
||||
{"text", input_item.at("text")},
|
||||
{"type", "text"},
|
||||
});
|
||||
} else if (type == "input_image") {
|
||||
// While `detail` is marked as required,
|
||||
// it has default value("auto") and can be omitted.
|
||||
|
||||
if (!input_item.contains("image_url")) {
|
||||
throw std::invalid_argument("'image_url' is required");
|
||||
}
|
||||
chatcmpl_content.push_back({
|
||||
{"image_url", json {
|
||||
{"url", input_item.at("image_url")}
|
||||
}},
|
||||
{"type", "image_url"},
|
||||
});
|
||||
} else if (type == "input_file") {
|
||||
throw std::invalid_argument("'input_file' is not supported by llamacpp at this moment");
|
||||
} else {
|
||||
throw std::invalid_argument("'type' must be one of 'input_text', 'input_image', or 'input_file'");
|
||||
}
|
||||
}
|
||||
|
||||
if (item.contains("type")) {
|
||||
item.erase("type");
|
||||
}
|
||||
if (item.contains("status")) {
|
||||
item.erase("status");
|
||||
}
|
||||
item["content"] = chatcmpl_content;
|
||||
|
||||
chatcmpl_messages.push_back(item);
|
||||
} else if (exists_and_is_string(item, "role") &&
|
||||
item.at("role") == "assistant" &&
|
||||
exists_and_is_string(item, "type") &&
|
||||
item.at("type") == "message"
|
||||
) {
|
||||
// #responses_create-input-input_item_list-item-output_message
|
||||
auto chatcmpl_content = json::array();
|
||||
|
||||
// Handle both string content and array content
|
||||
if (item.contains("content") && item.at("content").is_string()) {
|
||||
// String content - convert to text content part
|
||||
chatcmpl_content.push_back({
|
||||
{"text", item.at("content")},
|
||||
{"type", "text"},
|
||||
});
|
||||
} else if (exists_and_is_array(item, "content")) {
|
||||
// Array content - process each item
|
||||
for (const auto & output_text : item.at("content")) {
|
||||
const std::string type = json_value(output_text, "type", std::string());
|
||||
if (type == "output_text" || type == "input_text") {
|
||||
// Accept both output_text and input_text (string content gets converted to input_text)
|
||||
if (!exists_and_is_string(output_text, "text")) {
|
||||
throw std::invalid_argument("'Output text' requires 'text'");
|
||||
}
|
||||
chatcmpl_content.push_back({
|
||||
{"text", output_text.at("text")},
|
||||
{"type", "text"},
|
||||
});
|
||||
} else if (type == "refusal") {
|
||||
if (!exists_and_is_string(output_text, "refusal")) {
|
||||
throw std::invalid_argument("'Refusal' requires 'refusal'");
|
||||
}
|
||||
chatcmpl_content.push_back({
|
||||
{"refusal", output_text.at("refusal")},
|
||||
{"type", "refusal"},
|
||||
});
|
||||
} else {
|
||||
throw std::invalid_argument("'type' must be one of 'output_text' or 'refusal'");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (merge_prev) {
|
||||
auto & prev_msg = chatcmpl_messages.back();
|
||||
if (!exists_and_is_array(prev_msg, "content")) {
|
||||
prev_msg["content"] = json::array();
|
||||
}
|
||||
auto & prev_content = prev_msg["content"];
|
||||
prev_content.insert(prev_content.end(), chatcmpl_content.begin(), chatcmpl_content.end());
|
||||
} else {
|
||||
item.erase("status");
|
||||
item.erase("type");
|
||||
item["content"] = chatcmpl_content;
|
||||
chatcmpl_messages.push_back(item);
|
||||
}
|
||||
} else if (exists_and_is_string(item, "arguments") &&
|
||||
exists_and_is_string(item, "call_id") &&
|
||||
exists_and_is_string(item, "name") &&
|
||||
exists_and_is_string(item, "type") &&
|
||||
item.at("type") == "function_call"
|
||||
) {
|
||||
// #responses_create-input-input_item_list-item-function_tool_call
|
||||
json tool_call = {
|
||||
{"function", json {
|
||||
{"arguments", item.at("arguments")},
|
||||
{"name", item.at("name")},
|
||||
}},
|
||||
{"id", item.at("call_id")},
|
||||
{"type", "function"},
|
||||
};
|
||||
|
||||
if (merge_prev) {
|
||||
auto & prev_msg = chatcmpl_messages.back();
|
||||
if (!exists_and_is_array(prev_msg, "tool_calls")) {
|
||||
prev_msg["tool_calls"] = json::array();
|
||||
}
|
||||
prev_msg["tool_calls"].push_back(tool_call);
|
||||
} else {
|
||||
chatcmpl_messages.push_back(json {
|
||||
{"role", "assistant"},
|
||||
{"tool_calls", json::array({tool_call})}
|
||||
});
|
||||
}
|
||||
} else if (exists_and_is_string(item, "call_id") &&
|
||||
(exists_and_is_string(item, "output") || exists_and_is_array(item, "output")) &&
|
||||
exists_and_is_string(item, "type") &&
|
||||
item.at("type") == "function_call_output"
|
||||
) {
|
||||
// #responses_create-input-input_item_list-item-function_tool_call_output
|
||||
if (item.at("output").is_string()) {
|
||||
chatcmpl_messages.push_back(json {
|
||||
{"content", item.at("output")},
|
||||
{"role", "tool"},
|
||||
{"tool_call_id", item.at("call_id")},
|
||||
});
|
||||
} else {
|
||||
json chatcmpl_outputs = item.at("output");
|
||||
for (json & chatcmpl_output : chatcmpl_outputs) {
|
||||
if (!chatcmpl_output.contains("type") || chatcmpl_output.at("type") != "input_text") {
|
||||
throw std::invalid_argument("Output of tool call should be 'Input text'");
|
||||
}
|
||||
chatcmpl_output["type"] = "text";
|
||||
}
|
||||
chatcmpl_messages.push_back(json {
|
||||
{"content", chatcmpl_outputs},
|
||||
{"role", "tool"},
|
||||
{"tool_call_id", item.at("call_id")},
|
||||
});
|
||||
}
|
||||
} else if (exists_and_is_array(item, "summary") &&
|
||||
exists_and_is_string(item, "type") &&
|
||||
item.at("type") == "reasoning") {
|
||||
// #responses_create-input-input_item_list-item-reasoning
|
||||
|
||||
if (!exists_and_is_array(item, "content")) {
|
||||
throw std::invalid_argument("item['content'] is not an array");
|
||||
}
|
||||
if (item.at("content").empty()) {
|
||||
throw std::invalid_argument("item['content'] is empty");
|
||||
}
|
||||
if (!exists_and_is_string(item.at("content")[0], "text")) {
|
||||
throw std::invalid_argument("item['content']['text'] is not a string");
|
||||
}
|
||||
|
||||
if (merge_prev) {
|
||||
auto & prev_msg = chatcmpl_messages.back();
|
||||
prev_msg["reasoning_content"] = item.at("content")[0].at("text");
|
||||
} else {
|
||||
chatcmpl_messages.push_back(json {
|
||||
{"role", "assistant"},
|
||||
{"content", json::array()},
|
||||
{"reasoning_content", item.at("content")[0].at("text")},
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument("Cannot determine type of 'item'");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument("'input' must be a string or array of objects");
|
||||
}
|
||||
|
||||
chatcmpl_body["messages"] = chatcmpl_messages;
|
||||
|
||||
if (response_body.contains("tools")) {
|
||||
if (!response_body.at("tools").is_array()) {
|
||||
throw std::invalid_argument("'tools' must be an array of objects");
|
||||
}
|
||||
std::vector<json> chatcmpl_tools;
|
||||
for (json resp_tool : response_body.at("tools")) {
|
||||
json chatcmpl_tool;
|
||||
|
||||
if (json_value(resp_tool, "type", std::string()) != "function") {
|
||||
throw std::invalid_argument("'type' of tool must be 'function'");
|
||||
}
|
||||
resp_tool.erase("type");
|
||||
chatcmpl_tool["type"] = "function";
|
||||
|
||||
if (!resp_tool.contains("strict")) {
|
||||
resp_tool["strict"] = true;
|
||||
}
|
||||
chatcmpl_tool["function"] = resp_tool;
|
||||
chatcmpl_tools.push_back(chatcmpl_tool);
|
||||
}
|
||||
chatcmpl_body.erase("tools");
|
||||
chatcmpl_body["tools"] = chatcmpl_tools;
|
||||
}
|
||||
|
||||
if (response_body.contains("max_output_tokens")) {
|
||||
chatcmpl_body.erase("max_output_tokens");
|
||||
chatcmpl_body["max_tokens"] = response_body["max_output_tokens"];
|
||||
}
|
||||
|
||||
return chatcmpl_body;
|
||||
}
|
||||
|
||||
json server_chat_convert_anthropic_to_oai(const json & body) {
|
||||
json oai_body;
|
||||
|
||||
// Convert system prompt
|
||||
json oai_messages = json::array();
|
||||
auto system_param = json_value(body, "system", json());
|
||||
if (!system_param.is_null()) {
|
||||
std::string system_content;
|
||||
|
||||
if (system_param.is_string()) {
|
||||
system_content = system_param.get<std::string>();
|
||||
} else if (system_param.is_array()) {
|
||||
for (const auto & block : system_param) {
|
||||
if (json_value(block, "type", std::string()) == "text") {
|
||||
system_content += json_value(block, "text", std::string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
oai_messages.push_back({
|
||||
{"role", "system"},
|
||||
{"content", system_content}
|
||||
});
|
||||
}
|
||||
|
||||
// Convert messages
|
||||
if (!body.contains("messages")) {
|
||||
throw std::runtime_error("'messages' is required");
|
||||
}
|
||||
const json & messages = body.at("messages");
|
||||
if (messages.is_array()) {
|
||||
for (const auto & msg : messages) {
|
||||
std::string role = json_value(msg, "role", std::string());
|
||||
|
||||
if (!msg.contains("content")) {
|
||||
if (role == "assistant") {
|
||||
continue;
|
||||
}
|
||||
oai_messages.push_back(msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
const json & content = msg.at("content");
|
||||
|
||||
if (content.is_string()) {
|
||||
oai_messages.push_back(msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!content.is_array()) {
|
||||
oai_messages.push_back(msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
json tool_calls = json::array();
|
||||
json converted_content = json::array();
|
||||
json tool_results = json::array();
|
||||
std::string reasoning_content;
|
||||
bool has_tool_calls = false;
|
||||
|
||||
for (const auto & block : content) {
|
||||
std::string type = json_value(block, "type", std::string());
|
||||
|
||||
if (type == "text") {
|
||||
converted_content.push_back(block);
|
||||
} else if (type == "thinking") {
|
||||
reasoning_content += json_value(block, "thinking", std::string());
|
||||
} else if (type == "image") {
|
||||
json source = json_value(block, "source", json::object());
|
||||
std::string source_type = json_value(source, "type", std::string());
|
||||
|
||||
if (source_type == "base64") {
|
||||
std::string media_type = json_value(source, "media_type", std::string("image/jpeg"));
|
||||
std::string data = json_value(source, "data", std::string());
|
||||
std::ostringstream ss;
|
||||
ss << "data:" << media_type << ";base64," << data;
|
||||
|
||||
converted_content.push_back({
|
||||
{"type", "image_url"},
|
||||
{"image_url", {
|
||||
{"url", ss.str()}
|
||||
}}
|
||||
});
|
||||
} else if (source_type == "url") {
|
||||
std::string url = json_value(source, "url", std::string());
|
||||
converted_content.push_back({
|
||||
{"type", "image_url"},
|
||||
{"image_url", {
|
||||
{"url", url}
|
||||
}}
|
||||
});
|
||||
}
|
||||
} else if (type == "tool_use") {
|
||||
tool_calls.push_back({
|
||||
{"id", json_value(block, "id", std::string())},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", json_value(block, "name", std::string())},
|
||||
{"arguments", json_value(block, "input", json::object()).dump()}
|
||||
}}
|
||||
});
|
||||
has_tool_calls = true;
|
||||
} else if (type == "tool_result") {
|
||||
std::string tool_use_id = json_value(block, "tool_use_id", std::string());
|
||||
|
||||
auto result_content = json_value(block, "content", json());
|
||||
std::string result_text;
|
||||
if (result_content.is_string()) {
|
||||
result_text = result_content.get<std::string>();
|
||||
} else if (result_content.is_array()) {
|
||||
for (const auto & c : result_content) {
|
||||
if (json_value(c, "type", std::string()) == "text") {
|
||||
result_text += json_value(c, "text", std::string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tool_results.push_back({
|
||||
{"role", "tool"},
|
||||
{"tool_call_id", tool_use_id},
|
||||
{"content", result_text}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (!converted_content.empty() || has_tool_calls || !reasoning_content.empty()) {
|
||||
json new_msg = {{"role", role}};
|
||||
if (!converted_content.empty()) {
|
||||
new_msg["content"] = converted_content;
|
||||
} else if (has_tool_calls || !reasoning_content.empty()) {
|
||||
new_msg["content"] = "";
|
||||
}
|
||||
if (!tool_calls.empty()) {
|
||||
new_msg["tool_calls"] = tool_calls;
|
||||
}
|
||||
if (!reasoning_content.empty()) {
|
||||
new_msg["reasoning_content"] = reasoning_content;
|
||||
}
|
||||
oai_messages.push_back(new_msg);
|
||||
}
|
||||
|
||||
for (const auto & tool_msg : tool_results) {
|
||||
oai_messages.push_back(tool_msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
oai_body["messages"] = oai_messages;
|
||||
|
||||
// Convert tools
|
||||
if (body.contains("tools")) {
|
||||
const json & tools = body.at("tools");
|
||||
if (tools.is_array()) {
|
||||
json oai_tools = json::array();
|
||||
for (const auto & tool : tools) {
|
||||
oai_tools.push_back({
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", json_value(tool, "name", std::string())},
|
||||
{"description", json_value(tool, "description", std::string())},
|
||||
{"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()}
|
||||
}}
|
||||
});
|
||||
}
|
||||
oai_body["tools"] = oai_tools;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert tool_choice
|
||||
if (body.contains("tool_choice")) {
|
||||
const json & tc = body.at("tool_choice");
|
||||
if (tc.is_object()) {
|
||||
std::string type = json_value(tc, "type", std::string());
|
||||
if (type == "auto") {
|
||||
oai_body["tool_choice"] = "auto";
|
||||
} else if (type == "any" || type == "tool") {
|
||||
oai_body["tool_choice"] = "required";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert stop_sequences to stop
|
||||
if (body.contains("stop_sequences")) {
|
||||
oai_body["stop"] = body.at("stop_sequences");
|
||||
}
|
||||
|
||||
// Handle max_tokens (required in Anthropic, but we're permissive)
|
||||
if (body.contains("max_tokens")) {
|
||||
oai_body["max_tokens"] = body.at("max_tokens");
|
||||
} else {
|
||||
oai_body["max_tokens"] = 4096;
|
||||
}
|
||||
|
||||
// Pass through common params
|
||||
for (const auto & key : {"temperature", "top_p", "top_k", "stream"}) {
|
||||
if (body.contains(key)) {
|
||||
oai_body[key] = body.at(key);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle Anthropic-specific thinking param
|
||||
if (body.contains("thinking")) {
|
||||
json thinking = json_value(body, "thinking", json::object());
|
||||
std::string thinking_type = json_value(thinking, "type", std::string());
|
||||
if (thinking_type == "enabled") {
|
||||
int budget_tokens = json_value(thinking, "budget_tokens", 10000);
|
||||
oai_body["thinking_budget_tokens"] = budget_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle Anthropic-specific metadata param
|
||||
if (body.contains("metadata")) {
|
||||
json metadata = json_value(body, "metadata", json::object());
|
||||
std::string user_id = json_value(metadata, "user_id", std::string());
|
||||
if (!user_id.empty()) {
|
||||
oai_body["__metadata_user_id"] = user_id;
|
||||
}
|
||||
}
|
||||
|
||||
return oai_body;
|
||||
}
|
||||
|
||||
json server_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
|
||||
json delta = json::object();
|
||||
if (!diff.reasoning_content_delta.empty()) {
|
||||
delta["reasoning_content"] = diff.reasoning_content_delta;
|
||||
}
|
||||
if (!diff.content_delta.empty()) {
|
||||
delta["content"] = diff.content_delta;
|
||||
}
|
||||
if (diff.tool_call_index != std::string::npos) {
|
||||
json tool_call;
|
||||
tool_call["index"] = diff.tool_call_index;
|
||||
if (!diff.tool_call_delta.id.empty()) {
|
||||
tool_call["id"] = diff.tool_call_delta.id;
|
||||
tool_call["type"] = "function";
|
||||
}
|
||||
if (!diff.tool_call_delta.name.empty() || !diff.tool_call_delta.arguments.empty()) {
|
||||
json function = json::object();
|
||||
if (!diff.tool_call_delta.name.empty()) {
|
||||
function["name"] = diff.tool_call_delta.name;
|
||||
}
|
||||
if (!diff.tool_call_delta.arguments.empty()) {
|
||||
function["arguments"] = diff.tool_call_delta.arguments;
|
||||
}
|
||||
tool_call["function"] = function;
|
||||
}
|
||||
delta["tool_calls"] = json::array({ tool_call });
|
||||
}
|
||||
return delta;
|
||||
}
|
||||
|
||||
json convert_transcriptions_to_chatcmpl(
|
||||
const json & inp_body,
|
||||
const std::map<std::string, raw_buffer> & in_files,
|
||||
std::vector<raw_buffer> & out_files) {
|
||||
// TODO @ngxson : this function may need to be improved in the future
|
||||
// handle input files
|
||||
out_files.clear();
|
||||
auto it = in_files.find("file");
|
||||
if (it != in_files.end()) {
|
||||
out_files.push_back(it->second);
|
||||
} else {
|
||||
throw std::invalid_argument("No input file found for transcription");
|
||||
}
|
||||
|
||||
// handle input data
|
||||
std::string prompt = json_value(inp_body, "prompt", std::string());
|
||||
std::string language = json_value(inp_body, "language", std::string());
|
||||
std::string response_format = json_value(inp_body, "response_format", std::string("json"));
|
||||
if (response_format != "json") {
|
||||
throw std::invalid_argument("Only 'json' response_format is supported for transcription");
|
||||
}
|
||||
if (prompt.empty()) {
|
||||
prompt = "Transcribe audio to text";
|
||||
}
|
||||
if (!language.empty()) {
|
||||
prompt += string_format(" (language: %s)", language.c_str());
|
||||
}
|
||||
prompt += get_media_marker();
|
||||
|
||||
json chatcmpl_body = inp_body; // copy all fields
|
||||
chatcmpl_body["messages"] = json::array({
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", prompt},
|
||||
},
|
||||
});
|
||||
|
||||
// because input from form-data, everything is string, we need to correct the types here
|
||||
std::string stream = json_value(inp_body, "stream", std::string("false"));
|
||||
chatcmpl_body["stream"] = stream == "true";
|
||||
|
||||
if (inp_body.contains("max_tokens")) {
|
||||
std::string inp = inp_body["max_tokens"].get<std::string>();
|
||||
chatcmpl_body["max_tokens"] = std::stoul(inp);
|
||||
}
|
||||
|
||||
if (inp_body.contains("temperature")) {
|
||||
std::string inp = inp_body["temperature"].get<std::string>();
|
||||
chatcmpl_body["temperature"] = std::stof(inp);
|
||||
}
|
||||
|
||||
return chatcmpl_body;
|
||||
}
|
||||
24
tools/server/server-chat.h
Normal file
24
tools/server/server-chat.h
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
// Chat conversion functions for server (Responses API, Anthropic API, OAI streaming diffs)
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "chat.h"
|
||||
#include "server-common.h"
|
||||
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
// Convert OpenAI Responses API format to OpenAI Chat Completions API format
|
||||
json server_chat_convert_responses_to_chatcmpl(const json & body);
|
||||
|
||||
// Convert Anthropic Messages API format to OpenAI Chat Completions API format
|
||||
json server_chat_convert_anthropic_to_oai(const json & body);
|
||||
|
||||
// convert OpenAI transcriptions API format to OpenAI Chat Completions API format
|
||||
json convert_transcriptions_to_chatcmpl(
|
||||
const json & body,
|
||||
const std::map<std::string, raw_buffer> & in_files,
|
||||
std::vector<raw_buffer> & out_files);
|
||||
|
||||
json server_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|
||||
|
|
@ -1027,6 +1027,8 @@ json oaicompat_chat_params_parse(
|
|||
}
|
||||
}
|
||||
|
||||
auto caps = common_chat_templates_get_caps(opt.tmpls.get());
|
||||
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
|
||||
inputs.tools = common_chat_tools_parse_oaicompat(tools);
|
||||
|
|
@ -1034,7 +1036,7 @@ json oaicompat_chat_params_parse(
|
|||
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
|
||||
inputs.grammar = grammar;
|
||||
inputs.use_jinja = opt.use_jinja;
|
||||
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
|
||||
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", caps["supports_parallel_tool_calls"]);
|
||||
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
|
||||
inputs.reasoning_format = opt.reasoning_format;
|
||||
if (body.contains("reasoning_format")) {
|
||||
|
|
@ -1164,573 +1166,6 @@ json oaicompat_chat_params_parse(
|
|||
return llama_params;
|
||||
}
|
||||
|
||||
json convert_responses_to_chatcmpl(const json & response_body) {
|
||||
if (!response_body.contains("input")) {
|
||||
throw std::invalid_argument("'input' is required");
|
||||
}
|
||||
if (!json_value(response_body, "previous_response_id", std::string{}).empty()) {
|
||||
throw std::invalid_argument("llama.cpp does not support 'previous_response_id'.");
|
||||
}
|
||||
|
||||
const json input_value = response_body.at("input");
|
||||
json chatcmpl_body = response_body;
|
||||
chatcmpl_body.erase("input");
|
||||
std::vector<json> chatcmpl_messages;
|
||||
|
||||
if (response_body.contains("instructions")) {
|
||||
chatcmpl_messages.push_back({
|
||||
{"role", "system"},
|
||||
{"content", json_value(response_body, "instructions", std::string())},
|
||||
});
|
||||
chatcmpl_body.erase("instructions");
|
||||
}
|
||||
|
||||
if (input_value.is_string()) {
|
||||
// #responses_create-input-text_input
|
||||
chatcmpl_messages.push_back({
|
||||
{"role", "user"},
|
||||
{"content", input_value},
|
||||
});
|
||||
} else if (input_value.is_array()) {
|
||||
// #responses_create-input-input_item_list
|
||||
|
||||
static auto exists_and_is_array = [](const json & j, const char * key) -> bool {
|
||||
return j.contains(key) && j.at(key).is_array();
|
||||
};
|
||||
static auto exists_and_is_string = [](const json & j, const char * key) -> bool {
|
||||
return j.contains(key) && j.at(key).is_string();
|
||||
};
|
||||
|
||||
for (json item : input_value) {
|
||||
bool merge_prev = !chatcmpl_messages.empty() && chatcmpl_messages.back().value("role", "") == "assistant";
|
||||
|
||||
if (exists_and_is_string(item, "content")) {
|
||||
// #responses_create-input-input_item_list-input_message-content-text_input
|
||||
// Only "Input message" contains item["content"]::string
|
||||
// After converting item["content"]::string to item["content"]::array,
|
||||
// we can treat "Input message" as sum of "Item-Input message" and "Item-Output message"
|
||||
item["content"] = json::array({
|
||||
json {
|
||||
{"text", item.at("content")},
|
||||
{"type", "input_text"}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (exists_and_is_array(item, "content") &&
|
||||
exists_and_is_string(item, "role") &&
|
||||
(item.at("role") == "user" ||
|
||||
item.at("role") == "system" ||
|
||||
item.at("role") == "developer")
|
||||
) {
|
||||
// #responses_create-input-input_item_list-item-input_message
|
||||
std::vector<json> chatcmpl_content;
|
||||
|
||||
for (const json & input_item : item.at("content")) {
|
||||
const std::string type = json_value(input_item, "type", std::string());
|
||||
|
||||
if (type == "input_text") {
|
||||
if (!input_item.contains("text")) {
|
||||
throw std::invalid_argument("'Input text' requires 'text'");
|
||||
}
|
||||
chatcmpl_content.push_back({
|
||||
{"text", input_item.at("text")},
|
||||
{"type", "text"},
|
||||
});
|
||||
} else if (type == "input_image") {
|
||||
// While `detail` is marked as required,
|
||||
// it has default value("auto") and can be omitted.
|
||||
|
||||
if (!input_item.contains("image_url")) {
|
||||
throw std::invalid_argument("'image_url' is required");
|
||||
}
|
||||
chatcmpl_content.push_back({
|
||||
{"image_url", json {
|
||||
{"url", input_item.at("image_url")}
|
||||
}},
|
||||
{"type", "image_url"},
|
||||
});
|
||||
} else if (type == "input_file") {
|
||||
throw std::invalid_argument("'input_file' is not supported by llamacpp at this moment");
|
||||
// if (input_item.contains("file_url")) {
|
||||
// // chat completion API does not support file_url
|
||||
// throw std::invalid_argument("'file_url' is not supported");
|
||||
// }
|
||||
// if (!input_item.contains("file_data") || !input_item.contains("filename")) {
|
||||
// throw std::invalid_argument("Both 'file_data' and 'filename' are required");
|
||||
// }
|
||||
// chatcmpl_content.push_back({
|
||||
// {"file", json {
|
||||
// {"file_data", input_item.at("file_data")},
|
||||
// {"filename", input_item.at("filename")},
|
||||
// }},
|
||||
// {"type", "file"},
|
||||
// });
|
||||
} else {
|
||||
throw std::invalid_argument("'type' must be one of 'input_text', 'input_image', or 'input_file'");
|
||||
}
|
||||
}
|
||||
|
||||
if (item.contains("type")) {
|
||||
item.erase("type");
|
||||
}
|
||||
if (item.contains("status")) {
|
||||
item.erase("status");
|
||||
}
|
||||
item["content"] = chatcmpl_content;
|
||||
|
||||
chatcmpl_messages.push_back(item);
|
||||
} else if (exists_and_is_array(item, "content") &&
|
||||
exists_and_is_string(item, "role") &&
|
||||
item.at("role") == "assistant" &&
|
||||
// exists_and_is_string(item, "status") &&
|
||||
// (item.at("status") == "in_progress" ||
|
||||
// item.at("status") == "completed" ||
|
||||
// item.at("status") == "incomplete") &&
|
||||
// item["status"] not sent by codex-cli
|
||||
exists_and_is_string(item, "type") &&
|
||||
item.at("type") == "message"
|
||||
) {
|
||||
// #responses_create-input-input_item_list-item-output_message
|
||||
auto chatcmpl_content = json::array();
|
||||
|
||||
for (const auto & output_text : item.at("content")) {
|
||||
const std::string type = json_value(output_text, "type", std::string());
|
||||
if (type == "output_text") {
|
||||
if (!exists_and_is_string(output_text, "text")) {
|
||||
throw std::invalid_argument("'Output text' requires 'text'");
|
||||
// Ignore annotations and logprobs for now
|
||||
chatcmpl_content.push_back({
|
||||
{"text", output_text.at("text")},
|
||||
{"type", "text"},
|
||||
});
|
||||
}
|
||||
} else if (type == "refusal") {
|
||||
if (!exists_and_is_string(output_text, "refusal")) {
|
||||
throw std::invalid_argument("'Refusal' requires 'refusal'");
|
||||
// Ignore annotations and logprobs for now
|
||||
chatcmpl_content.push_back({
|
||||
{"refusal", output_text.at("refusal")},
|
||||
{"type", "refusal"},
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument("'type' must be one of 'output_text' or 'refusal'");
|
||||
}
|
||||
}
|
||||
|
||||
if (merge_prev) {
|
||||
auto & prev_msg = chatcmpl_messages.back();
|
||||
if (!exists_and_is_array(prev_msg, "content")) {
|
||||
prev_msg["content"] = json::array();
|
||||
}
|
||||
auto & prev_content = prev_msg["content"];
|
||||
prev_content.insert(prev_content.end(), chatcmpl_content.begin(), chatcmpl_content.end());
|
||||
} else {
|
||||
item.erase("status");
|
||||
item.erase("type");
|
||||
item["content"] = chatcmpl_content;
|
||||
chatcmpl_messages.push_back(item);
|
||||
}
|
||||
} else if (exists_and_is_string(item, "arguments") &&
|
||||
exists_and_is_string(item, "call_id") &&
|
||||
exists_and_is_string(item, "name") &&
|
||||
exists_and_is_string(item, "type") &&
|
||||
item.at("type") == "function_call"
|
||||
) {
|
||||
// #responses_create-input-input_item_list-item-function_tool_call
|
||||
json tool_call = {
|
||||
{"function", json {
|
||||
{"arguments", item.at("arguments")},
|
||||
{"name", item.at("name")},
|
||||
}},
|
||||
{"id", item.at("call_id")},
|
||||
{"type", "function"},
|
||||
};
|
||||
|
||||
if (merge_prev) {
|
||||
auto & prev_msg = chatcmpl_messages.back();
|
||||
if (!exists_and_is_array(prev_msg, "tool_calls")) {
|
||||
prev_msg["tool_calls"] = json::array();
|
||||
}
|
||||
prev_msg["tool_calls"].push_back(tool_call);
|
||||
} else {
|
||||
chatcmpl_messages.push_back(json {
|
||||
{"role", "assistant"},
|
||||
{"tool_calls", json::array({tool_call})}
|
||||
});
|
||||
}
|
||||
} else if (exists_and_is_string(item, "call_id") &&
|
||||
(exists_and_is_string(item, "output") || exists_and_is_array(item, "output")) &&
|
||||
exists_and_is_string(item, "type") &&
|
||||
item.at("type") == "function_call_output"
|
||||
) {
|
||||
// #responses_create-input-input_item_list-item-function_tool_call_output
|
||||
if (item.at("output").is_string()) {
|
||||
chatcmpl_messages.push_back(json {
|
||||
{"content", item.at("output")},
|
||||
{"role", "tool"},
|
||||
{"tool_call_id", item.at("call_id")},
|
||||
});
|
||||
} else {
|
||||
json chatcmpl_outputs = item.at("output");
|
||||
for (json & chatcmpl_output : chatcmpl_outputs) {
|
||||
if (!chatcmpl_output.contains("type") || chatcmpl_output.at("type") != "input_text") {
|
||||
throw std::invalid_argument("Output of tool call should be 'Input text'");
|
||||
}
|
||||
chatcmpl_output["type"] = "text";
|
||||
}
|
||||
chatcmpl_messages.push_back(json {
|
||||
{"content", chatcmpl_outputs},
|
||||
{"role", "tool"},
|
||||
{"tool_call_id", item.at("call_id")},
|
||||
});
|
||||
}
|
||||
} else if (// exists_and_is_string(item, "id") &&
|
||||
// item["id"] not sent by codex-cli
|
||||
exists_and_is_array(item, "summary") &&
|
||||
exists_and_is_string(item, "type") &&
|
||||
item.at("type") == "reasoning") {
|
||||
// #responses_create-input-input_item_list-item-reasoning
|
||||
|
||||
if (!exists_and_is_array(item, "content")) {
|
||||
throw std::invalid_argument("item['content'] is not an array");
|
||||
}
|
||||
if (item.at("content").empty()) {
|
||||
throw std::invalid_argument("item['content'] is empty");
|
||||
}
|
||||
if (!exists_and_is_string(item.at("content")[0], "text")) {
|
||||
throw std::invalid_argument("item['content']['text'] is not a string");
|
||||
}
|
||||
|
||||
if (merge_prev) {
|
||||
auto & prev_msg = chatcmpl_messages.back();
|
||||
prev_msg["reasoning_content"] = item.at("content")[0].at("text");
|
||||
} else {
|
||||
chatcmpl_messages.push_back(json {
|
||||
{"role", "assistant"},
|
||||
{"content", json::array()},
|
||||
{"reasoning_content", item.at("content")[0].at("text")},
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument("Cannot determine type of 'item'");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument("'input' must be a string or array of objects");
|
||||
}
|
||||
|
||||
chatcmpl_body["messages"] = chatcmpl_messages;
|
||||
|
||||
if (response_body.contains("tools")) {
|
||||
if (!response_body.at("tools").is_array()) {
|
||||
throw std::invalid_argument("'tools' must be an array of objects");
|
||||
}
|
||||
std::vector<json> chatcmpl_tools;
|
||||
for (json resp_tool : response_body.at("tools")) {
|
||||
json chatcmpl_tool;
|
||||
|
||||
if (json_value(resp_tool, "type", std::string()) != "function") {
|
||||
throw std::invalid_argument("'type' of tool must be 'function'");
|
||||
}
|
||||
resp_tool.erase("type");
|
||||
chatcmpl_tool["type"] = "function";
|
||||
|
||||
if (!resp_tool.contains("strict")) {
|
||||
resp_tool["strict"] = true;
|
||||
}
|
||||
chatcmpl_tool["function"] = resp_tool;
|
||||
chatcmpl_tools.push_back(chatcmpl_tool);
|
||||
}
|
||||
chatcmpl_body.erase("tools");
|
||||
chatcmpl_body["tools"] = chatcmpl_tools;
|
||||
}
|
||||
|
||||
if (response_body.contains("max_output_tokens")) {
|
||||
chatcmpl_body.erase("max_output_tokens");
|
||||
chatcmpl_body["max_tokens"] = response_body["max_output_tokens"];
|
||||
}
|
||||
|
||||
return chatcmpl_body;
|
||||
}
|
||||
|
||||
json convert_transcriptions_to_chatcmpl(
|
||||
const json & inp_body,
|
||||
const std::map<std::string, raw_buffer> & in_files,
|
||||
std::vector<raw_buffer> & out_files) {
|
||||
// TODO @ngxson : this function may need to be improved in the future
|
||||
// handle input files
|
||||
out_files.clear();
|
||||
auto it = in_files.find("file");
|
||||
if (it != in_files.end()) {
|
||||
out_files.push_back(it->second);
|
||||
} else {
|
||||
throw std::invalid_argument("No input file found for transcription");
|
||||
}
|
||||
|
||||
// handle input data
|
||||
std::string prompt = json_value(inp_body, "prompt", std::string());
|
||||
std::string language = json_value(inp_body, "language", std::string());
|
||||
std::string response_format = json_value(inp_body, "response_format", std::string("json"));
|
||||
if (response_format != "json") {
|
||||
throw std::invalid_argument("Only 'json' response_format is supported for transcription");
|
||||
}
|
||||
if (prompt.empty()) {
|
||||
prompt = "Transcribe audio to text";
|
||||
}
|
||||
if (!language.empty()) {
|
||||
prompt += string_format(" (language: %s)", language.c_str());
|
||||
}
|
||||
prompt += get_media_marker();
|
||||
|
||||
json chatcmpl_body = inp_body; // copy all fields
|
||||
chatcmpl_body["messages"] = json::array({
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", prompt},
|
||||
},
|
||||
});
|
||||
|
||||
// because input from form-data, everything is string, we need to correct the types here
|
||||
std::string stream = json_value(inp_body, "stream", std::string("false"));
|
||||
chatcmpl_body["stream"] = stream == "true";
|
||||
|
||||
if (inp_body.contains("max_tokens")) {
|
||||
std::string inp = inp_body["max_tokens"].get<std::string>();
|
||||
chatcmpl_body["max_tokens"] = std::stoul(inp);
|
||||
}
|
||||
|
||||
if (inp_body.contains("temperature")) {
|
||||
std::string inp = inp_body["temperature"].get<std::string>();
|
||||
chatcmpl_body["temperature"] = std::stof(inp);
|
||||
}
|
||||
|
||||
return chatcmpl_body;
|
||||
}
|
||||
|
||||
json convert_anthropic_to_oai(const json & body) {
|
||||
json oai_body;
|
||||
|
||||
// Convert system prompt
|
||||
json oai_messages = json::array();
|
||||
auto system_param = json_value(body, "system", json());
|
||||
if (!system_param.is_null()) {
|
||||
std::string system_content;
|
||||
|
||||
if (system_param.is_string()) {
|
||||
system_content = system_param.get<std::string>();
|
||||
} else if (system_param.is_array()) {
|
||||
for (const auto & block : system_param) {
|
||||
if (json_value(block, "type", std::string()) == "text") {
|
||||
system_content += json_value(block, "text", std::string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
oai_messages.push_back({
|
||||
{"role", "system"},
|
||||
{"content", system_content}
|
||||
});
|
||||
}
|
||||
|
||||
// Convert messages
|
||||
if (!body.contains("messages")) {
|
||||
throw std::runtime_error("'messages' is required");
|
||||
}
|
||||
const json & messages = body.at("messages");
|
||||
if (messages.is_array()) {
|
||||
for (const auto & msg : messages) {
|
||||
std::string role = json_value(msg, "role", std::string());
|
||||
|
||||
if (!msg.contains("content")) {
|
||||
if (role == "assistant") {
|
||||
continue;
|
||||
}
|
||||
oai_messages.push_back(msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
const json & content = msg.at("content");
|
||||
|
||||
if (content.is_string()) {
|
||||
oai_messages.push_back(msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!content.is_array()) {
|
||||
oai_messages.push_back(msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
json tool_calls = json::array();
|
||||
json converted_content = json::array();
|
||||
json tool_results = json::array();
|
||||
std::string reasoning_content;
|
||||
bool has_tool_calls = false;
|
||||
|
||||
for (const auto & block : content) {
|
||||
std::string type = json_value(block, "type", std::string());
|
||||
|
||||
if (type == "text") {
|
||||
converted_content.push_back(block);
|
||||
} else if (type == "thinking") {
|
||||
reasoning_content += json_value(block, "thinking", std::string());
|
||||
} else if (type == "image") {
|
||||
json source = json_value(block, "source", json::object());
|
||||
std::string source_type = json_value(source, "type", std::string());
|
||||
|
||||
if (source_type == "base64") {
|
||||
std::string media_type = json_value(source, "media_type", std::string("image/jpeg"));
|
||||
std::string data = json_value(source, "data", std::string());
|
||||
std::ostringstream ss;
|
||||
ss << "data:" << media_type << ";base64," << data;
|
||||
|
||||
converted_content.push_back({
|
||||
{"type", "image_url"},
|
||||
{"image_url", {
|
||||
{"url", ss.str()}
|
||||
}}
|
||||
});
|
||||
} else if (source_type == "url") {
|
||||
std::string url = json_value(source, "url", std::string());
|
||||
converted_content.push_back({
|
||||
{"type", "image_url"},
|
||||
{"image_url", {
|
||||
{"url", url}
|
||||
}}
|
||||
});
|
||||
}
|
||||
} else if (type == "tool_use") {
|
||||
tool_calls.push_back({
|
||||
{"id", json_value(block, "id", std::string())},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", json_value(block, "name", std::string())},
|
||||
{"arguments", json_value(block, "input", json::object()).dump()}
|
||||
}}
|
||||
});
|
||||
has_tool_calls = true;
|
||||
} else if (type == "tool_result") {
|
||||
std::string tool_use_id = json_value(block, "tool_use_id", std::string());
|
||||
|
||||
auto result_content = json_value(block, "content", json());
|
||||
std::string result_text;
|
||||
if (result_content.is_string()) {
|
||||
result_text = result_content.get<std::string>();
|
||||
} else if (result_content.is_array()) {
|
||||
for (const auto & c : result_content) {
|
||||
if (json_value(c, "type", std::string()) == "text") {
|
||||
result_text += json_value(c, "text", std::string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tool_results.push_back({
|
||||
{"role", "tool"},
|
||||
{"tool_call_id", tool_use_id},
|
||||
{"content", result_text}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (!converted_content.empty() || has_tool_calls || !reasoning_content.empty()) {
|
||||
json new_msg = {{"role", role}};
|
||||
if (!converted_content.empty()) {
|
||||
new_msg["content"] = converted_content;
|
||||
} else if (has_tool_calls || !reasoning_content.empty()) {
|
||||
new_msg["content"] = "";
|
||||
}
|
||||
if (!tool_calls.empty()) {
|
||||
new_msg["tool_calls"] = tool_calls;
|
||||
}
|
||||
if (!reasoning_content.empty()) {
|
||||
new_msg["reasoning_content"] = reasoning_content;
|
||||
}
|
||||
oai_messages.push_back(new_msg);
|
||||
}
|
||||
|
||||
for (const auto & tool_msg : tool_results) {
|
||||
oai_messages.push_back(tool_msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
oai_body["messages"] = oai_messages;
|
||||
|
||||
// Convert tools
|
||||
if (body.contains("tools")) {
|
||||
const json & tools = body.at("tools");
|
||||
if (tools.is_array()) {
|
||||
json oai_tools = json::array();
|
||||
for (const auto & tool : tools) {
|
||||
oai_tools.push_back({
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", json_value(tool, "name", std::string())},
|
||||
{"description", json_value(tool, "description", std::string())},
|
||||
{"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()}
|
||||
}}
|
||||
});
|
||||
}
|
||||
oai_body["tools"] = oai_tools;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert tool_choice
|
||||
if (body.contains("tool_choice")) {
|
||||
const json & tc = body.at("tool_choice");
|
||||
if (tc.is_object()) {
|
||||
std::string type = json_value(tc, "type", std::string());
|
||||
if (type == "auto") {
|
||||
oai_body["tool_choice"] = "auto";
|
||||
} else if (type == "any" || type == "tool") {
|
||||
oai_body["tool_choice"] = "required";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert stop_sequences to stop
|
||||
if (body.contains("stop_sequences")) {
|
||||
oai_body["stop"] = body.at("stop_sequences");
|
||||
}
|
||||
|
||||
// Handle max_tokens (required in Anthropic, but we're permissive)
|
||||
if (body.contains("max_tokens")) {
|
||||
oai_body["max_tokens"] = body.at("max_tokens");
|
||||
} else {
|
||||
oai_body["max_tokens"] = 4096;
|
||||
}
|
||||
|
||||
// Pass through common params
|
||||
for (const auto & key : {"temperature", "top_p", "top_k", "stream"}) {
|
||||
if (body.contains(key)) {
|
||||
oai_body[key] = body.at(key);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle Anthropic-specific thinking param
|
||||
if (body.contains("thinking")) {
|
||||
json thinking = json_value(body, "thinking", json::object());
|
||||
std::string thinking_type = json_value(thinking, "type", std::string());
|
||||
if (thinking_type == "enabled") {
|
||||
int budget_tokens = json_value(thinking, "budget_tokens", 10000);
|
||||
oai_body["thinking_budget_tokens"] = budget_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle Anthropic-specific metadata param
|
||||
if (body.contains("metadata")) {
|
||||
json metadata = json_value(body, "metadata", json::object());
|
||||
std::string user_id = json_value(metadata, "user_id", std::string());
|
||||
if (!user_id.empty()) {
|
||||
oai_body["__metadata_user_id"] = user_id;
|
||||
}
|
||||
}
|
||||
|
||||
return oai_body;
|
||||
}
|
||||
|
||||
json format_embeddings_response_oaicompat(
|
||||
const json & request,
|
||||
const std::string & model_name,
|
||||
|
|
|
|||
|
|
@ -307,18 +307,6 @@ json oaicompat_chat_params_parse(
|
|||
const server_chat_params & opt,
|
||||
std::vector<raw_buffer> & out_files);
|
||||
|
||||
// convert OpenAI Responses API format to OpenAI Chat Completions API format
|
||||
json convert_responses_to_chatcmpl(const json & body);
|
||||
|
||||
// convert OpenAI transcriptions API format to OpenAI Chat Completions API format
|
||||
json convert_transcriptions_to_chatcmpl(
|
||||
const json & body,
|
||||
const std::map<std::string, raw_buffer> & in_files,
|
||||
std::vector<raw_buffer> & out_files);
|
||||
|
||||
// convert Anthropic Messages API format to OpenAI Chat Completions API format
|
||||
json convert_anthropic_to_oai(const json & body);
|
||||
|
||||
// TODO: move it to server-task.cpp
|
||||
json format_embeddings_response_oaicompat(
|
||||
const json & request,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
#include "server-context.h"
|
||||
#include "server-chat.h"
|
||||
#include "server-common.h"
|
||||
#include "server-http.h"
|
||||
#include "server-task.h"
|
||||
|
|
@ -1044,8 +1045,8 @@ private:
|
|||
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
|
||||
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
|
||||
/* enable_thinking */ enable_thinking,
|
||||
/* reasoning_budget */ params_base.reasoning_budget,
|
||||
/* reasoning_budget_msg */ params_base.reasoning_budget_message,
|
||||
/* reasoning_budget */ params_base.sampling.reasoning_budget_tokens,
|
||||
/* reasoning_budget_msg */ params_base.sampling.reasoning_budget_message,
|
||||
/* media_path */ params_base.media_path,
|
||||
/* force_pure_content */ params_base.force_pure_content_parser
|
||||
};
|
||||
|
|
@ -2960,7 +2961,13 @@ private:
|
|||
|
||||
// verify and try to accept the draft
|
||||
{
|
||||
common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get()));
|
||||
const bool use_ckpt = slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
|
||||
// only save the sampler sampler state if we use checkpoints
|
||||
common_sampler_ptr smpl_save;
|
||||
if (use_ckpt) {
|
||||
smpl_save.reset(common_sampler_clone(slot.smpl.get()));
|
||||
}
|
||||
|
||||
GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1);
|
||||
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft);
|
||||
|
|
@ -2972,7 +2979,7 @@ private:
|
|||
|
||||
// check for partial draft acceptance
|
||||
if (accepted.size() < slot.spec_draft.size() + 1) {
|
||||
if (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
||||
if (use_ckpt) {
|
||||
// partial acceptance is not supported by the context -> truncate the draft and restore the state
|
||||
slot.spec_draft = std::move(accepted);
|
||||
|
||||
|
|
@ -3774,7 +3781,7 @@ void server_routes::init_routes() {
|
|||
this->post_responses_oai = [this](const server_http_req & req) {
|
||||
auto res = create_response();
|
||||
std::vector<raw_buffer> files;
|
||||
json body = convert_responses_to_chatcmpl(json::parse(req.body));
|
||||
json body = server_chat_convert_responses_to_chatcmpl(json::parse(req.body));
|
||||
SRV_DBG("%s\n", "Request converted: OpenAI Responses -> OpenAI Chat Completions");
|
||||
SRV_DBG("converted request: %s\n", body.dump().c_str());
|
||||
json body_parsed = oaicompat_chat_params_parse(
|
||||
|
|
@ -3819,7 +3826,7 @@ void server_routes::init_routes() {
|
|||
this->post_anthropic_messages = [this](const server_http_req & req) {
|
||||
auto res = create_response();
|
||||
std::vector<raw_buffer> files;
|
||||
json body = convert_anthropic_to_oai(json::parse(req.body));
|
||||
json body = server_chat_convert_anthropic_to_oai(json::parse(req.body));
|
||||
SRV_DBG("%s\n", "Request converted: Anthropic -> OpenAI Chat Completions");
|
||||
SRV_DBG("converted request: %s\n", body.dump().c_str());
|
||||
json body_parsed = oaicompat_chat_params_parse(
|
||||
|
|
@ -3837,7 +3844,7 @@ void server_routes::init_routes() {
|
|||
this->post_anthropic_count_tokens = [this](const server_http_req & req) {
|
||||
auto res = create_response();
|
||||
std::vector<raw_buffer> files;
|
||||
json body = convert_anthropic_to_oai(json::parse(req.body));
|
||||
json body = server_chat_convert_anthropic_to_oai(json::parse(req.body));
|
||||
SRV_DBG("%s\n", "Request converted: Anthropic -> OpenAI Chat Completions");
|
||||
SRV_DBG("converted request: %s\n", body.dump().c_str());
|
||||
json body_parsed = oaicompat_chat_params_parse(
|
||||
|
|
|
|||
|
|
@ -712,6 +712,11 @@ void server_models::unload(const std::string & name) {
|
|||
if (it->second.meta.is_running()) {
|
||||
SRV_INF("stopping model instance name=%s\n", name.c_str());
|
||||
stopping_models.insert(name);
|
||||
if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) {
|
||||
// special case: if model is in loading state, unloading means force-killing it
|
||||
SRV_WRN("model name=%s is still loading, force-killing\n", name.c_str());
|
||||
subprocess_terminate(it->second.subproc.get());
|
||||
}
|
||||
cv_stop.notify_all();
|
||||
// status change will be handled by the managing thread
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include "server-task.h"
|
||||
|
||||
#include "build-info.h"
|
||||
#include "server-chat.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
|
|
@ -873,7 +874,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
|
|||
json {
|
||||
{"finish_reason", nullptr},
|
||||
{"index", index},
|
||||
{"delta", common_chat_msg_diff_to_json_oaicompat(diff)},
|
||||
{"delta", server_chat_msg_diff_to_json_oaicompat(diff)},
|
||||
},
|
||||
})},
|
||||
{"created", t},
|
||||
|
|
@ -1110,7 +1111,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() {
|
|||
json server_task_result_cmpl_final::to_json_oaicompat_asr() {
|
||||
json event = json {
|
||||
{"type", "transcript.text.done"},
|
||||
{"text", content},
|
||||
{"text", oaicompat_msg.content},
|
||||
{"usage", json {
|
||||
{"type", "tokens"},
|
||||
{"input_tokens", n_prompt_tokens},
|
||||
|
|
@ -1522,7 +1523,7 @@ json server_task_result_cmpl_partial::to_json_oaicompat_chat() {
|
|||
}
|
||||
|
||||
for (const auto & diff : oaicompat_msg_diffs) {
|
||||
add_delta(common_chat_msg_diff_to_json_oaicompat(diff));
|
||||
add_delta(server_chat_msg_diff_to_json_oaicompat(diff));
|
||||
}
|
||||
|
||||
if (!deltas.empty()) {
|
||||
|
|
|
|||
595
vendor/cpp-httplib/httplib.cpp
vendored
595
vendor/cpp-httplib/httplib.cpp
vendored
|
|
@ -872,7 +872,8 @@ bool write_websocket_frame(Stream &strm, ws::Opcode opcode,
|
|||
if (strm.write(reinterpret_cast<char *>(header), 2) < 0) { return false; }
|
||||
uint8_t ext[8];
|
||||
for (int i = 7; i >= 0; i--) {
|
||||
ext[7 - i] = static_cast<uint8_t>((len >> (i * 8)) & 0xFF);
|
||||
ext[7 - i] =
|
||||
static_cast<uint8_t>((static_cast<uint64_t>(len) >> (i * 8)) & 0xFF);
|
||||
}
|
||||
if (strm.write(reinterpret_cast<char *>(ext), 8) < 0) { return false; }
|
||||
}
|
||||
|
|
@ -1034,10 +1035,15 @@ bool canonicalize_path(const char *path, std::string &resolved) {
|
|||
char buf[_MAX_PATH];
|
||||
if (_fullpath(buf, path, _MAX_PATH) == nullptr) { return false; }
|
||||
resolved = buf;
|
||||
#else
|
||||
#elif defined(PATH_MAX)
|
||||
char buf[PATH_MAX];
|
||||
if (realpath(path, buf) == nullptr) { return false; }
|
||||
resolved = buf;
|
||||
#else
|
||||
auto buf = realpath(path, nullptr);
|
||||
auto guard = scope_exit([&]() { std::free(buf); });
|
||||
if (buf == nullptr) { return false; }
|
||||
resolved = buf;
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
|
@ -2765,6 +2771,35 @@ EncodingType encoding_type(const Request &req, const Response &res) {
|
|||
return best;
|
||||
}
|
||||
|
||||
std::unique_ptr<compressor> make_compressor(EncodingType type) {
|
||||
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
|
||||
if (type == EncodingType::Gzip) {
|
||||
return detail::make_unique<gzip_compressor>();
|
||||
}
|
||||
#endif
|
||||
#ifdef CPPHTTPLIB_BROTLI_SUPPORT
|
||||
if (type == EncodingType::Brotli) {
|
||||
return detail::make_unique<brotli_compressor>();
|
||||
}
|
||||
#endif
|
||||
#ifdef CPPHTTPLIB_ZSTD_SUPPORT
|
||||
if (type == EncodingType::Zstd) {
|
||||
return detail::make_unique<zstd_compressor>();
|
||||
}
|
||||
#endif
|
||||
(void)type;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const char *encoding_name(EncodingType type) {
|
||||
switch (type) {
|
||||
case EncodingType::Gzip: return "gzip";
|
||||
case EncodingType::Brotli: return "br";
|
||||
case EncodingType::Zstd: return "zstd";
|
||||
default: return "";
|
||||
}
|
||||
}
|
||||
|
||||
bool nocompressor::compress(const char *data, size_t data_length,
|
||||
bool /*last*/, Callback callback) {
|
||||
if (!data_length) { return true; }
|
||||
|
|
@ -3097,6 +3132,29 @@ const char *get_header_value(const Headers &headers,
|
|||
return def;
|
||||
}
|
||||
|
||||
size_t get_header_value_count(const Headers &headers,
|
||||
const std::string &key) {
|
||||
auto r = headers.equal_range(key);
|
||||
return static_cast<size_t>(std::distance(r.first, r.second));
|
||||
}
|
||||
|
||||
template <typename Map>
|
||||
typename Map::mapped_type
|
||||
get_multimap_value(const Map &m, const std::string &key, size_t id) {
|
||||
auto rng = m.equal_range(key);
|
||||
auto it = rng.first;
|
||||
std::advance(it, static_cast<ssize_t>(id));
|
||||
if (it != rng.second) { return it->second; }
|
||||
return typename Map::mapped_type();
|
||||
}
|
||||
|
||||
void set_header(Headers &headers, const std::string &key,
|
||||
const std::string &val) {
|
||||
if (fields::is_field_name(key) && fields::is_field_value(val)) {
|
||||
headers.emplace(key, val);
|
||||
}
|
||||
}
|
||||
|
||||
bool read_headers(Stream &strm, Headers &headers) {
|
||||
const auto bufsiz = 2048;
|
||||
char buf[bufsiz];
|
||||
|
|
@ -5791,16 +5849,12 @@ std::string Request::get_header_value(const std::string &key,
|
|||
}
|
||||
|
||||
size_t Request::get_header_value_count(const std::string &key) const {
|
||||
auto r = headers.equal_range(key);
|
||||
return static_cast<size_t>(std::distance(r.first, r.second));
|
||||
return detail::get_header_value_count(headers, key);
|
||||
}
|
||||
|
||||
void Request::set_header(const std::string &key,
|
||||
const std::string &val) {
|
||||
if (detail::fields::is_field_name(key) &&
|
||||
detail::fields::is_field_value(val)) {
|
||||
headers.emplace(key, val);
|
||||
}
|
||||
detail::set_header(headers, key, val);
|
||||
}
|
||||
|
||||
bool Request::has_trailer(const std::string &key) const {
|
||||
|
|
@ -5809,11 +5863,7 @@ bool Request::has_trailer(const std::string &key) const {
|
|||
|
||||
std::string Request::get_trailer_value(const std::string &key,
|
||||
size_t id) const {
|
||||
auto rng = trailers.equal_range(key);
|
||||
auto it = rng.first;
|
||||
std::advance(it, static_cast<ssize_t>(id));
|
||||
if (it != rng.second) { return it->second; }
|
||||
return std::string();
|
||||
return detail::get_multimap_value(trailers, key, id);
|
||||
}
|
||||
|
||||
size_t Request::get_trailer_value_count(const std::string &key) const {
|
||||
|
|
@ -5827,11 +5877,7 @@ bool Request::has_param(const std::string &key) const {
|
|||
|
||||
std::string Request::get_param_value(const std::string &key,
|
||||
size_t id) const {
|
||||
auto rng = params.equal_range(key);
|
||||
auto it = rng.first;
|
||||
std::advance(it, static_cast<ssize_t>(id));
|
||||
if (it != rng.second) { return it->second; }
|
||||
return std::string();
|
||||
return detail::get_multimap_value(params, key, id);
|
||||
}
|
||||
|
||||
std::vector<std::string>
|
||||
|
|
@ -5886,11 +5932,7 @@ size_t MultipartFormData::get_field_count(const std::string &key) const {
|
|||
|
||||
FormData MultipartFormData::get_file(const std::string &key,
|
||||
size_t id) const {
|
||||
auto rng = files.equal_range(key);
|
||||
auto it = rng.first;
|
||||
std::advance(it, static_cast<ssize_t>(id));
|
||||
if (it != rng.second) { return it->second; }
|
||||
return FormData();
|
||||
return detail::get_multimap_value(files, key, id);
|
||||
}
|
||||
|
||||
std::vector<FormData>
|
||||
|
|
@ -5929,16 +5971,12 @@ std::string Response::get_header_value(const std::string &key,
|
|||
}
|
||||
|
||||
size_t Response::get_header_value_count(const std::string &key) const {
|
||||
auto r = headers.equal_range(key);
|
||||
return static_cast<size_t>(std::distance(r.first, r.second));
|
||||
return detail::get_header_value_count(headers, key);
|
||||
}
|
||||
|
||||
void Response::set_header(const std::string &key,
|
||||
const std::string &val) {
|
||||
if (detail::fields::is_field_name(key) &&
|
||||
detail::fields::is_field_value(val)) {
|
||||
headers.emplace(key, val);
|
||||
}
|
||||
detail::set_header(headers, key, val);
|
||||
}
|
||||
bool Response::has_trailer(const std::string &key) const {
|
||||
return trailers.find(key) != trailers.end();
|
||||
|
|
@ -5946,11 +5984,7 @@ bool Response::has_trailer(const std::string &key) const {
|
|||
|
||||
std::string Response::get_trailer_value(const std::string &key,
|
||||
size_t id) const {
|
||||
auto rng = trailers.equal_range(key);
|
||||
auto it = rng.first;
|
||||
std::advance(it, static_cast<ssize_t>(id));
|
||||
if (it != rng.second) { return it->second; }
|
||||
return std::string();
|
||||
return detail::get_multimap_value(trailers, key, id);
|
||||
}
|
||||
|
||||
size_t Response::get_trailer_value_count(const std::string &key) const {
|
||||
|
|
@ -6253,15 +6287,6 @@ void ThreadPool::worker(bool is_dynamic) {
|
|||
|
||||
assert(true == static_cast<bool>(fn));
|
||||
fn();
|
||||
|
||||
// Dynamic thread: exit if queue is empty after task completion
|
||||
if (is_dynamic) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
if (jobs_.empty()) {
|
||||
move_to_finished(std::this_thread::get_id());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && \
|
||||
|
|
@ -6791,61 +6816,51 @@ Server::make_matcher(const std::string &pattern) {
|
|||
}
|
||||
|
||||
Server &Server::Get(const std::string &pattern, Handler handler) {
|
||||
get_handlers_.emplace_back(make_matcher(pattern), std::move(handler));
|
||||
return *this;
|
||||
return add_handler(get_handlers_, pattern, std::move(handler));
|
||||
}
|
||||
|
||||
Server &Server::Post(const std::string &pattern, Handler handler) {
|
||||
post_handlers_.emplace_back(make_matcher(pattern), std::move(handler));
|
||||
return *this;
|
||||
return add_handler(post_handlers_, pattern, std::move(handler));
|
||||
}
|
||||
|
||||
Server &Server::Post(const std::string &pattern,
|
||||
HandlerWithContentReader handler) {
|
||||
post_handlers_for_content_reader_.emplace_back(make_matcher(pattern),
|
||||
std::move(handler));
|
||||
return *this;
|
||||
return add_handler(post_handlers_for_content_reader_, pattern,
|
||||
std::move(handler));
|
||||
}
|
||||
|
||||
Server &Server::Put(const std::string &pattern, Handler handler) {
|
||||
put_handlers_.emplace_back(make_matcher(pattern), std::move(handler));
|
||||
return *this;
|
||||
return add_handler(put_handlers_, pattern, std::move(handler));
|
||||
}
|
||||
|
||||
Server &Server::Put(const std::string &pattern,
|
||||
HandlerWithContentReader handler) {
|
||||
put_handlers_for_content_reader_.emplace_back(make_matcher(pattern),
|
||||
std::move(handler));
|
||||
return *this;
|
||||
return add_handler(put_handlers_for_content_reader_, pattern,
|
||||
std::move(handler));
|
||||
}
|
||||
|
||||
Server &Server::Patch(const std::string &pattern, Handler handler) {
|
||||
patch_handlers_.emplace_back(make_matcher(pattern), std::move(handler));
|
||||
return *this;
|
||||
return add_handler(patch_handlers_, pattern, std::move(handler));
|
||||
}
|
||||
|
||||
Server &Server::Patch(const std::string &pattern,
|
||||
HandlerWithContentReader handler) {
|
||||
patch_handlers_for_content_reader_.emplace_back(make_matcher(pattern),
|
||||
std::move(handler));
|
||||
return *this;
|
||||
return add_handler(patch_handlers_for_content_reader_, pattern,
|
||||
std::move(handler));
|
||||
}
|
||||
|
||||
Server &Server::Delete(const std::string &pattern, Handler handler) {
|
||||
delete_handlers_.emplace_back(make_matcher(pattern), std::move(handler));
|
||||
return *this;
|
||||
return add_handler(delete_handlers_, pattern, std::move(handler));
|
||||
}
|
||||
|
||||
Server &Server::Delete(const std::string &pattern,
|
||||
HandlerWithContentReader handler) {
|
||||
delete_handlers_for_content_reader_.emplace_back(make_matcher(pattern),
|
||||
std::move(handler));
|
||||
return *this;
|
||||
return add_handler(delete_handlers_for_content_reader_, pattern,
|
||||
std::move(handler));
|
||||
}
|
||||
|
||||
Server &Server::Options(const std::string &pattern, Handler handler) {
|
||||
options_handlers_.emplace_back(make_matcher(pattern), std::move(handler));
|
||||
return *this;
|
||||
return add_handler(options_handlers_, pattern, std::move(handler));
|
||||
}
|
||||
|
||||
Server &Server::WebSocket(const std::string &pattern,
|
||||
|
|
@ -7054,6 +7069,11 @@ Server &Server::set_payload_max_length(size_t length) {
|
|||
return *this;
|
||||
}
|
||||
|
||||
Server &Server::set_websocket_max_missed_pongs(int count) {
|
||||
websocket_max_missed_pongs_ = count;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Server &Server::set_websocket_ping_interval(time_t sec) {
|
||||
websocket_ping_interval_sec_ = sec;
|
||||
return *this;
|
||||
|
|
@ -7279,23 +7299,10 @@ Server::write_content_with_provider(Stream &strm, const Request &req,
|
|||
if (res.is_chunked_content_provider_) {
|
||||
auto type = detail::encoding_type(req, res);
|
||||
|
||||
std::unique_ptr<detail::compressor> compressor;
|
||||
if (type == detail::EncodingType::Gzip) {
|
||||
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
|
||||
compressor = detail::make_unique<detail::gzip_compressor>();
|
||||
#endif
|
||||
} else if (type == detail::EncodingType::Brotli) {
|
||||
#ifdef CPPHTTPLIB_BROTLI_SUPPORT
|
||||
compressor = detail::make_unique<detail::brotli_compressor>();
|
||||
#endif
|
||||
} else if (type == detail::EncodingType::Zstd) {
|
||||
#ifdef CPPHTTPLIB_ZSTD_SUPPORT
|
||||
compressor = detail::make_unique<detail::zstd_compressor>();
|
||||
#endif
|
||||
} else {
|
||||
auto compressor = detail::make_compressor(type);
|
||||
if (!compressor) {
|
||||
compressor = detail::make_unique<detail::nocompressor>();
|
||||
}
|
||||
assert(compressor != nullptr);
|
||||
|
||||
return detail::write_content_chunked(strm, res.content_provider_,
|
||||
is_shutting_down, *compressor);
|
||||
|
|
@ -7917,14 +7924,8 @@ void Server::apply_ranges(const Request &req, Response &res,
|
|||
if (res.content_provider_) {
|
||||
if (res.is_chunked_content_provider_) {
|
||||
res.set_header("Transfer-Encoding", "chunked");
|
||||
if (type == detail::EncodingType::Gzip) {
|
||||
res.set_header("Content-Encoding", "gzip");
|
||||
res.set_header("Vary", "Accept-Encoding");
|
||||
} else if (type == detail::EncodingType::Brotli) {
|
||||
res.set_header("Content-Encoding", "br");
|
||||
res.set_header("Vary", "Accept-Encoding");
|
||||
} else if (type == detail::EncodingType::Zstd) {
|
||||
res.set_header("Content-Encoding", "zstd");
|
||||
if (type != detail::EncodingType::None) {
|
||||
res.set_header("Content-Encoding", detail::encoding_name(type));
|
||||
res.set_header("Vary", "Accept-Encoding");
|
||||
}
|
||||
}
|
||||
|
|
@ -7955,27 +7956,7 @@ void Server::apply_ranges(const Request &req, Response &res,
|
|||
if (type != detail::EncodingType::None) {
|
||||
output_pre_compression_log(req, res);
|
||||
|
||||
std::unique_ptr<detail::compressor> compressor;
|
||||
std::string content_encoding;
|
||||
|
||||
if (type == detail::EncodingType::Gzip) {
|
||||
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
|
||||
compressor = detail::make_unique<detail::gzip_compressor>();
|
||||
content_encoding = "gzip";
|
||||
#endif
|
||||
} else if (type == detail::EncodingType::Brotli) {
|
||||
#ifdef CPPHTTPLIB_BROTLI_SUPPORT
|
||||
compressor = detail::make_unique<detail::brotli_compressor>();
|
||||
content_encoding = "br";
|
||||
#endif
|
||||
} else if (type == detail::EncodingType::Zstd) {
|
||||
#ifdef CPPHTTPLIB_ZSTD_SUPPORT
|
||||
compressor = detail::make_unique<detail::zstd_compressor>();
|
||||
content_encoding = "zstd";
|
||||
#endif
|
||||
}
|
||||
|
||||
if (compressor) {
|
||||
if (auto compressor = detail::make_compressor(type)) {
|
||||
std::string compressed;
|
||||
if (compressor->compress(res.body.data(), res.body.size(), true,
|
||||
[&](const char *data, size_t data_len) {
|
||||
|
|
@ -7983,7 +7964,7 @@ void Server::apply_ranges(const Request &req, Response &res,
|
|||
return true;
|
||||
})) {
|
||||
res.body.swap(compressed);
|
||||
res.set_header("Content-Encoding", content_encoding);
|
||||
res.set_header("Content-Encoding", detail::encoding_name(type));
|
||||
res.set_header("Vary", "Accept-Encoding");
|
||||
}
|
||||
}
|
||||
|
|
@ -8231,7 +8212,8 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
|
|||
{
|
||||
// Use WebSocket-specific read timeout instead of HTTP timeout
|
||||
strm.set_read_timeout(CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND, 0);
|
||||
ws::WebSocket ws(strm, req, true, websocket_ping_interval_sec_);
|
||||
ws::WebSocket ws(strm, req, true, websocket_ping_interval_sec_,
|
||||
websocket_max_missed_pongs_);
|
||||
entry.handler(req, ws);
|
||||
}
|
||||
return true;
|
||||
|
|
@ -10822,38 +10804,6 @@ void ClientImpl::enable_server_hostname_verification(bool enabled) {
|
|||
}
|
||||
#endif
|
||||
|
||||
// ClientImpl::set_ca_cert_store is defined after TLS namespace (uses helpers)
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert,
|
||||
std::size_t size) const {
|
||||
auto mem = BIO_new_mem_buf(ca_cert, static_cast<int>(size));
|
||||
auto se = detail::scope_exit([&] { BIO_free_all(mem); });
|
||||
if (!mem) { return nullptr; }
|
||||
|
||||
auto inf = PEM_X509_INFO_read_bio(mem, nullptr, nullptr, nullptr);
|
||||
if (!inf) { return nullptr; }
|
||||
|
||||
auto cts = X509_STORE_new();
|
||||
if (cts) {
|
||||
for (auto i = 0; i < static_cast<int>(sk_X509_INFO_num(inf)); i++) {
|
||||
auto itmp = sk_X509_INFO_value(inf, i);
|
||||
if (!itmp) { continue; }
|
||||
|
||||
if (itmp->x509) { X509_STORE_add_cert(cts, itmp->x509); }
|
||||
if (itmp->crl) { X509_STORE_add_crl(cts, itmp->crl); }
|
||||
}
|
||||
}
|
||||
|
||||
sk_X509_INFO_pop_free(inf, X509_INFO_free);
|
||||
return cts;
|
||||
}
|
||||
|
||||
void ClientImpl::set_server_certificate_verifier(
|
||||
std::function<SSLVerifierResponse(SSL *ssl)> /*verifier*/) {
|
||||
// Base implementation does nothing - SSLClient overrides this
|
||||
}
|
||||
#endif
|
||||
|
||||
void ClientImpl::set_logger(Logger logger) {
|
||||
logger_ = std::move(logger);
|
||||
}
|
||||
|
|
@ -10927,10 +10877,10 @@ Client::Client(const std::string &scheme_host_port,
|
|||
cli_ = detail::make_unique<ClientImpl>(scheme_host_port, 80,
|
||||
client_cert_path, client_key_path);
|
||||
}
|
||||
} // namespace detail
|
||||
}
|
||||
|
||||
Client::Client(const std::string &host, int port)
|
||||
: cli_(detail::make_unique<ClientImpl>(host, port)) {}
|
||||
: Client(host, port, std::string(), std::string()) {}
|
||||
|
||||
Client::Client(const std::string &host, int port,
|
||||
const std::string &client_cert_path,
|
||||
|
|
@ -11505,12 +11455,6 @@ void Client::set_follow_location(bool on) {
|
|||
|
||||
void Client::set_path_encode(bool on) { cli_->set_path_encode(on); }
|
||||
|
||||
[[deprecated("Use set_path_encode() instead. "
|
||||
"This function will be removed by v1.0.0.")]]
|
||||
void Client::set_url_encode(bool on) {
|
||||
cli_->set_path_encode(on);
|
||||
}
|
||||
|
||||
void Client::set_compress(bool on) { cli_->set_compress(on); }
|
||||
|
||||
void Client::set_decompress(bool on) { cli_->set_decompress(on); }
|
||||
|
|
@ -11893,24 +11837,31 @@ SSLClient::SSLClient(const std::string &host)
|
|||
SSLClient::SSLClient(const std::string &host, int port)
|
||||
: SSLClient(host, port, std::string(), std::string()) {}
|
||||
|
||||
void SSLClient::init_ctx() {
|
||||
ctx_ = tls::create_client_context();
|
||||
if (ctx_) { tls::set_min_version(ctx_, tls::Version::TLS1_2); }
|
||||
}
|
||||
|
||||
void SSLClient::reset_ctx_on_error() {
|
||||
last_backend_error_ = tls::get_error();
|
||||
tls::free_context(ctx_);
|
||||
ctx_ = nullptr;
|
||||
}
|
||||
|
||||
SSLClient::SSLClient(const std::string &host, int port,
|
||||
const std::string &client_cert_path,
|
||||
const std::string &client_key_path,
|
||||
const std::string &private_key_password)
|
||||
: ClientImpl(host, port, client_cert_path, client_key_path) {
|
||||
ctx_ = tls::create_client_context();
|
||||
init_ctx();
|
||||
if (!ctx_) { return; }
|
||||
|
||||
tls::set_min_version(ctx_, tls::Version::TLS1_2);
|
||||
|
||||
if (!client_cert_path.empty() && !client_key_path.empty()) {
|
||||
const char *password =
|
||||
private_key_password.empty() ? nullptr : private_key_password.c_str();
|
||||
if (!tls::set_client_cert_file(ctx_, client_cert_path.c_str(),
|
||||
client_key_path.c_str(), password)) {
|
||||
last_backend_error_ = tls::get_error();
|
||||
tls::free_context(ctx_);
|
||||
ctx_ = nullptr;
|
||||
reset_ctx_on_error();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -11918,17 +11869,13 @@ SSLClient::SSLClient(const std::string &host, int port,
|
|||
SSLClient::SSLClient(const std::string &host, int port,
|
||||
const PemMemory &pem)
|
||||
: ClientImpl(host, port) {
|
||||
ctx_ = tls::create_client_context();
|
||||
init_ctx();
|
||||
if (!ctx_) { return; }
|
||||
|
||||
tls::set_min_version(ctx_, tls::Version::TLS1_2);
|
||||
|
||||
if (pem.cert_pem && pem.key_pem) {
|
||||
if (!tls::set_client_cert_pem(ctx_, pem.cert_pem, pem.key_pem,
|
||||
pem.private_key_password)) {
|
||||
last_backend_error_ = tls::get_error();
|
||||
tls::free_context(ctx_);
|
||||
ctx_ = nullptr;
|
||||
reset_ctx_on_error();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -12479,41 +12426,6 @@ std::string Request::sni() const {
|
|||
* Group 8: TLS abstraction layer - OpenSSL backend
|
||||
*/
|
||||
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
// These wrappers forward to deprecated APIs that will be removed by v1.0.0.
|
||||
// Suppress C4996 / -Wdeprecated-declarations so that MSVC /sdl builds (which
|
||||
// promote C4996 to an error) compile cleanly even though the wrappers
|
||||
// themselves are also marked [[deprecated]].
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4996)
|
||||
#elif defined(__GNUC__) || defined(__clang__)
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
#endif
|
||||
|
||||
SSL_CTX *Client::ssl_context() const {
|
||||
if (is_ssl_) { return static_cast<SSLClient &>(*cli_).ssl_context(); }
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Client::set_server_certificate_verifier(
|
||||
std::function<SSLVerifierResponse(SSL *ssl)> verifier) {
|
||||
cli_->set_server_certificate_verifier(verifier);
|
||||
}
|
||||
|
||||
long Client::get_verify_result() const {
|
||||
if (is_ssl_) { return static_cast<SSLClient &>(*cli_).get_verify_result(); }
|
||||
return -1; // NOTE: -1 doesn't match any of X509_V_ERR_???
|
||||
}
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(pop)
|
||||
#elif defined(__GNUC__) || defined(__clang__)
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
#endif // CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
|
||||
/*
|
||||
* OpenSSL Backend Implementation
|
||||
*/
|
||||
|
|
@ -12523,54 +12435,6 @@ namespace tls {
|
|||
|
||||
namespace impl {
|
||||
|
||||
// OpenSSL-specific helpers for converting native types to PEM
|
||||
std::string x509_to_pem(X509 *cert) {
|
||||
if (!cert) return {};
|
||||
BIO *bio = BIO_new(BIO_s_mem());
|
||||
if (!bio) return {};
|
||||
if (PEM_write_bio_X509(bio, cert) != 1) {
|
||||
BIO_free(bio);
|
||||
return {};
|
||||
}
|
||||
char *data = nullptr;
|
||||
long len = BIO_get_mem_data(bio, &data);
|
||||
std::string pem(data, static_cast<size_t>(len));
|
||||
BIO_free(bio);
|
||||
return pem;
|
||||
}
|
||||
|
||||
std::string evp_pkey_to_pem(EVP_PKEY *key) {
|
||||
if (!key) return {};
|
||||
BIO *bio = BIO_new(BIO_s_mem());
|
||||
if (!bio) return {};
|
||||
if (PEM_write_bio_PrivateKey(bio, key, nullptr, nullptr, 0, nullptr,
|
||||
nullptr) != 1) {
|
||||
BIO_free(bio);
|
||||
return {};
|
||||
}
|
||||
char *data = nullptr;
|
||||
long len = BIO_get_mem_data(bio, &data);
|
||||
std::string pem(data, static_cast<size_t>(len));
|
||||
BIO_free(bio);
|
||||
return pem;
|
||||
}
|
||||
|
||||
std::string x509_store_to_pem(X509_STORE *store) {
|
||||
if (!store) return {};
|
||||
std::string pem;
|
||||
auto objs = X509_STORE_get0_objects(store);
|
||||
if (!objs) return {};
|
||||
auto count = sk_X509_OBJECT_num(objs);
|
||||
for (decltype(count) i = 0; i < count; i++) {
|
||||
auto obj = sk_X509_OBJECT_value(objs, i);
|
||||
if (X509_OBJECT_get_type(obj) == X509_LU_X509) {
|
||||
auto cert = X509_OBJECT_get0_X509(obj);
|
||||
if (cert) { pem += x509_to_pem(cert); }
|
||||
}
|
||||
}
|
||||
return pem;
|
||||
}
|
||||
|
||||
// Helper to map OpenSSL SSL_get_error to ErrorCode
|
||||
ErrorCode map_ssl_error(int ssl_error, int &out_errno) {
|
||||
switch (ssl_error) {
|
||||
|
|
@ -12603,8 +12467,10 @@ STACK_OF(X509_NAME) *
|
|||
X509 *cert = nullptr;
|
||||
while ((cert = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr)) !=
|
||||
nullptr) {
|
||||
X509_NAME *name = X509_get_subject_name(cert);
|
||||
if (name) { sk_X509_NAME_push(ca_list, X509_NAME_dup(name)); }
|
||||
const X509_NAME *name = X509_get_subject_name(cert);
|
||||
if (name) {
|
||||
sk_X509_NAME_push(ca_list, X509_NAME_dup(const_cast<X509_NAME *>(name)));
|
||||
}
|
||||
X509_free(cert);
|
||||
}
|
||||
BIO_free(bio);
|
||||
|
|
@ -12612,45 +12478,6 @@ STACK_OF(X509_NAME) *
|
|||
return ca_list;
|
||||
}
|
||||
|
||||
// Helper: Extract CA names from X509_STORE
|
||||
// Returns a new STACK_OF(X509_NAME)* or nullptr on failure
|
||||
// Caller takes ownership of returned list
|
||||
STACK_OF(X509_NAME) *
|
||||
extract_client_ca_list_from_store(X509_STORE *store) {
|
||||
if (!store) { return nullptr; }
|
||||
|
||||
auto ca_list = sk_X509_NAME_new_null();
|
||||
if (!ca_list) { return nullptr; }
|
||||
|
||||
auto objs = X509_STORE_get0_objects(store);
|
||||
if (!objs) {
|
||||
sk_X509_NAME_free(ca_list);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto count = sk_X509_OBJECT_num(objs);
|
||||
for (decltype(count) i = 0; i < count; i++) {
|
||||
auto obj = sk_X509_OBJECT_value(objs, i);
|
||||
if (X509_OBJECT_get_type(obj) == X509_LU_X509) {
|
||||
auto cert = X509_OBJECT_get0_X509(obj);
|
||||
if (cert) {
|
||||
auto subject = X509_get_subject_name(cert);
|
||||
if (subject) {
|
||||
auto name_dup = X509_NAME_dup(subject);
|
||||
if (name_dup) { sk_X509_NAME_push(ca_list, name_dup); }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (sk_X509_NAME_num(ca_list) == 0) {
|
||||
sk_X509_NAME_free(ca_list);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return ca_list;
|
||||
}
|
||||
|
||||
// OpenSSL verify callback wrapper
|
||||
int openssl_verify_callback(int preverify_ok, X509_STORE_CTX *ctx) {
|
||||
auto &callback = get_verify_callback();
|
||||
|
|
@ -13086,6 +12913,9 @@ ssize_t read(session_t session, void *buf, size_t len, TlsError &err) {
|
|||
|
||||
auto ssl_err = SSL_get_error(ssl, ret);
|
||||
err.code = impl::map_ssl_error(ssl_err, err.sys_errno);
|
||||
if (err.code == ErrorCode::PeerClosed) {
|
||||
return 0;
|
||||
} // Gracefully handle the peer closed state.
|
||||
if (err.code == ErrorCode::Fatal) { err.backend_code = ERR_get_error(); }
|
||||
return -1;
|
||||
}
|
||||
|
|
@ -13523,164 +13353,8 @@ std::string verify_error_string(long error_code) {
|
|||
return str ? str : "unknown error";
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
|
||||
// OpenSSL-specific helpers for public API wrappers
|
||||
ctx_t create_server_context_from_x509(X509 *cert, EVP_PKEY *key,
|
||||
X509_STORE *client_ca_store,
|
||||
int &out_error) {
|
||||
out_error = 0;
|
||||
auto cert_pem = x509_to_pem(cert);
|
||||
auto key_pem = evp_pkey_to_pem(key);
|
||||
if (cert_pem.empty() || key_pem.empty()) {
|
||||
out_error = static_cast<int>(ERR_get_error());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto ctx = create_server_context();
|
||||
if (!ctx) {
|
||||
out_error = static_cast<int>(get_error());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!set_server_cert_pem(ctx, cert_pem.c_str(), key_pem.c_str(), nullptr)) {
|
||||
out_error = static_cast<int>(get_error());
|
||||
free_context(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (client_ca_store) {
|
||||
// Set cert store for verification (SSL_CTX_set_cert_store takes ownership)
|
||||
SSL_CTX_set_cert_store(static_cast<SSL_CTX *>(ctx), client_ca_store);
|
||||
|
||||
// Extract and set client CA list directly from store (more efficient than
|
||||
// PEM conversion)
|
||||
auto ca_list = extract_client_ca_list_from_store(client_ca_store);
|
||||
if (ca_list) {
|
||||
SSL_CTX_set_client_CA_list(static_cast<SSL_CTX *>(ctx), ca_list);
|
||||
}
|
||||
|
||||
set_verify_client(ctx, true);
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
void update_server_certs_from_x509(ctx_t ctx, X509 *cert, EVP_PKEY *key,
|
||||
X509_STORE *client_ca_store) {
|
||||
auto cert_pem = x509_to_pem(cert);
|
||||
auto key_pem = evp_pkey_to_pem(key);
|
||||
|
||||
if (!cert_pem.empty() && !key_pem.empty()) {
|
||||
update_server_cert(ctx, cert_pem.c_str(), key_pem.c_str(), nullptr);
|
||||
}
|
||||
|
||||
if (client_ca_store) {
|
||||
auto ca_pem = x509_store_to_pem(client_ca_store);
|
||||
if (!ca_pem.empty()) { update_server_client_ca(ctx, ca_pem.c_str()); }
|
||||
X509_STORE_free(client_ca_store);
|
||||
}
|
||||
}
|
||||
|
||||
ctx_t create_client_context_from_x509(X509 *cert, EVP_PKEY *key,
|
||||
const char *password,
|
||||
uint64_t &out_error) {
|
||||
out_error = 0;
|
||||
auto ctx = create_client_context();
|
||||
if (!ctx) {
|
||||
out_error = get_error();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (cert && key) {
|
||||
auto cert_pem = x509_to_pem(cert);
|
||||
auto key_pem = evp_pkey_to_pem(key);
|
||||
if (cert_pem.empty() || key_pem.empty()) {
|
||||
out_error = ERR_get_error();
|
||||
free_context(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
if (!set_client_cert_pem(ctx, cert_pem.c_str(), key_pem.c_str(),
|
||||
password)) {
|
||||
out_error = get_error();
|
||||
free_context(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
|
||||
} // namespace tls
|
||||
|
||||
// ClientImpl::set_ca_cert_store - defined here to use
|
||||
// tls::impl::x509_store_to_pem Deprecated: converts X509_STORE to PEM and
|
||||
// stores for redirect transfer
|
||||
void ClientImpl::set_ca_cert_store(X509_STORE *ca_cert_store) {
|
||||
if (ca_cert_store) {
|
||||
ca_cert_pem_ = tls::impl::x509_store_to_pem(ca_cert_store);
|
||||
}
|
||||
}
|
||||
|
||||
SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key,
|
||||
X509_STORE *client_ca_cert_store) {
|
||||
ctx_ = tls::impl::create_server_context_from_x509(
|
||||
cert, private_key, client_ca_cert_store, last_ssl_error_);
|
||||
}
|
||||
|
||||
SSLServer::SSLServer(
|
||||
const std::function<bool(SSL_CTX &ssl_ctx)> &setup_ssl_ctx_callback) {
|
||||
// Use abstract API to create context
|
||||
ctx_ = tls::create_server_context();
|
||||
if (ctx_) {
|
||||
// Pass to OpenSSL-specific callback (ctx_ is SSL_CTX* internally)
|
||||
auto ssl_ctx = static_cast<SSL_CTX *>(ctx_);
|
||||
if (!setup_ssl_ctx_callback(*ssl_ctx)) {
|
||||
tls::free_context(ctx_);
|
||||
ctx_ = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SSL_CTX *SSLServer::ssl_context() const {
|
||||
return static_cast<SSL_CTX *>(ctx_);
|
||||
}
|
||||
|
||||
void SSLServer::update_certs(X509 *cert, EVP_PKEY *private_key,
|
||||
X509_STORE *client_ca_cert_store) {
|
||||
std::lock_guard<std::mutex> guard(ctx_mutex_);
|
||||
tls::impl::update_server_certs_from_x509(ctx_, cert, private_key,
|
||||
client_ca_cert_store);
|
||||
}
|
||||
|
||||
SSLClient::SSLClient(const std::string &host, int port,
|
||||
X509 *client_cert, EVP_PKEY *client_key,
|
||||
const std::string &private_key_password)
|
||||
: ClientImpl(host, port) {
|
||||
const char *password =
|
||||
private_key_password.empty() ? nullptr : private_key_password.c_str();
|
||||
ctx_ = tls::impl::create_client_context_from_x509(
|
||||
client_cert, client_key, password, last_backend_error_);
|
||||
}
|
||||
|
||||
long SSLClient::get_verify_result() const { return verify_result_; }
|
||||
|
||||
void SSLClient::set_server_certificate_verifier(
|
||||
std::function<SSLVerifierResponse(SSL *ssl)> verifier) {
|
||||
// Wrap SSL* callback into backend-independent session_verifier_
|
||||
auto v = std::make_shared<std::function<SSLVerifierResponse(SSL *)>>(
|
||||
std::move(verifier));
|
||||
session_verifier_ = [v](tls::session_t session) {
|
||||
return (*v)(static_cast<SSL *>(session));
|
||||
};
|
||||
}
|
||||
|
||||
SSL_CTX *SSLClient::ssl_context() const {
|
||||
return static_cast<SSL_CTX *>(ctx_);
|
||||
}
|
||||
|
||||
bool SSLClient::verify_host(X509 *server_cert) const {
|
||||
/* Quote from RFC2818 section 3.1 "Server Identity"
|
||||
|
||||
|
|
@ -16194,7 +15868,11 @@ ReadResult WebSocket::read(std::string &msg) {
|
|||
payload.size(), true, !is_server_);
|
||||
continue;
|
||||
}
|
||||
case Opcode::Pong: continue;
|
||||
case Opcode::Pong: {
|
||||
std::lock_guard<std::mutex> lock(ping_mutex_);
|
||||
unacked_pings_ = 0;
|
||||
continue;
|
||||
}
|
||||
case Opcode::Close: {
|
||||
if (!closed_.exchange(true)) {
|
||||
// Echo close frame back
|
||||
|
|
@ -16228,7 +15906,11 @@ ReadResult WebSocket::read(std::string &msg) {
|
|||
true, !is_server_);
|
||||
continue;
|
||||
}
|
||||
if (cont_opcode == Opcode::Pong) { continue; }
|
||||
if (cont_opcode == Opcode::Pong) {
|
||||
std::lock_guard<std::mutex> lock(ping_mutex_);
|
||||
unacked_pings_ = 0;
|
||||
continue;
|
||||
}
|
||||
if (cont_opcode == Opcode::Close) {
|
||||
if (!closed_.exchange(true)) {
|
||||
std::lock_guard<std::mutex> lock(write_mutex_);
|
||||
|
|
@ -16316,12 +15998,22 @@ void WebSocket::start_heartbeat() {
|
|||
while (!closed_) {
|
||||
ping_cv_.wait_for(lock, std::chrono::seconds(ping_interval_sec_));
|
||||
if (closed_) { break; }
|
||||
// If the peer has failed to respond to the previous pings, give up.
|
||||
// RFC 6455 does not define a pong-timeout mechanism; this is an
|
||||
// opt-in liveness check controlled by max_missed_pongs_.
|
||||
if (max_missed_pongs_ > 0 && unacked_pings_ >= max_missed_pongs_) {
|
||||
lock.unlock();
|
||||
close(CloseStatus::GoingAway, "pong timeout");
|
||||
return;
|
||||
}
|
||||
lock.unlock();
|
||||
if (!send_frame(Opcode::Ping, nullptr, 0)) {
|
||||
lock.lock();
|
||||
closed_ = true;
|
||||
break;
|
||||
}
|
||||
lock.lock();
|
||||
unacked_pings_++;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -16449,8 +16141,9 @@ bool WebSocketClient::connect() {
|
|||
Request req;
|
||||
req.method = "GET";
|
||||
req.path = path_;
|
||||
ws_ = std::unique_ptr<WebSocket>(
|
||||
new WebSocket(std::move(strm), req, false, websocket_ping_interval_sec_));
|
||||
ws_ = std::unique_ptr<WebSocket>(new WebSocket(std::move(strm), req, false,
|
||||
websocket_ping_interval_sec_,
|
||||
websocket_max_missed_pongs_));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -16494,6 +16187,10 @@ void WebSocketClient::set_websocket_ping_interval(time_t sec) {
|
|||
websocket_ping_interval_sec_ = sec;
|
||||
}
|
||||
|
||||
void WebSocketClient::set_websocket_max_missed_pongs(int count) {
|
||||
websocket_max_missed_pongs_ = count;
|
||||
}
|
||||
|
||||
void WebSocketClient::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; }
|
||||
|
||||
void WebSocketClient::set_address_family(int family) {
|
||||
|
|
|
|||
139
vendor/cpp-httplib/httplib.h
vendored
139
vendor/cpp-httplib/httplib.h
vendored
|
|
@ -8,8 +8,8 @@
|
|||
#ifndef CPPHTTPLIB_HTTPLIB_H
|
||||
#define CPPHTTPLIB_HTTPLIB_H
|
||||
|
||||
#define CPPHTTPLIB_VERSION "0.42.0"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x002a00"
|
||||
#define CPPHTTPLIB_VERSION "0.43.1"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x002b01"
|
||||
|
||||
#ifdef _WIN32
|
||||
#if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0A00
|
||||
|
|
@ -205,6 +205,10 @@
|
|||
#define CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND 30
|
||||
#endif
|
||||
|
||||
#ifndef CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS
|
||||
#define CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS 0
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Headers
|
||||
*/
|
||||
|
|
@ -1720,6 +1724,8 @@ public:
|
|||
Server &set_websocket_ping_interval(
|
||||
const std::chrono::duration<Rep, Period> &duration);
|
||||
|
||||
Server &set_websocket_max_missed_pongs(int count);
|
||||
|
||||
bool bind_to_port(const std::string &host, int port, int socket_flags = 0);
|
||||
int bind_to_any_port(const std::string &host, int socket_flags = 0);
|
||||
bool listen_after_bind();
|
||||
|
|
@ -1756,6 +1762,7 @@ protected:
|
|||
size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH;
|
||||
time_t websocket_ping_interval_sec_ =
|
||||
CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND;
|
||||
int websocket_max_missed_pongs_ = CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS;
|
||||
|
||||
private:
|
||||
using Handlers =
|
||||
|
|
@ -1767,6 +1774,14 @@ private:
|
|||
static std::unique_ptr<detail::MatcherBase>
|
||||
make_matcher(const std::string &pattern);
|
||||
|
||||
template <typename H>
|
||||
Server &add_handler(
|
||||
std::vector<std::pair<std::unique_ptr<detail::MatcherBase>, H>> &handlers,
|
||||
const std::string &pattern, H handler) {
|
||||
handlers.emplace_back(make_matcher(pattern), std::move(handler));
|
||||
return *this;
|
||||
}
|
||||
|
||||
Server &set_error_handler_core(HandlerWithResponse handler, std::true_type);
|
||||
Server &set_error_handler_core(Handler handler, std::false_type);
|
||||
|
||||
|
|
@ -1928,15 +1943,6 @@ private:
|
|||
int ssl_error_ = 0;
|
||||
uint64_t ssl_backend_error_ = 0;
|
||||
#endif
|
||||
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
public:
|
||||
[[deprecated("Use ssl_backend_error() instead. "
|
||||
"This function will be removed by v1.0.0.")]]
|
||||
uint64_t ssl_openssl_error() const {
|
||||
return ssl_backend_error_;
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
struct ClientConnection {
|
||||
|
|
@ -2409,22 +2415,6 @@ protected:
|
|||
int last_ssl_error_ = 0;
|
||||
uint64_t last_backend_error_ = 0;
|
||||
#endif
|
||||
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
public:
|
||||
[[deprecated("Use load_ca_cert_store() instead. "
|
||||
"This function will be removed by v1.0.0.")]]
|
||||
void set_ca_cert_store(X509_STORE *ca_cert_store);
|
||||
|
||||
[[deprecated("Use tls::create_ca_store() instead. "
|
||||
"This function will be removed by v1.0.0.")]]
|
||||
X509_STORE *create_ca_cert_store(const char *ca_cert, std::size_t size) const;
|
||||
|
||||
[[deprecated("Use set_server_certificate_verifier(VerifyCallback) instead. "
|
||||
"This function will be removed by v1.0.0.")]]
|
||||
virtual void set_server_certificate_verifier(
|
||||
std::function<SSLVerifierResponse(SSL *ssl)> verifier);
|
||||
#endif
|
||||
};
|
||||
|
||||
class Client {
|
||||
|
|
@ -2599,7 +2589,6 @@ public:
|
|||
void set_follow_location(bool on);
|
||||
|
||||
void set_path_encode(bool on);
|
||||
void set_url_encode(bool on);
|
||||
|
||||
void set_compress(bool on);
|
||||
|
||||
|
|
@ -2647,22 +2636,6 @@ public:
|
|||
private:
|
||||
bool is_ssl_ = false;
|
||||
#endif
|
||||
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
public:
|
||||
[[deprecated("Use tls_context() instead. "
|
||||
"This function will be removed by v1.0.0.")]]
|
||||
SSL_CTX *ssl_context() const;
|
||||
|
||||
[[deprecated("Use set_session_verifier(session_t) instead. "
|
||||
"This function will be removed by v1.0.0.")]]
|
||||
void set_server_certificate_verifier(
|
||||
std::function<SSLVerifierResponse(SSL *ssl)> verifier);
|
||||
|
||||
[[deprecated("Use Result::ssl_backend_error() instead. "
|
||||
"This function will be removed by v1.0.0.")]]
|
||||
long get_verify_result() const;
|
||||
#endif
|
||||
};
|
||||
|
||||
#ifdef CPPHTTPLIB_SSL_ENABLED
|
||||
|
|
@ -2708,29 +2681,6 @@ private:
|
|||
std::mutex ctx_mutex_;
|
||||
|
||||
int last_ssl_error_ = 0;
|
||||
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
public:
|
||||
[[deprecated("Use SSLServer(PemMemory) or "
|
||||
"SSLServer(ContextSetupCallback) instead. "
|
||||
"This constructor will be removed by v1.0.0.")]]
|
||||
SSLServer(X509 *cert, EVP_PKEY *private_key,
|
||||
X509_STORE *client_ca_cert_store = nullptr);
|
||||
|
||||
[[deprecated("Use SSLServer(ContextSetupCallback) instead. "
|
||||
"This constructor will be removed by v1.0.0.")]]
|
||||
SSLServer(
|
||||
const std::function<bool(SSL_CTX &ssl_ctx)> &setup_ssl_ctx_callback);
|
||||
|
||||
[[deprecated("Use tls_context() instead. "
|
||||
"This function will be removed by v1.0.0.")]]
|
||||
SSL_CTX *ssl_context() const;
|
||||
|
||||
[[deprecated("Use update_certs_pem() instead. "
|
||||
"This function will be removed by v1.0.0.")]]
|
||||
void update_certs(X509 *cert, EVP_PKEY *private_key,
|
||||
X509_STORE *client_ca_cert_store = nullptr);
|
||||
#endif
|
||||
};
|
||||
|
||||
class SSLClient final : public ClientImpl {
|
||||
|
|
@ -2794,6 +2744,9 @@ private:
|
|||
Response &res, bool &success, Error &error);
|
||||
bool initialize_ssl(Socket &socket, Error &error);
|
||||
|
||||
void init_ctx();
|
||||
void reset_ctx_on_error();
|
||||
|
||||
bool load_certs();
|
||||
|
||||
tls::ctx_t ctx_ = nullptr;
|
||||
|
|
@ -2811,42 +2764,6 @@ private:
|
|||
friend class ClientImpl;
|
||||
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
public:
|
||||
[[deprecated("Use SSLClient(host, port, PemMemory) instead. "
|
||||
"This constructor will be removed by v1.0.0.")]]
|
||||
explicit SSLClient(const std::string &host, int port, X509 *client_cert,
|
||||
EVP_PKEY *client_key,
|
||||
const std::string &private_key_password = std::string());
|
||||
|
||||
[[deprecated("Use Result::ssl_backend_error() instead. "
|
||||
"This function will be removed by v1.0.0.")]]
|
||||
long get_verify_result() const;
|
||||
|
||||
[[deprecated("Use tls_context() instead. "
|
||||
"This function will be removed by v1.0.0.")]]
|
||||
SSL_CTX *ssl_context() const;
|
||||
|
||||
// Override of a deprecated virtual in ClientImpl. Suppress C4996 /
|
||||
// -Wdeprecated-declarations on the override declaration itself so that
|
||||
// MSVC /sdl builds compile cleanly. Will be removed together with the
|
||||
// base virtual by v1.0.0.
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4996)
|
||||
#elif defined(__GNUC__) || defined(__clang__)
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
#endif
|
||||
[[deprecated("Use set_session_verifier(session_t) instead. "
|
||||
"This function will be removed by v1.0.0.")]]
|
||||
void set_server_certificate_verifier(
|
||||
std::function<SSLVerifierResponse(SSL *ssl)> verifier) override;
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(pop)
|
||||
#elif defined(__GNUC__) || defined(__clang__)
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
private:
|
||||
bool verify_host(X509 *server_cert) const;
|
||||
bool verify_host_with_subject_alt_name(X509 *server_cert) const;
|
||||
|
|
@ -3818,17 +3735,21 @@ private:
|
|||
|
||||
WebSocket(
|
||||
Stream &strm, const Request &req, bool is_server,
|
||||
time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND)
|
||||
time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND,
|
||||
int max_missed_pongs = CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS)
|
||||
: strm_(strm), req_(req), is_server_(is_server),
|
||||
ping_interval_sec_(ping_interval_sec) {
|
||||
ping_interval_sec_(ping_interval_sec),
|
||||
max_missed_pongs_(max_missed_pongs) {
|
||||
start_heartbeat();
|
||||
}
|
||||
|
||||
WebSocket(
|
||||
std::unique_ptr<Stream> &&owned_strm, const Request &req, bool is_server,
|
||||
time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND)
|
||||
time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND,
|
||||
int max_missed_pongs = CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS)
|
||||
: strm_(*owned_strm), owned_strm_(std::move(owned_strm)), req_(req),
|
||||
is_server_(is_server), ping_interval_sec_(ping_interval_sec) {
|
||||
is_server_(is_server), ping_interval_sec_(ping_interval_sec),
|
||||
max_missed_pongs_(max_missed_pongs) {
|
||||
start_heartbeat();
|
||||
}
|
||||
|
||||
|
|
@ -3840,6 +3761,8 @@ private:
|
|||
Request req_;
|
||||
bool is_server_;
|
||||
time_t ping_interval_sec_;
|
||||
int max_missed_pongs_;
|
||||
int unacked_pings_ = 0;
|
||||
std::atomic<bool> closed_{false};
|
||||
std::mutex write_mutex_;
|
||||
std::thread ping_thread_;
|
||||
|
|
@ -3869,6 +3792,7 @@ public:
|
|||
void set_read_timeout(time_t sec, time_t usec = 0);
|
||||
void set_write_timeout(time_t sec, time_t usec = 0);
|
||||
void set_websocket_ping_interval(time_t sec);
|
||||
void set_websocket_max_missed_pongs(int count);
|
||||
void set_tcp_nodelay(bool on);
|
||||
void set_address_family(int family);
|
||||
void set_ipv6_v6only(bool on);
|
||||
|
|
@ -3900,6 +3824,7 @@ private:
|
|||
time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND;
|
||||
time_t websocket_ping_interval_sec_ =
|
||||
CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND;
|
||||
int websocket_max_missed_pongs_ = CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS;
|
||||
int address_family_ = AF_UNSPEC;
|
||||
bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY;
|
||||
bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue