mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-20 01:09:57 +00:00
before merging conflicting round
This commit is contained in:
commit
ebc1cb0641
34 changed files with 1133 additions and 111 deletions
|
|
@ -516,8 +516,8 @@ void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggm
|
|||
const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>;
|
||||
constexpr bool need_f16_K = false;
|
||||
constexpr bool need_f16_V = false;
|
||||
const bool need_f16_K = type_K == GGML_TYPE_F16;
|
||||
const bool need_f16_V = type_V == GGML_TYPE_F16;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||
}
|
||||
|
|
@ -526,11 +526,6 @@ template <int D, ggml_type type_K, ggml_type type_V>
|
|||
void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
GGML_ASSERT(K->type == type_K);
|
||||
GGML_ASSERT(V->type == type_V);
|
||||
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
|
|
|||
|
|
@ -116,11 +116,15 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
|||
}
|
||||
}
|
||||
|
||||
#define FATTN_VEC_CASE(D, type_K, type_V) \
|
||||
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
|
||||
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
|
||||
return; \
|
||||
} \
|
||||
#define FATTN_VEC_CASE(D, type_K, type_V) \
|
||||
{ \
|
||||
const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
|
||||
const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
|
||||
if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
|
||||
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
|
||||
return; \
|
||||
} \
|
||||
} \
|
||||
|
||||
#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
|
||||
FATTN_VEC_CASE( 64, type_K, type_V) \
|
||||
|
|
@ -253,6 +257,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||
|
||||
switch (K->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
|
|
@ -278,7 +283,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
// If Turing tensor cores available, use them:
|
||||
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
|
||||
if (can_use_vector_kernel) {
|
||||
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
|
||||
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
|
|
@ -325,7 +330,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
|
||||
// If there are no tensor cores available, use the generic tile kernel:
|
||||
if (can_use_vector_kernel) {
|
||||
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
|
||||
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
if (Q->ne[1] == 1) {
|
||||
if (!gqa_opt_applies) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
|
|
|
|||
|
|
@ -276,6 +276,15 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|||
} else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
|
||||
turing_devices_without_mma.push_back({ id, device_name });
|
||||
}
|
||||
|
||||
// Temporary performance fix:
|
||||
// Setting device scheduling strategy for iGPUs with cc121 to "spinning" to avoid delays in cuda synchronize calls.
|
||||
// TODO: Check for future drivers the default scheduling strategy and
|
||||
// remove this call again when cudaDeviceScheduleSpin is default.
|
||||
if (prop.major == 12 && prop.minor == 1) {
|
||||
CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin));
|
||||
}
|
||||
|
||||
#endif // defined(GGML_USE_HIP)
|
||||
}
|
||||
|
||||
|
|
@ -2889,7 +2898,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||
}
|
||||
|
||||
//if rms norm is the B operand, then we don't handle broadcast
|
||||
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
|
||||
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -3629,9 +3638,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_ACC:
|
||||
return true;
|
||||
case GGML_OP_SUM:
|
||||
return ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_ARGSORT:
|
||||
// TODO: Support arbitrary column width
|
||||
return op->src[0]->ne[0] <= 1024;
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@
|
|||
|
||||
#include <Metal/Metal.h>
|
||||
|
||||
#include <stdatomic.h>
|
||||
|
||||
#ifndef TARGET_OS_VISION
|
||||
#define TARGET_OS_VISION 0
|
||||
#endif
|
||||
|
|
@ -22,6 +24,9 @@
|
|||
// overload of MTLGPUFamilyMetal3 (not available in some environments)
|
||||
static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
|
||||
|
||||
// virtual address for GPU memory allocations
|
||||
static atomic_uintptr_t g_addr_device = 0x000000400ULL;
|
||||
|
||||
#if !GGML_METAL_EMBED_LIBRARY
|
||||
// Here to assist with NSBundle Path Hack
|
||||
@interface GGMLMetalClass : NSObject
|
||||
|
|
@ -657,6 +662,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
case GGML_OP_LOG:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SUM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
|
|
@ -827,7 +833,7 @@ struct ggml_metal_buffer_wrapper {
|
|||
};
|
||||
|
||||
struct ggml_metal_buffer {
|
||||
void * all_data; // TODO: https://github.com/ggml-org/llama.cpp/pull/15985
|
||||
void * all_data;
|
||||
size_t all_size;
|
||||
|
||||
// if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host
|
||||
|
|
@ -965,14 +971,15 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
|
|||
if (shared) {
|
||||
res->all_data = ggml_metal_host_malloc(size_aligned);
|
||||
res->is_shared = true;
|
||||
res->owned = true;
|
||||
} else {
|
||||
// dummy, non-NULL value - we'll populate this after creating the Metal buffer below
|
||||
res->all_data = (void *) 0x000000400ULL;
|
||||
// use virtual address from g_addr_device counter
|
||||
res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed);
|
||||
res->is_shared = false;
|
||||
}
|
||||
res->all_size = size_aligned;
|
||||
|
||||
res->owned = true;
|
||||
|
||||
res->device = ggml_metal_device_get_obj(dev);
|
||||
res->queue = ggml_metal_device_get_queue(dev);
|
||||
|
||||
|
|
@ -983,15 +990,13 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
|
|||
res->buffers[0].metal = nil;
|
||||
|
||||
if (size_aligned > 0) {
|
||||
if (props_dev->use_shared_buffers &&shared) {
|
||||
if (props_dev->use_shared_buffers && shared) {
|
||||
res->buffers[0].metal = [res->device newBufferWithBytesNoCopy:res->all_data
|
||||
length:size_aligned
|
||||
options:MTLResourceStorageModeShared
|
||||
deallocator:nil];
|
||||
} else {
|
||||
res->buffers[0].metal = [res->device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate];
|
||||
|
||||
res->all_data = (void *) (res->buffers[0].metal.gpuAddress);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1139,7 +1144,7 @@ bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf) {
|
|||
|
||||
void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
||||
if (buf->is_shared) {
|
||||
memset((char *)tensor->data + offset, value, size);
|
||||
memset((char *) tensor->data + offset, value, size);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -1168,7 +1173,7 @@ void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor
|
|||
|
||||
void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
if (buf->is_shared) {
|
||||
memcpy((char *)tensor->data + offset, data, size);
|
||||
memcpy((char *) tensor->data + offset, data, size);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -1223,7 +1228,7 @@ void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor *
|
|||
|
||||
void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
if (buf->is_shared) {
|
||||
memcpy(data, (const char *)tensor->data + offset, size);
|
||||
memcpy(data, (const char *) tensor->data + offset, size);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -251,6 +251,7 @@ typedef struct {
|
|||
int32_t sect_1;
|
||||
int32_t sect_2;
|
||||
int32_t sect_3;
|
||||
bool src2;
|
||||
} ggml_metal_kargs_rope;
|
||||
|
||||
typedef struct {
|
||||
|
|
|
|||
|
|
@ -866,12 +866,25 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
|
|||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
nth = std::min(nth, (int) n);
|
||||
|
||||
const int nsg = (nth + 31) / 32;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
|
@ -2969,6 +2982,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
|
|||
/* sect_1 =*/ sect_1,
|
||||
/* sect_2 =*/ sect_2,
|
||||
/* sect_3 =*/ sect_3,
|
||||
/* src2 =*/ op->src[2] != nullptr,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
|
||||
|
|
|
|||
|
|
@ -1727,18 +1727,48 @@ kernel void kernel_op_sum_f32(
|
|||
constant ggml_metal_kargs_sum & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
ushort tiitg[[thread_index_in_threadgroup]]) {
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
if (tiitg != 0) {
|
||||
if (args.np == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
float acc = 0.0f;
|
||||
for (ulong i = 0; i < args.np; ++i) {
|
||||
acc += src0[i];
|
||||
const uint nsg = (ntg.x + 31) / 32;
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
|
||||
sumf += src0[i0];
|
||||
}
|
||||
|
||||
dst[0] = acc;
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float total = 0;
|
||||
|
||||
if (sgitg == 0) {
|
||||
float v = 0;
|
||||
|
||||
if (tpitg.x < nsg) {
|
||||
v = shmem_f32[tpitg.x];
|
||||
}
|
||||
|
||||
total = simd_sum(v);
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
dst[0] = total;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool norm>
|
||||
|
|
@ -3748,7 +3778,7 @@ kernel void kernel_rope_norm(
|
|||
|
||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
|
|
@ -3801,7 +3831,7 @@ kernel void kernel_rope_neox(
|
|||
|
||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
|
|
@ -3872,7 +3902,7 @@ kernel void kernel_rope_multi(
|
|||
|
||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
|
|
@ -3939,7 +3969,7 @@ kernel void kernel_rope_vision(
|
|||
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
|
||||
// end of mrope
|
||||
|
||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
|
|
|
|||
154
ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl
Normal file
154
ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#define LOAD_VEC_A 4
|
||||
#define LOAD_VEC_B 4
|
||||
|
||||
#define BM 64
|
||||
#define BN 64
|
||||
#define BK 32
|
||||
#define TM 4
|
||||
#define TN 8
|
||||
|
||||
kernel void kernel_mul_mm_q8_0_f32_l4_lm(
|
||||
global char4 * src0_q,
|
||||
global half * src0_d,
|
||||
global float4 * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne11,
|
||||
int ne12,
|
||||
|
||||
int stride_a,
|
||||
int stride_b,
|
||||
int stride_d,
|
||||
|
||||
int batch_stride_a,
|
||||
int batch_stride_b,
|
||||
int batch_stride_d,
|
||||
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = (global float4*)((global char*)src1 + offset1);
|
||||
dst = (global float *)((global char*)dst + offsetd);
|
||||
|
||||
local float buf_a[BM * BK];
|
||||
local float buf_b[BN * BK];
|
||||
|
||||
const int batch_idx = get_global_id(2);
|
||||
|
||||
const int i13 = batch_idx / ne12;
|
||||
const int i12 = batch_idx % ne12;
|
||||
|
||||
const int i03 = i13 / r3;
|
||||
const int i02 = i12 / r2;
|
||||
|
||||
const int batch_idx_a = i03 * ne02 + i02;
|
||||
|
||||
const int ir = get_group_id(0);
|
||||
const int ic = get_group_id(1);
|
||||
|
||||
const int tid = get_local_id(0);
|
||||
const int th_r = tid % (BM / TM);
|
||||
const int th_c = tid / (BM / TM);
|
||||
|
||||
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
|
||||
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
|
||||
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
|
||||
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
|
||||
|
||||
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
|
||||
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
|
||||
|
||||
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
|
||||
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
|
||||
|
||||
float sums[TM * TN];
|
||||
float cache_a[TM];
|
||||
float cache_b[TN];
|
||||
|
||||
for (int i = 0; i < TM * TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
for (int block = 0; block < ne00; block += BK) {
|
||||
for (int l = 0; l < BM; l += loadstride_a) {
|
||||
if (loadc_a + l < ne01) {
|
||||
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
|
||||
int ib = idx / 8;
|
||||
int iqs = idx % 8;
|
||||
|
||||
float d = (float)src0_d[ib];
|
||||
global char4 * qs = src0_q + ib*8 + iqs;
|
||||
char4 q = *qs;
|
||||
float4 v = convert_float4(q)*d;
|
||||
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v.s0;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v.s1;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v.s2;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v.s3;
|
||||
} else {
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < BN; l += loadstride_b) {
|
||||
if (loadc_b + l < ne11) {
|
||||
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
|
||||
} else {
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
pos_a += BK / LOAD_VEC_A;
|
||||
pos_b += BK / LOAD_VEC_B;
|
||||
|
||||
for (int i = 0; i < BK; i++) {
|
||||
for (int j = 0; j < TM; j++) {
|
||||
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
|
||||
}
|
||||
|
||||
for (int j = 0; j < TN; j++) {
|
||||
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
|
||||
}
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
const int sums_idx = cc*TM + cr;
|
||||
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
|
||||
const int dr = ir * BM + th_r * TM;
|
||||
const int dc = ic * BN + th_c * TN;
|
||||
|
||||
const int offsets = batch_idx * batch_stride_d;
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr + cr < ne01 && dc + cc < ne11) {
|
||||
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2665,11 +2665,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
} \
|
||||
}
|
||||
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
||||
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (device->coopmat1_fa_support) {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
|
||||
|
|
@ -2677,6 +2679,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
#endif
|
||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (device->coopmat2) {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
|
||||
|
|
@ -7487,8 +7490,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
}
|
||||
|
||||
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
||||
const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
||||
const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
|
||||
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
||||
uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
|
||||
|
||||
// For F32, the shader treats it as a block of size 4 (for vec4 loads)
|
||||
if (k->type == GGML_TYPE_F32) {
|
||||
k_stride /= 4;
|
||||
}
|
||||
if (v->type == GGML_TYPE_F32) {
|
||||
v_stride /= 4;
|
||||
}
|
||||
|
||||
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
|
||||
bool aligned = (KV % alignment) == 0 &&
|
||||
|
|
@ -12690,6 +12701,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
}
|
||||
switch (op->src[1]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
// supported in scalar and coopmat2 paths
|
||||
|
|
|
|||
|
|
@ -1,6 +1,18 @@
|
|||
|
||||
#include "types.glsl"
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 {
|
||||
vec4 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncF32(const in decodeBufF32 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const vec4 v = bl.block;
|
||||
const uint idx = coordInBlock[1];
|
||||
const f16vec4 vf16 = f16vec4(v);
|
||||
return vf16[idx];
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
|
||||
block_q4_0_packed16 block;
|
||||
};
|
||||
|
|
@ -717,4 +729,6 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
|
|||
#define dequantFuncA dequantFuncIQ4_NL
|
||||
#elif defined(DATA_A_MXFP4)
|
||||
#define dequantFuncA dequantFuncMXFP4
|
||||
#elif defined(DATA_A_F32)
|
||||
#define dequantFuncA dequantFuncF32
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -64,13 +64,31 @@ layout (binding = 4) readonly buffer S {float data_s[];};
|
|||
|
||||
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
|
||||
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
#define BINDING_IDX_K 0
|
||||
#define BINDING_IDX_V 1
|
||||
#if defined(DATA_A_F32)
|
||||
layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed;
|
||||
layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed;
|
||||
#elif defined(A_TYPE_PACKED16)
|
||||
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
|
||||
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_F32)
|
||||
#undef BLOCK_SIZE
|
||||
#define BLOCK_SIZE 4
|
||||
#define BLOCK_BYTE_SIZE 16
|
||||
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
// iqs is currently always zero in the flash attention shaders
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
return k_packed.k_data_packed[a_offset + ib];
|
||||
} else {
|
||||
return v_packed.v_data_packed[a_offset + ib];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#define BLOCK_BYTE_SIZE 18
|
||||
|
||||
|
|
|
|||
|
|
@ -313,12 +313,12 @@ void main() {
|
|||
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
|
||||
}
|
||||
#else
|
||||
ACC_TYPE sums[WMITER * TM * WNITER * TN];
|
||||
ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];
|
||||
FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
|
||||
FLOAT_TYPE_VEC2 cache_b[TN];
|
||||
FLOAT_TYPE_VEC2 cache_b;
|
||||
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = ACC_TYPE(0.0f);
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
|
||||
sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -360,20 +360,22 @@ void main() {
|
|||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (uint j = 0; j < TN; j++) {
|
||||
cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
||||
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx]));
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i];
|
||||
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
|
||||
// [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
|
||||
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
|
||||
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
|
||||
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -388,8 +390,9 @@ void main() {
|
|||
}
|
||||
}
|
||||
#else
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
|
||||
sums[i].x = clamp(sums[i].x, -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
||||
sums[i].y = clamp(sums[i].y, -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
|
@ -463,14 +466,21 @@ void main() {
|
|||
|
||||
const u16vec2 row_idx = row_ids[row_i - ic * BN];
|
||||
#endif // MUL_MAT_ID
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
|
||||
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
|
||||
#ifdef MUL_MAT_ID
|
||||
if (dr_warp + cr < p.M) {
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
||||
if (dr_warp + 2 * cr < p.M) {
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
|
||||
}
|
||||
if (dr_warp + 2 * cr + 1 < p.M) {
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
|
||||
}
|
||||
#else
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
||||
if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
|
||||
}
|
||||
if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) {
|
||||
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
|
||||
}
|
||||
#endif // MUL_MAT_ID
|
||||
}
|
||||
|
|
|
|||
|
|
@ -629,9 +629,6 @@ void process_shaders() {
|
|||
}
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
if (tname == "f32") {
|
||||
continue;
|
||||
}
|
||||
if (tname == "bf16") continue;
|
||||
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
|
|
@ -648,7 +645,7 @@ void process_shaders() {
|
|||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q8_0") {
|
||||
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
|
|
@ -657,7 +654,7 @@ void process_shaders() {
|
|||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q8_0") {
|
||||
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include <map>
|
||||
|
||||
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_CLIP, "clip" }, // dummy, only used by llama-quantize
|
||||
{ LLM_ARCH_LLAMA, "llama" },
|
||||
{ LLM_ARCH_LLAMA4, "llama4" },
|
||||
{ LLM_ARCH_DECI, "deci" },
|
||||
|
|
@ -275,6 +276,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||
};
|
||||
|
||||
static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_NAMES = {
|
||||
{
|
||||
LLM_ARCH_CLIP,
|
||||
{},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_LLAMA,
|
||||
{
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
//
|
||||
|
||||
enum llm_arch {
|
||||
LLM_ARCH_CLIP,
|
||||
LLM_ARCH_LLAMA,
|
||||
LLM_ARCH_LLAMA4,
|
||||
LLM_ARCH_DECI,
|
||||
|
|
|
|||
|
|
@ -483,7 +483,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
ml.get_key(LLM_KV_GENERAL_NAME, name, false);
|
||||
|
||||
// everything past this point is not vocab-related
|
||||
if (hparams.vocab_only) {
|
||||
// for CLIP models, we only need to load tensors, no hparams
|
||||
if (hparams.vocab_only || ml.get_arch() == LLM_ARCH_CLIP) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -20075,6 +20076,7 @@ int32_t llama_n_head(const llama_model * model) {
|
|||
llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
switch (model->arch) {
|
||||
// these models do not use RoPE
|
||||
case LLM_ARCH_CLIP:
|
||||
case LLM_ARCH_GPT2:
|
||||
case LLM_ARCH_GPTJ:
|
||||
case LLM_ARCH_MPT:
|
||||
|
|
|
|||
|
|
@ -704,6 +704,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
});
|
||||
}
|
||||
|
||||
bool is_clip_model = false;
|
||||
for (const auto * it : tensors) {
|
||||
const struct ggml_tensor * tensor = it->tensor;
|
||||
|
||||
|
|
@ -717,12 +718,14 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
} else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
|
||||
qs.has_output = true;
|
||||
}
|
||||
|
||||
is_clip_model |= name.rfind("mm.", 0) == 0; // check the "mm." prefix
|
||||
}
|
||||
|
||||
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
|
||||
|
||||
// sanity checks for models that have attention layers
|
||||
if (qs.n_attention_wv != 0)
|
||||
if (qs.n_attention_wv != 0 && !is_clip_model)
|
||||
{
|
||||
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
|
||||
// attention layers have a non-zero number of kv heads
|
||||
|
|
@ -884,6 +887,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
// do not quantize relative position bias (T5)
|
||||
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
|
||||
|
||||
// do not quantize specific multimodal tensors
|
||||
quantize &= name.find(".position_embd.") == std::string::npos;
|
||||
|
||||
ggml_type new_type;
|
||||
void * new_data;
|
||||
size_t new_size;
|
||||
|
|
|
|||
|
|
@ -155,6 +155,9 @@ static int llama_model_load(const std::string & fname, std::vector<std::string>
|
|||
} catch(const std::exception & e) {
|
||||
throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
|
||||
}
|
||||
if (model.arch == LLM_ARCH_CLIP) {
|
||||
throw std::runtime_error("CLIP cannot be used as main model, use it with --mmproj instead");
|
||||
}
|
||||
try {
|
||||
model.load_vocab(ml);
|
||||
} catch(const std::exception & e) {
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -3812,7 +3812,7 @@ struct server_context {
|
|||
if (slot.n_past > 0 && slot.n_past < (int) slot.prompt.tokens.size()) {
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||
if (pos_min == -1) {
|
||||
SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
|
||||
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
|
||||
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
|
||||
}
|
||||
|
||||
|
|
@ -3839,14 +3839,14 @@ struct server_context {
|
|||
|
||||
{
|
||||
const auto token = slot.prompt.tokens[i];
|
||||
const auto piece = common_token_to_piece(ctx, token);
|
||||
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
|
||||
ss0 << piece;
|
||||
st0 << std::setw(8) << token;
|
||||
}
|
||||
|
||||
{
|
||||
const auto token = slot.task->tokens[i];
|
||||
const auto piece = common_token_to_piece(ctx, token);
|
||||
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
|
||||
ss1 << piece;
|
||||
st1 << std::setw(8) << token;
|
||||
}
|
||||
|
|
@ -3860,7 +3860,7 @@ struct server_context {
|
|||
}
|
||||
|
||||
if (pos_min > pos_min_thold) {
|
||||
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
|
||||
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
|
||||
|
||||
// search for a context checkpoint
|
||||
const auto it = std::find_if(
|
||||
|
|
@ -4028,7 +4028,7 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
// SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
|
||||
// SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
|
||||
|
||||
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_past / slot.n_prompt_tokens());
|
||||
|
||||
|
|
|
|||
|
|
@ -1237,9 +1237,10 @@ public:
|
|||
// allowed to resize ^ ^
|
||||
// disallowed to resize ^ ^ ^
|
||||
if (n > 0) {
|
||||
llama_token last_token = tokens[n - 1];
|
||||
// make sure we never remove tokens in the middle of an image
|
||||
if (last_token == LLAMA_TOKEN_NULL) {
|
||||
// note that the case where we keep a full image at the end is allowed:
|
||||
// tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] != LLAMA_TOKEN_NULL
|
||||
if (tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] == LLAMA_TOKEN_NULL) {
|
||||
find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,8 +14,7 @@
|
|||
import { ChatSettingsFooter, ChatSettingsFields } from '$lib/components/app';
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import { ScrollArea } from '$lib/components/ui/scroll-area';
|
||||
import { SETTING_CONFIG_DEFAULT } from '$lib/constants/settings-config';
|
||||
import { config, updateMultipleConfig, resetConfig } from '$lib/stores/settings.svelte';
|
||||
import { config, updateMultipleConfig } from '$lib/stores/settings.svelte';
|
||||
import { setMode } from 'mode-watcher';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
|
|
@ -267,16 +266,13 @@
|
|||
}
|
||||
|
||||
function handleReset() {
|
||||
resetConfig();
|
||||
localConfig = { ...config() };
|
||||
|
||||
localConfig = { ...SETTING_CONFIG_DEFAULT };
|
||||
|
||||
setMode(SETTING_CONFIG_DEFAULT.theme as 'light' | 'dark' | 'system');
|
||||
originalTheme = SETTING_CONFIG_DEFAULT.theme as string;
|
||||
setMode(localConfig.theme as 'light' | 'dark' | 'system');
|
||||
originalTheme = localConfig.theme as string;
|
||||
}
|
||||
|
||||
function handleSave() {
|
||||
// Validate custom JSON if provided
|
||||
if (localConfig.custom && typeof localConfig.custom === 'string' && localConfig.custom.trim()) {
|
||||
try {
|
||||
JSON.parse(localConfig.custom);
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
<script lang="ts">
|
||||
import { RotateCcw } from '@lucide/svelte';
|
||||
import { Checkbox } from '$lib/components/ui/checkbox';
|
||||
import { Input } from '$lib/components/ui/input';
|
||||
import Label from '$lib/components/ui/label/label.svelte';
|
||||
|
|
@ -6,6 +7,9 @@
|
|||
import { Textarea } from '$lib/components/ui/textarea';
|
||||
import { SETTING_CONFIG_DEFAULT, SETTING_CONFIG_INFO } from '$lib/constants/settings-config';
|
||||
import { supportsVision } from '$lib/stores/server.svelte';
|
||||
import { getParameterInfo, resetParameterToServerDefault } from '$lib/stores/settings.svelte';
|
||||
import { ParameterSyncService } from '$lib/services/parameter-sync';
|
||||
import ParameterSourceIndicator from './ParameterSourceIndicator.svelte';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
|
|
@ -16,22 +20,77 @@
|
|||
}
|
||||
|
||||
let { fields, localConfig, onConfigChange, onThemeChange }: Props = $props();
|
||||
|
||||
// Helper function to get parameter source info for syncable parameters
|
||||
function getParameterSourceInfo(key: string) {
|
||||
if (!ParameterSyncService.canSyncParameter(key)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return getParameterInfo(key);
|
||||
}
|
||||
</script>
|
||||
|
||||
{#each fields as field (field.key)}
|
||||
<div class="space-y-2">
|
||||
{#if field.type === 'input'}
|
||||
<Label for={field.key} class="block text-sm font-medium">
|
||||
{field.label}
|
||||
</Label>
|
||||
{@const paramInfo = getParameterSourceInfo(field.key)}
|
||||
{@const currentValue = String(localConfig[field.key] ?? '')}
|
||||
{@const propsDefault = paramInfo?.serverDefault}
|
||||
{@const isCustomRealTime = (() => {
|
||||
if (!paramInfo || propsDefault === undefined) return false;
|
||||
|
||||
<Input
|
||||
id={field.key}
|
||||
value={String(localConfig[field.key] ?? '')}
|
||||
onchange={(e) => onConfigChange(field.key, e.currentTarget.value)}
|
||||
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] ?? 'none'}`}
|
||||
class="w-full md:max-w-md"
|
||||
/>
|
||||
// Apply same rounding logic for real-time comparison
|
||||
const inputValue = currentValue;
|
||||
const numericInput = parseFloat(inputValue);
|
||||
const normalizedInput = !isNaN(numericInput)
|
||||
? Math.round(numericInput * 1000000) / 1000000
|
||||
: inputValue;
|
||||
const normalizedDefault =
|
||||
typeof propsDefault === 'number'
|
||||
? Math.round(propsDefault * 1000000) / 1000000
|
||||
: propsDefault;
|
||||
|
||||
return normalizedInput !== normalizedDefault;
|
||||
})()}
|
||||
|
||||
<div class="flex items-center gap-2">
|
||||
<Label for={field.key} class="text-sm font-medium">
|
||||
{field.label}
|
||||
</Label>
|
||||
{#if isCustomRealTime}
|
||||
<ParameterSourceIndicator />
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class="relative w-full md:max-w-md">
|
||||
<Input
|
||||
id={field.key}
|
||||
value={currentValue}
|
||||
oninput={(e) => {
|
||||
// Update local config immediately for real-time badge feedback
|
||||
onConfigChange(field.key, e.currentTarget.value);
|
||||
}}
|
||||
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] ?? 'none'}`}
|
||||
class="w-full {isCustomRealTime ? 'pr-8' : ''}"
|
||||
/>
|
||||
{#if isCustomRealTime}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => {
|
||||
resetParameterToServerDefault(field.key);
|
||||
// Trigger UI update by calling onConfigChange with the default value
|
||||
const defaultValue = propsDefault ?? SETTING_CONFIG_DEFAULT[field.key];
|
||||
onConfigChange(field.key, String(defaultValue));
|
||||
}}
|
||||
class="absolute top-1/2 right-2 inline-flex h-5 w-5 -translate-y-1/2 items-center justify-center rounded transition-colors hover:bg-muted"
|
||||
aria-label="Reset to default"
|
||||
title="Reset to default"
|
||||
>
|
||||
<RotateCcw class="h-3 w-3" />
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
{field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
|
|
@ -59,14 +118,28 @@
|
|||
(opt: { value: string; label: string; icon?: Component }) =>
|
||||
opt.value === localConfig[field.key]
|
||||
)}
|
||||
{@const paramInfo = getParameterSourceInfo(field.key)}
|
||||
{@const currentValue = localConfig[field.key]}
|
||||
{@const propsDefault = paramInfo?.serverDefault}
|
||||
{@const isCustomRealTime = (() => {
|
||||
if (!paramInfo || propsDefault === undefined) return false;
|
||||
|
||||
<Label for={field.key} class="block text-sm font-medium">
|
||||
{field.label}
|
||||
</Label>
|
||||
// For select fields, do direct comparison (no rounding needed)
|
||||
return currentValue !== propsDefault;
|
||||
})()}
|
||||
|
||||
<div class="flex items-center gap-2">
|
||||
<Label for={field.key} class="text-sm font-medium">
|
||||
{field.label}
|
||||
</Label>
|
||||
{#if isCustomRealTime}
|
||||
<ParameterSourceIndicator />
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<Select.Root
|
||||
type="single"
|
||||
value={localConfig[field.key]}
|
||||
value={currentValue}
|
||||
onValueChange={(value) => {
|
||||
if (field.key === 'theme' && value && onThemeChange) {
|
||||
onThemeChange(value);
|
||||
|
|
@ -75,16 +148,34 @@
|
|||
}
|
||||
}}
|
||||
>
|
||||
<Select.Trigger class="w-full md:w-auto md:max-w-md">
|
||||
<div class="flex items-center gap-2">
|
||||
{#if selectedOption?.icon}
|
||||
{@const IconComponent = selectedOption.icon}
|
||||
<IconComponent class="h-4 w-4" />
|
||||
{/if}
|
||||
<div class="relative w-full md:w-auto md:max-w-md">
|
||||
<Select.Trigger class="w-full">
|
||||
<div class="flex items-center gap-2">
|
||||
{#if selectedOption?.icon}
|
||||
{@const IconComponent = selectedOption.icon}
|
||||
<IconComponent class="h-4 w-4" />
|
||||
{/if}
|
||||
|
||||
{selectedOption?.label || `Select ${field.label.toLowerCase()}`}
|
||||
</div>
|
||||
</Select.Trigger>
|
||||
{selectedOption?.label || `Select ${field.label.toLowerCase()}`}
|
||||
</div>
|
||||
</Select.Trigger>
|
||||
{#if isCustomRealTime}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => {
|
||||
resetParameterToServerDefault(field.key);
|
||||
// Trigger UI update by calling onConfigChange with the default value
|
||||
const defaultValue = propsDefault ?? SETTING_CONFIG_DEFAULT[field.key];
|
||||
onConfigChange(field.key, String(defaultValue));
|
||||
}}
|
||||
class="absolute top-1/2 right-8 inline-flex h-5 w-5 -translate-y-1/2 items-center justify-center rounded transition-colors hover:bg-muted"
|
||||
aria-label="Reset to default"
|
||||
title="Reset to default"
|
||||
>
|
||||
<RotateCcw class="h-3 w-3" />
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
<Select.Content>
|
||||
{#if field.options}
|
||||
{#each field.options as option (option.value)}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
<script lang="ts">
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
import { forceSyncWithServerDefaults } from '$lib/stores/settings.svelte';
|
||||
import { RotateCcw } from '@lucide/svelte';
|
||||
|
||||
interface Props {
|
||||
onReset?: () => void;
|
||||
|
|
@ -16,7 +18,9 @@
|
|||
}
|
||||
|
||||
function handleConfirmReset() {
|
||||
forceSyncWithServerDefaults();
|
||||
onReset?.();
|
||||
|
||||
showResetDialog = false;
|
||||
}
|
||||
|
||||
|
|
@ -26,7 +30,13 @@
|
|||
</script>
|
||||
|
||||
<div class="flex justify-between border-t border-border/30 p-6">
|
||||
<Button variant="outline" onclick={handleResetClick}>Reset to default</Button>
|
||||
<div class="flex gap-2">
|
||||
<Button variant="outline" onclick={handleResetClick}>
|
||||
<RotateCcw class="h-3 w-3" />
|
||||
|
||||
Reset to default
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<Button onclick={handleSave}>Save settings</Button>
|
||||
</div>
|
||||
|
|
@ -36,8 +46,9 @@
|
|||
<AlertDialog.Header>
|
||||
<AlertDialog.Title>Reset Settings to Default</AlertDialog.Title>
|
||||
<AlertDialog.Description>
|
||||
Are you sure you want to reset all settings to their default values? This action cannot be
|
||||
undone and will permanently remove all your custom configurations.
|
||||
Are you sure you want to reset all settings to their default values? This will reset all
|
||||
parameters to the values provided by the server's /props endpoint and remove all your custom
|
||||
configurations.
|
||||
</AlertDialog.Description>
|
||||
</AlertDialog.Header>
|
||||
<AlertDialog.Footer>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,18 @@
|
|||
<script lang="ts">
|
||||
import { Wrench } from '@lucide/svelte';
|
||||
import { Badge } from '$lib/components/ui/badge';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { class: className = '' }: Props = $props();
|
||||
</script>
|
||||
|
||||
<Badge
|
||||
variant="secondary"
|
||||
class="h-5 bg-orange-100 px-1.5 py-0.5 text-xs text-orange-800 dark:bg-orange-900 dark:text-orange-200 {className}"
|
||||
>
|
||||
<Wrench class="mr-1 h-3 w-3" />
|
||||
Custom
|
||||
</Badge>
|
||||
|
|
@ -25,6 +25,7 @@ export { default as ChatScreen } from './chat/ChatScreen/ChatScreen.svelte';
|
|||
export { default as ChatSettingsDialog } from './chat/ChatSettings/ChatSettingsDialog.svelte';
|
||||
export { default as ChatSettingsFooter } from './chat/ChatSettings/ChatSettingsFooter.svelte';
|
||||
export { default as ChatSettingsFields } from './chat/ChatSettings/ChatSettingsFields.svelte';
|
||||
export { default as ParameterSourceIndicator } from './chat/ChatSettings/ParameterSourceIndicator.svelte';
|
||||
|
||||
export { default as ChatSidebar } from './chat/ChatSidebar/ChatSidebar.svelte';
|
||||
export { default as ChatSidebarConversationItem } from './chat/ChatSidebar/ChatSidebarConversationItem.svelte';
|
||||
|
|
|
|||
2
tools/server/webui/src/lib/constants/precision.ts
Normal file
2
tools/server/webui/src/lib/constants/precision.ts
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
export const PRECISION_MULTIPLIER = 1000000;
|
||||
export const PRECISION_DECIMAL_PLACES = 6;
|
||||
135
tools/server/webui/src/lib/services/parameter-sync.spec.ts
Normal file
135
tools/server/webui/src/lib/services/parameter-sync.spec.ts
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
import { describe, it, expect } from 'vitest';
|
||||
import { ParameterSyncService } from './parameter-sync';
|
||||
import type { ApiLlamaCppServerProps } from '$lib/types/api';
|
||||
|
||||
describe('ParameterSyncService', () => {
|
||||
describe('roundFloatingPoint', () => {
|
||||
it('should fix JavaScript floating-point precision issues', () => {
|
||||
// Test the specific values from the screenshot
|
||||
const mockServerParams = {
|
||||
top_p: 0.949999988079071,
|
||||
min_p: 0.009999999776482582,
|
||||
temperature: 0.800000011920929,
|
||||
top_k: 40,
|
||||
samplers: ['top_k', 'typ_p', 'top_p', 'min_p', 'temperature']
|
||||
};
|
||||
|
||||
const result = ParameterSyncService.extractServerDefaults({
|
||||
...mockServerParams,
|
||||
// Add other required fields to match the API type
|
||||
n_predict: 512,
|
||||
seed: -1,
|
||||
dynatemp_range: 0.0,
|
||||
dynatemp_exponent: 1.0,
|
||||
xtc_probability: 0.0,
|
||||
xtc_threshold: 0.1,
|
||||
typ_p: 1.0,
|
||||
repeat_last_n: 64,
|
||||
repeat_penalty: 1.0,
|
||||
presence_penalty: 0.0,
|
||||
frequency_penalty: 0.0,
|
||||
dry_multiplier: 0.0,
|
||||
dry_base: 1.75,
|
||||
dry_allowed_length: 2,
|
||||
dry_penalty_last_n: -1,
|
||||
mirostat: 0,
|
||||
mirostat_tau: 5.0,
|
||||
mirostat_eta: 0.1,
|
||||
stop: [],
|
||||
max_tokens: -1,
|
||||
n_keep: 0,
|
||||
n_discard: 0,
|
||||
ignore_eos: false,
|
||||
stream: true,
|
||||
logit_bias: [],
|
||||
n_probs: 0,
|
||||
min_keep: 0,
|
||||
grammar: '',
|
||||
grammar_lazy: false,
|
||||
grammar_triggers: [],
|
||||
preserved_tokens: [],
|
||||
chat_format: '',
|
||||
reasoning_format: '',
|
||||
reasoning_in_content: false,
|
||||
thinking_forced_open: false,
|
||||
'speculative.n_max': 0,
|
||||
'speculative.n_min': 0,
|
||||
'speculative.p_min': 0.0,
|
||||
timings_per_token: false,
|
||||
post_sampling_probs: false,
|
||||
lora: [],
|
||||
top_n_sigma: 0.0,
|
||||
dry_sequence_breakers: []
|
||||
} as ApiLlamaCppServerProps['default_generation_settings']['params']);
|
||||
|
||||
// Check that the problematic floating-point values are rounded correctly
|
||||
expect(result.top_p).toBe(0.95);
|
||||
expect(result.min_p).toBe(0.01);
|
||||
expect(result.temperature).toBe(0.8);
|
||||
expect(result.top_k).toBe(40); // Integer should remain unchanged
|
||||
expect(result.samplers).toBe('top_k;typ_p;top_p;min_p;temperature');
|
||||
});
|
||||
|
||||
it('should preserve non-numeric values', () => {
|
||||
const mockServerParams = {
|
||||
samplers: ['top_k', 'temperature'],
|
||||
max_tokens: -1,
|
||||
temperature: 0.7
|
||||
};
|
||||
|
||||
const result = ParameterSyncService.extractServerDefaults({
|
||||
...mockServerParams,
|
||||
// Minimal required fields
|
||||
n_predict: 512,
|
||||
seed: -1,
|
||||
dynatemp_range: 0.0,
|
||||
dynatemp_exponent: 1.0,
|
||||
top_k: 40,
|
||||
top_p: 0.95,
|
||||
min_p: 0.05,
|
||||
xtc_probability: 0.0,
|
||||
xtc_threshold: 0.1,
|
||||
typ_p: 1.0,
|
||||
repeat_last_n: 64,
|
||||
repeat_penalty: 1.0,
|
||||
presence_penalty: 0.0,
|
||||
frequency_penalty: 0.0,
|
||||
dry_multiplier: 0.0,
|
||||
dry_base: 1.75,
|
||||
dry_allowed_length: 2,
|
||||
dry_penalty_last_n: -1,
|
||||
mirostat: 0,
|
||||
mirostat_tau: 5.0,
|
||||
mirostat_eta: 0.1,
|
||||
stop: [],
|
||||
n_keep: 0,
|
||||
n_discard: 0,
|
||||
ignore_eos: false,
|
||||
stream: true,
|
||||
logit_bias: [],
|
||||
n_probs: 0,
|
||||
min_keep: 0,
|
||||
grammar: '',
|
||||
grammar_lazy: false,
|
||||
grammar_triggers: [],
|
||||
preserved_tokens: [],
|
||||
chat_format: '',
|
||||
reasoning_format: '',
|
||||
reasoning_in_content: false,
|
||||
thinking_forced_open: false,
|
||||
'speculative.n_max': 0,
|
||||
'speculative.n_min': 0,
|
||||
'speculative.p_min': 0.0,
|
||||
timings_per_token: false,
|
||||
post_sampling_probs: false,
|
||||
lora: [],
|
||||
top_n_sigma: 0.0,
|
||||
dry_sequence_breakers: []
|
||||
} as ApiLlamaCppServerProps['default_generation_settings']['params']);
|
||||
|
||||
expect(result.samplers).toBe('top_k;temperature');
|
||||
expect(result.max_tokens).toBe(-1);
|
||||
expect(result.temperature).toBe(0.7);
|
||||
});
|
||||
});
|
||||
});
|
||||
202
tools/server/webui/src/lib/services/parameter-sync.ts
Normal file
202
tools/server/webui/src/lib/services/parameter-sync.ts
Normal file
|
|
@ -0,0 +1,202 @@
|
|||
/**
|
||||
* ParameterSyncService - Handles synchronization between server defaults and user settings
|
||||
*
|
||||
* This service manages the complex logic of merging server-provided default parameters
|
||||
* with user-configured overrides, ensuring the UI reflects the actual server state
|
||||
* while preserving user customizations.
|
||||
*
|
||||
* **Key Responsibilities:**
|
||||
* - Extract syncable parameters from server props
|
||||
* - Merge server defaults with user overrides
|
||||
* - Track parameter sources (server, user, default)
|
||||
* - Provide sync utilities for settings store integration
|
||||
*/
|
||||
|
||||
import type { ApiLlamaCppServerProps } from '$lib/types/api';
|
||||
import { normalizeFloatingPoint } from '$lib/utils/precision';
|
||||
|
||||
export type ParameterSource = 'default' | 'custom';
|
||||
export type ParameterValue = string | number | boolean;
|
||||
export type ParameterRecord = Record<string, ParameterValue>;
|
||||
|
||||
export interface ParameterInfo {
|
||||
value: string | number | boolean;
|
||||
source: ParameterSource;
|
||||
serverDefault?: string | number | boolean;
|
||||
userOverride?: string | number | boolean;
|
||||
}
|
||||
|
||||
export interface SyncableParameter {
|
||||
key: string;
|
||||
serverKey: string;
|
||||
type: 'number' | 'string' | 'boolean';
|
||||
canSync: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Mapping of webui setting keys to server parameter keys
|
||||
* Only parameters that should be synced from server are included
|
||||
*/
|
||||
export const SYNCABLE_PARAMETERS: SyncableParameter[] = [
|
||||
{ key: 'temperature', serverKey: 'temperature', type: 'number', canSync: true },
|
||||
{ key: 'top_k', serverKey: 'top_k', type: 'number', canSync: true },
|
||||
{ key: 'top_p', serverKey: 'top_p', type: 'number', canSync: true },
|
||||
{ key: 'min_p', serverKey: 'min_p', type: 'number', canSync: true },
|
||||
{ key: 'dynatemp_range', serverKey: 'dynatemp_range', type: 'number', canSync: true },
|
||||
{ key: 'dynatemp_exponent', serverKey: 'dynatemp_exponent', type: 'number', canSync: true },
|
||||
{ key: 'xtc_probability', serverKey: 'xtc_probability', type: 'number', canSync: true },
|
||||
{ key: 'xtc_threshold', serverKey: 'xtc_threshold', type: 'number', canSync: true },
|
||||
{ key: 'typ_p', serverKey: 'typ_p', type: 'number', canSync: true },
|
||||
{ key: 'repeat_last_n', serverKey: 'repeat_last_n', type: 'number', canSync: true },
|
||||
{ key: 'repeat_penalty', serverKey: 'repeat_penalty', type: 'number', canSync: true },
|
||||
{ key: 'presence_penalty', serverKey: 'presence_penalty', type: 'number', canSync: true },
|
||||
{ key: 'frequency_penalty', serverKey: 'frequency_penalty', type: 'number', canSync: true },
|
||||
{ key: 'dry_multiplier', serverKey: 'dry_multiplier', type: 'number', canSync: true },
|
||||
{ key: 'dry_base', serverKey: 'dry_base', type: 'number', canSync: true },
|
||||
{ key: 'dry_allowed_length', serverKey: 'dry_allowed_length', type: 'number', canSync: true },
|
||||
{ key: 'dry_penalty_last_n', serverKey: 'dry_penalty_last_n', type: 'number', canSync: true },
|
||||
{ key: 'max_tokens', serverKey: 'max_tokens', type: 'number', canSync: true },
|
||||
{ key: 'samplers', serverKey: 'samplers', type: 'string', canSync: true }
|
||||
];
|
||||
|
||||
export class ParameterSyncService {
|
||||
/**
|
||||
* Round floating-point numbers to avoid JavaScript precision issues
|
||||
*/
|
||||
private static roundFloatingPoint(value: ParameterValue): ParameterValue {
|
||||
return normalizeFloatingPoint(value) as ParameterValue;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract server default parameters that can be synced
|
||||
*/
|
||||
static extractServerDefaults(
|
||||
serverParams: ApiLlamaCppServerProps['default_generation_settings']['params'] | null
|
||||
): ParameterRecord {
|
||||
if (!serverParams) return {};
|
||||
|
||||
const extracted: ParameterRecord = {};
|
||||
|
||||
for (const param of SYNCABLE_PARAMETERS) {
|
||||
if (param.canSync && param.serverKey in serverParams) {
|
||||
const value = (serverParams as unknown as Record<string, ParameterValue>)[param.serverKey];
|
||||
if (value !== undefined) {
|
||||
// Apply precision rounding to avoid JavaScript floating-point issues
|
||||
extracted[param.key] = this.roundFloatingPoint(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle samplers array conversion to string
|
||||
if (serverParams.samplers && Array.isArray(serverParams.samplers)) {
|
||||
extracted.samplers = serverParams.samplers.join(';');
|
||||
}
|
||||
|
||||
return extracted;
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge server defaults with current user settings
|
||||
* Returns updated settings that respect user overrides while using server defaults
|
||||
*/
|
||||
static mergeWithServerDefaults(
|
||||
currentSettings: ParameterRecord,
|
||||
serverDefaults: ParameterRecord,
|
||||
userOverrides: Set<string> = new Set()
|
||||
): ParameterRecord {
|
||||
const merged = { ...currentSettings };
|
||||
|
||||
for (const [key, serverValue] of Object.entries(serverDefaults)) {
|
||||
// Only update if user hasn't explicitly overridden this parameter
|
||||
if (!userOverrides.has(key)) {
|
||||
merged[key] = this.roundFloatingPoint(serverValue);
|
||||
}
|
||||
}
|
||||
|
||||
return merged;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get parameter information including source and values
|
||||
*/
|
||||
static getParameterInfo(
|
||||
key: string,
|
||||
currentValue: ParameterValue,
|
||||
propsDefaults: ParameterRecord,
|
||||
userOverrides: Set<string>
|
||||
): ParameterInfo {
|
||||
const hasPropsDefault = propsDefaults[key] !== undefined;
|
||||
const isUserOverride = userOverrides.has(key);
|
||||
|
||||
// Simple logic: either using default (from props) or custom (user override)
|
||||
const source: ParameterSource = isUserOverride ? 'custom' : 'default';
|
||||
|
||||
return {
|
||||
value: currentValue,
|
||||
source,
|
||||
serverDefault: hasPropsDefault ? propsDefaults[key] : undefined, // Keep same field name for compatibility
|
||||
userOverride: isUserOverride ? currentValue : undefined
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a parameter can be synced from server
|
||||
*/
|
||||
static canSyncParameter(key: string): boolean {
|
||||
return SYNCABLE_PARAMETERS.some((param) => param.key === key && param.canSync);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all syncable parameter keys
|
||||
*/
|
||||
static getSyncableParameterKeys(): string[] {
|
||||
return SYNCABLE_PARAMETERS.filter((param) => param.canSync).map((param) => param.key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate server parameter value
|
||||
*/
|
||||
static validateServerParameter(key: string, value: ParameterValue): boolean {
|
||||
const param = SYNCABLE_PARAMETERS.find((p) => p.key === key);
|
||||
if (!param) return false;
|
||||
|
||||
switch (param.type) {
|
||||
case 'number':
|
||||
return typeof value === 'number' && !isNaN(value);
|
||||
case 'string':
|
||||
return typeof value === 'string';
|
||||
case 'boolean':
|
||||
return typeof value === 'boolean';
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a diff between current settings and server defaults
|
||||
*/
|
||||
static createParameterDiff(
|
||||
currentSettings: ParameterRecord,
|
||||
serverDefaults: ParameterRecord
|
||||
): Record<string, { current: ParameterValue; server: ParameterValue; differs: boolean }> {
|
||||
const diff: Record<
|
||||
string,
|
||||
{ current: ParameterValue; server: ParameterValue; differs: boolean }
|
||||
> = {};
|
||||
|
||||
for (const key of this.getSyncableParameterKeys()) {
|
||||
const currentValue = currentSettings[key];
|
||||
const serverValue = serverDefaults[key];
|
||||
|
||||
if (serverValue !== undefined) {
|
||||
diff[key] = {
|
||||
current: currentValue,
|
||||
server: serverValue,
|
||||
differs: currentValue !== serverValue
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return diff;
|
||||
}
|
||||
}
|
||||
|
|
@ -125,6 +125,12 @@ class ServerStore {
|
|||
return this._slotsEndpointAvailable;
|
||||
}
|
||||
|
||||
get serverDefaultParams():
|
||||
| ApiLlamaCppServerProps['default_generation_settings']['params']
|
||||
| null {
|
||||
return this._serverProps?.default_generation_settings?.params || null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if slots endpoint is available based on server properties and endpoint support
|
||||
*/
|
||||
|
|
@ -273,3 +279,4 @@ export const supportedModalities = () => serverStore.supportedModalities;
|
|||
export const supportsVision = () => serverStore.supportsVision;
|
||||
export const supportsAudio = () => serverStore.supportsAudio;
|
||||
export const slotsEndpointAvailable = () => serverStore.slotsEndpointAvailable;
|
||||
export const serverDefaultParams = () => serverStore.serverDefaultParams;
|
||||
|
|
|
|||
|
|
@ -33,11 +33,25 @@
|
|||
|
||||
import { browser } from '$app/environment';
|
||||
import { SETTING_CONFIG_DEFAULT } from '$lib/constants/settings-config';
|
||||
import { normalizeFloatingPoint } from '$lib/utils/precision';
|
||||
import { ParameterSyncService } from '$lib/services/parameter-sync';
|
||||
import { serverStore } from '$lib/stores/server.svelte';
|
||||
import { setConfigValue, getConfigValue, configToParameterRecord } from '$lib/utils/config-helpers';
|
||||
|
||||
class SettingsStore {
|
||||
config = $state<SettingsConfigType>({ ...SETTING_CONFIG_DEFAULT });
|
||||
theme = $state<string>('auto');
|
||||
isInitialized = $state(false);
|
||||
userOverrides = $state<Set<string>>(new Set());
|
||||
|
||||
/**
|
||||
* Helper method to get server defaults with null safety
|
||||
* Centralizes the pattern of getting and extracting server defaults
|
||||
*/
|
||||
private getServerDefaults(): Record<string, string | number | boolean> {
|
||||
const serverParams = serverStore.serverDefaultParams;
|
||||
return serverParams ? ParameterSyncService.extractServerDefaults(serverParams) : {};
|
||||
}
|
||||
|
||||
constructor() {
|
||||
if (browser) {
|
||||
|
|
@ -67,14 +81,20 @@ class SettingsStore {
|
|||
|
||||
try {
|
||||
const savedVal = JSON.parse(localStorage.getItem('config') || '{}');
|
||||
|
||||
// Merge with defaults to prevent breaking changes
|
||||
this.config = {
|
||||
...SETTING_CONFIG_DEFAULT,
|
||||
...savedVal
|
||||
};
|
||||
|
||||
// Load user overrides
|
||||
const savedOverrides = JSON.parse(localStorage.getItem('userOverrides') || '[]');
|
||||
this.userOverrides = new Set(savedOverrides);
|
||||
} catch (error) {
|
||||
console.warn('Failed to parse config from localStorage, using defaults:', error);
|
||||
this.config = { ...SETTING_CONFIG_DEFAULT };
|
||||
this.userOverrides = new Set();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -86,14 +106,30 @@ class SettingsStore {
|
|||
|
||||
this.theme = localStorage.getItem('theme') || 'auto';
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a specific configuration setting
|
||||
* @param key - The configuration key to update
|
||||
* @param value - The new value for the configuration key
|
||||
*/
|
||||
updateConfig<K extends keyof SettingsConfigType>(key: K, value: SettingsConfigType[K]) {
|
||||
updateConfig<K extends keyof SettingsConfigType>(key: K, value: SettingsConfigType[K]): void {
|
||||
this.config[key] = value;
|
||||
|
||||
if (ParameterSyncService.canSyncParameter(key as string)) {
|
||||
const propsDefaults = this.getServerDefaults();
|
||||
const propsDefault = propsDefaults[key as string];
|
||||
|
||||
if (propsDefault !== undefined) {
|
||||
const normalizedValue = normalizeFloatingPoint(value);
|
||||
const normalizedDefault = normalizeFloatingPoint(propsDefault);
|
||||
|
||||
if (normalizedValue === normalizedDefault) {
|
||||
this.userOverrides.delete(key as string);
|
||||
} else {
|
||||
this.userOverrides.add(key as string);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.saveConfig();
|
||||
}
|
||||
|
||||
|
|
@ -103,6 +139,26 @@ class SettingsStore {
|
|||
*/
|
||||
updateMultipleConfig(updates: Partial<SettingsConfigType>) {
|
||||
Object.assign(this.config, updates);
|
||||
|
||||
const propsDefaults = this.getServerDefaults();
|
||||
|
||||
for (const [key, value] of Object.entries(updates)) {
|
||||
if (ParameterSyncService.canSyncParameter(key)) {
|
||||
const propsDefault = propsDefaults[key];
|
||||
|
||||
if (propsDefault !== undefined) {
|
||||
const normalizedValue = normalizeFloatingPoint(value);
|
||||
const normalizedDefault = normalizeFloatingPoint(propsDefault);
|
||||
|
||||
if (normalizedValue === normalizedDefault) {
|
||||
this.userOverrides.delete(key);
|
||||
} else {
|
||||
this.userOverrides.add(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.saveConfig();
|
||||
}
|
||||
|
||||
|
|
@ -114,6 +170,8 @@ class SettingsStore {
|
|||
|
||||
try {
|
||||
localStorage.setItem('config', JSON.stringify(this.config));
|
||||
|
||||
localStorage.setItem('userOverrides', JSON.stringify(Array.from(this.userOverrides)));
|
||||
} catch (error) {
|
||||
console.error('Failed to save config to localStorage:', error);
|
||||
}
|
||||
|
|
@ -185,6 +243,129 @@ class SettingsStore {
|
|||
getAllConfig(): SettingsConfigType {
|
||||
return { ...this.config };
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize settings with props defaults when server properties are first loaded
|
||||
* This sets up the default values from /props endpoint
|
||||
*/
|
||||
syncWithServerDefaults(): void {
|
||||
const serverParams = serverStore.serverDefaultParams;
|
||||
if (!serverParams) {
|
||||
console.warn('No server parameters available for initialization');
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const propsDefaults = this.getServerDefaults();
|
||||
|
||||
for (const [key, propsValue] of Object.entries(propsDefaults)) {
|
||||
const currentValue = getConfigValue(this.config, key);
|
||||
|
||||
const normalizedCurrent = normalizeFloatingPoint(currentValue);
|
||||
const normalizedDefault = normalizeFloatingPoint(propsValue);
|
||||
|
||||
if (normalizedCurrent === normalizedDefault) {
|
||||
this.userOverrides.delete(key);
|
||||
setConfigValue(this.config, key, propsValue);
|
||||
} else if (!this.userOverrides.has(key)) {
|
||||
setConfigValue(this.config, key, propsValue);
|
||||
}
|
||||
}
|
||||
|
||||
this.saveConfig();
|
||||
console.log('Settings initialized with props defaults:', propsDefaults);
|
||||
console.log('Current user overrides after sync:', Array.from(this.userOverrides));
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear all user overrides (for debugging)
|
||||
*/
|
||||
clearAllUserOverrides(): void {
|
||||
this.userOverrides.clear();
|
||||
this.saveConfig();
|
||||
console.log('Cleared all user overrides');
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset all parameters to their default values (from props)
|
||||
* This is used by the "Reset to Default" functionality
|
||||
* Prioritizes server defaults from /props, falls back to webui defaults
|
||||
*/
|
||||
forceSyncWithServerDefaults(): void {
|
||||
const propsDefaults = this.getServerDefaults();
|
||||
const syncableKeys = ParameterSyncService.getSyncableParameterKeys();
|
||||
|
||||
for (const key of syncableKeys) {
|
||||
if (propsDefaults[key] !== undefined) {
|
||||
const normalizedValue = normalizeFloatingPoint(propsDefaults[key]);
|
||||
|
||||
setConfigValue(this.config, key, normalizedValue);
|
||||
} else {
|
||||
if (key in SETTING_CONFIG_DEFAULT) {
|
||||
const defaultValue = getConfigValue(SETTING_CONFIG_DEFAULT, key);
|
||||
|
||||
setConfigValue(this.config, key, defaultValue);
|
||||
}
|
||||
}
|
||||
|
||||
this.userOverrides.delete(key);
|
||||
}
|
||||
|
||||
this.saveConfig();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get parameter information including source for a specific parameter
|
||||
*/
|
||||
getParameterInfo(key: string) {
|
||||
const propsDefaults = this.getServerDefaults();
|
||||
const currentValue = getConfigValue(this.config, key);
|
||||
|
||||
return ParameterSyncService.getParameterInfo(
|
||||
key,
|
||||
currentValue ?? '',
|
||||
propsDefaults,
|
||||
this.userOverrides
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset a parameter to server default (or webui default if no server default)
|
||||
*/
|
||||
resetParameterToServerDefault(key: string): void {
|
||||
const serverDefaults = this.getServerDefaults();
|
||||
|
||||
if (serverDefaults[key] !== undefined) {
|
||||
const value = normalizeFloatingPoint(serverDefaults[key]);
|
||||
|
||||
this.config[key as keyof SettingsConfigType] =
|
||||
value as SettingsConfigType[keyof SettingsConfigType];
|
||||
} else {
|
||||
if (key in SETTING_CONFIG_DEFAULT) {
|
||||
const defaultValue = getConfigValue(SETTING_CONFIG_DEFAULT, key);
|
||||
|
||||
setConfigValue(this.config, key, defaultValue);
|
||||
}
|
||||
}
|
||||
|
||||
this.userOverrides.delete(key);
|
||||
this.saveConfig();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get diff between current settings and server defaults
|
||||
*/
|
||||
getParameterDiff() {
|
||||
const serverDefaults = this.getServerDefaults();
|
||||
if (Object.keys(serverDefaults).length === 0) return {};
|
||||
|
||||
const configAsRecord = configToParameterRecord(
|
||||
this.config,
|
||||
ParameterSyncService.getSyncableParameterKeys()
|
||||
);
|
||||
|
||||
return ParameterSyncService.createParameterDiff(configAsRecord, serverDefaults);
|
||||
}
|
||||
}
|
||||
|
||||
// Create and export the settings store instance
|
||||
|
|
@ -204,3 +385,11 @@ export const resetTheme = settingsStore.resetTheme.bind(settingsStore);
|
|||
export const resetAll = settingsStore.resetAll.bind(settingsStore);
|
||||
export const getConfig = settingsStore.getConfig.bind(settingsStore);
|
||||
export const getAllConfig = settingsStore.getAllConfig.bind(settingsStore);
|
||||
export const syncWithServerDefaults = settingsStore.syncWithServerDefaults.bind(settingsStore);
|
||||
export const forceSyncWithServerDefaults =
|
||||
settingsStore.forceSyncWithServerDefaults.bind(settingsStore);
|
||||
export const getParameterInfo = settingsStore.getParameterInfo.bind(settingsStore);
|
||||
export const resetParameterToServerDefault =
|
||||
settingsStore.resetParameterToServerDefault.bind(settingsStore);
|
||||
export const getParameterDiff = settingsStore.getParameterDiff.bind(settingsStore);
|
||||
export const clearAllUserOverrides = settingsStore.clearAllUserOverrides.bind(settingsStore);
|
||||
|
|
|
|||
53
tools/server/webui/src/lib/utils/config-helpers.ts
Normal file
53
tools/server/webui/src/lib/utils/config-helpers.ts
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Type-safe configuration helpers
|
||||
*
|
||||
* Provides utilities for safely accessing and modifying configuration objects
|
||||
* with dynamic keys while maintaining TypeScript type safety.
|
||||
*/
|
||||
|
||||
import type { SettingsConfigType } from '$lib/types/settings';
|
||||
|
||||
/**
|
||||
* Type-safe helper to access config properties dynamically
|
||||
* Provides better type safety than direct casting to Record
|
||||
*/
|
||||
export function setConfigValue<T extends SettingsConfigType>(
|
||||
config: T,
|
||||
key: string,
|
||||
value: unknown
|
||||
): void {
|
||||
if (key in config) {
|
||||
(config as Record<string, unknown>)[key] = value;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Type-safe helper to get config values dynamically
|
||||
*/
|
||||
export function getConfigValue<T extends SettingsConfigType>(
|
||||
config: T,
|
||||
key: string
|
||||
): string | number | boolean | undefined {
|
||||
const value = (config as Record<string, unknown>)[key];
|
||||
return value as string | number | boolean | undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a SettingsConfigType to a ParameterRecord for specific keys
|
||||
* Useful for parameter synchronization operations
|
||||
*/
|
||||
export function configToParameterRecord<T extends SettingsConfigType>(
|
||||
config: T,
|
||||
keys: string[]
|
||||
): Record<string, string | number | boolean> {
|
||||
const record: Record<string, string | number | boolean> = {};
|
||||
|
||||
for (const key of keys) {
|
||||
const value = getConfigValue(config, key);
|
||||
if (value !== undefined) {
|
||||
record[key] = value;
|
||||
}
|
||||
}
|
||||
|
||||
return record;
|
||||
}
|
||||
25
tools/server/webui/src/lib/utils/precision.ts
Normal file
25
tools/server/webui/src/lib/utils/precision.ts
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Floating-point precision utilities
|
||||
*
|
||||
* Provides functions to normalize floating-point numbers for consistent comparison
|
||||
* and display, addressing JavaScript's floating-point precision issues.
|
||||
*/
|
||||
|
||||
import { PRECISION_MULTIPLIER } from '$lib/constants/precision';
|
||||
|
||||
/**
|
||||
* Normalize floating-point numbers for consistent comparison
|
||||
* Addresses JavaScript floating-point precision issues (e.g., 0.949999988079071 → 0.95)
|
||||
*/
|
||||
export function normalizeFloatingPoint(value: unknown): unknown {
|
||||
return typeof value === 'number'
|
||||
? Math.round(value * PRECISION_MULTIPLIER) / PRECISION_MULTIPLIER
|
||||
: value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Type-safe version that only accepts numbers
|
||||
*/
|
||||
export function normalizeNumber(value: number): number {
|
||||
return Math.round(value * PRECISION_MULTIPLIER) / PRECISION_MULTIPLIER;
|
||||
}
|
||||
|
|
@ -9,7 +9,7 @@
|
|||
} from '$lib/stores/chat.svelte';
|
||||
import * as Sidebar from '$lib/components/ui/sidebar/index.js';
|
||||
import { serverStore } from '$lib/stores/server.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { config, settingsStore } from '$lib/stores/settings.svelte';
|
||||
import { ModeWatcher } from 'mode-watcher';
|
||||
import { Toaster } from 'svelte-sonner';
|
||||
import { goto } from '$app/navigation';
|
||||
|
|
@ -95,6 +95,15 @@
|
|||
serverStore.fetchServerProps();
|
||||
});
|
||||
|
||||
// Sync settings when server props are loaded
|
||||
$effect(() => {
|
||||
const serverProps = serverStore.serverProps;
|
||||
|
||||
if (serverProps?.default_generation_settings?.params) {
|
||||
settingsStore.syncWithServerDefaults();
|
||||
}
|
||||
});
|
||||
|
||||
// Monitor API key changes and redirect to error page if removed or changed when required
|
||||
$effect(() => {
|
||||
const apiKey = config().apiKey;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue