diff --git a/common/chat-parser-xml-toolcall.cpp b/common/chat-parser-xml-toolcall.cpp index 734989555..a80900ff8 100644 --- a/common/chat-parser-xml-toolcall.cpp +++ b/common/chat-parser-xml-toolcall.cpp @@ -724,16 +724,10 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons if (reasoning_unclosed) { if (auto pos = content.find(end_think); pos == std::string::npos && builder.pos() != builder.input().size()) { unclosed_reasoning_content += content; - if (form.allow_toolcall_in_think) { - builder.move_to(tc->groups[0].begin); - if (!builder.try_consume_xml_tool_calls(form)) { - unclosed_reasoning_content += tool_call_start; - builder.move_to(tc->groups[0].end); - } - } else { + if (!(form.allow_toolcall_in_think && tc)) { unclosed_reasoning_content += tool_call_start; + continue; } - continue; } else { reasoning_unclosed = false; std::string reasoning_content; @@ -781,8 +775,12 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons } } else { // This start is in thinking block, skip this tool call - auto pos = think_start + start_think.size(); - unclosed_reasoning_content = content.substr(pos) + tool_call_start; + // This start is in thinking block + if (form.allow_toolcall_in_think) { + unclosed_reasoning_content = content.substr(think_start + start_think.size()); + } else { + unclosed_reasoning_content = content.substr(think_start + start_think.size()) + tool_call_start; + } reasoning_unclosed = true; content.resize(think_start); toolcall_in_think = true; @@ -805,14 +803,35 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons } // remove potential partial suffix - if (content.size() > 0 && builder.pos() == builder.input().size() && unclosed_reasoning_content.empty()) { - rstrip(content); - trim_potential_partial_word(content); - rstrip(content); + if (builder.pos() == builder.input().size()) { + if (unclosed_reasoning_content.empty()) { + rstrip(content); + trim_potential_partial_word(content); + rstrip(content); + } else { + rstrip(unclosed_reasoning_content); + trim_potential_partial_word(unclosed_reasoning_content); + rstrip(unclosed_reasoning_content); + } + } + + // consume unclosed_reasoning_content if allow_toolcall_in_think is set + if (form.allow_toolcall_in_think && !unclosed_reasoning_content.empty()) { + if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) { + builder.add_reasoning_content(unclosed_reasoning_content); + } else { + if (content.empty()) { + content = start_think + unclosed_reasoning_content; + } else { + content += "\n\n" + start_think; + content += unclosed_reasoning_content; + } + } + unclosed_reasoning_content.clear(); } // Add content - if (content.size() != 0) { + if (!content.empty()) { // If there are multiple content blocks if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content && builder.result().content.size() != 0) { builder.add_content("\n\n"); @@ -820,7 +839,7 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons builder.add_content(content); } - // This start is in thinking block, skip this tool call + // This start is in thinking block and toolcall_in_think not set, skip this tool call if (toolcall_in_think && !form.allow_toolcall_in_think) { continue; } @@ -829,7 +848,7 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons if (!tc) { GGML_ASSERT(builder.pos() == builder.input().size()); GGML_ASSERT(unclosed_reasoning_content.empty()); - GGML_ASSERT(!reasoning_unclosed); + if (!form.allow_toolcall_in_think) GGML_ASSERT(!reasoning_unclosed); break; } @@ -854,7 +873,6 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons /** * Parse content uses reasoning and XML-Style tool call - * TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed. */ void common_chat_msg_parser::consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think, const std::string & end_think) { parse_msg_with_xml_tool_calls(*this, form, start_think, end_think); diff --git a/common/chat-parser-xml-toolcall.h b/common/chat-parser-xml-toolcall.h index 67face2b9..b309fb667 100644 --- a/common/chat-parser-xml-toolcall.h +++ b/common/chat-parser-xml-toolcall.h @@ -31,7 +31,7 @@ struct xml_tool_call_format { std::optional last_val_end = std::nullopt; std::optional last_tool_end = std::nullopt; bool trim_raw_argval = false; - bool allow_toolcall_in_think = false; // TODO: UNTESTED!!! + bool allow_toolcall_in_think = false; }; // make a GBNF that accept any strings except those containing any of the forbidden strings. diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index fe3e80037..d740dac06 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -917,12 +917,13 @@ static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) { form.tool_start = "<|tool_call_begin|>"; form.tool_sep = "<|tool_call_argument_begin|>{"; form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; + form.key_val_sep = "\":"; + form.val_end = ","; form.tool_end = "}<|tool_call_end|>"; form.scope_end = "<|tool_calls_section_end|>"; form.raw_argval = false; form.last_val_end = ""; + form.allow_toolcall_in_think = true; return form; })(); builder.consume_reasoning_with_xml_tool_calls(form, "", ""); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 549f13876..2223b33dc 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -494,6 +494,15 @@ static inline void ggml_thread_cpu_relax(void) { static inline void ggml_thread_cpu_relax(void) { _mm_pause(); } +#elif defined(__riscv) +static inline void ggml_thread_cpu_relax(void) { + #ifdef __riscv_zihintpause + __asm__ __volatile__ ("pause"); + #else + /* Encoding of the pause instruction */ + __asm__ __volatile__ (".4byte 0x100000F"); + #endif +} #else static inline void ggml_thread_cpu_relax(void) {;} #endif diff --git a/ggml/src/ggml-cuda/fill.cu b/ggml/src/ggml-cuda/fill.cu new file mode 100644 index 000000000..eb8ccb780 --- /dev/null +++ b/ggml/src/ggml-cuda/fill.cu @@ -0,0 +1,37 @@ +#include "fill.cuh" +#include "convert.cuh" + +#define CUDA_FILL_BLOCK_SIZE 256 + +template +static __global__ void fill_kernel(T * __restrict__ dst, const int64_t k, const T value) { + const int64_t i = (int64_t)blockDim.x * blockIdx.x + threadIdx.x; + if (i >= k) { + return; + } + dst[i] = value; +} + +void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + void * dst_d = dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(dst)); + + float value; + memcpy(&value, dst->op_params, sizeof(float)); + + const int64_t k = ggml_nelements(dst); + const int64_t num_blocks = (k + CUDA_FILL_BLOCK_SIZE - 1) / CUDA_FILL_BLOCK_SIZE; + + switch (dst->type) { + case GGML_TYPE_F32: + fill_kernel<<>>((float *)dst_d, k, value); + break; + case GGML_TYPE_F16: + fill_kernel<<>>((half *)dst_d, k, ggml_cuda_cast(value)); + break; + default: + GGML_ABORT("unsupported type"); + } +} diff --git a/ggml/src/ggml-cuda/fill.cuh b/ggml/src/ggml-cuda/fill.cuh new file mode 100644 index 000000000..8443c8362 --- /dev/null +++ b/ggml/src/ggml-cuda/fill.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 25cbf8238..d85fb2e1f 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -58,6 +58,7 @@ bool g_mul_mat_q = true; #include "ggml-cuda/solve_tri.cuh" #include "ggml-cuda/tri.cuh" #include "ggml-cuda/cumsum.cuh" +#include "ggml-cuda/fill.cuh" #include "ggml.h" #include @@ -2743,6 +2744,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SOLVE_TRI: ggml_cuda_op_solve_tri(ctx, dst); break; + case GGML_OP_FILL: + ggml_cuda_op_fill(ctx, dst); + break; default: return false; } @@ -4630,6 +4634,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: + case GGML_OP_FILL: case GGML_OP_CUMSUM: case GGML_OP_TRI: return true; diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu index 2e2b39720..e161d4dc4 100644 --- a/ggml/src/ggml-cuda/solve_tri.cu +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -3,7 +3,6 @@ #include "solve_tri.cuh" #define MAX_N_FAST 64 -#define MAX_K_FAST 32 // ====================== // Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction @@ -48,65 +47,58 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); __shared__ float sA[MAX_N_FAST * MAX_N_FAST]; - __shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)]; const int offset = threadIdx.x + threadIdx.y * blockDim.x; #pragma unroll for (int i = 0; i < n * n; i += k * WARP_SIZE) { - int i0 = i + offset; + const int i0 = i + offset; if (i0 < n * n) { sA[i0] = A_batch[i0]; } } - const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE; - -#pragma unroll - for (int i = 0; i < rows_per_warp; i++) { - const int i0 = lane + i * WARP_SIZE; - if (i0 < n) { - sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx]; - } - } - __syncthreads(); + float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f; + float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f; + + const int half = WARP_SIZE; + const int nrows_low = (n < half) ? n : half; + #pragma unroll - for (int row = 0; row < n; ++row) { + for (int row = 0; row < nrows_low; ++row) { float sum = 0.0f; - - { - int j = lane; - if (j < row) { - sum += sA[row * n + j] * sXt[col_idx * n + j]; - } + if (lane < row) { + sum += sA[row * n + lane] * x_low; } - if (row >= WARP_SIZE) { - int j = WARP_SIZE + lane; - if (j < row) { - sum += sA[row * n + j] * sXt[col_idx * n + j]; - } - } - sum = warp_reduce_sum(sum); - if (lane == 0) { - const float b_val = sXt[col_idx * n + row]; - const float a_diag = sA[row * n + row]; - // no safeguards for division by zero because that indicates corrupt - // data anyway - sXt[col_idx * n + row] = (b_val - sum) / a_diag; + if (lane == row) { + x_low = (x_low - sum) / sA[row * n + row]; } } - __syncthreads(); +#pragma unroll + for (int row = half; row < n; ++row) { + float sum = sA[row * n + lane] * x_low; + const int j = half + lane; + if (j < row) { + sum += sA[row * n + j] * x_high; + } + sum = warp_reduce_sum(sum); + + if (lane == row - half) { + x_high = (x_high - sum) / sA[row * n + row]; + } + } #pragma unroll - for (int i = 0; i < rows_per_warp; i++) { - const int i0 = lane + i * WARP_SIZE; - if (i0 < n) { - X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0]; + for (int rr = 0; rr < 2; ++rr) { + const int row = rr * WARP_SIZE + lane; + if (row < n) { + const float val = (row < half) ? x_low : x_high; + X_batch[row * k + col_idx] = val; } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp index 4cb292380..e5cc7ff86 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp @@ -7,35 +7,85 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; -void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { +void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, + const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + // Compute starting index in matrix B for this superblock const uint y_idx = i * QUANT_K + 32 * ib32; - uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + + // Precompute indices for quantization lookup tables + const uint qh_base = 2 * ib32; + const uint qs_base = 4 * ib32; + const uint sc_index = ib32 / 2; + const uint sc_shift = 6 * (ib32 & 1); + + // Loop over rows in the superblock [[unroll]] for (uint n = 0; n < num_rows; ++n) { + // Load per-block scales and shift for quantization const uint16_t[4] scales = data_a[ibi].scales; const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + const uint sc = data_a[ibi].scales[sc_index] >> sc_shift; - const uint sc = data_a[ibi].scales[ib32 / 2] >> (6 * (ib32 & 1)); + // Temporary caches for decoding + FLOAT_TYPE dl_cache[4]; + uint16_t gvf_cache[4]; + float delta_cache[4]; + + // Precompute the multiplier and lookup values for 4 sub-blocks [[unroll]] for (uint l = 0; l < 4; ++l) { - const uint qh = data_a[ibi].qh[2 * ib32 + l / 2] >> (4 * (l&1)); - const uint qs = data_a[ibi].qs[4 * ib32 + l]; - const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; - const float dl = d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1); + dl_cache[l] = FLOAT_TYPE(d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1)); + const uint qh = data_a[ibi].qh[qh_base + l / 2] >> (4 * (l & 1)); + const uint qs = data_a[ibi].qs[qs_base + l]; + gvf_cache[l] = iq1s_grid[qs | ((qh & 7) << 8)]; + delta_cache[l] = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + } - const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + // Loop over columns of the output + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + // Compute base index for matrix B + const uint base_b_idx = (j * p.batch_stride_b + b_offset + y_idx) / 4; + vec4 b_vals[8]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); - vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); - - FLOAT_TYPE sum = FLOAT_TYPE(0.0); - [[unroll]] for (int k = 0; k < 4; ++k) { - sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta, - fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum)); - } - temp[j][n] = fma(dl, sum, temp[j][n]); + // Load 8 vec4 values from matrix B + [[unroll]] for (int idx = 0; idx < 8; ++idx) { + b_vals[idx] = vec4(data_b_v4[base_b_idx + idx]); } + + FLOAT_TYPE col_sum = FLOAT_TYPE(0.0); + + // Loop over sub-blocks + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint16_t grid = gvf_cache[l]; + const float dl = dl_cache[l]; + + // Decode 8 2-bit fbits from gvf_cache + float f0 = float(bitfieldExtract(grid, 0, 2)); + float f1 = float(bitfieldExtract(grid, 2, 2)); + float f2 = float(bitfieldExtract(grid, 4, 2)); + float f3 = float(bitfieldExtract(grid, 6, 2)); + float f4 = float(bitfieldExtract(grid, 8, 2)); + float f5 = float(bitfieldExtract(grid, 10, 2)); + float f6 = float(bitfieldExtract(grid, 12, 2)); + float f7 = float(bitfieldExtract(grid, 14, 2)); + + // Pack into vec4 for vectorized FMA + const vec4 fbits_v0 = vec4(f0, f1, f2, f3); + const vec4 fbits_v1 = vec4(f4, f5, f6, f7); + const vec4 delta_v = vec4(delta_cache[l]); + + // Vectorized fused multiply-add + vec4 sum_v = fma(b_vals[2*l + 0], fbits_v0 + delta_v, vec4(0.0)); + sum_v = fma(b_vals[2*l + 1], fbits_v1 + delta_v, sum_v); + + // Horizontal add to get scalar sum + FLOAT_TYPE sum = sum_v.x + sum_v.y + sum_v.z + sum_v.w; + + // Accumulate to column sum + col_sum = fma(dl, sum, col_sum); + } + // Write result to temporary buffer + temp[j][n] += col_sum; } ibi += num_blocks_per_row; } diff --git a/src/llama-context.cpp b/src/llama-context.cpp index fed4b5cf5..cf93ae314 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -251,7 +251,10 @@ llama_context::llama_context( LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size()); - const size_t max_nodes = this->graph_max_nodes(); + const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + const size_t max_nodes = this->graph_max_nodes(n_tokens); LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); @@ -309,9 +312,6 @@ llama_context::llama_context( cross.v_embd.clear(); - const uint32_t n_seqs = cparams.n_seq_max; - const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - // avoid reserving graphs with zero outputs - assume one output per sequence n_outputs = n_seqs; @@ -1396,9 +1396,9 @@ void llama_context::output_reorder() { // graph // -uint32_t llama_context::graph_max_nodes() const { +uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { if (model.arch == LLM_ARCH_QWEN3NEXT) { - return std::max(8192u, 32u*model.n_tensors()); + return std::max(n_tokens * 40, 32u * model.n_tensors()); } return std::max(1024u, 8u*model.n_tensors()); } diff --git a/src/llama-context.h b/src/llama-context.h index 20cbd7895..cd26eafe1 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -197,7 +197,7 @@ private: // public: - uint32_t graph_max_nodes() const; + uint32_t graph_max_nodes(uint32_t n_tokens) const; // can reuse the llm_graph_result instance of the context (for example to update a memory module) llm_graph_result * get_gf_res_reserve() const; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 035cc06d6..be3eb5205 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1733,6 +1733,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { } ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); + // (optional) temperature tuning - used by mistral-large + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false); + switch (hparams.n_layer) { case 27: type = LLM_TYPE_16B; break; case 60: type = LLM_TYPE_236B; break; diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index 0b41f7ba8..dbaa8297b 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -30,6 +30,12 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // {n_embd, n_tokens} inpL = build_inp_embd(model.tok_embd); + // (optional) temperature tuning - used by mistral-large + ggml_tensor * inp_attn_scale = nullptr; + if (hparams.f_attn_temp_scale != 0.0f) { + inp_attn_scale = build_inp_attn_scale(); + } + // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); @@ -128,6 +134,12 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr ggml_tensor * Vcur = kv_cmpr; cb(Vcur, "Vcur", il); + if (inp_attn_scale) { + // apply llama 4 temperature scaling + Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale); + cb(Qcur, "Qcur_attn_temp_scaled", il); + } + // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group) cur = build_attn(inp_attn, model.layers[il].wo, NULL, @@ -160,6 +172,12 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0); cb(Kcur, "Kcur", il); + if (inp_attn_scale) { + // apply llama 4 temperature scaling + Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale); + cb(Qcur, "Qcur_attn_temp_scaled", il); + } + // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups) cur = build_attn(inp_attn, model.layers[il].wo, NULL, diff --git a/tools/server/README-dev.md b/tools/server/README-dev.md new file mode 100644 index 000000000..fbcd6bc1f --- /dev/null +++ b/tools/server/README-dev.md @@ -0,0 +1,177 @@ +# llama-server Development Documentation + +This document provides an in-depth technical overview of `llama-server`, intended for maintainers and contributors. + +If you are an end user consuming `llama-server` as a product, please refer to the main [README](./README.md) instead. + +## Backend + +### Overview + +The server supports two primary operating modes: + +- **Inference mode**: The default mode for performing inference with a single loaded GGUF model. +- **Router mode**: Enables management of multiple inference server instances behind a single API endpoint. Requests are automatically routed to the appropriate backend instance based on the requested model. + +The core architecture consists of the following components: + +- `server_context`: Holds the primary inference state, including the main `llama_context` and all active slots. +- `server_slot`: An abstraction over a single “sequence” in llama.cpp, responsible for managing individual parallel inference requests. +- `server_routes`: Middleware layer between `server_context` and the HTTP interface; handles JSON parsing/formatting and request routing logic. +- `server_http_context`: Implements the HTTP server using `cpp-httplib`. +- `server_queue`: Thread-safe queue used by HTTP workers to submit new tasks to `server_context`. +- `server_response`: Thread-safe queue used by `server_context` to return results to HTTP workers. +- `server_response_reader`: Higher-level wrapper around the two queues above for cleaner code. +- `server_task`: Unit of work pushed into `server_queue`. +- `server_task_result`: Unit of result pushed into `server_response`. +- `server_tokens`: Unified representation of token sequences (supports both text and multimodal tokens); used by `server_task` and `server_slot`. +- `server_prompt_checkpoint`: For recurrent (e.g., RWKV) and SWA models, stores snapshots of KV cache state. Enables reuse when subsequent requests share the same prompt prefix, saving redundant computation. +- `server_models`: Standalone component for managing multiple backend instances (used in router mode). It is completely independent of `server_context`. + +```mermaid +graph TD + API_User <--> server_http_context + server_http_context <-- router mode --> server_models + server_http_context <-- inference mode --> server_routes + server_routes -- server_task --> server_queue + subgraph server_context + server_queue --> server_slot + server_slot -- server_task_result --> server_response + server_slot[multiple server_slot] + end + server_response --> server_routes +``` + +### Batching + +The server context maintains a single batch shared across all slots. When `update_slots()` is invoked, the system iterates through all active slots to populate this batch. For each slot, either a generated token from the previous decoding step or available prompt tokens are added to the batch. + +Batching constraints apply: slots can only be batched together if they share compatible configurations. For instance, slots using a specific LoRA adapter can be batched with each other, but not with slots using a different LoRA adapter or no adapter at all. + +Once the batch reaches capacity or all slots have been processed, `llama_decode` is called to execute the inference. This operation represents the primary computational bottleneck in `update_slots()`. + +Following decoding, the system either retrieves embeddings or samples the next token using `common_sampler_sample`. If a slot has remaining prompt tokens to process, it yields until the next `update_slots()` iteration. + +### Thread Management + +`server_context` runs on a dedicated single thread. Because it is single-threaded, heavy post-processing (especially after token generation) should be avoided, as it directly impacts multi-sequence throughput. + +Each incoming HTTP request is handled by its own thread managed by the HTTP library. The following operations are performed in HTTP worker threads: + +- JSON request parsing +- Chat template application +- Tokenization +- Conversion of `server_task_result` into final JSON response +- Error formatting into JSON +- Tracking of partial/incremental responses (e.g., streaming tool calls or reasoning steps) + +**Best practices to follow:** + +- All JSON formatting and chat template logic must stay in the HTTP layer. +- Avoid passing raw JSON between the HTTP layer and `server_slot`. Instead, parse everything into native C++ types as early as possible. + +### Example trace of a request + +Here is an example trace of an API request for text completion: + +- A request arrives at the HTTP layer. +- The request is routed to the corresponding handler inside `server_routes`. In this case, `handle_completions_impl` is invoked. +- The handler parses the input request, constructs a new `server_task`, and passes it to `server_res_generator`. +- `server_res_generator` creates a new `task_result_state` for each task: + - `task_result_state` stays in the HTTP layer, responsible for keeping track of the current state of the response (e.g., parsing tool calls or thinking messages). + - `server_task` is moved into `server_queue` inside `server_context`. +- `server_context` launches the task by moving it into an available slot (see `launch_slot_with_task()`). +- `update_slot()` processes the task as described in the "Batching" section above. +- Results may be sent using `send_partial_response` or `send_final_response`, which creates a new `server_task_result` and pushes it to the response queue. +- At the same time, `server_res_generator` listens to the response queue and retrieves this response. +- As the response is stateless, `server_res_generator` calls `response->update()` to update the response with the current state. +- `server_res_generator` then calls `response->to_json()` and passes the response to the HTTP layer. + +### Testing + +`llama-server` includes an automated test suite based on `pytest`. + +The framework automatically starts a `llama-server` instance, sends requests, and validates responses. + +For detailed instructions, see the [test documentation](./tests/README.md). + +### Notable Related PRs + +- Initial server implementation: https://github.com/ggml-org/llama.cpp/pull/1443 +- Parallel decoding support: https://github.com/ggml-org/llama.cpp/pull/3228 +- Refactor introducing `server_queue` and `server_response`: https://github.com/ggml-org/llama.cpp/pull/5065 +- Reranking endpoint: https://github.com/ggml-org/llama.cpp/pull/9510 +- Multimodal model support (`libmtmd`): https://github.com/ggml-org/llama.cpp/pull/12898 +- Unified KV cache handling: https://github.com/ggml-org/llama.cpp/pull/16736 +- Separation of HTTP logic into dedicated files: https://github.com/ggml-org/llama.cpp/pull/17216 +- Large-scale code base split into smaller files: https://github.com/ggml-org/llama.cpp/pull/17362 +- Introduction of router mode: https://github.com/ggml-org/llama.cpp/pull/17470 +- Speculative decoding: https://github.com/ggml-org/llama.cpp/pull/17808 and rework in https://github.com/ggml-org/llama.cpp/pull/17808 + + + + +## Web UI + +The project includes a web-based user interface for interacting with `llama-server`. It supports both single-model (`MODEL` mode) and multi-model (`ROUTER` mode) operation. + +The SvelteKit-based Web UI is introduced in this PR: https://github.com/ggml-org/llama.cpp/pull/14839 + +### Features + +- **Chat interface** with streaming responses +- **Multi-model support** (ROUTER mode) - switch between models, auto-load on selection +- **Modality validation** - ensures selected model supports conversation's attachments (images, audio) +- **Conversation management** - branching, regeneration, editing with history preservation +- **Attachment support** - images, audio, PDFs (with vision/text fallback) +- **Configurable parameters** - temperature, top_p, etc. synced with server defaults +- **Dark/light theme** + +### Tech Stack + +- **SvelteKit** - frontend framework with Svelte 5 runes for reactive state +- **TailwindCSS** + **shadcn-svelte** - styling and UI components +- **Vite** - build tooling +- **IndexedDB** (Dexie) - local storage for conversations +- **LocalStorage** - user settings persistence + +### Architecture + +The WebUI follows a layered architecture: + +``` +Routes → Components → Hooks → Stores → Services → Storage/API +``` + +- **Stores** - reactive state management (`chatStore`, `conversationsStore`, `modelsStore`, `serverStore`, `settingsStore`) +- **Services** - stateless API/database communication (`ChatService`, `ModelsService`, `PropsService`, `DatabaseService`) +- **Hooks** - reusable logic (`useModelChangeValidation`, `useProcessingState`) + +For detailed architecture diagrams, see [`tools/server/webui/docs/`](webui/docs/): + +- `high-level-architecture.mmd` - full architecture with all modules +- `high-level-architecture-simplified.mmd` - simplified overview +- `data-flow-simplified-model-mode.mmd` - data flow for single-model mode +- `data-flow-simplified-router-mode.mmd` - data flow for multi-model mode +- `flows/*.mmd` - detailed per-domain flows (chat, conversations, models, etc.) + +### Development + +```sh +# make sure you have Node.js installed +cd tools/server/webui +npm i + +# run dev server (with hot reload) +npm run dev + +# run tests +npm run test + +# build production bundle +npm run build +``` + +After `public/index.html.gz` has been generated, rebuild `llama-server` as described in the [build](#build) section to include the updated UI. + +**Note:** The Vite dev server automatically proxies API requests to `http://localhost:8080`. Make sure `llama-server` is running on that port during development. diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 0c4d84ffa..0629bb5ed 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -18,11 +18,13 @@ const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + " using json = nlohmann::ordered_json; #define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_CNT(slot, fmt, ...) LOG_CNT("" fmt, __VA_ARGS__) #define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) #define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) #define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) #define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__) #define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 12a4e94e5..4578f8d7a 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -102,6 +102,11 @@ struct server_slot { std::string generated_text; llama_tokens generated_tokens; + // idx of draft tokens in the main batch + // non-empty if we went to evaluate draft tokens + // ref: https://github.com/ggml-org/llama.cpp/pull/17808 + std::vector i_batch_dft; + std::vector generated_token_probs; bool has_next_token = true; @@ -150,7 +155,8 @@ struct server_slot { struct common_sampler * smpl = nullptr; - llama_token sampled; + llama_token sampled; // in speculative mode, this is the last accepted token + llama_tokens drafted; // stats size_t n_sent_text = 0; // number of sent text character @@ -180,6 +186,8 @@ struct server_slot { stopping_word = ""; n_sent_text = 0; + drafted.clear(); + i_batch_dft.clear(); generated_tokens.clear(); generated_token_probs.clear(); json_schema = json(); @@ -255,6 +263,31 @@ struct server_slot { generated_token_probs.push_back(token); } + int get_n_draft_max() const { + if (!can_speculate()) { + return 0; + } + + // determine the max draft that fits the current slot state + int n_draft_max = task->params.speculative.n_max; + + // note: slot.prompt is not yet expanded with the `id` token sampled above + // also, need to leave space for 1 extra token to allow context shifts + n_draft_max = std::min(n_draft_max, n_ctx - prompt.n_tokens() - 2); + + if (n_remaining > 0) { + n_draft_max = std::min(n_draft_max, n_remaining - 1); + } + + SLT_DBG(*this, "max possible draft: %d\n", n_draft_max); + + if (n_draft_max < task->params.speculative.n_min) { + SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min); + n_draft_max = 0; + } + return n_draft_max; + } + // note: a slot can also be either a parent or a child bool is_parent() const { return is_processing() && task->n_children > 0; @@ -353,8 +386,7 @@ struct server_slot { if (n_draft_total > 0) { const float draft_ratio = (float) n_draft_accepted / n_draft_total; - SLT_INF(*this, - "\n" + SLT_CNT(*this, "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n", draft_ratio, n_draft_accepted, n_draft_total ); @@ -1774,14 +1806,57 @@ struct server_context_impl { continue; } - slot.i_batch = batch.n_tokens; + // generate draft tokens in speculative decoding mode + // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] + // perform the speculative drafting for all sequences at the same time in a single batch + int n_draft_max = slot.get_n_draft_max(); + if (n_draft_max > 0) { + if (mctx) { + // we should never reach this, as speculative is automatically disabled if mmproj is loaded + GGML_ABORT("not supported by multimodal"); + } - common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max; + params_spec.p_min = slot.task->params.speculative.p_min; + const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled); - slot.prompt.tokens.push_back(slot.sampled); + // add the sampled token to the batch + slot.i_batch_dft.push_back(batch.n_tokens); + common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); + slot.prompt.tokens.push_back(slot.sampled); - SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n", - slot.n_ctx, slot.prompt.n_tokens(), slot.truncated); + if (slot.task->params.speculative.n_min > (int) draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); + // fallback to normal decoding + slot.i_batch = slot.i_batch_dft[0]; + slot.drafted.clear(); + slot.i_batch_dft.clear(); + } else { + // keep track of total number of drafted tokens tested + slot.n_draft_total += draft.size(); + + // add all drafted tokens to the batch + for (size_t i = 0; i < draft.size(); i++) { + slot.i_batch_dft.push_back(batch.n_tokens); + common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true); + slot.prompt.tokens.push_back(draft[i]); + } + slot.drafted = std::move(draft); + } + } else { + // no speculative decoding + slot.i_batch = batch.n_tokens; + + common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); + + slot.prompt.tokens.push_back(slot.sampled); + + SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n", + slot.n_ctx, slot.prompt.n_tokens(), slot.truncated); + } } // process in chunks of params.n_batch @@ -1880,8 +1955,18 @@ struct server_context_impl { n_past = std::min(n_past, slot.alora_invocation_start - 1); } + const auto n_cache_reuse = slot.task->params.n_cache_reuse; + + const bool can_cache_reuse = + llama_memory_can_shift(llama_get_memory(ctx)) && + !slot.prompt.tokens.has_mtmd; + + if (!can_cache_reuse && n_cache_reuse > 0) { + SLT_WRN(slot, "cache reuse is not supported - ignoring n_cache_reuse = %d\n", n_cache_reuse); + } + // reuse chunks from the cached prompt by shifting their KV cache in the new position - if (params_base.n_cache_reuse > 0) { + if (can_cache_reuse && n_cache_reuse > 0) { GGML_ASSERT(!slot.prompt.tokens.has_mtmd); size_t head_c = n_past; // cache @@ -1892,7 +1977,7 @@ struct server_context_impl { GGML_ABORT("not supported by multimodal"); } - SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past); + SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", n_cache_reuse, n_past); while (head_c < slot.prompt.tokens.size() && head_p < input_tokens.size()) { @@ -1901,11 +1986,10 @@ struct server_context_impl { while (head_c + n_match < slot.prompt.tokens.size() && head_p + n_match < input_tokens.size() && slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) { - n_match++; } - if (n_match >= (size_t) params_base.n_cache_reuse) { + if (n_match >= (size_t) n_cache_reuse) { SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); //for (size_t i = head_p; i < head_p + n_match; i++) { // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); @@ -2336,6 +2420,10 @@ struct server_context_impl { // on successful decode, restore the original batch size n_batch = llama_n_batch(ctx); + // technically, measuring the time here excludes the sampling time for the last batch + // but on the other hand, we don't want to do too many system calls to measure the time, so it's ok + const int64_t t_current = ggml_time_us(); + for (auto & slot : slots) { // may need to copy state to other slots if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) { @@ -2390,6 +2478,10 @@ struct server_context_impl { continue; // continue loop of slots } + if (slot.i_batch_dft.size() > 0) { + continue; // sample using speculative decoding + } + const int tok_idx = slot.i_batch - i; llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); @@ -2400,8 +2492,6 @@ struct server_context_impl { slot.n_decoded += 1; - const int64_t t_current = ggml_time_us(); - if (slot.n_decoded == 1) { slot.t_start_generation = t_current; slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; @@ -2430,84 +2520,32 @@ struct server_context_impl { } } - // do speculative decoding - // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] - // perform the speculative drafting for all sequences at the same time in a single batch + // speculative decoding - main model sample and accept for (auto & slot : slots) { - if (!slot.is_processing() || !slot.can_speculate()) { + if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft.empty()) { continue; } - if (slot.state != SLOT_STATE_GENERATING) { - continue; - } - - if (mctx) { - // we should never reach this, as speculative is automatically disabled if mmproj is loaded - GGML_ABORT("not supported by multimodal"); - } - - // determine the max draft that fits the current slot state - int n_draft_max = slot.task->params.speculative.n_max; - - // note: slot.prompt is not yet expanded with the `id` token sampled above - // also, need to leave space for 1 extra token to allow context shifts - n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2); - - if (slot.n_remaining > 0) { - n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); - } - - SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); - - if (n_draft_max < slot.task->params.speculative.n_min) { - SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min); - - continue; - } - - llama_token id = slot.sampled; - - struct common_speculative_params params_spec; - params_spec.n_draft = n_draft_max; - params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max; - params_spec.p_min = slot.task->params.speculative.p_min; - - const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); - - // ignore small drafts - if (slot.task->params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); - - continue; - } - - // keep track of total number of drafted tokens tested - slot.n_draft_total += draft.size(); - - // construct the speculation batch - common_batch_clear(slot.batch_spec); - common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true); - - for (size_t i = 0; i < draft.size(); ++i) { - common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true); - } - - SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); - - llama_decode(ctx, slot.batch_spec); + size_t n_draft = slot.drafted.size(); // the accepted tokens from the speculation - const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, slot.i_batch_dft, slot.drafted); + slot.i_batch_dft.clear(); + slot.drafted.clear(); slot.n_decoded += ids.size(); + slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; + // update how many tokens out of those tested were accepted slot.n_draft_accepted += ids.size() - 1; - slot.prompt.tokens.push_back(id); + // rollback to the state before sampling the draft tokens + slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft); + + // add accepted tokens to the prompt slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); + slot.sampled = ids.back(); // last accepted token llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1); @@ -2530,7 +2568,7 @@ struct server_context_impl { } } - SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens()); + SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) slot.drafted.size(), slot.prompt.n_tokens()); } } @@ -2551,6 +2589,10 @@ struct server_context_impl { int get_slot_n_ctx() { return slots.back().n_ctx; } + + server_response_reader get_response_reader() { + return server_response_reader(queue_tasks, queue_results, HTTP_POLLING_SECONDS); + } }; // @@ -2580,8 +2622,8 @@ llama_context * server_context::get_llama_context() const { return impl->ctx; } -std::pair server_context::get_queues() { - return { impl->queue_tasks, impl->queue_results }; +server_response_reader server_context::get_response_reader() { + return impl->get_response_reader(); } @@ -2590,7 +2632,7 @@ std::pair server_context::get_queues() { struct server_res_generator : server_http_res { server_response_reader rd; server_res_generator(server_context_impl & ctx_server) - : rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS) {} + : rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) {} void ok(const json & response_data) { status = 200; data = safe_json_to_str(response_data); @@ -2623,9 +2665,6 @@ static std::unique_ptr handle_completions_impl( try { std::vector tasks; - // tracking generation state and partial tool calls - std::vector states; - const auto & prompt = data.at("prompt"); // TODO: this log can become very long, put it behind a flag or think about a more compact format //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); @@ -2641,7 +2680,6 @@ static std::unique_ptr handle_completions_impl( inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); } tasks.reserve(inputs.size()); - states.reserve(inputs.size()); int idx = 0; for (size_t i = 0; i < inputs.size(); i++) { server_task task = server_task(type); @@ -2660,7 +2698,6 @@ static std::unique_ptr handle_completions_impl( task.params.res_type = res_type; task.params.oaicompat_cmpl_id = completion_id; task.params.oaicompat_model = ctx_server.model_name; - states.push_back(task.params.oaicompat_chat_syntax); if (task.params.n_cmpl > 1) { task.n_children = task.params.n_cmpl - 1; @@ -2669,7 +2706,6 @@ static std::unique_ptr handle_completions_impl( task.id, ctx_server.queue_tasks.get_new_id(), idx++); - states.push_back(child.params.oaicompat_chat_syntax); tasks.push_back(std::move(child)); } } @@ -2677,7 +2713,6 @@ static std::unique_ptr handle_completions_impl( tasks.push_back(std::move(task)); } - rd.set_states(std::move(states)); rd.post_tasks(std::move(tasks)); } catch (const std::exception & e) { res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); @@ -3407,7 +3442,7 @@ void server_routes::init_routes() { // create and queue the task json responses = json::array(); - server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS); + server_response_reader rd = ctx_server.get_response_reader(); { std::vector tasks; tasks.reserve(documents.size()); @@ -3667,7 +3702,7 @@ std::unique_ptr server_routes::handle_embeddings_impl(cons // create and queue the task json responses = json::array(); - server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS); + server_response_reader rd = ctx_server.get_response_reader(); { std::vector tasks; for (size_t i = 0; i < tokenized_prompts.size(); i++) { diff --git a/tools/server/server-context.h b/tools/server/server-context.h index 05b4afaee..eaa138087 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -31,9 +31,8 @@ struct server_context { // get the underlaying llama_context llama_context * get_llama_context() const; - // get the underlaying queue_tasks and queue_results - // used by CLI application - std::pair get_queues(); + // get a new response reader, used by CLI application + server_response_reader get_response_reader(); }; diff --git a/tools/server/server-queue.cpp b/tools/server/server-queue.cpp index 10196128d..3cceb2bbe 100644 --- a/tools/server/server-queue.cpp +++ b/tools/server/server-queue.cpp @@ -271,12 +271,21 @@ void server_response::terminate() { // server_response_reader // -void server_response_reader::set_states(std::vector && states) { - this->states = std::move(states); +void server_response_reader::post_task(server_task && task) { + GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader"); + id_tasks.insert(task.id); + states.push_back(task.create_state()); + queue_results.add_waiting_task_id(task.id); + queue_tasks.post(std::move(task)); } void server_response_reader::post_tasks(std::vector && tasks) { + GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader"); id_tasks = server_task::get_list_id(tasks); + states.reserve(tasks.size()); + for (size_t i = 0; i < tasks.size(); i++) { + states.push_back(tasks[i].create_state()); + } queue_results.add_waiting_tasks(tasks); queue_tasks.post(std::move(tasks)); } diff --git a/tools/server/server-queue.h b/tools/server/server-queue.h index a5c3179d8..726eadf4e 100644 --- a/tools/server/server-queue.h +++ b/tools/server/server-queue.h @@ -129,13 +129,13 @@ struct server_response_reader { std::vector states; // should_stop function will be called each polling_interval_seconds - server_response_reader(std::pair server_queues, int polling_interval_seconds) - : queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {} + server_response_reader(server_queue & queue_tasks, server_response & queue_results, int polling_interval_seconds) + : queue_tasks(queue_tasks), queue_results(queue_results), polling_interval_seconds(polling_interval_seconds) {} ~server_response_reader() { stop(); } - void set_states(std::vector && states); + void post_task(server_task && tasks); void post_tasks(std::vector && tasks); bool has_next() const; diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index c401f47a7..360826062 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -155,11 +155,12 @@ task_params server_task::params_from_json_cmpl( // Sampling parameter defaults are loaded from the global server context (but individual requests can still them) task_params defaults; - defaults.sampling = params_base.sampling; - defaults.speculative = params_base.speculative; - defaults.n_keep = params_base.n_keep; - defaults.n_predict = params_base.n_predict; - defaults.antiprompt = params_base.antiprompt; + defaults.sampling = params_base.sampling; + defaults.speculative = params_base.speculative; + defaults.n_keep = params_base.n_keep; + defaults.n_predict = params_base.n_predict; + defaults.n_cache_reuse = params_base.n_cache_reuse; + defaults.antiprompt = params_base.antiprompt; // enabling this will output extra debug information in the HTTP responses from the server params.verbose = params_base.verbosity > 9; @@ -176,6 +177,7 @@ task_params server_task::params_from_json_cmpl( params.n_keep = json_value(data, "n_keep", defaults.n_keep); params.n_discard = json_value(data, "n_discard", defaults.n_discard); params.n_cmpl = json_value(data, "n_cmpl", json_value(data, "n", 1)); + params.n_cache_reuse = json_value(data, "n_cache_reuse", defaults.n_cache_reuse); //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); params.response_fields = json_value(data, "response_fields", std::vector()); diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 4e4840fc8..9011ff944 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -55,6 +55,8 @@ struct task_params { int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters int32_t n_cmpl = 1; // number of completions to generate from this prompt + int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled) + int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit @@ -62,18 +64,19 @@ struct task_params { std::vector antiprompt; std::vector response_fields; - bool timings_per_token = false; + + bool timings_per_token = false; bool post_sampling_probs = false; struct common_params_sampling sampling; struct common_params_speculative speculative; // response formatting - bool verbose = false; - task_response_type res_type = TASK_RESPONSE_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_syntax oaicompat_chat_syntax; + bool verbose = false; + task_response_type res_type = TASK_RESPONSE_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_syntax oaicompat_chat_syntax; // Embeddings int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) @@ -82,6 +85,25 @@ struct task_params { json to_json(bool only_metrics = false) const; }; +// struct for tracking the state of a task (e.g., for streaming) +struct task_result_state { + // tracking diffs for partial tool calls + std::vector diffs; + common_chat_syntax oaicompat_chat_syntax; + common_chat_msg chat_msg; + std::string generated_text; // append new chunks of generated text here + std::vector generated_tool_call_ids; + + task_result_state(const common_chat_syntax & oaicompat_chat_syntax) + : oaicompat_chat_syntax(oaicompat_chat_syntax) {} + + // parse partial tool calls and update the internal state + common_chat_msg update_chat_msg( + const std::string & text_added, + bool is_partial, + std::vector & diffs); +}; + struct server_task { int id = -1; // to be filled by server_queue int index = -1; // used when there are multiple prompts (batch request) @@ -146,6 +168,12 @@ struct server_task { copy.tokens = tokens.clone(); return copy; } + + // the task will be moved into queue, then onto slots + // however, the state must be kept by caller (e.g., HTTP thread) + task_result_state create_state() const { + return task_result_state(params.oaicompat_chat_syntax); + } }; struct result_timings { @@ -177,25 +205,6 @@ struct result_prompt_progress { json to_json() const; }; -// struct for tracking the state of a task (e.g., for streaming) -struct task_result_state { - // tracking diffs for partial tool calls - std::vector diffs; - common_chat_syntax oaicompat_chat_syntax; - common_chat_msg chat_msg; - std::string generated_text; // append new chunks of generated text here - std::vector generated_tool_call_ids; - - task_result_state(const common_chat_syntax & oaicompat_chat_syntax) - : oaicompat_chat_syntax(oaicompat_chat_syntax) {} - - // parse partial tool calls and update the internal state - common_chat_msg update_chat_msg( - const std::string & text_added, - bool is_partial, - std::vector & diffs); -}; - struct server_task_result { int id = -1; int id_slot = -1;