Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	docs/development/HOWTO-add-model.md
#	docs/multimodal.md
#	ggml/src/ggml-sycl/convert.cpp
#	ggml/src/ggml-sycl/dequantize.hpp
#	ggml/src/ggml-sycl/element_wise.cpp
#	ggml/src/ggml-sycl/gated_delta_net.cpp
#	ggml/src/ggml-sycl/ggml-sycl.cpp
#	ggml/src/ggml-sycl/upscale.cpp
#	ggml/src/ggml-webgpu/ggml-webgpu.cpp
#	tests/test-backend-ops.cpp
#	tests/test-llama-archs.cpp
#	tools/mtmd/CMakeLists.txt
This commit is contained in:
Concedo 2026-04-14 20:06:04 +08:00
commit 9c0b9b0bb1
53 changed files with 3214 additions and 720 deletions

View file

@ -261,6 +261,9 @@ static bool common_pull_file(httplib::Client & cli,
if (progress_step >= p.total / 1000 || p.downloaded == p.total) {
if (callback) {
callback->on_update(p);
if (callback->is_cancelled()) {
return false;
}
}
progress_step = 0;
}
@ -376,6 +379,9 @@ static int common_download_file_single_online(const std::string & url,
}
for (int i = 0; i < max_attempts; ++i) {
if (opts.callback && opts.callback->is_cancelled()) {
break;
}
if (i) {
LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay);
std::this_thread::sleep_for(std::chrono::seconds(delay));
@ -415,6 +421,12 @@ static int common_download_file_single_online(const std::string & url,
if (opts.callback) {
opts.callback->on_done(p, success);
}
if (opts.callback && opts.callback->is_cancelled() &&
std::filesystem::exists(path_temporary)) {
if (remove(path_temporary.c_str()) != 0) {
LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, path_temporary.c_str());
}
}
if (!success) {
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
return -1; // max attempts reached

View file

@ -21,6 +21,7 @@ public:
virtual void on_start(const common_download_progress & p) = 0;
virtual void on_update(const common_download_progress & p) = 0;
virtual void on_done(const common_download_progress & p, bool ok) = 0;
virtual bool is_cancelled() const { return false; }
};
struct common_remote_params {

View file

@ -4258,9 +4258,7 @@ class Qwen2VLVisionModel(MmprojModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Qwen2_5OmniModel")
class Qwen25OmniModel(Qwen2VLVisionModel):
has_vision_encoder = True
class Qwen25AudioModel(MmprojModel):
has_audio_encoder = True
def __init__(self, *args, **kwargs):
@ -4276,12 +4274,6 @@ class Qwen25OmniModel(Qwen2VLVisionModel):
self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["num_mel_bins"])
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-5))
def get_vision_config(self) -> dict[str, Any] | None:
return self.global_config["thinker_config"].get("vision_config")
def get_audio_config(self) -> dict[str, Any] | None:
return self.global_config["thinker_config"].get("audio_config")
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
# SinusoidsPositionEmbedding
assert self.hparams_audio is not None
@ -4312,7 +4304,32 @@ class Qwen25OmniModel(Qwen2VLVisionModel):
# this tensor is left unused in transformers code
# https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py#L1809
return
yield from super().modify_tensors(data_torch, name, bid)
yield from MmprojModel.modify_tensors(self, data_torch, name, bid)
return # skip other tensors
@ModelBase.register("Qwen2_5OmniModel")
class Qwen25OmniModel(Qwen2VLVisionModel, Qwen25AudioModel):
has_audio_encoder = True
has_vision_encoder = True
def get_vision_config(self) -> dict[str, Any] | None:
return self.global_config["thinker_config"].get("vision_config")
def get_audio_config(self) -> dict[str, Any] | None:
return self.global_config["thinker_config"].get("audio_config")
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25O)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if "visual." in name:
yield from Qwen2VLVisionModel.modify_tensors(self, data_torch, name, bid)
elif "audio_tower." in name:
yield from Qwen25AudioModel.modify_tensors(self, data_torch, name, bid)
return # skip other tensors
@ModelBase.register("InternVisionModel")
@ -4816,7 +4833,10 @@ class RND1Model(Qwen2MoeModel):
class Qwen3VLVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
if self.hparams_vision is None:
logger.info("No vision config found, skipping vision tensor processing")
return
# Compute image_size if not present
if "image_size" not in self.hparams_vision:
# For Qwen3VL/Qwen3VLMoe, compute from num_position_embeddings
@ -4837,7 +4857,9 @@ class Qwen3VLVisionModel(MmprojModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL)
# in case mixed modalities, the arch will be handled by subclass
if not self.has_audio_encoder:
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL)
self.gguf_writer.add_vision_use_gelu(True)
if self.hparams_vision is not None:
@ -4925,11 +4947,64 @@ class Qwen3VLVisionModel(MmprojModel):
return
if name.startswith("visual."):
yield from super().modify_tensors(data_torch, name, bid)
return
yield from MmprojModel.modify_tensors(self, data_torch, name, bid)
return # skip other tensors
# Fall back to parent class for other tensors
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Qwen3OmniMoeForConditionalGeneration")
class Qwen3OmniMmprojModel(Qwen3VLVisionModel, Qwen25AudioModel):
has_audio_encoder = True
has_vision_encoder = True
def get_vision_config(self) -> dict[str, Any] | None:
if self.has_vision_encoder:
return self.global_config["thinker_config"].get("vision_config")
else:
return None
def get_audio_config(self) -> dict[str, Any] | None:
if self.has_audio_encoder:
return self.global_config["thinker_config"].get("audio_config")
else:
return None
def set_gguf_parameters(self):
if self.has_vision_encoder:
Qwen3VLVisionModel.set_gguf_parameters(self)
self.gguf_writer.add_clip_vision_projector_type(gguf.VisionProjectorType.QWEN3VL)
if self.has_audio_encoder:
Qwen25AudioModel.set_gguf_parameters(self)
self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.QWEN3A)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if "visual." in name:
if not self.has_vision_encoder:
raise ValueError(f"Model does not have vision encoder, but found tensor {name}")
# need to transform vision tensor naming, so that modify_tensors() logic can be used correctly
name = name.replace("thinker.visual.", "model.visual.")
if ".merger_list." in name:
name = name.replace(".merger_list.", ".deepstack_merger_list.")
name = name.replace(".ln_q", ".norm")
name = name.replace(".mlp.0", ".linear_fc1")
name = name.replace(".mlp.2", ".linear_fc2")
elif ".merger." in name:
name = name.replace(".ln_q", ".norm")
name = name.replace(".mlp.0", ".linear_fc1")
name = name.replace(".mlp.2", ".linear_fc2")
yield from Qwen3VLVisionModel.modify_tensors(self, data_torch, name, bid)
elif "audio_tower." in name:
if not self.has_audio_encoder:
raise ValueError(f"Model does not have audio encoder, but found tensor {name}")
if "conv2d" in name and name.endswith(".bias"):
# transform conv2d bias [n_embd] --> [1, 1, n_embd]
data_torch = data_torch.unsqueeze(-1).unsqueeze(-1)
yield from Qwen25AudioModel.modify_tensors(self, data_torch, name, bid)
@ModelBase.register("Qwen3ASRForConditionalGeneration")
class Qwen3ASRMmprojModel(Qwen3OmniMmprojModel):
has_audio_encoder = True
has_vision_encoder = False
@ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration", "GlmOcrForConditionalGeneration")
@ -4992,6 +5067,8 @@ class Step3VLVisionModel(MmprojModel):
def tensor_force_quant(self, name, new_name, bid, n_dims):
if ".position_embd." in new_name:
return gguf.GGMLQuantizationType.F32
if ("mm.0." in new_name or "mm.1." in new_name) and new_name.endswith(".weight"):
return gguf.GGMLQuantizationType.F16 if self.ftype == gguf.LlamaFileType.MOSTLY_F16 else gguf.GGMLQuantizationType.F32
return super().tensor_force_quant(name, new_name, bid, n_dims)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
@ -5030,9 +5107,10 @@ class Qwen3VLTextModel(Qwen3Model):
def set_gguf_parameters(self):
super().set_gguf_parameters()
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
vision_config = self.hparams.get("vision_config", {})
if "thinker_config" in self.hparams:
vision_config = self.hparams["thinker_config"].get("vision_config", {})
else:
vision_config = self.hparams.get("vision_config", {})
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
@ -5101,6 +5179,70 @@ class Qwen3VLMoeTextModel(Qwen3MoeModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Qwen3OmniMoeForConditionalGeneration")
class Qwen3OmniMoeTextModel(Qwen3VLMoeTextModel):
model_arch = gguf.MODEL_ARCH.QWEN3VLMOE
def set_vocab(self):
super().set_vocab()
# correct BOS/EOS tokens
with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f:
tokenizer_config = json.load(f)
added_tokens = tokenizer_config.get("added_tokens_decoder", {})
for token_id, data in added_tokens.items():
if data.get("content") == "<|im_end|>":
self.gguf_writer.add_bos_token_id(int(token_id))
self.gguf_writer.add_eos_token_id(int(token_id))
break
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_num_deepstack_layers(0)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Skip vision and audio tensors - they go in the mmproj file
if "visual." in name or "audio_tower." in name \
or "talker." in name or "code2wav." in name:
return
name = name.replace("thinker.", "")
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Qwen3ASRForConditionalGeneration")
class Qwen3ASRTextModel(Qwen3VLTextModel):
model_arch = gguf.MODEL_ARCH.QWEN3VL
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_num_deepstack_layers(0)
def set_vocab(self):
super().set_vocab()
# fix chat template, use correct chatml format
self.gguf_writer.add_chat_template("{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}")
# correct BOS/EOS tokens
with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f:
tokenizer_config = json.load(f)
added_tokens = tokenizer_config.get("added_tokens_decoder", {})
for token_id, data in added_tokens.items():
if data.get("content") == "<|im_end|>":
self.gguf_writer.add_bos_token_id(int(token_id))
self.gguf_writer.add_eos_token_id(int(token_id))
break
def modify_tensors(self, data_torch, name, bid):
# qwen3-omni
name = name.replace("thinker.", "")
# Skip vision and audio tensors - they go in the mmproj file
if "visual." in name or "audio_tower." in name \
or "talker." in name or "code2wav." in name:
return
yield from super().modify_tensors(data_torch, name, bid)
class _LinearAttentionVReorderBase(Qwen3NextModel):
model_arch = gguf.MODEL_ARCH.QWEN3NEXT # overridden by subclasses
"""reorders V heads from grouped to tiled order for ggml broadcast

View file

@ -58,26 +58,48 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
size_t temp_storage_bytes = 0;
bool is_capturing = false;
#ifdef USE_CUDA_GRAPH
// Currently (confirmed for CCCL <= 3.2) DeviceSegmentedSort does not support stream capture, while DeviceSegmentedRadixSort does.
// See https://github.com/NVIDIA/cccl/issues/5661#issuecomment-3229037149
// TODO: constrain this to the CCCL versions that have this issue once it's resolved in a future CCCL release.
cudaStreamCaptureStatus capture_status;
CUDA_CHECK(cudaStreamIsCapturing(stream, &capture_status));
is_capturing = (capture_status != cudaStreamCaptureStatusNone);
#endif // USE_CUDA_GRAPH
if (order == GGML_SORT_ORDER_ASC) {
if (nrows == 1) {
CUDA_CHECK(DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream));
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream));
} else if (is_capturing) {
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs(
nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols * nrows, nrows, // num items, num segments
offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream));
} else {
CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols * nrows, nrows, // num items, num segments
offset_iterator, offset_iterator + 1, stream));
CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys,
temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols * nrows, nrows, // num items, num segments
offset_iterator, offset_iterator + 1, stream));
}
} else {
if (nrows == 1) {
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream));
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys,
temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream));
} else if (is_capturing) {
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending(
nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows,
offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream));
} else {
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
stream));
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys,
temp_indices, dst, ncols * nrows, nrows,
offset_iterator, offset_iterator + 1, stream));
}
}
@ -86,22 +108,33 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
if (order == GGML_SORT_ORDER_ASC) {
if (nrows == 1) {
CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream));
CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys,
temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream));
} else if (is_capturing) {
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
offset_iterator + 1, 0, sizeof(float) * 8, stream));
} else {
CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream));
CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
offset_iterator + 1, stream));
}
} else {
if (nrows == 1) {
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream));
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys,
temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream));
} else if (is_capturing) {
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending(
d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows,
offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream));
} else {
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
offset_iterator + 1, stream));
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys,
temp_keys, temp_indices, dst, ncols * nrows, nrows,
offset_iterator, offset_iterator + 1, stream));
}
}
}

View file

@ -134,8 +134,9 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
switch (nc) {
case 3: launch_kernel(std::integral_constant<int, 3>{}); break;
case 4: launch_kernel(std::integral_constant<int, 4>{}); break;
case 5: launch_kernel(std::integral_constant<int, 5>{}); break;
case 9: launch_kernel(std::integral_constant<int, 9>{}); break;
default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now.");
default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9 right now.");
}
}

View file

@ -2874,11 +2874,10 @@ struct vk_fa_tuning_params {
}
};
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type);
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
GGML_UNUSED(kv_type);
vk_fa_tuning_params result{};
result.path = FA_SCALAR;
@ -2930,7 +2929,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device,
result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) {
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, kv_type)) {
result.block_rows /= 2;
}
@ -3461,21 +3460,47 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (device->fp16) {
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product && device->subgroup_clustered) {
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _int8)
} else
#endif
{
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
}
} else {
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product && device->subgroup_clustered) {
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32_int8)
} else
#endif
{
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
}
}
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->coopmat1_fa_support) {
@ -8818,7 +8843,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
}
}
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
GGML_UNUSED(f32acc);
// Needs to be kept up to date on shader changes
const uint32_t wg_size = params.workgroup_size;
@ -8827,21 +8852,51 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
const bool mmq = device->integer_dot_product && device->subgroup_clustered &&
(kv_type == GGML_TYPE_Q4_0 || kv_type == GGML_TYPE_Q4_1 ||
kv_type == GGML_TYPE_Q5_0 || kv_type == GGML_TYPE_Q5_1 ||
kv_type == GGML_TYPE_Q8_0 || kv_type == GGML_TYPE_IQ4_NL);
// tmpsh is overestimated slightly
const uint32_t tmpsh = wg_size * sizeof(float);
const uint32_t tmpshv4 = wg_size * 4 * float_type_size;
const uint32_t masksh = Bc * (Br + 1) * float_type_size;
const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
uint32_t Qf, kvsh, kblocksh_size;
if (mmq) {
// block_b_cache: int32_t qs[8] + FLOAT_TYPEV2 ds
const uint32_t block_b_size = 8 * sizeof(int32_t) + 2 * float_type_size;
Qf = Br * (hsk / 32) * block_b_size;
const uint32_t D = std::max(hsk, hsv);
const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
// kvsh uses D = HSV (K goes through kblocksh instead)
kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;
// block_a_cache size depends on quant type
uint32_t block_a_size;
switch (kv_type) {
case GGML_TYPE_Q4_0: block_a_size = 4 * sizeof(uint32_t) + float_type_size; break;
case GGML_TYPE_Q4_1: block_a_size = 4 * sizeof(uint32_t) + 2 * float_type_size; break;
case GGML_TYPE_Q5_0: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + float_type_size; break;
case GGML_TYPE_Q5_1: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + 2 * float_type_size; break;
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL: block_a_size = 8 * sizeof(int32_t) + float_type_size; break;
default: block_a_size = 0; break;
}
kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size;
} else {
Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
const uint32_t D = std::max(hsk, hsv);
kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
kblocksh_size = 0;
}
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh + kblocksh_size;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", mmq=" << mmq << ", total_size=" << total_size << ", supported=" << supported);
return supported;
}

View file

@ -10,6 +10,13 @@
#extension GL_EXT_shader_subgroup_extended_types_float16 : require
#endif
#ifdef MMQ
#extension GL_EXT_integer_dot_product : require
#extension GL_KHR_shader_subgroup_clustered : require
#include "mul_mmq_shmem_types.glsl"
#endif
#extension GL_KHR_shader_subgroup_shuffle : enable
#extension GL_KHR_shader_subgroup_vote : enable
@ -41,15 +48,34 @@ shared FLOAT_TYPEV4 tmpshv4[tmpsh_size];
const uint32_t masksh_stride = Br + 1;
shared FLOAT_TYPE masksh[Bc * masksh_stride];
#ifndef MMQ
const uint32_t qf_stride = HSK / 4 + 1;
shared FLOAT_TYPEV4 Qf[Br * qf_stride];
#else
const uint32_t qf_stride = HSK / 32;
shared block_b_cache Qf[Br * qf_stride];
#endif
#ifndef MMQ
const uint32_t D = HSK > HSV ? HSK : HSV;
#else
const uint32_t D = HSV;
#endif
const uint32_t kvsh_stride = D / 4 + 1;
shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1];
#ifdef MMQ
shared block_a_cache kblocksh[SHMEM_STAGING != 0 ? Bc * qf_stride : 1];
#endif
shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1];
#ifdef MMQ
#include "flash_attn_mmq_funcs.glsl"
#endif
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
@ -82,10 +108,39 @@ void main() {
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t r = (idx + tid) / (HSK / 4);
if (r < Br && d < HSK / 4 &&
i * Br + r < N) {
const bool is_in_bounds = r < Br && d < HSK / 4 && i * Br + r < N;
#ifndef MMQ
if (is_in_bounds) {
Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
}
#else
const uint buf_ib = r * qf_stride + d / 8;
const uint buf_iqs = d % 8;
FLOAT_TYPEV4 vals = is_in_bounds ? FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale) : FLOAT_TYPEV4(0.0f);
const FLOAT_TYPEV4 abs_vals = abs(vals);
const FLOAT_TYPE thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
const FLOAT_TYPE amax = subgroupClusteredMax(thread_max, 8);
const FLOAT_TYPE qd = amax / FLOAT_TYPE(127.0);
const FLOAT_TYPE qd_inv = qd != FLOAT_TYPE(0.0) ? FLOAT_TYPE(1.0) / qd : FLOAT_TYPE(0.0);
vals = round(vals * qd_inv);
Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals));
#if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
if (buf_iqs == 0) {
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0);
}
#else // Q4_0, Q4_1, Q5_0, Q5_1
const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w;
const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8);
if (buf_iqs == 0) {
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd);
}
#endif
#endif
}
barrier();
@ -195,6 +250,7 @@ void main() {
if (SHMEM_STAGING != 0) {
barrier();
#ifndef MMQ
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
@ -214,9 +270,29 @@ void main() {
kvsh[c * kvsh_stride + d] = K_Tf;
}
}
#else // MMQ
const uint ints_per_block = 8 / QUANT_R_MMQ;
const uint quant_iters = Bc * HSK / 32 * ints_per_block;
[[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) {
const uint32_t iqs = (idx + tid) % ints_per_block;
const uint32_t ib = (idx + tid) / ints_per_block;
const uint32_t c = ib / (HSK / 32);
const uint32_t block = ib % (HSK / 32);
if (idx + gl_WorkGroupSize.x <= quant_iters || c < Bc) {
const uint buf_ib = c * qf_stride + block;
if (!KV_bounds_check || j * Bc + c < KV) {
const uint global_ib = (j * Bc + c) * k_stride + block;
k_block_to_shmem(buf_ib, global_ib, iqs, k_offset);
} else {
k_block_to_shmem_zero(buf_ib, iqs);
}
}
}
#endif // MMQ
barrier();
}
#ifndef MMQ
// More d iterations means Q register caching becomes relevant
// Few iterations means the additional registers needed are worse than the speed-up from caching
if (HSK_per_thread / 4 > 4) {
@ -275,6 +351,110 @@ void main() {
}
}
}
#else // MMQ
const uint hsk4 = HSK_per_thread / 4;
const uint d_per_step = (hsk4 % 8 == 0) ? 8 :
(hsk4 % 4 == 0) ? 4 :
(hsk4 % 2 == 0) ? 2 : 1;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
[[unroll]] for (uint32_t d_block = 0; d_block < HSK_per_thread / 4; d_block += d_per_step) {
int32_t k_quants[d_per_step];
ACC_TYPEV2 k_dm;
if (SHMEM_STAGING != 0) {
const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8;
const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx;
#if QUANT_AUXF == 1
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0);
#else
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm);
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
if (d_per_step == 8) {
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
uint vui = kblocksh[buf_ib].qs[d];
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF;
uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF;
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
#endif
}
} else
#endif
{
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d);
}
}
} else {
const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block);
const uint ib = coord / BLOCK_SIZE;
const uint iqs = (coord % BLOCK_SIZE);
#if QUANT_AUXF == 1
k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0);
#else
k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset));
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
if (d_per_step == 8) {
#if defined(DATA_A_Q5_0)
uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0],
k_packed.k_data_packed16[k_offset + ib].qh[1]));
#elif defined(DATA_A_Q5_1)
uint qh = k_packed.k_data_packed16[k_offset + ib].qh;
#endif
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
#if defined(A_TYPE_PACKED32)
uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d];
#else
uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0],
k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1]));
#endif
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
uint qh_lo = (qh >> (d * 4)) & 0xF;
uint qh_hi = (qh >> (d * 4 + 16)) & 0xF;
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
#endif
}
} else
#endif
{
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset);
}
}
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint qib = tile_row(r) * qf_stride + (d_tid * (HSK_per_thread / 4) + d_block) / 8;
const uint qiqs = (d_tid * (HSK_per_thread / 4) + d_block) % 8;
int32_t acc = 0;
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
acc += dotPacked4x8EXT(Qf[qib].qs[qiqs + d], k_quants[d]);
}
Sf[r][c] += ACC_TYPE(acc) * ACC_TYPE(Qf[qib].ds.x) * k_dm.x;
if ((d_tid * (HSK_per_thread / 4) + d_block) % 8 == 0) {
Sf[r][c] += k_dot_correction(qib, k_dm);
}
}
}
}
#endif // MMQ
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
// Compute sum across the D_split

View file

@ -89,6 +89,11 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 1) readonly buffer K_PACKED32 {A_TYPE_PACKED32 k_data_packed32[];} k_packed32;
layout (binding = 2) readonly buffer V_PACKED32 {A_TYPE_PACKED32 v_data_packed32[];} v_packed32;
#endif
#ifndef BLOCK_SIZE
#define BLOCK_SIZE 1
#endif

View file

@ -0,0 +1,149 @@
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
#ifdef DATA_A_Q4_0
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
#else
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
#endif
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
return int32_t(vui & 0x0F0F0F0F);
}
#endif
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
#ifdef DATA_A_Q5_0
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
uint qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qh[0],
k_packed.k_data_packed16[a_offset + ib].qh[1]));
#else
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
uint qh = k_packed.k_data_packed16[a_offset + ib].qh;
#endif
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
uint qh_bits = (qh >> iqs) & 0xF;
return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
}
#endif
#if defined(DATA_A_Q8_0)
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
return pack32(i16vec2(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2], k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1]));
}
#endif
#if defined(DATA_A_IQ4_NL)
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
u8vec4 idx = unpack8(vui & 0x0F0F0F0F);
return pack32(i8vec4(kvalues_iq4nl_const[idx.x],
kvalues_iq4nl_const[idx.y],
kvalues_iq4nl_const[idx.z],
kvalues_iq4nl_const[idx.w]));
}
#endif
#if QUANT_AUXF == 1
FLOAT_TYPE get_k_d(uint ib, uint a_offset) {
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d);
}
#else
FLOAT_TYPEV2 get_k_dm(uint ib, uint a_offset) {
return FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + ib].dm);
}
#endif
void k_block_to_shmem(const uint buf_ib, const uint global_ib, const uint iqs, const uint a_offset) {
#if defined(DATA_A_Q4_0)
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
#elif defined(DATA_A_Q4_1)
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
#elif defined(DATA_A_Q5_0)
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
if (iqs == 0) {
kblocksh[buf_ib].qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qh[0],
k_packed.k_data_packed16[a_offset + global_ib].qh[1]));
}
#elif defined(DATA_A_Q5_1)
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
if (iqs == 0) {
kblocksh[buf_ib].qh = k_packed.k_data_packed16[a_offset + global_ib].qh;
}
#elif defined(DATA_A_Q8_0)
kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
#elif defined(DATA_A_IQ4_NL)
const uint qs = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
kblocksh[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_iq4nl_const[i_a0.x], kvalues_iq4nl_const[i_a0.y],
kvalues_iq4nl_const[i_a0.z], kvalues_iq4nl_const[i_a0.w]));
kblocksh[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_iq4nl_const[i_a1.x], kvalues_iq4nl_const[i_a1.y],
kvalues_iq4nl_const[i_a1.z], kvalues_iq4nl_const[i_a1.w]));
#endif
if (iqs == 0) {
#if QUANT_AUXF == 1
kblocksh[buf_ib].dm = FLOAT_TYPE(k_packed.k_data_packed16[a_offset + global_ib].d);
#else
kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + global_ib].dm);
#endif
}
}
int32_t get_k_qs_shmem(const uint buf_ib, const uint pos) {
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
uint sub = pos % 4;
uint shift = ((pos % 8) >= 4) ? 4 : 0;
return int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
uint sub = pos % 4;
uint shift = ((pos % 8) >= 4) ? 4 : 0;
int32_t result = int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4)) & 0xF;
return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
#elif defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
return kblocksh[buf_ib].qs[pos];
#endif
}
ACC_TYPE k_dot_correction(const uint qib, const ACC_TYPEV2 k_dm) {
#if defined(DATA_A_Q4_0)
return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
#elif defined(DATA_A_Q5_0)
return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
#elif defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
return ACC_TYPE(Qf[qib].ds.y) * k_dm.y;
#else
return ACC_TYPE(0.0);
#endif
}
void k_block_to_shmem_zero(const uint buf_ib, const uint iqs) {
kblocksh[buf_ib].qs[iqs] = 0;
#if defined(DATA_A_IQ4_NL)
kblocksh[buf_ib].qs[iqs + 4] = 0;
#endif
if (iqs == 0) {
#if QUANT_AUXF == 1
kblocksh[buf_ib].dm = FLOAT_TYPE(0.0f);
#else
kblocksh[buf_ib].dm = FLOAT_TYPEV2(0.0f);
#endif
}
}

View file

@ -32,6 +32,12 @@ struct block_a_cache {
int32_t qs[32/4];
FLOAT_TYPE dm;
};
#elif defined(DATA_A_IQ4_NL)
#define QUANT_R_MMQ 2
struct block_a_cache {
int32_t qs[8];
FLOAT_TYPE dm;
};
#elif defined(DATA_A_MXFP4)
#define QUANT_R_MMQ 2
struct block_a_cache {

View file

@ -1692,6 +1692,7 @@ struct block_iq4_nl_packed16
#if defined(DATA_A_IQ4_NL)
#define QUANT_K QUANT_K_IQ4_NL
#define QUANT_R QUANT_R_IQ4_NL
#define QUANT_AUXF 1
#define A_TYPE block_iq4_nl
#define A_TYPE_PACKED16 block_iq4_nl_packed16
#endif

View file

@ -421,8 +421,8 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s
}
static std::vector<std::future<void>> compiles;
void string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
void string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false, const std::string& suffix = "") {
name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")) + suffix;
std::string out_path = join_paths(output_dir, name + ".spv");
// if (input_filepath == "") {
@ -642,15 +642,16 @@ void process_shaders() {
for (const bool& fp16 : {false, true}) {
std::map<std::string, std::string> base_dict;
if (fp16) {
base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV2", "f16vec2"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
} else {
base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}};
base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"FLOAT_TYPEV4", "vec4"}};
}
// flash attention
for (const bool& f16acc : {false, true}) {
std::map<std::string, std::string> fa_base_dict = base_dict;
fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float";
fa_base_dict["ACC_TYPEV2"] = fp16 && f16acc ? "f16vec2" : "vec2";
fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4";
if (fp16 && f16acc) {
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
@ -689,6 +690,12 @@ void process_shaders() {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (tname != "f32") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }, {"MMQ", "1"}}), fp16, false, false, f16acc, "_int8");
}
#endif
}
}
}

View file

@ -798,6 +798,8 @@ class MODEL_TENSOR(IntEnum):
A_ENC_INP_PROJ = auto() # gemma4
A_ENC_CONV1D = auto()
A_ENC_CONV1D_NORM = auto() # gemma3n
A_ENC_CONV2D = auto()
A_ENC_CONV_OUT = auto()
A_PRE_NORM = auto()
A_POST_NORM = auto()
A_ENC_LAYER_PRE_NORM = auto() # gemma3n
@ -1280,6 +1282,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS: "a.embd_to_logits",
MODEL_TENSOR.A_ENC_INP_PROJ: "a.input_projection",
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
MODEL_TENSOR.A_ENC_CONV2D: "a.conv2d.{bid}",
MODEL_TENSOR.A_ENC_CONV_OUT: "a.conv_out",
MODEL_TENSOR.A_ENC_CONV1D_NORM: "a.conv1d.{bid}.norm",
MODEL_TENSOR.A_PRE_NORM: "a.pre_ln",
MODEL_TENSOR.A_POST_NORM: "a.post_ln",
@ -1426,6 +1430,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS,
MODEL_TENSOR.A_ENC_INP_PROJ,
MODEL_TENSOR.A_ENC_CONV1D,
MODEL_TENSOR.A_ENC_CONV2D,
MODEL_TENSOR.A_ENC_CONV_OUT,
MODEL_TENSOR.A_ENC_CONV1D_NORM,
MODEL_TENSOR.A_PRE_NORM,
MODEL_TENSOR.A_POST_NORM,
@ -4112,6 +4118,7 @@ class VisionProjectorType:
ULTRAVOX = "ultravox"
INTERNVL = "internvl"
QWEN2A = "qwen2a" # audio
QWEN3A = "qwen3a" # audio
GLMA = "glma" # audio
QWEN25O = "qwen2.5o" # omni
VOXTRAL = "voxtral"

View file

@ -1892,6 +1892,14 @@ class TensorNameMap:
"conformer.subsample_conv_projection.input_proj_linear", # gemma4
),
MODEL_TENSOR.A_ENC_CONV2D: (
"audio_tower.conv2d{bid}", # qwen3omni
),
MODEL_TENSOR.A_ENC_CONV_OUT: (
"audio_tower.conv_out", # qwen3omni
),
MODEL_TENSOR.A_PRE_NORM: (),
MODEL_TENSOR.A_POST_NORM: (
@ -2042,7 +2050,8 @@ class TensorNameMap:
MODEL_TENSOR.A_MMPROJ: (
"audio.multi_modal_projector.linear_{bid}", # ultravox, meralion
"audio_adapter.model.{bid}" # lfm2
"audio_adapter.model.{bid}", # lfm2
"audio_tower.proj{bid}", # qwen3omni
),
MODEL_TENSOR.A_MMPROJ_FC: (

View file

@ -2123,6 +2123,7 @@ void kcpp_init_audio_proj(clip_ctx * ctx_a)
switch (proj) {
case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_QWEN25O:
case PROJECTOR_TYPE_QWEN3A:
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_GLMA:
@ -2133,6 +2134,9 @@ void kcpp_init_audio_proj(clip_ctx * ctx_a)
case PROJECTOR_TYPE_LFM2A:
audio_preproc = std::make_unique<mtmd_audio_preprocessor_conformer>(ctx_a);
break;
case PROJECTOR_TYPE_GEMMA4A:
audio_preproc = std::make_unique<mtmd_audio_preprocessor_gemma4a>(ctx_a);
break;
default:
GGML_ABORT("unsupported audio projector type");
}
@ -3700,7 +3704,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
if(clp_ctx_a)
{
int ptype = clip_get_projector_type_ext(clp_ctx_a);
if(ptype==PROJECTOR_TYPE_QWEN2A) //qwen omni
if(ptype==PROJECTOR_TYPE_QWEN2A || ptype==PROJECTOR_TYPE_QWEN3A || ptype==PROJECTOR_TYPE_QWEN25O) //qwen omni
{
aud_start = "<|audio_bos|>";
aud_end = "<|audio_eos|>\n";
@ -3710,6 +3714,11 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
aud_start = "[INST][BEGIN_AUDIO]";
aud_end = "[/INST]\n";
}
else if(ptype==PROJECTOR_TYPE_GEMMA4A)
{
aud_start = "<|audio>";
aud_end = "<audio|>\n";
}
}
TokenizeString(aud_start, lv.chunk_start_seq, file_format, false);

View file

@ -135,6 +135,8 @@
// ultravox
#define TN_CONV1D "a.conv1d.%d.%s"
#define TN_CONV2D "a.conv2d.%d.%s"
#define TN_CONV_OUT "a.conv_out.%s"
#define TN_MM_AUDIO_MLP "mm.a.mlp.%d.%s"
#define TN_MM_AUDIO_FC "mm.a.fc.%s" // fully connected layer
#define TN_MM_NORM_PRE "mm.a.norm_pre.%s"
@ -181,6 +183,21 @@
#define TN_CONV_PW1 "%s.blk.%d.conv_pw1.%s"
#define TN_CONV_PW2 "%s.blk.%d.conv_pw2.%s"
// gemma4 audio conformer
#define TN_A_MM_INP_PROJ "mm.a.input_projection.%s"
#define TN_A_MM_SOFT_EMB_N "mm.a.soft_emb_norm.%s"
#define TN_A_INP_PROJ "a.input_projection.%s"
#define TN_A_CONV1D "a.conv1d.%d.%s"
#define TN_A_CONV1D_NORM "a.conv1d.%d.norm.%s"
#define TN_A_OUT_PROJ "a.pre_encode.out.%s"
#define TN_A_ATTN_PRE_NORM "%s.blk.%d.attn_pre_norm.%s"
#define TN_A_ATTN_POST_NORM "%s.blk.%d.attn_post_norm.%s"
#define TN_A_ATTN_K_REL "%s.blk.%d.attn_k_rel.%s"
#define TN_A_PER_DIM_SCALE "%s.blk.%d.per_dim_scale.%s"
#define TN_A_PER_DIM_K_SCALE "%s.blk.%d.per_dim_k_scale.%s"
#define TN_A_FFN_POST_NORM "%s.blk.%d.ffn_post_norm.%s"
#define TN_A_FFN_POST_NORM_1 "%s.blk.%d.ffn_post_norm_1.%s"
// mobilenetv5 (gemma3n) definitions
#define TN_MNV5_STEM_CONV "v.conv_stem.conv.weight"
#define TN_MNV5_STEM_BIAS "v.conv_stem.conv.bias"
@ -256,6 +273,7 @@ enum projector_type {
PROJECTOR_TYPE_INTERNVL,
PROJECTOR_TYPE_LLAMA4,
PROJECTOR_TYPE_QWEN2A,
PROJECTOR_TYPE_QWEN3A,
PROJECTOR_TYPE_GLMA,
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
PROJECTOR_TYPE_VOXTRAL,
@ -300,6 +318,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
{ PROJECTOR_TYPE_QWEN2A, "qwen2a"},
{ PROJECTOR_TYPE_QWEN3A, "qwen3a"},
{ PROJECTOR_TYPE_GLMA, "glma"},
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},

View file

@ -217,6 +217,13 @@ struct clip_layer {
ggml_tensor * conv_pw2_w = nullptr;
ggml_tensor * conv_pw2_b = nullptr;
// gemma4 audio conformer per-layer
ggml_tensor * attn_pre_norm_w = nullptr;
ggml_tensor * attn_k_rel_w = nullptr;
ggml_tensor * per_dim_scale_w = nullptr;
ggml_tensor * per_dim_k_scale_w = nullptr;
ggml_tensor * ff_post_norm_1_w = nullptr;
bool has_deepstack() const {
return deepstack_fc1_w != nullptr;
}
@ -406,10 +413,20 @@ struct clip_model {
ggml_tensor * conv1d_1_b = nullptr;
ggml_tensor * conv1d_2_w = nullptr;
ggml_tensor * conv1d_2_b = nullptr;
ggml_tensor * conv_out_w = nullptr;
ggml_tensor * conv_out_b = nullptr;
ggml_tensor * mm_norm_pre_w = nullptr;
ggml_tensor * mm_norm_pre_b = nullptr;
ggml_tensor * mm_norm_mid_w = nullptr;
// qwen3a
ggml_tensor * conv2d_1_w = nullptr;
ggml_tensor * conv2d_1_b = nullptr;
ggml_tensor * conv2d_2_w = nullptr;
ggml_tensor * conv2d_2_b = nullptr;
ggml_tensor * conv2d_3_w = nullptr;
ggml_tensor * conv2d_3_b = nullptr;
// cogvlm
ggml_tensor * mm_post_fc_norm_w = nullptr;
ggml_tensor * mm_post_fc_norm_b = nullptr;
@ -459,6 +476,15 @@ struct clip_model {
};
std::map<std::string, clamp_info> clamp_info_map;
// gemma4 audio conformer
std::array<ggml_tensor *, 2> sscp_conv_w = {nullptr};
std::array<ggml_tensor *, 2> sscp_conv_b = {nullptr};
std::array<ggml_tensor *, 2> sscp_norm_w = {nullptr};
ggml_tensor * sscp_inp_proj_w = nullptr;
ggml_tensor * sscp_inp_proj_b = nullptr;
ggml_tensor * audio_out_proj_w = nullptr;
ggml_tensor * audio_out_proj_b = nullptr;
bool audio_has_avgpool() const {
return proj_type == PROJECTOR_TYPE_QWEN2A
|| proj_type == PROJECTOR_TYPE_VOXTRAL

View file

@ -54,6 +54,7 @@
#include "models/cogvlm.cpp"
#include "models/conformer.cpp"
#include "models/dotsocr.cpp"
#include "models/gemma4a.cpp"
#include "models/gemma4v.cpp"
#include "models/glm4v.cpp"
#include "models/hunyuanocr.cpp"
@ -68,6 +69,7 @@
#include "models/pixtral.cpp"
#include "models/qwen2vl.cpp"
#include "models/qwen3vl.cpp"
#include "models/qwen3a.cpp"
#include "models/step3vl.cpp"
#include "models/siglip.cpp"
#include "models/whisper-enc.cpp"
@ -986,10 +988,18 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
builder = std::make_unique<clip_graph_conformer>(ctx, img);
} break;
case PROJECTOR_TYPE_GEMMA4A:
{
builder = std::make_unique<clip_graph_gemma4a>(ctx, img);
} break;
case PROJECTOR_TYPE_GLM4V:
{
builder = std::make_unique<clip_graph_glm4v>(ctx, img);
} break;
case PROJECTOR_TYPE_QWEN3A:
{
builder = std::make_unique<clip_graph_qwen3a>(ctx, img);
} break;
case PROJECTOR_TYPE_YOUTUVL:
{
builder = std::make_unique<clip_graph_youtuvl>(ctx, img);
@ -1481,6 +1491,7 @@ struct clip_model_loader {
} break;
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_QWEN3A:
case PROJECTOR_TYPE_GLMA:
case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_MERALION:
@ -1542,6 +1553,16 @@ struct clip_model_loader {
hparams.audio_window_len = 400;
hparams.audio_hop_len = 160;
} break;
case PROJECTOR_TYPE_GEMMA4A:
{
// Gemma4 feature_extraction_gemma4.py:
// frame_length_ms=20 -> 320 samples, n_fft=512, hop=10ms -> 160
hparams.audio_chunk_len = 0; // no fixed-length padding
hparams.audio_sample_rate = 16000;
hparams.audio_n_fft = 512;
hparams.audio_window_len = 320; // 20ms frame (NOT 25ms/400)
hparams.audio_hop_len = 160;
} break;
case PROJECTOR_TYPE_JANUS_PRO:
{
hparams.image_pad_color = {127, 127, 127};
@ -1649,16 +1670,21 @@ struct clip_model_loader {
}
// helper function
std::unordered_set<std::string> loaded_tensor_names;
auto get_tensor = [&](const std::string & name, bool required = true) {
// Each tensor should only be loaded once; duplicates indicate a bug
if (loaded_tensor_names.count(name)) {
throw std::runtime_error(string_format("%s: tensor already loaded: %s\n", __func__, name.c_str()));
}
ggml_tensor * cur = ggml_get_tensor(ctx_meta.get(), name.c_str());
if (!cur && required) {
throw std::runtime_error(string_format("%s: unable to find tensor %s\n", __func__, name.c_str()));
}
if (cur) {
tensors_to_load.push_back(cur);
// add tensors to context
ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data.get(), cur);
ggml_set_name(data_tensor, cur->name);
loaded_tensor_names.insert(name);
cur = data_tensor;
}
return cur;
@ -2141,6 +2167,20 @@ struct clip_model_loader {
model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight"));
model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias"));
} break;
case PROJECTOR_TYPE_QWEN3A:
{
model.conv2d_1_w = get_tensor(string_format(TN_CONV2D, 1, "weight"));
model.conv2d_1_b = get_tensor(string_format(TN_CONV2D, 1, "bias"));
model.conv2d_2_w = get_tensor(string_format(TN_CONV2D, 2, "weight"));
model.conv2d_2_b = get_tensor(string_format(TN_CONV2D, 2, "bias"));
model.conv2d_3_w = get_tensor(string_format(TN_CONV2D, 3, "weight"));
model.conv2d_3_b = get_tensor(string_format(TN_CONV2D, 3, "bias"));
model.conv_out_w = get_tensor(string_format(TN_CONV_OUT, "weight")); // no bias
model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias"));
model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
model.mm_2_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "bias"));
} break;
case PROJECTOR_TYPE_VOXTRAL:
{
model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
@ -2274,6 +2314,76 @@ struct clip_model_loader {
model.mm_fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight"));
model.mm_fc_b = get_tensor(string_format(TN_MM_PROJECTOR, "bias"));
} break;
case PROJECTOR_TYPE_GEMMA4A:
{
for (int i = 0; i < 2; i++) {
model.sscp_conv_w[i] = get_tensor(string_format(TN_A_CONV1D, i, "weight"));
model.sscp_conv_b[i] = get_tensor(string_format(TN_A_CONV1D, i, "bias"), false);
model.sscp_norm_w[i] = get_tensor(string_format(TN_A_CONV1D_NORM, i, "weight"), false);
}
model.sscp_inp_proj_w = get_tensor(string_format(TN_A_INP_PROJ, "weight"));
model.sscp_inp_proj_b = get_tensor(string_format(TN_A_INP_PROJ, "bias"), false);
model.audio_out_proj_w = get_tensor(string_format(TN_A_OUT_PROJ, "weight"), false);
model.audio_out_proj_b = get_tensor(string_format(TN_A_OUT_PROJ, "bias"), false);
// audio multimodal embedder (mm.a.* namespace, not mm.*)
model.mm_soft_emb_norm_w = get_tensor(string_format(TN_A_MM_SOFT_EMB_N, "weight"), false);
model.mm_input_proj_w = get_tensor(string_format(TN_A_MM_INP_PROJ, "weight"), false);
// Per-layer tensors NOT loaded by the generic loop above
for (int il = 0; il < hparams.n_layer; ++il) {
auto & layer = model.layers[il];
// Gemma4 audio conformer-specific tensors
layer.ff_norm_w = get_tensor(string_format(TN_FFN_NORM, prefix, il, "weight"));
layer.attn_pre_norm_w = get_tensor(string_format(TN_A_ATTN_PRE_NORM, prefix, il, "weight"), false);
layer.per_dim_scale_w = get_tensor(string_format(TN_A_PER_DIM_SCALE, prefix, il, "weight"), false);
layer.per_dim_k_scale_w = get_tensor(string_format(TN_A_PER_DIM_K_SCALE, prefix, il, "weight"), false);
layer.attn_k_rel_w = get_tensor(string_format(TN_A_ATTN_K_REL, prefix, il, "weight"), false);
// Convolution module
// Note: conv_norm / norm_conv are swapped in GGUF due to
// upstream tensor_mapping.py, so we load them in reverse order
layer.norm_conv_w = get_tensor(string_format(TN_CONV_NORM, prefix, il, "weight"), false);
layer.norm_conv_b = get_tensor(string_format(TN_CONV_NORM, prefix, il, "bias"), false);
layer.conv_pw1_w = get_tensor(string_format(TN_CONV_PW1, prefix, il, "weight"));
layer.conv_pw1_b = get_tensor(string_format(TN_CONV_PW1, prefix, il, "bias"), false);
layer.conv_dw_w = get_tensor(string_format(TN_CONV_DW, prefix, il, "weight"));
layer.conv_dw_b = get_tensor(string_format(TN_CONV_DW, prefix, il, "bias"), false);
layer.conv_norm_w = get_tensor(string_format(TN_NORM_CONV, prefix, il, "weight"), false);
layer.conv_norm_b = get_tensor(string_format(TN_NORM_CONV, prefix, il, "bias"), false);
layer.conv_pw2_w = get_tensor(string_format(TN_CONV_PW2, prefix, il, "weight"));
layer.conv_pw2_b = get_tensor(string_format(TN_CONV_PW2, prefix, il, "bias"), false);
// FFN2 (second half-step)
layer.ff_norm_1_w = get_tensor(string_format(TN_FFN_NORM_1, prefix, il, "weight"));
layer.ff_up_1_w = get_tensor(string_format(TN_FFN_UP_1, prefix, il, "weight"));
layer.ff_up_1_b = get_tensor(string_format(TN_FFN_UP_1, prefix, il, "bias"), false);
layer.ff_down_1_w = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "weight"));
layer.ff_down_1_b = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "bias"), false);
layer.ff_post_norm_1_w = get_tensor(string_format(TN_A_FFN_POST_NORM_1, prefix, il, "weight"), false);
}
// Load clamp info for ClippableLinear AFTER all tensors are loaded
for (auto * tensor : tensors_to_load) {
std::string name = tensor->name;
if (string_ends_with2(name, ".weight")) {
std::string name_inp_max = name;
std::string name_inp_min = name;
std::string name_out_max = name;
std::string name_out_min = name;
string_replace_all(name_inp_max, ".weight", ".input_max");
string_replace_all(name_inp_min, ".weight", ".input_min");
string_replace_all(name_out_max, ".weight", ".output_max");
string_replace_all(name_out_min, ".weight", ".output_min");
model.clamp_info_map[name] = {
get_scalar(name_inp_max, FLT_MAX),
get_scalar(name_inp_min, -FLT_MAX),
get_scalar(name_out_max, FLT_MAX),
get_scalar(name_out_min, -FLT_MAX)
};
}
}
} break;
case PROJECTOR_TYPE_LFM2A:
{
for (int i : {0, 2, 3, 5, 6}) {
@ -2334,7 +2444,10 @@ struct clip_model_loader {
ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
for (auto & t : tensors_to_load) {
ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name);
const size_t offset = tensor_offset[t->name];
GGML_ASSERT(cur && "tensor not found in ctx_data");
auto it_off = tensor_offset.find(t->name);
GGML_ASSERT(it_off != tensor_offset.end() && "no offset for tensor");
const size_t offset = it_off->second;
fin.seekg(offset, std::ios::beg);
if (!fin) {
throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name));
@ -2354,6 +2467,7 @@ struct clip_model_loader {
LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str());
}
}
struct support_info_op {
@ -2626,8 +2740,7 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
// TODO: we don't support audio for Gemma 3N, but GGUF contains audio tensors
// we can remove this check when we implement audio support for Gemma 3N
skip_audio = ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA3NV
|| ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA4V;
skip_audio = ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA3NV;
}
if (loader.has_audio && !skip_audio) {
@ -3098,7 +3211,7 @@ void setup_init_vision_shim_kcpp(struct clip_ctx * ctx_v) {
GGML_ASSERT(image_preproc != nullptr);
//end of lcpp code block
// =====
// =====
}
//kcpp: legacy shim created during upstream PR 21031
@ -3344,6 +3457,15 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
n_patches /= 2;
}
} break;
case PROJECTOR_TYPE_QWEN3A:
{
// 3x stride-2 conv2d: each step is floor((n-1)/2)+1
int n = img->nx;
n = (n - 1) / 2 + 1;
n = (n - 1) / 2 + 1;
n = (n - 1) / 2 + 1;
n_patches = n;
} break;
case PROJECTOR_TYPE_GLMA:
{
n_patches = img->nx;
@ -3381,6 +3503,16 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
{
n_patches = ((((img->nx + 1) / 2) + 1) / 2 + 1) / 2;
} break;
case PROJECTOR_TYPE_GEMMA4A:
{
// Two Conv2D stride-2: O = floor((I + 2p - k) / s) + 1, p=1, k=3, s=2
// O = floor((I - 1) / 2) + 1
int n = img->nx;
for (int i = 0; i < 2; i++) {
n = (n - 1) / 2 + 1;
}
n_patches = n;
} break;
default:
GGML_ABORT("unsupported projector type");
}
@ -3810,6 +3942,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
case PROJECTOR_TYPE_INTERNVL:
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_QWEN3A:
case PROJECTOR_TYPE_GLMA:
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_LFM2:
@ -3840,6 +3973,56 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
set_input_i32("pos_w", pos_data);
} break;
case PROJECTOR_TYPE_GEMMA4A:
{
GGML_ASSERT(imgs.entries.size() == 1);
const auto & img0 = imgs.entries.front();
// Compute n_pos matching SSCP output: two stride-2 convs
int n_pos = img0->nx;
for (int i = 0; i < 2; i++) { n_pos = (n_pos - 1) / 2 + 1; }
// Chunked local attention: blocked causal mask and RPE
const int chunk_size = 12;
const int max_past = 12;
const int context_size = chunk_size + max_past;
const int num_blocks = (n_pos + chunk_size - 1) / chunk_size;
// Blocked causal attention mask: [context_size, chunk_size, num_blocks]
{
std::vector<float> mask(context_size * chunk_size * num_blocks, -1e9f);
for (int b = 0; b < num_blocks; b++) {
for (int q = 0; q < chunk_size; q++) {
int gq = b * chunk_size + q;
for (int k = 0; k < context_size; k++) {
int gk = b * chunk_size - max_past + k;
if (gq < n_pos && gk >= 0 && gk < n_pos && gk <= gq && (gq - gk) < max_past) {
mask[k + q * context_size + b * context_size * chunk_size] = 0.0f;
}
}
}
}
set_input_f32("kq_mask", mask);
}
// Sinusoidal RPE: 13 positions [12, 11, ..., 0]
{
const int n_embd = ctx->model.hparams.n_embd;
const int num_timescales = n_embd / 2;
const float log_timescale_increment = logf(10000.0f) / std::max(num_timescales - 1, 1);
const int rpe_len = max_past + 1;
std::vector<float> pos_emb(n_embd * rpe_len, 0.0f);
for (int p = 0; p < rpe_len; p++) {
float position = (float)(max_past - p);
for (int i = 0; i < num_timescales; i++) {
float inv_ts = expf(-(float)i * log_timescale_increment);
float scaled = position * inv_ts;
pos_emb[p * n_embd + i] = sinf(scaled);
pos_emb[p * n_embd + i + num_timescales] = cosf(scaled);
}
}
set_input_f32("pos_emb", pos_emb);
}
} break;
case PROJECTOR_TYPE_LFM2A:
{
GGML_ASSERT(imgs.entries.size() == 1);
@ -4186,8 +4369,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->model.mm_model_proj->ne[1];
case PROJECTOR_TYPE_QWEN2A:
return ctx->model.mm_fc_w->ne[1];
case PROJECTOR_TYPE_GLMA:
case PROJECTOR_TYPE_QWEN3A:
return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_GLMA:
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_PADDLEOCR:
@ -4201,6 +4385,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->model.mm_fc_w->ne[1];
case PROJECTOR_TYPE_LFM2A:
return ctx->model.position_embeddings->ne[0];
case PROJECTOR_TYPE_GEMMA4A:
return ctx->model.hparams.projection_dim;
case PROJECTOR_TYPE_GLM4V:
return ctx->model.mm_ffn_down_w->ne[1];
default:
@ -4254,6 +4440,7 @@ bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
switch (ctx->proj_type()) {
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_QWEN3A:
case PROJECTOR_TYPE_GLMA:
case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_MERALION:

View file

@ -0,0 +1,288 @@
/**
* Gemma 4 Audio Conformer Encoder (clip_graph_gemma4a)
*
* Architecture: Conformer with dual half-step FFN, full self-attention
* with sinusoidal RPE, depthwise light conv, and output projection.
*/
#include "models.h"
#include <cmath>
ggml_cgraph * clip_graph_gemma4a::build() {
const float res_weight = 0.5f;
const float norm_eps = 1e-6f;
// 1. Input
ggml_tensor * inp = build_inp_raw(1);
auto * cur = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
// 2. Subsampling Conv2D (symmetric padding=1, matching PyTorch)
{
for (int i = 0; i < 2; i++) {
cur = ggml_conv_2d(ctx0, model.sscp_conv_w[i], cur, 2, 2, 1, 1, 1, 1);
if (model.sscp_conv_b[i]) {
cur = ggml_add(ctx0, cur, model.sscp_conv_b[i]);
}
// nn.LayerNorm(channels): permute ch to ne[0], normalize, permute back
if (model.sscp_norm_w[i]) {
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
cur = ggml_norm(ctx0, cur, norm_eps);
cur = ggml_mul(ctx0, cur, model.sscp_norm_w[i]);
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3));
}
cur = ggml_relu(ctx0, cur);
}
// Flatten [freq, time, ch, 1] -> [ch*freq, time]
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2]);
if (model.sscp_inp_proj_w) {
cur = build_mm(model.sscp_inp_proj_w, cur);
if (model.sscp_inp_proj_b) {
cur = ggml_add(ctx0, cur, model.sscp_inp_proj_b);
}
}
}
const int64_t n_pos = cur->ne[1];
// Chunked local attention parameters
const int64_t C = 12; // chunk_size
const int64_t P = 12; // max_past_horizon (context_left - 1)
const int64_t S = C + P; // context_size = 24
const int64_t R = P + 1; // RPE positions = 13
const int64_t B = (n_pos + C - 1) / C; // num_blocks
const int64_t Np = B * C; // padded sequence length
const int64_t pad_seq = Np - n_pos;
// Input tensors: blocked RPE and blocked attention mask
ggml_tensor * pos_emb = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_head * d_head, R);
ggml_set_name(pos_emb, "pos_emb");
ggml_set_input(pos_emb);
ggml_tensor * kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, S, C, B);
ggml_set_name(kq_mask, "kq_mask");
ggml_set_input(kq_mask);
// 3. Conformer Blocks
for (int il = 0; il < hparams.n_layer; il++) {
const auto & layer = model.layers[il];
auto * residual = cur;
// FFN 1 (half-step)
if (layer.ff_norm_w && layer.ff_up_w && layer.ff_down_w) {
cur = build_norm(cur, layer.ff_norm_w, nullptr, NORM_TYPE_RMS, norm_eps, il);
cur = build_ffn(cur,
layer.ff_up_w, nullptr, nullptr, nullptr,
layer.ff_down_w, nullptr, FFN_SILU, il);
if (layer.ff_post_norm_w) {
cur = build_norm(cur, layer.ff_post_norm_w, nullptr, NORM_TYPE_RMS, norm_eps, il);
}
residual = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, res_weight));
}
// Chunked local self-attention with RPE
if (layer.q_w && layer.k_w && layer.v_w && layer.o_w) {
const float q_scale = (1.0f / sqrtf((float)d_head)) / logf(2.0f);
const float k_scale = logf(1.0f + expf(1.0f)) / logf(2.0f);
const float softcap = 50.0f;
ggml_tensor * attn_norm_w = layer.attn_pre_norm_w ? layer.attn_pre_norm_w : layer.ln_1_w;
cur = attn_norm_w
? build_norm(residual, attn_norm_w, nullptr, NORM_TYPE_RMS, norm_eps, il)
: residual;
ggml_tensor * Qcur = build_mm(layer.q_w, cur);
ggml_tensor * Kcur = build_mm(layer.k_w, cur);
ggml_tensor * Vcur = build_mm(layer.v_w, cur);
// [n_embd, n_pos] -> [D, H, N]
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
// Q/K scaling
Qcur = ggml_scale(ctx0, Qcur, q_scale);
if (layer.per_dim_scale_w) {
Qcur = ggml_mul(ctx0, Qcur, ggml_reshape_3d(ctx0, layer.per_dim_scale_w, d_head, 1, 1));
}
Kcur = ggml_scale(ctx0, Kcur, k_scale);
if (layer.per_dim_k_scale_w) {
Kcur = ggml_mul(ctx0, Kcur, ggml_reshape_3d(ctx0, layer.per_dim_k_scale_w, d_head, 1, 1));
}
// Q blocking: [D, H, N] -> pad to Np -> reshape [D, H, C, B]
// ggml permute: ne[ax_i] = src->ne[i], so (0,3,1,2) sends H->3, C->1, B->2
Qcur = ggml_pad(ctx0, Qcur, 0, 0, pad_seq, 0); // [D, H, Np]
Qcur = ggml_reshape_4d(ctx0, Qcur, d_head, n_head, C, B); // [D, H, C, B]
Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 3, 1, 2)); // [D, C, B, H]
// K/V block context extraction via overlapping view:
// Pad to S*B elements, roll right by P to create left-padding,
// then view with stride C in the block dimension (overlapping windows).
auto extract_blocks = [&](ggml_tensor * t) -> ggml_tensor * {
// [D, H, N] -> pad to S*B -> roll right by P -> cont (materialize)
const int64_t pad_kv = S * B - n_pos;
t = ggml_pad(ctx0, t, 0, 0, pad_kv, 0); // [D, H, S*B]
t = ggml_roll(ctx0, t, 0, 0, P, 0); // left-pad by P
t = ggml_cont(ctx0, t); // materialize roll (removes view offset)
// Overlapping view: stride for B dim is C positions, not S
// ne = [D, H, S, B], data_size = D*H*S*B*sizeof = source_nbytes (exact fit)
// nb1=D*sizeof, nb2=D*H*sizeof, nb3=C*D*H*sizeof (overlap: C < S)
t = ggml_view_4d(ctx0, t, d_head, n_head, S, B,
t->nb[1], t->nb[2], C * t->nb[2], 0);
t = ggml_cont(ctx0, t); // materialize overlapping windows
return t;
};
ggml_tensor * Kblk = extract_blocks(Kcur);
// [D, H, S, B] -> [D, S, B, H] via permute(0,3,1,2)
Kblk = ggml_cont(ctx0, ggml_permute(ctx0, Kblk, 0, 3, 1, 2));
ggml_tensor * Vblk = extract_blocks(Vcur);
// [D, H, S, B] -> [S, D, B, H] via permute(1,3,0,2)
Vblk = ggml_cont(ctx0, ggml_permute(ctx0, Vblk, 1, 3, 0, 2));
// Content attention: Q @ K^T
// Kblk=[D,S,B,H], Qcur=[D,C,B,H] -> mul_mat contracts on D -> [S,C,B,H]
ggml_tensor * matrix_ac = ggml_mul_mat(ctx0, Kblk, Qcur);
// Relative position attention
if (layer.attn_k_rel_w) {
// RPE: [n_embd, R] -> project -> [D, H, R] -> [D, R, H]
auto * p = ggml_mul_mat(ctx0, layer.attn_k_rel_w, pos_emb);
p = ggml_reshape_3d(ctx0, p, d_head, n_head, R);
p = ggml_cont(ctx0, ggml_permute(ctx0, p, 0, 2, 1, 3)); // [D, R, H]
// Q_flat @ RPE^T: [D, C*B, H] @ [D, R, H] -> [R, C*B, H]
auto * Q_flat = ggml_reshape_3d(ctx0, Qcur, d_head, C * B, n_head);
auto * matrix_bd = ggml_mul_mat(ctx0, p, Q_flat); // [R, C*B, H]
matrix_bd = ggml_reshape_4d(ctx0, matrix_bd, R, C, B, n_head); // [R, C, B, H]
// Blocked relative shift (appendix B of Transformer-XL)
{
matrix_bd = ggml_pad(ctx0, matrix_bd, S + 1 - R, 0, 0, 0); // [S+1, C, B, H]
matrix_bd = ggml_reshape_3d(ctx0, matrix_bd, (S + 1) * C, B, n_head);
matrix_bd = ggml_view_3d(ctx0, matrix_bd,
C * S, B, n_head,
matrix_bd->nb[1], matrix_bd->nb[2], 0);
matrix_bd = ggml_cont(ctx0, matrix_bd); // [C*S, B, H]
matrix_bd = ggml_reshape_4d(ctx0, matrix_bd, S, C, B, n_head); // [S, C, B, H]
}
matrix_ac = ggml_add(ctx0, matrix_ac, matrix_bd);
}
auto * scores = matrix_ac; // [S, C, B, H]
// Softcap
scores = ggml_scale(ctx0, scores, 1.0f / softcap);
scores = ggml_tanh(ctx0, scores);
scores = ggml_scale(ctx0, scores, softcap);
// Blocked attention mask: [S, C, B] broadcasts over H
scores = ggml_add(ctx0, scores, kq_mask);
ggml_tensor * attn = ggml_soft_max(ctx0, scores);
// attn @ V: [S,C,B,H] @ [S,D,B,H] -> [D,C,B,H]
ggml_tensor * x = ggml_mul_mat(ctx0, Vblk, attn);
// [D,C,B,H] -> [D,H,C,B] via permute(0,2,3,1) -> flatten -> trim
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 0, 2, 3, 1));
x = ggml_cont_2d(ctx0, x, d_head * n_head, C * B);
if (pad_seq > 0) {
x = ggml_view_2d(ctx0, x, d_head * n_head, n_pos, x->nb[1], 0);
x = ggml_cont(ctx0, x);
}
x = build_mm(layer.o_w, x);
if (layer.o_b) { x = ggml_add(ctx0, x, layer.o_b); }
if (layer.attn_post_norm_w) {
x = build_norm(x, layer.attn_post_norm_w, nullptr, NORM_TYPE_RMS, norm_eps, il);
}
residual = ggml_add(ctx0, residual, x);
}
// Convolution Module
if (layer.norm_conv_w && layer.conv_pw1_w && layer.conv_dw_w && layer.conv_pw2_w) {
cur = build_norm(residual, layer.norm_conv_w, nullptr, NORM_TYPE_RMS, norm_eps, il);
auto * x = build_mm(layer.conv_pw1_w, cur);
// GLU
{
int64_t d = x->ne[0] / 2;
ggml_tensor * gate = ggml_sigmoid(ctx0,
ggml_cont(ctx0, ggml_view_2d(ctx0, x, d, x->ne[1], x->nb[1], d * x->nb[0])));
x = ggml_mul(ctx0,
ggml_view_2d(ctx0, x, d, x->ne[1], x->nb[1], 0), gate);
x = ggml_cont(ctx0, ggml_transpose(ctx0, x));
}
// Causal depthwise Conv1D via ggml_ssm_conv (pad+roll for left-only padding).
x = ggml_pad(ctx0, x, 4, 0, 0, 0);
x = ggml_roll(ctx0, x, 4, 0, 0, 0);
x = ggml_ssm_conv(ctx0, x, layer.conv_dw_w);
if (layer.conv_dw_b) {
x = ggml_add(ctx0, x, layer.conv_dw_b);
}
if (layer.conv_norm_w) {
x = ggml_rms_norm(ctx0, x, norm_eps);
x = ggml_mul(ctx0, x, layer.conv_norm_w);
}
x = ggml_silu(ctx0, x);
x = build_mm(layer.conv_pw2_w, x);
residual = ggml_add(ctx0, residual, x);
}
// FFN 2 (half-step)
if (layer.ff_norm_1_w && layer.ff_up_1_w && layer.ff_down_1_w) {
cur = build_norm(residual, layer.ff_norm_1_w, nullptr, NORM_TYPE_RMS, norm_eps, il);
cur = build_ffn(cur,
layer.ff_up_1_w, nullptr, nullptr, nullptr,
layer.ff_down_1_w, nullptr, FFN_SILU, il);
if (layer.ff_post_norm_1_w) {
cur = build_norm(cur, layer.ff_post_norm_1_w, nullptr, NORM_TYPE_RMS, norm_eps, il);
}
residual = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, res_weight));
}
// Layer output norm
cur = layer.ln_2_w
? build_norm(residual, layer.ln_2_w, nullptr, NORM_TYPE_RMS, norm_eps, il)
: residual;
}
// 4. Output Projection
if (model.audio_out_proj_w) {
cur = build_mm(model.audio_out_proj_w, cur);
if (model.audio_out_proj_b) {
cur = ggml_add(ctx0, cur, model.audio_out_proj_b);
}
}
// 5. Audio Multimodal Embedder
cur = ggml_rms_norm(ctx0, cur, norm_eps);
if (model.mm_soft_emb_norm_w) {
cur = ggml_mul(ctx0, cur, model.mm_soft_emb_norm_w);
}
if (model.mm_input_proj_w) {
cur = build_mm(model.mm_input_proj_w, cur);
}
ggml_build_forward_expand(gf, cur);
return gf;
}
ggml_tensor * clip_graph_gemma4a::build_mm(ggml_tensor * w, ggml_tensor * x) const {
auto it = model.clamp_info_map.find(w->name);
if (it == model.clamp_info_map.end()) {
return ggml_mul_mat(ctx0, w, x);
}
const auto & ci = it->second;
ggml_tensor * clamped = ggml_clamp(ctx0, x, ci.inp_min, ci.inp_max);
ggml_tensor * out = ggml_mul_mat(ctx0, w, clamped);
return ggml_clamp(ctx0, out, ci.out_min, ci.out_max);
}

View file

@ -103,6 +103,12 @@ struct clip_graph_conformer : clip_graph {
ggml_cgraph * build() override;
};
struct clip_graph_gemma4a : clip_graph {
clip_graph_gemma4a(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;
ggml_tensor * build_mm(ggml_tensor * w, ggml_tensor * x) const override;
};
struct clip_graph_glm4v : clip_graph {
clip_graph_glm4v(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;
@ -146,6 +152,11 @@ struct clip_graph_mobilenetv5 : clip_graph {
const mobilenetv5_block & block);
};
struct clip_graph_qwen3a : clip_graph {
clip_graph_qwen3a(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;
};
struct clip_graph_kimik25 : clip_graph {
clip_graph_kimik25(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;

View file

@ -0,0 +1,68 @@
#include "models.h"
ggml_cgraph * clip_graph_qwen3a::build() {
ggml_tensor * inp = build_inp_raw(1);
// conv2d block
// TODO: do we need to split by chunks of n_window each like on transformers impl?
{
inp = ggml_conv_2d(ctx0, model.conv2d_1_w, inp, 2, 2, 1, 1, 1, 1);
inp = ggml_add(ctx0, inp, model.conv2d_1_b);
inp = ggml_gelu_erf(ctx0, inp);
inp = ggml_conv_2d(ctx0, model.conv2d_2_w, inp, 2, 2, 1, 1, 1, 1);
inp = ggml_add(ctx0, inp, model.conv2d_2_b);
inp = ggml_gelu_erf(ctx0, inp);
inp = ggml_conv_2d(ctx0, model.conv2d_3_w, inp, 2, 2, 1, 1, 1, 1);
inp = ggml_add(ctx0, inp, model.conv2d_3_b);
inp = ggml_gelu_erf(ctx0, inp);
// inp [n_pos, n_mels/8, channels, 1] (W, H, C, N)
cb(inp, "after_conv_blocks", -1);
const int64_t n_pos_after_conv = inp->ne[0];
const int64_t n_mel_after_conv = inp->ne[1]; // 128/8 = 16
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 3, 1));
inp = ggml_reshape_2d(ctx0, inp, n_pos_after_conv, n_mel_after_conv * inp->ne[3]); // [n_pos, 7680]
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); // [7680, n_pos]
// project to n_embd
inp = ggml_mul_mat(ctx0, model.conv_out_w, inp);
if (model.conv_out_b) {
inp = ggml_add(ctx0, inp, model.conv_out_b);
}
cb(inp, "after_conv_out", -1);
}
auto n_pos = inp->ne[1];
ggml_tensor * pos_embd_selected = ggml_view_2d(
ctx0, model.position_embeddings,
model.position_embeddings->ne[0], n_pos,
model.position_embeddings->nb[1], 0
);
ggml_tensor * cur = build_vit(
inp, n_pos,
NORM_TYPE_NORMAL,
hparams.ffn_op,
pos_embd_selected,
nullptr);
cb(cur, "after_transformer", -1);
// projector
cur = build_ffn(cur,
model.mm_1_w, model.mm_1_b,
nullptr, nullptr,
model.mm_2_w, model.mm_2_b,
FFN_GELU_ERF,
-1);
cb(cur, "projected", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}

View file

@ -8,6 +8,7 @@
#include <vector>
#include <fstream>
#include <algorithm>
#include <functional>
// some of the code here is copied from whisper.cpp
@ -37,23 +38,36 @@ void mtmd_audio_cache::fill_mel_filterbank_matrix(int n_mel,
float fmin,
float fmax,
bool slaney_area_norm,
float scale) {
float scale,
bool use_htk) {
GGML_ASSERT(n_mel > 0 && n_fft > 1);
if (fmax <= 0.0f) {
fmax = 0.5f * sample_rate;
}
// Slaney scale (matches librosa default)
const double min_log_hz = 1000.0;
const double lin_slope = 3 / 200.;
const double min_log_mel = min_log_hz * lin_slope;
const double log_step = log(6.4) / 27.0;
auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double {
return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step;
};
auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double {
return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step);
};
std::function<double(double)> hz_to_mel;
std::function<double(double)> mel_to_hz;
if (use_htk) {
hz_to_mel = [](const double f_hz) -> double {
return 2595.0 * log10(1.0 + f_hz / 700.0);
};
mel_to_hz = [](const double m) -> double {
return 700.0 * (pow(10.0, m / 2595.0) - 1.0);
};
} else {
// Slaney scale (matches librosa default)
const double min_log_hz = 1000.0;
const double lin_slope = 3 / 200.;
const double min_log_mel = min_log_hz * lin_slope;
const double log_step = log(6.4) / 27.0;
hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double {
return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step;
};
mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double {
return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step);
};
}
// infer N_fft from n_fft_bins
const double bin_hz_step = double(sample_rate) / double(n_fft);
@ -257,10 +271,13 @@ struct filter_params {
int32_t hann_window_size;
int32_t hop_length;
int32_t sample_rate;
bool center_padding = false;
float preemph = 0.f;
bool no_padding = false;
bool center_padding = false;
float preemph = 0.f;
bool use_natural_log = false;
bool norm_per_feature = false;
bool use_magnitude = false; // |X| instead of |X|^2
float mel_floor = 5.960464477539063e-08f;
};
static void log_mel_spectrogram_worker_thread(int ith,
@ -301,10 +318,10 @@ static void log_mel_spectrogram_worker_thread(int ith,
// FFT
fft(cache, fft_in.data(), frame_size, fft_out.data());
// Calculate modulus^2 of complex numbers
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
// Calculate modulus^2 (power) or modulus (magnitude)
for (int j = 0; j < n_fft_bins; j++) {
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
float power = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
fft_out[j] = params.use_magnitude ? sqrtf(power) : power;
}
// mel spectrogram
@ -324,9 +341,10 @@ static void log_mel_spectrogram_worker_thread(int ith,
for (; k < n_fft_bins; k++) {
sum += fft_out[k] * filters.data[j * n_fft_bins + k];
}
sum = std::max(sum, (double)params.mel_floor);
sum = params.use_natural_log
? log(sum + 5.960464477539063e-08)
: log10(std::max(sum, 1e-10));
? log(sum)
: log10(sum);
out.data[j * out.n_len + i] = sum;
}
}
@ -360,7 +378,12 @@ static bool log_mel_spectrogram(
// Padding
std::vector<float> samples_padded;
if (params.center_padding) {
if (params.no_padding) {
// no padding, use samples as-is
samples_padded = std::vector<float>(samples, samples + n_samples);
samples = samples_padded.data();
n_samples = samples_padded.size();
} else if (params.center_padding) {
const auto pad_amount = frame_size / 2;
samples_padded = std::vector<float>(n_samples + 2 * pad_amount, 0);
std::copy(samples, samples + n_samples, samples_padded.data() + pad_amount);
@ -464,8 +487,8 @@ static bool log_mel_spectrogram(
out.data[i * out.n_len + j] = 0.0;
}
}
} else {
// clamping and normalization
} else if (!params.no_padding) {
// Whisper-style clamping and normalization (NOT used by Gemma4)
double mmax = -1e20;
for (int i = 0; i < out.n_mel*out.n_len; i++) {
if (out.data[i] > mmax) {
@ -627,6 +650,87 @@ bool mtmd_audio_preprocessor_conformer::preprocess(const float *
return true;
}
//
// mtmd_audio_preprocessor_gemma4a
//
void mtmd_audio_preprocessor_gemma4a::initialize() {
cache.fill_sin_cos_table(hparams.audio_n_fft);
// Standard periodic Hann window, zero-padded to FFT size
cache.hann_window.assign(hparams.audio_n_fft, 0.0f);
for (uint32_t i = 0; i < (uint32_t)hparams.audio_window_len; i++) {
cache.hann_window[i] = 0.5f - 0.5f * cosf((2.0f * (float)M_PI * i) / hparams.audio_window_len);
}
// HTK mel scale, no Slaney area normalization
cache.fill_mel_filterbank_matrix(
hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate,
0.0f, hparams.audio_sample_rate / 2.0f,
/*slaney_area_norm=*/ false,
/*scale=*/ 1.0f,
/*use_htk=*/ true
);
}
bool mtmd_audio_preprocessor_gemma4a::preprocess(const float * samples,
size_t n_samples,
std::vector<mtmd_audio_mel> & output) {
if (n_samples == 0) {
return false;
}
GGML_ASSERT(!cache.sin_vals.empty());
GGML_ASSERT(!cache.cos_vals.empty());
GGML_ASSERT(!cache.filters.data.empty());
filter_params params;
params.n_mel = hparams.n_mel_bins;
params.n_fft_bins = 1 + (hparams.audio_n_fft / 2);
params.hann_window_size = hparams.audio_n_fft; // window is zero-padded to FFT size
params.hop_length = hparams.audio_hop_len;
params.sample_rate = hparams.audio_sample_rate;
params.no_padding = true;
params.center_padding = false;
params.preemph = 0.0f;
params.use_natural_log = true;
params.use_magnitude = true;
params.mel_floor = 0.001f;
params.norm_per_feature = false;
// Split into 30-second chunks (model context limit, ~750 tokens each)
const size_t chunk_samples = 30 * hparams.audio_sample_rate;
for (size_t off = 0; off < n_samples; off += chunk_samples) {
const float * chunk_ptr = samples + off;
size_t chunk_len = std::min(chunk_samples, n_samples - off);
// Semicausal left-padding + right-padding to match PyTorch frame count
const int pad_left = hparams.audio_window_len / 2;
const int fft_size = hparams.audio_n_fft;
const int hop = hparams.audio_hop_len;
const int n_with_left = (int)chunk_len + pad_left;
// PyTorch: unfold(size=frame_length+1, step=hop) on semicausal-padded waveform
const int pt_frames = (n_with_left - (hparams.audio_window_len + 1)) / hop + 1;
const int n_padded_needed = (pt_frames - 1) * hop + fft_size;
const int total_pad = std::max((int)(n_padded_needed - (int)chunk_len), pad_left);
std::vector<float> padded_samples(total_pad + chunk_len, 0.0f);
std::copy(chunk_ptr, chunk_ptr + chunk_len, padded_samples.data() + pad_left);
mtmd_audio_mel out_chunk;
bool ok = log_mel_spectrogram(padded_samples.data(), padded_samples.size(), 4, params, cache, out_chunk);
if (!ok) {
return false;
}
// Trim to PyTorch frame count
out_chunk.n_len = std::min(out_chunk.n_len, pt_frames);
output.push_back(std::move(out_chunk));
}
return true;
}
//
// mtmd_audio_streaming_istft implementation
//

View file

@ -45,7 +45,8 @@ struct mtmd_audio_cache {
float fmin = 0.0f, // e.g. 0.0
float fmax = -1.0f, // e.g. sr/2; pass -1 for auto
bool slaney_area_norm = true,
float scale = 1.0f // optional extra scaling
float scale = 1.0f,
bool use_htk = false
);
};
@ -77,6 +78,15 @@ struct mtmd_audio_preprocessor_conformer : mtmd_audio_preprocessor {
mtmd_audio_cache cache;
};
struct mtmd_audio_preprocessor_gemma4a : mtmd_audio_preprocessor {
mtmd_audio_preprocessor_gemma4a(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {}
void initialize() override;
bool preprocess(const float * samples, size_t n_samples, std::vector<mtmd_audio_mel> & output) override;
private:
mtmd_audio_cache cache;
};
//
// streaming ISTFT - converts spectrogram frames back to audio one frame at a time
//

View file

@ -274,7 +274,8 @@ int32_t mtmd_helper_decode_image_chunk(
batch_embd.set_position_normal(n_past, seq_id);
}
if (mtmd_decode_use_non_causal(ctx)) {
const bool use_non_causal = mtmd_decode_use_non_causal(ctx, chunk);
if (use_non_causal) {
llama_set_causal_attn(lctx, false);
// TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
}
@ -302,7 +303,7 @@ int32_t mtmd_helper_decode_image_chunk(
n_past += mtmd_input_chunk_get_n_pos(chunk);
*new_n_past = n_past;
if (mtmd_decode_use_non_causal(ctx)) {
if (use_non_causal) {
llama_set_causal_attn(lctx, true);
}
return 0;

View file

@ -198,35 +198,38 @@ struct img_tool {
private:
// Bilinear resize function
static void resize_bilinear(const clip_image_u8 & src, clip_image_u8 & dst, int target_width, int target_height) {
GGML_ASSERT(src.nx >= 2 && src.ny >= 2);
if (src.nx == 0 || src.ny == 0) { dst.nx = dst.ny = 0; dst.buf.clear(); return; }
if (target_width <= 0) target_width = 1;
if (target_height <= 0) target_height = 1;
dst.nx = target_width;
dst.ny = target_height;
dst.buf.resize(3 * target_width * target_height);
float x_ratio = static_cast<float>(src.nx - 1) / target_width;
float y_ratio = static_cast<float>(src.ny - 1) / target_height;
float x_ratio = target_width > 1 ? static_cast<float>(src.nx - 1) / (target_width - 1) : 0.0f;
float y_ratio = target_height > 1 ? static_cast<float>(src.ny - 1) / (target_height - 1) : 0.0f;
for (int y = 0; y < target_height; y++) {
for (int x = 0; x < target_width; x++) {
float px = x_ratio * x;
float py = y_ratio * y;
int x_floor = std::min(static_cast<int>(px), src.nx - 2);
int y_floor = std::min(static_cast<int>(py), src.ny - 2);
float x_lerp = px - x_floor;
float y_lerp = py - y_floor;
for (int y = 0; y < target_height; ++y) {
for (int x = 0; x < target_width; ++x) {
float px = x * x_ratio;
float py = y * y_ratio;
for (int c = 0; c < 3; c++) {
float top = lerp(
static_cast<float>(src.buf[3 * (y_floor * src.nx + x_floor) + c]),
static_cast<float>(src.buf[3 * (y_floor * src.nx + (x_floor + 1)) + c]),
x_lerp
);
float bottom = lerp(
static_cast<float>(src.buf[3 * ((y_floor + 1) * src.nx + x_floor) + c]),
static_cast<float>(src.buf[3 * ((y_floor + 1) * src.nx + (x_floor + 1)) + c]),
x_lerp
);
dst.buf[3 * (y * target_width + x) + c] = static_cast<uint8_t>(lerp(top, bottom, y_lerp));
int x0 = std::min(static_cast<int>(px), src.nx - 1);
int y0 = std::min(static_cast<int>(py), src.ny - 1);
int x1 = std::min(x0 + 1, src.nx - 1);
int y1 = std::min(y0 + 1, src.ny - 1);
float xf = px - x0;
float yf = py - y0;
for (int c = 0; c < 3; ++c) {
float top = lerp(static_cast<float>(src.buf[3 * (y0 * src.nx + x0) + c]),
static_cast<float>(src.buf[3 * (y0 * src.nx + x1) + c]),
xf);
float bottom = lerp(static_cast<float>(src.buf[3 * (y1 * src.nx + x0) + c]),
static_cast<float>(src.buf[3 * (y1 * src.nx + x1) + c]),
xf);
dst.buf[3 * (y * target_width + x) + c] = static_cast<uint8_t>(lerp(top, bottom, yf));
}
}
}

View file

@ -455,6 +455,7 @@ struct mtmd_context {
// set preprocessor
switch (proj) {
case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_QWEN3A:
case PROJECTOR_TYPE_QWEN25O:
{
// <|audio_bos|> ... (embeddings) ... <|audio_eos|>
@ -484,6 +485,12 @@ struct mtmd_context {
{
audio_preproc = std::make_unique<mtmd_audio_preprocessor_conformer>(ctx_a);
} break;
case PROJECTOR_TYPE_GEMMA4A:
{
aud_beg = "<|audio>";
aud_end = "<audio|>";
audio_preproc = std::make_unique<mtmd_audio_preprocessor_gemma4a>(ctx_a);
} break;
default:
throw std::runtime_error(string_format("%s: unexpected audio projector type %d\n", __func__, proj));
}
@ -1010,8 +1017,12 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
return ctx->image_embd_v.data();
}
bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
switch (ctx->proj_type_v()) {
bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk) {
auto proj_type = ctx->proj_type_v();
if (chunk && chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
proj_type = ctx->proj_type_a();
}
switch (proj_type) {
case PROJECTOR_TYPE_GEMMA3:
case PROJECTOR_TYPE_GEMMA4V:
return true;
@ -1021,6 +1032,10 @@ bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
}
bool mtmd_decode_use_mrope(mtmd_context * ctx) {
if (ctx->ctx_v == nullptr && ctx->proj_type_a() == PROJECTOR_TYPE_QWEN3A) {
// qwen3-asr
return true;
}
switch (ctx->proj_type_v()) {
case PROJECTOR_TYPE_QWEN2VL:
case PROJECTOR_TYPE_QWEN25VL:

View file

@ -114,7 +114,8 @@ MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
MTMD_API void mtmd_free(mtmd_context * ctx);
// whether we need to set non-causal mask before llama_decode
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
// if chunk is nullptr, we assume the default case where chunk is an image chunk
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk);
// whether the current model use M-RoPE for llama_decode
MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);

View file

@ -91,11 +91,13 @@ add_test_vision "ggml-org/LightOnOCR-1B-1025-GGUF:Q8_0"
add_test_vision "ggml-org/DeepSeek-OCR-GGUF:Q8_0" -p "Free OCR." --chat-template deepseek-ocr
add_test_vision "ggml-org/dots.ocr-GGUF:Q8_0" -p "OCR"
add_test_vision "ggml-org/HunyuanOCR-GGUF:Q8_0" -p "OCR"
add_test_vision "ggml-org/gemma-4-E2B-it-GGUF:Q8_0" --jinja
add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0"
add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
add_test_audio "ggml-org/Voxtral-Mini-3B-2507-GGUF:Q4_K_M"
add_test_audio "ggml-org/LFM2-Audio-1.5B-GGUF:Q8_0"
add_test_audio "ggml-org/gemma-4-E2B-it-GGUF:Q8_0" --jinja
# to test the big models, run: ./tests.sh big
if [ "$RUN_BIG_TESTS" = true ]; then

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load diff

View file

@ -926,7 +926,8 @@ void server_models_routes::init_routes() {
res_ok(res, {
// TODO: add support for this on web UI
{"role", "router"},
{"max_instances", 4}, // dummy value for testing
{"max_instances", params.models_max},
{"models_autoload", params.models_autoload},
// this is a dummy response to make sure webui doesn't break
{"model_alias", "llama-server"},
{"model_path", "none"},
@ -935,6 +936,7 @@ void server_models_routes::init_routes() {
{"n_ctx", 0},
}},
{"webui_settings", webui_settings},
{"build_info", build_info},
});
return res;
}

View file

@ -9,6 +9,19 @@ def create_server():
server = ServerPreset.router()
def test_router_props():
global server
server.models_max = 2
server.no_models_autoload = True
server.start()
res = server.make_request("GET", "/props")
assert res.status_code == 200
assert res.body["role"] == "router"
assert res.body["max_instances"] == 2
assert res.body["models_autoload"] is False
assert res.body["build_info"].startswith("b")
@pytest.mark.parametrize(
"model,success",
[

View file

@ -9,7 +9,14 @@
import { getMessageEditContext } from '$lib/contexts';
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
import { isLoading, isChatStreaming } from '$lib/stores/chat.svelte';
import { autoResizeTextarea, copyToClipboard, isIMEComposing } from '$lib/utils';
import {
autoResizeTextarea,
copyToClipboard,
isIMEComposing,
deriveAgenticSections
} from '$lib/utils';
import { AgenticSectionType } from '$lib/enums';
import { REASONING_TAGS } from '$lib/constants/agentic';
import { tick } from 'svelte';
import { fade } from 'svelte/transition';
import { Check, X } from '@lucide/svelte';
@ -95,6 +102,49 @@
let currentConfig = $derived(config());
let isRouter = $derived(isRouterMode());
let showRawOutput = $state(false);
let rawOutputContent = $derived.by(() => {
const sections = deriveAgenticSections(message, toolMessages, [], false);
const parts: string[] = [];
for (const section of sections) {
switch (section.type) {
case AgenticSectionType.REASONING:
case AgenticSectionType.REASONING_PENDING:
parts.push(`${REASONING_TAGS.START}\n${section.content}\n${REASONING_TAGS.END}`);
break;
case AgenticSectionType.TEXT:
parts.push(section.content);
break;
case AgenticSectionType.TOOL_CALL:
case AgenticSectionType.TOOL_CALL_PENDING:
case AgenticSectionType.TOOL_CALL_STREAMING: {
const callObj: Record<string, unknown> = { name: section.toolName };
if (section.toolArgs) {
try {
callObj.arguments = JSON.parse(section.toolArgs);
} catch {
callObj.arguments = section.toolArgs;
}
}
parts.push(JSON.stringify(callObj, null, 2));
if (section.toolResult) {
parts.push(`[Tool Result]\n${section.toolResult}`);
}
break;
}
}
}
return parts.join('\n\n\n');
});
let activeStatsView = $state<ChatMessageStatsView>(ChatMessageStatsView.GENERATION);
let statsContainerEl: HTMLDivElement | undefined = $state();
@ -252,7 +302,7 @@
</div>
{:else if message.role === MessageRole.ASSISTANT}
{#if showRawOutput}
<pre class="raw-output">{messageContent || ''}</pre>
<pre class="raw-output">{rawOutputContent || ''}</pre>
{:else}
<ChatMessageAgenticContent
{message}

View file

@ -89,6 +89,11 @@
key: SETTINGS_KEYS.ASK_FOR_TITLE_CONFIRMATION,
label: 'Ask for confirmation before changing conversation title',
type: SettingsFieldType.CHECKBOX
},
{
key: SETTINGS_KEYS.TITLE_GENERATION_USE_FIRST_LINE,
label: 'Use first non-empty line for conversation title',
type: SettingsFieldType.CHECKBOX
}
]
},

View file

@ -15,6 +15,18 @@
let { logs, connectionTimeMs, defaultExpanded = false, class: className }: Props = $props();
let isExpanded = $derived(defaultExpanded);
function formatLogDetails(details: unknown): string {
if (details == null) {
return '';
}
try {
return JSON.stringify(details, null, 2);
} catch {
return String(details);
}
}
</script>
{#if logs.length > 0}
@ -53,6 +65,16 @@
<span class="break-all">{log.message}</span>
</div>
{#if log.details !== undefined}
<details class="ml-11">
<summary class="cursor-pointer text-[10px] text-muted-foreground"> details </summary>
<pre
class="mt-1 overflow-x-auto rounded bg-background/70 p-2 text-[10px] break-all whitespace-pre-wrap text-foreground/80">
{formatLogDetails(log.details)}</pre>
</details>
{/if}
{/each}
</div>
</Collapsible.Content>

View file

@ -15,6 +15,11 @@ export const DEFAULT_AGENTIC_CONFIG: AgenticConfig = {
maxToolPreviewLines: 25
} as const;
export const REASONING_TAGS = {
START: '<think>',
END: '</think>'
} as const;
/**
* @deprecated Legacy marker tags - only used for migration of old stored messages.
* New messages use structured fields (reasoningContent, toolCalls, toolCallId).

View file

@ -48,6 +48,26 @@ export const EXPECTED_THEMED_ICON_PAIR_COUNT = 2;
/** CORS proxy URL query parameter name */
export const CORS_PROXY_URL_PARAM = 'url';
/** Number of trailing characters to keep visible when partially redacting mcp-session-id */
export const MCP_SESSION_ID_VISIBLE_CHARS = 5;
/** Partial-redaction rules for MCP headers: header name -> visible trailing chars */
export const MCP_PARTIAL_REDACT_HEADERS = new Map<string, number>([
['mcp-session-id', MCP_SESSION_ID_VISIBLE_CHARS]
]);
/** Header names whose values should be redacted in diagnostic logs */
export const REDACTED_HEADERS = new Set([
'authorization',
'api-key',
'cookie',
'mcp-session-id',
'proxy-authorization',
'set-cookie',
'x-auth-token',
'x-api-key'
]);
/** Human-readable labels for MCP transport types */
export const MCP_TRANSPORT_LABELS: Record<MCPTransportType, string> = {
[MCPTransportType.WEBSOCKET]: 'WebSocket',

View file

@ -15,6 +15,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean |
keepStatsVisible: false,
showMessageStats: true,
askForTitleConfirmation: false,
titleGenerationUseFirstLine: false,
pasteLongTextToFileLen: 2500,
copyTextAttachmentsAsPlainText: false,
pdfAsImage: false,
@ -118,6 +119,8 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
'Display generation statistics (tokens/second, token count, duration) below each assistant message.',
askForTitleConfirmation:
'Ask for confirmation before automatically changing conversation title when editing the first message.',
titleGenerationUseFirstLine:
'Use only the first non-empty line of the prompt to generate the conversation title.',
pdfAsImage:
'Parse PDF as image instead of text. Automatically falls back to text processing for non-vision models.',
disableAutoScroll:

View file

@ -15,6 +15,7 @@ export const SETTINGS_KEYS = {
ENABLE_CONTINUE_GENERATION: 'enableContinueGeneration',
PDF_AS_IMAGE: 'pdfAsImage',
ASK_FOR_TITLE_CONFIRMATION: 'askForTitleConfirmation',
TITLE_GENERATION_USE_FIRST_LINE: 'titleGenerationUseFirstLine',
// Display
SHOW_MESSAGE_STATS: 'showMessageStats',
SHOW_THOUGHT_IN_PROGRESS: 'showThoughtInProgress',

View file

@ -15,7 +15,8 @@ import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
import {
DEFAULT_MCP_CONFIG,
DEFAULT_CLIENT_VERSION,
DEFAULT_IMAGE_MIME_TYPE
DEFAULT_IMAGE_MIME_TYPE,
MCP_PARTIAL_REDACT_HEADERS
} from '$lib/constants';
import {
MCPConnectionPhase,
@ -43,9 +44,17 @@ import {
buildProxiedUrl,
buildProxiedHeaders,
getAuthHeaders,
sanitizeHeaders,
throwIfAborted,
isAbortError,
createBase64DataUrl
createBase64DataUrl,
getRequestUrl,
getRequestMethod,
getRequestBody,
summarizeRequestBody,
formatDiagnosticErrorMessage,
extractJsonRpcMethods,
type RequestBodySummary
} from '$lib/utils';
interface ToolResultContentItem {
@ -62,6 +71,16 @@ interface ToolCallResult {
_meta?: Record<string, unknown>;
}
interface DiagnosticRequestDetails {
url: string;
method: string;
credentials?: RequestCredentials;
mode?: RequestMode;
headers: Record<string, string>;
body: RequestBodySummary;
jsonRpcMethods?: string[];
}
export class MCPService {
/**
* Create a connection log entry for phase tracking.
@ -87,6 +106,225 @@ export class MCPService {
};
}
private static createDiagnosticRequestDetails(
input: RequestInfo | URL,
init: RequestInit | undefined,
baseInit: RequestInit,
requestHeaders: Headers,
extraRedactedHeaders?: Iterable<string>
): DiagnosticRequestDetails {
const body = getRequestBody(input, init);
const details: DiagnosticRequestDetails = {
url: getRequestUrl(input),
method: getRequestMethod(input, init, baseInit).toUpperCase(),
credentials: init?.credentials ?? baseInit.credentials,
mode: init?.mode ?? baseInit.mode,
headers: sanitizeHeaders(requestHeaders, extraRedactedHeaders, MCP_PARTIAL_REDACT_HEADERS),
body: summarizeRequestBody(body)
};
const jsonRpcMethods = extractJsonRpcMethods(body);
if (jsonRpcMethods) {
details.jsonRpcMethods = jsonRpcMethods;
}
return details;
}
private static summarizeError(error: unknown): Record<string, unknown> {
if (error instanceof Error) {
return {
name: error.name,
message: error.message,
cause:
error.cause instanceof Error
? { name: error.cause.name, message: error.cause.message }
: error.cause,
stack: error.stack?.split('\n').slice(0, 6).join('\n')
};
}
return { value: String(error) };
}
private static getBrowserContext(
targetUrl: URL,
useProxy: boolean
): Record<string, unknown> | undefined {
if (typeof window === 'undefined') {
return undefined;
}
return {
location: window.location.href,
origin: window.location.origin,
protocol: window.location.protocol,
isSecureContext: window.isSecureContext,
targetOrigin: targetUrl.origin,
targetProtocol: targetUrl.protocol,
sameOrigin: window.location.origin === targetUrl.origin,
useProxy
};
}
private static getConnectionHints(
targetUrl: URL,
config: MCPServerConfig,
error: unknown
): string[] {
const hints: string[] = [];
const message = error instanceof Error ? error.message : String(error);
const headerNames = Object.keys(config.headers ?? {});
if (typeof window !== 'undefined') {
if (
window.location.protocol === 'https:' &&
targetUrl.protocol === 'http:' &&
!config.useProxy
) {
hints.push(
'The page is running over HTTPS but the MCP server is HTTP. Browsers often block this as mixed content; enable the proxy or use HTTPS/WSS for the MCP server.'
);
}
if (window.location.origin !== targetUrl.origin && !config.useProxy) {
hints.push(
'This is a cross-origin browser request. If the server is reachable from curl or Node but not from the browser, missing CORS headers are the most likely cause.'
);
}
}
if (headerNames.length > 0) {
hints.push(
`Custom request headers are configured (${headerNames.join(', ')}). That triggers a CORS preflight, so the server must allow OPTIONS and include the matching Access-Control-Allow-Headers response.`
);
}
if (config.credentials && config.credentials !== 'omit') {
hints.push(
'Credentials are enabled for this connection. Cross-origin credentialed requests need Access-Control-Allow-Credentials: true and cannot use a wildcard Access-Control-Allow-Origin.'
);
}
if (message.includes('Failed to fetch')) {
hints.push(
'"Failed to fetch" is a browser-level network failure. Common causes are CORS rejection, mixed-content blocking, certificate/TLS errors, DNS failures, or nothing listening on the target port.'
);
}
return hints;
}
private static createDiagnosticFetch(
serverName: string,
config: MCPServerConfig,
baseInit: RequestInit,
targetUrl: URL,
useProxy: boolean,
onLog?: (log: MCPConnectionLog) => void
): {
fetch: typeof fetch;
disable: () => void;
} {
let enabled = true;
const logIfEnabled = (log: MCPConnectionLog) => {
if (enabled) {
onLog?.(log);
}
};
return {
fetch: async (input, init) => {
const startedAt = performance.now();
const requestHeaders = new Headers(baseInit.headers);
if (typeof Request !== 'undefined' && input instanceof Request) {
for (const [key, value] of input.headers.entries()) {
requestHeaders.set(key, value);
}
}
if (init?.headers) {
for (const [key, value] of new Headers(init.headers).entries()) {
requestHeaders.set(key, value);
}
}
const request = this.createDiagnosticRequestDetails(
input,
init,
baseInit,
requestHeaders,
Object.keys(config.headers ?? {})
);
const { method, url } = request;
logIfEnabled(
this.createLog(
MCPConnectionPhase.INITIALIZING,
`HTTP ${method} ${url}`,
MCPLogLevel.INFO,
{
serverName,
request
}
)
);
try {
const response = await fetch(input, {
...baseInit,
...init,
headers: requestHeaders
});
const durationMs = Math.round(performance.now() - startedAt);
logIfEnabled(
this.createLog(
MCPConnectionPhase.INITIALIZING,
`HTTP ${response.status} ${method} ${url} (${durationMs}ms)`,
response.ok ? MCPLogLevel.INFO : MCPLogLevel.WARN,
{
response: {
url,
status: response.status,
statusText: response.statusText,
headers: sanitizeHeaders(response.headers, undefined, MCP_PARTIAL_REDACT_HEADERS),
durationMs
}
}
)
);
return response;
} catch (error) {
const durationMs = Math.round(performance.now() - startedAt);
logIfEnabled(
this.createLog(
MCPConnectionPhase.ERROR,
`HTTP ${method} ${url} failed: ${formatDiagnosticErrorMessage(error)}`,
MCPLogLevel.ERROR,
{
serverName,
request,
error: this.summarizeError(error),
browser: this.getBrowserContext(targetUrl, useProxy),
hints: this.getConnectionHints(targetUrl, config, error),
durationMs
}
)
);
throw error;
}
},
disable: () => {
enabled = false;
}
};
}
/**
* Detect if an error indicates an expired/invalidated MCP session.
* Per MCP spec 2025-11-25: HTTP 404 means session invalidated, client MUST
@ -113,9 +351,14 @@ export class MCPService {
* @returns Object containing the created transport and the transport type used
* @throws {Error} If url is missing, WebSocket + proxy combination, or all transports fail
*/
static createTransport(config: MCPServerConfig): {
static createTransport(
serverName: string,
config: MCPServerConfig,
onLog?: (log: MCPConnectionLog) => void
): {
transport: Transport;
type: MCPTransportType;
stopPhaseLogging: () => void;
} {
if (!config.url) {
throw new Error('MCP server configuration is missing url');
@ -154,11 +397,20 @@ export class MCPService {
return {
transport: new WebSocketClientTransport(url),
type: MCPTransportType.WEBSOCKET
type: MCPTransportType.WEBSOCKET,
stopPhaseLogging: () => {}
};
}
const url = useProxy ? buildProxiedUrl(config.url) : new URL(config.url);
const { fetch: diagnosticFetch, disable: stopPhaseLogging } = this.createDiagnosticFetch(
serverName,
config,
requestInit,
url,
useProxy,
onLog
);
if (useProxy && import.meta.env.DEV) {
console.log(`[MCPService] Using CORS proxy for ${config.url} -> ${url.href}`);
@ -171,17 +423,24 @@ export class MCPService {
return {
transport: new StreamableHTTPClientTransport(url, {
requestInit
requestInit,
fetch: diagnosticFetch
}),
type: MCPTransportType.STREAMABLE_HTTP
type: MCPTransportType.STREAMABLE_HTTP,
stopPhaseLogging
};
} catch (httpError) {
console.warn(`[MCPService] StreamableHTTP failed, trying SSE transport...`, httpError);
try {
return {
transport: new SSEClientTransport(url, { requestInit }),
type: MCPTransportType.SSE
transport: new SSEClientTransport(url, {
requestInit,
fetch: diagnosticFetch,
eventSourceInit: { fetch: diagnosticFetch }
}),
type: MCPTransportType.SSE,
stopPhaseLogging
};
} catch (sseError) {
const httpMsg = httpError instanceof Error ? httpError.message : String(httpError);
@ -263,7 +522,11 @@ export class MCPService {
console.log(`[MCPService][${serverName}] Creating transport...`);
}
const { transport, type: transportType } = this.createTransport(serverConfig);
const {
transport,
type: transportType,
stopPhaseLogging
} = this.createTransport(serverName, serverConfig, (log) => onPhase?.(log.phase, log));
// Setup WebSocket reconnection handler
if (transportType === MCPTransportType.WEBSOCKET) {
@ -294,6 +557,24 @@ export class MCPService {
}
);
const runtimeErrorHandler = (error: Error) => {
console.error(`[MCPService][${serverName}] Protocol error after initialize:`, error);
};
client.onerror = (error) => {
onPhase?.(
MCPConnectionPhase.ERROR,
this.createLog(
MCPConnectionPhase.ERROR,
`Protocol error: ${error.message}`,
MCPLogLevel.ERROR,
{
error: this.summarizeError(error)
}
)
);
};
// Phase: Initializing
onPhase?.(
MCPConnectionPhase.INITIALIZING,
@ -301,7 +582,49 @@ export class MCPService {
);
console.log(`[MCPService][${serverName}] Connecting to server...`);
await client.connect(transport);
try {
await client.connect(transport);
// Transport diagnostics are only for the initial handshake, not long-lived traffic.
stopPhaseLogging();
client.onerror = runtimeErrorHandler;
} catch (error) {
client.onerror = runtimeErrorHandler;
const url =
(serverConfig.useProxy ?? false)
? buildProxiedUrl(serverConfig.url)
: new URL(serverConfig.url);
onPhase?.(
MCPConnectionPhase.ERROR,
this.createLog(
MCPConnectionPhase.ERROR,
`Connection failed during initialize: ${
error instanceof Error ? error.message : String(error)
}`,
MCPLogLevel.ERROR,
{
error: this.summarizeError(error),
config: {
serverName,
configuredUrl: serverConfig.url,
effectiveUrl: url.href,
transportType,
useProxy: serverConfig.useProxy ?? false,
headers: sanitizeHeaders(
serverConfig.headers,
Object.keys(serverConfig.headers ?? {}),
MCP_PARTIAL_REDACT_HEADERS
),
credentials: serverConfig.credentials
},
browser: this.getBrowserContext(url, serverConfig.useProxy ?? false),
hints: this.getConnectionHints(url, serverConfig, error)
}
)
);
throw error;
}
const serverVersion = client.getServerVersion();
const serverCapabilities = client.getServerCapabilities();

View file

@ -130,6 +130,12 @@ export const SYNCABLE_PARAMETERS: SyncableParameter[] = [
type: SyncableParameterType.BOOLEAN,
canSync: true
},
{
key: 'titleGenerationUseFirstLine',
serverKey: 'titleGenerationUseFirstLine',
type: SyncableParameterType.BOOLEAN,
canSync: true
},
{
key: 'disableAutoScroll',
serverKey: 'disableAutoScroll',

View file

@ -30,7 +30,8 @@ import {
findDescendantMessages,
findLeafNode,
findMessageById,
isAbortError
isAbortError,
generateConversationTitle
} from '$lib/utils';
import {
MAX_INACTIVE_CONVERSATION_STATES,
@ -504,7 +505,10 @@ class ChatStore {
allExtras
);
if (isNewConversation && content)
await conversationsStore.updateConversationName(currentConv.id, content.trim());
await conversationsStore.updateConversationName(
currentConv.id,
generateConversationTitle(content, Boolean(config().titleGenerationUseFirstLine))
);
const assistantMessage = await this.createAssistantMessage(userMessage.id);
conversationsStore.addMessageToActive(assistantMessage);
await this.streamChatCompletion(
@ -896,7 +900,7 @@ class ChatStore {
if (isFirstUserMessage && newContent.trim())
await conversationsStore.updateConversationTitleWithConfirmation(
activeConv.id,
newContent.trim()
generateConversationTitle(newContent, Boolean(config().titleGenerationUseFirstLine))
);
const messagesToRemove = conversationsStore.activeMessages.slice(messageIndex + 1);
for (const message of messagesToRemove) await DatabaseService.deleteMessage(message.id);
@ -1317,7 +1321,7 @@ class ChatStore {
if (rootMessage && msg.parent === rootMessage.id && newContent.trim()) {
await conversationsStore.updateConversationTitleWithConfirmation(
activeConv.id,
newContent.trim()
generateConversationTitle(newContent, Boolean(config().titleGenerationUseFirstLine))
);
}
@ -1391,7 +1395,7 @@ class ChatStore {
if (isFirstUserMessage && newContent.trim())
await conversationsStore.updateConversationTitleWithConfirmation(
activeConv.id,
newContent.trim()
generateConversationTitle(newContent, Boolean(config().titleGenerationUseFirstLine))
);
await conversationsStore.refreshActiveMessages();
if (msg.role === MessageRole.USER)

View file

@ -23,7 +23,12 @@ import { browser } from '$app/environment';
import { toast } from 'svelte-sonner';
import { DatabaseService } from '$lib/services/database.service';
import { config } from '$lib/stores/settings.svelte';
import { filterByLeafNodeId, findLeafNode, runLegacyMigration } from '$lib/utils';
import {
filterByLeafNodeId,
findLeafNode,
runLegacyMigration,
generateConversationTitle
} from '$lib/utils';
import type { McpServerOverride } from '$lib/types/database';
import { MessageRole } from '$lib/enums';
import {
@ -548,7 +553,10 @@ class ConversationsStore {
) {
await this.updateConversationTitleWithConfirmation(
this.activeConversation.id,
newFirstUserMessage.content.trim()
generateConversationTitle(
newFirstUserMessage.content,
Boolean(config().titleGenerationUseFirstLine)
)
);
}
}

View file

@ -1460,12 +1460,14 @@ class MCPStore {
} catch (error) {
const message = error instanceof Error ? error.message : 'Unknown error occurred';
logs.push({
timestamp: new Date(),
phase: MCPConnectionPhase.ERROR,
message: `Connection failed: ${message}`,
level: MCPLogLevel.ERROR
});
if (logs.at(-1)?.phase !== MCPConnectionPhase.ERROR) {
logs.push({
timestamp: new Date(),
phase: MCPConnectionPhase.ERROR,
message: `Connection failed: ${message}`,
level: MCPLogLevel.ERROR
});
}
this.updateHealthCheck(server.id, {
status: HealthCheckStatus.ERROR,

View file

@ -1,4 +1,6 @@
import { config } from '$lib/stores/settings.svelte';
import { REDACTED_HEADERS } from '$lib/constants';
import { redactValue } from './redact';
/**
* Get authorization headers for API requests
@ -20,3 +22,46 @@ export function getJsonHeaders(): Record<string, string> {
...getAuthHeaders()
};
}
/**
* Sanitize HTTP headers by redacting sensitive values.
* Known sensitive headers (from REDACTED_HEADERS) and any extra headers
* specified by the caller are fully redacted. Headers listed in
* `partialRedactHeaders` are partially redacted, showing only the
* specified number of trailing characters.
*
* @param headers - Headers to sanitize
* @param extraRedactedHeaders - Additional header names to fully redact
* @param partialRedactHeaders - Map of header name -> number of trailing chars to keep visible
* @returns Object with header names as keys and (possibly redacted) values
*/
export function sanitizeHeaders(
headers?: HeadersInit,
extraRedactedHeaders?: Iterable<string>,
partialRedactHeaders?: Map<string, number>
): Record<string, string> {
if (!headers) {
return {};
}
const normalized = new Headers(headers);
const sanitized: Record<string, string> = {};
const redactedHeaders = new Set(
Array.from(extraRedactedHeaders ?? [], (header) => header.toLowerCase())
);
for (const [key, value] of normalized.entries()) {
const normalizedKey = key.toLowerCase();
const partialChars = partialRedactHeaders?.get(normalizedKey);
if (partialChars !== undefined) {
sanitized[key] = redactValue(value, partialChars);
} else if (REDACTED_HEADERS.has(normalizedKey) || redactedHeaders.has(normalizedKey)) {
sanitized[key] = redactValue(value);
} else {
sanitized[key] = value;
}
}
return sanitized;
}

View file

@ -8,7 +8,7 @@
*/
// API utilities
export { getAuthHeaders, getJsonHeaders } from './api-headers';
export { getAuthHeaders, getJsonHeaders, sanitizeHeaders } from './api-headers';
export { apiFetch, apiFetchWithParams, apiPost, type ApiFetchOptions } from './api-fetch';
export { validateApiKey } from './api-key-validation';
@ -55,7 +55,7 @@ export {
// File preview utilities
export { getFileTypeLabel } from './file-preview';
export { getPreviewText } from './text';
export { getPreviewText, generateConversationTitle } from './text';
// File type utilities
export {
@ -164,6 +164,20 @@ export { runLegacyMigration, isMigrationNeeded } from './legacy-migration';
// Cache utilities
export { TTLCache, ReactiveTTLMap, type TTLCacheOptions } from './cache-ttl';
// Redaction utilities
export { redactValue } from './redact';
// Request inspection utilities
export {
getRequestUrl,
getRequestMethod,
getRequestBody,
summarizeRequestBody,
formatDiagnosticErrorMessage,
extractJsonRpcMethods,
type RequestBodySummary
} from './request-helpers';
// Abort signal utilities
export {
throwIfAborted,

View file

@ -0,0 +1,14 @@
/**
* Redacts a sensitive value, optionally showing the last N characters.
*
* @param value - The value to redact
* @param showLastChars - If provided, reveals the last N characters with a leading mask
* @returns The redacted string
*/
export function redactValue(value: string, showLastChars?: number): string {
if (showLastChars) {
return `....${value.slice(-showLastChars)}`;
}
return '[redacted]';
}

View file

@ -0,0 +1,111 @@
/**
* HTTP request inspection utilities for diagnostic logging.
* These helpers extract metadata from fetch-style request arguments
* without exposing sensitive payload data.
*/
export interface RequestBodySummary {
kind: string;
size?: number;
}
export function getRequestUrl(input: RequestInfo | URL): string {
if (typeof input === 'string') {
return input;
}
if (input instanceof URL) {
return input.href;
}
return input.url;
}
export function getRequestMethod(
input: RequestInfo | URL,
init?: RequestInit,
baseInit?: RequestInit
): string {
if (init?.method) {
return init.method;
}
if (typeof Request !== 'undefined' && input instanceof Request) {
return input.method;
}
return baseInit?.method ?? 'GET';
}
export function getRequestBody(
input: RequestInfo | URL,
init?: RequestInit
): BodyInit | null | undefined {
if (init?.body !== undefined) {
return init.body;
}
if (typeof Request !== 'undefined' && input instanceof Request) {
return input.body;
}
return undefined;
}
export function summarizeRequestBody(body: BodyInit | null | undefined): RequestBodySummary {
if (body == null) {
return { kind: 'empty' };
}
if (typeof body === 'string') {
return { kind: 'string', size: body.length };
}
if (body instanceof Blob) {
return { kind: 'blob', size: body.size };
}
if (body instanceof URLSearchParams) {
return { kind: 'urlsearchparams', size: body.toString().length };
}
if (body instanceof FormData) {
return { kind: 'formdata' };
}
if (body instanceof ArrayBuffer) {
return { kind: 'arraybuffer', size: body.byteLength };
}
if (ArrayBuffer.isView(body)) {
return { kind: body.constructor.name, size: body.byteLength };
}
return { kind: typeof body };
}
export function formatDiagnosticErrorMessage(error: unknown): string {
const message = error instanceof Error ? error.message : String(error);
return message.includes('Failed to fetch') ? `${message} (check CORS?)` : message;
}
export function extractJsonRpcMethods(body: BodyInit | null | undefined): string[] | undefined {
if (typeof body !== 'string') {
return undefined;
}
try {
const parsed = JSON.parse(body);
const messages = Array.isArray(parsed) ? parsed : [parsed];
const methods = messages
.map((message: Record<string, unknown>) =>
typeof message?.method === 'string' ? (message.method as string) : undefined
)
.filter((method: string | undefined): method is string => Boolean(method));
return methods.length > 0 ? methods : undefined;
} catch {
return undefined;
}
}

View file

@ -1,3 +1,5 @@
import { NEWLINE_SEPARATOR } from '$lib/constants';
/**
* Returns a shortened preview of the provided content capped at the given length.
* Appends an ellipsis when the content exceeds the maximum.
@ -5,3 +7,16 @@
export function getPreviewText(content: string, max = 150): string {
return content.length > max ? content.slice(0, max) + '...' : content;
}
/**
* Generates a single-line title from a potentially multi-line prompt.
* Uses the first non-empty line if `useFirstLine` is true.
*/
export function generateConversationTitle(content: string, useFirstLine: boolean = false): string {
if (useFirstLine) {
const firstLine = content.split(NEWLINE_SEPARATOR).find((line) => line.trim().length > 0);
return firstLine ? firstLine.trim() : content.trim();
}
return content.trim();
}

View file

@ -0,0 +1,252 @@
import { afterEach, describe, expect, it, vi } from 'vitest';
import { Client } from '@modelcontextprotocol/sdk/client';
import { MCPService } from '$lib/services/mcp.service';
import { MCPConnectionPhase, MCPTransportType } from '$lib/enums';
import type { MCPConnectionLog, MCPServerConfig } from '$lib/types';
type DiagnosticFetchFactory = (
serverName: string,
config: MCPServerConfig,
baseInit: RequestInit,
targetUrl: URL,
useProxy: boolean,
onLog?: (log: MCPConnectionLog) => void
) => { fetch: typeof fetch; disable: () => void };
const createDiagnosticFetch = (
config: MCPServerConfig,
onLog?: (log: MCPConnectionLog) => void,
baseInit: RequestInit = {}
) =>
(
MCPService as unknown as { createDiagnosticFetch: DiagnosticFetchFactory }
).createDiagnosticFetch('test-server', config, baseInit, new URL(config.url), false, onLog);
describe('MCPService', () => {
afterEach(() => {
vi.restoreAllMocks();
vi.unstubAllGlobals();
});
it('stops transport phase logging after handshake diagnostics are disabled', async () => {
const logs: MCPConnectionLog[] = [];
const response = new Response('{}', {
status: 200,
headers: { 'content-type': 'application/json' }
});
vi.stubGlobal('fetch', vi.fn().mockResolvedValue(response));
const config: MCPServerConfig = {
url: 'https://example.com/mcp',
transport: MCPTransportType.STREAMABLE_HTTP
};
const controller = createDiagnosticFetch(config, (log) => logs.push(log));
await controller.fetch(config.url, { method: 'POST', body: '{}' });
expect(logs).toHaveLength(2);
expect(logs.every((log) => log.message.includes('https://example.com/mcp'))).toBe(true);
controller.disable();
await controller.fetch(config.url, { method: 'POST', body: '{}' });
expect(logs).toHaveLength(2);
});
it('redacts all configured custom headers in diagnostic request logs', async () => {
const logs: MCPConnectionLog[] = [];
const response = new Response('{}', {
status: 200,
headers: { 'content-type': 'application/json' }
});
vi.stubGlobal('fetch', vi.fn().mockResolvedValue(response));
const config: MCPServerConfig = {
url: 'https://example.com/mcp',
transport: MCPTransportType.STREAMABLE_HTTP,
headers: {
'x-auth-token': 'secret-token',
'x-vendor-api-key': 'secret-key'
}
};
const controller = createDiagnosticFetch(config, (log) => logs.push(log), {
headers: config.headers
});
await controller.fetch(config.url, {
method: 'POST',
headers: { 'content-type': 'application/json' },
body: '{}'
});
expect(logs).toHaveLength(2);
expect(logs[0].details).toMatchObject({
request: {
headers: {
'x-auth-token': '[redacted]',
'x-vendor-api-key': '[redacted]',
'content-type': 'application/json'
}
}
});
});
it('partially redacts mcp-session-id in diagnostic request and response logs', async () => {
const logs: MCPConnectionLog[] = [];
const response = new Response('{}', {
status: 200,
headers: {
'content-type': 'application/json',
'mcp-session-id': 'session-response-67890'
}
});
vi.stubGlobal('fetch', vi.fn().mockResolvedValue(response));
const config: MCPServerConfig = {
url: 'https://example.com/mcp',
transport: MCPTransportType.STREAMABLE_HTTP
};
const controller = createDiagnosticFetch(config, (log) => logs.push(log));
await controller.fetch(config.url, {
method: 'POST',
headers: {
'content-type': 'application/json',
'mcp-session-id': 'session-request-12345'
},
body: '{}'
});
expect(logs).toHaveLength(2);
expect(logs[0].details).toMatchObject({
request: {
headers: {
'content-type': 'application/json',
'mcp-session-id': '....12345'
}
}
});
expect(logs[1].details).toMatchObject({
response: {
headers: {
'content-type': 'application/json',
'mcp-session-id': '....67890'
}
}
});
});
it('extracts JSON-RPC methods without logging the raw request body', async () => {
const logs: MCPConnectionLog[] = [];
const response = new Response('{}', {
status: 200,
headers: { 'content-type': 'application/json' }
});
vi.stubGlobal('fetch', vi.fn().mockResolvedValue(response));
const config: MCPServerConfig = {
url: 'https://example.com/mcp',
transport: MCPTransportType.STREAMABLE_HTTP
};
const controller = createDiagnosticFetch(config, (log) => logs.push(log));
await controller.fetch(config.url, {
method: 'POST',
body: JSON.stringify([
{ jsonrpc: '2.0', id: 1, method: 'initialize' },
{ jsonrpc: '2.0', method: 'notifications/initialized' }
])
});
expect(logs[0].details).toMatchObject({
request: {
method: 'POST',
body: {
kind: 'string',
size: expect.any(Number)
},
jsonRpcMethods: ['initialize', 'notifications/initialized']
}
});
});
it('adds a CORS hint to Failed to fetch diagnostic log messages', async () => {
const logs: MCPConnectionLog[] = [];
const fetchError = new TypeError('Failed to fetch');
vi.stubGlobal('fetch', vi.fn().mockRejectedValue(fetchError));
const config: MCPServerConfig = {
url: 'http://localhost:8000/mcp',
transport: MCPTransportType.STREAMABLE_HTTP
};
const controller = createDiagnosticFetch(config, (log) => logs.push(log));
await expect(controller.fetch(config.url, { method: 'POST', body: '{}' })).rejects.toThrow(
'Failed to fetch'
);
expect(logs).toHaveLength(2);
expect(logs[1].message).toBe(
'HTTP POST http://localhost:8000/mcp failed: Failed to fetch (check CORS?)'
);
});
it('detaches phase error logging after the initialize handshake completes', async () => {
const phaseLogs: Array<{ phase: MCPConnectionPhase; log: MCPConnectionLog }> = [];
const stopPhaseLogging = vi.fn();
let emitClientError: ((error: Error) => void) | undefined;
vi.spyOn(MCPService, 'createTransport').mockReturnValue({
transport: {} as never,
type: MCPTransportType.WEBSOCKET,
stopPhaseLogging
});
vi.spyOn(MCPService, 'listTools').mockResolvedValue([]);
vi.spyOn(Client.prototype, 'getServerVersion').mockReturnValue(undefined);
vi.spyOn(Client.prototype, 'getServerCapabilities').mockReturnValue(undefined);
vi.spyOn(Client.prototype, 'getInstructions').mockReturnValue(undefined);
vi.spyOn(Client.prototype, 'connect').mockImplementation(async function (this: Client) {
emitClientError = (error: Error) => this.onerror?.(error);
this.onerror?.(new Error('handshake protocol error'));
});
await MCPService.connect(
'test-server',
{
url: 'ws://example.com/mcp',
transport: MCPTransportType.WEBSOCKET
},
undefined,
undefined,
(phase, log) => phaseLogs.push({ phase, log })
);
expect(stopPhaseLogging).toHaveBeenCalledTimes(1);
expect(
phaseLogs.filter(
({ phase, log }) =>
phase === MCPConnectionPhase.ERROR &&
log.message === 'Protocol error: handshake protocol error'
)
).toHaveLength(1);
emitClientError?.(new Error('runtime protocol error'));
expect(
phaseLogs.filter(
({ phase, log }) =>
phase === MCPConnectionPhase.ERROR &&
log.message === 'Protocol error: runtime protocol error'
)
).toHaveLength(0);
});
});

View file

@ -0,0 +1,20 @@
import { describe, expect, it } from 'vitest';
import { redactValue } from '$lib/utils/redact';
describe('redactValue', () => {
it('returns [redacted] by default', () => {
expect(redactValue('secret-token')).toBe('[redacted]');
});
it('shows last N characters when showLastChars is provided', () => {
expect(redactValue('session-abc12', 5)).toBe('....abc12');
});
it('handles value shorter than showLastChars', () => {
expect(redactValue('ab', 5)).toBe('....ab');
});
it('returns [redacted] when showLastChars is 0', () => {
expect(redactValue('secret', 0)).toBe('[redacted]');
});
});

View file

@ -0,0 +1,124 @@
import { describe, expect, it } from 'vitest';
import {
getRequestUrl,
getRequestMethod,
getRequestBody,
summarizeRequestBody,
formatDiagnosticErrorMessage,
extractJsonRpcMethods
} from '$lib/utils/request-helpers';
describe('getRequestUrl', () => {
it('returns a plain string input as-is', () => {
expect(getRequestUrl('https://example.com/mcp')).toBe('https://example.com/mcp');
});
it('returns href from a URL object', () => {
expect(getRequestUrl(new URL('https://example.com/mcp'))).toBe('https://example.com/mcp');
});
it('returns url from a Request object', () => {
const req = new Request('https://example.com/mcp');
expect(getRequestUrl(req)).toBe('https://example.com/mcp');
});
});
describe('getRequestMethod', () => {
it('prefers method from init', () => {
expect(getRequestMethod('https://example.com', { method: 'POST' })).toBe('POST');
});
it('falls back to Request.method', () => {
const req = new Request('https://example.com', { method: 'PUT' });
expect(getRequestMethod(req)).toBe('PUT');
});
it('falls back to baseInit.method', () => {
expect(getRequestMethod('https://example.com', undefined, { method: 'DELETE' })).toBe('DELETE');
});
it('defaults to GET', () => {
expect(getRequestMethod('https://example.com')).toBe('GET');
});
});
describe('getRequestBody', () => {
it('returns body from init', () => {
expect(getRequestBody('https://example.com', { body: 'payload' })).toBe('payload');
});
it('returns undefined when no body is present', () => {
expect(getRequestBody('https://example.com')).toBeUndefined();
});
});
describe('summarizeRequestBody', () => {
it('returns empty for null', () => {
expect(summarizeRequestBody(null)).toEqual({ kind: 'empty' });
});
it('returns empty for undefined', () => {
expect(summarizeRequestBody(undefined)).toEqual({ kind: 'empty' });
});
it('returns string kind with size', () => {
expect(summarizeRequestBody('hello')).toEqual({ kind: 'string', size: 5 });
});
it('returns blob kind with size', () => {
const blob = new Blob(['abc']);
expect(summarizeRequestBody(blob)).toEqual({ kind: 'blob', size: 3 });
});
it('returns formdata kind', () => {
expect(summarizeRequestBody(new FormData())).toEqual({ kind: 'formdata' });
});
it('returns arraybuffer kind with size', () => {
expect(summarizeRequestBody(new ArrayBuffer(8))).toEqual({ kind: 'arraybuffer', size: 8 });
});
});
describe('formatDiagnosticErrorMessage', () => {
it('appends CORS hint for Failed to fetch', () => {
expect(formatDiagnosticErrorMessage(new TypeError('Failed to fetch'))).toBe(
'Failed to fetch (check CORS?)'
);
});
it('passes through other error messages unchanged', () => {
expect(formatDiagnosticErrorMessage(new Error('timeout'))).toBe('timeout');
});
it('handles non-Error values', () => {
expect(formatDiagnosticErrorMessage('some string')).toBe('some string');
});
});
describe('extractJsonRpcMethods', () => {
it('extracts methods from a JSON-RPC array', () => {
const body = JSON.stringify([
{ jsonrpc: '2.0', id: 1, method: 'initialize' },
{ jsonrpc: '2.0', method: 'notifications/initialized' }
]);
expect(extractJsonRpcMethods(body)).toEqual(['initialize', 'notifications/initialized']);
});
it('extracts method from a single JSON-RPC message', () => {
const body = JSON.stringify({ jsonrpc: '2.0', id: 1, method: 'tools/list' });
expect(extractJsonRpcMethods(body)).toEqual(['tools/list']);
});
it('returns undefined for non-string body', () => {
expect(extractJsonRpcMethods(null)).toBeUndefined();
expect(extractJsonRpcMethods(undefined)).toBeUndefined();
});
it('returns undefined for invalid JSON', () => {
expect(extractJsonRpcMethods('not json')).toBeUndefined();
});
it('returns undefined when no methods found', () => {
expect(extractJsonRpcMethods(JSON.stringify({ foo: 'bar' }))).toBeUndefined();
});
});

View file

@ -0,0 +1,55 @@
import { describe, expect, it } from 'vitest';
import { sanitizeHeaders } from '$lib/utils/api-headers';
describe('sanitizeHeaders', () => {
it('returns empty object for undefined input', () => {
expect(sanitizeHeaders()).toEqual({});
});
it('passes through non-sensitive headers', () => {
const headers = new Headers({ 'content-type': 'application/json', accept: 'text/html' });
expect(sanitizeHeaders(headers)).toEqual({
'content-type': 'application/json',
accept: 'text/html'
});
});
it('redacts known sensitive headers', () => {
const headers = new Headers({
authorization: 'Bearer secret',
'x-api-key': 'key-123',
'content-type': 'application/json'
});
const result = sanitizeHeaders(headers);
expect(result.authorization).toBe('[redacted]');
expect(result['x-api-key']).toBe('[redacted]');
expect(result['content-type']).toBe('application/json');
});
it('partially redacts headers specified in partialRedactHeaders', () => {
const headers = new Headers({ 'mcp-session-id': 'session-12345' });
const partial = new Map([['mcp-session-id', 5]]);
expect(sanitizeHeaders(headers, undefined, partial)['mcp-session-id']).toBe('....12345');
});
it('fully redacts mcp-session-id when no partialRedactHeaders is given', () => {
const headers = new Headers({ 'mcp-session-id': 'session-12345' });
expect(sanitizeHeaders(headers)['mcp-session-id']).toBe('[redacted]');
});
it('redacts extra headers provided by the caller', () => {
const headers = new Headers({
'x-vendor-key': 'vendor-secret',
'content-type': 'application/json'
});
const result = sanitizeHeaders(headers, ['x-vendor-key']);
expect(result['x-vendor-key']).toBe('[redacted]');
expect(result['content-type']).toBe('application/json');
});
it('handles case-insensitive extra header names', () => {
const headers = new Headers({ 'X-Custom-Token': 'token-value' });
const result = sanitizeHeaders(headers, ['X-CUSTOM-TOKEN']);
expect(result['x-custom-token']).toBe('[redacted]');
});
});