diff --git a/CMakeLists.txt b/CMakeLists.txt index 2f919bc72..5bd8f3816 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -453,6 +453,8 @@ add_library(common2 tools/mtmd/clip.h src/unicode.h src/unicode.cpp + common/unicode.h + common/unicode.cpp src/llama-impl.h src/llama-impl.cpp src/unicode-data.cpp diff --git a/Makefile b/Makefile index 59189728a..52de513c9 100644 --- a/Makefile +++ b/Makefile @@ -110,10 +110,10 @@ endif CUBLASLD_FLAGS = CUBLAS_OBJS = -OBJS_FULL += ggml-alloc.o ggml-cpu-traits.o ggml-quants.o ggml-cpu-quants.o kcpp-quantmapper.o kcpp-repackmapper.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm.o common.o llama-impl.o sampling.o kcpputils.o mtmdaudio.o -OBJS_SIMPLE += ggml-alloc.o ggml-cpu-traits.o ggml-quants_noavx2.o ggml-cpu-quants.o kcpp-quantmapper_noavx2.o kcpp-repackmapper_noavx2.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_noavx2.o common.o llama-impl.o sampling.o kcpputils.o mtmdaudio.o -OBJS_SIMPLER += ggml-alloc.o ggml-cpu-traits.o ggml-quants_noavx1.o ggml-cpu-quants.o kcpp-quantmapper_noavx1.o kcpp-repackmapper_noavx1.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_noavx1.o common.o llama-impl.o sampling.o kcpputils.o mtmdaudio.o -OBJS_FAILSAFE += ggml-alloc.o ggml-cpu-traits.o ggml-quants_failsafe.o ggml-cpu-quants.o kcpp-quantmapper_failsafe.o kcpp-repackmapper_failsafe.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_failsafe.o common.o llama-impl.o sampling.o kcpputils.o mtmdaudio.o +OBJS_FULL += ggml-alloc.o ggml-cpu-traits.o ggml-quants.o ggml-cpu-quants.o kcpp-quantmapper.o kcpp-repackmapper.o unicode.o unicode-common.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm.o common.o llama-impl.o sampling.o kcpputils.o mtmdaudio.o +OBJS_SIMPLE += ggml-alloc.o ggml-cpu-traits.o ggml-quants_noavx2.o ggml-cpu-quants.o kcpp-quantmapper_noavx2.o kcpp-repackmapper_noavx2.o unicode.o unicode-common.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_noavx2.o common.o llama-impl.o sampling.o kcpputils.o mtmdaudio.o +OBJS_SIMPLER += ggml-alloc.o ggml-cpu-traits.o ggml-quants_noavx1.o ggml-cpu-quants.o kcpp-quantmapper_noavx1.o kcpp-repackmapper_noavx1.o unicode.o unicode-common.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_noavx1.o common.o llama-impl.o sampling.o kcpputils.o mtmdaudio.o +OBJS_FAILSAFE += ggml-alloc.o ggml-cpu-traits.o ggml-quants_failsafe.o ggml-cpu-quants.o kcpp-quantmapper_failsafe.o kcpp-repackmapper_failsafe.o unicode.o unicode-common.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_failsafe.o common.o llama-impl.o sampling.o kcpputils.o mtmdaudio.o # OS specific ifeq ($(UNAME_S),Linux) @@ -580,6 +580,8 @@ ggml-alloc.o: ggml/src/ggml-alloc.c ggml/include/ggml.h ggml/include/ggml-alloc. $(CC) $(CFLAGS) -c $< -o $@ llava.o: tools/mtmd/llava.cpp tools/mtmd/llava.h $(CXX) $(CXXFLAGS) -c $< -o $@ +unicode-common.o: common/unicode.cpp common/unicode.h + $(CXX) $(CXXFLAGS) -c $< -o $@ unicode.o: src/unicode.cpp src/unicode.h $(CXX) $(CXXFLAGS) -c $< -o $@ unicode-data.o: src/unicode-data.cpp src/unicode-data.h diff --git a/common/chat.cpp b/common/chat.cpp index b30905b9a..13fa4a24d 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -510,7 +510,7 @@ json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) { } #include "chat-parser.cpp" -#include "common/unicode.cpp" +#include "common/unicode.h" #include "peg-parser.cpp" #include "chat-peg-parser.cpp" diff --git a/common/common.cpp b/common/common.cpp index e5f45957d..91057041b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1,7 +1,3 @@ -#if defined(_MSC_VER) -#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING -#endif - #include "ggml.h" #include "gguf.h" @@ -16,12 +12,12 @@ #include "llama.h" #include "sampling.h" #include "ggml/src/ggml-opt.cpp" //dear god pls +#include "unicode.h" #include #include #include #include -#include #include #include #include @@ -713,45 +709,28 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) { return false; } - std::u32string filename_utf32; - try { -#if defined(__clang__) - // disable C++17 deprecation warning for std::codecvt_utf8 -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wdeprecated-declarations" -#elif defined(__GNUC__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wdeprecated-declarations" -#endif + size_t offset = 0; + while (offset < filename.size()) { + utf8_parse_result result = parse_utf8_codepoint(filename, offset); - std::wstring_convert, char32_t> converter; - -#if defined(__clang__) -# pragma clang diagnostic pop -#elif defined(__GNUC__) -# pragma GCC diagnostic pop -#endif - - filename_utf32 = converter.from_bytes(filename); - - // If the reverse conversion mismatches, it means overlong UTF-8 sequences were used, - // or invalid encodings were encountered. Reject such attempts - std::string filename_reencoded = converter.to_bytes(filename_utf32); - if (filename_reencoded != filename) { + if (result.status != utf8_parse_result::SUCCESS) { return false; } - } catch (const std::exception &) { - return false; - } + uint32_t c = result.codepoint; - // Check for forbidden codepoints: - // - Control characters - // - Unicode equivalents of illegal characters - // - UTF-16 surrogate pairs - // - UTF-8 replacement character - // - Byte order mark (BOM) - // - Illegal characters: / \ : * ? " < > | - for (char32_t c : filename_utf32) { + if ((result.bytes_consumed == 2 && c < 0x80) || + (result.bytes_consumed == 3 && c < 0x800) || + (result.bytes_consumed == 4 && c < 0x10000)) { + return false; + } + + // Check for forbidden codepoints: + // - Control characters + // - Unicode equivalents of illegal characters + // - UTF-16 surrogate pairs + // - UTF-8 replacement character + // - Byte order mark (BOM) + // - Illegal characters: / \ : * ? " < > | if (c <= 0x1F // Control characters (C0) || c == 0x7F // Control characters (DEL) || (c >= 0x80 && c <= 0x9F) // Control characters (C1) @@ -759,6 +738,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) { || c == 0x2215 // Division Slash (forward slash equivalent) || c == 0x2216 // Set Minus (backslash equivalent) || (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs + || c > 0x10FFFF // Max Unicode limit || c == 0xFFFD // Replacement Character (UTF-8) || c == 0xFEFF // Byte Order Mark (BOM) || c == ':' || c == '*' // Illegal characters @@ -769,6 +749,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) { // Subdirectories not allowed, reject path separators return false; } + offset += result.bytes_consumed; } // Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename @@ -1476,66 +1457,6 @@ void common_batch_add( batch.n_tokens++; } -// -// Token utils -// - -size_t common_lcp(const llama_tokens & a, const llama_tokens & b) { - size_t i; - for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} - - return i; -} - -size_t common_lcs(const llama_tokens & a, const llama_tokens & b) { - // check for empty sequences - if (a.empty() || b.empty()) { - return 0; - } - - // get the lengths of the input sequences - size_t a_len = a.size(); - size_t b_len = b.size(); - - // initialize the maximum length of the longest common subsequence (LCS) - size_t max_length = 0; - - // use two rows instead of a 2D matrix to optimize space - std::vector prev_row(b_len + 1, 0); - std::vector curr_row(b_len + 1, 0); - - // iterate through the elements of a - for (size_t i = 1; i <= a_len; i++) { - // iterate through the elements of b - for (size_t j = 1; j <= b_len; j++) { - // if elements at the current positions match - if (a[i - 1] == b[j - 1]) { - // if it's the first element of either sequences, set LCS length to 1 - if (i == 1 || j == 1) { - curr_row[j] = 1; - } else { - // increment LCS length by 1 compared to the previous element - curr_row[j] = prev_row[j - 1] + 1; - } - - // update max_length if necessary - if (curr_row[j] > max_length) { - max_length = curr_row[j]; - } - } else { - // reset LCS length if elements don't match - curr_row[j] = 0; - } - } - - // update the previous row for the next iteration - prev_row = curr_row; - } - - // return the maximum length of the LCS - return max_length; -} - // // Vocab utils // diff --git a/common/common.h b/common/common.h index f5c273d37..285def17a 100644 --- a/common/common.h +++ b/common/common.h @@ -776,16 +776,6 @@ void common_batch_add( const std::vector & seq_ids, bool logits); -// -// Token utils -// - -// longest common prefix -size_t common_lcp(const llama_tokens & a, const llama_tokens & b); - -// longet common subsequence -size_t common_lcs(const llama_tokens & a, const llama_tokens & b); - // // Vocab utils // diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index ed4535020..4352e1328 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -2096,10 +2096,14 @@ static void ggml_compute_forward_gelu_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2113,10 +2117,14 @@ static void ggml_compute_forward_gelu_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2135,10 +2143,14 @@ static void ggml_compute_forward_gelu_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2152,10 +2164,14 @@ static void ggml_compute_forward_gelu_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2276,10 +2292,14 @@ static void ggml_compute_forward_gelu_erf_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2293,10 +2313,14 @@ static void ggml_compute_forward_gelu_erf_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_erf_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2315,10 +2339,14 @@ static void ggml_compute_forward_gelu_erf_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2332,10 +2360,14 @@ static void ggml_compute_forward_gelu_erf_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_erf_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2379,10 +2411,14 @@ static void ggml_compute_forward_gelu_quick_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2396,10 +2432,14 @@ static void ggml_compute_forward_gelu_quick_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_quick_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2418,10 +2458,14 @@ static void ggml_compute_forward_gelu_quick_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2435,10 +2479,14 @@ static void ggml_compute_forward_gelu_quick_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_quick_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2482,10 +2530,14 @@ static void ggml_compute_forward_silu_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2499,10 +2551,14 @@ static void ggml_compute_forward_silu_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_silu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2521,10 +2577,14 @@ static void ggml_compute_forward_silu_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2538,10 +2598,14 @@ static void ggml_compute_forward_silu_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_silu_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index 1d9873ad0..1d8344436 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -111,7 +111,7 @@ template static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_contiguous_rows(src0) && ggml_is_contiguous_rows(dst) && ggml_are_same_shape(src0, dst)); GGML_TENSOR_UNARY_OP_LOCALS diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index e84558ee6..e61447f00 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -63,11 +63,19 @@ static __global__ void flash_attn_ext_f16( constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); +#if defined(GGML_USE_HIP) + typedef wmma::fragment frag_a_K; + typedef wmma::fragment frag_a_V; + typedef wmma::fragment frag_b; + typedef wmma::fragment frag_c_KQ; + typedef wmma::fragment frag_c_VKQ; +#else typedef wmma::fragment frag_a_K; typedef wmma::fragment frag_a_V; typedef wmma::fragment frag_b; typedef wmma::fragment frag_c_KQ; typedef wmma::fragment frag_c_VKQ; +#endif constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. @@ -126,6 +134,19 @@ static __global__ void flash_attn_ext_f16( __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; + +#if defined(GGML_USE_HIP) + const _Float16 * K_h_f16 = reinterpret_cast(K_h); + const _Float16 * V_h_f16 = reinterpret_cast(V_h); + _Float16 * KQ_f16 = reinterpret_cast<_Float16 *>(KQ); + _Float16 * VKQ_f16 = reinterpret_cast<_Float16 *>(VKQ); +#else + const half * K_h_f16 = K_h; + const half * V_h_f16 = V_h; + half * KQ_f16 = KQ; + half * VKQ_f16 = VKQ; +#endif + #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -160,7 +181,7 @@ static __global__ void flash_attn_ext_f16( for (int i0 = 0; i0 < D; i0 += 16) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ_f16 + j0*D_padded + i0, D_padded); } } @@ -180,7 +201,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { frag_a_K K_a; - wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); + wmma::load_matrix_sync(K_a, K_h_f16 + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); @@ -310,7 +331,7 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; wmma::load_matrix_sync( KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], - KQ + j0*(kqar*kqs_padded) + k, + KQ_f16 + j0*(kqar*kqs_padded) + k, kqar*kqs_padded); } } @@ -328,7 +349,7 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; - wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); + wmma::load_matrix_sync(v_a, V_h_f16 + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); @@ -344,7 +365,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { wmma::store_matrix_sync( - KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + KQ_f16 + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], D_padded, wmma::mem_col_major); } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 517559d12..06f3d8045 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -328,31 +328,46 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_l } ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { - GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); char base[256]; char name[256]; - const char * op_str = "undefined"; + int op_num = -1; + switch (op->op) { - case GGML_OP_SUM_ROWS: - op_str = "sum_rows"; break; - case GGML_OP_MEAN: - op_str = "mean"; break; + case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break; + case GGML_OP_MEAN: op_num = OP_SUM_ROWS_NUM_MEAN; break; default: GGML_ABORT("fatal error"); }; - snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t_str = ggml_type_name(op->type); - snprintf(name, 256, "%s", base); + const bool is_c4 = op->src[0]->ne[0] % 4 == 0; + + snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : ""); + snprintf(name, 256, "%s_op=%d", base, op_num); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } res.smem = 32*sizeof(float); + if (is_c4) { + res.smem *= 4; + } + + res.c4 = is_c4; + return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 0c2d75b65..66360c5b1 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1025,7 +1025,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_LOG: - return ggml_is_contiguous_rows(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_TANH: @@ -1045,7 +1045,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_EXPM1: - return ggml_is_contiguous_rows(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); default: return false; } diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 952e1be07..383e0d6e9 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -82,6 +82,7 @@ #define FC_COUNT_EQUAL 1100 #define FC_UNARY 1200 #define FC_BIN 1300 +#define FC_SUM_ROWS 1400 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPSG 8 @@ -118,6 +119,8 @@ #define OP_UNARY_NUM_SOFTPLUS 115 #define OP_UNARY_NUM_EXPM1 116 +#define OP_SUM_ROWS_NUM_SUM_ROWS 10 +#define OP_SUM_ROWS_NUM_MEAN 11 // kernel argument structs // diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 7db95d1c8..20880d955 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -904,6 +904,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + ggml_metal_kargs_sum_rows args = { /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, @@ -925,21 +930,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op); + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } + int nth = 32; // SIMD width - while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nth *= 2; } nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - nth = std::min(nth, ne00); + nth = std::min(nth, (int) args.ne00); const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index a385a50b9..6c349aa0c 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -77,6 +77,14 @@ static inline float dot(float x, float y) { return x*y; } +static inline float sum(float x) { + return x; +} + +static inline float sum(float4 x) { + return x[0] + x[1] + x[2] + x[3]; +} + // NOTE: this is not dequantizing - we are simply fitting the template template void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { @@ -910,7 +918,7 @@ constant float a4_erf = -1.453152027f; constant float a5_erf = 1.061405429f; template -T erf_approx(T x) { +inline T erf_approx(T x) { T sign_x = sign(x); x = fabs(x); T t = 1.0f / (1.0f + p_erf * x); @@ -918,10 +926,27 @@ T erf_approx(T x) { return sign_x * y; } +template T elu_approx(T x); + +template<> inline float elu_approx(float x) { + return (x > 0.f) ? x : (exp(x) - 1); +} + +template<> inline float4 elu_approx(float4 x) { + float4 res; + + res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f); + res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f); + res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f); + res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f); + + return res; +} + constant short FC_unary_op [[function_constant(FC_UNARY + 0)]]; constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]]; -template +template kernel void kernel_unary_impl( constant ggml_metal_kargs_unary & args, device const char * src0, @@ -963,111 +988,111 @@ kernel void kernel_unary_impl( } } - device const T0 & x = src0_ptr[i0]; + const TC x = (TC) src0_ptr[i0]; if (FC_OP == OP_UNARY_NUM_SCALE) { - dst_ptr[i0] = args.scale * x + args.bias; + dst_ptr[i0] = (T) (args.scale * x + args.bias); } if (FC_OP == OP_UNARY_NUM_FILL) { - dst_ptr[i0] = args.val; + dst_ptr[i0] = (T) args.val; } if (FC_OP == OP_UNARY_NUM_CLAMP) { - dst_ptr[i0] = clamp(x, args.min, args.max); + dst_ptr[i0] = (T) clamp(x, args.min, args.max); } if (FC_OP == OP_UNARY_NUM_SQR) { - dst_ptr[i0] = x * x; + dst_ptr[i0] = (T) (x * x); } if (FC_OP == OP_UNARY_NUM_SQRT) { - dst_ptr[i0] = sqrt(x); + dst_ptr[i0] = (T) sqrt(x); } if (FC_OP == OP_UNARY_NUM_SIN) { - dst_ptr[i0] = sin(x); + dst_ptr[i0] = (T) sin(x); } if (FC_OP == OP_UNARY_NUM_COS) { - dst_ptr[i0] = cos(x); + dst_ptr[i0] = (T) cos(x); } if (FC_OP == OP_UNARY_NUM_LOG) { - dst_ptr[i0] = log(x); + dst_ptr[i0] = (T) log(x); } if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) { - dst_ptr[i0] = T(x > 0.0f)*x + T(x <= 0.0f)*(x * args.slope); + dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope)); } if (FC_OP == OP_UNARY_NUM_TANH) { - dst_ptr[i0] = precise::tanh(x); + dst_ptr[i0] = (T) precise::tanh(x); } if (FC_OP == OP_UNARY_NUM_RELU) { - dst_ptr[i0] = fmax(0.0f, x); + dst_ptr[i0] = (T) fmax(0, x); } if (FC_OP == OP_UNARY_NUM_SIGMOID) { - dst_ptr[i0] = 1.0f / (1.0f + exp(-x)); + dst_ptr[i0] = (T) (1 / (1 + exp(-x))); } if (FC_OP == OP_UNARY_NUM_GELU) { - dst_ptr[i0] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); + dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x)))); } if (FC_OP == OP_UNARY_NUM_GELU_ERF) { - dst_ptr[i0] = 0.5f*x*(1.0f + erf_approx(SQRT_2_INV*x)); + dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x))); } if (FC_OP == OP_UNARY_NUM_GELU_QUICK) { - dst_ptr[i0] = x * (1.0f/(1.0f + exp(GELU_QUICK_COEF*x))); + dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x)))); } if (FC_OP == OP_UNARY_NUM_SILU) { - dst_ptr[i0] = x / (1.0f + exp(-x)); + dst_ptr[i0] = (T) (x / (1 + exp(-x))); } if (FC_OP == OP_UNARY_NUM_ELU) { - dst_ptr[i0] = T(x > 0.0f)*x + T(x <= 0.0f)*(exp(x) - 1.0f); + dst_ptr[i0] = (T) elu_approx(x); } if (FC_OP == OP_UNARY_NUM_NEG) { - dst_ptr[i0] = -x; + dst_ptr[i0] = (T) -x; } if (FC_OP == OP_UNARY_NUM_ABS) { - dst_ptr[i0] = fabs(x); + dst_ptr[i0] = (T) fabs(x); } if (FC_OP == OP_UNARY_NUM_SGN) { - dst_ptr[i0] = T(x > 0.0f) - T(x < 0.0f); + dst_ptr[i0] = T(x > 0) - T(x < 0); } if (FC_OP == OP_UNARY_NUM_STEP) { - dst_ptr[i0] = T(x > 0.0f); + dst_ptr[i0] = T(x > 0); } if (FC_OP == OP_UNARY_NUM_HARDSWISH) { - dst_ptr[i0] = x * fmax(0.0f, fmin(1.0f, x/6.0f + 0.5f)); + dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5))); } if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) { - dst_ptr[i0] = fmax(0.0f, fmin(1.0f, x/6.0f + 0.5f)); + dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5)); } if (FC_OP == OP_UNARY_NUM_EXP) { - dst_ptr[i0] = exp(x); + dst_ptr[i0] = (T) exp(x); } if (FC_OP == OP_UNARY_NUM_SOFTPLUS) { - dst_ptr[i0] = select(log(1.0f + exp(x)), x, x > 20.0f); + dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20); } if (FC_OP == OP_UNARY_NUM_EXPM1) { // TODO: precise implementation - dst_ptr[i0] = exp(x) - 1.0f; + dst_ptr[i0] = (T) (exp(x) - 1); } } @@ -1075,11 +1100,12 @@ kernel void kernel_unary_impl( #undef FC_CNT } -typedef decltype(kernel_unary_impl) kernel_unary_t; - -template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl; -template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl; +typedef decltype(kernel_unary_impl) kernel_unary_t; +template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl; +template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl; +template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl; +template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl; // OP: 0 - add, 1 - sub, 2 - mul, 3 - div constant short FC_bin_op [[function_constant(FC_BIN + 0)]]; @@ -1483,33 +1509,35 @@ kernel void kernel_op_sum_f32( } } -template -kernel void kernel_sum_rows( +constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]]; + +template +kernel void kernel_sum_rows_impl( constant ggml_metal_kargs_sum_rows & args, - device const float * src0, - device float * dst, - threadgroup float * shmem_f32 [[threadgroup(0)]], + device const char * src0, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - int64_t i3 = tgpig.z; - int64_t i2 = tgpig.y; - int64_t i1 = tgpig.x; +#define FC_OP FC_sum_rows_op - if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { - return; - } + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + threadgroup T0 * shmem_t = (threadgroup T0 *) shmem; if (sgitg == 0) { - shmem_f32[tiisg] = 0.0f; + shmem_t[tiisg] = 0.0f; } - device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); - device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); + device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); - float sumf = 0; + T0 sumf = T0(0.0f); for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { sumf += src_row[i0]; @@ -1520,23 +1548,33 @@ kernel void kernel_sum_rows( threadgroup_barrier(mem_flags::mem_threadgroup); if (tiisg == 0) { - shmem_f32[sgitg] = sumf; + shmem_t[sgitg] = sumf; } threadgroup_barrier(mem_flags::mem_threadgroup); - sumf = shmem_f32[tiisg]; + sumf = shmem_t[tiisg]; sumf = simd_sum(sumf); if (tpitg.x == 0) { - dst_row[0] = norm ? sumf / args.ne00 : sumf; + if (FC_OP == OP_SUM_ROWS_NUM_MEAN) { + if (is_same::value) { + dst_row[0] = sum(sumf) / (4*args.ne00); + } else { + dst_row[0] = sum(sumf) / args.ne00; + } + } else { + dst_row[0] = sum(sumf); + } } + +#undef FC_OP } -typedef decltype(kernel_sum_rows) kernel_sum_rows_t; +typedef decltype(kernel_sum_rows_impl) kernel_sum_rows_t; -template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; -template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; +template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl; +template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl; template kernel void kernel_cumsum_blk( @@ -2417,9 +2455,6 @@ kernel void kernel_solve_tri_f32( const short K = FC_solve_tri_k; const short NP = PAD2(N, NW); - const int32_t ne02 = args.ne02; - const int32_t ne03 = args.ne03; - const int32_t i03 = tgpig.z; const int32_t i02 = tgpig.y; const int32_t i01 = tgpig.x*NSG + sgitg; @@ -5931,7 +5966,7 @@ kernel void kernel_flash_attn_ext_vec( static_assert(DK4 % NL == 0, "DK4 must be divisible by NL"); static_assert(DV4 % NL == 0, "DV4 must be divisible by NL"); - const short T = PK + NSG*SH; // shared memory size per query in (half) + //const short T = PK + NSG*SH; // shared memory size per query in (half) //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t @@ -8519,7 +8554,9 @@ kernel void kernel_mul_mm( threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); +#ifdef GGML_METAL_HAS_TENSOR threadgroup float * sc = (threadgroup float *)(shmem); +#endif constexpr int NR0 = 64; constexpr int NR1 = 32; @@ -8642,8 +8679,8 @@ kernel void kernel_mul_mm( const short sx = (tiitg%NL1); const short sy = (tiitg/NL1)/8; - const short dx = sx; - const short dy = sy; + //const short dx = sx; + //const short dy = sy; const short ly = (tiitg/NL1)%8; @@ -8892,7 +8929,9 @@ kernel void kernel_mul_mm_id( threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); +#ifdef GGML_METAL_HAS_TENSOR threadgroup float * sc = (threadgroup float *)(shmem); +#endif constexpr int NR0 = 64; constexpr int NR1 = 32; @@ -9027,8 +9066,8 @@ kernel void kernel_mul_mm_id( const short sx = (tiitg%NL1); const short sy = (tiitg/NL1)/8; - const short dx = sx; - const short dy = sy; + //const short dx = sx; + //const short dy = sy; const short ly = (tiitg/NL1)%8; diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl new file mode 100644 index 000000000..3602c92fe --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl @@ -0,0 +1,158 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 2 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q6_k_f32_l4_lm( + global uchar * src0_ql, + global uchar * src0_qh, + global char * src0_s, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + + int ib = idx / 128; // 2 values per idx + int iqs = idx % 128; // 0..127 + + int n = iqs / 64; // 0,1 + int b = (iqs % 64) / 32; // 0,1 + int is_b = (iqs % 16) / 8; // 0,1 + int qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + int is = 8 * n + qhshift + is_b; // 0..15 + int qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 + int qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + + float dscale = (float)src0_d[ib] * (float)src0_s[ib*16 + is]; + + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 0] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 0] >> qhshift) & 3) << 4)) - 32); + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 1] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 1] >> qhshift) & 3) << 4)) - 32); + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl new file mode 100644 index 000000000..71ab98982 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl @@ -0,0 +1,180 @@ +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// block_q4_K +//------------------------------------------------------------------------------ +#define QK_K 256 +#define K_SCALE_SIZE 12 + +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + + uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uchar qs[QK_K/2]; // 4-bit quants +} block_q4_K; + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // number of rows each SIMD group works on +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // SIMD group size +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +#undef BLOCK_STRIDE +// number of (super) blocks each subgroup processes +// each thread in a subgroup processes a block (32 weights) +#define BLOCK_STRIDE (N_SIMDWIDTH/8) + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q4_K_f32( + global char * src0, + int offset0, + global char * src1, + int offset1, + global char * dst, + int offsetd, + int ne00, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + ushort kmask1 = 0x3f3f; + ushort kmask2 = 0x0f0f; + ushort kmask3 = 0xc0c0; + + int ix = get_sub_group_local_id()/8; // super block index + int it = get_sub_group_local_id()%8; // block index (inside super block) + int iq = it/4; // 0 or 1 - first or second half of the super block + int ir = it%4; // 0...3 - block index in the half super block + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + int offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global block_q4_K * x = (global block_q4_K *) (src0 + offset_src0); + global float * y = (global float *) (src1 + offset_src1); + + float yl[16]; + float yh[16]; + float sumf[N_DST] = {0.f}; + float all_sum; + + global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + ushort sc16[4]; + uchar * sc8 = (uchar *)sc16; + + for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+0]; + sumy.s0 += yl[i+0]; + + yl[i+8] = y4[i+32]; + sumy.s1 += yl[i+8]; + + yh[i+0] = y4[i+128]; + sumy.s2 += yh[i+0]; + + yh[i+8] = y4[i+160]; + sumy.s3 += yh[i+8]; + } + + global ushort * sc = (global ushort *)x[ib].scales + iq; + global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir; + global half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + global ushort * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F); + acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00); + acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0); + acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000); + acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F); + acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00); + acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0); + acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] + + (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f + + (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] + + (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) - + dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]); + + q1 += nb01/2; + sc += nb01/2; + dh += nb01/2; + } + + y4 += BLOCK_STRIDE * QK_K; + } + + global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0; + + for (int row = 0; row < N_DST; ++row) { + all_sum = sub_group_reduce_add(sumf[row]); + if (first_row + row < ne01) { + if (get_sub_group_local_id() == 0) { + dst_f32[first_row + row] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 6756bab03..143d6b836 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5765,7 +5765,7 @@ static struct ggml_tensor * ggml_unary_impl( struct ggml_tensor * a, enum ggml_unary_op op, bool inplace) { - GGML_ASSERT(ggml_is_contiguous_1(a)); + GGML_ASSERT(ggml_is_contiguous_rows(a)); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index e3b06f490..ae2c8f77a 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/webui/src/app.css b/tools/server/webui/src/app.css index 9705040a4..3ab21f0cc 100644 --- a/tools/server/webui/src/app.css +++ b/tools/server/webui/src/app.css @@ -14,11 +14,11 @@ --popover-foreground: oklch(0.145 0 0); --primary: oklch(0.205 0 0); --primary-foreground: oklch(0.985 0 0); - --secondary: oklch(0.97 0 0); + --secondary: oklch(0.95 0 0); --secondary-foreground: oklch(0.205 0 0); --muted: oklch(0.97 0 0); --muted-foreground: oklch(0.556 0 0); - --accent: oklch(0.97 0 0); + --accent: oklch(0.95 0 0); --accent-foreground: oklch(0.205 0 0); --destructive: oklch(0.577 0.245 27.325); --border: oklch(0.875 0 0); @@ -37,7 +37,7 @@ --sidebar-accent-foreground: oklch(0.205 0 0); --sidebar-border: oklch(0.922 0 0); --sidebar-ring: oklch(0.708 0 0); - --code-background: oklch(0.975 0 0); + --code-background: oklch(0.985 0 0); --code-foreground: oklch(0.145 0 0); --layer-popover: 1000000; } @@ -51,7 +51,7 @@ --popover-foreground: oklch(0.985 0 0); --primary: oklch(0.922 0 0); --primary-foreground: oklch(0.205 0 0); - --secondary: oklch(0.269 0 0); + --secondary: oklch(0.29 0 0); --secondary-foreground: oklch(0.985 0 0); --muted: oklch(0.269 0 0); --muted-foreground: oklch(0.708 0 0); @@ -116,12 +116,62 @@ --color-sidebar-ring: var(--sidebar-ring); } +:root { + --chat-form-area-height: 8rem; + --chat-form-area-offset: 2rem; + --max-message-height: max(24rem, min(80dvh, calc(100dvh - var(--chat-form-area-height) - 12rem))); +} + +@media (min-width: 640px) { + :root { + --chat-form-area-height: 24rem; + --chat-form-area-offset: 12rem; + } +} + @layer base { * { @apply border-border outline-ring/50; } + body { @apply bg-background text-foreground; + scrollbar-width: thin; + scrollbar-gutter: stable; + } + + /* Global scrollbar styling - visible only on hover */ + * { + scrollbar-width: thin; + scrollbar-color: transparent transparent; + transition: scrollbar-color 0.2s ease; + } + + *:hover { + scrollbar-color: hsl(var(--muted-foreground) / 0.3) transparent; + } + + *::-webkit-scrollbar { + width: 6px; + height: 6px; + } + + *::-webkit-scrollbar-track { + background: transparent; + } + + *::-webkit-scrollbar-thumb { + background: transparent; + border-radius: 3px; + transition: background 0.2s ease; + } + + *:hover::-webkit-scrollbar-thumb { + background: hsl(var(--muted-foreground) / 0.3); + } + + *::-webkit-scrollbar-thumb:hover { + background: hsl(var(--muted-foreground) / 0.5); } } diff --git a/tools/server/webui/src/lib/components/app/actions/ActionIcon.svelte b/tools/server/webui/src/lib/components/app/actions/ActionIcon.svelte new file mode 100644 index 000000000..4494ea880 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/actions/ActionIcon.svelte @@ -0,0 +1,48 @@ + + + + + + + + +

{tooltip}

+
+
diff --git a/tools/server/webui/src/lib/components/app/actions/ActionIconCopyToClipboard.svelte b/tools/server/webui/src/lib/components/app/actions/ActionIconCopyToClipboard.svelte new file mode 100644 index 000000000..bf6cd4fb2 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/actions/ActionIconCopyToClipboard.svelte @@ -0,0 +1,18 @@ + + + canCopy && copyToClipboard(text)} +/> diff --git a/tools/server/webui/src/lib/components/app/actions/ActionIconRemove.svelte b/tools/server/webui/src/lib/components/app/actions/ActionIconRemove.svelte new file mode 100644 index 000000000..1ae3d2177 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/actions/ActionIconRemove.svelte @@ -0,0 +1,26 @@ + + + diff --git a/tools/server/webui/src/lib/components/app/actions/ActionIconsCodeBlock.svelte b/tools/server/webui/src/lib/components/app/actions/ActionIconsCodeBlock.svelte new file mode 100644 index 000000000..54ff0af1a --- /dev/null +++ b/tools/server/webui/src/lib/components/app/actions/ActionIconsCodeBlock.svelte @@ -0,0 +1,46 @@ + + +
+
+ +
+ + {#if showPreview} + + {/if} +
diff --git a/tools/server/webui/src/lib/components/app/actions/index.ts b/tools/server/webui/src/lib/components/app/actions/index.ts new file mode 100644 index 000000000..43485c7b7 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/actions/index.ts @@ -0,0 +1,19 @@ +/** + * + * ACTIONS + * + * Small interactive components for user actions. + * + */ + +/** Styled icon button for action triggers with tooltip. */ +export { default as ActionIcon } from './ActionIcon.svelte'; + +/** Code block actions component (copy, preview). */ +export { default as ActionIconsCodeBlock } from './ActionIconsCodeBlock.svelte'; + +/** Copy-to-clipboard icon button with click handler. */ +export { default as ActionIconCopyToClipboard } from './ActionIconCopyToClipboard.svelte'; + +/** Remove/delete icon button with X icon. */ +export { default as ActionIconRemove } from './ActionIconRemove.svelte'; diff --git a/tools/server/webui/src/lib/components/app/badges/BadgeChatStatistic.svelte b/tools/server/webui/src/lib/components/app/badges/BadgeChatStatistic.svelte new file mode 100644 index 000000000..a2b28d205 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/badges/BadgeChatStatistic.svelte @@ -0,0 +1,44 @@ + + +{#if tooltipLabel} + + + + {#snippet icon()} + + {/snippet} + + {value} + + + +

{tooltipLabel}

+
+
+{:else} + + {#snippet icon()} + + {/snippet} + + {value} + +{/if} diff --git a/tools/server/webui/src/lib/components/app/badges/BadgeInfo.svelte b/tools/server/webui/src/lib/components/app/badges/BadgeInfo.svelte new file mode 100644 index 000000000..c70af6f42 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/badges/BadgeInfo.svelte @@ -0,0 +1,27 @@ + + + diff --git a/tools/server/webui/src/lib/components/app/badges/BadgeModality.svelte b/tools/server/webui/src/lib/components/app/badges/BadgeModality.svelte new file mode 100644 index 000000000..a0d5e863c --- /dev/null +++ b/tools/server/webui/src/lib/components/app/badges/BadgeModality.svelte @@ -0,0 +1,39 @@ + + +{#each displayableModalities as modality, index (index)} + {@const IconComponent = MODALITY_ICONS[modality]} + {@const label = MODALITY_LABELS[modality]} + + + {#if IconComponent} + + {/if} + + {label} + +{/each} diff --git a/tools/server/webui/src/lib/components/app/badges/index.ts b/tools/server/webui/src/lib/components/app/badges/index.ts new file mode 100644 index 000000000..860afe308 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/badges/index.ts @@ -0,0 +1,16 @@ +/** + * + * BADGES & INDICATORS + * + * Small visual indicators for status and metadata. + * + */ + +/** Badge displaying chat statistics (tokens, timing). */ +export { default as BadgeChatStatistic } from './BadgeChatStatistic.svelte'; + +/** Generic info badge with optional tooltip and click handler. */ +export { default as BadgeInfo } from './BadgeInfo.svelte'; + +/** Badge indicating model modality (vision, audio, tools). */ +export { default as BadgeModality } from './BadgeModality.svelte'; diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte index 27ab975cb..95645295f 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte @@ -27,11 +27,13 @@ interface Props { class?: string; disabled?: boolean; + initialMessage?: string; isLoading?: boolean; onFileRemove?: (fileId: string) => void; onFileUpload?: (files: File[]) => void; onSend?: (message: string, files?: ChatUploadedFile[]) => Promise; onStop?: () => void; + onSystemPromptAdd?: (draft: { message: string; files: ChatUploadedFile[] }) => void; showHelperText?: boolean; uploadedFiles?: ChatUploadedFile[]; } @@ -39,11 +41,13 @@ let { class: className, disabled = false, + initialMessage = '', isLoading = false, onFileRemove, onFileUpload, onSend, onStop, + onSystemPromptAdd, showHelperText = true, uploadedFiles = $bindable([]) }: Props = $props(); @@ -53,15 +57,28 @@ let currentConfig = $derived(config()); let fileInputRef: ChatFormFileInputInvisible | undefined = $state(undefined); let isRecording = $state(false); - let message = $state(''); + let message = $state(initialMessage); let pasteLongTextToFileLength = $derived.by(() => { const n = Number(currentConfig.pasteLongTextToFileLen); return Number.isNaN(n) ? Number(SETTING_CONFIG_DEFAULT.pasteLongTextToFileLen) : n; }); let previousIsLoading = $state(isLoading); + let previousInitialMessage = $state(initialMessage); let recordingSupported = $state(false); let textareaRef: ChatFormTextarea | undefined = $state(undefined); + // Sync message when initialMessage prop changes (e.g., after draft restoration) + $effect(() => { + if (initialMessage !== previousInitialMessage) { + message = initialMessage; + previousInitialMessage = initialMessage; + } + }); + + function handleSystemPromptClick() { + onSystemPromptAdd?.({ message, files: uploadedFiles }); + } + // Check if model is selected (in ROUTER mode) let conversationModel = $derived( chatStore.getConversationModel(activeMessages() as DatabaseMessage[]) @@ -308,6 +325,7 @@ onFileUpload={handleFileUpload} onMicClick={handleMicClick} onStop={handleStop} + onSystemPromptClick={handleSystemPromptClick} /> diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte index dd3726809..3545b4aeb 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte @@ -1,5 +1,6 @@ + + { + open = value; + onToggle?.(); + }} + class={className} +> + + +
+ {#if Icon} + + {/if} + + {title} + + {#if subtitle} + {subtitle} + {/if} +
+ +
+ + + Toggle content +
+
+ + +
+ {@render children()} +
+
+
+
diff --git a/tools/server/webui/src/lib/components/app/content/MarkdownContent.svelte b/tools/server/webui/src/lib/components/app/content/MarkdownContent.svelte new file mode 100644 index 000000000..ef6c7e064 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/content/MarkdownContent.svelte @@ -0,0 +1,1201 @@ + + +
+ {#each renderedBlocks as block (block.id)} +
+ + {@html block.html} +
+ {/each} + + {#if unstableBlockHtml} +
+ + {@html unstableBlockHtml} +
+ {/if} + + {#if incompleteCodeBlock} +
+
+ {incompleteCodeBlock.language || 'text'} + { + previewCode = code; + previewLanguage = lang; + previewDialogOpen = true; + }} + /> +
+
streamingAutoScroll.handleScroll()} + > +
{@html highlightCode(
+							incompleteCodeBlock.code,
+							incompleteCodeBlock.language || 'text'
+						)}
+
+
+ {/if} +
+ + + + diff --git a/tools/server/webui/src/lib/components/app/content/SyntaxHighlightedCode.svelte b/tools/server/webui/src/lib/components/app/content/SyntaxHighlightedCode.svelte new file mode 100644 index 000000000..625fdc7b1 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/content/SyntaxHighlightedCode.svelte @@ -0,0 +1,95 @@ + + +
+ +
{@html highlightedHtml}
+
+ + diff --git a/tools/server/webui/src/lib/components/app/content/index.ts b/tools/server/webui/src/lib/components/app/content/index.ts new file mode 100644 index 000000000..bca1c9f4c --- /dev/null +++ b/tools/server/webui/src/lib/components/app/content/index.ts @@ -0,0 +1,79 @@ +/** + * + * CONTENT RENDERING + * + * Components for rendering rich content: markdown, code, and previews. + * + */ + +/** + * **MarkdownContent** - Rich markdown renderer + * + * Renders markdown content with syntax highlighting, LaTeX math, + * tables, links, and code blocks. Optimized for streaming with + * incremental block-based rendering. + * + * **Features:** + * - GFM (GitHub Flavored Markdown): tables, task lists, strikethrough + * - LaTeX math via KaTeX (`$inline$` and `$$block$$`) + * - Syntax highlighting (highlight.js) with language detection + * - Code copy buttons with click feedback + * - External links open in new tab with security attrs + * - Image attachment resolution from message extras + * - Dark/light theme support (auto-switching) + * - Streaming-optimized incremental rendering + * - Code preview dialog for large blocks + * + * @example + * ```svelte + * + * ``` + */ +export { default as MarkdownContent } from './MarkdownContent.svelte'; + +/** + * **SyntaxHighlightedCode** - Code syntax highlighting + * + * Renders code with syntax highlighting using highlight.js. + * Supports theme switching and scrollable containers. + * + * **Features:** + * - Auto language detection with fallback + * - Dark/light theme auto-switching + * - Scrollable container with configurable max dimensions + * - Monospace font styling + * - Preserves whitespace and formatting + * + * @example + * ```svelte + * + * ``` + */ +export { default as SyntaxHighlightedCode } from './SyntaxHighlightedCode.svelte'; + +/** + * **CollapsibleContentBlock** - Expandable content card + * + * Reusable collapsible card with header, icon, and auto-scroll. + * Used for tool calls and reasoning blocks in chat messages. + * + * **Features:** + * - Collapsible content with smooth animation + * - Custom icon and title display + * - Optional subtitle/status text + * - Auto-scroll during streaming (pauses on user scroll) + * - Configurable max height with overflow scroll + * + * @example + * ```svelte + * + * {reasoningContent} + * + * ``` + */ +export { default as CollapsibleContentBlock } from './CollapsibleContentBlock.svelte'; diff --git a/tools/server/webui/src/lib/components/app/misc/ConversationSelection.svelte b/tools/server/webui/src/lib/components/app/misc/ConversationSelection.svelte index e2095e087..21412f47e 100644 --- a/tools/server/webui/src/lib/components/app/misc/ConversationSelection.svelte +++ b/tools/server/webui/src/lib/components/app/misc/ConversationSelection.svelte @@ -17,9 +17,13 @@ let { conversations, messageCountMap = new Map(), mode, onCancel, onConfirm }: Props = $props(); let searchQuery = $state(''); - let selectedIds = $state.raw>(new SvelteSet(conversations.map((c) => c.id))); + let selectedIds = $state.raw>(getInitialSelectedIds()); let lastClickedId = $state(null); + function getInitialSelectedIds(): SvelteSet { + return new SvelteSet(conversations.map((c) => c.id)); + } + let filteredConversations = $derived( conversations.filter((conv) => { const name = conv.name || 'Untitled conversation'; @@ -92,7 +96,7 @@ } function handleCancel() { - selectedIds = new SvelteSet(conversations.map((c) => c.id)); + selectedIds = getInitialSelectedIds(); searchQuery = ''; lastClickedId = null; @@ -100,7 +104,7 @@ } export function reset() { - selectedIds = new SvelteSet(conversations.map((c) => c.id)); + selectedIds = getInitialSelectedIds(); searchQuery = ''; lastClickedId = null; } diff --git a/tools/server/webui/src/lib/components/app/misc/HorizontalScrollCarousel.svelte b/tools/server/webui/src/lib/components/app/misc/HorizontalScrollCarousel.svelte new file mode 100644 index 000000000..e302f83e1 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/misc/HorizontalScrollCarousel.svelte @@ -0,0 +1,93 @@ + + +
+ + +
+ {@render children?.()} +
+ + +
diff --git a/tools/server/webui/src/lib/components/app/misc/KeyboardShortcutInfo.svelte b/tools/server/webui/src/lib/components/app/misc/KeyboardShortcutInfo.svelte index 5b7522fe1..da55abda0 100644 --- a/tools/server/webui/src/lib/components/app/misc/KeyboardShortcutInfo.svelte +++ b/tools/server/webui/src/lib/components/app/misc/KeyboardShortcutInfo.svelte @@ -11,7 +11,9 @@ let baseClasses = 'px-1 pointer-events-none inline-flex select-none items-center gap-0.5 font-sans text-md font-medium opacity-0 transition-opacity -my-1'; - let variantClasses = variant === 'destructive' ? 'text-destructive' : 'text-muted-foreground'; + let variantClasses = $derived( + variant === 'destructive' ? 'text-destructive' : 'text-muted-foreground' + ); diff --git a/tools/server/webui/src/lib/components/app/misc/TruncatedText.svelte b/tools/server/webui/src/lib/components/app/misc/TruncatedText.svelte new file mode 100644 index 000000000..9a8731fc7 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/misc/TruncatedText.svelte @@ -0,0 +1,48 @@ + + +{#if isTruncated} + + + + {text} + + + + +

{text}

+
+
+{:else} + + {text} + +{/if} diff --git a/tools/server/webui/src/lib/components/app/misc/index.ts b/tools/server/webui/src/lib/components/app/misc/index.ts new file mode 100644 index 000000000..02bd70b24 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/misc/index.ts @@ -0,0 +1,45 @@ +/** + * + * MISC + * + * Miscellaneous utility components. + * + */ + +/** + * **ConversationSelection** - Multi-select conversation picker + * + * List of conversations with checkboxes for multi-selection. + * Used in import/export dialogs for selecting conversations. + * + * **Features:** + * - Search/filter conversations by name + * - Select all / deselect all controls + * - Shift-click for range selection + * - Message count display per conversation + * - Mode-specific UI (export vs import) + */ +export { default as ConversationSelection } from './ConversationSelection.svelte'; + +/** + * Horizontal scrollable carousel with navigation arrows. + * Used for displaying items in a horizontally scrollable container + * with left/right navigation buttons that appear on hover. + */ +export { default as HorizontalScrollCarousel } from './HorizontalScrollCarousel.svelte'; + +/** + * **TruncatedText** - Text with ellipsis and tooltip + * + * Displays text with automatic truncation and full content in tooltip. + * Useful for long names or paths in constrained spaces. + */ +export { default as TruncatedText } from './TruncatedText.svelte'; + +/** + * **KeyboardShortcutInfo** - Keyboard shortcut hint display + * + * Displays keyboard shortcut hints (e.g., "⌘ + Enter"). + * Supports special keys like shift, cmd, and custom text. + */ +export { default as KeyboardShortcutInfo } from './KeyboardShortcutInfo.svelte'; diff --git a/tools/server/webui/src/lib/components/app/navigation/DropdownMenuActions.svelte b/tools/server/webui/src/lib/components/app/navigation/DropdownMenuActions.svelte new file mode 100644 index 000000000..83d856d10 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/navigation/DropdownMenuActions.svelte @@ -0,0 +1,86 @@ + + + + e.stopPropagation()} + > + {#if triggerTooltip} + + + {@render iconComponent(triggerIcon, 'h-3 w-3')} + {triggerTooltip} + + +

{triggerTooltip}

+
+
+ {:else} + {@render iconComponent(triggerIcon, 'h-3 w-3')} + {/if} +
+ + + {#each actions as action, index (action.label)} + {#if action.separator && index > 0} + + {/if} + + +
+ {@render iconComponent( + action.icon, + `h-4 w-4 ${action.variant === 'destructive' ? 'text-destructive' : ''}` + )} + {action.label} +
+ + {#if action.shortcut} + + {/if} +
+ {/each} +
+
+ +{#snippet iconComponent(IconComponent: Component, className: string)} + +{/snippet} diff --git a/tools/server/webui/src/lib/components/app/navigation/DropdownMenuSearchable.svelte b/tools/server/webui/src/lib/components/app/navigation/DropdownMenuSearchable.svelte new file mode 100644 index 000000000..3bd68d3bd --- /dev/null +++ b/tools/server/webui/src/lib/components/app/navigation/DropdownMenuSearchable.svelte @@ -0,0 +1,50 @@ + + +
+ +
+ +
+ {@render children()} + + {#if isEmpty} +
{emptyMessage}
+ {/if} +
+ +{#if footer} + + + {@render footer()} +{/if} diff --git a/tools/server/webui/src/lib/components/app/navigation/index.ts b/tools/server/webui/src/lib/components/app/navigation/index.ts new file mode 100644 index 000000000..051491b86 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/navigation/index.ts @@ -0,0 +1,65 @@ +/** + * + * NAVIGATION & MENUS + * + * Components for dropdown menus and action selection. + * + */ + +/** + * **DropdownMenuSearchable** - Searchable content for dropdown menus + * + * Renders a search input with filtered content area, empty state, and optional footer. + * Designed to be injected into any dropdown container (DropdownMenu.Content, + * DropdownMenu.SubContent, etc.) without providing its own Root. + * + * **Features:** + * - Search/filter input + * - Keyboard navigation support + * - Custom content and footer via snippets + * - Empty state message + * + * @example + * ```svelte + * + * ... + * + * + * {#each items as item}{/each} + * + * + * + * ``` + */ +export { default as DropdownMenuSearchable } from './DropdownMenuSearchable.svelte'; + +/** + * **DropdownMenuActions** - Multi-action dropdown menu + * + * Dropdown menu for multiple action options with icons and shortcuts. + * Supports destructive variants and keyboard shortcut hints. + * + * **Features:** + * - Configurable trigger icon with tooltip + * - Action items with icons and labels + * - Destructive variant styling + * - Keyboard shortcut display + * - Separator support between groups + * + * @example + * ```svelte + * + * ``` + */ +export { default as DropdownMenuActions } from './DropdownMenuActions.svelte'; diff --git a/tools/server/webui/src/lib/components/app/server/ServerErrorSplash.svelte b/tools/server/webui/src/lib/components/app/server/ServerErrorSplash.svelte index fa4c2842c..520e5bf56 100644 --- a/tools/server/webui/src/lib/components/app/server/ServerErrorSplash.svelte +++ b/tools/server/webui/src/lib/components/app/server/ServerErrorSplash.svelte @@ -8,6 +8,7 @@ import { serverStore, serverLoading } from '$lib/stores/server.svelte'; import { config, settingsStore } from '$lib/stores/settings.svelte'; import { fade, fly, scale } from 'svelte/transition'; + import { KeyboardKey } from '$lib/enums/keyboard'; interface Props { class?: string; @@ -117,7 +118,7 @@ } function handleApiKeyKeydown(event: KeyboardEvent) { - if (event.key === 'Enter') { + if (event.key === KeyboardKey.ENTER) { handleSaveApiKey(); } } diff --git a/tools/server/webui/src/lib/components/app/server/ServerStatus.svelte b/tools/server/webui/src/lib/components/app/server/ServerStatus.svelte index d9f6d4a32..86a962de1 100644 --- a/tools/server/webui/src/lib/components/app/server/ServerStatus.svelte +++ b/tools/server/webui/src/lib/components/app/server/ServerStatus.svelte @@ -48,7 +48,7 @@ {model || 'Unknown Model'} - {#if serverData.default_generation_settings.n_ctx} + {#if serverData?.default_generation_settings?.n_ctx} ctx: {serverData.default_generation_settings.n_ctx.toLocaleString()} diff --git a/tools/server/webui/src/lib/components/app/server/index.ts b/tools/server/webui/src/lib/components/app/server/index.ts new file mode 100644 index 000000000..39ac5b482 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/server/index.ts @@ -0,0 +1,80 @@ +/** + * + * SERVER + * + * Components for displaying server connection state and handling + * connection errors. Integrates with serverStore for state management. + * + */ + +/** + * **ServerStatus** - Server connection status indicator + * + * Compact status display showing connection state, model name, + * and context size. Used in headers and loading screens. + * + * **Architecture:** + * - Reads state from serverStore (props, loading, error) + * - Displays model name from modelsStore + * + * **Features:** + * - Status dot: green (connected), yellow (connecting), red (error), gray (unknown) + * - Status text label + * - Model name badge with icon + * - Context size badge + * - Optional error action button + * + * @example + * ```svelte + * + * ``` + */ +export { default as ServerStatus } from './ServerStatus.svelte'; + +/** + * **ServerErrorSplash** - Full-screen connection error display + * + * Blocking error screen shown when server connection fails. + * Provides retry options and API key input for authentication errors. + * + * **Architecture:** + * - Detects access denied errors for API key flow + * - Validates API key against server before saving + * - Integrates with settingsStore for API key persistence + * + * **Features:** + * - Error message display with icon + * - Retry connection button with loading state + * - API key input for authentication errors + * - API key validation with success/error feedback + * - Troubleshooting section with server start commands + * - Animated transitions for UI elements + * + * @example + * ```svelte + * + * ``` + */ +export { default as ServerErrorSplash } from './ServerErrorSplash.svelte'; + +/** + * **ServerLoadingSplash** - Full-screen loading display + * + * Shown during initial server connection. Displays loading animation + * with ServerStatus component for real-time connection state. + * + * **Features:** + * - Animated server icon + * - Customizable loading message + * - Embedded ServerStatus for live updates + * + * @example + * ```svelte + * + * ``` + */ +export { default as ServerLoadingSplash } from './ServerLoadingSplash.svelte'; diff --git a/tools/server/webui/src/lib/components/ui/badge/badge.svelte b/tools/server/webui/src/lib/components/ui/badge/badge.svelte index 4d1514549..c3e6ac072 100644 --- a/tools/server/webui/src/lib/components/ui/badge/badge.svelte +++ b/tools/server/webui/src/lib/components/ui/badge/badge.svelte @@ -42,7 +42,7 @@ bind:this={ref} data-slot="badge" {href} - class={cn(badgeVariants({ variant }), className)} + class={cn(badgeVariants({ variant }), className, 'backdrop-blur-sm')} {...restProps} > {@render children?.()} diff --git a/tools/server/webui/src/lib/components/ui/button/button.svelte b/tools/server/webui/src/lib/components/ui/button/button.svelte index d12c8de14..d29358c8e 100644 --- a/tools/server/webui/src/lib/components/ui/button/button.svelte +++ b/tools/server/webui/src/lib/components/ui/button/button.svelte @@ -12,8 +12,9 @@ 'bg-destructive shadow-xs hover:bg-destructive/90 focus-visible:ring-destructive/20 dark:focus-visible:ring-destructive/40 dark:bg-destructive/60 text-white', outline: 'bg-background shadow-xs hover:bg-accent hover:text-accent-foreground dark:bg-input/30 dark:border-input dark:hover:bg-input/50 border', - secondary: 'bg-secondary text-secondary-foreground shadow-xs hover:bg-secondary/80', - ghost: 'hover:bg-accent hover:text-accent-foreground dark:hover:bg-accent/50', + secondary: + 'dark:bg-secondary dark:text-secondary-foreground bg-background shadow-sm text-foreground hover:bg-muted-foreground/20', + ghost: 'hover:text-accent-foreground hover:bg-muted-foreground/10', link: 'text-primary underline-offset-4 hover:underline' }, size: { diff --git a/tools/server/webui/src/lib/components/ui/card/card.svelte b/tools/server/webui/src/lib/components/ui/card/card.svelte index c40d14309..b9dcd2de6 100644 --- a/tools/server/webui/src/lib/components/ui/card/card.svelte +++ b/tools/server/webui/src/lib/components/ui/card/card.svelte @@ -1,6 +1,7 @@ - +{#snippet tooltipContent()} {@render children?.()} @@ -44,4 +50,12 @@ {/snippet} - +{/snippet} + +{#if noPortal} + {@render tooltipContent()} +{:else} + + {@render tooltipContent()} + +{/if} diff --git a/tools/server/webui/src/lib/constants/binary-detection.ts b/tools/server/webui/src/lib/constants/binary-detection.ts index a4440fde5..eac919ad9 100644 --- a/tools/server/webui/src/lib/constants/binary-detection.ts +++ b/tools/server/webui/src/lib/constants/binary-detection.ts @@ -1,9 +1,6 @@ export interface BinaryDetectionOptions { - /** Number of characters to check from the beginning of the file */ prefixLength: number; - /** Maximum ratio of suspicious characters allowed (0.0 to 1.0) */ suspiciousCharThresholdRatio: number; - /** Maximum absolute number of null bytes allowed */ maxAbsoluteNullBytes: number; } diff --git a/tools/server/webui/src/lib/constants/chat-form.ts b/tools/server/webui/src/lib/constants/chat-form.ts new file mode 100644 index 000000000..c5e3dc3d1 --- /dev/null +++ b/tools/server/webui/src/lib/constants/chat-form.ts @@ -0,0 +1,3 @@ +export const INITIAL_FILE_SIZE = 0; +export const PROMPT_CONTENT_SEPARATOR = '\n\n'; +export const CLIPBOARD_CONTENT_QUOTE_PREFIX = '"'; diff --git a/tools/server/webui/src/lib/constants/code-blocks.ts b/tools/server/webui/src/lib/constants/code-blocks.ts new file mode 100644 index 000000000..0f7265104 --- /dev/null +++ b/tools/server/webui/src/lib/constants/code-blocks.ts @@ -0,0 +1,8 @@ +export const CODE_BLOCK_SCROLL_CONTAINER_CLASS = 'code-block-scroll-container'; +export const CODE_BLOCK_WRAPPER_CLASS = 'code-block-wrapper'; +export const CODE_BLOCK_HEADER_CLASS = 'code-block-header'; +export const CODE_BLOCK_ACTIONS_CLASS = 'code-block-actions'; +export const CODE_LANGUAGE_CLASS = 'code-language'; +export const COPY_CODE_BTN_CLASS = 'copy-code-btn'; +export const PREVIEW_CODE_BTN_CLASS = 'preview-code-btn'; +export const RELATIVE_CLASS = 'relative'; diff --git a/tools/server/webui/src/lib/constants/code.ts b/tools/server/webui/src/lib/constants/code.ts new file mode 100644 index 000000000..12bcd0db7 --- /dev/null +++ b/tools/server/webui/src/lib/constants/code.ts @@ -0,0 +1,7 @@ +export const NEWLINE = '\n'; +export const DEFAULT_LANGUAGE = 'text'; +export const LANG_PATTERN = /^(\w*)\n?/; +export const AMPERSAND_REGEX = /&/g; +export const LT_REGEX = //g; +export const FENCE_PATTERN = /^```|\n```/g; diff --git a/tools/server/webui/src/lib/constants/css-classes.ts b/tools/server/webui/src/lib/constants/css-classes.ts new file mode 100644 index 000000000..46076e55f --- /dev/null +++ b/tools/server/webui/src/lib/constants/css-classes.ts @@ -0,0 +1,10 @@ +export const BOX_BORDER = + 'border border-border/30 focus-within:border-border dark:border-border/20 dark:focus-within:border-border'; + +export const INPUT_CLASSES = ` + bg-muted/60 dark:bg-muted/75 + ${BOX_BORDER} + shadow-sm + outline-none + text-foreground +`; diff --git a/tools/server/webui/src/lib/constants/formatters.ts b/tools/server/webui/src/lib/constants/formatters.ts new file mode 100644 index 000000000..d6d1b883f --- /dev/null +++ b/tools/server/webui/src/lib/constants/formatters.ts @@ -0,0 +1,8 @@ +export const MS_PER_SECOND = 1000; +export const SECONDS_PER_MINUTE = 60; +export const SECONDS_PER_HOUR = 3600; +export const SHORT_DURATION_THRESHOLD = 1; +export const MEDIUM_DURATION_THRESHOLD = 10; + +/** Default display value when no performance time is available */ +export const DEFAULT_PERFORMANCE_TIME = '0s'; diff --git a/tools/server/webui/src/lib/constants/markdown.ts b/tools/server/webui/src/lib/constants/markdown.ts new file mode 100644 index 000000000..783d31a22 --- /dev/null +++ b/tools/server/webui/src/lib/constants/markdown.ts @@ -0,0 +1,4 @@ +export const IMAGE_NOT_ERROR_BOUND_SELECTOR = 'img:not([data-error-bound])'; +export const DATA_ERROR_BOUND_ATTR = 'errorBound'; +export const DATA_ERROR_HANDLED_ATTR = 'errorHandled'; +export const BOOL_TRUE_STRING = 'true'; diff --git a/tools/server/webui/src/lib/constants/processing-info.ts b/tools/server/webui/src/lib/constants/processing-info.ts index 726439211..2c3f7dc53 100644 --- a/tools/server/webui/src/lib/constants/processing-info.ts +++ b/tools/server/webui/src/lib/constants/processing-info.ts @@ -1 +1,8 @@ export const PROCESSING_INFO_TIMEOUT = 2000; + +/** + * Statistics units labels + */ +export const STATS_UNITS = { + TOKENS_PER_SECOND: 't/s' +} as const; diff --git a/tools/server/webui/src/lib/constants/settings-fields.ts b/tools/server/webui/src/lib/constants/settings-fields.ts new file mode 100644 index 000000000..79a6e9287 --- /dev/null +++ b/tools/server/webui/src/lib/constants/settings-fields.ts @@ -0,0 +1,33 @@ +/** + * List of all numeric fields in settings configuration. + * These fields will be converted from strings to numbers during save. + */ +export const NUMERIC_FIELDS = [ + 'temperature', + 'top_k', + 'top_p', + 'min_p', + 'max_tokens', + 'pasteLongTextToFileLen', + 'dynatemp_range', + 'dynatemp_exponent', + 'typ_p', + 'xtc_probability', + 'xtc_threshold', + 'repeat_last_n', + 'repeat_penalty', + 'presence_penalty', + 'frequency_penalty', + 'dry_multiplier', + 'dry_base', + 'dry_allowed_length', + 'dry_penalty_last_n', + 'agenticMaxTurns', + 'agenticMaxToolPreviewLines' +] as const; + +/** + * Fields that must be positive integers (>= 1). + * These will be clamped to minimum 1 and rounded during save. + */ +export const POSITIVE_INTEGER_FIELDS = ['agenticMaxTurns', 'agenticMaxToolPreviewLines'] as const; diff --git a/tools/server/webui/src/lib/constants/tooltip-config.ts b/tools/server/webui/src/lib/constants/tooltip-config.ts index 3c30c8c07..ad76ab352 100644 --- a/tools/server/webui/src/lib/constants/tooltip-config.ts +++ b/tools/server/webui/src/lib/constants/tooltip-config.ts @@ -1 +1 @@ -export const TOOLTIP_DELAY_DURATION = 100; +export const TOOLTIP_DELAY_DURATION = 500; diff --git a/tools/server/webui/src/lib/constants/ui.ts b/tools/server/webui/src/lib/constants/ui.ts new file mode 100644 index 000000000..a75b30f2f --- /dev/null +++ b/tools/server/webui/src/lib/constants/ui.ts @@ -0,0 +1 @@ +export const SYSTEM_MESSAGE_PLACEHOLDER = 'System message'; diff --git a/tools/server/webui/src/lib/contexts/chat-actions.context.ts b/tools/server/webui/src/lib/contexts/chat-actions.context.ts new file mode 100644 index 000000000..eba0fec02 --- /dev/null +++ b/tools/server/webui/src/lib/contexts/chat-actions.context.ts @@ -0,0 +1,34 @@ +import { getContext, setContext } from 'svelte'; + +export interface ChatActionsContext { + copy: (message: DatabaseMessage) => void; + delete: (message: DatabaseMessage) => void; + navigateToSibling: (siblingId: string) => void; + editWithBranching: ( + message: DatabaseMessage, + newContent: string, + newExtras?: DatabaseMessageExtra[] + ) => void; + editWithReplacement: ( + message: DatabaseMessage, + newContent: string, + shouldBranch: boolean + ) => void; + editUserMessagePreserveResponses: ( + message: DatabaseMessage, + newContent: string, + newExtras?: DatabaseMessageExtra[] + ) => void; + regenerateWithBranching: (message: DatabaseMessage, modelOverride?: string) => void; + continueAssistantMessage: (message: DatabaseMessage) => void; +} + +const CHAT_ACTIONS_KEY = Symbol.for('chat-actions'); + +export function setChatActionsContext(ctx: ChatActionsContext): ChatActionsContext { + return setContext(CHAT_ACTIONS_KEY, ctx); +} + +export function getChatActionsContext(): ChatActionsContext { + return getContext(CHAT_ACTIONS_KEY); +} diff --git a/tools/server/webui/src/lib/contexts/index.ts b/tools/server/webui/src/lib/contexts/index.ts new file mode 100644 index 000000000..73ff6f96f --- /dev/null +++ b/tools/server/webui/src/lib/contexts/index.ts @@ -0,0 +1,13 @@ +export { + getMessageEditContext, + setMessageEditContext, + type MessageEditContext, + type MessageEditState, + type MessageEditActions +} from './message-edit.context'; + +export { + getChatActionsContext, + setChatActionsContext, + type ChatActionsContext +} from './chat-actions.context'; diff --git a/tools/server/webui/src/lib/contexts/message-edit.context.ts b/tools/server/webui/src/lib/contexts/message-edit.context.ts new file mode 100644 index 000000000..7af116daa --- /dev/null +++ b/tools/server/webui/src/lib/contexts/message-edit.context.ts @@ -0,0 +1,39 @@ +import { getContext, setContext } from 'svelte'; + +export interface MessageEditState { + readonly isEditing: boolean; + readonly editedContent: string; + readonly editedExtras: DatabaseMessageExtra[]; + readonly editedUploadedFiles: ChatUploadedFile[]; + readonly originalContent: string; + readonly originalExtras: DatabaseMessageExtra[]; + readonly showSaveOnlyOption: boolean; +} + +export interface MessageEditActions { + setContent: (content: string) => void; + setExtras: (extras: DatabaseMessageExtra[]) => void; + setUploadedFiles: (files: ChatUploadedFile[]) => void; + save: () => void; + saveOnly: () => void; + cancel: () => void; + startEdit: () => void; +} + +export type MessageEditContext = MessageEditState & MessageEditActions; + +const MESSAGE_EDIT_KEY = Symbol.for('chat-message-edit'); + +/** + * Sets the message edit context. Call this in the parent component (ChatMessage.svelte). + */ +export function setMessageEditContext(ctx: MessageEditContext): MessageEditContext { + return setContext(MESSAGE_EDIT_KEY, ctx); +} + +/** + * Gets the message edit context. Call this in child components. + */ +export function getMessageEditContext(): MessageEditContext { + return getContext(MESSAGE_EDIT_KEY); +} diff --git a/tools/server/webui/src/lib/enums/chat.ts b/tools/server/webui/src/lib/enums/chat.ts index 2b9eb7bc2..0b6f357d9 100644 --- a/tools/server/webui/src/lib/enums/chat.ts +++ b/tools/server/webui/src/lib/enums/chat.ts @@ -1,4 +1,51 @@ export enum ChatMessageStatsView { GENERATION = 'generation', - READING = 'reading' + READING = 'reading', + TOOLS = 'tools', + SUMMARY = 'summary' +} + +/** + * Reasoning format options for API requests. + */ +export enum ReasoningFormat { + NONE = 'none', + AUTO = 'auto' +} + +/** + * Message roles for chat messages. + */ +export enum MessageRole { + USER = 'user', + ASSISTANT = 'assistant', + SYSTEM = 'system', + TOOL = 'tool' +} + +/** + * Message types for different content kinds. + */ +export enum MessageType { + ROOT = 'root', + TEXT = 'text', + THINK = 'think', + SYSTEM = 'system' +} + +/** + * Content part types for API chat message content. + */ +export enum ContentPartType { + TEXT = 'text', + IMAGE_URL = 'image_url', + INPUT_AUDIO = 'input_audio' +} + +/** + * Error dialog types for displaying server/timeout errors. + */ +export enum ErrorDialogType { + TIMEOUT = 'timeout', + SERVER = 'server' } diff --git a/tools/server/webui/src/lib/enums/keyboard.ts b/tools/server/webui/src/lib/enums/keyboard.ts new file mode 100644 index 000000000..b8f6d5f7a --- /dev/null +++ b/tools/server/webui/src/lib/enums/keyboard.ts @@ -0,0 +1,15 @@ +/** + * Keyboard key names for event handling + */ +export enum KeyboardKey { + ENTER = 'Enter', + ESCAPE = 'Escape', + ARROW_UP = 'ArrowUp', + ARROW_DOWN = 'ArrowDown', + TAB = 'Tab', + D_LOWER = 'd', + D_UPPER = 'D', + E_UPPER = 'E', + K_LOWER = 'k', + O_UPPER = 'O' +} diff --git a/tools/server/webui/src/lib/enums/settings.ts b/tools/server/webui/src/lib/enums/settings.ts new file mode 100644 index 000000000..f17f21976 --- /dev/null +++ b/tools/server/webui/src/lib/enums/settings.ts @@ -0,0 +1,26 @@ +/** + * Parameter source - indicates whether a parameter uses default or custom value + */ +export enum ParameterSource { + DEFAULT = 'default', + CUSTOM = 'custom' +} + +/** + * Syncable parameter type - data types for parameters that can be synced with server + */ +export enum SyncableParameterType { + NUMBER = 'number', + STRING = 'string', + BOOLEAN = 'boolean' +} + +/** + * Settings field type - defines the input type for settings fields + */ +export enum SettingsFieldType { + INPUT = 'input', + TEXTAREA = 'textarea', + CHECKBOX = 'checkbox', + SELECT = 'select' +} diff --git a/tools/server/webui/src/lib/hooks/use-auto-scroll.svelte.ts b/tools/server/webui/src/lib/hooks/use-auto-scroll.svelte.ts new file mode 100644 index 000000000..bbaa5d136 --- /dev/null +++ b/tools/server/webui/src/lib/hooks/use-auto-scroll.svelte.ts @@ -0,0 +1,165 @@ +import { AUTO_SCROLL_AT_BOTTOM_THRESHOLD, AUTO_SCROLL_INTERVAL } from '$lib/constants/auto-scroll'; + +export interface AutoScrollOptions { + /** Whether auto-scroll is disabled globally (e.g., from settings) */ + disabled?: boolean; +} + +/** + * Creates an auto-scroll controller for a scrollable container. + * + * Features: + * - Auto-scrolls to bottom during streaming/loading + * - Stops auto-scroll when user manually scrolls up + * - Resumes auto-scroll when user scrolls back to bottom + */ +export class AutoScrollController { + private _autoScrollEnabled = $state(true); + private _userScrolledUp = $state(false); + private _lastScrollTop = $state(0); + private _scrollInterval: ReturnType | undefined; + private _scrollTimeout: ReturnType | undefined; + private _container: HTMLElement | undefined; + private _disabled: boolean; + + constructor(options: AutoScrollOptions = {}) { + this._disabled = options.disabled ?? false; + } + + get autoScrollEnabled(): boolean { + return this._autoScrollEnabled; + } + + get userScrolledUp(): boolean { + return this._userScrolledUp; + } + + /** + * Binds the controller to a scrollable container element. + */ + setContainer(container: HTMLElement | undefined): void { + this._container = container; + } + + /** + * Updates the disabled state. + */ + setDisabled(disabled: boolean): void { + this._disabled = disabled; + if (disabled) { + this._autoScrollEnabled = false; + this.stopInterval(); + } + } + + /** + * Handles scroll events to detect user scroll direction and toggle auto-scroll. + */ + handleScroll(): void { + if (this._disabled || !this._container) return; + + const { scrollTop, scrollHeight, clientHeight } = this._container; + const distanceFromBottom = scrollHeight - scrollTop - clientHeight; + const isAtBottom = distanceFromBottom < AUTO_SCROLL_AT_BOTTOM_THRESHOLD; + + if (scrollTop < this._lastScrollTop && !isAtBottom) { + this._userScrolledUp = true; + this._autoScrollEnabled = false; + } else if (isAtBottom && this._userScrolledUp) { + this._userScrolledUp = false; + this._autoScrollEnabled = true; + } + + if (this._scrollTimeout) { + clearTimeout(this._scrollTimeout); + } + + this._scrollTimeout = setTimeout(() => { + if (isAtBottom) { + this._userScrolledUp = false; + this._autoScrollEnabled = true; + } + }, AUTO_SCROLL_INTERVAL); + + this._lastScrollTop = scrollTop; + } + + /** + * Scrolls the container to the bottom. + */ + scrollToBottom(behavior: ScrollBehavior = 'smooth'): void { + if (this._disabled || !this._container) return; + + this._container.scrollTo({ + top: this._container.scrollHeight, + behavior + }); + } + + /** + * Enables auto-scroll (e.g., when user sends a message). + */ + enable(): void { + if (this._disabled) return; + this._userScrolledUp = false; + this._autoScrollEnabled = true; + } + + /** + * Starts the auto-scroll interval for continuous scrolling during streaming. + */ + startInterval(): void { + if (this._disabled || this._scrollInterval) return; + + this._scrollInterval = setInterval(() => { + this.scrollToBottom(); + }, AUTO_SCROLL_INTERVAL); + } + + /** + * Stops the auto-scroll interval. + */ + stopInterval(): void { + if (this._scrollInterval) { + clearInterval(this._scrollInterval); + this._scrollInterval = undefined; + } + } + + /** + * Updates the auto-scroll interval based on streaming state. + * Call this in a $effect to automatically manage the interval. + */ + updateInterval(isStreaming: boolean): void { + if (this._disabled) { + this.stopInterval(); + return; + } + + if (isStreaming && this._autoScrollEnabled) { + if (!this._scrollInterval) { + this.startInterval(); + } + } else { + this.stopInterval(); + } + } + + /** + * Cleans up resources. Call this in onDestroy or when the component unmounts. + */ + destroy(): void { + this.stopInterval(); + if (this._scrollTimeout) { + clearTimeout(this._scrollTimeout); + this._scrollTimeout = undefined; + } + } +} + +/** + * Creates a new AutoScrollController instance. + */ +export function createAutoScrollController(options: AutoScrollOptions = {}): AutoScrollController { + return new AutoScrollController(options); +} diff --git a/tools/server/webui/src/lib/hooks/use-processing-state.svelte.ts b/tools/server/webui/src/lib/hooks/use-processing-state.svelte.ts index c06cf2886..068440cdc 100644 --- a/tools/server/webui/src/lib/hooks/use-processing-state.svelte.ts +++ b/tools/server/webui/src/lib/hooks/use-processing-state.svelte.ts @@ -1,7 +1,9 @@ import { activeProcessingState } from '$lib/stores/chat.svelte'; import { config } from '$lib/stores/settings.svelte'; +import { STATS_UNITS } from '$lib/constants/processing-info'; +import type { ApiProcessingState } from '$lib/types'; -export interface LiveProcessingStats { +interface LiveProcessingStats { tokensProcessed: number; totalTokens: number; timeMs: number; @@ -9,7 +11,7 @@ export interface LiveProcessingStats { etaSecs?: number; } -export interface LiveGenerationStats { +interface LiveGenerationStats { tokensGenerated: number; timeMs: number; tokensPerSecond: number; @@ -18,6 +20,7 @@ export interface LiveGenerationStats { export interface UseProcessingStateReturn { readonly processingState: ApiProcessingState | null; getProcessingDetails(): string[]; + getTechnicalDetails(): string[]; getProcessingMessage(): string; getPromptProgressText(): string | null; getLiveProcessingStats(): LiveProcessingStats | null; @@ -138,8 +141,31 @@ export function useProcessingState(): UseProcessingStateReturn { const details: string[] = []; + // Show prompt processing progress with ETA during preparation phase + if (stateToUse.promptProgress) { + const { processed, total, time_ms, cache } = stateToUse.promptProgress; + const actualProcessed = processed - cache; + const actualTotal = total - cache; + + if (actualProcessed < actualTotal && actualProcessed > 0) { + const percent = Math.round((actualProcessed / actualTotal) * 100); + const eta = getETASecs(actualProcessed, actualTotal, time_ms); + + if (eta !== undefined) { + const etaSecs = Math.ceil(eta); + details.push(`Processing ${percent}% (ETA: ${etaSecs}s)`); + } else { + details.push(`Processing ${percent}%`); + } + } + } + // Always show context info when we have valid data - if (stateToUse.contextUsed >= 0 && stateToUse.contextTotal > 0) { + if ( + typeof stateToUse.contextTotal === 'number' && + stateToUse.contextUsed >= 0 && + stateToUse.contextTotal > 0 + ) { const contextPercent = Math.round((stateToUse.contextUsed / stateToUse.contextTotal) * 100); details.push( @@ -163,7 +189,57 @@ export function useProcessingState(): UseProcessingStateReturn { } if (stateToUse.tokensPerSecond && stateToUse.tokensPerSecond > 0) { - details.push(`${stateToUse.tokensPerSecond.toFixed(1)} tokens/sec`); + details.push(`${stateToUse.tokensPerSecond.toFixed(1)} ${STATS_UNITS.TOKENS_PER_SECOND}`); + } + + if (stateToUse.speculative) { + details.push('Speculative decoding enabled'); + } + + return details; + } + + /** + * Returns technical details without the progress message (for bottom bar) + */ + function getTechnicalDetails(): string[] { + const stateToUse = processingState || lastKnownState; + if (!stateToUse) { + return []; + } + + const details: string[] = []; + + // Always show context info when we have valid data + if ( + typeof stateToUse.contextTotal === 'number' && + stateToUse.contextUsed >= 0 && + stateToUse.contextTotal > 0 + ) { + const contextPercent = Math.round((stateToUse.contextUsed / stateToUse.contextTotal) * 100); + + details.push( + `Context: ${stateToUse.contextUsed}/${stateToUse.contextTotal} (${contextPercent}%)` + ); + } + + if (stateToUse.outputTokensUsed > 0) { + // Handle infinite max_tokens (-1) case + if (stateToUse.outputTokensMax <= 0) { + details.push(`Output: ${stateToUse.outputTokensUsed}/∞`); + } else { + const outputPercent = Math.round( + (stateToUse.outputTokensUsed / stateToUse.outputTokensMax) * 100 + ); + + details.push( + `Output: ${stateToUse.outputTokensUsed}/${stateToUse.outputTokensMax} (${outputPercent}%)` + ); + } + } + + if (stateToUse.tokensPerSecond && stateToUse.tokensPerSecond > 0) { + details.push(`${stateToUse.tokensPerSecond.toFixed(1)} ${STATS_UNITS.TOKENS_PER_SECOND}`); } if (stateToUse.speculative) { @@ -251,6 +327,7 @@ export function useProcessingState(): UseProcessingStateReturn { return processingState; }, getProcessingDetails, + getTechnicalDetails, getProcessingMessage, getPromptProgressText, getLiveProcessingStats, diff --git a/tools/server/webui/src/lib/markdown/enhance-code-blocks.ts b/tools/server/webui/src/lib/markdown/enhance-code-blocks.ts index 6f0e03e21..168de9740 100644 --- a/tools/server/webui/src/lib/markdown/enhance-code-blocks.ts +++ b/tools/server/webui/src/lib/markdown/enhance-code-blocks.ts @@ -13,6 +13,16 @@ import type { Plugin } from 'unified'; import type { Root, Element, ElementContent } from 'hast'; import { visit } from 'unist-util-visit'; +import { + CODE_BLOCK_SCROLL_CONTAINER_CLASS, + CODE_BLOCK_WRAPPER_CLASS, + CODE_BLOCK_HEADER_CLASS, + CODE_BLOCK_ACTIONS_CLASS, + CODE_LANGUAGE_CLASS, + COPY_CODE_BTN_CLASS, + PREVIEW_CODE_BTN_CLASS, + RELATIVE_CLASS +} from '$lib/constants/code-blocks'; declare global { interface Window { @@ -42,7 +52,7 @@ function createCopyButton(codeId: string): Element { type: 'element', tagName: 'button', properties: { - className: ['copy-code-btn'], + className: [COPY_CODE_BTN_CLASS], 'data-code-id': codeId, title: 'Copy code', type: 'button' @@ -56,7 +66,7 @@ function createPreviewButton(codeId: string): Element { type: 'element', tagName: 'button', properties: { - className: ['preview-code-btn'], + className: [PREVIEW_CODE_BTN_CLASS], 'data-code-id': codeId, title: 'Preview code', type: 'button' @@ -75,30 +85,39 @@ function createHeader(language: string, codeId: string): Element { return { type: 'element', tagName: 'div', - properties: { className: ['code-block-header'] }, + properties: { className: [CODE_BLOCK_HEADER_CLASS] }, children: [ { type: 'element', tagName: 'span', - properties: { className: ['code-language'] }, + properties: { className: [CODE_LANGUAGE_CLASS] }, children: [{ type: 'text', value: language }] }, { type: 'element', tagName: 'div', - properties: { className: ['code-block-actions'] }, + properties: { className: [CODE_BLOCK_ACTIONS_CLASS] }, children: actions } ] }; } +function createScrollContainer(preElement: Element): Element { + return { + type: 'element', + tagName: 'div', + properties: { className: [CODE_BLOCK_SCROLL_CONTAINER_CLASS] }, + children: [preElement] + }; +} + function createWrapper(header: Element, preElement: Element): Element { return { type: 'element', tagName: 'div', - properties: { className: ['code-block-wrapper'] }, - children: [header, preElement] + properties: { className: [CODE_BLOCK_WRAPPER_CLASS, RELATIVE_CLASS] }, + children: [header, createScrollContainer(preElement)] }; } diff --git a/tools/server/webui/src/lib/services/database.service.ts b/tools/server/webui/src/lib/services/database.service.ts new file mode 100644 index 000000000..0d5a9c1b9 --- /dev/null +++ b/tools/server/webui/src/lib/services/database.service.ts @@ -0,0 +1,368 @@ +import Dexie, { type EntityTable } from 'dexie'; +import { findDescendantMessages } from '$lib/utils'; + +class LlamacppDatabase extends Dexie { + conversations!: EntityTable; + messages!: EntityTable; + + constructor() { + super('LlamacppWebui'); + + this.version(1).stores({ + conversations: 'id, lastModified, currNode, name', + messages: 'id, convId, type, role, timestamp, parent, children' + }); + } +} + +const db = new LlamacppDatabase(); +import { v4 as uuid } from 'uuid'; +import { MessageRole } from '$lib/enums/chat'; + +export class DatabaseService { + /** + * + * + * Conversations + * + * + */ + + /** + * Creates a new conversation. + * + * @param name - Name of the conversation + * @returns The created conversation + */ + static async createConversation(name: string): Promise { + const conversation: DatabaseConversation = { + id: uuid(), + name, + lastModified: Date.now(), + currNode: '' + }; + + await db.conversations.add(conversation); + return conversation; + } + + /** + * + * + * Messages + * + * + */ + + /** + * Creates a new message branch by adding a message and updating parent/child relationships. + * Also updates the conversation's currNode to point to the new message. + * + * @param message - Message to add (without id) + * @param parentId - Parent message ID to attach to + * @returns The created message + */ + static async createMessageBranch( + message: Omit, + parentId: string | null + ): Promise { + return await db.transaction('rw', [db.conversations, db.messages], async () => { + // Handle null parent (root message case) + if (parentId !== null) { + const parentMessage = await db.messages.get(parentId); + if (!parentMessage) { + throw new Error(`Parent message ${parentId} not found`); + } + } + + const newMessage: DatabaseMessage = { + ...message, + id: uuid(), + parent: parentId, + toolCalls: message.toolCalls ?? '', + children: [] + }; + + await db.messages.add(newMessage); + + // Update parent's children array if parent exists + if (parentId !== null) { + const parentMessage = await db.messages.get(parentId); + if (parentMessage) { + await db.messages.update(parentId, { + children: [...parentMessage.children, newMessage.id] + }); + } + } + + await this.updateConversation(message.convId, { + currNode: newMessage.id + }); + + return newMessage; + }); + } + + /** + * Creates a root message for a new conversation. + * Root messages are not displayed but serve as the tree root for branching. + * + * @param convId - Conversation ID + * @returns The created root message + */ + static async createRootMessage(convId: string): Promise { + const rootMessage: DatabaseMessage = { + id: uuid(), + convId, + type: 'root', + timestamp: Date.now(), + role: MessageRole.SYSTEM, + content: '', + parent: null, + toolCalls: '', + children: [] + }; + + await db.messages.add(rootMessage); + return rootMessage.id; + } + + /** + * Creates a system prompt message for a conversation. + * + * @param convId - Conversation ID + * @param systemPrompt - The system prompt content (must be non-empty) + * @param parentId - Parent message ID (typically the root message) + * @returns The created system message + * @throws Error if systemPrompt is empty + */ + static async createSystemMessage( + convId: string, + systemPrompt: string, + parentId: string + ): Promise { + const trimmedPrompt = systemPrompt.trim(); + if (!trimmedPrompt) { + throw new Error('Cannot create system message with empty content'); + } + + const systemMessage: DatabaseMessage = { + id: uuid(), + convId, + type: MessageRole.SYSTEM, + timestamp: Date.now(), + role: MessageRole.SYSTEM, + content: trimmedPrompt, + parent: parentId, + children: [] + }; + + await db.messages.add(systemMessage); + + const parentMessage = await db.messages.get(parentId); + if (parentMessage) { + await db.messages.update(parentId, { + children: [...parentMessage.children, systemMessage.id] + }); + } + + return systemMessage; + } + + /** + * Deletes a conversation and all its messages. + * + * @param id - Conversation ID + */ + static async deleteConversation(id: string): Promise { + await db.transaction('rw', [db.conversations, db.messages], async () => { + await db.conversations.delete(id); + await db.messages.where('convId').equals(id).delete(); + }); + } + + /** + * Deletes a message and removes it from its parent's children array. + * + * @param messageId - ID of the message to delete + */ + static async deleteMessage(messageId: string): Promise { + await db.transaction('rw', db.messages, async () => { + const message = await db.messages.get(messageId); + if (!message) return; + + // Remove this message from its parent's children array + if (message.parent) { + const parent = await db.messages.get(message.parent); + if (parent) { + parent.children = parent.children.filter((childId: string) => childId !== messageId); + await db.messages.put(parent); + } + } + + // Delete the message + await db.messages.delete(messageId); + }); + } + + /** + * Deletes a message and all its descendant messages (cascading deletion). + * This removes the entire branch starting from the specified message. + * + * @param conversationId - ID of the conversation containing the message + * @param messageId - ID of the root message to delete (along with all descendants) + * @returns Array of all deleted message IDs + */ + static async deleteMessageCascading( + conversationId: string, + messageId: string + ): Promise { + return await db.transaction('rw', db.messages, async () => { + // Get all messages in the conversation to find descendants + const allMessages = await db.messages.where('convId').equals(conversationId).toArray(); + + // Find all descendant messages + const descendants = findDescendantMessages(allMessages, messageId); + const allToDelete = [messageId, ...descendants]; + + // Get the message to delete for parent cleanup + const message = await db.messages.get(messageId); + if (message && message.parent) { + const parent = await db.messages.get(message.parent); + if (parent) { + parent.children = parent.children.filter((childId: string) => childId !== messageId); + await db.messages.put(parent); + } + } + + // Delete all messages in the branch + await db.messages.bulkDelete(allToDelete); + + return allToDelete; + }); + } + + /** + * Gets all conversations, sorted by last modified time (newest first). + * + * @returns Array of conversations + */ + static async getAllConversations(): Promise { + return await db.conversations.orderBy('lastModified').reverse().toArray(); + } + + /** + * Gets a conversation by ID. + * + * @param id - Conversation ID + * @returns The conversation if found, otherwise undefined + */ + static async getConversation(id: string): Promise { + return await db.conversations.get(id); + } + + /** + * Gets all messages in a conversation, sorted by timestamp (oldest first). + * + * @param convId - Conversation ID + * @returns Array of messages in the conversation + */ + static async getConversationMessages(convId: string): Promise { + return await db.messages.where('convId').equals(convId).sortBy('timestamp'); + } + + /** + * Updates a conversation. + * + * @param id - Conversation ID + * @param updates - Partial updates to apply + * @returns Promise that resolves when the conversation is updated + */ + static async updateConversation( + id: string, + updates: Partial> + ): Promise { + await db.conversations.update(id, { + ...updates, + lastModified: Date.now() + }); + } + + /** + * + * + * Navigation + * + * + */ + + /** + * Updates the conversation's current node (active branch). + * This determines which conversation path is currently being viewed. + * + * @param convId - Conversation ID + * @param nodeId - Message ID to set as current node + */ + static async updateCurrentNode(convId: string, nodeId: string): Promise { + await this.updateConversation(convId, { + currNode: nodeId + }); + } + + /** + * Updates a message. + * + * @param id - Message ID + * @param updates - Partial updates to apply + * @returns Promise that resolves when the message is updated + */ + static async updateMessage( + id: string, + updates: Partial> + ): Promise { + await db.messages.update(id, updates); + } + + /** + * + * + * Import + * + * + */ + + /** + * Imports multiple conversations and their messages. + * Skips conversations that already exist. + * + * @param data - Array of { conv, messages } objects + */ + static async importConversations( + data: { conv: DatabaseConversation; messages: DatabaseMessage[] }[] + ): Promise<{ imported: number; skipped: number }> { + let importedCount = 0; + let skippedCount = 0; + + return await db.transaction('rw', [db.conversations, db.messages], async () => { + for (const item of data) { + const { conv, messages } = item; + + const existing = await db.conversations.get(conv.id); + if (existing) { + console.warn(`Conversation "${conv.name}" already exists, skipping...`); + skippedCount++; + continue; + } + + await db.conversations.add(conv); + for (const msg of messages) { + await db.messages.put(msg); + } + + importedCount++; + } + + return { imported: importedCount, skipped: skippedCount }; + }); + } +} diff --git a/tools/server/webui/src/lib/services/models.service.ts b/tools/server/webui/src/lib/services/models.service.ts new file mode 100644 index 000000000..7357c3f40 --- /dev/null +++ b/tools/server/webui/src/lib/services/models.service.ts @@ -0,0 +1,99 @@ +import { ServerModelStatus } from '$lib/enums'; +import { apiFetch, apiPost } from '$lib/utils/api-fetch'; + +export class ModelsService { + /** + * + * + * Listing + * + * + */ + + /** + * Fetch list of models from OpenAI-compatible endpoint. + * Works in both MODEL and ROUTER modes. + * + * @returns List of available models with basic metadata + */ + static async list(): Promise { + return apiFetch('/v1/models'); + } + + /** + * Fetch list of all models with detailed metadata (ROUTER mode). + * Returns models with load status, paths, and other metadata + * beyond what the OpenAI-compatible endpoint provides. + * + * @returns List of models with detailed status and configuration info + */ + static async listRouter(): Promise { + return apiFetch('/v1/models'); + } + + /** + * + * + * Load/Unload + * + * + */ + + /** + * Load a model (ROUTER mode only). + * Sends POST request to `/models/load`. Note: the endpoint returns success + * before loading completes — use polling to await actual load status. + * + * @param modelId - Model identifier to load + * @param extraArgs - Optional additional arguments to pass to the model instance + * @returns Load response from the server + */ + static async load(modelId: string, extraArgs?: string[]): Promise { + const payload: { model: string; extra_args?: string[] } = { model: modelId }; + if (extraArgs && extraArgs.length > 0) { + payload.extra_args = extraArgs; + } + + return apiPost('/models/load', payload); + } + + /** + * Unload a model (ROUTER mode only). + * Sends POST request to `/models/unload`. Note: the endpoint returns success + * before unloading completes — use polling to await actual unload status. + * + * @param modelId - Model identifier to unload + * @returns Unload response from the server + */ + static async unload(modelId: string): Promise { + return apiPost('/models/unload', { model: modelId }); + } + + /** + * + * + * Status + * + * + */ + + /** + * Check if a model is loaded based on its metadata. + * + * @param model - Model data entry from the API response + * @returns True if the model status is LOADED + */ + static isModelLoaded(model: ApiModelDataEntry): boolean { + return model.status.value === ServerModelStatus.LOADED; + } + + /** + * Check if a model is currently loading. + * + * @param model - Model data entry from the API response + * @returns True if the model status is LOADING + */ + static isModelLoading(model: ApiModelDataEntry): boolean { + return model.status.value === ServerModelStatus.LOADING; + } +} diff --git a/tools/server/webui/src/lib/services/parameter-sync.service.spec.ts b/tools/server/webui/src/lib/services/parameter-sync.service.spec.ts new file mode 100644 index 000000000..46cce5e7c --- /dev/null +++ b/tools/server/webui/src/lib/services/parameter-sync.service.spec.ts @@ -0,0 +1,148 @@ +import { describe, it, expect } from 'vitest'; +import { ParameterSyncService } from './parameter-sync.service'; + +describe('ParameterSyncService', () => { + describe('roundFloatingPoint', () => { + it('should fix JavaScript floating-point precision issues', () => { + // Test the specific values from the screenshot + const mockServerParams = { + top_p: 0.949999988079071, + min_p: 0.009999999776482582, + temperature: 0.800000011920929, + top_k: 40, + samplers: ['top_k', 'typ_p', 'top_p', 'min_p', 'temperature'] + }; + + const result = ParameterSyncService.extractServerDefaults({ + ...mockServerParams, + // Add other required fields to match the API type + n_predict: 512, + seed: -1, + dynatemp_range: 0.0, + dynatemp_exponent: 1.0, + xtc_probability: 0.0, + xtc_threshold: 0.1, + typ_p: 1.0, + repeat_last_n: 64, + repeat_penalty: 1.0, + presence_penalty: 0.0, + frequency_penalty: 0.0, + dry_multiplier: 0.0, + dry_base: 1.75, + dry_allowed_length: 2, + dry_penalty_last_n: -1, + mirostat: 0, + mirostat_tau: 5.0, + mirostat_eta: 0.1, + stop: [], + max_tokens: -1, + n_keep: 0, + n_discard: 0, + ignore_eos: false, + stream: true, + logit_bias: [], + n_probs: 0, + min_keep: 0, + grammar: '', + grammar_lazy: false, + grammar_triggers: [], + preserved_tokens: [], + chat_format: '', + reasoning_format: '', + reasoning_in_content: false, + thinking_forced_open: false, + 'speculative.n_max': 0, + 'speculative.n_min': 0, + 'speculative.p_min': 0.0, + timings_per_token: false, + post_sampling_probs: false, + lora: [], + top_n_sigma: 0.0, + dry_sequence_breakers: [] + } as ApiLlamaCppServerProps['default_generation_settings']['params']); + + // Check that the problematic floating-point values are rounded correctly + expect(result.top_p).toBe(0.95); + expect(result.min_p).toBe(0.01); + expect(result.temperature).toBe(0.8); + expect(result.top_k).toBe(40); // Integer should remain unchanged + expect(result.samplers).toBe('top_k;typ_p;top_p;min_p;temperature'); + }); + + it('should preserve non-numeric values', () => { + const mockServerParams = { + samplers: ['top_k', 'temperature'], + max_tokens: -1, + temperature: 0.7 + }; + + const result = ParameterSyncService.extractServerDefaults({ + ...mockServerParams, + // Minimal required fields + n_predict: 512, + seed: -1, + dynatemp_range: 0.0, + dynatemp_exponent: 1.0, + top_k: 40, + top_p: 0.95, + min_p: 0.05, + xtc_probability: 0.0, + xtc_threshold: 0.1, + typ_p: 1.0, + repeat_last_n: 64, + repeat_penalty: 1.0, + presence_penalty: 0.0, + frequency_penalty: 0.0, + dry_multiplier: 0.0, + dry_base: 1.75, + dry_allowed_length: 2, + dry_penalty_last_n: -1, + mirostat: 0, + mirostat_tau: 5.0, + mirostat_eta: 0.1, + stop: [], + n_keep: 0, + n_discard: 0, + ignore_eos: false, + stream: true, + logit_bias: [], + n_probs: 0, + min_keep: 0, + grammar: '', + grammar_lazy: false, + grammar_triggers: [], + preserved_tokens: [], + chat_format: '', + reasoning_format: '', + reasoning_in_content: false, + thinking_forced_open: false, + 'speculative.n_max': 0, + 'speculative.n_min': 0, + 'speculative.p_min': 0.0, + timings_per_token: false, + post_sampling_probs: false, + lora: [], + top_n_sigma: 0.0, + dry_sequence_breakers: [] + } as ApiLlamaCppServerProps['default_generation_settings']['params']); + + expect(result.samplers).toBe('top_k;temperature'); + expect(result.max_tokens).toBe(-1); + expect(result.temperature).toBe(0.7); + }); + + it('should merge webui settings from props when provided', () => { + const result = ParameterSyncService.extractServerDefaults(null, { + pasteLongTextToFileLen: 0, + pdfAsImage: true, + renderUserContentAsMarkdown: false, + theme: 'dark' + }); + + expect(result.pasteLongTextToFileLen).toBe(0); + expect(result.pdfAsImage).toBe(true); + expect(result.renderUserContentAsMarkdown).toBe(false); + expect(result.theme).toBeUndefined(); + }); + }); +}); diff --git a/tools/server/webui/src/lib/services/parameter-sync.service.ts b/tools/server/webui/src/lib/services/parameter-sync.service.ts new file mode 100644 index 000000000..6cb53d12d --- /dev/null +++ b/tools/server/webui/src/lib/services/parameter-sync.service.ts @@ -0,0 +1,400 @@ +import { normalizeFloatingPoint } from '$lib/utils'; +import { SyncableParameterType, ParameterSource } from '$lib/enums/settings'; + +type ParameterValue = string | number | boolean; +type ParameterRecord = Record; + +interface ParameterInfo { + value: string | number | boolean; + source: ParameterSource; + serverDefault?: string | number | boolean; + userOverride?: string | number | boolean; +} + +interface SyncableParameter { + key: string; + serverKey: string; + type: SyncableParameterType; + canSync: boolean; +} + +/** + * Mapping of webui setting keys to server parameter keys. + * Only parameters listed here can be synced from the server `/props` endpoint. + * Each entry defines the webui key, corresponding server key, value type, + * and whether sync is enabled. + */ +export const SYNCABLE_PARAMETERS: SyncableParameter[] = [ + { + key: 'temperature', + serverKey: 'temperature', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { key: 'top_k', serverKey: 'top_k', type: SyncableParameterType.NUMBER, canSync: true }, + { key: 'top_p', serverKey: 'top_p', type: SyncableParameterType.NUMBER, canSync: true }, + { key: 'min_p', serverKey: 'min_p', type: SyncableParameterType.NUMBER, canSync: true }, + { + key: 'dynatemp_range', + serverKey: 'dynatemp_range', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { + key: 'dynatemp_exponent', + serverKey: 'dynatemp_exponent', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { + key: 'xtc_probability', + serverKey: 'xtc_probability', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { + key: 'xtc_threshold', + serverKey: 'xtc_threshold', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { key: 'typ_p', serverKey: 'typ_p', type: SyncableParameterType.NUMBER, canSync: true }, + { + key: 'repeat_last_n', + serverKey: 'repeat_last_n', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { + key: 'repeat_penalty', + serverKey: 'repeat_penalty', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { + key: 'presence_penalty', + serverKey: 'presence_penalty', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { + key: 'frequency_penalty', + serverKey: 'frequency_penalty', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { + key: 'dry_multiplier', + serverKey: 'dry_multiplier', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { key: 'dry_base', serverKey: 'dry_base', type: SyncableParameterType.NUMBER, canSync: true }, + { + key: 'dry_allowed_length', + serverKey: 'dry_allowed_length', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { + key: 'dry_penalty_last_n', + serverKey: 'dry_penalty_last_n', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { key: 'max_tokens', serverKey: 'max_tokens', type: SyncableParameterType.NUMBER, canSync: true }, + { key: 'samplers', serverKey: 'samplers', type: SyncableParameterType.STRING, canSync: true }, + { + key: 'pasteLongTextToFileLen', + serverKey: 'pasteLongTextToFileLen', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { + key: 'pdfAsImage', + serverKey: 'pdfAsImage', + type: SyncableParameterType.BOOLEAN, + canSync: true + }, + { + key: 'showThoughtInProgress', + serverKey: 'showThoughtInProgress', + type: SyncableParameterType.BOOLEAN, + canSync: true + }, + { + key: 'keepStatsVisible', + serverKey: 'keepStatsVisible', + type: SyncableParameterType.BOOLEAN, + canSync: true + }, + { + key: 'showMessageStats', + serverKey: 'showMessageStats', + type: SyncableParameterType.BOOLEAN, + canSync: true + }, + { + key: 'askForTitleConfirmation', + serverKey: 'askForTitleConfirmation', + type: SyncableParameterType.BOOLEAN, + canSync: true + }, + { + key: 'disableAutoScroll', + serverKey: 'disableAutoScroll', + type: SyncableParameterType.BOOLEAN, + canSync: true + }, + { + key: 'renderUserContentAsMarkdown', + serverKey: 'renderUserContentAsMarkdown', + type: SyncableParameterType.BOOLEAN, + canSync: true + }, + { + key: 'autoMicOnEmpty', + serverKey: 'autoMicOnEmpty', + type: SyncableParameterType.BOOLEAN, + canSync: true + }, + { + key: 'pyInterpreterEnabled', + serverKey: 'pyInterpreterEnabled', + type: SyncableParameterType.BOOLEAN, + canSync: true + }, + { + key: 'enableContinueGeneration', + serverKey: 'enableContinueGeneration', + type: SyncableParameterType.BOOLEAN, + canSync: true + } +]; + +export class ParameterSyncService { + /** + * + * + * Extraction + * + * + */ + + /** + * Round floating-point numbers to avoid JavaScript precision issues. + * E.g., 0.1 + 0.2 = 0.30000000000000004 → 0.3 + * + * @param value - Parameter value to normalize + * @returns Precision-normalized value + */ + private static roundFloatingPoint(value: ParameterValue): ParameterValue { + return normalizeFloatingPoint(value) as ParameterValue; + } + + /** + * Extract server default parameters that can be synced from `/props` response. + * Handles both generation settings parameters and webui-specific settings. + * Converts samplers array to semicolon-delimited string for UI display. + * + * @param serverParams - Raw generation settings from server `/props` endpoint + * @param webuiSettings - Optional webui-specific settings from server + * @returns Record of extracted parameter key-value pairs with normalized precision + */ + static extractServerDefaults( + serverParams: ApiLlamaCppServerProps['default_generation_settings']['params'] | null, + webuiSettings?: Record + ): ParameterRecord { + const extracted: ParameterRecord = {}; + + if (serverParams) { + for (const param of SYNCABLE_PARAMETERS) { + if (param.canSync && param.serverKey in serverParams) { + const value = (serverParams as unknown as Record)[ + param.serverKey + ]; + if (value !== undefined) { + // Apply precision rounding to avoid JavaScript floating-point issues + extracted[param.key] = this.roundFloatingPoint(value); + } + } + } + + // Handle samplers array conversion to string + if (serverParams.samplers && Array.isArray(serverParams.samplers)) { + extracted.samplers = serverParams.samplers.join(';'); + } + } + + if (webuiSettings) { + for (const param of SYNCABLE_PARAMETERS) { + if (param.canSync && param.serverKey in webuiSettings) { + const value = webuiSettings[param.serverKey]; + if (value !== undefined) { + extracted[param.key] = this.roundFloatingPoint(value); + } + } + } + } + + return extracted; + } + + /** + * + * + * Merging + * + * + */ + + /** + * Merge server defaults with current user settings. + * User overrides always take priority — only parameters not in `userOverrides` + * set will be updated from server defaults. + * + * @param currentSettings - Current parameter values in the settings store + * @param serverDefaults - Default values extracted from server props + * @param userOverrides - Set of parameter keys explicitly overridden by the user + * @returns Merged parameter record with user overrides preserved + */ + static mergeWithServerDefaults( + currentSettings: ParameterRecord, + serverDefaults: ParameterRecord, + userOverrides: Set = new Set() + ): ParameterRecord { + const merged = { ...currentSettings }; + + for (const [key, serverValue] of Object.entries(serverDefaults)) { + // Only update if user hasn't explicitly overridden this parameter + if (!userOverrides.has(key)) { + merged[key] = this.roundFloatingPoint(serverValue); + } + } + + return merged; + } + + /** + * + * + * Info + * + * + */ + + /** + * Get parameter information including source and values. + * Used by ChatSettingsParameterSourceIndicator to display the correct badge + * (Custom vs Default) for each parameter in the settings UI. + * + * @param key - The parameter key to get info for + * @param currentValue - The current value of the parameter + * @param propsDefaults - Server default values from `/props` + * @param userOverrides - Set of parameter keys explicitly overridden by the user + * @returns Parameter info with source, server default, and user override values + */ + static getParameterInfo( + key: string, + currentValue: ParameterValue, + propsDefaults: ParameterRecord, + userOverrides: Set + ): ParameterInfo { + const hasPropsDefault = propsDefaults[key] !== undefined; + const isUserOverride = userOverrides.has(key); + + // Simple logic: either using default (from props) or custom (user override) + const source = isUserOverride ? ParameterSource.CUSTOM : ParameterSource.DEFAULT; + + return { + value: currentValue, + source, + serverDefault: hasPropsDefault ? propsDefaults[key] : undefined, // Keep same field name for compatibility + userOverride: isUserOverride ? currentValue : undefined + }; + } + + /** + * Check if a parameter can be synced from server. + * + * @param key - The parameter key to check + * @returns True if the parameter is in the syncable parameters list + */ + static canSyncParameter(key: string): boolean { + return SYNCABLE_PARAMETERS.some((param) => param.key === key && param.canSync); + } + + /** + * Get all syncable parameter keys. + * + * @returns Array of parameter keys that can be synced from server + */ + static getSyncableParameterKeys(): string[] { + return SYNCABLE_PARAMETERS.filter((param) => param.canSync).map((param) => param.key); + } + + /** + * Validate a server parameter value against its expected type. + * + * @param key - The parameter key to validate + * @param value - The value to validate + * @returns True if value matches the expected type for this parameter + */ + static validateServerParameter(key: string, value: ParameterValue): boolean { + const param = SYNCABLE_PARAMETERS.find((p) => p.key === key); + if (!param) return false; + + switch (param.type) { + case SyncableParameterType.NUMBER: + return typeof value === 'number' && !isNaN(value); + case SyncableParameterType.STRING: + return typeof value === 'string'; + case SyncableParameterType.BOOLEAN: + return typeof value === 'boolean'; + default: + return false; + } + } + + /** + * + * + * Diff + * + * + */ + + /** + * Create a diff between current settings and server defaults. + * Shows which parameters differ from server values, useful for debugging + * and for the "Reset to defaults" functionality. + * + * @param currentSettings - Current parameter values in the settings store + * @param serverDefaults - Default values extracted from server props + * @returns Record of parameter diffs with current value, server value, and whether they differ + */ + static createParameterDiff( + currentSettings: ParameterRecord, + serverDefaults: ParameterRecord + ): Record { + const diff: Record< + string, + { current: ParameterValue; server: ParameterValue; differs: boolean } + > = {}; + + for (const key of this.getSyncableParameterKeys()) { + const currentValue = currentSettings[key]; + const serverValue = serverDefaults[key]; + + if (serverValue !== undefined) { + diff[key] = { + current: currentValue, + server: serverValue, + differs: currentValue !== serverValue + }; + } + } + + return diff; + } +} diff --git a/tools/server/webui/src/lib/services/props.service.ts b/tools/server/webui/src/lib/services/props.service.ts new file mode 100644 index 000000000..7373b7e01 --- /dev/null +++ b/tools/server/webui/src/lib/services/props.service.ts @@ -0,0 +1,47 @@ +import { apiFetchWithParams } from '$lib/utils/api-fetch'; + +export class PropsService { + /** + * + * + * Fetching + * + * + */ + + /** + * Fetches global server properties from the `/props` endpoint. + * In MODEL mode, returns modalities for the single loaded model. + * In ROUTER mode, returns server-wide settings without model-specific modalities. + * + * @param autoload - If false, prevents automatic model loading (default: false) + * @returns Server properties including default generation settings and capabilities + * @throws {Error} If the request fails or returns invalid data + */ + static async fetch(autoload = false): Promise { + const params: Record = {}; + if (!autoload) { + params.autoload = 'false'; + } + + return apiFetchWithParams('./props', params, { authOnly: true }); + } + + /** + * Fetches server properties for a specific model (ROUTER mode only). + * Required in ROUTER mode because global `/props` does not include per-model modalities. + * + * @param modelId - The model ID to fetch properties for + * @param autoload - If false, prevents automatic model loading (default: false) + * @returns Server properties specific to the requested model + * @throws {Error} If the request fails, model not found, or model not loaded + */ + static async fetchForModel(modelId: string, autoload = false): Promise { + const params: Record = { model: modelId }; + if (!autoload) { + params.autoload = 'false'; + } + + return apiFetchWithParams('./props', params, { authOnly: true }); + } +} diff --git a/tools/server/webui/src/lib/stores/chat.svelte.ts b/tools/server/webui/src/lib/stores/chat.svelte.ts index 879b2f324..89de4f080 100644 --- a/tools/server/webui/src/lib/stores/chat.svelte.ts +++ b/tools/server/webui/src/lib/stores/chat.svelte.ts @@ -15,6 +15,7 @@ import { } from '$lib/utils'; import { SvelteMap } from 'svelte/reactivity'; import { DEFAULT_CONTEXT } from '$lib/constants/default-context'; +import { SYSTEM_MESSAGE_PLACEHOLDER } from '$lib/constants/ui'; /** * chatStore - Active AI interaction and streaming state management @@ -76,6 +77,10 @@ class ChatStore { private isStreamingActive = $state(false); private isEditModeActive = $state(false); private addFilesHandler: ((files: File[]) => void) | null = $state(null); + pendingEditMessageId = $state(null); + // Draft preservation for navigation (e.g., when adding system prompt from welcome page) + private _pendingDraftMessage = $state(''); + private _pendingDraftFiles = $state([]); // ───────────────────────────────────────────────────────────────────────────── // Loading State @@ -455,6 +460,166 @@ class ChatStore { } } + /** + * Adds a system message at the top of a conversation and triggers edit mode. + * The system message is inserted between root and the first message of the active branch. + * Creates a new conversation if one doesn't exist. + */ + async addSystemPrompt(): Promise { + let activeConv = conversationsStore.activeConversation; + + // Create conversation if needed + if (!activeConv) { + await conversationsStore.createConversation(); + activeConv = conversationsStore.activeConversation; + } + if (!activeConv) return; + + try { + // Get all messages to find the root + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); + const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); + let rootId: string; + + // Create root message if it doesn't exist + if (!rootMessage) { + rootId = await DatabaseService.createRootMessage(activeConv.id); + } else { + rootId = rootMessage.id; + } + + // Check if there's already a system message as root's child + const existingSystemMessage = allMessages.find( + (m) => m.role === 'system' && m.parent === rootId + ); + + if (existingSystemMessage) { + // If system message exists, just trigger edit mode on it + this.pendingEditMessageId = existingSystemMessage.id; + + // Make sure it's in active messages at the beginning + if (!conversationsStore.activeMessages.some((m) => m.id === existingSystemMessage.id)) { + conversationsStore.activeMessages.unshift(existingSystemMessage); + } + return; + } + + // Find the first message of the active branch (child of root that's in activeMessages) + const activeMessages = conversationsStore.activeMessages; + const firstActiveMessage = activeMessages.find((m) => m.parent === rootId); + + // Create new system message with placeholder content (will be edited by user) + const systemMessage = await DatabaseService.createSystemMessage( + activeConv.id, + SYSTEM_MESSAGE_PLACEHOLDER, + rootId + ); + + // If there's a first message in the active branch, re-parent it to the system message + if (firstActiveMessage) { + // Update the first message's parent to be the system message + await DatabaseService.updateMessage(firstActiveMessage.id, { + parent: systemMessage.id + }); + + // Update the system message's children to include the first message + await DatabaseService.updateMessage(systemMessage.id, { + children: [firstActiveMessage.id] + }); + + // Remove first message from root's children + const updatedRootChildren = rootMessage + ? rootMessage.children.filter((id: string) => id !== firstActiveMessage.id) + : []; + // Note: system message was already added to root's children by createSystemMessage + await DatabaseService.updateMessage(rootId, { + children: [ + ...updatedRootChildren.filter((id: string) => id !== systemMessage.id), + systemMessage.id + ] + }); + + // Update local state + const firstMsgIndex = conversationsStore.findMessageIndex(firstActiveMessage.id); + if (firstMsgIndex !== -1) { + conversationsStore.updateMessageAtIndex(firstMsgIndex, { parent: systemMessage.id }); + } + } + + // Add system message to active messages at the beginning + conversationsStore.activeMessages.unshift(systemMessage); + + // Set pending edit message ID to trigger edit mode + this.pendingEditMessageId = systemMessage.id; + + conversationsStore.updateConversationTimestamp(); + } catch (error) { + console.error('Failed to add system prompt:', error); + } + } + + /** + * Removes a system message placeholder without deleting its children. + * Re-parents children back to the root message. + * If this is a new empty conversation (only root + system placeholder), deletes the entire conversation. + * @returns true if the entire conversation was deleted, false otherwise + */ + async removeSystemPromptPlaceholder(messageId: string): Promise { + const activeConv = conversationsStore.activeConversation; + if (!activeConv) return false; + + try { + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); + const systemMessage = allMessages.find((m) => m.id === messageId); + if (!systemMessage || systemMessage.role !== 'system') return false; + + const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); + if (!rootMessage) return false; + + // Check if this is a new empty conversation (only root + system placeholder) + const isEmptyConversation = allMessages.length === 2 && systemMessage.children.length === 0; + + if (isEmptyConversation) { + // Delete the entire conversation + await conversationsStore.deleteConversation(activeConv.id); + return true; + } + + // Re-parent system message's children to root + for (const childId of systemMessage.children) { + await DatabaseService.updateMessage(childId, { parent: rootMessage.id }); + + // Update local state + const childIndex = conversationsStore.findMessageIndex(childId); + if (childIndex !== -1) { + conversationsStore.updateMessageAtIndex(childIndex, { parent: rootMessage.id }); + } + } + + // Update root's children: remove system message, add system's children + const newRootChildren = [ + ...rootMessage.children.filter((id: string) => id !== messageId), + ...systemMessage.children + ]; + await DatabaseService.updateMessage(rootMessage.id, { children: newRootChildren }); + + // Delete the system message (without cascade) + await DatabaseService.deleteMessage(messageId); + + // Remove from active messages + const systemIndex = conversationsStore.findMessageIndex(messageId); + if (systemIndex !== -1) { + conversationsStore.activeMessages.splice(systemIndex, 1); + } + + conversationsStore.updateConversationTimestamp(); + return false; + } catch (error) { + console.error('Failed to remove system prompt placeholder:', error); + return false; + } + } + private async createAssistantMessage(parentId?: string): Promise { const activeConv = conversationsStore.activeConversation; if (!activeConv) return null; @@ -916,6 +1081,28 @@ class ChatStore { if (!activeConv) return { totalCount: 0, userMessages: 0, assistantMessages: 0, messageTypes: [] }; const allMessages = await conversationsStore.getConversationMessages(activeConv.id); + const messageToDelete = allMessages.find((m) => m.id === messageId); + + // For system messages, don't count descendants as they will be preserved (reparented to root) + if (messageToDelete?.role === 'system') { + const messagesToDelete = allMessages.filter((m) => m.id === messageId); + let userMessages = 0, + assistantMessages = 0; + const messageTypes: string[] = []; + + for (const msg of messagesToDelete) { + if (msg.role === 'user') { + userMessages++; + if (!messageTypes.includes('user message')) messageTypes.push('user message'); + } else if (msg.role === 'assistant') { + assistantMessages++; + if (!messageTypes.includes('assistant response')) messageTypes.push('assistant response'); + } + } + + return { totalCount: 1, userMessages, assistantMessages, messageTypes }; + } + const descendants = findDescendantMessages(allMessages, messageId); const allToDelete = [messageId, ...descendants]; const messagesToDelete = allMessages.filter((m) => allToDelete.includes(m.id)); @@ -1381,6 +1568,31 @@ class ChatStore { return this.addFilesHandler; } + savePendingDraft(message: string, files: ChatUploadedFile[]): void { + this._pendingDraftMessage = message; + this._pendingDraftFiles = [...files]; + } + + consumePendingDraft(): { message: string; files: ChatUploadedFile[] } | null { + if (!this._pendingDraftMessage && this._pendingDraftFiles.length === 0) { + return null; + } + + const draft = { + message: this._pendingDraftMessage, + files: [...this._pendingDraftFiles] + }; + + this._pendingDraftMessage = ''; + this._pendingDraftFiles = []; + + return draft; + } + + hasPendingDraft(): boolean { + return Boolean(this._pendingDraftMessage) || this._pendingDraftFiles.length > 0; + } + public getAllLoadingChats(): string[] { return Array.from(this.chatLoadingStates.keys()); } @@ -1485,3 +1697,7 @@ export const isEditing = () => chatStore.isEditing(); export const isLoading = () => chatStore.isLoading; export const setEditModeActive = (handler: (files: File[]) => void) => chatStore.setEditModeActive(handler); +export const pendingEditMessageId = () => chatStore.pendingEditMessageId; +export const clearPendingEditMessageId = () => (chatStore.pendingEditMessageId = null); +export const removeSystemPromptPlaceholder = (messageId: string) => + chatStore.removeSystemPromptPlaceholder(messageId); diff --git a/tools/server/webui/src/lib/types/api.d.ts b/tools/server/webui/src/lib/types/api.d.ts index 714509f02..307e3b71d 100644 --- a/tools/server/webui/src/lib/types/api.d.ts +++ b/tools/server/webui/src/lib/types/api.d.ts @@ -1,8 +1,19 @@ -import type { ServerModelStatus, ServerRole } from '$lib/enums'; -import type { ChatMessagePromptProgress } from './chat'; +import type { ContentPartType, ServerModelStatus, ServerRole } from '$lib/enums'; +import type { ChatMessagePromptProgress, ChatRole } from './chat'; + +export interface ApiChatCompletionToolFunction { + name: string; + description?: string; + parameters: Record; +} + +export interface ApiChatCompletionTool { + type: 'function'; + function: ApiChatCompletionToolFunction; +} export interface ApiChatMessageContentPart { - type: 'text' | 'image_url' | 'input_audio'; + type: ContentPartType; text?: string; image_url?: { url: string; @@ -34,6 +45,8 @@ export interface ApiErrorResponse { export interface ApiChatMessageData { role: ChatRole; content: string | ApiChatMessageContentPart[]; + tool_calls?: ApiChatCompletionToolCall[]; + tool_call_id?: string; timestamp?: number; } @@ -188,6 +201,7 @@ export interface ApiChatCompletionRequest { stream?: boolean; model?: string; return_progress?: boolean; + tools?: ApiChatCompletionTool[]; // Reasoning parameters reasoning_format?: string; // Generation parameters @@ -247,6 +261,7 @@ export interface ApiChatCompletionStreamChunk { model?: string; tool_calls?: ApiChatCompletionToolCallDelta[]; }; + finish_reason?: string | null; }>; timings?: { prompt_n?: number; @@ -267,8 +282,9 @@ export interface ApiChatCompletionResponse { content: string; reasoning_content?: string; model?: string; - tool_calls?: ApiChatCompletionToolCallDelta[]; + tool_calls?: ApiChatCompletionToolCall[]; }; + finish_reason?: string | null; }>; } @@ -335,7 +351,7 @@ export interface ApiProcessingState { tokensDecoded: number; tokensRemaining: number; contextUsed: number; - contextTotal: number; + contextTotal: number | null; outputTokensUsed: number; // Total output tokens (thinking + regular content) outputTokensMax: number; // Max output tokens allowed temperature: number; diff --git a/tools/server/webui/src/lib/types/models.d.ts b/tools/server/webui/src/lib/types/models.d.ts index ef44a2cb6..505867a1f 100644 --- a/tools/server/webui/src/lib/types/models.d.ts +++ b/tools/server/webui/src/lib/types/models.d.ts @@ -1,8 +1,5 @@ import type { ApiModelDataEntry, ApiModelDetails } from '$lib/types/api'; -/** - * Model modalities - vision and audio capabilities - */ export interface ModelModalities { vision: boolean; audio: boolean; @@ -14,8 +11,15 @@ export interface ModelOption { model: string; description?: string; capabilities: string[]; - /** Model modalities from /props endpoint */ modalities?: ModelModalities; details?: ApiModelDetails['details']; meta?: ApiModelDataEntry['meta']; } + +/** + * Modality capabilities for file validation + */ +export interface ModalityCapabilities { + hasVision: boolean; + hasAudio: boolean; +} diff --git a/tools/server/webui/src/lib/utils/abort.ts b/tools/server/webui/src/lib/utils/abort.ts new file mode 100644 index 000000000..fc4f31ec6 --- /dev/null +++ b/tools/server/webui/src/lib/utils/abort.ts @@ -0,0 +1,151 @@ +/** + * Abort Signal Utilities + * + * Provides utilities for consistent AbortSignal propagation across the application. + * These utilities help ensure that async operations can be properly cancelled + * when needed (e.g., user stops generation, navigates away, etc.). + */ + +/** + * Throws an AbortError if the signal is aborted. + * Use this at the start of async operations to fail fast. + * + * @param signal - Optional AbortSignal to check + * @throws DOMException with name 'AbortError' if signal is aborted + * + * @example + * ```ts + * async function fetchData(signal?: AbortSignal) { + * throwIfAborted(signal); + * // ... proceed with operation + * } + * ``` + */ +export function throwIfAborted(signal?: AbortSignal): void { + if (signal?.aborted) { + throw new DOMException('Operation was aborted', 'AbortError'); + } +} + +/** + * Checks if an error is an AbortError. + * Use this to distinguish between user-initiated cancellation and actual errors. + * + * @param error - Error to check + * @returns true if the error is an AbortError + * + * @example + * ```ts + * try { + * await fetchData(signal); + * } catch (error) { + * if (isAbortError(error)) { + * // User cancelled - no error dialog needed + * return; + * } + * // Handle actual error + * } + * ``` + */ +export function isAbortError(error: unknown): boolean { + if (error instanceof DOMException && error.name === 'AbortError') { + return true; + } + if (error instanceof Error && error.name === 'AbortError') { + return true; + } + return false; +} + +/** + * Creates a new AbortController that is linked to one or more parent signals. + * When any parent signal aborts, the returned controller also aborts. + * + * Useful for creating child operations that should be cancelled when + * either the parent operation or their own timeout/condition triggers. + * + * @param signals - Parent signals to link to (undefined signals are ignored) + * @returns A new AbortController linked to all provided signals + * + * @example + * ```ts + * // Link to user's abort signal and add a timeout + * const linked = createLinkedController(userSignal, timeoutSignal); + * await fetch(url, { signal: linked.signal }); + * ``` + */ +export function createLinkedController(...signals: (AbortSignal | undefined)[]): AbortController { + const controller = new AbortController(); + + for (const signal of signals) { + if (!signal) continue; + + // If already aborted, abort immediately + if (signal.aborted) { + controller.abort(signal.reason); + return controller; + } + + // Link to parent signal + signal.addEventListener('abort', () => controller.abort(signal.reason), { once: true }); + } + + return controller; +} + +/** + * Creates an AbortSignal that times out after the specified duration. + * + * @param ms - Timeout duration in milliseconds + * @returns AbortSignal that will abort after the timeout + * + * @example + * ```ts + * const signal = createTimeoutSignal(5000); // 5 second timeout + * await fetch(url, { signal }); + * ``` + */ +export function createTimeoutSignal(ms: number): AbortSignal { + return AbortSignal.timeout(ms); +} + +/** + * Wraps a promise to reject if the signal is aborted. + * Useful for making non-abortable promises respect an AbortSignal. + * + * @param promise - Promise to wrap + * @param signal - AbortSignal to respect + * @returns Promise that rejects with AbortError if signal aborts + * + * @example + * ```ts + * // Make a non-abortable operation respect abort signal + * const result = await withAbortSignal( + * someNonAbortableOperation(), + * signal + * ); + * ``` + */ +export async function withAbortSignal(promise: Promise, signal?: AbortSignal): Promise { + if (!signal) return promise; + + throwIfAborted(signal); + + return new Promise((resolve, reject) => { + const abortHandler = () => { + reject(new DOMException('Operation was aborted', 'AbortError')); + }; + + signal.addEventListener('abort', abortHandler, { once: true }); + + promise + .then((value) => { + signal.removeEventListener('abort', abortHandler); + resolve(value); + }) + .catch((error) => { + signal.removeEventListener('abort', abortHandler); + reject(error); + }); + }); +} diff --git a/tools/server/webui/src/lib/utils/api-fetch.ts b/tools/server/webui/src/lib/utils/api-fetch.ts new file mode 100644 index 000000000..28757a966 --- /dev/null +++ b/tools/server/webui/src/lib/utils/api-fetch.ts @@ -0,0 +1,154 @@ +import { base } from '$app/paths'; +import { getJsonHeaders, getAuthHeaders } from './api-headers'; + +/** + * API Fetch Utilities + * + * Provides common fetch patterns used across services: + * - Automatic JSON headers + * - Error handling with proper error messages + * - Base path resolution + */ + +export interface ApiFetchOptions extends Omit { + /** + * Use auth-only headers (no Content-Type). + * Default: false (uses JSON headers with Content-Type: application/json) + */ + authOnly?: boolean; + /** + * Additional headers to merge with default headers. + */ + headers?: Record; +} + +/** + * Fetch JSON data from an API endpoint with standard headers and error handling. + * + * @param path - API path (will be prefixed with base path) + * @param options - Fetch options with additional authOnly flag + * @returns Parsed JSON response + * @throws Error with formatted message on failure + * + * @example + * ```typescript + * // GET request + * const models = await apiFetch('/v1/models'); + * + * // POST request + * const result = await apiFetch('/models/load', { + * method: 'POST', + * body: JSON.stringify({ model: 'gpt-4' }) + * }); + * ``` + */ +export async function apiFetch(path: string, options: ApiFetchOptions = {}): Promise { + const { authOnly = false, headers: customHeaders, ...fetchOptions } = options; + + const baseHeaders = authOnly ? getAuthHeaders() : getJsonHeaders(); + const headers = { ...baseHeaders, ...customHeaders }; + + const url = path.startsWith('http://') || path.startsWith('https://') ? path : `${base}${path}`; + + const response = await fetch(url, { + ...fetchOptions, + headers + }); + + if (!response.ok) { + const errorMessage = await parseErrorMessage(response); + throw new Error(errorMessage); + } + + return response.json() as Promise; +} + +/** + * Fetch with URL constructed from base URL and query parameters. + * + * @param basePath - Base API path + * @param params - Query parameters to append + * @param options - Fetch options + * @returns Parsed JSON response + * + * @example + * ```typescript + * const props = await apiFetchWithParams('./props', { + * model: 'gpt-4', + * autoload: 'false' + * }); + * ``` + */ +export async function apiFetchWithParams( + basePath: string, + params: Record, + options: ApiFetchOptions = {} +): Promise { + const url = new URL(basePath, window.location.href); + + for (const [key, value] of Object.entries(params)) { + if (value !== undefined && value !== null) { + url.searchParams.set(key, value); + } + } + + const { authOnly = false, headers: customHeaders, ...fetchOptions } = options; + + const baseHeaders = authOnly ? getAuthHeaders() : getJsonHeaders(); + const headers = { ...baseHeaders, ...customHeaders }; + + const response = await fetch(url.toString(), { + ...fetchOptions, + headers + }); + + if (!response.ok) { + const errorMessage = await parseErrorMessage(response); + throw new Error(errorMessage); + } + + return response.json() as Promise; +} + +/** + * POST JSON data to an API endpoint. + * + * @param path - API path + * @param body - Request body (will be JSON stringified) + * @param options - Additional fetch options + * @returns Parsed JSON response + */ +export async function apiPost( + path: string, + body: B, + options: ApiFetchOptions = {} +): Promise { + return apiFetch(path, { + method: 'POST', + body: JSON.stringify(body), + ...options + }); +} + +/** + * Parse error message from a failed response. + * Tries to extract error message from JSON body, falls back to status text. + */ +async function parseErrorMessage(response: Response): Promise { + try { + const errorData = await response.json(); + if (errorData?.error?.message) { + return errorData.error.message; + } + if (errorData?.error && typeof errorData.error === 'string') { + return errorData.error; + } + if (errorData?.message) { + return errorData.message; + } + } catch { + // JSON parsing failed, use status text + } + + return `Request failed: ${response.status} ${response.statusText}`; +} diff --git a/tools/server/webui/src/lib/utils/branching.ts b/tools/server/webui/src/lib/utils/branching.ts index 3be56047a..ee3a505ee 100644 --- a/tools/server/webui/src/lib/utils/branching.ts +++ b/tools/server/webui/src/lib/utils/branching.ts @@ -15,6 +15,8 @@ * └── message 5 (assistant) */ +import { MessageRole } from '$lib/enums/chat'; + /** * Filters messages to get the conversation path from root to a specific leaf node. * If the leafNodeId doesn't exist, returns the path with the latest timestamp. @@ -65,8 +67,13 @@ export function filterByLeafNodeId( currentNode = nodeMap.get(currentNode.parent); } - // Sort by timestamp to get chronological order (root to leaf) - result.sort((a, b) => a.timestamp - b.timestamp); + // Sort: system messages first, then by timestamp + result.sort((a, b) => { + if (a.role === MessageRole.SYSTEM && b.role !== MessageRole.SYSTEM) return -1; + if (a.role !== MessageRole.SYSTEM && b.role === MessageRole.SYSTEM) return 1; + + return a.timestamp - b.timestamp; + }); return result; } diff --git a/tools/server/webui/src/lib/utils/browser-only.ts b/tools/server/webui/src/lib/utils/browser-only.ts index 0af800638..27d2be4aa 100644 --- a/tools/server/webui/src/lib/utils/browser-only.ts +++ b/tools/server/webui/src/lib/utils/browser-only.ts @@ -23,7 +23,7 @@ export { } from './pdf-processing'; // File conversion utilities (depends on pdf-processing) -export { parseFilesToMessageExtras, type FileProcessingResult } from './convert-files-to-extra'; +export { parseFilesToMessageExtras } from './convert-files-to-extra'; // File upload processing utilities (depends on pdf-processing, svg-to-png, webp-to-png) export { processFilesToChatUploaded } from './process-uploaded-files'; diff --git a/tools/server/webui/src/lib/utils/cache-ttl.ts b/tools/server/webui/src/lib/utils/cache-ttl.ts new file mode 100644 index 000000000..9d1f00582 --- /dev/null +++ b/tools/server/webui/src/lib/utils/cache-ttl.ts @@ -0,0 +1,293 @@ +const DEFAULT_CACHE_TTL_MS = 5 * 60 * 1000; +const DEFAULT_CACHE_MAX_ENTRIES = 100; + +/** + * TTL Cache - Time-To-Live cache implementation for memory optimization + * + * Provides automatic expiration of cached entries to prevent memory bloat + * in long-running sessions. + * + * @example + * ```ts + * const cache = new TTLCache({ ttlMs: 5 * 60 * 1000 }); // 5 minutes + * cache.set('key', data); + * const value = cache.get('key'); // null if expired + * ``` + */ + +export interface TTLCacheOptions { + /** Time-to-live in milliseconds. Default: 5 minutes */ + ttlMs?: number; + /** Maximum number of entries. Oldest entries are evicted when exceeded. Default: 100 */ + maxEntries?: number; + /** Callback when an entry expires or is evicted */ + onEvict?: (key: string, value: unknown) => void; +} + +interface CacheEntry { + value: T; + expiresAt: number; + lastAccessed: number; +} + +export class TTLCache { + private cache = new Map>(); + private readonly ttlMs: number; + private readonly maxEntries: number; + private readonly onEvict?: (key: string, value: unknown) => void; + + constructor(options: TTLCacheOptions = {}) { + this.ttlMs = options.ttlMs ?? DEFAULT_CACHE_TTL_MS; + this.maxEntries = options.maxEntries ?? DEFAULT_CACHE_MAX_ENTRIES; + this.onEvict = options.onEvict; + } + + /** + * Get a value from cache. Returns null if expired or not found. + */ + get(key: K): V | null { + const entry = this.cache.get(key); + if (!entry) return null; + + if (Date.now() > entry.expiresAt) { + this.delete(key); + return null; + } + + // Update last accessed time for LRU-like behavior + entry.lastAccessed = Date.now(); + return entry.value; + } + + /** + * Set a value in cache with TTL. + */ + set(key: K, value: V, customTtlMs?: number): void { + // Evict oldest entries if at capacity + if (this.cache.size >= this.maxEntries && !this.cache.has(key)) { + this.evictOldest(); + } + + const ttl = customTtlMs ?? this.ttlMs; + const now = Date.now(); + + this.cache.set(key, { + value, + expiresAt: now + ttl, + lastAccessed: now + }); + } + + /** + * Check if key exists and is not expired. + */ + has(key: K): boolean { + const entry = this.cache.get(key); + if (!entry) return false; + + if (Date.now() > entry.expiresAt) { + this.delete(key); + return false; + } + + return true; + } + + /** + * Delete a specific key from cache. + */ + delete(key: K): boolean { + const entry = this.cache.get(key); + if (entry && this.onEvict) { + this.onEvict(key, entry.value); + } + return this.cache.delete(key); + } + + /** + * Clear all entries from cache. + */ + clear(): void { + if (this.onEvict) { + for (const [key, entry] of this.cache) { + this.onEvict(key, entry.value); + } + } + this.cache.clear(); + } + + /** + * Get the number of entries (including potentially expired ones). + */ + get size(): number { + return this.cache.size; + } + + /** + * Remove all expired entries from cache. + * Call periodically for proactive cleanup. + */ + prune(): number { + const now = Date.now(); + let pruned = 0; + + for (const [key, entry] of this.cache) { + if (now > entry.expiresAt) { + this.delete(key); + pruned++; + } + } + + return pruned; + } + + /** + * Get all valid (non-expired) keys. + */ + keys(): K[] { + const now = Date.now(); + const validKeys: K[] = []; + + for (const [key, entry] of this.cache) { + if (now <= entry.expiresAt) { + validKeys.push(key); + } + } + + return validKeys; + } + + /** + * Evict the oldest (least recently accessed) entry. + */ + private evictOldest(): void { + let oldestKey: K | null = null; + let oldestTime = Infinity; + + for (const [key, entry] of this.cache) { + if (entry.lastAccessed < oldestTime) { + oldestTime = entry.lastAccessed; + oldestKey = key; + } + } + + if (oldestKey !== null) { + this.delete(oldestKey); + } + } + + /** + * Refresh TTL for an existing entry without changing the value. + */ + touch(key: K): boolean { + const entry = this.cache.get(key); + if (!entry) return false; + + const now = Date.now(); + if (now > entry.expiresAt) { + this.delete(key); + return false; + } + + entry.expiresAt = now + this.ttlMs; + entry.lastAccessed = now; + return true; + } +} + +/** + * Reactive TTL Map for Svelte stores + * Wraps SvelteMap with TTL functionality + */ +export class ReactiveTTLMap { + private entries = $state>>(new Map()); + private readonly ttlMs: number; + private readonly maxEntries: number; + + constructor(options: TTLCacheOptions = {}) { + this.ttlMs = options.ttlMs ?? DEFAULT_CACHE_TTL_MS; + this.maxEntries = options.maxEntries ?? DEFAULT_CACHE_MAX_ENTRIES; + } + + get(key: K): V | null { + const entry = this.entries.get(key); + if (!entry) return null; + + if (Date.now() > entry.expiresAt) { + this.entries.delete(key); + return null; + } + + entry.lastAccessed = Date.now(); + return entry.value; + } + + set(key: K, value: V, customTtlMs?: number): void { + if (this.entries.size >= this.maxEntries && !this.entries.has(key)) { + this.evictOldest(); + } + + const ttl = customTtlMs ?? this.ttlMs; + const now = Date.now(); + + this.entries.set(key, { + value, + expiresAt: now + ttl, + lastAccessed: now + }); + } + + has(key: K): boolean { + const entry = this.entries.get(key); + if (!entry) return false; + + if (Date.now() > entry.expiresAt) { + this.entries.delete(key); + return false; + } + + return true; + } + + delete(key: K): boolean { + return this.entries.delete(key); + } + + clear(): void { + this.entries.clear(); + } + + get size(): number { + return this.entries.size; + } + + prune(): number { + const now = Date.now(); + let pruned = 0; + + for (const [key, entry] of this.entries) { + if (now > entry.expiresAt) { + this.entries.delete(key); + pruned++; + } + } + + return pruned; + } + + private evictOldest(): void { + let oldestKey: K | null = null; + let oldestTime = Infinity; + + for (const [key, entry] of this.entries) { + if (entry.lastAccessed < oldestTime) { + oldestTime = entry.lastAccessed; + oldestKey = key; + } + } + + if (oldestKey !== null) { + this.entries.delete(oldestKey); + } + } +} diff --git a/tools/server/webui/src/lib/utils/code.ts b/tools/server/webui/src/lib/utils/code.ts new file mode 100644 index 000000000..67efc6b27 --- /dev/null +++ b/tools/server/webui/src/lib/utils/code.ts @@ -0,0 +1,85 @@ +import hljs from 'highlight.js'; +import { + NEWLINE, + DEFAULT_LANGUAGE, + LANG_PATTERN, + AMPERSAND_REGEX, + LT_REGEX, + GT_REGEX, + FENCE_PATTERN +} from '$lib/constants/code'; + +export interface IncompleteCodeBlock { + language: string; + code: string; + openingIndex: number; +} + +/** + * Highlights code using highlight.js + * @param code - The code to highlight + * @param language - The programming language + * @returns HTML string with syntax highlighting + */ +export function highlightCode(code: string, language: string): string { + if (!code) return ''; + + try { + const lang = language.toLowerCase(); + const isSupported = hljs.getLanguage(lang); + + if (isSupported) { + return hljs.highlight(code, { language: lang }).value; + } else { + return hljs.highlightAuto(code).value; + } + } catch { + // Fallback to escaped plain text + return code + .replace(AMPERSAND_REGEX, '&') + .replace(LT_REGEX, '<') + .replace(GT_REGEX, '>'); + } +} + +/** + * Detects if markdown ends with an incomplete code block (opened but not closed). + * Returns the code block info if found, null otherwise. + * @param markdown - The raw markdown string to check + * @returns IncompleteCodeBlock info or null + */ +export function detectIncompleteCodeBlock(markdown: string): IncompleteCodeBlock | null { + // Count all code fences in the markdown + // A code block is incomplete if there's an odd number of ``` fences + const fencePattern = new RegExp(FENCE_PATTERN.source, FENCE_PATTERN.flags); + const fences: number[] = []; + let fenceMatch; + + while ((fenceMatch = fencePattern.exec(markdown)) !== null) { + // Store the position after the ``` + const pos = fenceMatch[0].startsWith(NEWLINE) ? fenceMatch.index + 1 : fenceMatch.index; + fences.push(pos); + } + + // If even number of fences (including 0), all code blocks are closed + if (fences.length % 2 === 0) { + return null; + } + + // Odd number means last code block is incomplete + // The last fence is the opening of the incomplete block + const openingIndex = fences[fences.length - 1]; + const afterOpening = markdown.slice(openingIndex + 3); + + // Extract language and code content + const langMatch = afterOpening.match(LANG_PATTERN); + const language = langMatch?.[1] || DEFAULT_LANGUAGE; + const codeStartIndex = openingIndex + 3 + (langMatch?.[0]?.length ?? 0); + const code = markdown.slice(codeStartIndex); + + return { + language, + code, + openingIndex + }; +} diff --git a/tools/server/webui/src/lib/utils/data-url.ts b/tools/server/webui/src/lib/utils/data-url.ts new file mode 100644 index 000000000..6f55be793 --- /dev/null +++ b/tools/server/webui/src/lib/utils/data-url.ts @@ -0,0 +1,10 @@ +/** + * Creates a base64 data URL from MIME type and base64-encoded data. + * + * @param mimeType - The MIME type (e.g., 'image/png', 'audio/mp3') + * @param base64Data - The base64-encoded data + * @returns A data URL string in format 'data:{mimeType};base64,{data}' + */ +export function createBase64DataUrl(mimeType: string, base64Data: string): string { + return `data:${mimeType};base64,${base64Data}`; +} diff --git a/tools/server/webui/src/lib/utils/debounce.ts b/tools/server/webui/src/lib/utils/debounce.ts new file mode 100644 index 000000000..90a5a0178 --- /dev/null +++ b/tools/server/webui/src/lib/utils/debounce.ts @@ -0,0 +1,22 @@ +/** + * @param fn - The function to debounce + * @param delay - The delay in milliseconds + * @returns A debounced version of the function + */ +export function debounce) => void>( + fn: T, + delay: number +): (...args: Parameters) => void { + let timeoutId: ReturnType | null = null; + + return (...args: Parameters) => { + if (timeoutId) { + clearTimeout(timeoutId); + } + + timeoutId = setTimeout(() => { + fn(...args); + timeoutId = null; + }, delay); + }; +} diff --git a/tools/server/webui/src/lib/utils/image-error-fallback.ts b/tools/server/webui/src/lib/utils/image-error-fallback.ts new file mode 100644 index 000000000..6e3260f4a --- /dev/null +++ b/tools/server/webui/src/lib/utils/image-error-fallback.ts @@ -0,0 +1,10 @@ +/** + * Simplified HTML fallback for external images that fail to load. + * Displays a centered message with a link to open the image in a new tab. + */ +export function getImageErrorFallbackHtml(src: string): string { + return `
+ Image cannot be displayed + (open link) +
`; +} diff --git a/tools/server/webui/src/lib/utils/index.ts b/tools/server/webui/src/lib/utils/index.ts index 588167b8c..5eb2bbaea 100644 --- a/tools/server/webui/src/lib/utils/index.ts +++ b/tools/server/webui/src/lib/utils/index.ts @@ -9,6 +9,7 @@ // API utilities export { getAuthHeaders, getJsonHeaders } from './api-headers'; +export { apiFetch, apiFetchWithParams, apiPost, type ApiFetchOptions } from './api-fetch'; export { validateApiKey } from './api-key-validation'; // Attachment utilities @@ -75,8 +76,7 @@ export { maskInlineLaTeX, preprocessLaTeX } from './latex-protection'; export { isFileTypeSupportedByModel, filterFilesByModalities, - generateModalityErrorMessage, - type ModalityCapabilities + generateModalityErrorMessage } from './modality-file-validation'; // Model name utilities @@ -93,3 +93,6 @@ export { getLanguageFromFilename } from './syntax-highlight-language'; // Text file utilities export { isTextFileByName, readFileAsText, isLikelyTextFile } from './text-files'; + +// Image error fallback utilities +export { getImageErrorFallbackHtml } from './image-error-fallback'; diff --git a/tools/server/webui/src/lib/utils/modality-file-validation.ts b/tools/server/webui/src/lib/utils/modality-file-validation.ts index 136c08414..02fb4e4a3 100644 --- a/tools/server/webui/src/lib/utils/modality-file-validation.ts +++ b/tools/server/webui/src/lib/utils/modality-file-validation.ts @@ -5,12 +5,7 @@ import { getFileTypeCategory } from '$lib/utils'; import { FileTypeCategory } from '$lib/enums'; - -/** Modality capabilities for file validation */ -export interface ModalityCapabilities { - hasVision: boolean; - hasAudio: boolean; -} +import type { ModalityCapabilities } from '$lib/types/models'; /** * Check if a file type is supported by the given modalities diff --git a/tools/server/webui/src/lib/utils/text-files.ts b/tools/server/webui/src/lib/utils/text-files.ts index e8006de64..2f1a575d1 100644 --- a/tools/server/webui/src/lib/utils/text-files.ts +++ b/tools/server/webui/src/lib/utils/text-files.ts @@ -3,10 +3,8 @@ * Handles text file detection, reading, and validation */ -import { - DEFAULT_BINARY_DETECTION_OPTIONS, - type BinaryDetectionOptions -} from '$lib/constants/binary-detection'; +import { DEFAULT_BINARY_DETECTION_OPTIONS } from '$lib/constants/binary-detection'; +import type { BinaryDetectionOptions } from '$lib/constants/binary-detection'; import { FileExtensionText } from '$lib/enums'; /**