Merge commit '5016b72862' into concedo_experimental

# Conflicts:
#	.github/workflows/build.yml
#	docs/ops.md
#	docs/ops/SYCL.csv
#	ggml/src/ggml-cann/aclnn_ops.cpp
#	ggml/src/ggml-opencl/ggml-opencl.cpp
#	ggml/src/ggml-sycl/backend.hpp
#	ggml/src/ggml-sycl/binbcast.cpp
#	ggml/src/ggml-sycl/binbcast.hpp
#	ggml/src/ggml-sycl/common.hpp
#	ggml/src/ggml-sycl/element_wise.cpp
#	ggml/src/ggml-sycl/element_wise.hpp
#	ggml/src/ggml-sycl/ggml-sycl.cpp
#	tests/test-chat-parser.cpp
#	tests/test-json-partial.cpp
This commit is contained in:
Concedo 2025-10-16 12:05:21 +08:00
commit 1ff97f8a00
32 changed files with 912 additions and 401 deletions

View file

@ -432,7 +432,7 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
if (is_arguments_path({})) {
// Entire JSON is the arguments and was parsed fully.
return consume_json_result {
partial->json.dump(),
partial->json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true),
/* .is_partial = */ false,
};
}
@ -444,7 +444,7 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
std::vector<std::string> path;
std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
if (is_arguments_path(path)) {
auto arguments = j.dump();
auto arguments = j.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true);
if (is_partial() && !partial->healing_marker.marker.empty()) {
auto idx = arguments.find(partial->healing_marker.json_dump_marker);
if (idx != std::string::npos) {

View file

@ -5,6 +5,7 @@
#include <nlohmann/json.hpp>
#include <string>
#include <regex>
using json = nlohmann::ordered_json;
@ -168,6 +169,47 @@ bool common_json_parse(
}
}
// Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX
static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)");
auto is_high_surrogate = [&](const std::string & s) {
// Check if a partial of a high surrogate (U+D800-U+DBFF)
return s.length() >= 4 &&
s[0] == '\\' && s[1] == 'u' &&
std::tolower(s[2]) == 'd' &&
(s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b');
};
// Initialize the unicode marker to a low surrogate to handle the edge case
// where a high surrogate (U+D800-U+DBFF) is immediately followed by a
// backslash (\)
std::string unicode_marker_padding = "udc00";
std::smatch last_unicode_seq;
if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) {
std::smatch second_last_seq;
std::string prelude = str.substr(0, last_unicode_seq.position());
// Pad the escape sequence with 0s until it forms a complete sequence of 6 characters
unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0');
if (is_high_surrogate(last_unicode_seq.str())) {
// If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF)
unicode_marker_padding += "\\udc00";
} else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) {
if (is_high_surrogate(second_last_seq.str())) {
// If this follows a high surrogate, pad it to be a low surrogate
if (last_unicode_seq.length() == 2) {
unicode_marker_padding = "dc00";
} else if (last_unicode_seq.length() == 3) {
unicode_marker_padding = "c00";
} else {
// The original unicode_marker_padding is already padded with 0s
}
}
}
}
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
@ -186,6 +228,9 @@ bool common_json_parse(
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
// Was inside an object value string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
// Was inside an object value string after a partial unicode escape
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
} else {
// find last :
auto last_pos = str.find_last_of(':');
@ -205,6 +250,9 @@ bool common_json_parse(
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
// Was inside an array value string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
// Was inside an array value string after a partial unicode escape
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
// Had just finished a value
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
@ -230,6 +278,9 @@ bool common_json_parse(
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
// Was inside an object key string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
} else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) {
// Was inside an object key string after a partial unicode escape
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing;
} else {
auto last_pos = str.find_last_of(':');
if (last_pos == std::string::npos) {

View file

@ -68,7 +68,7 @@ struct ggml_compute_params {
#endif // __VXE2__
#endif // __s390x__ && __VEC__
#if defined(__ARM_FEATURE_SVE)
#if defined(__ARM_FEATURE_SVE) && defined(__linux__)
#include <sys/prctl.h>
#endif

View file

@ -694,8 +694,13 @@ bool ggml_is_numa(void) {
#endif
static void ggml_init_arm_arch_features(void) {
#if defined(__linux__) && defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
#if defined(__linux__)
ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
#else
// TODO: add support of SVE for non-linux systems
#error "TODO: SVE is not supported on this platform. To use SVE, sve_cnt needs to be initialized here."
#endif
#endif
}

View file

@ -540,10 +540,12 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
}
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) && (!oob_check || i_KQ < k_VKQ_sup) ?
slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
if (!oob_check || i_KQ < k_VKQ_sup) {
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?
slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
}
}
KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]);
@ -581,10 +583,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
float KQ_sum_add = 0.0f;
#pragma unroll
for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
const float val = expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]);
if (!oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup) {
KQ_sum_add += val;
}
const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ?
expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
KQ_sum_add += val;
tmp[i0/(np*warp_size)][jc1] = val;
}
KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;
@ -975,26 +976,6 @@ static __global__ void flash_attn_tile(
}
}
if (gridDim.y == 1) {
#pragma unroll
for (int jc0 = 0; jc0 < cpw; ++jc0) {
#ifdef FAST_FP16_AVAILABLE
const half2 KQ_sum_jc_inv = make_half2(1.0f/KQ_sum[jc0], 1.0f/KQ_sum[jc0]);
#pragma unroll
for (int i = 0; i < (DVp/2)/warp_size; ++i) {
VKQ[jc0*((DVp/2)/warp_size) + i] *= KQ_sum_jc_inv;
}
#else
const float KQ_sum_jc_inv = 1.0f/KQ_sum[jc0];
#pragma unroll
for (int i = 0; i < (DVp/2)/warp_size; ++i) {
VKQ[jc0*((DVp/2)/warp_size) + i].x *= KQ_sum_jc_inv;
VKQ[jc0*((DVp/2)/warp_size) + i].y *= KQ_sum_jc_inv;
}
#endif // FAST_FP16_AVAILABLE
}
}
// Write back results:
#pragma unroll
for (int jc0 = 0; jc0 < cpw; ++jc0) {
@ -1007,6 +988,8 @@ static __global__ void flash_attn_tile(
return;
}
const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f;
const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
#ifdef FAST_FP16_AVAILABLE
@ -1017,6 +1000,8 @@ static __global__ void flash_attn_tile(
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);
tmp[i1].x *= scale;
tmp[i1].y *= scale;
}
if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) {
ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
@ -1027,6 +1012,11 @@ static __global__ void flash_attn_tile(
#pragma unroll
for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) {
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale;
VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale;
}
ggml_cuda_memcpy_1<cpy_ne_D*4>(
&dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D],
&VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);

View file

@ -268,6 +268,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_SUM);
char base[256];
char name[256];
snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
@ -1482,3 +1501,40 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_OPT_STEP_ADAMW);
char base[256];
char name[256];
snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_OPT_STEP_SGD);
char base[256];
char name[256];
snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}

View file

@ -109,6 +109,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
@ -134,6 +135,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,

View file

@ -656,6 +656,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_COS:
case GGML_OP_LOG:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
@ -798,6 +799,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return false;
};
}
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return has_simdgroup_reduction;
default:
return false;
}

View file

@ -544,6 +544,10 @@ typedef struct{
float limit;
} ggml_metal_kargs_glu;
typedef struct {
uint64_t np;
} ggml_metal_kargs_sum;
typedef struct {
int64_t ne00;
int64_t ne01;
@ -773,4 +777,12 @@ typedef struct {
uint64_t nb01;
} ggml_metal_kargs_argmax;
typedef struct {
int64_t np;
} ggml_metal_kargs_opt_step_adamw;
typedef struct {
int64_t np;
} ggml_metal_kargs_opt_step_sgd;
#endif // GGML_METAL_IMPL

View file

@ -301,6 +301,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_glu(ctx, idx);
} break;
case GGML_OP_SUM:
{
n_fuse = ggml_metal_op_sum(ctx, idx);
} break;
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
{
@ -410,6 +414,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_argmax(ctx, idx);
} break;
case GGML_OP_OPT_STEP_ADAMW:
{
n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
} break;
case GGML_OP_OPT_STEP_SGD:
{
n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
} break;
default:
{
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
@ -840,6 +852,30 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);
ggml_metal_kargs_sum args = {
/*.np =*/ n,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
@ -3401,3 +3437,73 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
const int64_t np = ggml_nelements(op->src[0]);
ggml_metal_kargs_opt_step_adamw args = {
/*.np =*/ np,
};
int ida = 0;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
const int64_t n = (np + nth - 1) / nth;
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
return 1;
}
int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
const int64_t np = ggml_nelements(op->src[0]);
ggml_metal_kargs_opt_step_sgd args = {
/*.np =*/ np,
};
int ida = 0;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
const int64_t n = (np + nth - 1) / nth;
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
return 1;
}

View file

@ -50,6 +50,7 @@ int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
@ -78,6 +79,8 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
#ifdef __cplusplus
}

View file

@ -1723,6 +1723,24 @@ kernel void kernel_geglu_quick_f32(
}
}
kernel void kernel_op_sum_f32(
constant ggml_metal_kargs_sum & args,
device const float * src0,
device float * dst,
ushort tiitg[[thread_index_in_threadgroup]]) {
if (tiitg != 0) {
return;
}
float acc = 0.0f;
for (ulong i = 0; i < args.np; ++i) {
acc += src0[i];
}
dst[0] = acc;
}
template <bool norm>
kernel void kernel_sum_rows(
constant ggml_metal_kargs_sum_rows & args,
@ -8754,3 +8772,51 @@ kernel void kernel_pool_2d_avg_f32(
o_ptr[cur_oh * args.OW + cur_ow] = res;
}
kernel void kernel_opt_step_adamw_f32(
constant ggml_metal_kargs_opt_step_adamw & args,
device float * x,
device const float * g,
device float * g_m,
device float * g_v,
device const float * pars,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
const float alpha = pars[0];
const float beta1 = pars[1];
const float beta2 = pars[2];
const float eps = pars[3];
const float wd = pars[4];
const float beta1h = pars[5];
const float beta2h = pars[6];
const float gi = g[gid];
const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
g_m[gid] = gmi;
g_v[gid] = gvi;
const float mh = gmi * beta1h;
const float vh = sqrt(gvi * beta2h) + eps;
x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
}
kernel void kernel_opt_step_sgd_f32(
constant ggml_metal_kargs_opt_step_sgd & args,
device float * x,
device const float * g,
device const float * pars,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
}

View file

@ -0,0 +1,79 @@
#include "count-equal.hpp"
#include <cstdint>
template <typename T>
static void count_equal(const T *__restrict__ x, const T *__restrict__ y,
int64_t *__restrict__ dst, const int64_t dk,
const int64_t k) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int64_t i0 = (int64_t)item_ct1.get_group(2) * dk;
const int64_t i1 = sycl::min(i0 + dk, k);
int nequal = 0;
for (int64_t i = i0 + item_ct1.get_local_id(2); i < i1; i += WARP_SIZE) {
const T xi = x[i];
const T yi = y[i];
nequal += xi == yi;
}
nequal = warp_reduce_sum(nequal);
if (item_ct1.get_local_id(2) != 0) {
return;
}
dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
(int *)dst, nequal);
}
void ggml_sycl_count_equal(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->type == src1->type);
GGML_ASSERT( dst->type == GGML_TYPE_I64);
GGML_ASSERT(ggml_are_same_shape(src0, src1));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(dst));
int64_t * dst_d = (int64_t *) dst->data;
dpct::queue_ptr stream = ctx.stream();
const int id = get_current_device_id();
const int nsm = ggml_sycl_info().devices[id].nsm;
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne < (1 << 30) && "atomicAdd implementation only supports int");
const int64_t dne =
GGML_PAD((ne + 4 * nsm - 1) / (4 * nsm), SYCL_COUNT_EQUAL_CHUNK_SIZE);
SYCL_CHECK(CHECK_TRY_ERROR(stream->memset(dst_d, 0, ggml_nbytes(dst))));
const dpct::dim3 block_dims(WARP_SIZE, 1, 1);
const dpct::dim3 block_nums(
std::min((int64_t)4 * nsm, (ne + SYCL_COUNT_EQUAL_CHUNK_SIZE - 1) /
SYCL_COUNT_EQUAL_CHUNK_SIZE),
1, 1);
switch (src0->type) {
case GGML_TYPE_I32: {
const int *src0_d = (const int *)src0->data;
const int *src1_d = (const int *)src1->data;
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
count_equal(src0_d, src1_d, dst_d, dne, ne);
GGML_UNUSED(item_ct1);
});
} break;
default:
GGML_ASSERT(false);
break;
}
}

View file

@ -0,0 +1,9 @@
#ifndef GGML_SYCL_COUNT_EQUAL_HPP
#define GGML_SYCL_COUNT_EQUAL_HPP
#include "common.hpp"
#define SYCL_COUNT_EQUAL_CHUNK_SIZE 128
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif //GGML_SYCL_COUNT_EQUAL_HPP

View file

@ -0,0 +1,97 @@
//
// MIT license
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//#include "common.hpp"
#include "pad.hpp"
static void pad_f32(const float * src, float * dst,
const int lp0, const int rp0, const int lp1, const int rp1,
const int lp2, const int rp2, const int lp3, const int rp3,
const int ne0, const int ne1, const int ne2, const int ne3) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
int i0 = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2);
int i1 = item_ct1.get_group(1);
int i2 = item_ct1.get_group(0) % ne2;
int i3 = item_ct1.get_group(0) / ne2;
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return;
}
// operation
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
if ((i0 >= lp0 && i0 < ne0 - rp0) &&
(i1 >= lp1 && i1 < ne1 - rp1) &&
(i2 >= lp2 && i2 < ne2 - rp2) &&
(i3 >= lp3 && i3 < ne3 - rp3)) {
const int64_t i00 = i0 - lp0;
const int64_t i01 = i1 - lp1;
const int64_t i02 = i2 - lp2;
const int64_t i03 = i3 - lp3;
const int64_t ne02 = ne2 - lp2 - rp2;
const int64_t ne01 = ne1 - lp1 - rp1;
const int64_t ne00 = ne0 - lp0 - rp0;
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) +
i02 * (ne00 * ne01) + i01 * ne00 + i00;
dst[dst_idx] = src[src_idx];
} else {
dst[dst_idx] = 0.0f;
}
}
static void pad_f32_sycl(const float *src, float *dst, const int lp0,
const int rp0, const int lp1, const int rp1,
const int lp2, const int rp2, const int lp3,
const int rp3, const int ne0, const int ne1,
const int ne2, const int ne3,
dpct::queue_ptr stream) {
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
dpct::dim3 gridDim(num_blocks, ne1, ne2 * ne3);
stream->parallel_for(
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
pad_f32(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1,
ne2, ne3);
});
}
void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
dpct::queue_ptr stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
pad_f32_sycl(src0_d, dst_d,
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
}
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_pad(ctx, dst);
}

View file

@ -0,0 +1,24 @@
//
// MIT license
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
#ifndef GGML_SYCL_PAD_HPP
#define GGML_SYCL_PAD_HPP
#include "common.hpp"
#define SYCL_PAD_BLOCK_SIZE 256
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_PAD_HPP

Binary file not shown.

View file

@ -50,6 +50,7 @@
"eslint-plugin-svelte": "^3.0.0",
"fflate": "^0.8.2",
"globals": "^16.0.0",
"mdast": "^3.0.0",
"mdsvex": "^0.12.3",
"playwright": "^1.53.0",
"prettier": "^3.4.2",
@ -66,6 +67,7 @@
"tw-animate-css": "^1.3.5",
"typescript": "^5.0.0",
"typescript-eslint": "^8.20.0",
"unified": "^11.0.5",
"uuid": "^13.0.0",
"vite": "^7.0.4",
"vite-plugin-devtools-json": "^0.2.0",
@ -2128,6 +2130,66 @@
"node": ">=14.0.0"
}
},
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/core": {
"version": "1.4.3",
"dev": true,
"inBundle": true,
"license": "MIT",
"optional": true,
"dependencies": {
"@emnapi/wasi-threads": "1.0.2",
"tslib": "^2.4.0"
}
},
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/runtime": {
"version": "1.4.3",
"dev": true,
"inBundle": true,
"license": "MIT",
"optional": true,
"dependencies": {
"tslib": "^2.4.0"
}
},
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/wasi-threads": {
"version": "1.0.2",
"dev": true,
"inBundle": true,
"license": "MIT",
"optional": true,
"dependencies": {
"tslib": "^2.4.0"
}
},
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@napi-rs/wasm-runtime": {
"version": "0.2.11",
"dev": true,
"inBundle": true,
"license": "MIT",
"optional": true,
"dependencies": {
"@emnapi/core": "^1.4.3",
"@emnapi/runtime": "^1.4.3",
"@tybys/wasm-util": "^0.9.0"
}
},
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@tybys/wasm-util": {
"version": "0.9.0",
"dev": true,
"inBundle": true,
"license": "MIT",
"optional": true,
"dependencies": {
"tslib": "^2.4.0"
}
},
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/tslib": {
"version": "2.8.0",
"dev": true,
"inBundle": true,
"license": "0BSD",
"optional": true
},
"node_modules/@tailwindcss/oxide-win32-arm64-msvc": {
"version": "4.1.11",
"resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-arm64-msvc/-/oxide-win32-arm64-msvc-4.1.11.tgz",
@ -4946,6 +5008,13 @@
"url": "https://github.com/sponsors/wooorm"
}
},
"node_modules/mdast": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/mdast/-/mdast-3.0.0.tgz",
"integrity": "sha512-xySmf8g4fPKMeC07jXGz971EkLbWAJ83s4US2Tj9lEdnZ142UP5grN73H1Xd3HzrdbU5o9GYYP/y8F9ZSwLE9g==",
"dev": true,
"license": "MIT"
},
"node_modules/mdast-util-find-and-replace": {
"version": "3.0.2",
"resolved": "https://registry.npmjs.org/mdast-util-find-and-replace/-/mdast-util-find-and-replace-3.0.2.tgz",

View file

@ -52,6 +52,7 @@
"eslint-plugin-svelte": "^3.0.0",
"fflate": "^0.8.2",
"globals": "^16.0.0",
"mdast": "^3.0.0",
"mdsvex": "^0.12.3",
"playwright": "^1.53.0",
"prettier": "^3.4.2",
@ -68,6 +69,7 @@
"tw-animate-css": "^1.3.5",
"typescript": "^5.0.0",
"typescript-eslint": "^8.20.0",
"unified": "^11.0.5",
"uuid": "^13.0.0",
"vite": "^7.0.4",
"vite-plugin-devtools-json": "^0.2.0",

View file

@ -7,6 +7,7 @@
ChatMessages,
ChatProcessingInfo,
EmptyFileAlertDialog,
ChatErrorDialog,
ServerErrorSplash,
ServerInfo,
ServerLoadingSplash,
@ -22,10 +23,11 @@
activeMessages,
activeConversation,
deleteConversation,
dismissErrorDialog,
errorDialog,
isLoading,
sendMessage,
stopGeneration,
setMaxContextError
stopGeneration
} from '$lib/stores/chat.svelte';
import {
supportsVision,
@ -34,7 +36,6 @@
serverWarning,
serverStore
} from '$lib/stores/server.svelte';
import { contextService } from '$lib/services';
import { parseFilesToMessageExtras } from '$lib/utils/convert-files-to-extra';
import { isFileTypeSupported } from '$lib/utils/file-type';
import { filterFilesByModalities } from '$lib/utils/modality-file-validation';
@ -79,6 +80,7 @@
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
);
let activeErrorDialog = $derived(errorDialog());
let isServerLoading = $derived(serverLoading());
async function handleDeleteConfirm() {
@ -105,6 +107,12 @@
}
}
function handleErrorDialogOpenChange(open: boolean) {
if (!open) {
dismissErrorDialog();
}
}
function handleDragOver(event: DragEvent) {
event.preventDefault();
}
@ -183,21 +191,6 @@
const extras = result?.extras;
// Check context limit using real-time slots data
const contextCheck = await contextService.checkContextLimit();
if (contextCheck && contextCheck.wouldExceed) {
const errorMessage = contextService.getContextErrorMessage(contextCheck);
setMaxContextError({
message: errorMessage,
estimatedTokens: contextCheck.currentUsage,
maxContext: contextCheck.maxContext
});
return false;
}
// Enable autoscroll for user-initiated message sending
userScrolledUp = false;
autoScrollEnabled = true;
@ -461,6 +454,13 @@
}}
/>
<ChatErrorDialog
message={activeErrorDialog?.message ?? ''}
onOpenChange={handleErrorDialogOpenChange}
open={Boolean(activeErrorDialog)}
type={activeErrorDialog?.type ?? 'server'}
/>
<style>
.conversation-chat-form {
position: relative;

View file

@ -0,0 +1,60 @@
<script lang="ts">
import * as AlertDialog from '$lib/components/ui/alert-dialog';
import { AlertTriangle, TimerOff } from '@lucide/svelte';
interface Props {
open: boolean;
type: 'timeout' | 'server';
message: string;
onOpenChange?: (open: boolean) => void;
}
let { open = $bindable(), type, message, onOpenChange }: Props = $props();
const isTimeout = $derived(type === 'timeout');
const title = $derived(isTimeout ? 'TCP Timeout' : 'Server Error');
const description = $derived(
isTimeout
? 'The request did not receive a response from the server before timing out.'
: 'The server responded with an error message. Review the details below.'
);
const iconClass = $derived(isTimeout ? 'text-destructive' : 'text-amber-500');
const badgeClass = $derived(
isTimeout
? 'border-destructive/40 bg-destructive/10 text-destructive'
: 'border-amber-500/40 bg-amber-500/10 text-amber-600 dark:text-amber-400'
);
function handleOpenChange(newOpen: boolean) {
open = newOpen;
onOpenChange?.(newOpen);
}
</script>
<AlertDialog.Root {open} onOpenChange={handleOpenChange}>
<AlertDialog.Content>
<AlertDialog.Header>
<AlertDialog.Title class="flex items-center gap-2">
{#if isTimeout}
<TimerOff class={`h-5 w-5 ${iconClass}`} />
{:else}
<AlertTriangle class={`h-5 w-5 ${iconClass}`} />
{/if}
{title}
</AlertDialog.Title>
<AlertDialog.Description>
{description}
</AlertDialog.Description>
</AlertDialog.Header>
<div class={`rounded-lg border px-4 py-3 text-sm ${badgeClass}`}>
<p class="font-medium">{message}</p>
</div>
<AlertDialog.Footer>
<AlertDialog.Action onclick={() => handleOpenChange(false)}>Close</AlertDialog.Action>
</AlertDialog.Footer>
</AlertDialog.Content>
</AlertDialog.Root>

View file

@ -1,66 +0,0 @@
<script lang="ts">
import { AlertTriangle } from '@lucide/svelte';
import * as AlertDialog from '$lib/components/ui/alert-dialog';
import { maxContextError, clearMaxContextError } from '$lib/stores/chat.svelte';
</script>
<AlertDialog.Root
open={maxContextError() !== null}
onOpenChange={(open) => !open && clearMaxContextError()}
>
<AlertDialog.Content>
<AlertDialog.Header>
<AlertDialog.Title class="flex items-center gap-2">
<AlertTriangle class="h-5 w-5 text-destructive" />
Message Too Long
</AlertDialog.Title>
<AlertDialog.Description>
Your message exceeds the model's context window and cannot be processed.
</AlertDialog.Description>
</AlertDialog.Header>
{#if maxContextError()}
<div class="space-y-3 text-sm">
<div class="rounded-lg bg-muted p-3">
<div class="mb-2 font-medium">Token Usage:</div>
<div class="space-y-1 text-muted-foreground">
<div>
Estimated tokens:
<span class="font-mono">
{maxContextError()?.estimatedTokens.toLocaleString()}
</span>
</div>
<div>
Context window:
<span class="font-mono">
{maxContextError()?.maxContext.toLocaleString()}
</span>
</div>
</div>
</div>
<div>
<div class="mb-2 font-medium">Suggestions:</div>
<ul class="list-inside list-disc space-y-1 text-muted-foreground">
<li>Shorten your message</li>
<li>Remove some file attachments</li>
<li>Start a new conversation</li>
</ul>
</div>
</div>
{/if}
<AlertDialog.Footer>
<AlertDialog.Action onclick={() => clearMaxContextError()}>Got it</AlertDialog.Action>
</AlertDialog.Footer>
</AlertDialog.Content>
</AlertDialog.Root>

View file

@ -30,12 +30,11 @@ export { default as ChatSidebar } from './chat/ChatSidebar/ChatSidebar.svelte';
export { default as ChatSidebarConversationItem } from './chat/ChatSidebar/ChatSidebarConversationItem.svelte';
export { default as ChatSidebarSearch } from './chat/ChatSidebar/ChatSidebarSearch.svelte';
export { default as ChatErrorDialog } from './dialogs/ChatErrorDialog.svelte';
export { default as EmptyFileAlertDialog } from './dialogs/EmptyFileAlertDialog.svelte';
export { default as ConversationTitleUpdateDialog } from './dialogs/ConversationTitleUpdateDialog.svelte';
export { default as MaximumContextAlertDialog } from './dialogs/MaximumContextAlertDialog.svelte';
export { default as KeyboardShortcutInfo } from './misc/KeyboardShortcutInfo.svelte';
export { default as MarkdownContent } from './misc/MarkdownContent.svelte';

View file

@ -14,6 +14,7 @@
import githubDarkCss from 'highlight.js/styles/github-dark.css?inline';
import githubLightCss from 'highlight.js/styles/github.css?inline';
import { mode } from 'mode-watcher';
import { remarkLiteralHtml } from '$lib/markdown/literal-html';
interface Props {
content: string;
@ -50,36 +51,59 @@
.use(remarkGfm) // GitHub Flavored Markdown
.use(remarkMath) // Parse $inline$ and $$block$$ math
.use(remarkBreaks) // Convert line breaks to <br>
.use(remarkRehype) // Convert to rehype (HTML AST)
.use(remarkLiteralHtml) // Treat raw HTML as literal text with preserved indentation
.use(remarkRehype) // Convert Markdown AST to rehype
.use(rehypeKatex) // Render math using KaTeX
.use(rehypeHighlight) // Add syntax highlighting
.use(rehypeStringify); // Convert to HTML string
});
function enhanceLinks(html: string): string {
if (!html.includes('<a')) {
return html;
}
const tempDiv = document.createElement('div');
tempDiv.innerHTML = html;
// Make all links open in new tabs
const linkElements = tempDiv.querySelectorAll('a[href]');
let mutated = false;
for (const link of linkElements) {
const target = link.getAttribute('target');
const rel = link.getAttribute('rel');
if (target !== '_blank' || rel !== 'noopener noreferrer') {
mutated = true;
}
link.setAttribute('target', '_blank');
link.setAttribute('rel', 'noopener noreferrer');
}
return tempDiv.innerHTML;
return mutated ? tempDiv.innerHTML : html;
}
function enhanceCodeBlocks(html: string): string {
if (!html.includes('<pre')) {
return html;
}
const tempDiv = document.createElement('div');
tempDiv.innerHTML = html;
const preElements = tempDiv.querySelectorAll('pre');
let mutated = false;
for (const [index, pre] of Array.from(preElements).entries()) {
const codeElement = pre.querySelector('code');
if (!codeElement) continue;
if (!codeElement) {
continue;
}
mutated = true;
let language = 'text';
const classList = Array.from(codeElement.classList);
@ -127,7 +151,7 @@
pre.parentNode?.replaceChild(wrapper, pre);
}
return tempDiv.innerHTML;
return mutated ? tempDiv.innerHTML : html;
}
async function processMarkdown(text: string): Promise<string> {

View file

@ -0,0 +1,15 @@
export const LINE_BREAK = /\r?\n/;
export const PHRASE_PARENTS = new Set([
'paragraph',
'heading',
'emphasis',
'strong',
'delete',
'link',
'linkReference',
'tableCell'
]);
export const NBSP = '\u00a0';
export const TAB_AS_SPACES = NBSP.repeat(4);

View file

@ -0,0 +1,121 @@
import type { Plugin } from 'unified';
import { visit } from 'unist-util-visit';
import type { Break, Content, Paragraph, PhrasingContent, Root, Text } from 'mdast';
import { LINE_BREAK, NBSP, PHRASE_PARENTS, TAB_AS_SPACES } from '$lib/constants/literal-html';
/**
* remark plugin that rewrites raw HTML nodes into plain-text equivalents.
*
* remark parses inline HTML into `html` nodes even when we do not want to render
* them. We turn each of those nodes into regular text (plus `<br>` break markers)
* so the downstream rehype pipeline escapes the characters instead of executing
* them. Leading spaces and tab characters are converted to nonbreaking spaces to
* keep indentation identical to the original author input.
*/
function preserveIndent(line: string): string {
let index = 0;
let output = '';
while (index < line.length) {
const char = line[index];
if (char === ' ') {
output += NBSP;
index += 1;
continue;
}
if (char === '\t') {
output += TAB_AS_SPACES;
index += 1;
continue;
}
break;
}
return output + line.slice(index);
}
function createLiteralChildren(value: string): PhrasingContent[] {
const lines = value.split(LINE_BREAK);
const nodes: PhrasingContent[] = [];
for (const [lineIndex, rawLine] of lines.entries()) {
if (lineIndex > 0) {
nodes.push({ type: 'break' } as Break as unknown as PhrasingContent);
}
nodes.push({
type: 'text',
value: preserveIndent(rawLine)
} as Text as unknown as PhrasingContent);
}
if (!nodes.length) {
nodes.push({ type: 'text', value: '' } as Text as unknown as PhrasingContent);
}
return nodes;
}
export const remarkLiteralHtml: Plugin<[], Root> = () => {
return (tree) => {
visit(tree, 'html', (node, index, parent) => {
if (!parent || typeof index !== 'number') {
return;
}
const replacement = createLiteralChildren(node.value);
if (!PHRASE_PARENTS.has(parent.type as string)) {
const paragraph: Paragraph = {
type: 'paragraph',
children: replacement as Paragraph['children'],
data: { literalHtml: true }
};
const siblings = parent.children as unknown as Content[];
siblings.splice(index, 1, paragraph as unknown as Content);
if (index > 0) {
const previous = siblings[index - 1] as Paragraph | undefined;
if (
previous?.type === 'paragraph' &&
(previous.data as { literalHtml?: boolean } | undefined)?.literalHtml
) {
const prevChildren = previous.children as unknown as PhrasingContent[];
if (prevChildren.length) {
const lastChild = prevChildren[prevChildren.length - 1];
if (lastChild.type !== 'break') {
prevChildren.push({
type: 'break'
} as Break as unknown as PhrasingContent);
}
}
prevChildren.push(...(paragraph.children as unknown as PhrasingContent[]));
siblings.splice(index, 1);
return index;
}
}
return index + 1;
}
(parent.children as unknown as PhrasingContent[]).splice(
index,
1,
...(replacement as unknown as PhrasingContent[])
);
return index + replacement.length;
});
};
};

View file

@ -13,7 +13,7 @@ import { slotsService } from './slots';
* - Manages streaming and non-streaming response parsing
* - Provides request abortion capabilities
* - Converts database messages to API format
* - Handles error translation and context detection
* - Handles error translation for server responses
*
* - **ChatStore**: Stateful orchestration and UI state management
* - Uses ChatService for all AI model communication
@ -26,7 +26,6 @@ import { slotsService } from './slots';
* - Streaming response handling with real-time callbacks
* - Reasoning content extraction and processing
* - File attachment processing (images, PDFs, audio, text)
* - Context error detection and reporting
* - Request lifecycle management (abort, cleanup)
*/
export class ChatService {
@ -209,10 +208,13 @@ export class ChatService {
userFriendlyError = new Error(
'Unable to connect to server - please check if the server is running'
);
userFriendlyError.name = 'NetworkError';
} else if (error.message.includes('ECONNREFUSED')) {
userFriendlyError = new Error('Connection refused - server may be offline');
userFriendlyError.name = 'NetworkError';
} else if (error.message.includes('ETIMEDOUT')) {
userFriendlyError = new Error('Request timeout - server may be overloaded');
userFriendlyError = new Error('Request timed out - the server took too long to respond');
userFriendlyError.name = 'TimeoutError';
} else {
userFriendlyError = error;
}
@ -262,6 +264,7 @@ export class ChatService {
let fullReasoningContent = '';
let hasReceivedData = false;
let lastTimings: ChatMessageTimings | undefined;
let streamFinished = false;
try {
let chunk = '';
@ -277,18 +280,8 @@ export class ChatService {
if (line.startsWith('data: ')) {
const data = line.slice(6);
if (data === '[DONE]') {
if (!hasReceivedData && aggregatedContent.length === 0) {
const contextError = new Error(
'The request exceeds the available context size. Try increasing the context size or enable context shift.'
);
contextError.name = 'ContextError';
onError?.(contextError);
return;
}
onComplete?.(aggregatedContent, fullReasoningContent || undefined, lastTimings);
return;
streamFinished = true;
continue;
}
try {
@ -326,13 +319,13 @@ export class ChatService {
}
}
if (!hasReceivedData && aggregatedContent.length === 0) {
const contextError = new Error(
'The request exceeds the available context size. Try increasing the context size or enable context shift.'
);
contextError.name = 'ContextError';
onError?.(contextError);
return;
if (streamFinished) {
if (!hasReceivedData && aggregatedContent.length === 0) {
const noResponseError = new Error('No response received from server. Please try again.');
throw noResponseError;
}
onComplete?.(aggregatedContent, fullReasoningContent || undefined, lastTimings);
}
} catch (error) {
const err = error instanceof Error ? error : new Error('Stream error');
@ -368,12 +361,8 @@ export class ChatService {
const responseText = await response.text();
if (!responseText.trim()) {
const contextError = new Error(
'The request exceeds the available context size. Try increasing the context size or enable context shift.'
);
contextError.name = 'ContextError';
onError?.(contextError);
throw contextError;
const noResponseError = new Error('No response received from server. Please try again.');
throw noResponseError;
}
const data: ApiChatCompletionResponse = JSON.parse(responseText);
@ -385,22 +374,14 @@ export class ChatService {
}
if (!content.trim()) {
const contextError = new Error(
'The request exceeds the available context size. Try increasing the context size or enable context shift.'
);
contextError.name = 'ContextError';
onError?.(contextError);
throw contextError;
const noResponseError = new Error('No response received from server. Please try again.');
throw noResponseError;
}
onComplete?.(content, reasoningContent);
return content;
} catch (error) {
if (error instanceof Error && error.name === 'ContextError') {
throw error;
}
const err = error instanceof Error ? error : new Error('Parse error');
onError?.(err);
@ -594,37 +575,19 @@ export class ChatService {
const errorText = await response.text();
const errorData: ApiErrorResponse = JSON.parse(errorText);
if (errorData.error?.type === 'exceed_context_size_error') {
const contextError = errorData.error as ApiContextSizeError;
const error = new Error(contextError.message);
error.name = 'ContextError';
// Attach structured context information
(
error as Error & {
contextInfo?: { promptTokens: number; maxContext: number; estimatedTokens: number };
}
).contextInfo = {
promptTokens: contextError.n_prompt_tokens,
maxContext: contextError.n_ctx,
estimatedTokens: contextError.n_prompt_tokens
};
return error;
}
// Fallback for other error types
const message = errorData.error?.message || 'Unknown server error';
return new Error(message);
const error = new Error(message);
error.name = response.status === 400 ? 'ServerError' : 'HttpError';
return error;
} catch {
// If we can't parse the error response, return a generic error
return new Error(`Server error (${response.status}): ${response.statusText}`);
const fallback = new Error(`Server error (${response.status}): ${response.statusText}`);
fallback.name = 'HttpError';
return fallback;
}
}
/**
* Updates the processing state with timing information from the server response
* @param timings - Timing data from the API response
* @param promptProgress - Progress data from the API response
*/
private updateProcessingState(
timings?: ChatMessageTimings,
promptProgress?: ChatMessagePromptProgress

View file

@ -1,102 +0,0 @@
import { slotsService } from './slots';
export interface ContextCheckResult {
wouldExceed: boolean;
currentUsage: number;
maxContext: number;
availableTokens: number;
reservedTokens: number;
}
/**
* ContextService - Context window management and limit checking
*
* This service provides context window monitoring and limit checking using real-time
* server data from the slots service. It helps prevent context overflow by tracking
* current usage and calculating available space for new content.
*
* **Architecture & Relationships:**
* - **ContextService** (this class): Context limit monitoring
* - Uses SlotsService for real-time context usage data
* - Calculates available tokens with configurable reserves
* - Provides context limit checking and error messaging
* - Helps prevent context window overflow
*
* - **SlotsService**: Provides current context usage from server slots
* - **ChatStore**: Uses context checking before sending messages
* - **UI Components**: Display context usage warnings and limits
*
* **Key Features:**
* - **Real-time Context Checking**: Uses live server data for accuracy
* - **Token Reservation**: Reserves tokens for response generation
* - **Limit Detection**: Prevents context window overflow
* - **Usage Reporting**: Detailed context usage statistics
* - **Error Messaging**: User-friendly context limit messages
* - **Configurable Reserves**: Adjustable token reservation for responses
*
* **Context Management:**
* - Monitors current context usage from active slots
* - Calculates available space considering reserved tokens
* - Provides early warning before context limits are reached
* - Helps optimize conversation length and content
*/
export class ContextService {
private reserveTokens: number;
constructor(reserveTokens = 512) {
this.reserveTokens = reserveTokens;
}
/**
* Checks if the context limit would be exceeded
*
* @returns {Promise<ContextCheckResult | null>} Promise that resolves to the context check result or null if an error occurs
*/
async checkContextLimit(): Promise<ContextCheckResult | null> {
try {
const currentState = await slotsService.getCurrentState();
if (!currentState) {
return null;
}
const maxContext = currentState.contextTotal;
const currentUsage = currentState.contextUsed;
const availableTokens = maxContext - currentUsage - this.reserveTokens;
const wouldExceed = availableTokens <= 0;
return {
wouldExceed,
currentUsage,
maxContext,
availableTokens: Math.max(0, availableTokens),
reservedTokens: this.reserveTokens
};
} catch (error) {
console.warn('Error checking context limit:', error);
return null;
}
}
/**
* Returns a formatted error message for context limit exceeded
*
* @param {ContextCheckResult} result - Context check result
* @returns {string} Formatted error message
*/
getContextErrorMessage(result: ContextCheckResult): string {
const usagePercent = Math.round((result.currentUsage / result.maxContext) * 100);
return `Context window is nearly full. Current usage: ${result.currentUsage.toLocaleString()}/${result.maxContext.toLocaleString()} tokens (${usagePercent}%). Available space: ${result.availableTokens.toLocaleString()} tokens (${result.reservedTokens} reserved for response).`;
}
/**
* Sets the number of tokens to reserve for response generation
*
* @param {number} tokens - Number of tokens to reserve
*/
setReserveTokens(tokens: number): void {
this.reserveTokens = tokens;
}
}
export const contextService = new ContextService();

View file

@ -1,3 +1,2 @@
export { chatService } from './chat';
export { contextService } from './context';
export { slotsService } from './slots';

View file

@ -39,7 +39,6 @@ import type { ExportedConversations } from '$lib/types/database';
* - Conversation branching for exploring different response paths
* - Streaming AI responses with real-time content updates
* - File attachment support (images, PDFs, text files, audio)
* - Context window management with error recovery
* - Partial response saving when generation is interrupted
* - Message editing with automatic response regeneration
*/
@ -48,11 +47,9 @@ class ChatStore {
activeMessages = $state<DatabaseMessage[]>([]);
conversations = $state<DatabaseConversation[]>([]);
currentResponse = $state('');
errorDialogState = $state<{ type: 'timeout' | 'server'; message: string } | null>(null);
isInitialized = $state(false);
isLoading = $state(false);
maxContextError = $state<{ message: string; estimatedTokens: number; maxContext: number } | null>(
null
);
titleUpdateConfirmationCallback?: (currentTitle: string, newTitle: string) => Promise<boolean>;
constructor() {
@ -69,8 +66,6 @@ class ChatStore {
try {
await this.loadConversations();
this.maxContextError = null;
this.isInitialized = true;
} catch (error) {
console.error('Failed to initialize chat store:', error);
@ -99,8 +94,6 @@ class ChatStore {
this.activeConversation = conversation;
this.activeMessages = [];
this.maxContextError = null;
await goto(`#/chat/${conversation.id}`);
return conversation.id;
@ -133,8 +126,6 @@ class ChatStore {
this.activeMessages = await DatabaseStore.getConversationMessages(convId);
}
this.maxContextError = null;
return true;
} catch (error) {
console.error('Failed to load conversation:', error);
@ -418,56 +409,6 @@ class ChatStore {
return;
}
if (error.name === 'ContextError') {
console.warn('Context error detected:', error.message);
this.isLoading = false;
this.currentResponse = '';
const messageIndex = this.activeMessages.findIndex(
(m: DatabaseMessage) => m.id === assistantMessage.id
);
if (messageIndex !== -1) {
this.activeMessages.splice(messageIndex, 1);
DatabaseStore.deleteMessage(assistantMessage.id).catch(console.error);
}
// Use structured context info from new exceed_context_size_error format if available
const contextInfo = (
error as Error & {
contextInfo?: { promptTokens: number; maxContext: number; estimatedTokens: number };
}
).contextInfo;
let estimatedTokens = 0;
let maxContext = serverStore.serverProps?.default_generation_settings.n_ctx || 8192;
if (contextInfo) {
// Use precise token counts from server response
estimatedTokens = contextInfo.promptTokens;
maxContext = contextInfo.maxContext;
} else {
// Fallback to estimation for older error format
try {
// Rough estimation: ~4 characters per token
const messageContent = JSON.stringify(messages);
estimatedTokens = Math.ceil(messageContent.length / 4);
} catch {
estimatedTokens = 0;
}
}
this.maxContextError = {
message: error.message,
estimatedTokens,
maxContext
};
if (onError) {
onError(error);
}
return;
}
console.error('Streaming error:', error);
this.isLoading = false;
this.currentResponse = '';
@ -477,9 +418,19 @@ class ChatStore {
);
if (messageIndex !== -1) {
this.activeMessages[messageIndex].content = `Error: ${error.message}`;
const [failedMessage] = this.activeMessages.splice(messageIndex, 1);
if (failedMessage) {
DatabaseStore.deleteMessage(failedMessage.id).catch((cleanupError) => {
console.error('Failed to remove assistant message after error:', cleanupError);
});
}
}
const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server';
this.showErrorDialog(dialogType, error.message);
if (onError) {
onError(error);
}
@ -487,6 +438,14 @@ class ChatStore {
});
}
private showErrorDialog(type: 'timeout' | 'server', message: string): void {
this.errorDialogState = { type, message };
}
dismissErrorDialog(): void {
this.errorDialogState = null;
}
/**
* Checks if an error is an abort error (user cancelled operation)
* @param error - The error to check
@ -574,6 +533,7 @@ class ChatStore {
return;
}
this.errorDialogState = null;
this.isLoading = true;
this.currentResponse = '';
@ -603,37 +563,23 @@ class ChatStore {
const conversationContext = this.activeMessages.slice(0, -1);
await this.streamChatCompletion(
conversationContext,
assistantMessage,
undefined,
(error: Error) => {
if (error.name === 'ContextError' && userMessage) {
const userMessageIndex = this.findMessageIndex(userMessage.id);
if (userMessageIndex !== -1) {
this.activeMessages.splice(userMessageIndex, 1);
DatabaseStore.deleteMessage(userMessage.id).catch(console.error);
}
}
}
);
await this.streamChatCompletion(conversationContext, assistantMessage);
} catch (error) {
if (this.isAbortError(error)) {
this.isLoading = false;
return;
}
if (error instanceof Error && error.name === 'ContextError' && userMessage) {
const userMessageIndex = this.findMessageIndex(userMessage.id);
if (userMessageIndex !== -1) {
this.activeMessages.splice(userMessageIndex, 1);
DatabaseStore.deleteMessage(userMessage.id).catch(console.error);
}
}
console.error('Failed to send message:', error);
this.isLoading = false;
if (!this.errorDialogState) {
if (error instanceof Error) {
const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server';
this.showErrorDialog(dialogType, error.message);
} else {
this.showErrorDialog('server', 'Unknown error occurred while sending message');
}
}
}
}
@ -662,24 +608,6 @@ class ChatStore {
this.currentResponse = '';
}
/**
* Clears the max context error state
* Removes any displayed context limit warnings
*/
clearMaxContextError(): void {
this.maxContextError = null;
}
/**
* Sets the max context error state
* @param error - The context error details or null to clear
*/
setMaxContextError(
error: { message: string; estimatedTokens: number; maxContext: number } | null
): void {
this.maxContextError = error;
}
/**
* Saves partial response if generation was interrupted
* Preserves user's partial content and timing data when generation is stopped early
@ -1250,7 +1178,6 @@ class ChatStore {
this.activeMessages = [];
this.currentResponse = '';
this.isLoading = false;
this.maxContextError = null;
}
/** Refreshes active messages based on currNode after branch navigation */
@ -1538,6 +1465,7 @@ class ChatStore {
private async generateResponseForMessage(userMessageId: string): Promise<void> {
if (!this.activeConversation) return;
this.errorDialogState = null;
this.isLoading = true;
this.currentResponse = '';
@ -1584,7 +1512,7 @@ export const activeMessages = () => chatStore.activeMessages;
export const isLoading = () => chatStore.isLoading;
export const currentResponse = () => chatStore.currentResponse;
export const isInitialized = () => chatStore.isInitialized;
export const maxContextError = () => chatStore.maxContextError;
export const errorDialog = () => chatStore.errorDialogState;
export const createConversation = chatStore.createConversation.bind(chatStore);
export const downloadConversation = chatStore.downloadConversation.bind(chatStore);
@ -1592,9 +1520,9 @@ export const exportAllConversations = chatStore.exportAllConversations.bind(chat
export const importConversations = chatStore.importConversations.bind(chatStore);
export const deleteConversation = chatStore.deleteConversation.bind(chatStore);
export const sendMessage = chatStore.sendMessage.bind(chatStore);
export const dismissErrorDialog = chatStore.dismissErrorDialog.bind(chatStore);
export const gracefulStop = chatStore.gracefulStop.bind(chatStore);
export const clearMaxContextError = chatStore.clearMaxContextError.bind(chatStore);
export const setMaxContextError = chatStore.setMaxContextError.bind(chatStore);
// Branching operations
export const refreshActiveMessages = chatStore.refreshActiveMessages.bind(chatStore);

View file

@ -197,7 +197,7 @@ class ServerStore {
errorMessage = 'Server not found - check server address';
isOfflineLikeError = true;
} else if (error.message.includes('ETIMEDOUT')) {
errorMessage = 'Connection timeout - server may be overloaded';
errorMessage = 'Request timed out - the server took too long to respond';
isOfflineLikeError = true;
} else if (error.message.includes('503')) {
errorMessage = 'Server temporarily unavailable - try again shortly';

View file

@ -1,11 +1,7 @@
<script lang="ts">
import '../app.css';
import { page } from '$app/state';
import {
ChatSidebar,
ConversationTitleUpdateDialog,
MaximumContextAlertDialog
} from '$lib/components/app';
import { ChatSidebar, ConversationTitleUpdateDialog } from '$lib/components/app';
import {
activeMessages,
isLoading,
@ -145,8 +141,6 @@
<Toaster richColors />
<MaximumContextAlertDialog />
<ConversationTitleUpdateDialog
bind:open={titleUpdateDialogOpen}
currentTitle={titleUpdateCurrentTitle}