mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-07 09:02:04 +00:00
Merge commit '2fa51c19b0' into concedo_experimental
# Conflicts: # .github/actions/windows-setup-cuda/action.yml # .github/workflows/build-linux-cross.yml # .github/workflows/release.yml # README.md # docs/build-riscv64-spacemit.md # examples/model-conversion/logits.cpp # ggml/CMakeLists.txt # ggml/src/ggml-cpu/CMakeLists.txt # models/templates/Kimi-K2-Instruct.jinja # models/templates/Kimi-K2-Thinking.jinja # tests/test-chat.cpp # tools/server/README.md
This commit is contained in:
commit
278e45becf
21 changed files with 584 additions and 214 deletions
|
|
@ -724,16 +724,10 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons
|
|||
if (reasoning_unclosed) {
|
||||
if (auto pos = content.find(end_think); pos == std::string::npos && builder.pos() != builder.input().size()) {
|
||||
unclosed_reasoning_content += content;
|
||||
if (form.allow_toolcall_in_think) {
|
||||
builder.move_to(tc->groups[0].begin);
|
||||
if (!builder.try_consume_xml_tool_calls(form)) {
|
||||
unclosed_reasoning_content += tool_call_start;
|
||||
builder.move_to(tc->groups[0].end);
|
||||
}
|
||||
} else {
|
||||
if (!(form.allow_toolcall_in_think && tc)) {
|
||||
unclosed_reasoning_content += tool_call_start;
|
||||
continue;
|
||||
}
|
||||
continue;
|
||||
} else {
|
||||
reasoning_unclosed = false;
|
||||
std::string reasoning_content;
|
||||
|
|
@ -781,8 +775,12 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons
|
|||
}
|
||||
} else {
|
||||
// This <tool_call> start is in thinking block, skip this tool call
|
||||
auto pos = think_start + start_think.size();
|
||||
unclosed_reasoning_content = content.substr(pos) + tool_call_start;
|
||||
// This <tool_call> start is in thinking block
|
||||
if (form.allow_toolcall_in_think) {
|
||||
unclosed_reasoning_content = content.substr(think_start + start_think.size());
|
||||
} else {
|
||||
unclosed_reasoning_content = content.substr(think_start + start_think.size()) + tool_call_start;
|
||||
}
|
||||
reasoning_unclosed = true;
|
||||
content.resize(think_start);
|
||||
toolcall_in_think = true;
|
||||
|
|
@ -805,14 +803,35 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons
|
|||
}
|
||||
|
||||
// remove potential partial suffix
|
||||
if (content.size() > 0 && builder.pos() == builder.input().size() && unclosed_reasoning_content.empty()) {
|
||||
rstrip(content);
|
||||
trim_potential_partial_word(content);
|
||||
rstrip(content);
|
||||
if (builder.pos() == builder.input().size()) {
|
||||
if (unclosed_reasoning_content.empty()) {
|
||||
rstrip(content);
|
||||
trim_potential_partial_word(content);
|
||||
rstrip(content);
|
||||
} else {
|
||||
rstrip(unclosed_reasoning_content);
|
||||
trim_potential_partial_word(unclosed_reasoning_content);
|
||||
rstrip(unclosed_reasoning_content);
|
||||
}
|
||||
}
|
||||
|
||||
// consume unclosed_reasoning_content if allow_toolcall_in_think is set
|
||||
if (form.allow_toolcall_in_think && !unclosed_reasoning_content.empty()) {
|
||||
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
|
||||
builder.add_reasoning_content(unclosed_reasoning_content);
|
||||
} else {
|
||||
if (content.empty()) {
|
||||
content = start_think + unclosed_reasoning_content;
|
||||
} else {
|
||||
content += "\n\n" + start_think;
|
||||
content += unclosed_reasoning_content;
|
||||
}
|
||||
}
|
||||
unclosed_reasoning_content.clear();
|
||||
}
|
||||
|
||||
// Add content
|
||||
if (content.size() != 0) {
|
||||
if (!content.empty()) {
|
||||
// If there are multiple content blocks
|
||||
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content && builder.result().content.size() != 0) {
|
||||
builder.add_content("\n\n");
|
||||
|
|
@ -820,7 +839,7 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons
|
|||
builder.add_content(content);
|
||||
}
|
||||
|
||||
// This <tool_call> start is in thinking block, skip this tool call
|
||||
// This <tool_call> start is in thinking block and toolcall_in_think not set, skip this tool call
|
||||
if (toolcall_in_think && !form.allow_toolcall_in_think) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -829,7 +848,7 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons
|
|||
if (!tc) {
|
||||
GGML_ASSERT(builder.pos() == builder.input().size());
|
||||
GGML_ASSERT(unclosed_reasoning_content.empty());
|
||||
GGML_ASSERT(!reasoning_unclosed);
|
||||
if (!form.allow_toolcall_in_think) GGML_ASSERT(!reasoning_unclosed);
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
@ -854,7 +873,6 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons
|
|||
|
||||
/**
|
||||
* Parse content uses reasoning and XML-Style tool call
|
||||
* TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed.
|
||||
*/
|
||||
void common_chat_msg_parser::consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think, const std::string & end_think) {
|
||||
parse_msg_with_xml_tool_calls(*this, form, start_think, end_think);
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ struct xml_tool_call_format {
|
|||
std::optional<std::string> last_val_end = std::nullopt;
|
||||
std::optional<std::string> last_tool_end = std::nullopt;
|
||||
bool trim_raw_argval = false;
|
||||
bool allow_toolcall_in_think = false; // TODO: UNTESTED!!!
|
||||
bool allow_toolcall_in_think = false;
|
||||
};
|
||||
|
||||
// make a GBNF that accept any strings except those containing any of the forbidden strings.
|
||||
|
|
|
|||
|
|
@ -917,12 +917,13 @@ static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
|
|||
form.tool_start = "<|tool_call_begin|>";
|
||||
form.tool_sep = "<|tool_call_argument_begin|>{";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\": ";
|
||||
form.val_end = ", ";
|
||||
form.key_val_sep = "\":";
|
||||
form.val_end = ",";
|
||||
form.tool_end = "}<|tool_call_end|>";
|
||||
form.scope_end = "<|tool_calls_section_end|>";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
form.allow_toolcall_in_think = true;
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||
|
|
|
|||
|
|
@ -494,6 +494,15 @@ static inline void ggml_thread_cpu_relax(void) {
|
|||
static inline void ggml_thread_cpu_relax(void) {
|
||||
_mm_pause();
|
||||
}
|
||||
#elif defined(__riscv)
|
||||
static inline void ggml_thread_cpu_relax(void) {
|
||||
#ifdef __riscv_zihintpause
|
||||
__asm__ __volatile__ ("pause");
|
||||
#else
|
||||
/* Encoding of the pause instruction */
|
||||
__asm__ __volatile__ (".4byte 0x100000F");
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
static inline void ggml_thread_cpu_relax(void) {;}
|
||||
#endif
|
||||
|
|
|
|||
37
ggml/src/ggml-cuda/fill.cu
Normal file
37
ggml/src/ggml-cuda/fill.cu
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
#include "fill.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
#define CUDA_FILL_BLOCK_SIZE 256
|
||||
|
||||
template <typename T>
|
||||
static __global__ void fill_kernel(T * __restrict__ dst, const int64_t k, const T value) {
|
||||
const int64_t i = (int64_t)blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
dst[i] = value;
|
||||
}
|
||||
|
||||
void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
void * dst_d = dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
float value;
|
||||
memcpy(&value, dst->op_params, sizeof(float));
|
||||
|
||||
const int64_t k = ggml_nelements(dst);
|
||||
const int64_t num_blocks = (k + CUDA_FILL_BLOCK_SIZE - 1) / CUDA_FILL_BLOCK_SIZE;
|
||||
|
||||
switch (dst->type) {
|
||||
case GGML_TYPE_F32:
|
||||
fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, stream>>>((float *)dst_d, k, value);
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, stream>>>((half *)dst_d, k, ggml_cuda_cast<half>(value));
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("unsupported type");
|
||||
}
|
||||
}
|
||||
3
ggml/src/ggml-cuda/fill.cuh
Normal file
3
ggml/src/ggml-cuda/fill.cuh
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -58,6 +58,7 @@ bool g_mul_mat_q = true;
|
|||
#include "ggml-cuda/solve_tri.cuh"
|
||||
#include "ggml-cuda/tri.cuh"
|
||||
#include "ggml-cuda/cumsum.cuh"
|
||||
#include "ggml-cuda/fill.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
|
@ -2743,6 +2744,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_SOLVE_TRI:
|
||||
ggml_cuda_op_solve_tri(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_FILL:
|
||||
ggml_cuda_op_fill(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
@ -4630,6 +4634,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
case GGML_OP_FILL:
|
||||
case GGML_OP_CUMSUM:
|
||||
case GGML_OP_TRI:
|
||||
return true;
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
#include "solve_tri.cuh"
|
||||
|
||||
#define MAX_N_FAST 64
|
||||
#define MAX_K_FAST 32
|
||||
|
||||
// ======================
|
||||
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
|
||||
|
|
@ -48,65 +47,58 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
|
|||
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
|
||||
|
||||
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
|
||||
__shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];
|
||||
|
||||
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
|
||||
int i0 = i + offset;
|
||||
const int i0 = i + offset;
|
||||
if (i0 < n * n) {
|
||||
sA[i0] = A_batch[i0];
|
||||
}
|
||||
}
|
||||
|
||||
const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < rows_per_warp; i++) {
|
||||
const int i0 = lane + i * WARP_SIZE;
|
||||
if (i0 < n) {
|
||||
sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
|
||||
float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
|
||||
|
||||
const int half = WARP_SIZE;
|
||||
const int nrows_low = (n < half) ? n : half;
|
||||
|
||||
#pragma unroll
|
||||
for (int row = 0; row < n; ++row) {
|
||||
for (int row = 0; row < nrows_low; ++row) {
|
||||
float sum = 0.0f;
|
||||
|
||||
{
|
||||
int j = lane;
|
||||
if (j < row) {
|
||||
sum += sA[row * n + j] * sXt[col_idx * n + j];
|
||||
}
|
||||
if (lane < row) {
|
||||
sum += sA[row * n + lane] * x_low;
|
||||
}
|
||||
if (row >= WARP_SIZE) {
|
||||
int j = WARP_SIZE + lane;
|
||||
if (j < row) {
|
||||
sum += sA[row * n + j] * sXt[col_idx * n + j];
|
||||
}
|
||||
}
|
||||
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
if (lane == 0) {
|
||||
const float b_val = sXt[col_idx * n + row];
|
||||
const float a_diag = sA[row * n + row];
|
||||
// no safeguards for division by zero because that indicates corrupt
|
||||
// data anyway
|
||||
sXt[col_idx * n + row] = (b_val - sum) / a_diag;
|
||||
if (lane == row) {
|
||||
x_low = (x_low - sum) / sA[row * n + row];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int row = half; row < n; ++row) {
|
||||
float sum = sA[row * n + lane] * x_low;
|
||||
const int j = half + lane;
|
||||
if (j < row) {
|
||||
sum += sA[row * n + j] * x_high;
|
||||
}
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
if (lane == row - half) {
|
||||
x_high = (x_high - sum) / sA[row * n + row];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < rows_per_warp; i++) {
|
||||
const int i0 = lane + i * WARP_SIZE;
|
||||
if (i0 < n) {
|
||||
X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0];
|
||||
for (int rr = 0; rr < 2; ++rr) {
|
||||
const int row = rr * WARP_SIZE + lane;
|
||||
if (row < n) {
|
||||
const float val = (row < half) ? x_low : x_high;
|
||||
X_batch[row * k + col_idx] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,35 +7,85 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|||
|
||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||
|
||||
void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
|
||||
void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i,
|
||||
const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
|
||||
// Compute starting index in matrix B for this superblock
|
||||
const uint y_idx = i * QUANT_K + 32 * ib32;
|
||||
|
||||
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
|
||||
|
||||
// Precompute indices for quantization lookup tables
|
||||
const uint qh_base = 2 * ib32;
|
||||
const uint qs_base = 4 * ib32;
|
||||
const uint sc_index = ib32 / 2;
|
||||
const uint sc_shift = 6 * (ib32 & 1);
|
||||
|
||||
// Loop over rows in the superblock
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
// Load per-block scales and shift for quantization
|
||||
const uint16_t[4] scales = data_a[ibi].scales;
|
||||
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
||||
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
|
||||
const uint sc = data_a[ibi].scales[sc_index] >> sc_shift;
|
||||
|
||||
const uint sc = data_a[ibi].scales[ib32 / 2] >> (6 * (ib32 & 1));
|
||||
// Temporary caches for decoding
|
||||
FLOAT_TYPE dl_cache[4];
|
||||
uint16_t gvf_cache[4];
|
||||
float delta_cache[4];
|
||||
|
||||
// Precompute the multiplier and lookup values for 4 sub-blocks
|
||||
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
||||
const uint qh = data_a[ibi].qh[2 * ib32 + l / 2] >> (4 * (l&1));
|
||||
const uint qs = data_a[ibi].qs[4 * ib32 + l];
|
||||
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1);
|
||||
dl_cache[l] = FLOAT_TYPE(d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1));
|
||||
const uint qh = data_a[ibi].qh[qh_base + l / 2] >> (4 * (l & 1));
|
||||
const uint qs = data_a[ibi].qs[qs_base + l];
|
||||
gvf_cache[l] = iq1s_grid[qs | ((qh & 7) << 8)];
|
||||
delta_cache[l] = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||
}
|
||||
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||
// Loop over columns of the output
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
// Compute base index for matrix B
|
||||
const uint base_b_idx = (j * p.batch_stride_b + b_offset + y_idx) / 4;
|
||||
vec4 b_vals[8];
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
|
||||
vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
|
||||
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
|
||||
[[unroll]] for (int k = 0; k < 4; ++k) {
|
||||
sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta,
|
||||
fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum));
|
||||
}
|
||||
temp[j][n] = fma(dl, sum, temp[j][n]);
|
||||
// Load 8 vec4 values from matrix B
|
||||
[[unroll]] for (int idx = 0; idx < 8; ++idx) {
|
||||
b_vals[idx] = vec4(data_b_v4[base_b_idx + idx]);
|
||||
}
|
||||
|
||||
FLOAT_TYPE col_sum = FLOAT_TYPE(0.0);
|
||||
|
||||
// Loop over sub-blocks
|
||||
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
||||
const uint16_t grid = gvf_cache[l];
|
||||
const float dl = dl_cache[l];
|
||||
|
||||
// Decode 8 2-bit fbits from gvf_cache
|
||||
float f0 = float(bitfieldExtract(grid, 0, 2));
|
||||
float f1 = float(bitfieldExtract(grid, 2, 2));
|
||||
float f2 = float(bitfieldExtract(grid, 4, 2));
|
||||
float f3 = float(bitfieldExtract(grid, 6, 2));
|
||||
float f4 = float(bitfieldExtract(grid, 8, 2));
|
||||
float f5 = float(bitfieldExtract(grid, 10, 2));
|
||||
float f6 = float(bitfieldExtract(grid, 12, 2));
|
||||
float f7 = float(bitfieldExtract(grid, 14, 2));
|
||||
|
||||
// Pack into vec4 for vectorized FMA
|
||||
const vec4 fbits_v0 = vec4(f0, f1, f2, f3);
|
||||
const vec4 fbits_v1 = vec4(f4, f5, f6, f7);
|
||||
const vec4 delta_v = vec4(delta_cache[l]);
|
||||
|
||||
// Vectorized fused multiply-add
|
||||
vec4 sum_v = fma(b_vals[2*l + 0], fbits_v0 + delta_v, vec4(0.0));
|
||||
sum_v = fma(b_vals[2*l + 1], fbits_v1 + delta_v, sum_v);
|
||||
|
||||
// Horizontal add to get scalar sum
|
||||
FLOAT_TYPE sum = sum_v.x + sum_v.y + sum_v.z + sum_v.w;
|
||||
|
||||
// Accumulate to column sum
|
||||
col_sum = fma(dl, sum, col_sum);
|
||||
}
|
||||
// Write result to temporary buffer
|
||||
temp[j][n] += col_sum;
|
||||
}
|
||||
ibi += num_blocks_per_row;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -251,7 +251,10 @@ llama_context::llama_context(
|
|||
|
||||
LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
|
||||
|
||||
const size_t max_nodes = this->graph_max_nodes();
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
const size_t max_nodes = this->graph_max_nodes(n_tokens);
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
|
||||
|
||||
|
|
@ -309,9 +312,6 @@ llama_context::llama_context(
|
|||
|
||||
cross.v_embd.clear();
|
||||
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
// avoid reserving graphs with zero outputs - assume one output per sequence
|
||||
n_outputs = n_seqs;
|
||||
|
||||
|
|
@ -1396,9 +1396,9 @@ void llama_context::output_reorder() {
|
|||
// graph
|
||||
//
|
||||
|
||||
uint32_t llama_context::graph_max_nodes() const {
|
||||
uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
|
||||
if (model.arch == LLM_ARCH_QWEN3NEXT) {
|
||||
return std::max<uint32_t>(8192u, 32u*model.n_tensors());
|
||||
return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
|
||||
}
|
||||
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ private:
|
|||
//
|
||||
|
||||
public:
|
||||
uint32_t graph_max_nodes() const;
|
||||
uint32_t graph_max_nodes(uint32_t n_tokens) const;
|
||||
|
||||
// can reuse the llm_graph_result instance of the context (for example to update a memory module)
|
||||
llm_graph_result * get_gf_res_reserve() const;
|
||||
|
|
|
|||
|
|
@ -1733,6 +1733,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
}
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);
|
||||
|
||||
// (optional) temperature tuning - used by mistral-large
|
||||
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 27: type = LLM_TYPE_16B; break;
|
||||
case 60: type = LLM_TYPE_236B; break;
|
||||
|
|
|
|||
|
|
@ -30,6 +30,12 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
|||
// {n_embd, n_tokens}
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// (optional) temperature tuning - used by mistral-large
|
||||
ggml_tensor * inp_attn_scale = nullptr;
|
||||
if (hparams.f_attn_temp_scale != 0.0f) {
|
||||
inp_attn_scale = build_inp_attn_scale();
|
||||
}
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
|
|
@ -128,6 +134,12 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
|||
ggml_tensor * Vcur = kv_cmpr;
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
if (inp_attn_scale) {
|
||||
// apply llama 4 temperature scaling
|
||||
Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
|
||||
cb(Qcur, "Qcur_attn_temp_scaled", il);
|
||||
}
|
||||
|
||||
// note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group)
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
|
|
@ -160,6 +172,12 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
|||
ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
if (inp_attn_scale) {
|
||||
// apply llama 4 temperature scaling
|
||||
Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
|
||||
cb(Qcur, "Qcur_attn_temp_scaled", il);
|
||||
}
|
||||
|
||||
// note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups)
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
|
|
|
|||
177
tools/server/README-dev.md
Normal file
177
tools/server/README-dev.md
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
# llama-server Development Documentation
|
||||
|
||||
This document provides an in-depth technical overview of `llama-server`, intended for maintainers and contributors.
|
||||
|
||||
If you are an end user consuming `llama-server` as a product, please refer to the main [README](./README.md) instead.
|
||||
|
||||
## Backend
|
||||
|
||||
### Overview
|
||||
|
||||
The server supports two primary operating modes:
|
||||
|
||||
- **Inference mode**: The default mode for performing inference with a single loaded GGUF model.
|
||||
- **Router mode**: Enables management of multiple inference server instances behind a single API endpoint. Requests are automatically routed to the appropriate backend instance based on the requested model.
|
||||
|
||||
The core architecture consists of the following components:
|
||||
|
||||
- `server_context`: Holds the primary inference state, including the main `llama_context` and all active slots.
|
||||
- `server_slot`: An abstraction over a single “sequence” in llama.cpp, responsible for managing individual parallel inference requests.
|
||||
- `server_routes`: Middleware layer between `server_context` and the HTTP interface; handles JSON parsing/formatting and request routing logic.
|
||||
- `server_http_context`: Implements the HTTP server using `cpp-httplib`.
|
||||
- `server_queue`: Thread-safe queue used by HTTP workers to submit new tasks to `server_context`.
|
||||
- `server_response`: Thread-safe queue used by `server_context` to return results to HTTP workers.
|
||||
- `server_response_reader`: Higher-level wrapper around the two queues above for cleaner code.
|
||||
- `server_task`: Unit of work pushed into `server_queue`.
|
||||
- `server_task_result`: Unit of result pushed into `server_response`.
|
||||
- `server_tokens`: Unified representation of token sequences (supports both text and multimodal tokens); used by `server_task` and `server_slot`.
|
||||
- `server_prompt_checkpoint`: For recurrent (e.g., RWKV) and SWA models, stores snapshots of KV cache state. Enables reuse when subsequent requests share the same prompt prefix, saving redundant computation.
|
||||
- `server_models`: Standalone component for managing multiple backend instances (used in router mode). It is completely independent of `server_context`.
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
API_User <--> server_http_context
|
||||
server_http_context <-- router mode --> server_models
|
||||
server_http_context <-- inference mode --> server_routes
|
||||
server_routes -- server_task --> server_queue
|
||||
subgraph server_context
|
||||
server_queue --> server_slot
|
||||
server_slot -- server_task_result --> server_response
|
||||
server_slot[multiple server_slot]
|
||||
end
|
||||
server_response --> server_routes
|
||||
```
|
||||
|
||||
### Batching
|
||||
|
||||
The server context maintains a single batch shared across all slots. When `update_slots()` is invoked, the system iterates through all active slots to populate this batch. For each slot, either a generated token from the previous decoding step or available prompt tokens are added to the batch.
|
||||
|
||||
Batching constraints apply: slots can only be batched together if they share compatible configurations. For instance, slots using a specific LoRA adapter can be batched with each other, but not with slots using a different LoRA adapter or no adapter at all.
|
||||
|
||||
Once the batch reaches capacity or all slots have been processed, `llama_decode` is called to execute the inference. This operation represents the primary computational bottleneck in `update_slots()`.
|
||||
|
||||
Following decoding, the system either retrieves embeddings or samples the next token using `common_sampler_sample`. If a slot has remaining prompt tokens to process, it yields until the next `update_slots()` iteration.
|
||||
|
||||
### Thread Management
|
||||
|
||||
`server_context` runs on a dedicated single thread. Because it is single-threaded, heavy post-processing (especially after token generation) should be avoided, as it directly impacts multi-sequence throughput.
|
||||
|
||||
Each incoming HTTP request is handled by its own thread managed by the HTTP library. The following operations are performed in HTTP worker threads:
|
||||
|
||||
- JSON request parsing
|
||||
- Chat template application
|
||||
- Tokenization
|
||||
- Conversion of `server_task_result` into final JSON response
|
||||
- Error formatting into JSON
|
||||
- Tracking of partial/incremental responses (e.g., streaming tool calls or reasoning steps)
|
||||
|
||||
**Best practices to follow:**
|
||||
|
||||
- All JSON formatting and chat template logic must stay in the HTTP layer.
|
||||
- Avoid passing raw JSON between the HTTP layer and `server_slot`. Instead, parse everything into native C++ types as early as possible.
|
||||
|
||||
### Example trace of a request
|
||||
|
||||
Here is an example trace of an API request for text completion:
|
||||
|
||||
- A request arrives at the HTTP layer.
|
||||
- The request is routed to the corresponding handler inside `server_routes`. In this case, `handle_completions_impl` is invoked.
|
||||
- The handler parses the input request, constructs a new `server_task`, and passes it to `server_res_generator`.
|
||||
- `server_res_generator` creates a new `task_result_state` for each task:
|
||||
- `task_result_state` stays in the HTTP layer, responsible for keeping track of the current state of the response (e.g., parsing tool calls or thinking messages).
|
||||
- `server_task` is moved into `server_queue` inside `server_context`.
|
||||
- `server_context` launches the task by moving it into an available slot (see `launch_slot_with_task()`).
|
||||
- `update_slot()` processes the task as described in the "Batching" section above.
|
||||
- Results may be sent using `send_partial_response` or `send_final_response`, which creates a new `server_task_result` and pushes it to the response queue.
|
||||
- At the same time, `server_res_generator` listens to the response queue and retrieves this response.
|
||||
- As the response is stateless, `server_res_generator` calls `response->update()` to update the response with the current state.
|
||||
- `server_res_generator` then calls `response->to_json()` and passes the response to the HTTP layer.
|
||||
|
||||
### Testing
|
||||
|
||||
`llama-server` includes an automated test suite based on `pytest`.
|
||||
|
||||
The framework automatically starts a `llama-server` instance, sends requests, and validates responses.
|
||||
|
||||
For detailed instructions, see the [test documentation](./tests/README.md).
|
||||
|
||||
### Notable Related PRs
|
||||
|
||||
- Initial server implementation: https://github.com/ggml-org/llama.cpp/pull/1443
|
||||
- Parallel decoding support: https://github.com/ggml-org/llama.cpp/pull/3228
|
||||
- Refactor introducing `server_queue` and `server_response`: https://github.com/ggml-org/llama.cpp/pull/5065
|
||||
- Reranking endpoint: https://github.com/ggml-org/llama.cpp/pull/9510
|
||||
- Multimodal model support (`libmtmd`): https://github.com/ggml-org/llama.cpp/pull/12898
|
||||
- Unified KV cache handling: https://github.com/ggml-org/llama.cpp/pull/16736
|
||||
- Separation of HTTP logic into dedicated files: https://github.com/ggml-org/llama.cpp/pull/17216
|
||||
- Large-scale code base split into smaller files: https://github.com/ggml-org/llama.cpp/pull/17362
|
||||
- Introduction of router mode: https://github.com/ggml-org/llama.cpp/pull/17470
|
||||
- Speculative decoding: https://github.com/ggml-org/llama.cpp/pull/17808 and rework in https://github.com/ggml-org/llama.cpp/pull/17808
|
||||
|
||||
|
||||
|
||||
|
||||
## Web UI
|
||||
|
||||
The project includes a web-based user interface for interacting with `llama-server`. It supports both single-model (`MODEL` mode) and multi-model (`ROUTER` mode) operation.
|
||||
|
||||
The SvelteKit-based Web UI is introduced in this PR: https://github.com/ggml-org/llama.cpp/pull/14839
|
||||
|
||||
### Features
|
||||
|
||||
- **Chat interface** with streaming responses
|
||||
- **Multi-model support** (ROUTER mode) - switch between models, auto-load on selection
|
||||
- **Modality validation** - ensures selected model supports conversation's attachments (images, audio)
|
||||
- **Conversation management** - branching, regeneration, editing with history preservation
|
||||
- **Attachment support** - images, audio, PDFs (with vision/text fallback)
|
||||
- **Configurable parameters** - temperature, top_p, etc. synced with server defaults
|
||||
- **Dark/light theme**
|
||||
|
||||
### Tech Stack
|
||||
|
||||
- **SvelteKit** - frontend framework with Svelte 5 runes for reactive state
|
||||
- **TailwindCSS** + **shadcn-svelte** - styling and UI components
|
||||
- **Vite** - build tooling
|
||||
- **IndexedDB** (Dexie) - local storage for conversations
|
||||
- **LocalStorage** - user settings persistence
|
||||
|
||||
### Architecture
|
||||
|
||||
The WebUI follows a layered architecture:
|
||||
|
||||
```
|
||||
Routes → Components → Hooks → Stores → Services → Storage/API
|
||||
```
|
||||
|
||||
- **Stores** - reactive state management (`chatStore`, `conversationsStore`, `modelsStore`, `serverStore`, `settingsStore`)
|
||||
- **Services** - stateless API/database communication (`ChatService`, `ModelsService`, `PropsService`, `DatabaseService`)
|
||||
- **Hooks** - reusable logic (`useModelChangeValidation`, `useProcessingState`)
|
||||
|
||||
For detailed architecture diagrams, see [`tools/server/webui/docs/`](webui/docs/):
|
||||
|
||||
- `high-level-architecture.mmd` - full architecture with all modules
|
||||
- `high-level-architecture-simplified.mmd` - simplified overview
|
||||
- `data-flow-simplified-model-mode.mmd` - data flow for single-model mode
|
||||
- `data-flow-simplified-router-mode.mmd` - data flow for multi-model mode
|
||||
- `flows/*.mmd` - detailed per-domain flows (chat, conversations, models, etc.)
|
||||
|
||||
### Development
|
||||
|
||||
```sh
|
||||
# make sure you have Node.js installed
|
||||
cd tools/server/webui
|
||||
npm i
|
||||
|
||||
# run dev server (with hot reload)
|
||||
npm run dev
|
||||
|
||||
# run tests
|
||||
npm run test
|
||||
|
||||
# build production bundle
|
||||
npm run build
|
||||
```
|
||||
|
||||
After `public/index.html.gz` has been generated, rebuild `llama-server` as described in the [build](#build) section to include the updated UI.
|
||||
|
||||
**Note:** The Vite dev server automatically proxies API requests to `http://localhost:8080`. Make sure `llama-server` is running on that port during development.
|
||||
|
|
@ -18,11 +18,13 @@ const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "
|
|||
using json = nlohmann::ordered_json;
|
||||
|
||||
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||
#define SLT_CNT(slot, fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
|
||||
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||
|
||||
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
|
||||
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
|
|
|
|||
|
|
@ -102,6 +102,11 @@ struct server_slot {
|
|||
std::string generated_text;
|
||||
llama_tokens generated_tokens;
|
||||
|
||||
// idx of draft tokens in the main batch
|
||||
// non-empty if we went to evaluate draft tokens
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/17808
|
||||
std::vector<int32_t> i_batch_dft;
|
||||
|
||||
std::vector<completion_token_output> generated_token_probs;
|
||||
|
||||
bool has_next_token = true;
|
||||
|
|
@ -150,7 +155,8 @@ struct server_slot {
|
|||
|
||||
struct common_sampler * smpl = nullptr;
|
||||
|
||||
llama_token sampled;
|
||||
llama_token sampled; // in speculative mode, this is the last accepted token
|
||||
llama_tokens drafted;
|
||||
|
||||
// stats
|
||||
size_t n_sent_text = 0; // number of sent text character
|
||||
|
|
@ -180,6 +186,8 @@ struct server_slot {
|
|||
stopping_word = "";
|
||||
n_sent_text = 0;
|
||||
|
||||
drafted.clear();
|
||||
i_batch_dft.clear();
|
||||
generated_tokens.clear();
|
||||
generated_token_probs.clear();
|
||||
json_schema = json();
|
||||
|
|
@ -255,6 +263,31 @@ struct server_slot {
|
|||
generated_token_probs.push_back(token);
|
||||
}
|
||||
|
||||
int get_n_draft_max() const {
|
||||
if (!can_speculate()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// determine the max draft that fits the current slot state
|
||||
int n_draft_max = task->params.speculative.n_max;
|
||||
|
||||
// note: slot.prompt is not yet expanded with the `id` token sampled above
|
||||
// also, need to leave space for 1 extra token to allow context shifts
|
||||
n_draft_max = std::min(n_draft_max, n_ctx - prompt.n_tokens() - 2);
|
||||
|
||||
if (n_remaining > 0) {
|
||||
n_draft_max = std::min(n_draft_max, n_remaining - 1);
|
||||
}
|
||||
|
||||
SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);
|
||||
|
||||
if (n_draft_max < task->params.speculative.n_min) {
|
||||
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min);
|
||||
n_draft_max = 0;
|
||||
}
|
||||
return n_draft_max;
|
||||
}
|
||||
|
||||
// note: a slot can also be either a parent or a child
|
||||
bool is_parent() const {
|
||||
return is_processing() && task->n_children > 0;
|
||||
|
|
@ -353,8 +386,7 @@ struct server_slot {
|
|||
|
||||
if (n_draft_total > 0) {
|
||||
const float draft_ratio = (float) n_draft_accepted / n_draft_total;
|
||||
SLT_INF(*this,
|
||||
"\n"
|
||||
SLT_CNT(*this,
|
||||
"draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
|
||||
draft_ratio, n_draft_accepted, n_draft_total
|
||||
);
|
||||
|
|
@ -1774,14 +1806,57 @@ struct server_context_impl {
|
|||
continue;
|
||||
}
|
||||
|
||||
slot.i_batch = batch.n_tokens;
|
||||
// generate draft tokens in speculative decoding mode
|
||||
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
|
||||
// perform the speculative drafting for all sequences at the same time in a single batch
|
||||
int n_draft_max = slot.get_n_draft_max();
|
||||
if (n_draft_max > 0) {
|
||||
if (mctx) {
|
||||
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
|
||||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
|
||||
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
|
||||
struct common_speculative_params params_spec;
|
||||
params_spec.n_draft = n_draft_max;
|
||||
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
|
||||
params_spec.p_min = slot.task->params.speculative.p_min;
|
||||
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
|
||||
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
|
||||
|
||||
slot.prompt.tokens.push_back(slot.sampled);
|
||||
// add the sampled token to the batch
|
||||
slot.i_batch_dft.push_back(batch.n_tokens);
|
||||
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
|
||||
slot.prompt.tokens.push_back(slot.sampled);
|
||||
|
||||
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
|
||||
slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
|
||||
if (slot.task->params.speculative.n_min > (int) draft.size()) {
|
||||
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
|
||||
// fallback to normal decoding
|
||||
slot.i_batch = slot.i_batch_dft[0];
|
||||
slot.drafted.clear();
|
||||
slot.i_batch_dft.clear();
|
||||
} else {
|
||||
// keep track of total number of drafted tokens tested
|
||||
slot.n_draft_total += draft.size();
|
||||
|
||||
// add all drafted tokens to the batch
|
||||
for (size_t i = 0; i < draft.size(); i++) {
|
||||
slot.i_batch_dft.push_back(batch.n_tokens);
|
||||
common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true);
|
||||
slot.prompt.tokens.push_back(draft[i]);
|
||||
}
|
||||
slot.drafted = std::move(draft);
|
||||
}
|
||||
} else {
|
||||
// no speculative decoding
|
||||
slot.i_batch = batch.n_tokens;
|
||||
|
||||
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
|
||||
|
||||
slot.prompt.tokens.push_back(slot.sampled);
|
||||
|
||||
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
|
||||
slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
|
||||
}
|
||||
}
|
||||
|
||||
// process in chunks of params.n_batch
|
||||
|
|
@ -1880,8 +1955,18 @@ struct server_context_impl {
|
|||
n_past = std::min(n_past, slot.alora_invocation_start - 1);
|
||||
}
|
||||
|
||||
const auto n_cache_reuse = slot.task->params.n_cache_reuse;
|
||||
|
||||
const bool can_cache_reuse =
|
||||
llama_memory_can_shift(llama_get_memory(ctx)) &&
|
||||
!slot.prompt.tokens.has_mtmd;
|
||||
|
||||
if (!can_cache_reuse && n_cache_reuse > 0) {
|
||||
SLT_WRN(slot, "cache reuse is not supported - ignoring n_cache_reuse = %d\n", n_cache_reuse);
|
||||
}
|
||||
|
||||
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
||||
if (params_base.n_cache_reuse > 0) {
|
||||
if (can_cache_reuse && n_cache_reuse > 0) {
|
||||
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
||||
|
||||
size_t head_c = n_past; // cache
|
||||
|
|
@ -1892,7 +1977,7 @@ struct server_context_impl {
|
|||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past);
|
||||
SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", n_cache_reuse, n_past);
|
||||
|
||||
while (head_c < slot.prompt.tokens.size() &&
|
||||
head_p < input_tokens.size()) {
|
||||
|
|
@ -1901,11 +1986,10 @@ struct server_context_impl {
|
|||
while (head_c + n_match < slot.prompt.tokens.size() &&
|
||||
head_p + n_match < input_tokens.size() &&
|
||||
slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
|
||||
|
||||
n_match++;
|
||||
}
|
||||
|
||||
if (n_match >= (size_t) params_base.n_cache_reuse) {
|
||||
if (n_match >= (size_t) n_cache_reuse) {
|
||||
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
|
||||
//for (size_t i = head_p; i < head_p + n_match; i++) {
|
||||
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
|
|
@ -2336,6 +2420,10 @@ struct server_context_impl {
|
|||
// on successful decode, restore the original batch size
|
||||
n_batch = llama_n_batch(ctx);
|
||||
|
||||
// technically, measuring the time here excludes the sampling time for the last batch
|
||||
// but on the other hand, we don't want to do too many system calls to measure the time, so it's ok
|
||||
const int64_t t_current = ggml_time_us();
|
||||
|
||||
for (auto & slot : slots) {
|
||||
// may need to copy state to other slots
|
||||
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
|
||||
|
|
@ -2390,6 +2478,10 @@ struct server_context_impl {
|
|||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
if (slot.i_batch_dft.size() > 0) {
|
||||
continue; // sample using speculative decoding
|
||||
}
|
||||
|
||||
const int tok_idx = slot.i_batch - i;
|
||||
|
||||
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
|
||||
|
|
@ -2400,8 +2492,6 @@ struct server_context_impl {
|
|||
|
||||
slot.n_decoded += 1;
|
||||
|
||||
const int64_t t_current = ggml_time_us();
|
||||
|
||||
if (slot.n_decoded == 1) {
|
||||
slot.t_start_generation = t_current;
|
||||
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
||||
|
|
@ -2430,84 +2520,32 @@ struct server_context_impl {
|
|||
}
|
||||
}
|
||||
|
||||
// do speculative decoding
|
||||
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
|
||||
// perform the speculative drafting for all sequences at the same time in a single batch
|
||||
// speculative decoding - main model sample and accept
|
||||
for (auto & slot : slots) {
|
||||
if (!slot.is_processing() || !slot.can_speculate()) {
|
||||
if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (slot.state != SLOT_STATE_GENERATING) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (mctx) {
|
||||
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
|
||||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
|
||||
// determine the max draft that fits the current slot state
|
||||
int n_draft_max = slot.task->params.speculative.n_max;
|
||||
|
||||
// note: slot.prompt is not yet expanded with the `id` token sampled above
|
||||
// also, need to leave space for 1 extra token to allow context shifts
|
||||
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2);
|
||||
|
||||
if (slot.n_remaining > 0) {
|
||||
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
|
||||
|
||||
if (n_draft_max < slot.task->params.speculative.n_min) {
|
||||
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min);
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
llama_token id = slot.sampled;
|
||||
|
||||
struct common_speculative_params params_spec;
|
||||
params_spec.n_draft = n_draft_max;
|
||||
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
|
||||
params_spec.p_min = slot.task->params.speculative.p_min;
|
||||
|
||||
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
|
||||
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
|
||||
|
||||
// ignore small drafts
|
||||
if (slot.task->params.speculative.n_min > (int) draft.size()) {
|
||||
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// keep track of total number of drafted tokens tested
|
||||
slot.n_draft_total += draft.size();
|
||||
|
||||
// construct the speculation batch
|
||||
common_batch_clear(slot.batch_spec);
|
||||
common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true);
|
||||
|
||||
for (size_t i = 0; i < draft.size(); ++i) {
|
||||
common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true);
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
|
||||
|
||||
llama_decode(ctx, slot.batch_spec);
|
||||
size_t n_draft = slot.drafted.size();
|
||||
|
||||
// the accepted tokens from the speculation
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, slot.i_batch_dft, slot.drafted);
|
||||
slot.i_batch_dft.clear();
|
||||
slot.drafted.clear();
|
||||
|
||||
slot.n_decoded += ids.size();
|
||||
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
||||
|
||||
// update how many tokens out of those tested were accepted
|
||||
slot.n_draft_accepted += ids.size() - 1;
|
||||
|
||||
slot.prompt.tokens.push_back(id);
|
||||
// rollback to the state before sampling the draft tokens
|
||||
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
|
||||
|
||||
// add accepted tokens to the prompt
|
||||
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
|
||||
slot.sampled = ids.back(); // last accepted token
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
|
||||
|
||||
|
|
@ -2530,7 +2568,7 @@ struct server_context_impl {
|
|||
}
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens());
|
||||
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) slot.drafted.size(), slot.prompt.n_tokens());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2551,6 +2589,10 @@ struct server_context_impl {
|
|||
int get_slot_n_ctx() {
|
||||
return slots.back().n_ctx;
|
||||
}
|
||||
|
||||
server_response_reader get_response_reader() {
|
||||
return server_response_reader(queue_tasks, queue_results, HTTP_POLLING_SECONDS);
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
|
|
@ -2580,8 +2622,8 @@ llama_context * server_context::get_llama_context() const {
|
|||
return impl->ctx;
|
||||
}
|
||||
|
||||
std::pair<server_queue &, server_response &> server_context::get_queues() {
|
||||
return { impl->queue_tasks, impl->queue_results };
|
||||
server_response_reader server_context::get_response_reader() {
|
||||
return impl->get_response_reader();
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -2590,7 +2632,7 @@ std::pair<server_queue &, server_response &> server_context::get_queues() {
|
|||
struct server_res_generator : server_http_res {
|
||||
server_response_reader rd;
|
||||
server_res_generator(server_context_impl & ctx_server)
|
||||
: rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS) {}
|
||||
: rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) {}
|
||||
void ok(const json & response_data) {
|
||||
status = 200;
|
||||
data = safe_json_to_str(response_data);
|
||||
|
|
@ -2623,9 +2665,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|||
try {
|
||||
std::vector<server_task> tasks;
|
||||
|
||||
// tracking generation state and partial tool calls
|
||||
std::vector<task_result_state> states;
|
||||
|
||||
const auto & prompt = data.at("prompt");
|
||||
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
||||
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
||||
|
|
@ -2641,7 +2680,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|||
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
||||
}
|
||||
tasks.reserve(inputs.size());
|
||||
states.reserve(inputs.size());
|
||||
int idx = 0;
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
server_task task = server_task(type);
|
||||
|
|
@ -2660,7 +2698,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|||
task.params.res_type = res_type;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
task.params.oaicompat_model = ctx_server.model_name;
|
||||
states.push_back(task.params.oaicompat_chat_syntax);
|
||||
|
||||
if (task.params.n_cmpl > 1) {
|
||||
task.n_children = task.params.n_cmpl - 1;
|
||||
|
|
@ -2669,7 +2706,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|||
task.id,
|
||||
ctx_server.queue_tasks.get_new_id(),
|
||||
idx++);
|
||||
states.push_back(child.params.oaicompat_chat_syntax);
|
||||
tasks.push_back(std::move(child));
|
||||
}
|
||||
}
|
||||
|
|
@ -2677,7 +2713,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
rd.set_states(std::move(states));
|
||||
rd.post_tasks(std::move(tasks));
|
||||
} catch (const std::exception & e) {
|
||||
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||
|
|
@ -3407,7 +3442,7 @@ void server_routes::init_routes() {
|
|||
|
||||
// create and queue the task
|
||||
json responses = json::array();
|
||||
server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS);
|
||||
server_response_reader rd = ctx_server.get_response_reader();
|
||||
{
|
||||
std::vector<server_task> tasks;
|
||||
tasks.reserve(documents.size());
|
||||
|
|
@ -3667,7 +3702,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_embeddings_impl(cons
|
|||
|
||||
// create and queue the task
|
||||
json responses = json::array();
|
||||
server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS);
|
||||
server_response_reader rd = ctx_server.get_response_reader();
|
||||
{
|
||||
std::vector<server_task> tasks;
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
|
|
|
|||
|
|
@ -31,9 +31,8 @@ struct server_context {
|
|||
// get the underlaying llama_context
|
||||
llama_context * get_llama_context() const;
|
||||
|
||||
// get the underlaying queue_tasks and queue_results
|
||||
// used by CLI application
|
||||
std::pair<server_queue &, server_response &> get_queues();
|
||||
// get a new response reader, used by CLI application
|
||||
server_response_reader get_response_reader();
|
||||
};
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -271,12 +271,21 @@ void server_response::terminate() {
|
|||
// server_response_reader
|
||||
//
|
||||
|
||||
void server_response_reader::set_states(std::vector<task_result_state> && states) {
|
||||
this->states = std::move(states);
|
||||
void server_response_reader::post_task(server_task && task) {
|
||||
GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader");
|
||||
id_tasks.insert(task.id);
|
||||
states.push_back(task.create_state());
|
||||
queue_results.add_waiting_task_id(task.id);
|
||||
queue_tasks.post(std::move(task));
|
||||
}
|
||||
|
||||
void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
|
||||
GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader");
|
||||
id_tasks = server_task::get_list_id(tasks);
|
||||
states.reserve(tasks.size());
|
||||
for (size_t i = 0; i < tasks.size(); i++) {
|
||||
states.push_back(tasks[i].create_state());
|
||||
}
|
||||
queue_results.add_waiting_tasks(tasks);
|
||||
queue_tasks.post(std::move(tasks));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -129,13 +129,13 @@ struct server_response_reader {
|
|||
std::vector<task_result_state> states;
|
||||
|
||||
// should_stop function will be called each polling_interval_seconds
|
||||
server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
|
||||
: queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
|
||||
server_response_reader(server_queue & queue_tasks, server_response & queue_results, int polling_interval_seconds)
|
||||
: queue_tasks(queue_tasks), queue_results(queue_results), polling_interval_seconds(polling_interval_seconds) {}
|
||||
~server_response_reader() {
|
||||
stop();
|
||||
}
|
||||
|
||||
void set_states(std::vector<task_result_state> && states);
|
||||
void post_task(server_task && tasks);
|
||||
void post_tasks(std::vector<server_task> && tasks);
|
||||
bool has_next() const;
|
||||
|
||||
|
|
|
|||
|
|
@ -155,11 +155,12 @@ task_params server_task::params_from_json_cmpl(
|
|||
|
||||
// Sampling parameter defaults are loaded from the global server context (but individual requests can still them)
|
||||
task_params defaults;
|
||||
defaults.sampling = params_base.sampling;
|
||||
defaults.speculative = params_base.speculative;
|
||||
defaults.n_keep = params_base.n_keep;
|
||||
defaults.n_predict = params_base.n_predict;
|
||||
defaults.antiprompt = params_base.antiprompt;
|
||||
defaults.sampling = params_base.sampling;
|
||||
defaults.speculative = params_base.speculative;
|
||||
defaults.n_keep = params_base.n_keep;
|
||||
defaults.n_predict = params_base.n_predict;
|
||||
defaults.n_cache_reuse = params_base.n_cache_reuse;
|
||||
defaults.antiprompt = params_base.antiprompt;
|
||||
|
||||
// enabling this will output extra debug information in the HTTP responses from the server
|
||||
params.verbose = params_base.verbosity > 9;
|
||||
|
|
@ -176,6 +177,7 @@ task_params server_task::params_from_json_cmpl(
|
|||
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
|
||||
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
||||
params.n_cmpl = json_value(data, "n_cmpl", json_value(data, "n", 1));
|
||||
params.n_cache_reuse = json_value(data, "n_cache_reuse", defaults.n_cache_reuse);
|
||||
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
|
||||
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
||||
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
|
||||
|
|
|
|||
|
|
@ -55,6 +55,8 @@ struct task_params {
|
|||
int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
|
||||
int32_t n_cmpl = 1; // number of completions to generate from this prompt
|
||||
|
||||
int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled)
|
||||
|
||||
int64_t t_max_prompt_ms = -1; // TODO: implement
|
||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||
|
||||
|
|
@ -62,18 +64,19 @@ struct task_params {
|
|||
|
||||
std::vector<std::string> antiprompt;
|
||||
std::vector<std::string> response_fields;
|
||||
bool timings_per_token = false;
|
||||
|
||||
bool timings_per_token = false;
|
||||
bool post_sampling_probs = false;
|
||||
|
||||
struct common_params_sampling sampling;
|
||||
struct common_params_speculative speculative;
|
||||
|
||||
// response formatting
|
||||
bool verbose = false;
|
||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_syntax oaicompat_chat_syntax;
|
||||
bool verbose = false;
|
||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_syntax oaicompat_chat_syntax;
|
||||
|
||||
// Embeddings
|
||||
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
|
||||
|
|
@ -82,6 +85,25 @@ struct task_params {
|
|||
json to_json(bool only_metrics = false) const;
|
||||
};
|
||||
|
||||
// struct for tracking the state of a task (e.g., for streaming)
|
||||
struct task_result_state {
|
||||
// tracking diffs for partial tool calls
|
||||
std::vector<common_chat_msg_diff> diffs;
|
||||
common_chat_syntax oaicompat_chat_syntax;
|
||||
common_chat_msg chat_msg;
|
||||
std::string generated_text; // append new chunks of generated text here
|
||||
std::vector<std::string> generated_tool_call_ids;
|
||||
|
||||
task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
|
||||
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}
|
||||
|
||||
// parse partial tool calls and update the internal state
|
||||
common_chat_msg update_chat_msg(
|
||||
const std::string & text_added,
|
||||
bool is_partial,
|
||||
std::vector<common_chat_msg_diff> & diffs);
|
||||
};
|
||||
|
||||
struct server_task {
|
||||
int id = -1; // to be filled by server_queue
|
||||
int index = -1; // used when there are multiple prompts (batch request)
|
||||
|
|
@ -146,6 +168,12 @@ struct server_task {
|
|||
copy.tokens = tokens.clone();
|
||||
return copy;
|
||||
}
|
||||
|
||||
// the task will be moved into queue, then onto slots
|
||||
// however, the state must be kept by caller (e.g., HTTP thread)
|
||||
task_result_state create_state() const {
|
||||
return task_result_state(params.oaicompat_chat_syntax);
|
||||
}
|
||||
};
|
||||
|
||||
struct result_timings {
|
||||
|
|
@ -177,25 +205,6 @@ struct result_prompt_progress {
|
|||
json to_json() const;
|
||||
};
|
||||
|
||||
// struct for tracking the state of a task (e.g., for streaming)
|
||||
struct task_result_state {
|
||||
// tracking diffs for partial tool calls
|
||||
std::vector<common_chat_msg_diff> diffs;
|
||||
common_chat_syntax oaicompat_chat_syntax;
|
||||
common_chat_msg chat_msg;
|
||||
std::string generated_text; // append new chunks of generated text here
|
||||
std::vector<std::string> generated_tool_call_ids;
|
||||
|
||||
task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
|
||||
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}
|
||||
|
||||
// parse partial tool calls and update the internal state
|
||||
common_chat_msg update_chat_msg(
|
||||
const std::string & text_added,
|
||||
bool is_partial,
|
||||
std::vector<common_chat_msg_diff> & diffs);
|
||||
};
|
||||
|
||||
struct server_task_result {
|
||||
int id = -1;
|
||||
int id_slot = -1;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue