Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.github/ISSUE_TEMPLATE/010-bug-compilation.yml
#	.github/ISSUE_TEMPLATE/011-bug-results.yml
#	.github/ISSUE_TEMPLATE/019-bug-misc.yml
#	.github/ISSUE_TEMPLATE/020-enhancement.yml
#	.github/ISSUE_TEMPLATE/030-research.yml
#	.github/ISSUE_TEMPLATE/040-refactor.yml
#	ggml/CMakeLists.txt
#	ggml/src/ggml-cann/ggml-cann.cpp
#	ggml/src/ggml-hexagon/CMakeLists.txt
#	ggml/src/ggml-hexagon/ggml-hexagon.cpp
#	ggml/src/ggml-hexagon/htp/CMakeLists.txt
#	ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake
#	ggml/src/ggml-hexagon/htp/flash-attn-ops.c
#	ggml/src/ggml-hexagon/htp/hex-utils.h
#	ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c
#	ggml/src/ggml-hexagon/htp/hmx-ops.h
#	ggml/src/ggml-hexagon/htp/hmx-utils.h
#	ggml/src/ggml-hexagon/htp/hvx-base.h
#	ggml/src/ggml-hexagon/htp/hvx-copy.h
#	ggml/src/ggml-hexagon/htp/hvx-exp.h
#	ggml/src/ggml-hexagon/htp/unary-ops.c
#	ggml/src/ggml-opencl/CMakeLists.txt
#	ggml/src/ggml-opencl/ggml-opencl.cpp
#	ggml/src/ggml-opencl/kernels/cvt.cl
#	ggml/src/ggml-rpc/ggml-rpc.cpp
#	ggml/src/ggml-sycl/ggml-sycl.cpp
#	ggml/src/ggml-virtgpu/ggml-backend.cpp
#	ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
#	ggml/src/ggml-webgpu/ggml-webgpu.cpp
#	ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl
#	ggml/src/ggml-zdnn/ggml-zdnn.cpp
#	ggml/src/ggml-zendnn/ggml-zendnn.cpp
#	scripts/sync-ggml.last
#	tests/test-backend-ops.cpp
This commit is contained in:
Concedo 2026-05-02 18:07:50 +08:00
commit 7c70187e26
190 changed files with 11572 additions and 7414 deletions

View file

@ -57,7 +57,7 @@ static fs::path get_cache_directory() {
#ifndef _WIN32
const struct passwd * pw = getpwuid(getuid());
if (pw->pw_dir && *pw->pw_dir) {
if (pw && pw->pw_dir && *pw->pw_dir) {
return fs::path(pw->pw_dir) / ".cache" / "huggingface" / "hub";
}
#endif

View file

@ -13232,17 +13232,18 @@ class LazyTorchTensor(gguf.LazyBase):
}
# only used when byteswapping data. Only correct size is needed
# TODO: uncomment uint64, uint32, and uint16, ref: https://github.com/pytorch/pytorch/issues/58734
_dtype_byteswap_map: dict[torch.dtype, type] = {
torch.float64: np.float64,
torch.float32: np.float32,
torch.bfloat16: np.float16,
torch.float16: np.float16,
torch.int64: np.int64,
torch.uint64: np.uint64,
# torch.uint64: np.uint64,
torch.int32: np.int32,
torch.uint32: np.uint32,
# torch.uint32: np.uint32,
torch.int16: np.int16,
torch.uint16: np.uint16,
# torch.uint16: np.uint16,
torch.int8: np.int8,
torch.uint8: np.uint8,
torch.bool: np.uint8,

View file

@ -2100,8 +2100,8 @@ static const ggml_backend_i ggml_backend_meta_i = {
/* .free = */ ggml_backend_meta_free,
/* .set_tensor_async = */ ggml_backend_meta_set_tensor_async,
/* .get_tensor_async = */ ggml_backend_meta_get_tensor_async,
/* .get_tensor_2d_async = */ nullptr,
/* .set_tensor_2d_async = */ nullptr,
/* .get_tensor_2d_async = */ nullptr,
/* .cpy_tensor_async = */ nullptr,
/* .synchronize = */ ggml_backend_meta_synchronize,
/* .graph_plan_create = */ nullptr,

View file

@ -262,9 +262,9 @@ static struct ggml_backend_i blas_backend_i = {
/* .get_name = */ ggml_backend_blas_get_name,
/* .free = */ ggml_backend_blas_free,
/* .set_tensor_async = */ NULL,
/* .get_tensor_2d_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .get_tensor_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .get_tensor_2d_async = */ NULL,
/* .cpy_tensor_async = */ NULL,
/* .synchronize = */ NULL,
/* .graph_plan_create = */ NULL,

View file

@ -195,8 +195,8 @@ static const struct ggml_backend_i ggml_backend_cpu_i = {
/* .free = */ ggml_backend_cpu_free,
/* .set_tensor_async = */ NULL,
/* .get_tensor_async = */ NULL,
/* .get_tensor_2d_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .get_tensor_2d_async = */ NULL,
/* .cpy_tensor_async = */ NULL,
/* .synchronize = */ NULL,
/* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create,

View file

@ -4609,8 +4609,8 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
/* .free = */ ggml_backend_cuda_free,
/* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async,
/* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async,
/* .get_tensor_2d_async = */ ggml_backend_cuda_set_tensor_2d_async,
/* .set_tensor_2d_async = */ ggml_backend_cuda_get_tensor_2d_async,
/* .set_tensor_2d_async = */ ggml_backend_cuda_set_tensor_2d_async,
/* .get_tensor_2d_async = */ ggml_backend_cuda_get_tensor_2d_async,
/* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async,
/* .synchronize = */ ggml_backend_cuda_synchronize,
/* .graph_plan_create = */ NULL,

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,16 @@
#ifndef VTCM_UTILS_H
#define VTCM_UTILS_H
#include "hex-utils.h"
#include <assert.h>
#include <stdint.h>
#include <hexagon_types.h>
static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) {
uint8_t *p = *vtcm_ptr;
*vtcm_ptr += size;
return p;
}
#endif // VTCM_UTILS_H

View file

@ -166,8 +166,8 @@ static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = {
/* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor,
/* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor,
/* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor,
/* .get_tensor_2d_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor,
/* .clear = */ ggml_backend_metal_buffer_private_clear,
/* .reset = */ NULL,
@ -567,8 +567,8 @@ static ggml_backend_i ggml_backend_metal_i = {
/* .free = */ ggml_backend_metal_free,
/* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
/* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
/* .get_tensor_2d_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .get_tensor_2d_async = */ NULL,
/* .cpy_tensor_async = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups
/* .synchronize = */ ggml_backend_metal_synchronize,
/* .graph_plan_create = */ NULL,

View file

@ -0,0 +1,302 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
#define TILESIZE_K 16
#define TILESIZE_M 64
#define TILESIZE_N 32
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) {
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
sign_b.hi = fp4x8.s0 & 0x8000;
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
ushort2 fp16_packed_a_1, fp16_packed_b_1;
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
sign_b.hi = fp4x8.s1 & 0x8000;
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
}
#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \
acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \
acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \
acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \
acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \
acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \
acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \
acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \
acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \
acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \
acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \
acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \
acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \
acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \
acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \
acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \
acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \
acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \
acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \
acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \
acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \
acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \
acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \
acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \
acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \
acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \
acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \
acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \
acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \
acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \
acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \
acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \
acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \
c_reg.lo += convert_float8(acc.lo); \
c_reg.hi += convert_float8(acc.hi); \
acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \
acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \
acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \
acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \
acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \
acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \
acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \
acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \
acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \
acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \
acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \
acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \
acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \
acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \
acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \
acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \
acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \
acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \
acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \
acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \
acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \
acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \
acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \
acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \
acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \
acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \
acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \
acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \
acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \
acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \
acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \
acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \
c_reg.lo += convert_float8(acc.lo); \
c_reg.hi += convert_float8(acc.hi); \
static inline half e8m0_to_fp16(uchar x) {
ushort bits;
bits = (ushort)(x) - (ushort)(112);
bits = ((bits & 0x00E0) != 0) ? 0x7C00 : (bits << 10);
return as_half(bits);
}
static inline float e8m0_to_fp32(uchar x) {
int bits;
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
return as_float(bits);
}
__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair
kernel void kernel_gemm_moe_mxfp4_f32_ns(
__read_only image1d_buffer_t src0_q,
__global uchar * src0_d,
__read_only image1d_buffer_t src1,
__global uint * src2,
__global ushort * src2_emap,
__write_only image1d_buffer_t dst,
__global int * total_tiles,
uint ne00,
uint ne01
) {
uint block_id_m = get_global_id(1); // m_tile
uint block_id_n = get_global_id(2); // n_tile
// Boundary check
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
return;
}
__private half16 reg_a;
__private float32 reg_c = (float32)(0);
__local half4 shared_b[128];
const ushort expert_id = src2_emap[block_id_n];
const uint row = block_id_m * TILESIZE_M;
const uint col = block_id_n * TILESIZE_N;
uint sub_block_id_m = get_local_id(0);
uint2 b_global_offset;
b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00;
b_global_offset.y = b_global_offset.x + (16 * ne00);
uint2 b_local_offset;
b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2);
b_local_offset.y = b_local_offset.x + 16;
// Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks
for (uint step = 0; step < ne00; step += TILESIZE_K * 2) {
// First sub-block
uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5);
uint b_sub_offset = col * ne00 + step;
// Load scale for current mxfp4 block
uint s_offset = s_sub_offset + get_global_id(0);
float s = e8m0_to_fp32(src0_d[s_offset]);
// Load 16 fp4 (64-bits) in transposed layout
uint2 mxfp4x16;
mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
// Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements
float8 bx8_f32;
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
// Convert to half and store to LM to share within the subgroup
half8 bx8_f16 = convert_half8(bx8_f32);
shared_b[b_local_offset.x] = bx8_f16.lo;
shared_b[b_local_offset.y] = bx8_f16.hi;
// Dequantization
reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s;
reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s;
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
// 32 16x16 fp16 dot product with 8 elements reduction for better precision
half16 acc;
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
// Repeat for second sub-block
uint half_step = step + TILESIZE_K;
q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
b_sub_offset = col * ne00 + half_step;
// Load next 16 fp4 (64-bits) in transposed layout
mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
// Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
// Convert to half and store to LM to share within the subgroup
bx8_f16 = convert_half8(bx8_f32);
shared_b[b_local_offset.x] = bx8_f16.lo;
shared_b[b_local_offset.y] = bx8_f16.hi;
// Dequantization
reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s;
reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s;
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
// 32 16x16 fp16 dot product with 3-levels reduction for better precision
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
}
// Load poster router and share in LM
__local uint out_idx[TILESIZE_N];
if (get_local_id(0) < TILESIZE_N) {
uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)];
if (idx == 0xFFFFFFFF) {
idx = src2[block_id_n * TILESIZE_N + 0];
}
out_idx[get_local_id(0)] = idx * ne01;
}
barrier(CLK_LOCAL_MEM_FENCE);
// Scatter results back to original position in output grid
uint m_offset = row + get_local_id(0);
write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1));
write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2));
write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3));
write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4));
write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5));
write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6));
write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7));
write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8));
write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9));
write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa));
write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb));
write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc));
write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd));
write_imagef(dst, out_idx[14] + m_offset, (reg_c.se));
write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf));
write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg));
write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh));
write_imagef(dst, out_idx[18] + m_offset, (reg_c.si));
write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj));
write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk));
write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl));
write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm));
write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn));
write_imagef(dst, out_idx[24] + m_offset, (reg_c.so));
write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp));
write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq));
write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr));
write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss));
write_imagef(dst, out_idx[29] + m_offset, (reg_c.st));
write_imagef(dst, out_idx[30] + m_offset, (reg_c.su));
write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv));
// Store zero padding parts to the index of first output in tile, override correct result in the end
barrier(CLK_GLOBAL_MEM_FENCE);
write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0));
}

View file

@ -0,0 +1,161 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define QK_MXFP4 32
#define N_SIMDGROUP 4
#define SIMDGROUP_WIDTH 64
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) {
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
sign_b.hi = fp4x8.s0 & 0x8000;
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
ushort2 fp16_packed_a_1, fp16_packed_b_1;
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
sign_b.hi = fp4x8.s1 & 0x8000;
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
}
static inline float e8m0_to_fp32(uchar x) {
int bits;
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
return as_float(bits);
}
__attribute__((qcom_reqd_sub_group_size("half")))
__kernel void kernel_gemv_moe_mxfp4_f32_ns(
__global uint * src0_q,
__global uchar * src0_e,
__read_only image1d_buffer_t src1,
__global uint * src2,
__global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne11
) {
uint i01 = get_global_id(0);
uint i20 = get_global_id(2);
uint sgid = get_local_id(1);
uint slid = get_sub_group_local_id();
uint i11 = i20 % ne11;
uint expert_id = src2[i20];
uint expert_offset = expert_id * ne00 * ne01 / 32;
__private float sum = 0.0f; // each thread calculate partial sum of one output
// loop along ne00 in block granularity, skip 4 blocks every iter
for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {
// load one block of q
uint4 regQ;
uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01;
regQ.s0 = src0_q[block_offset];
regQ.s1 = src0_q[block_offset + ne01];
regQ.s2 = src0_q[block_offset + ne01 * 2];
regQ.s3 = src0_q[block_offset + ne01 * 3];
uint offset = i11 * ne00 / 4 + ib00 * 8;
half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));
float4 shared_y4;
shared_y4 = read_imagef(src1, (offset + 0));
float4 acc = shared_y4 * convert_float4(fp16x8.lo);
shared_y4 = read_imagef(src1, (offset + 1));
acc += shared_y4 * convert_float4(fp16x8.hi);
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));
shared_y4 = read_imagef(src1, (offset + 2));
acc += shared_y4 * convert_float4(fp16x8.lo);
shared_y4 = read_imagef(src1, (offset + 3));
acc += shared_y4 * convert_float4(fp16x8.hi);
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));
shared_y4 = read_imagef(src1, (offset + 4));
acc += shared_y4 * convert_float4(fp16x8.lo);
shared_y4 = read_imagef(src1, (offset + 5));
acc += shared_y4 * convert_float4(fp16x8.hi);
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));
shared_y4 = read_imagef(src1, (offset + 6));
acc += shared_y4 * convert_float4(fp16x8.lo);
shared_y4 = read_imagef(src1, (offset + 7));
acc += shared_y4 * convert_float4(fp16x8.hi);
uchar regE = src0_e[ib00 * ne01 + i01 + expert_offset];
sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
}
// reduction in local memory, assumes #subgroups=4
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
barrier(CLK_LOCAL_MEM_FENCE);
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
// 1 outputs per thread in subgroup 0
if (sgid == 0) {
dst = dst + (offsetd >> 2);
dst[i01 + i20 * ne01] = sum;
}
}

View file

@ -0,0 +1,30 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define QK4_0 32
kernel void kernel_moe_reorder_b(
global float4 * src,
global uint * router,
global float4 * dst,
global int * total_tiles,
uint K,
ushort map_ratio,
uint tile_size
) {
uint k_4 = get_global_id(0);
uint post_router_idx = get_global_id(1);
if ((k_4 >= (K / 4)) || (post_router_idx >= total_tiles[0] * tile_size)) {
return;
}
uint router_idx = router[post_router_idx];
float4 out = (float4)(0);
if (router_idx != 0xFFFFFFFF) {
ushort activation_idx = router_idx / map_ratio;
out = src[activation_idx * K / 4 + k_4];
}
dst[post_router_idx * K / 4 + k_4] = out;
}

View file

@ -0,0 +1,82 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void kernel_moe_histogram(
__global const int * input,
__global int * hist,
uint N,
uint topK,
uint n_experts
) {
uint n = get_global_id(0);
uint k = get_global_id(1);
if (n >= N || k >= topK) {
return;
}
int expert_id = input[n * n_experts + k];
atomic_inc(&hist[expert_id]);
}
__kernel void kernel_moe_scan(
__global int * hist,
__global int * tile_offset,
__global int * total_tiles,
__global int * slot_counter,
int tile_size,
uint n_experts
) {
int offset = 0;
for (int v = 0; v < n_experts; v++) {
int count = hist[v];
int tiles = (count + tile_size - 1) / tile_size;
tile_offset[v] = offset;
offset += tiles;
hist[v] = 0;
slot_counter[v] = 0;
}
*total_tiles = offset;
}
__kernel void kernel_moe_scatter(
__global const int * input,
__global int * post_router,
__global ushort * emap,
__global const int * tile_offset,
__global int * slot_counter,
int N,
int topK,
uint n_experts
) {
uint n = get_global_id(0);
uint k = get_global_id(1);
if (n >= N || k >= topK) {
return;
}
int val = input[n * n_experts + k];
int local_slot = atomic_inc(&slot_counter[val]);
int tile_idx = tile_offset[val] + (local_slot / 32);
int lane = local_slot % 32;
int out_pos = tile_idx * 32 + lane;
post_router[out_pos] = n * topK + k;
emap[tile_idx] = val;
}
__kernel void kernel_moe_fill(
__global int * post_router,
__global int * total_tiles,
int tile_size
) {
int tile_id = get_global_id(0);
int vec_id_in_tile = get_global_id(1);
if (tile_id < total_tiles[0]) {
post_router[tile_id * tile_size + vec_id_in_tile] = 0xFFFFFFFF;
}
}

View file

@ -445,10 +445,12 @@ struct vk_fa_pipeline_state {
bool f32acc;
uint32_t flags;
uint32_t limit_occupancy_shmem;
ggml_type k_type;
ggml_type v_type;
bool operator<(const vk_fa_pipeline_state &b) const {
return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) <
std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem);
return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem, k_type, v_type) <
std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem, b.k_type, b.v_type);
}
};
@ -3046,7 +3048,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device
return result;
}
static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
GGML_UNUSED(n_kv);
GGML_UNUSED(f32acc);
@ -3060,7 +3062,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device
if (small_rows) {
result.block_rows = 32;
result.block_cols = 32;
} else if (ggml_is_quantized(kv_type) || hsk >= 256 || hsv >= 256) {
} else if (ggml_is_quantized(k_type) || ggml_is_quantized(v_type) || hsk >= 256 || hsv >= 256) {
result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64;
result.block_cols = 32;
} else {
@ -3074,7 +3076,13 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device
return result;
}
static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
// Mixed K/V is only implemented on the coopmat2 (flash_attn_cm2) path; never use scalar/cm1.
if (k_type != v_type) {
GGML_ASSERT(device->coopmat2);
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
}
FaCodePath path = device->coopmat2 ? FA_COOPMAT2 :
device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
@ -3086,7 +3094,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
if (path == FA_COOPMAT1) {
bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) ||
(!f32acc && device->coopmat_support_16x16x16_f16acc);
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc);
if (!shape_ok || !shmem_ok) {
@ -3099,20 +3107,25 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
path = FA_SCALAR;
}
// Q1_0 K/V is only implemented on coopmat2 (flash_attn_cm2); there is no scalar FA shader for it.
if ((k_type == GGML_TYPE_Q1_0 || v_type == GGML_TYPE_Q1_0) && device->coopmat2) {
path = FA_COOPMAT2;
}
switch (path) {
case FA_SCALAR:
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
case FA_COOPMAT1:
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
case FA_COOPMAT2:
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
default:
throw std::runtime_error("unsupported FaCodePath");
}
}
static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc,
bool use_mask, bool use_mask_opt, bool use_logit_softcap) {
bool use_mask, bool use_mask_opt, bool use_logit_softcap, ggml_type k_type, ggml_type v_type) {
const bool old_amd_windows = device->vendor_id == VK_VENDOR_ID_AMD && device->driver_id == vk::DriverId::eAmdProprietary &&
(device->architecture == AMD_GCN || device->architecture == AMD_RDNA1 || device->architecture == AMD_RDNA2);
@ -3123,12 +3136,32 @@ static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const
const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size;
return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem};
return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem, k_type, v_type};
}
static std::vector<uint32_t> get_fa_spec_constants(const vk_fa_pipeline_state& state) {
return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split,
state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags, state.limit_occupancy_shmem};
const auto fa_block_bytes = [](ggml_type t) -> uint32_t {
// decodeBufF32 uses a block of vec4s for a better memory access pattern.
return t == GGML_TYPE_F32 ? 16u : (uint32_t) ggml_type_size(t);
};
return {
/* 0 WorkGroupSize */ state.workgroup_size,
/* 1 Br */ state.Br,
/* 2 Bc */ state.Bc,
/* 3 HSK */ state.HSK,
/* 4 HSV */ state.HSV,
/* 5 Clamp */ static_cast<uint32_t>(!state.aligned),
/* 6 D_split */ state.D_split,
/* 7 row_split */ state.row_split,
/* 8 SubGroupSize */ state.subgroup_size,
/* 9 SHMEM_STAGING */ state.shmem_staging ? 1u : 0u,
/*10 Flags */ state.flags,
/*11 LIMIT_OCCUPANCY_SHMEM */ state.limit_occupancy_shmem,
/*12 FaTypeK */ static_cast<uint32_t>(state.k_type),
/*13 FaTypeV */ static_cast<uint32_t>(state.v_type),
/*14 FaBlockBytesK */ fa_block_bytes(state.k_type),
/*15 FaBlockBytesV */ fa_block_bytes(state.v_type),
};
}
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
@ -3583,16 +3616,35 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
#endif
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
#define CREATE_FA_CM2_MIXED() \
for (int fa_k_ty = 0; fa_k_ty < (int)GGML_TYPE_COUNT; ++fa_k_ty) { \
for (auto &fa : device->pipeline_flash_attn_f32_f16[fa_k_ty]) { \
FaCodePath path = fa.first.path; \
uint32_t Br = fa.first.Br; \
uint32_t Bc = fa.first.Bc; \
bool aligned = fa.first.aligned; \
bool f32acc = fa.first.f32acc; \
if (path == FA_COOPMAT2) { \
if (aligned) { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \
} \
} else { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \
} \
} \
} \
} \
}
if (device->coopmat2) {
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
CREATE_FA_CM2_MIXED();
}
#undef CREATE_FA_CM2_MIXED
#endif
#undef CREATE_FA
@ -6872,7 +6924,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
}
}
static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) {
static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) {
VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")");
// Check if src is pinned memory
vk_buffer buf = nullptr;
@ -6882,7 +6934,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
if (buf != nullptr) {
// Memory is pinned, use as staging buffer
std::vector<vk::BufferCopy> slices(1);
if (width == spitch) {
if (width == spitch && width == dpitch) {
// Only do single write if stride is equal
slices[0].srcOffset = buf_offset;
slices[0].dstOffset = offset;
@ -6891,7 +6943,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
slices.resize(height);
for (size_t i = 0; i < height; i++) {
slices[i].srcOffset = buf_offset + i * spitch;
slices[i].dstOffset = offset + i * width;
slices[i].dstOffset = offset + i * dpitch;
slices[i].size = width;
}
}
@ -6908,21 +6960,30 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
}
// Staging buffer required
const size_t copy_size = width*height;
ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size);
const size_t staging_size = width * height;
ggml_vk_ensure_sync_staging_buffer(dst->device, staging_size);
vk_buffer& staging_buffer = dst->device->sync_staging;
VkBufferCopy buf_copy = {
0,
offset,
copy_size};
std::vector<vk::BufferCopy> slices(1);
if (width == dpitch) {
slices[0].srcOffset = 0;
slices[0].dstOffset = offset;
slices[0].size = staging_size;
} else {
slices.resize(height);
for (size_t i = 0; i < height; i++) {
slices[i].srcOffset = i * width;
slices[i].dstOffset = offset + i * dpitch;
slices[i].size = width;
}
}
ggml_vk_sync_buffers(nullptr, subctx);
vkCmdCopyBuffer(subctx->s->buffer->buf, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
subctx->s->buffer->buf.copyBuffer((VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, slices);
if (width == spitch) {
deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys);
deferred_memcpy((uint8_t *)staging_buffer->ptr, src, staging_size, &subctx->in_memcpys);
} else {
for (size_t i = 0; i < height; i++) {
deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys);
@ -6933,24 +6994,24 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
static bool ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) {
VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")");
return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging);
return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, size, 1, sync_staging);
}
static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) {
static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t dpitch, size_t width, size_t height) {
VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")");
// Buffer is already mapped
if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
for (size_t i = 0; i < height; i++) {
memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
memcpy((uint8_t *)dst->ptr + offset + i * dpitch, (const uint8_t *) src + i * spitch, width);
}
} else {
std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(dst->device, subctx);
bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true);
bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, dpitch, width, height, true);
GGML_ASSERT(ret);
ggml_vk_ctx_end(subctx);
@ -6971,7 +7032,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) {
VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")");
ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1);
ggml_vk_buffer_write_2d(dst, offset, src, size, size, size, 1);
}
static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) {
@ -7017,15 +7078,35 @@ static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
}
// Fall back to staging buffer
const size_t copy_size = dpitch * height;
ggml_vk_ensure_sync_staging_buffer(src->device, copy_size);
const size_t staging_size = width * height;
ggml_vk_ensure_sync_staging_buffer(src->device, staging_size);
vk_buffer& staging_buffer = src->device->sync_staging;
ggml_vk_sync_buffers(nullptr, subctx);
subctx->s->buffer->buf.copyBuffer(src->buffer, staging_buffer->buffer, slices);
std::vector<vk::BufferCopy> staging_slices(1);
if (width == spitch) {
staging_slices[0].srcOffset = offset;
staging_slices[0].dstOffset = 0;
staging_slices[0].size = staging_size;
} else {
staging_slices.resize(height);
for (size_t i = 0; i < height; i++) {
staging_slices[i].srcOffset = offset + i * spitch;
staging_slices[i].dstOffset = i * width;
staging_slices[i].size = width;
}
}
deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys);
ggml_vk_sync_buffers(nullptr, subctx);
subctx->s->buffer->buf.copyBuffer(src->buffer, staging_buffer->buffer, staging_slices);
if (width == dpitch) {
deferred_memcpy(dst, staging_buffer->ptr, staging_size, &subctx->out_memcpys);
} else {
for (size_t i = 0; i < height; i++) {
deferred_memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) staging_buffer->ptr + i * width, width, &subctx->out_memcpys);
}
}
return true;
}
@ -7033,8 +7114,8 @@ static bool ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t
return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging);
}
static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) {
VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")");
static void ggml_vk_buffer_read_2d(vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height) {
VK_LOG_DEBUG("ggml_vk_buffer_read_2d(" << src->buffer << ", " << offset << ", " << width << ", " << height << ")");
// If the device is not an UMA device the memory is host-accessible through rebar. While writing
// through PCIe is sufficient fast reading back data from PCIe is slower than going through
@ -7042,18 +7123,20 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_
if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) {
GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
memcpy(dst, (uint8_t *) src->ptr + offset, size);
for (size_t i = 0; i < height; i++) {
memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) src->ptr + offset + i * spitch, width);
}
} else {
std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(src->device, subctx);
bool ret = ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true);
bool ret = ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, spitch, dpitch, width, height, true);
GGML_ASSERT(ret);
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, src->device->fence);
VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences");
VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read_2d waitForFences");
src->device->device.resetFences({ src->device->fence });
ggml_vk_queue_command_pools_cleanup(src->device);
@ -7063,6 +7146,11 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_
}
}
static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) {
VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")");
ggml_vk_buffer_read_2d(src, offset, dst, size, size, size, 1);
}
static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")");
// Make sure both buffers are on same device
@ -7094,7 +7182,7 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
// Copy to src staging buffer
ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
// Copy to dst buffer
ggml_vk_buffer_write_2d(dst, dst_offset, src->device->sync_staging->ptr, 0, size, 1);
ggml_vk_buffer_write(dst, dst_offset, src->device->sync_staging->ptr, size);
}
}
@ -9033,8 +9121,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
assert(dst->type == GGML_TYPE_F32);
assert(q->type == GGML_TYPE_F32);
assert(k->type == v->type);
uint32_t gqa_ratio = 1;
uint32_t qk_ratio = neq2 / nek2;
uint32_t workgroups_x = (uint32_t)neq1;
@ -9045,7 +9131,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc);
vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, v->type, f32acc);
const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u);
if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
@ -9058,7 +9144,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
workgroups_y /= gqa_ratio;
}
tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc);
tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, v->type, f32acc);
if (tuning_params.path != FA_COOPMAT2) {
GGML_ASSERT(k->type == v->type);
}
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
@ -9097,7 +9187,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16;
vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,
mask != nullptr, use_mask_opt, logit_softcap != 0);
mask != nullptr, use_mask_opt, logit_softcap != 0, k->type, v->type);
vk_pipeline pipeline = nullptr;
@ -13642,6 +13732,20 @@ static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml
ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
}
static void ggml_backend_vk_buffer_set_tensor_2d(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset,
size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor_2d(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ", " <<
n_copies << ", " << stride_tensor << ", " << stride_data << ")");
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
vk_buffer buf = buf_ctx->dev_buffer;
if (size == 0) {
return;
}
ggml_vk_buffer_write_2d(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, stride_data, stride_tensor, size, n_copies);
}
static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
@ -13655,6 +13759,21 @@ static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, cons
ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
}
static void ggml_backend_vk_buffer_get_tensor_2d(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset,
size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor_2d(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ", " <<
n_copies << ", " << stride_tensor << ", " << stride_data << ")");
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
if (size == 0) {
return;
}
vk_buffer buf = buf_ctx->dev_buffer;
ggml_vk_buffer_read_2d(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, stride_tensor, stride_data, size, n_copies);
}
static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
if (ggml_nbytes(src) == 0) {
return true;
@ -13689,8 +13808,8 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
/* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_vk_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_vk_buffer_get_tensor,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .set_tensor_2d = */ ggml_backend_vk_buffer_set_tensor_2d,
/* .get_tensor_2d = */ ggml_backend_vk_buffer_get_tensor_2d,
/* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor,
/* .clear = */ ggml_backend_vk_buffer_clear,
/* .reset = */ NULL,
@ -13846,8 +13965,9 @@ static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_b
return &ctx->device->buffer_type;
}
static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")");
static void ggml_backend_vk_set_tensor_2d_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset,
size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
VK_LOG_DEBUG("ggml_backend_vk_set_tensor_2d_async(" << size << ", " << n_copies << ")");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
@ -13861,7 +13981,6 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
if (ctx->device->async_use_transfer_queue) {
if (ctx->transfer_ctx.expired()) {
// Initialize new transfer context
cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool);
ctx->transfer_ctx = cpy_ctx;
ggml_vk_ctx_begin(ctx->device, cpy_ctx);
@ -13876,25 +13995,48 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
bool ret = ggml_vk_buffer_write_async(cpy_ctx, buf, dst_offset, data, size);
bool ret = ggml_vk_buffer_write_2d_async(cpy_ctx, buf, dst_offset, data, stride_data, stride_tensor, size, n_copies);
if (!ret) {
ggml_vk_ensure_sync_staging_buffer(ctx, size);
const size_t staging_size = size * n_copies;
ggml_vk_ensure_sync_staging_buffer(ctx, staging_size);
ggml_vk_sync_buffers(nullptr, cpy_ctx);
vk::BufferCopy buffer_cpy;
buffer_cpy.srcOffset = 0;
buffer_cpy.dstOffset = dst_offset;
buffer_cpy.size = size;
std::vector<vk::BufferCopy> slices(1);
if (size == stride_tensor) {
slices[0].srcOffset = 0;
slices[0].dstOffset = dst_offset;
slices[0].size = staging_size;
} else {
slices.resize(n_copies);
for (size_t i = 0; i < n_copies; i++) {
slices[i].srcOffset = i * size;
slices[i].dstOffset = dst_offset + i * stride_tensor;
slices[i].size = size;
}
}
cpy_ctx->s->buffer->buf.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
deferred_memcpy(ctx->sync_staging->ptr, data, size, &cpy_ctx->in_memcpys);
cpy_ctx->s->buffer->buf.copyBuffer(ctx->sync_staging->buffer, buf->buffer, slices);
if (size == stride_data) {
deferred_memcpy(ctx->sync_staging->ptr, data, staging_size, &cpy_ctx->in_memcpys);
} else {
for (size_t i = 0; i < n_copies; i++) {
deferred_memcpy((uint8_t *)ctx->sync_staging->ptr + i * size, (const uint8_t *)data + i * stride_data, size, &cpy_ctx->in_memcpys);
}
}
ggml_vk_synchronize(ctx);
}
}
static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")");
static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")");
ggml_backend_vk_set_tensor_2d_async(backend, tensor, data, offset, size, 1, size, size);
}
static void ggml_backend_vk_get_tensor_2d_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset,
size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
VK_LOG_DEBUG("ggml_backend_vk_get_tensor_2d_async(" << size << ", " << n_copies << ")");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
@ -13909,24 +14051,45 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_
vk_buffer buf = buf_ctx->dev_buffer;
auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
bool ret = ggml_vk_buffer_read_async(compute_ctx, buf, src_offset, data, size);
bool ret = ggml_vk_buffer_read_2d_async(compute_ctx, buf, src_offset, data, stride_tensor, stride_data, size, n_copies);
// If that failed, copy synchronously through a staging buffer
if (!ret) {
ggml_vk_ensure_sync_staging_buffer(ctx, size);
const size_t staging_size = size * n_copies;
ggml_vk_ensure_sync_staging_buffer(ctx, staging_size);
ggml_vk_sync_buffers(nullptr, compute_ctx);
vk::BufferCopy buffer_cpy;
buffer_cpy.srcOffset = src_offset;
buffer_cpy.dstOffset = 0;
buffer_cpy.size = size;
std::vector<vk::BufferCopy> slices(1);
if (size == stride_tensor) {
slices[0].srcOffset = src_offset;
slices[0].dstOffset = 0;
slices[0].size = staging_size;
} else {
slices.resize(n_copies);
for (size_t i = 0; i < n_copies; i++) {
slices[i].srcOffset = src_offset + i * stride_tensor;
slices[i].dstOffset = i * size;
slices[i].size = size;
}
}
compute_ctx->s->buffer->buf.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy });
deferred_memcpy(data, ctx->sync_staging->ptr, size, &compute_ctx->out_memcpys);
compute_ctx->s->buffer->buf.copyBuffer(buf->buffer, ctx->sync_staging->buffer, slices);
if (size == stride_data) {
deferred_memcpy(data, ctx->sync_staging->ptr, staging_size, &compute_ctx->out_memcpys);
} else {
for (size_t i = 0; i < n_copies; i++) {
deferred_memcpy((uint8_t *)data + i * stride_data, (const uint8_t *)ctx->sync_staging->ptr + i * size, size, &compute_ctx->out_memcpys);
}
}
ggml_vk_synchronize(ctx);
}
}
static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")");
ggml_backend_vk_get_tensor_2d_async(backend, tensor, data, offset, size, 1, size, size);
}
static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async(" << src << " -> " << dst << ", size=" << ggml_nbytes(src) << ")");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend_dst->context;
@ -15150,8 +15313,8 @@ static ggml_backend_i ggml_backend_vk_interface = {
/* .free = */ ggml_backend_vk_free,
/* .set_tensor_async = */ ggml_backend_vk_set_tensor_async,
/* .get_tensor_async = */ ggml_backend_vk_get_tensor_async,
/* .get_tensor_2d_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .set_tensor_2d_async = */ ggml_backend_vk_set_tensor_2d_async,
/* .get_tensor_2d_async = */ ggml_backend_vk_get_tensor_2d_async,
/* .cpy_tensor_async = */ ggml_backend_vk_cpy_tensor_async,
/* .synchronize = */ ggml_backend_vk_synchronize,
/* .graph_plan_create = */ NULL,
@ -15508,38 +15671,27 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
return false;
}
// It's straightforward to support different K/V dequant, but would
// significantly increase the number of pipelines
if (op->src[1]->type != op->src[2]->type) {
// mismatching K/V type is currently supported for coopmat2 only.
if (op->src[1]->type != op->src[2]->type && !coopmat2) {
return false;
}
switch (op->src[1]->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_IQ4_NL:
// supported in scalar and coopmat2 paths
break;
// K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
//case GGML_TYPE_Q2_K:
//case GGML_TYPE_Q3_K:
//case GGML_TYPE_Q4_K:
//case GGML_TYPE_Q5_K:
//case GGML_TYPE_Q6_K:
//case GGML_TYPE_IQ1_S:
//case GGML_TYPE_IQ1_M:
//case GGML_TYPE_IQ2_XXS:
//case GGML_TYPE_IQ2_XS:
//case GGML_TYPE_IQ2_S:
//case GGML_TYPE_IQ3_XXS:
//case GGML_TYPE_IQ3_S:
//case GGML_TYPE_IQ4_XS:
default:
auto fa_kv_ok = [coopmat2](ggml_type t) {
switch (t) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_0:
return true;
case GGML_TYPE_Q1_0:
return coopmat2;
default:
return false;
}
};
if (!fa_kv_ok(op->src[1]->type) || !fa_kv_ok(op->src[2]->type)) {
return false;
}
if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) {

View file

@ -13,6 +13,12 @@ layout (constant_id = 8) const uint32_t SubGroupSize = 32;
layout (constant_id = 9) const uint32_t SHMEM_STAGING = 0;
layout (constant_id = 10) const uint32_t Flags = 0;
layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0;
// ggml_type enumerant for K/V
layout (constant_id = 12) const uint32_t FaTypeK = 0;
layout (constant_id = 13) const uint32_t FaTypeV = 0;
// sizeof(decode buffer): quants -> ggml block size; F32 -> 16 (decodeBufF32 vec4).
layout (constant_id = 14) const uint32_t FaBlockBytesK = 2;
layout (constant_id = 15) const uint32_t FaBlockBytesV = 2;
const bool USE_MASK_OPT = (Flags & 1) != 0;
const bool MASK_ENABLE = (Flags & 2) != 0;

View file

@ -17,8 +17,57 @@
#extension GL_EXT_null_initializer : enable
#include "types.glsl"
#include "dequant_funcs_cm2.glsl"
#include "flash_attn_base.glsl"
#include "dequant_funcs_cm2.glsl"
// buffer_reference stride = sizeof(struct) = FaBlockBytesK/V.
layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_K {
uint8_t raw[FaBlockBytesK];
};
layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_V {
uint8_t raw[FaBlockBytesV];
};
uint fa_block_elems(uint ty) {
switch (ty) {
case 0u: return 4u; // GGML_TYPE_F32: vec4 block (matches decodeBufF32 / dequantFuncF32)
case 1u: return 1u; // GGML_TYPE_F16
case 2u: return uint(QUANT_K_Q4_0);
case 3u: return uint(QUANT_K_Q4_1);
case 6u: return uint(QUANT_K_Q5_0);
case 7u: return uint(QUANT_K_Q5_1);
case 8u: return uint(QUANT_K_Q8_0);
case 41u: return uint(QUANT_K_Q1_0);
default:
return 1u;
}
}
float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
switch (FaTypeK) {
case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock);
case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
default: return float16_t(0);
}
}
float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
switch (FaTypeV) {
case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock);
case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
default: return float16_t(0);
}
}
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
@ -55,12 +104,6 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
return max(elem0, elem1);
}
#if BLOCK_SIZE > 1
#define DECODEFUNC , DEQUANTFUNC
#else
#define DECODEFUNC
#endif
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
@ -95,10 +138,6 @@ ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c
}
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
init_indices();
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
@ -107,10 +146,10 @@ void main() {
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
#if BLOCK_SIZE > 1
tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
#endif
const uint bs_k = fa_block_elems(FaTypeK);
const uint bs_v = fa_block_elems(FaTypeV);
tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, bs_k);
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, bs_v);
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
@ -120,10 +159,12 @@ void main() {
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
{
q_stride &= ~7;
#if BLOCK_SIZE == 1
k_stride &= ~7;
v_stride &= ~7;
#endif
if (bs_k == 1u) {
k_stride &= ~7;
}
if (bs_v == 1u) {
v_stride &= ~7;
}
m_stride &= ~7;
}
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
@ -230,7 +271,13 @@ void main() {
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
// F16: bs_k==1 (direct load). F32: bs_k==4 (vec4 / dequantFuncF32). Q4/Q8 family: bs_k==32. Q1_0: bs_k==128.
const bool k_use_decode = (bs_k > 1u);
if (k_use_decode) {
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose, faDecodeK);
} else {
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose);
}
S = coopMatMulAdd(Qf16, K_T, S);
if (LOGIT_SOFTCAP) {
@ -291,7 +338,12 @@ void main() {
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC);
const bool v_use_decode = (bs_v > 1u);
if (v_use_decode) {
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad), faDecodeV);
} else {
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad));
}
L = eM*L + rowsum;

View file

@ -658,20 +658,17 @@ void process_shaders() {
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
}
if (fp16) {
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
string_to_spv("flash_attn_f32_f16_mixed", "flash_attn_cm2.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc);
#endif
}
for (const auto& tname : type_names) {
if (tname == "bf16") continue;
if (fp16) {
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc);
} else {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc);
}
#endif
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",

View file

@ -0,0 +1,154 @@
#ifdef USE_SUBGROUP_REDUCTION
enable subgroups;
#endif
enable f16;
#define DECLARE_BYTE_LOADERS_SRC0
#include "common_decls.tmpl"
#include "mul_mat_vec_acc.tmpl"
struct MulMatIdVecParams {
offset_src0: u32,
offset_src1: u32,
offset_ids: u32,
offset_dst: u32,
k: u32,
m: u32,
n_expert: u32,
n_expert_used: u32,
b_ne1: u32,
stride_01: u32,
stride_11: u32,
stride_02: u32,
stride_12: u32,
};
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // [cols, rows, n_expert]
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // [cols, b_ne1, n_tokens(1)]
@group(0) @binding(2) var<storage, read_write> ids: array<u32>; // [n_experd_used, n_tokens(1)]
@group(0) @binding(3) var<storage, read_write> dst: array<f32>; // [rows, n_expert_used, n_tokens(1)]
// "mul_mat_vec_acc.tmpl" requires params.k, params.m, params.stride_01
@group(0) @binding(4) var<uniform> params: MulMatIdVecParams;
// Flattened as [row][thread] to keep each row's reduction contiguous in memory.
var<workgroup> partial_sums: array<f32, OUTPUTS_PER_WG * WG_SIZE>;
fn partial_index(row: u32, thread: u32) -> u32 {
return row * WG_SIZE + thread;
}
var<workgroup> gathered_count_ids: array<u32, N_EXPERTS>;
var<workgroup> gathered_expert_used: array<u32, N_EXPERTS>;
@compute @workgroup_size(WG_SIZE)
fn main(
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>
#ifdef USE_SUBGROUP_REDUCTION
, @builtin(subgroup_id) subgroup_id: u32,
@builtin(subgroup_invocation_id) subgroup_invocation_id: u32,
@builtin(num_subgroups) num_subgroups: u32,
@builtin(subgroup_size) subgroup_size: u32
#endif
) {
let thread_id = local_id.x;
for (var i = thread_id;i < params.n_expert;i += WG_SIZE) {
gathered_count_ids[i] = 0;
}
workgroupBarrier();
// gather the selected experts for the target token.
for (var col = thread_id;col < params.n_expert_used;col += WG_SIZE) {
let expert = ids[params.offset_ids + col];
gathered_count_ids[expert] = 1;
gathered_expert_used[expert] = col;
}
workgroupBarrier();
let output_groups:u32 = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG;
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
var own_expert:u32 = 0;
var wg_in_batch:u32 = 0;
var wg_sum:u32 = 0;
for (var i = 0u;i < params.n_expert;i += 1) {
let wg_vec_count = gathered_count_ids[i]; // 1 or 0
let wg_per_matrix = output_groups * wg_vec_count;
if (wg_sum <= wg_linear && wg_linear < wg_sum + wg_per_matrix) {
own_expert = i;
wg_in_batch = wg_linear - wg_sum;
break;
}
wg_sum += wg_per_matrix;
}
let row_base = (wg_linear % output_groups) * OUTPUTS_PER_WG;
let dst1_stride = params.m;
let src0_batch_offset = params.offset_src0 + own_expert * params.stride_02;
let src1_idx_base = params.offset_src1 + (gathered_expert_used[own_expert] % params.b_ne1) * params.stride_11;
let dst_idx_base = params.offset_dst + gathered_expert_used[own_expert] * dst1_stride + row_base;
let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base);
#ifdef USE_SUBGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let subgroup_total = subgroupAdd(acc[row]);
if (subgroup_invocation_id == 0u) {
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
}
}
workgroupBarrier();
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
let output_row = row_base + row;
var row_acc = 0.0f;
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
row_acc += partial_sums[partial_index(row, k)];
}
let row_total = subgroupAdd(row_acc);
if (subgroup_invocation_id == 0) {
dst[dst_idx_base + row] = row_total;
}
}
#endif
#ifdef USE_WORKGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] = acc[row];
}
workgroupBarrier();
var stride:u32 = WG_SIZE / 2u;
while (stride > 0) {
if (thread_id < stride) {
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
}
}
workgroupBarrier();
stride = stride / 2;
}
if (thread_id < OUTPUTS_PER_WG) {
let output_row = row_base + thread_id;
if (output_row < params.m) {
dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)];
}
}
#endif
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,240 @@
#if defined(SRC_F16) || defined(DST_F16)
enable f16;
#endif
#ifdef SRC_F16
#define SRC_TYPE f16
#else
#define SRC_TYPE f32
#endif
#ifdef DST_F16
#define DST_TYPE f16
#else
#define DST_TYPE f32
#endif
@group(0) @binding(0)
var<storage, read_write> input: array<SRC_TYPE>;
@group(0) @binding(1)
var<storage, read_write> output: array<DST_TYPE>;
struct Params {
offset_i: u32,
offset_o: u32,
// element strides
si0: u32, si1: u32, si2: u32, si3: u32,
so0: u32, so1: u32, so2: u32, so3: u32,
src_w: u32,
src_h: u32,
src_z: u32,
src_n: u32,
dst_w: u32,
dst_h: u32,
dst_z: u32,
dst_n: u32,
mode_flags: u32,
};
@group(0) @binding(2)
var<uniform> params: Params;
const GGML_SCALE_FLAG_ALIGN_CORNERS: u32 = 1u << 8u;
fn get_clamped_input(x: i32, y: i32, z: u32, n: u32) -> f32 {
let cx = u32(clamp(x, 0, i32(params.src_w) - 1));
let cy = u32(clamp(y, 0, i32(params.src_h) - 1));
let i = params.offset_i + cx * params.si0 + cy * params.si1 + z * params.si2 + n * params.si3;
return f32(input[i]);
}
fn cubic_weight(t: f32, a: f32) -> f32 {
let at = abs(t);
if (at <= 1.0) {
return (a + 2.0) * at * at * at - (a + 3.0) * at * at + 1.0;
} else if (at <= 2.0) {
return a * at * at * at - 5.0 * a * at * at + 8.0 * a * at - 4.0 * a;
} else {
return 0.0;
}
}
@compute @workgroup_size(WG_SIZE)
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>
) {
let i_out = gid.x + (num_wg.x * u32(WG_SIZE)) * gid.y;
let total = params.dst_w * params.dst_h * params.dst_z * params.dst_n;
if (i_out >= total) {
return;
}
// decode (x, y, z, n)
var i = i_out;
let x_dst = i % params.dst_w;
i = i / params.dst_w;
let y_dst = i % params.dst_h;
i = i / params.dst_h;
let z_dst = i % params.dst_z;
let n_dst = i / params.dst_z;
// scale factors
var sf0 = f32(params.dst_w) / f32(params.src_w);
var sf1 = f32(params.dst_h) / f32(params.src_h);
var sf2 = f32(params.dst_z) / f32(params.src_z);
var sf3 = f32(params.dst_n) / f32(params.src_n);
let align_corners = (params.mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) != 0;
// pixel_offset: 0.5 for half-pixel-center (default), 0.0 for align_corners
var pixel_offset = 0.5;
if (align_corners) {
pixel_offset = 0.0;
if (params.dst_w > 1 && params.src_w > 1) {
sf0 = f32(params.dst_w - 1) / f32(params.src_w - 1);
}
if (params.dst_h > 1 && params.src_h > 1) {
sf1 = f32(params.dst_h - 1) / f32(params.src_h - 1);
}
}
let z_src = min(params.src_z - 1, u32(floor(f32(z_dst) / sf2)));
let n_src = min(params.src_n - 1, u32(floor(f32(n_dst) / sf3)));
var result = 0.0;
#if defined(NEAREST)
let x_src = min(params.src_w - 1, u32(floor(f32(x_dst) / sf0)));
let y_src = min(params.src_h - 1, u32(floor(f32(y_dst) / sf1)));
result = get_clamped_input(i32(x_src), i32(y_src), z_src, n_src);
#elif defined(BILINEAR)
#if defined(ANTIALIAS)
// Antialiased bilinear: triangle filter over a variable support region.
let support0 = max(1.0f / sf0, 1.0f);
let support1 = max(1.0f / sf1, 1.0f);
let invscale0 = 1.0 / support0;
let invscale1 = 1.0 / support1;
let fx = (f32(x_dst) + pixel_offset) / sf0;
let fy = (f32(y_dst) + pixel_offset) / sf1;
let x_min = max(i32(fx - support0 + pixel_offset), 0);
let y_min = max(i32(fy - support1 + pixel_offset), 0);
let x_max = min(i32(fx + support0 + pixel_offset), i32(params.src_w));
let y_max = min(i32(fy + support1 + pixel_offset), i32(params.src_h));
var weighted_sum = 0.0;
var total_weight = 0.0;
for (var x = x_min; x < x_max; x += 1) {
let wx = max(1.0 - abs(f32(x) - fx + pixel_offset) * invscale0, 0.0);
for (var y = y_min; y < y_max; y += 1) {
let wy = max(1.0 - abs(f32(y) - fy + pixel_offset) * invscale1, 0.0);
let w = wx * wy;
if (w > 0.0) {
weighted_sum += get_clamped_input(x, y, z_src, n_src) * w;
total_weight += w;
}
}
}
if (total_weight > 0.0) {
result = weighted_sum / total_weight;
}
#else
let fx = (f32(x_dst) + pixel_offset) / sf0 - pixel_offset;
let fy = (f32(y_dst) + pixel_offset) / sf1 - pixel_offset;
let x0 = i32(floor(fx));
let y0 = i32(floor(fy));
let dx = clamp(fx - f32(x0), 0.0, 1.0);
let dy = clamp(fy - f32(y0), 0.0, 1.0);
let a = get_clamped_input(x0, y0, z_src, n_src);
let b = get_clamped_input(x0 + 1, y0, z_src, n_src);
let c = get_clamped_input(x0, y0 + 1, z_src, n_src);
let d = get_clamped_input(x0 + 1, y0 + 1, z_src, n_src);
let wa = (1.0 - dx) * (1.0 - dy);
let wb = dx * (1.0 - dy);
let wc = (1.0 - dx) * dy;
let wd = dx * dy;
result = a * wa + b * wb + c * wc + d * wd;
#endif
#elif defined(BICUBIC)
// bicubic convolution with alpha = -0.75 (PyTorch default)
let alpha = -0.75;
let fx = (f32(x_dst) + pixel_offset) / sf0 - pixel_offset;
let fy = (f32(y_dst) + pixel_offset) / sf1 - pixel_offset;
let x0 = i32(floor(fx));
let y0 = i32(floor(fy));
let dx = fx - f32(x0);
let dy = fy - f32(y0);
// horizontal weights for offsets -1, 0, 1, 2
let wx0 = cubic_weight(dx + 1.0, alpha);
let wx1 = cubic_weight(dx, alpha);
let wx2 = cubic_weight(1.0 - dx, alpha);
let wx3 = cubic_weight(2.0 - dx, alpha);
// vertical weights for offsets -1, 0, 1, 2
let wy0 = cubic_weight(dy + 1.0, alpha);
let wy1 = cubic_weight(dy, alpha);
let wy2 = cubic_weight(1.0 - dy, alpha);
let wy3 = cubic_weight(2.0 - dy, alpha);
// intermediate horizontal interpolation for 4x4 grid of pixels
// x0-1, x0, x0+1, x0+2, y0-1
let p0 = get_clamped_input(x0 - 1, y0 - 1, z_src, n_src);
let p1 = get_clamped_input(x0, y0 - 1, z_src, n_src);
let p2 = get_clamped_input(x0 + 1, y0 - 1, z_src, n_src);
let p3 = get_clamped_input(x0 + 2, y0 - 1, z_src, n_src);
let row0 = p0 * wx0 + p1 * wx1 + p2 * wx2 + p3 * wx3;
// x0-1, x0, x0+1, x0+2, y0
let q0 = get_clamped_input(x0 - 1, y0, z_src, n_src);
let q1 = get_clamped_input(x0, y0, z_src, n_src);
let q2 = get_clamped_input(x0 + 1, y0, z_src, n_src);
let q3 = get_clamped_input(x0 + 2, y0, z_src, n_src);
let row1 = q0 * wx0 + q1 * wx1 + q2 * wx2 + q3 * wx3;
// x0-1, x0, x0+1, x0+2, y0+1
let r0 = get_clamped_input(x0 - 1, y0 + 1, z_src, n_src);
let r1 = get_clamped_input(x0, y0 + 1, z_src, n_src);
let r2 = get_clamped_input(x0 + 1, y0 + 1, z_src, n_src);
let r3 = get_clamped_input(x0 + 2, y0 + 1, z_src, n_src);
let row2 = r0 * wx0 + r1 * wx1 + r2 * wx2 + r3 * wx3;
// x0-1, x0, x0+1, x0+2, y0+2
let s0 = get_clamped_input(x0 - 1, y0 + 2, z_src, n_src);
let s1 = get_clamped_input(x0, y0 + 2, z_src, n_src);
let s2 = get_clamped_input(x0 + 1, y0 + 2, z_src, n_src);
let s3 = get_clamped_input(x0 + 2, y0 + 2, z_src, n_src);
let row3 = s0 * wx0 + s1 * wx1 + s2 * wx2 + s3 * wx3;
// final vertical interpolation
result = row0 * wy0 + row1 * wy1 + row2 * wy2 + row3 * wy3;
#endif
let dst_idx = params.offset_o + x_dst * params.so0 + y_dst * params.so1 + z_dst * params.so2 + n_dst * params.so3;
output[dst_idx] = DST_TYPE(result);
}

View file

@ -59,8 +59,13 @@
uint64_t ggml_graph_next_uid(void) {
#ifdef _MSC_VER
#if defined(_WIN32)
static volatile LONG counter = 1;
return (uint64_t) InterlockedIncrement(&counter) - 1;
#else
static volatile long long counter = 1;
return (uint64_t) _InterlockedIncrement64(&counter) - 1;
#endif
#else
static uint64_t counter = 1;
return __atomic_fetch_add(&counter, 1, __ATOMIC_RELAXED);

View file

@ -40,6 +40,14 @@
#include <TargetConditionals.h>
#endif
#ifdef _WIN32
# define llama_mmap_ftell _ftelli64
# define llama_mmap_fseek _fseeki64
#else
# define llama_mmap_ftell ftello
# define llama_mmap_fseek fseeko
#endif
// TODO: consider moving to llama-impl.h if needed in more places
#if defined(_WIN32)
static std::string llama_format_win_err(DWORD err) {
@ -226,7 +234,7 @@ struct llama_file::impl {
size_t tell() const {
if (fd == -1) {
long ret = std::ftell(fp);
off_t ret = llama_mmap_ftell(fp);
if (ret == -1) {
throw std::runtime_error(format("ftell error: %s", strerror(errno)));
}
@ -244,7 +252,7 @@ struct llama_file::impl {
void seek(size_t offset, int whence) const {
off_t ret = 0;
if (fd == -1) {
ret = std::fseek(fp, (long) offset, whence);
ret = llama_mmap_fseek(fp, offset, whence);
} else {
ret = lseek(fd, offset, whence);
}

View file

@ -685,9 +685,9 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, const llama_mod
LLAMA_LOG_WARN("%s: %-36s - applying manual override: %s -> %s\n",
__func__, tensor_name.c_str(), ggml_type_name(new_type), ggml_type_name(qtype));
new_type = qtype;
manual = true;
break;
}
manual = true;
break;
}
}
}

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load diff

View file

@ -1,5 +1,5 @@
<script lang="ts">
import * as Tooltip from '../src/lib/components/ui/tooltip';
import * as Tooltip from '../../src/lib/components/ui/tooltip';
interface Props {
children: any;

View file

@ -1,7 +1,7 @@
import type { Preview } from '@storybook/sveltekit';
import '../src/app.css';
import ModeWatcherDecorator from './ModeWatcherDecorator.svelte';
import TooltipProviderDecorator from './TooltipProviderDecorator.svelte';
import ModeWatcherDecorator from './decorators/ModeWatcherDecorator.svelte';
import TooltipProviderDecorator from './decorators/TooltipProviderDecorator.svelte';
const preview: Preview = {
parameters: {

View file

@ -3640,9 +3640,9 @@
}
},
"node_modules/bits-ui": {
"version": "2.17.3",
"resolved": "https://registry.npmjs.org/bits-ui/-/bits-ui-2.17.3.tgz",
"integrity": "sha512-Bef41uY9U2jaBJHPhcPvmBNkGec5Wx2z6eioDsTmsaR2vH4QoaOcPi75gzCG3+/2TNr6v/qBwzgWNPYCxNtrEA==",
"version": "2.18.0",
"resolved": "https://registry.npmjs.org/bits-ui/-/bits-ui-2.18.0.tgz",
"integrity": "sha512-GLOBZRVy3hxNHIQ2MpD/+5aK9KcBFZRhUJtZ1UDABXdlVR4K6zFpgt4T+Rwuhf2sQzlc6yK1q/DprHPjwT4Pjw==",
"dev": true,
"license": "MIT",
"dependencies": {

View file

@ -28,7 +28,6 @@ import type {
ApiRouterModelsUnloadResponse,
// Chat types
ChatAttachmentDisplayItem,
ChatAttachmentPreviewItem,
ChatMessageType,
ChatRole,
ChatUploadedFile,
@ -92,7 +91,6 @@ declare global {
ApiRouterModelsUnloadResponse,
// Chat types
ChatAttachmentDisplayItem,
ChatAttachmentPreviewItem,
ChatMessagePromptProgress,
ChatMessageSiblingInfo,
ChatMessageTimings,

View file

@ -1,3 +1,5 @@
import { isElementInViewport } from '$lib/utils/viewport';
/**
* Svelte action that fades in an element when it enters the viewport.
* Uses IntersectionObserver for efficient viewport detection.
@ -12,17 +14,8 @@ export function fadeInView(
) {
const { duration = 300, y = 0, skipIfVisible = false } = options;
if (skipIfVisible) {
const rect = node.getBoundingClientRect();
const isAlreadyVisible =
rect.top < window.innerHeight &&
rect.bottom > 0 &&
rect.left < window.innerWidth &&
rect.right > 0;
if (isAlreadyVisible) {
return;
}
if (skipIfVisible && isElementInViewport(node)) {
return;
}
node.style.opacity = '0';

View file

@ -0,0 +1,11 @@
---
name: app
description: Opinionated app components building on top of ./ui primitives
---
- Can include business logic and state management
- Can include data fetching and caching logic
- Should use original spelling for HTML-native events and `camelCase` for custom events
- Props and markup attributes should be listed alphabetically
- Use JS Objects and Arrays for CSS classes and styles when they are dynamic
- Whenever there can be repetition in the component's markup, if it's too small to be decoupled as a separate component — use Svelte 5's `{#snippet}` + `{@render}`

View file

@ -5,15 +5,16 @@
import { TooltipSide } from '$lib/enums';
interface Props {
icon: Component;
tooltip: string;
variant?: ButtonVariant;
size?: ButtonSize;
iconSize?: string;
ariaLabel?: string;
class?: string;
disabled?: boolean;
icon: Component;
iconSize?: string;
onclick: (e?: MouseEvent) => void;
'aria-label'?: string;
size?: ButtonSize;
stopPropagationOnClick?: boolean;
tooltip: string;
variant?: ButtonVariant;
tooltipSide?: TooltipSide;
}
@ -26,8 +27,9 @@
disabled = false,
iconSize = 'h-3 w-3',
tooltipSide = TooltipSide.TOP,
stopPropagationOnClick = false,
onclick,
'aria-label': ariaLabel
ariaLabel
}: Props = $props();
</script>
@ -37,13 +39,18 @@
{variant}
{size}
{disabled}
{onclick}
onclick={(e: MouseEvent) => {
if (stopPropagationOnClick) e.stopPropagation();
onclick?.(e);
}}
class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!"
aria-label={ariaLabel || tooltip}
>
{@const IconComponent = icon}
<IconComponent class={iconSize} />
{#if icon}
{@const IconComponent = icon}
<IconComponent class={iconSize} />
{/if}
</Button>
</Tooltip.Trigger>

View file

@ -1,18 +1,17 @@
<script lang="ts">
import { Copy } from '@lucide/svelte';
import { copyToClipboard } from '$lib/utils';
import ActionIcon from './ActionIcon.svelte';
interface Props {
ariaLabel?: string;
canCopy?: boolean;
text: string;
}
let { ariaLabel = 'Copy to clipboard', canCopy = true, text }: Props = $props();
export let ariaLabel: string = 'Copy to clipboard';
export let canCopy: boolean = true;
export let text: string;
</script>
<Copy
class="h-3 w-3 flex-shrink-0 cursor-{canCopy ? 'pointer' : 'not-allowed'}"
aria-label={ariaLabel}
<ActionIcon
icon={Copy}
tooltip={ariaLabel}
iconSize="h-4 w-4"
disabled={!canCopy}
onclick={() => canCopy && copyToClipboard(text)}
/>

View file

@ -1,27 +0,0 @@
<script lang="ts">
import { X } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
interface Props {
id: string;
onRemove?: (id: string) => void;
class?: string;
iconSize?: number;
}
let { id, onRemove, class: className = '', iconSize = 3 }: Props = $props();
</script>
<Button
type="button"
variant="ghost"
size="icon-sm"
class="bg-white/20 p-0 hover:bg-white/30 {className}"
onclick={(e: MouseEvent) => {
e.stopPropagation();
onRemove?.(id);
}}
aria-label="Remove file"
>
<X class="h-{iconSize} w-{iconSize}" />
</Button>

View file

@ -1,46 +0,0 @@
<script lang="ts">
import { Eye } from '@lucide/svelte';
import { ActionIconCopyToClipboard } from '$lib/components/app';
import { FileTypeText } from '$lib/enums';
interface Props {
code: string;
language: string;
disabled?: boolean;
onPreview?: (code: string, language: string) => void;
}
let { code, language, disabled = false, onPreview }: Props = $props();
const showPreview = $derived(language?.toLowerCase() === FileTypeText.HTML);
function handlePreview() {
if (disabled) return;
onPreview?.(code, language);
}
</script>
<div class="code-block-actions">
<div class="copy-code-btn" class:opacity-50={disabled} class:!cursor-not-allowed={disabled}>
<ActionIconCopyToClipboard
text={code}
canCopy={!disabled}
ariaLabel={disabled ? 'Code incomplete' : 'Copy code'}
/>
</div>
{#if showPreview}
<button
class="preview-code-btn"
class:opacity-50={disabled}
class:!cursor-not-allowed={disabled}
title={disabled ? 'Code incomplete' : 'Preview code'}
aria-label="Preview code"
aria-disabled={disabled}
type="button"
onclick={handlePreview}
>
<Eye size={16} />
</button>
{/if}
</div>

View file

@ -9,11 +9,5 @@
/** Styled icon button for action triggers with tooltip. */
export { default as ActionIcon } from './ActionIcon.svelte';
/** Code block actions component (copy, preview). */
export { default as ActionIconsCodeBlock } from './ActionIconsCodeBlock.svelte';
/** Copy-to-clipboard icon button with click handler. */
/** Copy-to-clipboard icon button with clipboard logic. */
export { default as ActionIconCopyToClipboard } from './ActionIconCopyToClipboard.svelte';
/** Remove/delete icon button with X icon. */
export { default as ActionIconRemove } from './ActionIconRemove.svelte';

View file

@ -1,5 +1,4 @@
<script lang="ts">
import { cn } from '$lib/components/ui/utils';
import type { Snippet } from 'svelte';
interface Props {
@ -13,10 +12,10 @@
</script>
<button
class={cn(
class={[
'inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75',
className
)}
]}
{onclick}
>
{#if icon}

View file

@ -1,39 +0,0 @@
<script lang="ts">
import { ModelModality } from '$lib/enums';
import { MODALITY_ICONS, MODALITY_LABELS } from '$lib/constants';
import { cn } from '$lib/components/ui/utils';
type DisplayableModality = ModelModality.VISION | ModelModality.AUDIO;
interface Props {
modalities: ModelModality[];
class?: string;
}
let { modalities, class: className = '' }: Props = $props();
// Filter to only modalities that have icons (VISION, AUDIO)
const displayableModalities = $derived(
modalities.filter(
(m): m is DisplayableModality => m === ModelModality.VISION || m === ModelModality.AUDIO
)
);
</script>
{#each displayableModalities as modality, index (index)}
{@const IconComponent = MODALITY_ICONS[modality]}
{@const label = MODALITY_LABELS[modality]}
<span
class={cn(
'inline-flex items-center gap-1 rounded-md bg-muted px-2 py-1 text-xs font-medium',
className
)}
>
{#if IconComponent}
<IconComponent class="h-3 w-3" />
{/if}
{label}
</span>
{/each}

View file

@ -0,0 +1,32 @@
<script lang="ts">
import { Eye, Mic } from '@lucide/svelte';
import { ModelModality } from '$lib/enums';
interface Props {
modalities: ModelModality[];
class?: string;
}
let { modalities, class: className = '' }: Props = $props();
</script>
{#each modalities as modality (modality)}
{#if modality === ModelModality.VISION || modality === ModelModality.AUDIO}
<span
class={[
'inline-flex items-center gap-1 rounded-md bg-muted px-2 py-1 text-xs font-medium',
className
]}
>
{#if modality === ModelModality.VISION}
<Eye class="h-3 w-3" />
Vision
{:else}
<Mic class="h-3 w-3" />
Audio
{/if}
</span>
{/if}
{/each}

View file

@ -6,11 +6,8 @@
*
*/
/** Badge displaying chat statistics (tokens, timing). */
export { default as BadgeChatStatistic } from './BadgeChatStatistic.svelte';
/** Generic info badge with optional tooltip and click handler. */
export { default as BadgeInfo } from './BadgeInfo.svelte';
/** Badge indicating model modality (vision, audio, tools). */
export { default as BadgeModality } from './BadgeModality.svelte';
export { default as BadgesModality } from './BadgesModality.svelte';

View file

@ -1,284 +0,0 @@
<script lang="ts">
import { Button } from '$lib/components/ui/button';
import * as Alert from '$lib/components/ui/alert';
import { SyntaxHighlightedCode } from '$lib/components/app';
import { FileText, Image, Music, FileIcon, Eye, Info } from '@lucide/svelte';
import {
isTextFile,
isImageFile,
isPdfFile,
isAudioFile,
getLanguageFromFilename,
createBase64DataUrl
} from '$lib/utils';
import { convertPDFToImage } from '$lib/utils/browser-only';
import { modelsStore } from '$lib/stores/models.svelte';
interface Props {
// Either an uploaded file or a stored attachment
uploadedFile?: ChatUploadedFile;
attachment?: DatabaseMessageExtra;
// For uploaded files
preview?: string;
name?: string;
textContent?: string;
// For checking vision modality
activeModelId?: string;
}
let { uploadedFile, attachment, preview, name, textContent, activeModelId }: Props = $props();
let hasVisionModality = $derived(
activeModelId ? modelsStore.modelSupportsVision(activeModelId) : false
);
let displayName = $derived(uploadedFile?.name || attachment?.name || name || 'Unknown File');
// Determine file type from uploaded file or attachment
let isAudio = $derived(isAudioFile(attachment, uploadedFile));
let isImage = $derived(isImageFile(attachment, uploadedFile));
let isPdf = $derived(isPdfFile(attachment, uploadedFile));
let isText = $derived(isTextFile(attachment, uploadedFile));
let displayPreview = $derived(
uploadedFile?.preview ||
(isImage && attachment && 'base64Url' in attachment ? attachment.base64Url : preview)
);
let displayTextContent = $derived(
uploadedFile?.textContent ||
(attachment && 'content' in attachment ? attachment.content : textContent)
);
let language = $derived(getLanguageFromFilename(displayName));
let IconComponent = $derived(() => {
if (isImage) return Image;
if (isText || isPdf) return FileText;
if (isAudio) return Music;
return FileIcon;
});
let pdfViewMode = $state<'text' | 'pages'>('pages');
let pdfImages = $state<string[]>([]);
let pdfImagesLoading = $state(false);
let pdfImagesError = $state<string | null>(null);
async function loadPdfImages() {
if (!isPdf || pdfImages.length > 0 || pdfImagesLoading) return;
pdfImagesLoading = true;
pdfImagesError = null;
try {
let file: File | null = null;
if (uploadedFile?.file) {
file = uploadedFile.file;
} else if (isPdf && attachment) {
// Check if we have pre-processed images
if (
'images' in attachment &&
attachment.images &&
Array.isArray(attachment.images) &&
attachment.images.length > 0
) {
pdfImages = attachment.images;
return;
}
// Convert base64 back to File for processing
if ('base64Data' in attachment && attachment.base64Data) {
const base64Data = attachment.base64Data;
const byteCharacters = atob(base64Data);
const byteNumbers = new Array(byteCharacters.length);
for (let i = 0; i < byteCharacters.length; i++) {
byteNumbers[i] = byteCharacters.charCodeAt(i);
}
const byteArray = new Uint8Array(byteNumbers);
file = new File([byteArray], displayName, { type: 'application/pdf' });
}
}
if (file) {
pdfImages = await convertPDFToImage(file);
} else {
throw new Error('No PDF file available for conversion');
}
} catch (error) {
pdfImagesError = error instanceof Error ? error.message : 'Failed to load PDF images';
} finally {
pdfImagesLoading = false;
}
}
export function reset() {
pdfImages = [];
pdfImagesLoading = false;
pdfImagesError = null;
pdfViewMode = 'pages';
}
$effect(() => {
if (isPdf && pdfViewMode === 'pages') {
loadPdfImages();
}
});
</script>
<div class="space-y-4">
<div class="flex items-center justify-end gap-6">
{#if isPdf}
<div class="flex items-center gap-2">
<Button
variant={pdfViewMode === 'text' ? 'default' : 'outline'}
size="sm"
onclick={() => (pdfViewMode = 'text')}
disabled={pdfImagesLoading}
>
<FileText class="mr-1 h-4 w-4" />
Text
</Button>
<Button
variant={pdfViewMode === 'pages' ? 'default' : 'outline'}
size="sm"
onclick={() => {
pdfViewMode = 'pages';
loadPdfImages();
}}
disabled={pdfImagesLoading}
>
{#if pdfImagesLoading}
<div
class="mr-1 h-4 w-4 animate-spin rounded-full border-2 border-current border-t-transparent"
></div>
{:else}
<Eye class="mr-1 h-4 w-4" />
{/if}
Pages
</Button>
</div>
{/if}
</div>
<div class="flex-1 overflow-auto">
{#if isImage && displayPreview}
<div class="flex items-center justify-center">
<img
src={displayPreview}
alt={displayName}
class="max-h-full rounded-lg object-contain shadow-lg"
/>
</div>
{:else if isPdf && pdfViewMode === 'pages'}
{#if !hasVisionModality && activeModelId}
<Alert.Root class="mb-4">
<Info class="h-4 w-4" />
<Alert.Title>Preview only</Alert.Title>
<Alert.Description>
<span class="inline-flex">
The selected model does not support vision. Only the extracted
<!-- svelte-ignore a11y_click_events_have_key_events -->
<!-- svelte-ignore a11y_no_static_element_interactions -->
<span class="mx-1 cursor-pointer underline" onclick={() => (pdfViewMode = 'text')}>
text
</span>
will be sent to the model.
</span>
</Alert.Description>
</Alert.Root>
{/if}
{#if pdfImagesLoading}
<div class="flex items-center justify-center p-8">
<div class="text-center">
<div
class="mx-auto mb-4 h-8 w-8 animate-spin rounded-full border-4 border-primary border-t-transparent"
></div>
<p class="text-muted-foreground">Converting PDF to images...</p>
</div>
</div>
{:else if pdfImagesError}
<div class="flex items-center justify-center p-8">
<div class="text-center">
<FileText class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
<p class="mb-4 text-muted-foreground">Failed to load PDF images</p>
<p class="text-sm text-muted-foreground">{pdfImagesError}</p>
<Button class="mt-4" onclick={() => (pdfViewMode = 'text')}>View as Text</Button>
</div>
</div>
{:else if pdfImages.length > 0}
<div class="max-h-[70vh] space-y-4 overflow-auto">
{#each pdfImages as image, index (image)}
<div class="text-center">
<p class="mb-2 text-sm text-muted-foreground">Page {index + 1}</p>
<img
src={image}
alt="PDF Page {index + 1}"
class="mx-auto max-w-full rounded-lg shadow-lg"
/>
</div>
{/each}
</div>
{:else}
<div class="flex items-center justify-center p-8">
<div class="text-center">
<FileText class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
<p class="mb-4 text-muted-foreground">No PDF pages available</p>
</div>
</div>
{/if}
{:else if (isText || (isPdf && pdfViewMode === 'text')) && displayTextContent}
<SyntaxHighlightedCode code={displayTextContent} {language} maxWidth="calc(69rem - 2rem)" />
{:else if isAudio}
<div class="flex items-center justify-center p-8">
<div class="w-full max-w-md text-center">
<Music class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
{#if uploadedFile?.preview}
<audio controls class="mb-4 w-full" src={uploadedFile.preview}>
Your browser does not support the audio element.
</audio>
{:else if isAudio && attachment && 'mimeType' in attachment && 'base64Data' in attachment}
<audio
controls
class="mb-4 w-full"
src={createBase64DataUrl(attachment.mimeType, attachment.base64Data)}
>
Your browser does not support the audio element.
</audio>
{:else}
<p class="mb-4 text-muted-foreground">Audio preview not available</p>
{/if}
<p class="text-sm text-muted-foreground">
{displayName}
</p>
</div>
</div>
{:else}
<div class="flex items-center justify-center p-8">
<div class="text-center">
{#if IconComponent}
<IconComponent class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
{/if}
<p class="mb-4 text-muted-foreground">Preview not available for this file type</p>
</div>
</div>
{/if}
</div>
</div>

View file

@ -1,165 +0,0 @@
<script lang="ts">
import { ActionIconRemove } from '$lib/components/app';
import { formatFileSize, getFileTypeLabel, getPreviewText, isTextFile } from '$lib/utils';
import { AttachmentType } from '$lib/enums';
interface Props {
class?: string;
id: string;
onClick?: (event?: MouseEvent) => void;
onRemove?: (id: string) => void;
name: string;
readonly?: boolean;
size?: number;
textContent?: string;
// Either uploaded file or stored attachment
uploadedFile?: ChatUploadedFile;
attachment?: DatabaseMessageExtra;
}
let {
class: className = '',
id,
onClick,
onRemove,
name,
readonly = false,
size,
textContent,
uploadedFile,
attachment
}: Props = $props();
let isText = $derived(isTextFile(attachment, uploadedFile));
let fileTypeLabel = $derived.by(() => {
if (uploadedFile?.type) {
return getFileTypeLabel(uploadedFile.type);
}
if (attachment) {
if ('mimeType' in attachment && attachment.mimeType) {
return getFileTypeLabel(attachment.mimeType);
}
if (attachment.type) {
return getFileTypeLabel(attachment.type);
}
}
return getFileTypeLabel(name);
});
let pdfProcessingMode = $derived.by(() => {
if (attachment?.type === AttachmentType.PDF) {
const pdfAttachment = attachment as DatabaseMessageExtraPdfFile;
return pdfAttachment.processedAsImages ? 'Sent as Image' : 'Sent as Text';
}
return null;
});
</script>
{#if isText}
{#if readonly}
<!-- Readonly mode (ChatMessage) -->
<button
class="cursor-pointer rounded-lg border border-border bg-muted p-3 transition-shadow hover:shadow-md {className} w-full max-w-2xl"
onclick={onClick}
aria-label={`Preview ${name}`}
type="button"
>
<div class="flex items-start gap-3">
<div class="flex min-w-0 flex-1 flex-col items-start text-left">
<span class="w-full truncate text-sm font-medium text-foreground">{name}</span>
{#if size}
<span class="text-xs text-muted-foreground">{formatFileSize(size)}</span>
{/if}
{#if textContent}
<div class="relative mt-2 w-full">
<div
class="overflow-hidden font-mono text-xs leading-relaxed break-words whitespace-pre-wrap text-muted-foreground"
>
{getPreviewText(textContent)}
</div>
{#if textContent.length > 150}
<div
class="pointer-events-none absolute right-0 bottom-0 left-0 h-6 bg-gradient-to-t from-muted to-transparent"
></div>
{/if}
</div>
{/if}
</div>
</div>
</button>
{:else}
<!-- Non-readonly mode (ChatForm) -->
<button
class="group relative rounded-lg border border-border bg-muted p-3 {className} {textContent
? 'max-h-24 max-w-72'
: 'max-w-36'} cursor-pointer text-left"
onclick={onClick}
>
<div class="absolute top-2 right-2 opacity-0 transition-opacity group-hover:opacity-100">
<ActionIconRemove {id} {onRemove} />
</div>
<div class="pr-8">
<span class="mb-3 block truncate text-sm font-medium text-foreground">{name}</span>
{#if textContent}
<div class="relative">
<div
class="overflow-hidden font-mono text-xs leading-relaxed break-words whitespace-pre-wrap text-muted-foreground"
style="max-height: 3rem; line-height: 1.2em;"
>
{getPreviewText(textContent)}
</div>
{#if textContent.length > 150}
<div
class="pointer-events-none absolute right-0 bottom-0 left-0 h-4 bg-gradient-to-t from-muted to-transparent"
></div>
{/if}
</div>
{/if}
</div>
</button>
{/if}
{:else}
<button
class="group flex items-center gap-3 rounded-lg border border-border bg-muted p-3 {className} relative"
onclick={onClick}
>
<div
class="flex h-8 w-8 items-center justify-center rounded bg-primary/10 text-xs font-medium text-primary"
>
{fileTypeLabel}
</div>
<div class="flex flex-col gap-0.5">
<span
class="max-w-24 truncate text-sm font-medium text-foreground {readonly
? ''
: 'group-hover:pr-6'} md:max-w-32"
>
{name}
</span>
{#if pdfProcessingMode}
<span class="text-left text-xs text-muted-foreground">{pdfProcessingMode}</span>
{:else if size}
<span class="text-left text-xs text-muted-foreground">{formatFileSize(size)}</span>
{/if}
</div>
{#if !readonly}
<div class="absolute top-2 right-2 opacity-0 transition-opacity group-hover:opacity-100">
<ActionIconRemove {id} {onRemove} />
</div>
{/if}
</button>
{/if}

View file

@ -1,287 +0,0 @@
<script lang="ts">
import {
ChatAttachmentMcpPrompt,
ChatAttachmentMcpResource,
ChatAttachmentThumbnailImage,
ChatAttachmentThumbnailFile,
HorizontalScrollCarousel,
DialogChatAttachmentPreview,
DialogChatAttachmentsViewAll,
DialogMcpResourcePreview
} from '$lib/components/app';
import { Button } from '$lib/components/ui/button';
import { AttachmentType } from '$lib/enums';
import type {
DatabaseMessageExtraMcpPrompt,
DatabaseMessageExtraMcpResource,
MCPResourceAttachment
} from '$lib/types';
import { getAttachmentDisplayItems } from '$lib/utils';
interface Props {
class?: string;
style?: string;
// For ChatMessage - stored attachments
attachments?: DatabaseMessageExtra[];
readonly?: boolean;
// For ChatForm - pending uploads
onFileRemove?: (fileId: string) => void;
uploadedFiles?: ChatUploadedFile[];
// Image size customization
imageClass?: string;
imageHeight?: string;
imageWidth?: string;
// Limit display to single row with "+ X more" button
limitToSingleRow?: boolean;
// For vision modality check
activeModelId?: string;
}
let {
class: className = '',
style = '',
attachments = [],
readonly = false,
onFileRemove,
uploadedFiles = $bindable([]),
// Default to small size for form previews
imageClass = '',
imageHeight = 'h-24',
imageWidth = 'w-auto',
limitToSingleRow = false,
activeModelId
}: Props = $props();
let displayItems = $derived(getAttachmentDisplayItems({ uploadedFiles, attachments }));
let carouselRef: HorizontalScrollCarousel | undefined = $state();
let isScrollable = $state(false);
let previewDialogOpen = $state(false);
let previewItem = $state<ChatAttachmentPreviewItem | null>(null);
let mcpResourcePreviewOpen = $state(false);
let mcpResourcePreviewExtra = $state<DatabaseMessageExtraMcpResource | null>(null);
let showViewAll = $derived(limitToSingleRow && displayItems.length > 0 && isScrollable);
let viewAllDialogOpen = $state(false);
function openPreview(item: ChatAttachmentDisplayItem, event?: MouseEvent) {
event?.stopPropagation();
event?.preventDefault();
previewItem = {
uploadedFile: item.uploadedFile,
attachment: item.attachment,
preview: item.preview,
name: item.name,
size: item.size,
textContent: item.textContent
};
previewDialogOpen = true;
}
function openMcpResourcePreview(extra: DatabaseMessageExtraMcpResource) {
mcpResourcePreviewExtra = extra;
mcpResourcePreviewOpen = true;
}
function toMcpResourceAttachment(
extra: DatabaseMessageExtraMcpResource,
id: string
): MCPResourceAttachment {
return {
id,
resource: {
uri: extra.uri,
name: extra.name,
title: extra.name,
serverName: extra.serverName
}
};
}
$effect(() => {
if (carouselRef && displayItems.length) {
carouselRef.resetScroll();
}
});
</script>
{#if displayItems.length > 0}
<div class={className} {style}>
{#if limitToSingleRow}
<HorizontalScrollCarousel
bind:this={carouselRef}
onScrollableChange={(scrollable) => (isScrollable = scrollable)}
>
{#each displayItems as item (item.id)}
{#if item.isMcpPrompt}
{@const mcpPrompt =
item.attachment?.type === AttachmentType.MCP_PROMPT
? (item.attachment as DatabaseMessageExtraMcpPrompt)
: item.uploadedFile?.mcpPrompt
? {
type: AttachmentType.MCP_PROMPT as const,
name: item.name,
serverName: item.uploadedFile.mcpPrompt.serverName,
promptName: item.uploadedFile.mcpPrompt.promptName,
content: item.textContent ?? '',
arguments: item.uploadedFile.mcpPrompt.arguments
}
: null}
{#if mcpPrompt}
<ChatAttachmentMcpPrompt
class="max-w-[300px] min-w-[200px] flex-shrink-0 {limitToSingleRow
? 'first:ml-4 last:mr-4'
: ''}"
prompt={mcpPrompt}
{readonly}
isLoading={item.isLoading}
loadError={item.loadError}
onRemove={onFileRemove ? () => onFileRemove(item.id) : undefined}
/>
{/if}
{:else if item.isMcpResource && item.attachment?.type === AttachmentType.MCP_RESOURCE}
{@const mcpResource = item.attachment as DatabaseMessageExtraMcpResource}
<ChatAttachmentMcpResource
class="flex-shrink-0 {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
attachment={toMcpResourceAttachment(mcpResource, item.id)}
onClick={() => openMcpResourcePreview(mcpResource)}
/>
{:else if item.isImage && item.preview}
<ChatAttachmentThumbnailImage
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
id={item.id}
name={item.name}
preview={item.preview}
{readonly}
onRemove={onFileRemove}
height={imageHeight}
width={imageWidth}
{imageClass}
onClick={(event) => openPreview(item, event)}
/>
{:else}
<ChatAttachmentThumbnailFile
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
id={item.id}
name={item.name}
size={item.size}
{readonly}
onRemove={onFileRemove}
textContent={item.textContent}
attachment={item.attachment}
uploadedFile={item.uploadedFile}
onClick={(event) => openPreview(item, event)}
/>
{/if}
{/each}
</HorizontalScrollCarousel>
{#if showViewAll}
<div class="mt-2 -mr-2 flex justify-end px-4">
<Button
type="button"
variant="ghost"
size="sm"
class="h-6 text-xs text-muted-foreground hover:text-foreground"
onclick={() => (viewAllDialogOpen = true)}
>
View all ({displayItems.length})
</Button>
</div>
{/if}
{:else}
<div class="flex flex-wrap items-start justify-end gap-3">
{#each displayItems as item (item.id)}
{#if item.isMcpPrompt}
{@const mcpPrompt =
item.attachment?.type === AttachmentType.MCP_PROMPT
? (item.attachment as DatabaseMessageExtraMcpPrompt)
: item.uploadedFile?.mcpPrompt
? {
type: AttachmentType.MCP_PROMPT as const,
name: item.name,
serverName: item.uploadedFile.mcpPrompt.serverName,
promptName: item.uploadedFile.mcpPrompt.promptName,
content: item.textContent ?? '',
arguments: item.uploadedFile.mcpPrompt.arguments
}
: null}
{#if mcpPrompt}
<ChatAttachmentMcpPrompt
class="max-w-[300px] min-w-[200px]"
prompt={mcpPrompt}
{readonly}
isLoading={item.isLoading}
loadError={item.loadError}
onRemove={onFileRemove ? () => onFileRemove(item.id) : undefined}
/>
{/if}
{:else if item.isMcpResource && item.attachment?.type === AttachmentType.MCP_RESOURCE}
{@const mcpResource = item.attachment as DatabaseMessageExtraMcpResource}
<ChatAttachmentMcpResource
attachment={toMcpResourceAttachment(mcpResource, item.id)}
onClick={() => openMcpResourcePreview(mcpResource)}
/>
{:else if item.isImage && item.preview}
<ChatAttachmentThumbnailImage
class="cursor-pointer"
id={item.id}
name={item.name}
preview={item.preview}
{readonly}
onRemove={onFileRemove}
height={imageHeight}
width={imageWidth}
{imageClass}
onClick={(event) => openPreview(item, event)}
/>
{:else}
<ChatAttachmentThumbnailFile
class="cursor-pointer"
id={item.id}
name={item.name}
size={item.size}
{readonly}
onRemove={onFileRemove}
textContent={item.textContent}
attachment={item.attachment}
uploadedFile={item.uploadedFile}
onClick={(event?: MouseEvent) => openPreview(item, event)}
/>
{/if}
{/each}
</div>
{/if}
</div>
{/if}
{#if previewItem}
<DialogChatAttachmentPreview
bind:open={previewDialogOpen}
uploadedFile={previewItem.uploadedFile}
attachment={previewItem.attachment}
preview={previewItem.preview}
name={previewItem.name}
size={previewItem.size}
textContent={previewItem.textContent}
{activeModelId}
/>
{/if}
<DialogChatAttachmentsViewAll
bind:open={viewAllDialogOpen}
{uploadedFiles}
{attachments}
{readonly}
{onFileRemove}
imageHeight="h-64"
{imageClass}
{activeModelId}
/>
{#if mcpResourcePreviewExtra}
<DialogMcpResourcePreview bind:open={mcpResourcePreviewOpen} extra={mcpResourcePreviewExtra} />
{/if}

View file

@ -0,0 +1,119 @@
<script lang="ts">
import {
ChatAttachmentsListItem,
DialogChatAttachmentsPreview,
DialogMcpResourcePreview,
HorizontalScrollCarousel
} from '$lib/components/app';
import type { DatabaseMessageExtraMcpResource } from '$lib/types';
import { getAttachmentDisplayItems, isMcpPrompt, isMcpResource } from '$lib/utils';
interface Props {
class?: string;
style?: string;
// For ChatMessage - stored attachments
attachments?: DatabaseMessageExtra[];
readonly?: boolean;
// For ChatForm - pending uploads
onFileRemove?: (fileId: string) => void;
uploadedFiles?: ChatUploadedFile[];
// Image size customization
imageClass?: string;
imageHeight?: string;
imageWidth?: string;
// Limit display to single row with "+ X more" button
limitToSingleRow?: boolean;
// For vision modality check
activeModelId?: string;
}
let {
class: className = '',
style = '',
attachments = [],
readonly = false,
onFileRemove,
uploadedFiles = $bindable([]),
// Default to small size for form previews
imageClass = '',
imageHeight = 'h-24',
imageWidth = 'w-auto',
limitToSingleRow = false,
activeModelId
}: Props = $props();
let carouselRef: HorizontalScrollCarousel | undefined = $state();
let mcpResourcePreviewOpen = $state(false);
let mcpResourcePreviewExtra = $state<DatabaseMessageExtraMcpResource | null>(null);
let previewFocusIndex = $state(0);
let viewAllDialogOpen = $state(false);
let displayItems = $derived(getAttachmentDisplayItems({ uploadedFiles, attachments }));
function openPreview(item: ChatAttachmentDisplayItem, event?: MouseEvent) {
event?.stopPropagation();
event?.preventDefault();
// Find the index of the clicked item among non-MCP attachments
const nonMcpItems = displayItems.filter((i) => !isMcpPrompt(i) && !isMcpResource(i));
const index = nonMcpItems.findIndex((i) => i.id === item.id);
previewFocusIndex = index >= 0 ? index : 0;
viewAllDialogOpen = true;
}
function openMcpResourcePreview(extra: DatabaseMessageExtraMcpResource) {
mcpResourcePreviewExtra = extra;
mcpResourcePreviewOpen = true;
}
$effect(() => {
if (carouselRef && displayItems.length) {
carouselRef.resetScroll();
}
});
</script>
{#snippet attachmentitem(item: ChatAttachmentDisplayItem)}
<ChatAttachmentsListItem
{imageClass}
{imageHeight}
{imageWidth}
{item}
{limitToSingleRow}
{onFileRemove}
onMcpResourcePreview={openMcpResourcePreview}
onPreview={(i: ChatAttachmentDisplayItem, event?: MouseEvent) => openPreview(i, event)}
{readonly}
/>
{/snippet}
{#if displayItems.length > 0}
<div class={className} {style}>
{#if limitToSingleRow}
<HorizontalScrollCarousel bind:this={carouselRef}>
{#each displayItems as item (item.id)}
{@render attachmentitem(item)}
{/each}
</HorizontalScrollCarousel>
{:else}
<div class="flex flex-wrap items-start justify-end gap-3">
{#each displayItems as item (item.id)}
{@render attachmentitem(item)}
{/each}
</div>
{/if}
</div>
{/if}
<DialogChatAttachmentsPreview
{activeModelId}
{attachments}
bind:open={viewAllDialogOpen}
{previewFocusIndex}
{uploadedFiles}
/>
{#if mcpResourcePreviewExtra}
<DialogMcpResourcePreview extra={mcpResourcePreviewExtra} bind:open={mcpResourcePreviewOpen} />
{/if}

View file

@ -0,0 +1,132 @@
<script lang="ts">
import {
ChatAttachmentsListItemMcpPrompt,
ChatAttachmentsListItemMcpResource,
ChatAttachmentsListItemThumbnailImage,
ChatAttachmentsListItemThumbnailFile
} from '$lib/components/app';
import { AttachmentType } from '$lib/enums';
import type {
ChatAttachmentDisplayItem,
DatabaseMessageExtraMcpPrompt,
DatabaseMessageExtraMcpResource,
MCPResourceAttachment
} from '$lib/types';
import { isMcpPrompt, isMcpResource, isPdfFile } from '$lib/utils';
interface Props {
class?: string;
imageClass?: string;
imageHeight?: string;
imageWidth?: string;
item: ChatAttachmentDisplayItem;
limitToSingleRow?: boolean;
onFileRemove?: (fileId: string) => void;
onMcpResourcePreview?: (extra: DatabaseMessageExtraMcpResource) => void;
onPreview?: (item: ChatAttachmentDisplayItem) => void;
readonly?: boolean;
}
let {
class: className = '',
imageClass = '',
imageHeight = 'h-24',
imageWidth = 'w-auto',
item,
limitToSingleRow = false,
onFileRemove,
onMcpResourcePreview,
onPreview,
readonly = false
}: Props = $props();
const scrollClasses = $derived(limitToSingleRow ? 'first:ml-4 last:mr-4' : '');
function toMcpResourceAttachment(
extra: DatabaseMessageExtraMcpResource,
id: string
): MCPResourceAttachment {
return {
id,
resource: {
uri: extra.uri,
name: extra.name,
title: extra.name,
serverName: extra.serverName
}
};
}
</script>
{#if isMcpPrompt(item)}
{@const mcpPrompt =
item.attachment?.type === AttachmentType.MCP_PROMPT
? (item.attachment as DatabaseMessageExtraMcpPrompt)
: item.uploadedFile?.mcpPrompt
? {
type: AttachmentType.MCP_PROMPT as const,
name: item.name,
serverName: item.uploadedFile.mcpPrompt.serverName,
promptName: item.uploadedFile.mcpPrompt.promptName,
content: item.textContent ?? '',
arguments: item.uploadedFile.mcpPrompt.arguments
}
: null}
{#if mcpPrompt}
<ChatAttachmentsListItemMcpPrompt
class="max-w-[300px] min-w-[200px] flex-shrink-0 {className} {scrollClasses}"
prompt={mcpPrompt}
{readonly}
isLoading={item.isLoading}
loadError={item.loadError}
onRemove={onFileRemove ? () => onFileRemove(item.id) : undefined}
/>
{/if}
{:else if isMcpResource(item)}
{@const mcpResource = item.attachment as DatabaseMessageExtraMcpResource}
<ChatAttachmentsListItemMcpResource
class="flex-shrink-0 {className} {scrollClasses}"
attachment={toMcpResourceAttachment(mcpResource, item.id)}
onclick={() => onMcpResourcePreview?.(mcpResource)}
/>
{:else if item.isImage && item.preview}
<ChatAttachmentsListItemThumbnailImage
class="flex-shrink-0 cursor-pointer {className} {scrollClasses}"
id={item.id}
name={item.name}
preview={item.preview}
{readonly}
onRemove={onFileRemove}
height={imageHeight}
width={imageWidth}
{imageClass}
onclick={() => onPreview?.(item)}
/>
{:else if isPdfFile(item.attachment, item.uploadedFile)}
<ChatAttachmentsListItemThumbnailFile
class="flex-shrink-0 cursor-pointer {className} {scrollClasses}"
id={item.id}
name={item.name}
size={item.size}
{readonly}
onRemove={onFileRemove}
textContent={item.textContent}
attachment={item.attachment}
uploadedFile={item.uploadedFile}
onclick={() => onPreview?.(item)}
/>
{:else}
<ChatAttachmentsListItemThumbnailFile
class="flex-shrink-0 cursor-pointer {className} {scrollClasses}"
id={item.id}
name={item.name}
size={item.size}
{readonly}
onRemove={onFileRemove}
textContent={item.textContent}
attachment={item.attachment}
uploadedFile={item.uploadedFile}
onclick={() => onPreview?.(item)}
/>
{/if}

View file

@ -1,40 +1,41 @@
<script lang="ts">
import { ChatMessageMcpPromptContent, ActionIconRemove } from '$lib/components/app';
import { ChatMessageMcpPromptContent, ActionIcon } from '$lib/components/app';
import { X } from '@lucide/svelte';
import type { DatabaseMessageExtraMcpPrompt } from '$lib/types';
import { McpPromptVariant } from '$lib/enums';
interface Props {
class?: string;
prompt: DatabaseMessageExtraMcpPrompt;
readonly?: boolean;
isLoading?: boolean;
loadError?: string;
onRemove?: () => void;
prompt: DatabaseMessageExtraMcpPrompt;
readonly?: boolean;
}
let {
class: className = '',
prompt,
readonly = false,
isLoading = false,
loadError,
onRemove
onRemove,
prompt,
readonly = false
}: Props = $props();
</script>
<div class="group relative {className}">
<ChatMessageMcpPromptContent
{prompt}
variant={McpPromptVariant.ATTACHMENT}
{isLoading}
{loadError}
{prompt}
variant={McpPromptVariant.ATTACHMENT}
/>
{#if !readonly && onRemove}
<div
class="absolute top-10 right-2 flex items-center justify-center opacity-0 transition-opacity group-hover:opacity-100"
>
<ActionIconRemove id={prompt.name} onRemove={() => onRemove?.()} />
<ActionIcon icon={X} tooltip="Remove" stopPropagationOnClick onclick={() => onRemove?.()} />
</div>
{/if}
</div>

View file

@ -1,46 +1,47 @@
<script lang="ts">
import { Loader2, AlertCircle } from '@lucide/svelte';
import { cn } from '$lib/components/ui/utils';
import { mcpStore } from '$lib/stores/mcp.svelte';
import type { MCPResourceAttachment } from '$lib/types';
import * as Tooltip from '$lib/components/ui/tooltip';
import { ActionIconRemove } from '$lib/components/app';
import { ActionIcon } from '$lib/components/app';
import { X } from '@lucide/svelte';
import { getResourceIcon, getResourceDisplayName } from '$lib/utils';
interface Props {
attachment: MCPResourceAttachment;
onRemove?: (attachmentId: string) => void;
onClick?: () => void;
class?: string;
onclick?: () => void;
onRemove?: (attachmentId: string) => void;
}
let { attachment, onRemove, onClick, class: className }: Props = $props();
function getStatusClass(attachment: MCPResourceAttachment): string {
if (attachment.error) return 'border-red-500/50 bg-red-500/10';
if (attachment.loading) return 'border-border/50 bg-muted/30';
return 'border-border/50 bg-muted/30';
}
let { attachment, class: className, onclick, onRemove }: Props = $props();
const ResourceIcon = $derived(
getResourceIcon(attachment.resource.mimeType, attachment.resource.uri)
);
const serverName = $derived(mcpStore.getServerDisplayName(attachment.resource.serverName));
const favicon = $derived(mcpStore.getServerFavicon(attachment.resource.serverName));
function getStatusClass(attachment: MCPResourceAttachment): string {
if (attachment.error) return 'border-red-500/50 bg-red-500/10';
if (attachment.loading) return 'border-border/50 bg-muted/30';
return 'border-border/50 bg-muted/30';
}
</script>
<Tooltip.Root>
<Tooltip.Trigger>
<button
type="button"
class={cn(
class={[
'flex flex-shrink-0 items-center gap-1.5 rounded-md border px-2 py-0.75 text-sm transition-colors',
getStatusClass(attachment),
onClick && 'cursor-pointer hover:bg-muted/50',
onclick && 'cursor-pointer hover:bg-muted/50',
className
)}
onclick={onClick}
disabled={!onClick}
]}
disabled={!onclick}
{onclick}
type="button"
>
{#if attachment.loading}
<Loader2 class="h-3 w-3 animate-spin text-muted-foreground" />
@ -55,11 +56,13 @@
</span>
{#if onRemove}
<ActionIconRemove
<ActionIcon
class="-my-2 -mr-1.5 bg-transparent"
iconSize={2}
id={attachment.id}
{onRemove}
icon={X}
iconSize="h-2 w-2"
onclick={() => onRemove?.(attachment.id)}
stopPropagationOnClick
tooltip="Remove"
/>
{/if}
</button>
@ -69,12 +72,12 @@
<div class="flex items-center gap-1 text-xs">
{#if favicon}
<img
src={favicon}
alt=""
alt={attachment.resource.serverName}
class="h-3 w-3 shrink-0 rounded-sm"
onerror={(e) => {
(e.currentTarget as HTMLImageElement).style.display = 'none';
}}
src={favicon}
/>
{/if}

View file

@ -0,0 +1,174 @@
<script lang="ts">
import { X } from '@lucide/svelte';
import {
formatFileSize,
getFileTypeLabel,
getPreviewText,
isPdfFile,
isTextFile
} from '$lib/utils';
import { ActionIcon } from '$lib/components/app';
import { AttachmentType } from '$lib/enums';
interface Props {
attachment?: DatabaseMessageExtra;
class?: string;
id: string;
onclick?: (event: MouseEvent) => void;
onRemove?: (id: string) => void;
name: string;
readonly?: boolean;
size?: number;
textContent?: string;
// Either uploaded file or stored attachment
uploadedFile?: ChatUploadedFile;
}
let {
attachment,
class: className = '',
id,
onclick,
onRemove,
name,
readonly = false,
size,
textContent,
uploadedFile
}: Props = $props();
let isPdf = $derived(isPdfFile(attachment, uploadedFile));
let isPdfWithContent = $derived(isPdf && !!textContent);
let isText = $derived(isTextFile(attachment, uploadedFile));
let isTextWithContent = $derived(isText && !!textContent);
let fileTypeLabel = $derived.by(() => {
if (uploadedFile?.type) {
return getFileTypeLabel(uploadedFile.type);
}
if (attachment) {
if ('mimeType' in attachment && attachment.mimeType) {
return getFileTypeLabel(attachment.mimeType);
}
if (attachment.type) {
return getFileTypeLabel(attachment.type);
}
}
return getFileTypeLabel(name);
});
let pdfProcessingMode = $derived.by(() => {
if (attachment?.type === AttachmentType.PDF) {
const pdfAttachment = attachment as DatabaseMessageExtraPdfFile;
return pdfAttachment.processedAsImages ? 'Sent as Image' : 'Sent as Text';
}
return null;
});
</script>
{#snippet textPreview(content: string)}
<div class="relative">
<div
class="font-mono text-xs leading-relaxed break-words whitespace-pre-wrap text-muted-foreground {!readonly
? 'max-h-3rem line-height-1.2'
: ''}"
>
{getPreviewText(content)}
</div>
{#if content.length > 150}
<div
class="pointer-events-none absolute right-0 bottom-0 left-0 h-4 bg-gradient-to-t from-muted to-transparent {readonly
? 'h-6'
: ''}"
></div>
{/if}
</div>
{/snippet}
{#snippet removeButton()}
<div class="absolute top-2 right-2 opacity-0 transition-opacity group-hover:opacity-100">
<ActionIcon icon={X} tooltip="Remove" stopPropagationOnClick onclick={() => onRemove?.(id)} />
</div>
{/snippet}
{#snippet fileIcon()}
<div
class="flex h-8 w-8 items-center justify-center rounded bg-primary/10 text-xs font-medium text-primary"
>
{fileTypeLabel}
</div>
{/snippet}
{#snippet info(text: string | undefined)}
{#if text}
<span class="text-xs text-muted-foreground">{text}</span>
{/if}
{/snippet}
{#if isTextWithContent || isPdfWithContent}
<button
aria-label={readonly ? `Preview ${name}` : undefined}
class="rounded-lg border border-border bg-muted p-3 {className} cursor-pointer {readonly
? 'w-full max-w-2xl transition-shadow hover:shadow-md'
: `group relative text-left ${textContent ? 'max-h-24 max-w-72' : 'max-w-36'}`} overflow-hidden"
{onclick}
type="button"
>
{#if !readonly}
{@render removeButton()}
{/if}
<div class={[!readonly && 'pr-8', 'overflow-hidden']}>
{#if readonly}
<div class="flex items-start gap-3">
<div class="flex min-w-0 flex-1 flex-col items-start text-left">
<span class="w-full truncate text-sm font-medium text-foreground">{name}</span>
{@render info(pdfProcessingMode || (size ? formatFileSize(size) : undefined))}
{#if textContent}
{@render textPreview(textContent)}
{/if}
</div>
</div>
{:else}
<span class="mb-3 block truncate text-sm font-medium text-foreground">{name}</span>
{#if textContent}
{@render textPreview(textContent)}
{/if}
{/if}
</div>
</button>
{:else}
<button
class="group flex items-center gap-3 rounded-lg border border-border bg-muted p-3 {className} relative"
{onclick}
type="button"
>
{@render fileIcon()}
<div class="flex flex-col items-start gap-0.5">
<span
class="max-w-24 truncate text-sm font-medium text-foreground {readonly
? ''
: 'group-hover:pr-6'} md:max-w-32"
>
{name}
</span>
{@render info(pdfProcessingMode || (size ? formatFileSize(size) : undefined))}
</div>
{#if !readonly}
{@render removeButton()}
{/if}
</button>
{/if}

View file

@ -1,64 +1,65 @@
<script lang="ts">
import { ActionIconRemove } from '$lib/components/app';
import { ActionIcon } from '$lib/components/app';
import { X } from '@lucide/svelte';
interface Props {
class?: string;
height?: string;
id: string;
imageClass?: string;
onclick?: (event?: MouseEvent) => void;
onRemove?: (id: string) => void;
name: string;
preview: string;
readonly?: boolean;
onRemove?: (id: string) => void;
onClick?: (event?: MouseEvent) => void;
class?: string;
// Customizable size props
width?: string;
height?: string;
imageClass?: string;
}
let {
class: className = '',
height = 'h-16',
id,
imageClass = '',
onclick,
onRemove,
name,
preview,
readonly = false,
onRemove,
onClick,
class: className = '',
// Default to small size for form previews
width = 'w-auto',
height = 'h-16',
imageClass = ''
width = 'w-auto'
}: Props = $props();
</script>
{#snippet image()}
<img src={preview} alt={name} class="{height} {width} cursor-pointer object-cover {imageClass}" />
{/snippet}
<div
class="group relative overflow-hidden rounded-lg bg-muted shadow-lg dark:border dark:border-muted {className}"
>
{#if onClick}
{#if onclick}
<button
type="button"
class="block h-full w-full rounded-lg focus:ring-2 focus:ring-primary focus:ring-offset-2 focus:outline-none"
onclick={onClick}
aria-label="Preview {name}"
class="block h-full w-full rounded-lg focus:ring-2 focus:ring-primary focus:ring-offset-2 focus:outline-none"
{onclick}
type="button"
>
<img
src={preview}
alt={name}
class="{height} {width} cursor-pointer object-cover {imageClass}"
/>
{@render image()}
</button>
{:else}
<img
src={preview}
alt={name}
class="{height} {width} cursor-pointer object-cover {imageClass}"
/>
{@render image()}
{/if}
{#if !readonly}
<div
class="absolute top-1 right-1 flex items-center justify-center opacity-0 transition-opacity group-hover:opacity-100"
>
<ActionIconRemove {id} {onRemove} class="text-white" />
<ActionIcon
class="text-white"
icon={X}
onclick={() => onRemove?.(id)}
stopPropagationOnClick
tooltip="Remove"
/>
</div>
{/if}
</div>

View file

@ -0,0 +1,190 @@
<script lang="ts">
import {
ChatAttachmentsPreviewCurrentItem,
ChatAttachmentsPreviewFileInfo,
ChatAttachmentsPreviewNavButtons,
ChatAttachmentsPreviewThumbnailStrip
} from '$lib/components/app';
import { modelsStore } from '$lib/stores/models.svelte';
import {
createBase64DataUrl,
formatFileSize,
getAttachmentDisplayItems,
getLanguageFromFilename,
isAudioFile,
isImageFile,
isMcpPrompt,
isMcpResource,
isPdfFile,
isTextFile
} from '$lib/utils';
interface PreviewItem {
id: string;
name: string;
size?: number;
preview?: string;
uploadedFile?: ChatUploadedFile;
attachment?: DatabaseMessageExtra;
textContent?: string;
isImage: boolean;
isAudio: boolean;
}
interface Props {
uploadedFiles?: ChatUploadedFile[];
attachments?: DatabaseMessageExtra[];
activeModelId?: string;
class?: string;
previewFocusIndex?: number;
}
let {
uploadedFiles = [],
attachments = [],
activeModelId,
class: className = '',
previewFocusIndex = 0
}: Props = $props();
let allItems = $derived(
getAttachmentDisplayItems({ uploadedFiles, attachments })
.filter((item) => !isMcpPrompt(item) && !isMcpResource(item))
.map(
(item): PreviewItem => ({
...item,
isImage: isImageFile(item.attachment, item.uploadedFile),
isAudio: isAudioFile(item.attachment, item.uploadedFile)
})
)
);
let currentIndex = $state(0);
$effect(() => {
if (previewFocusIndex >= 0 && previewFocusIndex < allItems.length) {
currentIndex = previewFocusIndex;
}
});
$effect(() => {
const handler = (e: Event) => {
const delta = (e as CustomEvent).detail;
if (delta < 0) {
currentIndex = currentIndex > 0 ? currentIndex - 1 : allItems.length - 1;
} else {
currentIndex = currentIndex < allItems.length - 1 ? currentIndex + 1 : 0;
}
};
document.addEventListener('chat-attachments-nav', handler);
return () => document.removeEventListener('chat-attachments-nav', handler);
});
$effect(() => {
const index = currentIndex;
setTimeout(() => {
const thumbnail = document.querySelector(`[data-thumbnail-index="${index}"]`);
thumbnail?.scrollIntoView({ behavior: 'smooth', inline: 'center', block: 'nearest' });
}, 0);
});
let currentItem = $derived(allItems[currentIndex] ?? null);
let displayName = $derived(
currentItem?.name ||
currentItem?.uploadedFile?.name ||
currentItem?.attachment?.name ||
'Unknown File'
);
let isAudio = $derived(
currentItem ? isAudioFile(currentItem.attachment, currentItem.uploadedFile) : false
);
let isImage = $derived(
currentItem ? isImageFile(currentItem.attachment, currentItem.uploadedFile) : false
);
let isPdf = $derived(
currentItem ? isPdfFile(currentItem.attachment, currentItem.uploadedFile) : false
);
let isText = $derived(
currentItem ? isTextFile(currentItem.attachment, currentItem.uploadedFile) : false
);
let displayPreview = $derived(
currentItem?.uploadedFile?.preview ||
(isImage && currentItem?.attachment && 'base64Url' in currentItem.attachment
? currentItem.attachment.base64Url
: currentItem?.preview)
);
let displayTextContent = $derived(
currentItem?.uploadedFile?.textContent ||
(currentItem?.attachment && 'content' in currentItem.attachment
? currentItem.attachment.content
: currentItem?.textContent)
);
let language = $derived(getLanguageFromFilename(displayName));
let fileSize = $derived(currentItem?.size ? formatFileSize(currentItem.size) : '');
let hasVisionModality = $derived(
currentItem && activeModelId ? modelsStore.modelSupportsVision(activeModelId) : false
);
let audioSrc = $derived(
isAudio && currentItem
? (currentItem.uploadedFile?.preview ??
(currentItem.attachment &&
'mimeType' in currentItem.attachment &&
'base64Data' in currentItem.attachment
? createBase64DataUrl(
currentItem.attachment.mimeType,
currentItem.attachment.base64Data
)
: null))
: null
);
export function prev() {
currentIndex = currentIndex > 0 ? currentIndex - 1 : allItems.length - 1;
}
export function next() {
currentIndex = currentIndex < allItems.length - 1 ? currentIndex + 1 : 0;
}
function onNavigate(index: number) {
currentIndex = index;
}
</script>
<div class="{className} flex flex-col text-white">
<div class="relative flex min-h-0 flex-1 items-center justify-center overflow-hidden">
<ChatAttachmentsPreviewNavButtons onPrev={prev} onNext={next} show={allItems.length > 1} />
<div class="flex h-full w-full flex-col items-center justify-start overflow-auto py-4">
{#if currentItem}
<ChatAttachmentsPreviewFileInfo {displayName} {fileSize} />
<ChatAttachmentsPreviewCurrentItem
{currentItem}
{isImage}
{isAudio}
{isPdf}
{isText}
{displayPreview}
{displayTextContent}
{audioSrc}
{language}
{hasVisionModality}
{activeModelId}
/>
{/if}
<ChatAttachmentsPreviewThumbnailStrip items={allItems} {currentIndex} {onNavigate} />
</div>
</div>
</div>

View file

@ -0,0 +1,65 @@
<script lang="ts">
import type { ChatAttachmentDisplayItem } from '$lib/types';
import { Image, Music, FileText, FileIcon } from '@lucide/svelte';
import ChatAttachmentsPreviewCurrentItemPdf from './ChatAttachmentsPreviewCurrentItemPdf.svelte';
import ChatAttachmentsPreviewCurrentItemImage from './ChatAttachmentsPreviewCurrentItemImage.svelte';
import ChatAttachmentsPreviewCurrentItemAudio from './ChatAttachmentsPreviewCurrentItemAudio.svelte';
import ChatAttachmentsPreviewCurrentItemText from './ChatAttachmentsPreviewCurrentItemText.svelte';
import ChatAttachmentsPreviewCurrentItemUnavailable from './ChatAttachmentsPreviewCurrentItemUnavailable.svelte';
interface Props {
currentItem: ChatAttachmentDisplayItem | null;
isImage: boolean;
isAudio: boolean;
isPdf: boolean;
isText: boolean;
displayPreview: string | undefined;
displayTextContent: string | undefined;
audioSrc: string | null;
language: string;
hasVisionModality: boolean;
activeModelId?: string;
}
let {
currentItem,
isImage,
isAudio,
isPdf,
isText,
displayPreview,
displayTextContent,
audioSrc,
language,
hasVisionModality,
activeModelId
}: Props = $props();
let IconComponent = $derived(
isImage ? Image : isText || isPdf ? FileText : isAudio ? Music : FileIcon
);
let isUnavailable = $derived(!isPdf && !isImage && !(isText && displayTextContent) && !isAudio);
</script>
{#if currentItem}
{#key currentItem.id}
{#if isPdf}
<ChatAttachmentsPreviewCurrentItemPdf
{currentItem}
displayName={currentItem.name}
{displayTextContent}
{hasVisionModality}
{activeModelId}
/>
{:else if isImage}
<ChatAttachmentsPreviewCurrentItemImage {currentItem} {displayPreview} />
{:else if isText && displayTextContent}
<ChatAttachmentsPreviewCurrentItemText {displayTextContent} {language} />
{:else if isAudio}
<ChatAttachmentsPreviewCurrentItemAudio {currentItem} {audioSrc} />
{:else if isUnavailable}
<ChatAttachmentsPreviewCurrentItemUnavailable {IconComponent} />
{/if}
{/key}
{/if}

View file

@ -0,0 +1,26 @@
<script lang="ts">
import { Music } from '@lucide/svelte';
interface Props {
currentItem: { name?: string } | null;
audioSrc: string | null;
}
let { currentItem, audioSrc }: Props = $props();
</script>
<div class="flex flex-1 items-center justify-center p-8">
<div class="w-full max-w-md text-center">
<Music class="mx-auto mb-4 h-16 w-16 text-white/50" />
{#if audioSrc}
<audio controls class="mb-4 w-full" src={audioSrc}>
Your browser does not support the audio element.
</audio>
{:else}
<p class="mb-4 text-white/70">Audio preview not available</p>
{/if}
<p class="text-sm text-white/50">{currentItem?.name || 'Audio'}</p>
</div>
</div>

View file

@ -0,0 +1,18 @@
<script lang="ts">
interface Props {
currentItem: { name?: string } | null;
displayPreview: string | undefined;
}
let { currentItem, displayPreview }: Props = $props();
</script>
{#if displayPreview}
<div class="flex flex-1 items-center justify-center">
<img
src={displayPreview}
alt={currentItem?.name || 'preview'}
class="max-h-[80vh] max-w-[80vw] rounded-lg object-contain shadow-lg"
/>
</div>
{/if}

View file

@ -0,0 +1,174 @@
<script lang="ts">
import type { ChatAttachmentDisplayItem } from '$lib/types';
import { FileText, Eye, Info } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import * as Alert from '$lib/components/ui/alert';
import { SyntaxHighlightedCode } from '$lib/components/app';
import { getLanguageFromFilename } from '$lib/utils';
import { convertPDFToImage } from '$lib/utils/browser-only';
import { PdfViewMode } from '$lib/enums';
interface Props {
currentItem: ChatAttachmentDisplayItem | null;
displayName: string;
displayTextContent: string | undefined;
hasVisionModality: boolean;
activeModelId?: string;
}
let { currentItem, displayName, displayTextContent, hasVisionModality, activeModelId }: Props =
$props();
let pdfViewMode = $state<PdfViewMode>(PdfViewMode.PAGES);
let pdfImages = $state<string[]>([]);
let pdfImagesLoading = $state(false);
let pdfImagesError = $state<string | null>(null);
let language = $derived(getLanguageFromFilename(displayName));
async function loadPdfImages() {
if (pdfImages.length > 0 || pdfImagesLoading || !currentItem) return;
pdfImagesLoading = true;
pdfImagesError = null;
try {
let file: File | null = null;
if (currentItem.uploadedFile?.file) {
file = currentItem.uploadedFile.file;
} else if (currentItem.attachment) {
// Check if we have pre-processed images
if (
'images' in currentItem.attachment &&
currentItem.attachment.images &&
Array.isArray(currentItem.attachment.images) &&
currentItem.attachment.images.length > 0
) {
pdfImages = currentItem.attachment.images;
return;
}
// Convert base64 back to File for processing
if ('base64Data' in currentItem.attachment && currentItem.attachment.base64Data) {
const base64Data = currentItem.attachment.base64Data;
const byteCharacters = atob(base64Data);
const byteNumbers = new Array(byteCharacters.length);
for (let i = 0; i < byteCharacters.length; i++) {
byteNumbers[i] = byteCharacters.charCodeAt(i);
}
const byteArray = new Uint8Array(byteNumbers);
file = new File([byteArray], displayName, { type: 'application/pdf' });
}
}
if (file) {
pdfImages = await convertPDFToImage(file);
} else {
throw new Error('No PDF file available for conversion');
}
} catch (error) {
pdfImagesError = error instanceof Error ? error.message : 'Failed to load PDF images';
} finally {
pdfImagesLoading = false;
}
}
$effect(() => {
if (pdfViewMode === PdfViewMode.PAGES) {
loadPdfImages();
}
});
</script>
<div class="mb-4 flex items-center justify-end gap-2">
<Button
variant={pdfViewMode === PdfViewMode.TEXT ? 'default' : 'outline'}
size="sm"
onclick={() => (pdfViewMode = PdfViewMode.TEXT)}
disabled={pdfImagesLoading}
>
<FileText class="mr-1 h-4 w-4" />
Text
</Button>
<Button
variant={pdfViewMode === PdfViewMode.PAGES ? 'default' : 'outline'}
size="sm"
onclick={() => (pdfViewMode = PdfViewMode.PAGES)}
disabled={pdfImagesLoading}
>
{#if pdfImagesLoading}
<div
class="mr-1 h-4 w-4 animate-spin rounded-full border-2 border-current border-t-transparent"
></div>
{:else}
<Eye class="mr-1 h-4 w-4" />
{/if}
Pages
</Button>
</div>
{#if !hasVisionModality && activeModelId && currentItem}
<Alert.Root class="mb-4 max-w-4xl">
<Info class="h-4 w-4" />
<Alert.Title>Preview only</Alert.Title>
<Alert.Description>
<span class="inline-flex">
The selected model does not support vision. Only the extracted
<!-- svelte-ignore a11y_click_events_have_key_events -->
<!-- svelte-ignore a11y_no_static_element_interactions -->
<span
class="mx-1 cursor-pointer underline"
onclick={() => (pdfViewMode = PdfViewMode.TEXT)}
>
text
</span>
will be sent to the model.
</span>
</Alert.Description>
</Alert.Root>
{/if}
{#if pdfImagesLoading}
<div class="flex flex-1 items-center justify-center p-8">
<div class="text-center">
<div
class="mx-auto mb-4 h-8 w-8 animate-spin rounded-full border-4 border-white border-t-transparent"
></div>
<p class="text-white/70">Converting PDF to images...</p>
</div>
</div>
{:else if pdfImagesError}
<div class="flex flex-1 items-center justify-center p-8">
<div class="text-center">
<FileText class="mx-auto mb-4 h-16 w-16 text-white/50" />
<p class="mb-4 text-white/70">Failed to load PDF images</p>
<p class="text-sm text-white/50">{pdfImagesError}</p>
</div>
</div>
{:else if pdfImages.length > 0}
{#each pdfImages as image, index (image)}
<p class="mb-2 text-sm text-white/50">Page {index + 1}</p>
<img src={image} alt="PDF Page {index + 1}" class="mx-auto max-w-[85vw] rounded-lg shadow-lg" />
<div class="h-4"></div>
{/each}
{:else}
<div class="flex flex-1 items-center justify-center p-8">
<div class="text-center">
<FileText class="mx-auto mb-4 h-16 w-16 text-white/50" />
<p class="text-white/70">No PDF pages available</p>
</div>
</div>
{/if}
{#if pdfViewMode === PdfViewMode.TEXT && displayTextContent}
<div class="px-4 pb-4">
<SyntaxHighlightedCode
class="max-w-4xl"
code={displayTextContent}
{language}
maxHeight="none"
/>
</div>
{/if}

View file

@ -0,0 +1,21 @@
<script lang="ts">
import { SyntaxHighlightedCode } from '$lib/components/app';
interface Props {
displayTextContent: string | undefined;
language: string;
}
let { displayTextContent, language }: Props = $props();
</script>
{#if displayTextContent}
<div class="px-4 pb-4">
<SyntaxHighlightedCode
class="max-w-4xl"
code={displayTextContent}
{language}
maxHeight="none"
/>
</div>
{/if}

View file

@ -0,0 +1,17 @@
<script lang="ts">
import type { Component } from 'svelte';
interface Props {
IconComponent: Component;
}
let { IconComponent }: Props = $props();
</script>
<div class="flex flex-1 items-center justify-center p-8">
<div class="text-center">
<IconComponent class="mx-auto mb-4 h-16 w-16 text-white/50" />
<p class="text-white/70">Preview not available for this file type</p>
</div>
</div>

View file

@ -0,0 +1,16 @@
<script lang="ts">
interface Props {
displayName: string;
fileSize: string;
}
let { displayName, fileSize }: Props = $props();
</script>
<div class="sticky top-0 z-[20] mb-4 rounded-lg bg-black/5 px-4 py-2 text-center backdrop-blur-md">
<p class="font-medium text-white">{displayName}</p>
{#if fileSize}
<p class="text-xs text-white/60">{fileSize}</p>
{/if}
</div>

View file

@ -0,0 +1,34 @@
<script lang="ts">
import { ChevronLeft, ChevronRight } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
interface Props {
onPrev: () => void;
onNext: () => void;
show: boolean;
}
let { onPrev, onNext, show }: Props = $props();
</script>
{#if show}
<Button
variant="secondary"
size="icon"
class="absolute top-1/2 left-4 z-10 h-8 w-8 -translate-y-1/2 rounded-full bg-background/5 p-0 text-white!"
onclick={onPrev}
aria-label="Previous"
>
<ChevronLeft class="size-4" />
</Button>
<Button
variant="secondary"
size="icon"
class="absolute top-1/2 right-4 z-10 h-8 w-8 -translate-y-1/2 rounded-full bg-background/5 p-0 text-white!"
onclick={onNext}
aria-label="Next"
>
<ChevronRight class="size-4" />
</Button>
{/if}

View file

@ -0,0 +1,63 @@
<script lang="ts">
import { Music, FileText } from '@lucide/svelte';
import { HorizontalScrollCarousel } from '$lib/components/app/misc';
interface PreviewItem {
id: string;
name: string;
isImage: boolean;
isAudio: boolean;
preview?: string;
}
interface Props {
items: PreviewItem[];
currentIndex: number;
onNavigate: (index: number) => void;
}
let { items, currentIndex, onNavigate }: Props = $props();
function getFileExtension(name: string): string {
const parts = name.split('.');
if (parts.length > 1) {
return parts.pop()?.toUpperCase() ?? '';
}
return '';
}
</script>
{#if items.length > 1}
<div class="sticky bottom-0 z-10 mt-4 flex-shrink-0">
<HorizontalScrollCarousel class="max-w-full">
{#each items as item, index (item.id)}
<button
data-thumbnail-index={index}
class={[
'relative flex-shrink-0 cursor-pointer overflow-hidden rounded border-2 bg-black/80 backdrop-blur-sm transition-all hover:opacity-90',
index === currentIndex ? 'border-white' : 'border-transparent opacity-60',
'[&:not(:first-child)]:last:mr-4 [&:not(:last-child)]:first:ml-4'
]}
onclick={() => onNavigate(index)}
aria-label={`Go to ${item.name}`}
>
{#if item.isImage && item.preview}
<img src={item.preview} alt={item.name} class="h-12 w-12 object-cover" />
{:else}
<div
class="bg-foreground-muted/50 flex h-12 w-12 flex-col items-center justify-center gap-0.5 py-1"
>
{#if item.isAudio}
<Music class="h-4 w-4 text-white/70" />
{:else}
<FileText class="h-4 w-4 text-white/70" />
{/if}
<span class="font-mono text-[9px] text-white/60">{getFileExtension(item.name)}</span>
</div>
{/if}
</button>
{/each}
</HorizontalScrollCarousel>
</div>
{/if}

View file

@ -1,117 +0,0 @@
<script lang="ts">
import {
ChatAttachmentThumbnailImage,
ChatAttachmentThumbnailFile,
DialogChatAttachmentPreview
} from '$lib/components/app';
import { getAttachmentDisplayItems } from '$lib/utils';
interface Props {
uploadedFiles?: ChatUploadedFile[];
attachments?: DatabaseMessageExtra[];
readonly?: boolean;
onFileRemove?: (fileId: string) => void;
imageHeight?: string;
imageWidth?: string;
imageClass?: string;
activeModelId?: string;
}
let {
uploadedFiles = [],
attachments = [],
readonly = false,
onFileRemove,
imageHeight = 'h-24',
imageWidth = 'w-auto',
imageClass = '',
activeModelId
}: Props = $props();
let previewDialogOpen = $state(false);
let previewItem = $state<ChatAttachmentPreviewItem | null>(null);
let displayItems = $derived(getAttachmentDisplayItems({ uploadedFiles, attachments }));
let imageItems = $derived(displayItems.filter((item) => item.isImage));
let fileItems = $derived(displayItems.filter((item) => !item.isImage));
function openPreview(item: (typeof displayItems)[0], event?: Event) {
if (event) {
event.preventDefault();
event.stopPropagation();
}
previewItem = {
uploadedFile: item.uploadedFile,
attachment: item.attachment,
preview: item.preview,
name: item.name,
size: item.size,
textContent: item.textContent
};
previewDialogOpen = true;
}
</script>
<div class="space-y-4">
<div class="min-h-0 flex-1 space-y-6 overflow-y-auto px-1">
{#if fileItems.length > 0}
<div>
<h3 class="mb-3 text-sm font-medium text-foreground">Files ({fileItems.length})</h3>
<div class="flex flex-wrap items-start gap-3">
{#each fileItems as item (item.id)}
<ChatAttachmentThumbnailFile
class="cursor-pointer"
id={item.id}
name={item.name}
size={item.size}
{readonly}
onRemove={onFileRemove}
textContent={item.textContent}
attachment={item.attachment}
uploadedFile={item.uploadedFile}
onClick={(event?: MouseEvent) => openPreview(item, event)}
/>
{/each}
</div>
</div>
{/if}
{#if imageItems.length > 0}
<div>
<h3 class="mb-3 text-sm font-medium text-foreground">Images ({imageItems.length})</h3>
<div class="flex flex-wrap items-start gap-3">
{#each imageItems as item (item.id)}
{#if item.preview}
<ChatAttachmentThumbnailImage
class="cursor-pointer"
id={item.id}
name={item.name}
preview={item.preview}
{readonly}
onRemove={onFileRemove}
height={imageHeight}
width={imageWidth}
{imageClass}
onClick={(event) => openPreview(item, event)}
/>
{/if}
{/each}
</div>
</div>
{/if}
</div>
</div>
{#if previewItem}
<DialogChatAttachmentPreview
bind:open={previewDialogOpen}
uploadedFile={previewItem.uploadedFile}
attachment={previewItem.attachment}
preview={previewItem.preview}
name={previewItem.name}
size={previewItem.size}
textContent={previewItem.textContent}
{activeModelId}
/>
{/if}

View file

@ -1,14 +1,13 @@
<script lang="ts">
import {
ChatAttachmentsList,
ChatAttachmentMcpResources,
ChatFormActions,
ChatFormFileInputInvisible,
ChatFormPromptPicker,
ChatFormResourcePicker,
ChatFormTextarea
ChatFormMcpResourcesList,
ChatFormPickers,
ChatFormTextarea,
DialogMcpResourcesBrowser
} from '$lib/components/app';
import { DialogMcpResources } from '$lib/components/app/dialogs';
import {
CLIPBOARD_CONTENT_QUOTE_PREFIX,
INPUT_CLASSES,
@ -54,6 +53,8 @@
isLoading?: boolean;
placeholder?: string;
showMcpPromptButton?: boolean;
showAddButton?: boolean;
showModelSelector?: boolean;
// Event Handlers
onAttachmentRemove?: (index: number) => void;
@ -73,6 +74,8 @@
isLoading = false,
placeholder = 'Type a message...',
showMcpPromptButton = false,
showAddButton = true,
showModelSelector = true,
uploadedFiles = $bindable([]),
value = $bindable(''),
onAttachmentRemove,
@ -85,31 +88,21 @@
onValueChange
}: Props = $props();
/**
*
*
* STATE
*
*
*/
// Component References
let audioRecorder: AudioRecorder | undefined;
let chatFormActionsRef: ChatFormActions | undefined = $state(undefined);
let fileInputRef: ChatFormFileInputInvisible | undefined = $state(undefined);
let promptPickerRef: ChatFormPromptPicker | undefined = $state(undefined);
let resourcePickerRef: ChatFormResourcePicker | undefined = $state(undefined);
let pickersRef: { handleKeydown: (event: KeyboardEvent) => boolean } | undefined =
$state(undefined);
let textareaRef: ChatFormTextarea | undefined = $state(undefined);
// Audio Recording State
let isRecording = $state(false);
let recordingSupported = $state(false);
// Prompt Picker State
// Picker State
let isPromptPickerOpen = $state(false);
let promptSearchQuery = $state('');
// Inline Resource Picker State (triggered by @)
let isInlineResourcePickerOpen = $state(false);
let resourceSearchQuery = $state('');
@ -117,22 +110,12 @@
let isResourceDialogOpen = $state(false);
let preSelectedResourceUri = $state<string | undefined>(undefined);
/**
*
*
* DERIVED STATE
*
*
*/
// Configuration
let currentConfig = $derived(config());
let pasteLongTextToFileLength = $derived.by(() => {
const n = Number(currentConfig.pasteLongTextToFileLen);
return Number.isNaN(n) ? Number(SETTING_CONFIG_DEFAULT.pasteLongTextToFileLen) : n;
});
// Model Selection Logic
let isRouter = $derived(isRouterMode());
let conversationModel = $derived(
chatStore.getConversationModel(activeMessages() as DatabaseMessage[])
@ -158,7 +141,6 @@
return null;
});
// Form Validation State
let hasModelSelected = $derived(!isRouter || !!conversationModel || !!selectedModelId());
let hasLoadingAttachments = $derived(uploadedFiles.some((f) => f.isLoading));
let hasAttachments = $derived(
@ -166,27 +148,11 @@
);
let canSubmit = $derived(value.trim().length > 0 || hasAttachments);
/**
*
*
* LIFECYCLE
*
*
*/
onMount(() => {
recordingSupported = isAudioRecordingSupported();
audioRecorder = new AudioRecorder();
});
/**
*
*
* PUBLIC API
*
*
*/
export function focus() {
textareaRef?.focus();
}
@ -199,10 +165,6 @@
chatFormActionsRef?.openModelSelector();
}
/**
* Check if a model is selected, open selector if not
* @returns true if model is selected, false otherwise
*/
export function checkModelSelected(): boolean {
if (!hasModelSelected) {
chatFormActionsRef?.openModelSelector();
@ -211,14 +173,6 @@
return true;
}
/**
*
*
* EVENT HANDLERS - File Management
*
*
*/
function handleFileSelect(files: File[]) {
onFilesAdd?.(files);
}
@ -238,14 +192,6 @@
}
}
/**
*
*
* EVENT HANDLERS - Input & Keyboard
*
*
*/
function handleInput() {
const perChatOverrides = conversationsStore.getAllMcpServerOverrides();
const hasServers = mcpStore.hasEnabledServers(perChatOverrides);
@ -273,11 +219,7 @@
}
function handleKeydown(event: KeyboardEvent) {
if (isPromptPickerOpen && promptPickerRef?.handleKeydown(event)) {
return;
}
if (isInlineResourcePickerOpen && resourcePickerRef?.handleKeydown(event)) {
if (pickersRef?.handleKeydown(event)) {
return;
}
@ -388,14 +330,6 @@
}
}
/**
*
*
* EVENT HANDLERS - Prompt Picker
*
*
*/
function handlePromptLoadStart(
placeholderId: string,
promptInfo: MCPPromptInfo,
@ -474,14 +408,6 @@
textareaRef?.focus();
}
/**
*
*
* EVENT HANDLERS - Inline Resource Picker
*
*
*/
function handleInlineResourcePickerClose() {
isInlineResourcePickerOpen = false;
resourceSearchQuery = '';
@ -489,7 +415,6 @@
}
function handleInlineResourceSelect() {
// Clear the @query from input after resource is attached
if (value.startsWith(RESOURCE_TRIGGER_PREFIX)) {
value = '';
onValueChange?.('');
@ -512,14 +437,6 @@
isResourceDialogOpen = true;
}
/**
*
*
* EVENT HANDLERS - Audio Recording
*
*
*/
async function handleMicClick() {
if (!audioRecorder || !recordingSupported) {
console.warn('Audio recording not supported');
@ -552,29 +469,27 @@
<form
class="relative {className}"
onsubmit={(e) => {
e.preventDefault();
onsubmit={(event) => {
event.preventDefault();
if (!canSubmit || disabled || hasLoadingAttachments) return;
onSubmit?.();
}}
>
<ChatFormPromptPicker
bind:this={promptPickerRef}
isOpen={isPromptPickerOpen}
searchQuery={promptSearchQuery}
onClose={handlePromptPickerClose}
<ChatFormPickers
bind:this={pickersRef}
{isPromptPickerOpen}
{promptSearchQuery}
{isInlineResourcePickerOpen}
{resourceSearchQuery}
onPromptPickerClose={handlePromptPickerClose}
onInlineResourcePickerClose={handleInlineResourcePickerClose}
onInlineResourceSelect={handleInlineResourceSelect}
onPromptLoadStart={handlePromptLoadStart}
onPromptLoadComplete={handlePromptLoadComplete}
onPromptLoadError={handlePromptLoadError}
/>
<ChatFormResourcePicker
bind:this={resourcePickerRef}
isOpen={isInlineResourcePickerOpen}
searchQuery={resourceSearchQuery}
onClose={handleInlineResourcePickerClose}
onResourceSelect={handleInlineResourceSelect}
onBrowse={handleBrowseResources}
onInlineResourceBrowse={handleBrowseResources}
/>
<div
@ -611,7 +526,7 @@
/>
{#if mcpHasResourceAttachments()}
<ChatAttachmentMcpResources
<ChatFormMcpResourcesList
class="mb-3"
onResourceClick={(uri) => {
preSelectedResourceUri = uri;
@ -624,10 +539,11 @@
class="px-3"
bind:this={chatFormActionsRef}
canSend={canSubmit}
hasText={value.trim().length > 0}
{disabled}
{isLoading}
{isRecording}
{showAddButton}
{showModelSelector}
{uploadedFiles}
onFileUpload={handleFileUpload}
onMicClick={handleMicClick}
@ -640,7 +556,7 @@
</div>
</form>
<DialogMcpResources
<DialogMcpResourcesBrowser
bind:open={isResourceDialogOpen}
preSelectedUri={preSelectedResourceUri}
onAttach={(resource: MCPResourceInfo) => {

View file

@ -0,0 +1,33 @@
<script lang="ts">
import { Plus } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import * as Tooltip from '$lib/components/ui/tooltip';
import { ATTACHMENT_TOOLTIP_TEXT } from '$lib/constants';
interface Props {
disabled?: boolean;
onclick?: (e: MouseEvent) => void;
}
let { disabled = false, onclick }: Props = $props();
</script>
<Tooltip.Root>
<Tooltip.Trigger class="w-full">
<Button
class="file-upload-button h-8 w-8 rounded-full p-0"
{disabled}
{onclick}
variant="secondary"
type="button"
>
<span class="sr-only">{ATTACHMENT_TOOLTIP_TEXT}</span>
<Plus class="h-4 w-4" />
</Button>
</Tooltip.Trigger>
<Tooltip.Content>
<p>{ATTACHMENT_TOOLTIP_TEXT}</p>
</Tooltip.Content>
</Tooltip.Root>

View file

@ -1,17 +1,18 @@
<script lang="ts">
import { Plus } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import type { Snippet } from 'svelte';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import * as Tooltip from '$lib/components/ui/tooltip';
import {
ATTACHMENT_FILE_ITEMS,
ATTACHMENT_EXTRA_ITEMS,
ATTACHMENT_MCP_ITEMS,
ATTACHMENT_TOOLTIP_TEXT,
TOOLTIP_DELAY_DURATION
} from '$lib/constants';
import { AttachmentMenuItemId } from '$lib/enums';
import { ChatFormActionToolsSubmenu, ChatFormActionMcpServersSubmenu } from '$lib/components/app';
import {
ChatFormActionAddToolsSubmenu,
ChatFormActionAddMcpServersSubmenu
} from '$lib/components/app';
import { useAttachmentMenu } from '$lib/hooks/use-attachment-menu.svelte';
@ -27,6 +28,7 @@
onMcpPromptClick?: () => void;
onMcpSettingsClick?: () => void;
onMcpResourcesClick?: () => void;
trigger: Snippet<[{ disabled: boolean }]>;
}
let {
@ -40,7 +42,8 @@
onSystemPromptClick,
onMcpPromptClick,
onMcpSettingsClick,
onMcpResourcesClick
onMcpResourcesClick,
trigger
}: Props = $props();
let dropdownOpen = $state(false);
@ -62,24 +65,7 @@
<div class="flex items-center gap-1 {className}">
<DropdownMenu.Root bind:open={dropdownOpen}>
<DropdownMenu.Trigger name="Attach files" {disabled}>
<Tooltip.Root>
<Tooltip.Trigger class="w-full">
<Button
class="file-upload-button h-8 w-8 rounded-full p-0"
{disabled}
variant="secondary"
type="button"
>
<span class="sr-only">{ATTACHMENT_TOOLTIP_TEXT}</span>
<Plus class="h-4 w-4" />
</Button>
</Tooltip.Trigger>
<Tooltip.Content>
<p>{ATTACHMENT_TOOLTIP_TEXT}</p>
</Tooltip.Content>
</Tooltip.Root>
{@render trigger({ disabled })}
</DropdownMenu.Trigger>
<DropdownMenu.Content align="start" class="w-48">
@ -161,9 +147,9 @@
{/if}
{/each}
<ChatFormActionToolsSubmenu />
<ChatFormActionAddToolsSubmenu />
<ChatFormActionMcpServersSubmenu onMcpSettingsClick={handleMcpSettingsClick} />
<ChatFormActionAddMcpServersSubmenu onMcpSettingsClick={handleMcpSettingsClick} />
{#each ATTACHMENT_MCP_ITEMS as item (item.id)}
{#if attachmentMenu.isItemVisible(item.visibleWhen)}

View file

@ -0,0 +1,149 @@
<script lang="ts">
import { Settings, Plus } from '@lucide/svelte';
import { Switch } from '$lib/components/ui/switch';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import { McpLogo, DropdownMenuSearchable } from '$lib/components/app';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { HealthCheckStatus } from '$lib/enums';
import type { MCPServerSettingsEntry } from '$lib/types';
import { goto } from '$app/navigation';
interface Props {
onMcpSettingsClick?: () => void;
}
let { onMcpSettingsClick }: Props = $props();
let mcpSearchQuery = $state('');
let allMcpServers = $derived(mcpStore.getServersSorted());
let mcpServers = $derived(allMcpServers.filter((s) => s.enabled));
let hasMcpServers = $derived(mcpServers.length > 0);
// let hasAnyMcpServers = $derived(allMcpServers.length > 0);
let filteredMcpServers = $derived.by(() => {
const query = mcpSearchQuery.toLowerCase().trim();
if (!query) return mcpServers;
return mcpServers.filter((s) => {
const name = getServerLabel(s).toLowerCase();
const url = s.url.toLowerCase();
return name.includes(query) || url.includes(query);
});
});
function getServerLabel(server: MCPServerSettingsEntry): string {
return mcpStore.getServerLabel(server);
}
function isServerEnabledForChat(serverId: string): boolean {
return conversationsStore.isMcpServerEnabledForChat(serverId);
}
async function toggleServerForChat(serverId: string) {
await conversationsStore.toggleMcpServerForChat(serverId);
}
function handleMcpSubMenuOpen(open: boolean) {
if (open) {
mcpSearchQuery = '';
mcpStore.runHealthChecksForServers(allMcpServers);
}
}
function handleMcpSettingsClick() {
onMcpSettingsClick?.();
goto(`${hasMcpServers ? '' : '?add'}#/settings/mcp`);
}
</script>
<DropdownMenu.Root>
<DropdownMenu.Sub onOpenChange={handleMcpSubMenuOpen}>
<DropdownMenu.SubTrigger class="flex cursor-pointer items-center gap-2">
<McpLogo class="h-4 w-4" />
<span>MCP Servers</span>
</DropdownMenu.SubTrigger>
<DropdownMenu.SubContent class="w-72 pt-0">
{#if hasMcpServers}
<DropdownMenuSearchable
placeholder="Search servers..."
bind:searchValue={mcpSearchQuery}
emptyMessage="No servers found"
isEmpty={filteredMcpServers.length === 0}
>
<div class="max-h-64 overflow-y-auto">
{#each filteredMcpServers as server (server.id)}
{@const healthState = mcpStore.getHealthCheckState(server.id)}
{@const hasError = healthState.status === HealthCheckStatus.ERROR}
{@const isEnabledForChat = isServerEnabledForChat(server.id)}
<button
type="button"
class="flex w-full items-center justify-between gap-2 rounded-sm px-2 py-2 text-left transition-colors hover:bg-accent disabled:cursor-not-allowed disabled:opacity-50"
onclick={() => !hasError && toggleServerForChat(server.id)}
disabled={hasError}
>
<div class="flex min-w-0 flex-1 items-center gap-2">
{#if mcpStore.getServerFavicon(server.id)}
<img
src={mcpStore.getServerFavicon(server.id)}
alt=""
class="h-4 w-4 shrink-0 rounded-sm"
onerror={(e) => {
(e.currentTarget as HTMLImageElement).style.display = 'none';
}}
/>
{/if}
<span class="truncate text-sm">{getServerLabel(server)}</span>
{#if hasError}
<span
class="shrink-0 rounded bg-destructive/15 px-1.5 py-0.5 text-xs text-destructive"
>
Error
</span>
{/if}
</div>
<Switch
checked={isEnabledForChat}
disabled={hasError}
onclick={(e) => e.stopPropagation()}
onCheckedChange={() => toggleServerForChat(server.id)}
/>
</button>
{/each}
</div>
{#snippet footer()}
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={handleMcpSettingsClick}
>
<Settings class="h-4 w-4" />
<span>Manage MCP Servers</span>
</DropdownMenu.Item>
{/snippet}
</DropdownMenuSearchable>
{:else}
<div class="px-2 py-3 text-center text-sm text-muted-foreground">
No MCP servers configured
</div>
<DropdownMenu.Separator />
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={handleMcpSettingsClick}
>
<Plus class="h-4 w-4" />
<span>Add MCP Servers</span>
</DropdownMenu.Item>
{/if}
</DropdownMenu.SubContent>
</DropdownMenu.Sub>
</DropdownMenu.Root>

View file

@ -1,18 +1,17 @@
<script lang="ts">
import { Plus } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import type { Snippet } from 'svelte';
import * as Tooltip from '$lib/components/ui/tooltip';
import * as Sheet from '$lib/components/ui/sheet';
import { TOOLTIP_DELAY_DURATION } from '$lib/constants';
import {
ATTACHMENT_FILE_ITEMS,
ATTACHMENT_EXTRA_ITEMS,
ATTACHMENT_MCP_ITEMS,
ATTACHMENT_TOOLTIP_TEXT
ATTACHMENT_MCP_ITEMS
} from '$lib/constants/attachment-menu';
import { ChatFormActionToolsSubmenu, ChatFormActionMcpServersSubmenu } from '$lib/components/app';
import { McpLogo } from '$lib/components/app';
import { useAttachmentMenu } from '$lib/hooks/use-attachment-menu.svelte';
import { AttachmentMenuItemId } from '$lib/enums';
import { PencilRuler } from '@lucide/svelte';
interface Props {
class?: string;
@ -24,8 +23,8 @@
onFileUpload?: () => void;
onSystemPromptClick?: () => void;
onMcpPromptClick?: () => void;
onMcpSettingsClick?: () => void;
onMcpResourcesClick?: () => void;
trigger: Snippet<[{ disabled: boolean; onclick?: () => void }]>;
}
let {
@ -38,8 +37,8 @@
onFileUpload,
onSystemPromptClick,
onMcpPromptClick,
onMcpSettingsClick,
onMcpResourcesClick
onMcpResourcesClick,
trigger
}: Props = $props();
let sheetOpen = $state(false);
@ -52,28 +51,14 @@
}
);
function handleMcpSettingsClick() {
sheetOpen = false;
onMcpSettingsClick?.();
}
const sheetItemClass =
'flex w-full items-center gap-3 rounded-md px-3 py-2.5 text-left text-sm transition-colors hover:bg-accent active:bg-accent disabled:cursor-not-allowed disabled:opacity-50';
</script>
<div class="flex items-center gap-1 {className}">
<Sheet.Root bind:open={sheetOpen}>
<Button
class="file-upload-button h-8 w-8 rounded-full p-0"
{disabled}
variant="secondary"
type="button"
onclick={() => (sheetOpen = true)}
>
<span class="sr-only">{ATTACHMENT_TOOLTIP_TEXT}</span>
<Plus class="h-4 w-4" />
</Button>
{@render trigger({ disabled, onclick: () => (sheetOpen = true) })}
<!-- <ChatFormActionAddButton {disabled} onclick={() => (sheetOpen = true)} /> -->
<Sheet.Content side="bottom" class="max-h-[85vh] gap-0 overflow-y-auto">
<Sheet.Header>
@ -161,9 +146,17 @@
<div class="my-2 border-t"></div>
<ChatFormActionToolsSubmenu />
<a href="#/settings/mcp" class="flex items-center gap-3 px-3 py-2">
<McpLogo class="inline h-4 w-4" />
<ChatFormActionMcpServersSubmenu onMcpSettingsClick={handleMcpSettingsClick} />
<span class="text-sm">MCP Servers</span>
</a>
<a href="#/settings/chat/tools" class="flex items-center gap-3 px-3 py-2">
<PencilRuler class="inline h-4 w-4" />
<span class="text-sm">Tools</span>
</a>
{#each ATTACHMENT_MCP_ITEMS as item (item.id)}
{#if attachmentMenu.isItemVisible(item.visibleWhen)}

View file

@ -24,6 +24,7 @@
{#if toolsStore.loading}
<div class="px-3 py-4 text-center text-sm text-muted-foreground">
<Loader2 class="mx-auto mb-1 h-4 w-4 animate-spin" />
Loading tools...
</div>
{:else if toolsStore.isToolsEndpointUnreachable}
@ -31,19 +32,21 @@
<span class="flex gap-2">
<Info class="mt-0.5 h-4 w-4 shrink-0" />
<span
>Run llama-server with <code>--tools</code> flag to enable
<strong>Built-in Tools</strong>.</span
>
<span>
Run llama-server with <code>--tools</code> flag to enable
<strong>Built-in Tools</strong>.
</span>
</span>
<span class="flex gap-2">
<Info class="mt-0.5 h-4 w-4 shrink-0" />
<span
>{hasMcpServersAvailable ? 'Enable' : 'Add'} MCP Server(s) to access
<strong>MCP Tools</strong>.</span
>
<span>
{hasMcpServersAvailable ? 'Enable' : 'Add'} MCP Server(s) to access
<strong>MCP Tools</strong>.
</span>
</span>
</div>
{:else if toolsStore.error}

View file

@ -0,0 +1,68 @@
<script lang="ts">
import { IsMobile } from '$lib/hooks/is-mobile.svelte';
import ChatFormActionAddDropdown from './ChatFormActionAddDropdown.svelte';
import ChatFormActionAddSheet from './ChatFormActionAddSheet.svelte';
import ChatFormActionAddButton from './ChatFormActionAddButton.svelte';
interface Props {
disabled?: boolean;
hasAudioModality?: boolean;
hasMcpPromptsSupport?: boolean;
hasMcpResourcesSupport?: boolean;
hasVisionModality?: boolean;
onFileUpload?: () => void;
onMcpPromptClick?: () => void;
onMcpResourcesClick?: () => void;
onMcpSettingsClick?: () => void;
onSystemPromptClick?: () => void;
}
let {
disabled = false,
hasAudioModality = false,
hasMcpPromptsSupport = false,
hasMcpResourcesSupport = false,
hasVisionModality = false,
onFileUpload,
onMcpPromptClick,
onMcpResourcesClick,
onMcpSettingsClick,
onSystemPromptClick
}: Props = $props();
const isMobile = new IsMobile();
</script>
{#if isMobile.current}
<ChatFormActionAddSheet
{disabled}
{hasAudioModality}
{hasVisionModality}
{hasMcpPromptsSupport}
{hasMcpResourcesSupport}
{onFileUpload}
{onMcpPromptClick}
{onMcpResourcesClick}
>
{#snippet trigger({ disabled, onclick })}
<ChatFormActionAddButton {disabled} {onclick} />
{/snippet}
</ChatFormActionAddSheet>
{:else}
<ChatFormActionAddDropdown
{disabled}
{hasAudioModality}
{hasVisionModality}
{hasMcpPromptsSupport}
{hasMcpResourcesSupport}
{onFileUpload}
{onMcpPromptClick}
{onMcpResourcesClick}
{onMcpSettingsClick}
{onSystemPromptClick}
>
{#snippet trigger()}
<ChatFormActionAddButton {disabled} />
{/snippet}
</ChatFormActionAddDropdown>
{/if}

View file

@ -1,147 +0,0 @@
<script lang="ts">
import { Settings, Plus } from '@lucide/svelte';
import { Switch } from '$lib/components/ui/switch';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import { McpLogo, DropdownMenuSearchable } from '$lib/components/app';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { HealthCheckStatus } from '$lib/enums';
import type { MCPServerSettingsEntry } from '$lib/types';
import { goto } from '$app/navigation';
interface Props {
onMcpSettingsClick?: () => void;
}
let { onMcpSettingsClick }: Props = $props();
let mcpSearchQuery = $state('');
let allMcpServers = $derived(mcpStore.getServersSorted());
let mcpServers = $derived(allMcpServers.filter((s) => s.enabled));
let hasMcpServers = $derived(mcpServers.length > 0);
// let hasAnyMcpServers = $derived(allMcpServers.length > 0);
let filteredMcpServers = $derived.by(() => {
const query = mcpSearchQuery.toLowerCase().trim();
if (!query) return mcpServers;
return mcpServers.filter((s) => {
const name = getServerLabel(s).toLowerCase();
const url = s.url.toLowerCase();
return name.includes(query) || url.includes(query);
});
});
function getServerLabel(server: MCPServerSettingsEntry): string {
return mcpStore.getServerLabel(server);
}
function isServerEnabledForChat(serverId: string): boolean {
return conversationsStore.isMcpServerEnabledForChat(serverId);
}
async function toggleServerForChat(serverId: string) {
await conversationsStore.toggleMcpServerForChat(serverId);
}
function handleMcpSubMenuOpen(open: boolean) {
if (open) {
mcpSearchQuery = '';
mcpStore.runHealthChecksForServers(allMcpServers);
}
}
function handleMcpSettingsClick() {
onMcpSettingsClick?.();
goto(`${hasMcpServers ? '' : '?add'}#/settings/mcp`);
}
</script>
<DropdownMenu.Sub onOpenChange={handleMcpSubMenuOpen}>
<DropdownMenu.SubTrigger class="flex cursor-pointer items-center gap-2">
<McpLogo class="h-4 w-4" />
<span>MCP Servers</span>
</DropdownMenu.SubTrigger>
<DropdownMenu.SubContent class="w-72 pt-0">
{#if hasMcpServers}
<DropdownMenuSearchable
placeholder="Search servers..."
bind:searchValue={mcpSearchQuery}
emptyMessage="No servers found"
isEmpty={filteredMcpServers.length === 0}
>
<div class="max-h-64 overflow-y-auto">
{#each filteredMcpServers as server (server.id)}
{@const healthState = mcpStore.getHealthCheckState(server.id)}
{@const hasError = healthState.status === HealthCheckStatus.ERROR}
{@const isEnabledForChat = isServerEnabledForChat(server.id)}
<button
type="button"
class="flex w-full items-center justify-between gap-2 rounded-sm px-2 py-2 text-left transition-colors hover:bg-accent disabled:cursor-not-allowed disabled:opacity-50"
onclick={() => !hasError && toggleServerForChat(server.id)}
disabled={hasError}
>
<div class="flex min-w-0 flex-1 items-center gap-2">
{#if mcpStore.getServerFavicon(server.id)}
<img
src={mcpStore.getServerFavicon(server.id)}
alt=""
class="h-4 w-4 shrink-0 rounded-sm"
onerror={(e) => {
(e.currentTarget as HTMLImageElement).style.display = 'none';
}}
/>
{/if}
<span class="truncate text-sm">{getServerLabel(server)}</span>
{#if hasError}
<span
class="shrink-0 rounded bg-destructive/15 px-1.5 py-0.5 text-xs text-destructive"
>
Error
</span>
{/if}
</div>
<Switch
checked={isEnabledForChat}
disabled={hasError}
onclick={(e: MouseEvent) => e.stopPropagation()}
onCheckedChange={() => toggleServerForChat(server.id)}
/>
</button>
{/each}
</div>
{#snippet footer()}
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={handleMcpSettingsClick}
>
<Settings class="h-4 w-4" />
<span>Manage MCP Servers</span>
</DropdownMenu.Item>
{/snippet}
</DropdownMenuSearchable>
{:else}
<div class="px-2 py-3 text-center text-sm text-muted-foreground">
No MCP servers configured
</div>
<DropdownMenu.Separator />
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={handleMcpSettingsClick}
>
<Plus class="h-4 w-4" />
<span>Add MCP Servers</span>
</DropdownMenu.Item>
{/if}
</DropdownMenu.SubContent>
</DropdownMenu.Sub>

View file

@ -0,0 +1,160 @@
<script lang="ts">
import { chatStore } from '$lib/stores/chat.svelte';
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
import { isRouterMode, serverError } from '$lib/stores/server.svelte';
import { ModelsSelectorDropdown, ModelsSelectorSheet } from '$lib/components/app';
import { IsMobile } from '$lib/hooks/is-mobile.svelte';
import { activeMessages } from '$lib/stores/conversations.svelte';
interface Props {
currentModel?: string;
disabled?: boolean;
forceForegroundText?: boolean;
hasAudioModality?: boolean;
hasVisionModality?: boolean;
hasModelSelected?: boolean;
isSelectedModelInCache?: boolean;
submitTooltip?: string;
useGlobalSelection?: boolean;
}
let {
currentModel,
disabled = false,
forceForegroundText = false,
hasAudioModality = $bindable(false),
hasVisionModality = $bindable(false),
hasModelSelected = $bindable(false),
isSelectedModelInCache = $bindable(true),
submitTooltip = $bindable(''),
useGlobalSelection = false
}: Props = $props();
let isRouter = $derived(isRouterMode());
let isOffline = $derived(!!serverError());
let conversationModel = $derived(
chatStore.getConversationModel(activeMessages() as DatabaseMessage[])
);
let lastSyncedConversationModel: string | null = null;
$effect(() => {
if (conversationModel && conversationModel !== lastSyncedConversationModel) {
lastSyncedConversationModel = conversationModel;
modelsStore.selectModelByName(conversationModel);
} else if (isRouter && !modelsStore.selectedModelId && modelsStore.loadedModelIds.length > 0) {
lastSyncedConversationModel = null;
// auto-select the first loaded model only when nothing is selected yet
const first = modelOptions().find((m) => modelsStore.loadedModelIds.includes(m.model));
if (first) modelsStore.selectModelById(first.id);
}
});
let activeModelId = $derived.by(() => {
const options = modelOptions();
if (!isRouter) {
return options.length > 0 ? options[0].model : null;
}
const selectedId = selectedModelId();
if (selectedId) {
const model = options.find((m) => m.id === selectedId);
if (model) return model.model;
}
if (conversationModel) {
const model = options.find((m) => m.model === conversationModel);
if (model) return model.model;
}
return null;
});
let modelPropsVersion = $state(0); // Used to trigger reactivity after fetch
$effect(() => {
if (activeModelId) {
const cached = modelsStore.getModelProps(activeModelId);
if (!cached) {
modelsStore.fetchModelProps(activeModelId).then(() => {
modelPropsVersion++;
});
}
}
});
$effect(() => {
hasAudioModality = activeModelId ? modelsStore.modelSupportsAudio(activeModelId) : false;
});
$effect(() => {
void modelPropsVersion;
hasVisionModality = activeModelId ? modelsStore.modelSupportsVision(activeModelId) : false;
});
$effect(() => {
hasModelSelected = !isRouter || !!conversationModel || !!selectedModelId();
});
$effect(() => {
if (!isRouter) {
isSelectedModelInCache = true;
} else if (conversationModel) {
isSelectedModelInCache = modelOptions().some((option) => option.model === conversationModel);
} else {
const currentModelId = selectedModelId();
if (!currentModelId) {
isSelectedModelInCache = false;
} else {
isSelectedModelInCache = modelOptions().some((option) => option.id === currentModelId);
}
}
});
$effect(() => {
if (!hasModelSelected) {
submitTooltip = 'Please select a model first';
} else if (!isSelectedModelInCache) {
submitTooltip = 'Selected model is not available, please select another';
} else {
submitTooltip = '';
}
});
let selectorModelRef: ModelsSelectorDropdown | ModelsSelectorSheet | undefined =
$state(undefined);
let isMobile = new IsMobile();
export function open() {
selectorModelRef?.open();
}
</script>
{#if isMobile.current}
<ModelsSelectorSheet
disabled={disabled || isOffline}
bind:this={selectorModelRef}
{currentModel}
{forceForegroundText}
{useGlobalSelection}
/>
{:else}
<ModelsSelectorDropdown
disabled={disabled || isOffline}
bind:this={selectorModelRef}
{currentModel}
{forceForegroundText}
{useGlobalSelection}
/>
{/if}

View file

@ -2,7 +2,6 @@
import { ArrowUp } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import * as Tooltip from '$lib/components/ui/tooltip';
import { cn } from '$lib/components/ui/utils';
interface Props {
canSend?: boolean;
@ -20,12 +19,11 @@
<Button
type="submit"
disabled={isDisabled}
class={cn(
class={[
'h-8 w-8 rounded-full p-0',
showErrorState
? 'bg-red-400/10 text-red-400 hover:bg-red-400/20 hover:text-red-400 disabled:opacity-100'
: ''
)}
showErrorState &&
'bg-red-400/10 text-red-400 hover:bg-red-400/20 hover:text-red-400 disabled:opacity-100'
]}
{...props}
>
<span class="sr-only">Send</span>

View file

@ -2,31 +2,27 @@
import { Square } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import {
ChatFormActionAttachmentsDropdown,
ChatFormActionAttachmentsSheet,
ChatFormActionsAdd,
ChatFormActionModels,
ChatFormActionRecord,
ChatFormActionSubmit,
ModelsSelectorDropdown,
ModelsSelectorSheet
ChatFormActionSubmit
} from '$lib/components/app';
import { FileTypeCategory } from '$lib/enums';
import { IsMobile } from '$lib/hooks/is-mobile.svelte';
import { chatStore } from '$lib/stores/chat.svelte';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
import { isRouterMode, serverError } from '$lib/stores/server.svelte';
import { config } from '$lib/stores/settings.svelte';
import { activeMessages, conversationsStore } from '$lib/stores/conversations.svelte';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { getFileTypeCategory } from '$lib/utils';
import { goto } from '$app/navigation';
interface Props {
canSend?: boolean;
canSubmit?: boolean;
class?: string;
disabled?: boolean;
isLoading?: boolean;
isRecording?: boolean;
hasText?: boolean;
showAddButton?: boolean;
showModelSelector?: boolean;
uploadedFiles?: ChatUploadedFile[];
onFileUpload?: () => void;
onMicClick?: () => void;
@ -38,11 +34,13 @@
let {
canSend = false,
canSubmit = false,
class: className = '',
disabled = false,
isLoading = false,
isRecording = false,
hasText = false,
showAddButton = true,
showModelSelector = true,
uploadedFiles = [],
onFileUpload,
onMicClick,
@ -53,124 +51,6 @@
}: Props = $props();
let currentConfig = $derived(config());
let isRouter = $derived(isRouterMode());
let isOffline = $derived(!!serverError());
let conversationModel = $derived(
chatStore.getConversationModel(activeMessages() as DatabaseMessage[])
);
let lastSyncedConversationModel: string | null = null;
$effect(() => {
if (conversationModel && conversationModel !== lastSyncedConversationModel) {
lastSyncedConversationModel = conversationModel;
modelsStore.selectModelByName(conversationModel);
} else if (isRouter && !modelsStore.selectedModelId && modelsStore.loadedModelIds.length > 0) {
lastSyncedConversationModel = null;
// auto-select the first loaded model only when nothing is selected yet
const first = modelOptions().find((m) => modelsStore.loadedModelIds.includes(m.model));
if (first) modelsStore.selectModelById(first.id);
}
});
let activeModelId = $derived.by(() => {
const options = modelOptions();
if (!isRouter) {
return options.length > 0 ? options[0].model : null;
}
const selectedId = selectedModelId();
if (selectedId) {
const model = options.find((m) => m.id === selectedId);
if (model) return model.model;
}
if (conversationModel) {
const model = options.find((m) => m.model === conversationModel);
if (model) return model.model;
}
return null;
});
let modelPropsVersion = $state(0); // Used to trigger reactivity after fetch
$effect(() => {
if (activeModelId) {
const cached = modelsStore.getModelProps(activeModelId);
if (!cached) {
modelsStore.fetchModelProps(activeModelId).then(() => {
modelPropsVersion++;
});
}
}
});
let hasAudioModality = $derived.by(() => {
if (activeModelId) {
void modelPropsVersion;
return modelsStore.modelSupportsAudio(activeModelId);
}
return false;
});
let hasVisionModality = $derived.by(() => {
if (activeModelId) {
void modelPropsVersion;
return modelsStore.modelSupportsVision(activeModelId);
}
return false;
});
let hasAudioAttachments = $derived(
uploadedFiles.some((file) => getFileTypeCategory(file.type) === FileTypeCategory.AUDIO)
);
let shouldShowRecordButton = $derived(
hasAudioModality && !hasText && !hasAudioAttachments && currentConfig.autoMicOnEmpty
);
let hasModelSelected = $derived(!isRouter || !!conversationModel || !!selectedModelId());
let isSelectedModelInCache = $derived.by(() => {
if (!isRouter) return true;
if (conversationModel) {
return modelOptions().some((option) => option.model === conversationModel);
}
const currentModelId = selectedModelId();
if (!currentModelId) return false;
return modelOptions().some((option) => option.id === currentModelId);
});
let submitTooltip = $derived.by(() => {
if (!hasModelSelected) {
return 'Please select a model first';
}
if (!isSelectedModelInCache) {
return 'Selected model is not available, please select another';
}
return '';
});
let selectorModelRef: ModelsSelectorDropdown | ModelsSelectorSheet | undefined =
$state(undefined);
let isMobile = new IsMobile();
export function openModelSelector() {
selectorModelRef?.open();
}
let hasMcpPromptsSupport = $derived.by(() => {
const perChatOverrides = conversationsStore.getAllMcpServerOverrides();
@ -183,25 +63,34 @@
return mcpStore.hasResourcesCapability(perChatOverrides);
});
let hasAudioModality = $state(false);
let hasVisionModality = $state(false);
let hasModelSelected = $state(false);
let isSelectedModelInCache = $state(true);
let submitTooltip = $state('');
let hasAudioAttachments = $derived(
uploadedFiles.some((file) => getFileTypeCategory(file.type) === FileTypeCategory.AUDIO)
);
let shouldShowRecordButton = $derived(
hasAudioModality && !canSubmit && !hasAudioAttachments && currentConfig.autoMicOnEmpty
);
let selectorModelRef: ChatFormActionModels | undefined = $state(undefined);
export function openModelSelector() {
selectorModelRef?.open();
}
</script>
<div class="flex w-full items-center gap-3 {className}" style="container-type: inline-size">
<div class="mr-auto flex items-center gap-2">
{#if isMobile.current}
<ChatFormActionAttachmentsSheet
{disabled}
{hasAudioModality}
{hasVisionModality}
{hasMcpPromptsSupport}
{hasMcpResourcesSupport}
{onFileUpload}
{onSystemPromptClick}
{onMcpPromptClick}
onMcpSettingsClick={() => goto('#/settings/mcp')}
{onMcpResourcesClick}
/>
{:else}
<ChatFormActionAttachmentsDropdown
<div
class="flex w-full items-center gap-3 {className} {showAddButton ? '' : 'justify-end'}"
style="container-type: inline-size"
>
{#if showAddButton}
<div class="mr-auto flex items-center gap-2">
<ChatFormActionsAdd
{disabled}
{hasAudioModality}
{hasVisionModality}
@ -213,30 +102,24 @@
{onMcpResourcesClick}
onMcpSettingsClick={() => goto('#/settings/mcp')}
/>
{/if}
</div>
</div>
{/if}
<div class="ml-auto flex items-center gap-2">
{#if isMobile.current}
<ModelsSelectorSheet
disabled={disabled || isOffline}
bind:this={selectorModelRef}
currentModel={conversationModel}
forceForegroundText
useGlobalSelection
/>
{:else}
<ModelsSelectorDropdown
disabled={disabled || isOffline}
bind:this={selectorModelRef}
currentModel={conversationModel}
forceForegroundText
useGlobalSelection
/>
{/if}
</div>
{#if showModelSelector}
<ChatFormActionModels
{disabled}
bind:this={selectorModelRef}
bind:hasAudioModality
bind:hasVisionModality
bind:hasModelSelected
bind:isSelectedModelInCache
bind:submitTooltip
forceForegroundText
useGlobalSelection
/>
{/if}
{#if isLoading && !hasText}
{#if isLoading && !canSubmit}
<Button
type="button"
variant="secondary"
@ -253,10 +136,10 @@
<ChatFormActionRecord {disabled} {hasAudioModality} {isLoading} {isRecording} {onMicClick} />
{:else}
<ChatFormActionSubmit
canSend={canSend && hasModelSelected && isSelectedModelInCache}
canSend={canSend && (showModelSelector ? hasModelSelected && isSelectedModelInCache : true)}
{disabled}
tooltipLabel={submitTooltip}
showErrorState={hasModelSelected && !isSelectedModelInCache}
showErrorState={showModelSelector && hasModelSelected && !isSelectedModelInCache}
/>
{/if}
</div>

View file

@ -15,6 +15,7 @@
function handleFileSelect(event: Event) {
const input = event.target as HTMLInputElement;
if (input.files) {
onFileSelect?.(Array.from(input.files));
}

View file

@ -1,31 +0,0 @@
<script lang="ts">
import { browser } from '$app/environment';
import { config } from '$lib/stores/settings.svelte';
interface Props {
class?: string;
show?: boolean;
}
let { class: className = '', show = true }: Props = $props();
let sendOnEnter = $derived(config().sendOnEnter !== false);
let modKey = browser && /Mac|iPhone|iPad|iPod/.test(navigator.platform) ? 'Cmd' : 'Ctrl';
</script>
{#if show}
<div class="mt-6 items-center justify-center {className} hidden md:flex">
{#if sendOnEnter}
<p class="text-xs text-muted-foreground">
Press <kbd class="rounded bg-muted px-1 py-0.5 font-mono text-xs">Enter</kbd> to send,
<kbd class="rounded bg-muted px-1 py-0.5 font-mono text-xs">Shift + Enter</kbd> for new line
</p>
{:else}
<p class="text-xs text-muted-foreground">
Press <kbd class="rounded bg-muted px-1 py-0.5 font-mono text-xs">{modKey} + Enter</kbd> to
send,
<kbd class="rounded bg-muted px-1 py-0.5 font-mono text-xs">Enter</kbd> for new line
</p>
{/if}
</div>
{/if}

View file

@ -4,7 +4,10 @@
mcpResourceAttachments,
mcpHasResourceAttachments
} from '$lib/stores/mcp-resources.svelte';
import { ChatAttachmentMcpResource, HorizontalScrollCarousel } from '$lib/components/app';
import {
ChatAttachmentsListItemMcpResource,
HorizontalScrollCarousel
} from '$lib/components/app';
interface Props {
class?: string;
@ -29,11 +32,11 @@
<div class={className}>
<HorizontalScrollCarousel gapSize="2">
{#each attachments as attachment, i (attachment.id)}
<ChatAttachmentMcpResource
<ChatAttachmentsListItemMcpResource
class={i === 0 ? 'ml-3' : ''}
{attachment}
onRemove={handleRemove}
onClick={() => handleResourceClick(attachment.resource.uri)}
onclick={() => handleResourceClick(attachment.resource.uri)}
/>
{/each}
</HorizontalScrollCarousel>

View file

@ -60,8 +60,7 @@
<div
bind:this={listContainer}
class="{CHAT_FORM_POPOVER_MAX_HEIGHT} p-2"
class:pt-13={showSearchInput}
class={[`${CHAT_FORM_POPOVER_MAX_HEIGHT} p-2`, showSearchInput && 'pt-13']}
>
{#if isLoading}
{#if skeleton}

View file

@ -3,18 +3,18 @@
interface Props {
isSelected?: boolean;
onClick: () => void;
onclick: () => void;
dataIndex?: number;
children: Snippet;
}
let { isSelected = false, onClick, dataIndex, children }: Props = $props();
let { isSelected = false, onclick, dataIndex, children }: Props = $props();
</script>
<button
type="button"
data-picker-index={dataIndex}
onclick={onClick}
{onclick}
class="flex w-full cursor-pointer items-start gap-3 rounded-lg px-3 py-2 text-left hover:bg-accent/50 {isSelected
? 'bg-accent/50'
: ''}"

View file

@ -39,7 +39,7 @@
sideOffset={12}
class="w-[var(--bits-popover-anchor-width)] max-w-none rounded-xl border-border/50 p-0 shadow-xl {className}"
onkeydown={onKeydown}
onOpenAutoFocus={(e) => e.preventDefault()}
onOpenAutoFocus={(event) => event.preventDefault()}
>
{@render children()}
</Popover.Content>

View file

@ -5,7 +5,6 @@
import { KeyboardKey } from '$lib/enums';
import type { MCPPromptInfo, GetPromptResult, MCPServerSettingsEntry } from '$lib/types';
import { SvelteMap } from 'svelte/reactivity';
import Badge from '$lib/components/ui/badge/badge.svelte';
import {
ChatFormPickerPopover,
ChatFormPickerList,
@ -14,6 +13,7 @@
ChatFormPickerListItemSkeleton,
ChatFormPromptPickerArgumentForm
} from '$lib/components/app/chat';
import Badge from '$lib/components/ui/badge/badge.svelte';
interface Props {
class?: string;
@ -97,7 +97,7 @@
prompts = await mcpStore.getAllPrompts();
} catch (error) {
console.error('[ChatFormPromptPicker] Failed to load prompts:', error);
console.error('[ChatFormPickerMcpPrompts] Failed to load prompts:', error);
prompts = [];
} finally {
isLoading = false;
@ -163,7 +163,7 @@
}
if (import.meta.env.DEV) {
console.log('[ChatFormPromptPicker] Fetching completions for:', {
console.log('[ChatFormPickerMcpPrompts] Fetching completions for:', {
serverName: selectedPrompt.serverName,
promptName: selectedPrompt.name,
argName,
@ -182,7 +182,7 @@
);
if (import.meta.env.DEV) {
console.log('[ChatFormPromptPicker] Autocomplete result:', {
console.log('[ChatFormPickerMcpPrompts] Autocomplete result:', {
argName,
value,
result,
@ -205,7 +205,7 @@
suggestions[argName] = [];
}
} catch (error) {
console.error('[ChatFormPromptPicker] Failed to fetch completions:', error);
console.error('[ChatFormPickerMcpPrompts] Failed to fetch completions:', error);
suggestions[argName] = [];
} finally {
loadingSuggestions[argName] = false;
@ -408,7 +408,7 @@
<ChatFormPickerListItem
dataIndex={index}
{isSelected}
onClick={() => handlePromptClick(prompt)}
onclick={() => handlePromptClick(prompt)}
>
<ChatFormPickerItemHeader
{server}

View file

@ -67,7 +67,6 @@
try {
const perChatOverrides = conversationsStore.getAllMcpServerOverrides();
const initialized = await mcpStore.ensureInitialized(perChatOverrides);
if (!initialized) {
@ -79,7 +78,7 @@
await mcpStore.fetchAllResources();
resources = mcpResourceStore.getAllResourceInfos();
} catch (error) {
console.error('[ChatFormResourcePicker] Failed to load resources:', error);
console.error('[ChatFormPickerMcpResources] Failed to load resources:', error);
resources = [];
} finally {
isLoading = false;
@ -88,6 +87,7 @@
function handleResourceClick(resource: MCPResourceInfo) {
mcpStore.attachResource(resource.uri);
onResourceSelect?.(resource);
onClose?.();
}
@ -144,6 +144,7 @@
const sortedResources = [...resources].sort((a, b) => {
const orderA = serverOrderMap.get(a.serverName) ?? Number.MAX_SAFE_INTEGER;
const orderB = serverOrderMap.get(b.serverName) ?? Number.MAX_SAFE_INTEGER;
return orderA - orderB;
});
@ -186,7 +187,7 @@
<ChatFormPickerListItem
dataIndex={index}
{isSelected}
onClick={() => handleResourceClick(resource)}
onclick={() => handleResourceClick(resource)}
>
<ChatFormPickerItemHeader
{server}

View file

@ -0,0 +1,75 @@
<script lang="ts">
import ChatFormPickerMcpPrompts from './ChatFormPickerMcpPrompts/ChatFormPickerMcpPrompts.svelte';
import ChatFormPickerMcpResources from './ChatFormPickerMcpResources.svelte';
import type { GetPromptResult, MCPPromptInfo } from '$lib/types';
interface Props {
isPromptPickerOpen?: boolean;
promptSearchQuery?: string;
isInlineResourcePickerOpen?: boolean;
resourceSearchQuery?: string;
onPromptPickerClose?: () => void;
onInlineResourcePickerClose?: () => void;
onInlineResourceSelect?: () => void;
onPromptLoadStart?: (
placeholderId: string,
promptInfo: MCPPromptInfo,
args?: Record<string, string>
) => void;
onPromptLoadComplete?: (placeholderId: string, result: GetPromptResult) => void;
onPromptLoadError?: (placeholderId: string, error: string) => void;
onInlineResourceBrowse?: () => void;
}
let {
isPromptPickerOpen,
promptSearchQuery,
isInlineResourcePickerOpen,
resourceSearchQuery,
onPromptPickerClose,
onInlineResourcePickerClose,
onInlineResourceSelect,
onPromptLoadStart,
onPromptLoadComplete,
onPromptLoadError,
onInlineResourceBrowse
}: Props = $props();
let promptPickerRef: ChatFormPickerMcpPrompts | undefined = $state(undefined);
let resourcePickerRef: ChatFormPickerMcpResources | undefined = $state(undefined);
/**
* Delegates keyboard events to the active picker child.
* Returns true if the event was handled.
*/
export function handleKeydown(event: KeyboardEvent): boolean {
if (isPromptPickerOpen && promptPickerRef?.handleKeydown(event)) {
return true;
}
if (isInlineResourcePickerOpen && resourcePickerRef?.handleKeydown(event)) {
return true;
}
return false;
}
</script>
<ChatFormPickerMcpPrompts
bind:this={promptPickerRef}
isOpen={isPromptPickerOpen}
searchQuery={promptSearchQuery}
onClose={onPromptPickerClose}
{onPromptLoadStart}
{onPromptLoadComplete}
{onPromptLoadError}
/>
<ChatFormPickerMcpResources
bind:this={resourcePickerRef}
isOpen={isInlineResourcePickerOpen}
searchQuery={resourceSearchQuery}
onClose={onInlineResourcePickerClose}
onResourceSelect={onInlineResourceSelect}
onBrowse={onInlineResourceBrowse}
/>

View file

@ -51,8 +51,10 @@
<textarea
bind:this={textareaElement}
bind:value
class="text-md min-h-12 w-full resize-none border-0 bg-transparent p-0 leading-6 outline-none placeholder:text-muted-foreground focus-visible:ring-0 focus-visible:ring-offset-0"
class:cursor-not-allowed={disabled}
class={[
'text-md min-h-12 w-full resize-none border-0 bg-transparent p-0 leading-6 outline-none placeholder:text-muted-foreground focus-visible:ring-0 focus-visible:ring-offset-0',
disabled && 'cursor-not-allowed'
]}
style="max-height: var(--max-message-height);"
{disabled}
onkeydown={onKeydown}

View file

@ -5,7 +5,8 @@
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { DatabaseService } from '$lib/services/database.service';
import { SYSTEM_MESSAGE_PLACEHOLDER } from '$lib/constants';
import { MessageRole, AttachmentType } from '$lib/enums';
import { REASONING_TAGS } from '$lib/constants/agentic';
import { MessageRole, AttachmentType, AgenticSectionType } from '$lib/enums';
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
import {
ChatMessageAssistant,
@ -14,6 +15,7 @@
ChatMessageMcpPrompt
} from '$lib/components/app/chat';
import { parseFilesToMessageExtras } from '$lib/utils/browser-only';
import { deriveAgenticSections } from '$lib/utils';
import type { DatabaseMessageExtraMcpPrompt } from '$lib/types';
interface Props {
@ -41,6 +43,50 @@
messageTypes: string[];
} | null>(null);
let editedContent = $derived(message.content);
let rawEditContent = $derived.by(() => {
if (message.role !== MessageRole.ASSISTANT) return undefined;
const sections = deriveAgenticSections(message, toolMessages, [], false);
const parts: string[] = [];
for (const section of sections) {
switch (section.type) {
case AgenticSectionType.REASONING:
case AgenticSectionType.REASONING_PENDING:
parts.push(`${REASONING_TAGS.START}\n${section.content}\n${REASONING_TAGS.END}`);
break;
case AgenticSectionType.TEXT:
parts.push(section.content);
break;
case AgenticSectionType.TOOL_CALL:
case AgenticSectionType.TOOL_CALL_PENDING:
case AgenticSectionType.TOOL_CALL_STREAMING: {
const callObj: Record<string, unknown> = { name: section.toolName };
if (section.toolArgs) {
try {
callObj.arguments = JSON.parse(section.toolArgs);
} catch {
callObj.arguments = section.toolArgs;
}
}
parts.push(JSON.stringify(callObj, null, 2));
if (section.toolResult) {
parts.push(`[Tool Result]\n${section.toolResult}`);
}
break;
}
}
}
return parts.join('\n\n\n');
});
let editedExtras = $derived<DatabaseMessageExtra[]>(message.extra ? [...message.extra] : []);
let editedUploadedFiles = $state<ChatUploadedFile[]>([]);
let isEditing = $state(false);
@ -49,6 +95,7 @@
let textareaElement: HTMLTextAreaElement | undefined = $state();
let showSaveOnlyOption = $derived(message.role === MessageRole.USER);
let showBranchAfterEditOption = $derived(message.role === MessageRole.ASSISTANT);
setMessageEditContext({
get isEditing() {
@ -64,7 +111,9 @@
return editedUploadedFiles;
},
get originalContent() {
return message.content;
return message.role === MessageRole.ASSISTANT
? (rawEditContent ?? message.content)
: message.content;
},
get originalExtras() {
return message.extra || [];
@ -72,6 +121,18 @@
get showSaveOnlyOption() {
return showSaveOnlyOption;
},
get showBranchAfterEditOption() {
return showBranchAfterEditOption;
},
get shouldBranchAfterEdit() {
return shouldBranchAfterEdit;
},
get messageRole() {
return message.role;
},
get rawEditContent() {
return rawEditContent;
},
setContent: (content: string) => {
editedContent = content;
},
@ -81,6 +142,9 @@
setUploadedFiles: (files: ChatUploadedFile[]) => {
editedUploadedFiles = files;
},
setShouldBranchAfterEdit: (value: boolean) => {
shouldBranchAfterEdit = value;
},
save: handleSaveEdit,
saveOnly: handleSaveEditOnly,
cancel: handleCancelEdit,
@ -124,7 +188,10 @@
return;
}
editedContent = message.content;
editedContent =
message.role === MessageRole.ASSISTANT
? rawEditContent || message.content || ''
: message.content;
editedExtras = message.extra ? [...message.extra] : [];
editedUploadedFiles = [];
}
@ -155,10 +222,14 @@
function handleEdit() {
isEditing = true;
// Clear temporary placeholder content for system messages
editedContent =
message.role === MessageRole.SYSTEM && message.content === SYSTEM_MESSAGE_PLACEHOLDER
? ''
: message.content;
if (message.role === MessageRole.SYSTEM && message.content === SYSTEM_MESSAGE_PLACEHOLDER) {
editedContent = '';
} else if (message.role === MessageRole.ASSISTANT) {
editedContent = rawEditContent || message.content || '';
} else {
editedContent = message.content;
}
textareaElement?.focus();
editedExtras = message.extra ? [...message.extra] : [];
editedUploadedFiles = [];

View file

@ -1,7 +1,8 @@
<script lang="ts">
import {
ChatMessageAgenticContent,
ChatMessageActions,
ChatMessageActionIcons,
ChatMessageEditForm,
ChatMessageStatistics,
ModelBadge,
ModelsSelectorDropdown
@ -9,22 +10,12 @@
import { getMessageEditContext } from '$lib/contexts';
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
import { isLoading, isChatStreaming } from '$lib/stores/chat.svelte';
import {
autoResizeTextarea,
copyToClipboard,
isIMEComposing,
deriveAgenticSections
} from '$lib/utils';
import { copyToClipboard, deriveAgenticSections } from '$lib/utils';
import { AgenticSectionType } from '$lib/enums';
import { REASONING_TAGS } from '$lib/constants/agentic';
import { tick } from 'svelte';
import { fade } from 'svelte/transition';
import { Check, X } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import { Checkbox } from '$lib/components/ui/checkbox';
import { INPUT_CLASSES } from '$lib/constants';
import { MessageRole, KeyboardKey, ChatMessageStatsView } from '$lib/enums';
import Label from '$lib/components/ui/label/label.svelte';
import { MessageRole, ChatMessageStatsView } from '$lib/enums';
import { config } from '$lib/stores/settings.svelte';
import { isRouterMode } from '$lib/stores/server.svelte';
import { modelsStore } from '$lib/stores/models.svelte';
@ -82,19 +73,6 @@
// Get edit context
const editCtx = getMessageEditContext();
// Local state for assistant-specific editing
let shouldBranchAfterEdit = $state(false);
function handleEditKeydown(event: KeyboardEvent) {
if (event.key === KeyboardKey.ENTER && !event.shiftKey && !isIMEComposing(event)) {
event.preventDefault();
editCtx.save();
} else if (event.key === KeyboardKey.ESCAPE) {
event.preventDefault();
editCtx.cancel();
}
}
const isAgentic = $derived(hasAgenticContent(message, toolMessages));
const hasReasoning = $derived(!!message.reasoningContent);
const processingState = useProcessingState();
@ -227,12 +205,6 @@
void copyToClipboard(displayedModel ?? '');
}
$effect(() => {
if (editCtx.isEditing && textareaElement) {
autoResizeTextarea(textareaElement);
}
});
$effect(() => {
if (showProcessingInfoTop || showProcessingInfoBottom) {
processingState.startMonitoring();
@ -258,48 +230,7 @@
{/if}
{#if editCtx.isEditing}
<div class="w-full">
<textarea
bind:this={textareaElement}
value={editCtx.editedContent}
class="min-h-[50vh] w-full resize-y rounded-2xl px-3 py-2 text-sm {INPUT_CLASSES}"
onkeydown={handleEditKeydown}
oninput={(e) => {
autoResizeTextarea(e.currentTarget);
editCtx.setContent(e.currentTarget.value);
}}
placeholder="Edit assistant message..."
></textarea>
<div class="mt-2 flex items-center justify-between">
<div class="flex items-center space-x-2">
<Checkbox
id="branch-after-edit"
bind:checked={shouldBranchAfterEdit}
onCheckedChange={(checked) => (shouldBranchAfterEdit = checked === true)}
/>
<Label for="branch-after-edit" class="cursor-pointer text-sm text-muted-foreground">
Branch conversation after edit
</Label>
</div>
<div class="flex gap-2">
<Button class="h-8 px-3" onclick={editCtx.cancel} size="sm" variant="outline">
<X class="mr-1 h-3 w-3" />
Cancel
</Button>
<Button
class="h-8 px-3"
onclick={editCtx.save}
disabled={!editCtx.editedContent?.trim()}
size="sm"
>
<Check class="mr-1 h-3 w-3" />
Save
</Button>
</div>
</div>
</div>
<ChatMessageEditForm />
{:else if message.role === MessageRole.ASSISTANT}
{#if showRawOutput}
<pre class="raw-output">{rawOutputContent || ''}</pre>
@ -388,7 +319,7 @@
</div>
{#if message.timestamp && !editCtx.isEditing}
<ChatMessageActions
<ChatMessageActionIcons
role={MessageRole.ASSISTANT}
justify="start"
actionsPosition="left"

View file

@ -1,6 +1,6 @@
<script lang="ts">
import {
ChatMessageActions,
ChatMessageActionIcons,
ChatMessageEditForm,
ChatMessageMcpPromptContent
} from '$lib/components/app';
@ -63,7 +63,7 @@
{#if message.timestamp}
<div class="max-w-[80%]">
<ChatMessageActions
<ChatMessageActionIcons
actionsPosition="right"
{deletionInfo}
justify="end"

View file

@ -1,14 +1,13 @@
<script lang="ts">
import { Check, X } from '@lucide/svelte';
import { Card } from '$lib/components/ui/card';
import { ChatMessageActionIcons, MarkdownContent } from '$lib/components/app';
import { Button } from '$lib/components/ui/button';
import { MarkdownContent } from '$lib/components/app';
import { getMessageEditContext } from '$lib/contexts';
import { Card } from '$lib/components/ui/card';
import { INPUT_CLASSES } from '$lib/constants';
import { getMessageEditContext } from '$lib/contexts';
import { KeyboardKey, MessageRole } from '$lib/enums';
import { config } from '$lib/stores/settings.svelte';
import { isIMEComposing } from '$lib/utils';
import ChatMessageActions from './ChatMessageActions.svelte';
import { KeyboardKey, MessageRole } from '$lib/enums';
interface Props {
class?: string;
@ -213,7 +212,7 @@
{#if message.timestamp}
<div class="max-w-[80%]">
<ChatMessageActions
<ChatMessageActionIcons
actionsPosition="right"
{deletionInfo}
justify="end"

View file

@ -1,9 +1,11 @@
<script lang="ts">
import {
ChatMessageActionIcons,
ChatMessageEditForm,
ChatMessageUserBubble
} from '$lib/components/app/chat';
import { getMessageEditContext } from '$lib/contexts';
import ChatMessageActions from './ChatMessageActions.svelte';
import ChatMessageEditForm from './ChatMessageEditForm.svelte';
import { MessageRole } from '$lib/enums';
import ChatMessageUserBubble from './ChatMessageUserBubble.svelte';
interface Props {
class?: string;
@ -60,7 +62,7 @@
{#if message.timestamp}
<div class="max-w-[80%]">
<ChatMessageActions
<ChatMessageActionIcons
actionsPosition="right"
{deletionInfo}
justify="end"

View file

@ -1,11 +1,9 @@
<script lang="ts">
import { ActionIcon } from '$lib/components/app';
import ChatMessageEditForm from './ChatMessageEditForm.svelte';
import { ActionIcon, ChatMessageEditForm, ChatMessageUserBubble } from '$lib/components/app';
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
import { ArrowUp, Edit, Trash2 } from '@lucide/svelte';
import { getProcessingInfoContext } from '$lib/contexts';
import { useMessageEditContext } from '$lib/hooks/use-message-edit-context.svelte';
import ChatMessageUserBubble from './ChatMessageUserBubble.svelte';
interface Props {
class?: string;

View file

@ -7,12 +7,12 @@
actions: Snippet;
}
let { icon: Icon, message, actions }: Props = $props();
let { icon: IconComponent, message, actions }: Props = $props();
</script>
<div class="my-2 rounded-lg border border-border bg-card p-3">
<div class="mb-3 flex items-center gap-2 text-sm">
<Icon class="h-4 w-4 shrink-0 text-muted-foreground" />
<IconComponent class="h-4 w-4 shrink-0 text-muted-foreground" />
<span>
{@render message()}
</span>

View file

@ -1,12 +1,12 @@
<script lang="ts">
import { ChevronDown, ShieldQuestion } from '@lucide/svelte';
import { ChatMessageActionCard } from '$lib/components/app';
import { Button } from '$lib/components/ui/button';
import * as ButtonGroup from '$lib/components/ui/button-group';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import { ToolSource, ToolPermissionDecision } from '$lib/enums';
import { TOOL_SERVER_LABELS } from '$lib/constants';
import { toolsStore } from '$lib/stores/tools.svelte';
import ChatMessageActionCard from './ChatMessageActionCard.svelte';
interface Props {
toolName: string;

View file

@ -2,7 +2,7 @@
import { Edit, Copy, RefreshCw, Trash2, ArrowRight, GitBranch } from '@lucide/svelte';
import {
ActionIcon,
ChatMessageBranchingControls,
ChatMessageActionIconsBranchingControls,
DialogConfirmation
} from '$lib/components/app';
import { Switch } from '$lib/components/ui/switch';
@ -89,7 +89,7 @@
: 'right-0'} flex items-center gap-2 opacity-100 transition-opacity"
>
{#if siblingInfo && siblingInfo.totalSiblings > 1}
<ChatMessageBranchingControls {siblingInfo} {onNavigateToSibling} />
<ChatMessageActionIconsBranchingControls {siblingInfo} {onNavigateToSibling} />
{/if}
<div

View file

@ -0,0 +1,49 @@
<script lang="ts">
import { ChevronLeft, ChevronRight } from '@lucide/svelte';
import { ActionIcon } from '$lib/components/app';
interface Props {
class?: string;
siblingInfo: ChatMessageSiblingInfo | null;
onNavigateToSibling?: (siblingId: string) => void;
}
let { class: className = '', siblingInfo, onNavigateToSibling }: Props = $props();
let hasPrevious = $derived(siblingInfo && siblingInfo.currentIndex > 0);
let hasNext = $derived(siblingInfo && siblingInfo.currentIndex < siblingInfo.totalSiblings - 1);
let nextSiblingId = $derived(
hasNext ? siblingInfo!.siblingIds[siblingInfo!.currentIndex + 1] : null
);
let previousSiblingId = $derived(
hasPrevious ? siblingInfo!.siblingIds[siblingInfo!.currentIndex - 1] : null
);
</script>
{#if siblingInfo && siblingInfo.totalSiblings > 1}
<div
aria-label="Message version {siblingInfo.currentIndex + 1} of {siblingInfo.totalSiblings}"
class="flex items-center gap-1 text-xs text-muted-foreground {className}"
role="navigation"
>
<ActionIcon
icon={ChevronLeft}
tooltip="Previous version"
disabled={!hasPrevious}
class="h-5 w-5 p-0 {!hasPrevious ? '!cursor-not-allowed opacity-30' : ''}"
onclick={() => onNavigateToSibling?.(previousSiblingId!)}
/>
<span class="px-1 font-mono text-xs">
{siblingInfo.currentIndex + 1}/{siblingInfo.totalSiblings}
</span>
<ActionIcon
icon={ChevronRight}
tooltip="Next version"
disabled={!hasNext}
class="h-5 w-5 p-0 {!hasNext ? 'opacity-30' : ''}"
onclick={() => onNavigateToSibling?.(nextSiblingId!)}
/>
</div>
{/if}

View file

@ -5,8 +5,8 @@
CollapsibleContentBlock,
MarkdownContent,
SyntaxHighlightedCode,
ChatMessagePermissionRequest,
ChatMessageContinueRequest
ChatMessageActionCardPermissionRequest,
ChatMessageActionCardContinueRequest
} from '$lib/components/app';
import {
@ -359,7 +359,7 @@
{/if}
{#if pendingPermission && !permissionDismissed}
<ChatMessagePermissionRequest
<ChatMessageActionCardPermissionRequest
toolName={pendingPermission.toolName}
serverLabel={pendingPermission.serverLabel}
onDecision={handlePermission}
@ -367,7 +367,7 @@
{/if}
{#if pendingContinue && !continueDismissed}
<ChatMessageContinueRequest onDecision={handleContinue} />
<ChatMessageActionCardContinueRequest onDecision={handleContinue} />
{/if}
</div>

Some files were not shown because too many files have changed in this diff Show more