diff --git a/common/arg.cpp b/common/arg.cpp index 115752704..1e1fcdfcf 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1679,7 +1679,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.warmup = false; } - ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING})); + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL})); add_opt(common_arg( {"--spm-infill"}, string_format( diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 15e019a10..e88076ccc 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2645,7 +2645,7 @@ class Qwen2Model(TextModel): yield from super().modify_tensors(data_torch, name, bid) -@ModelBase.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration") +@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration") class Qwen2VLModel(TextModel): model_arch = gguf.MODEL_ARCH.QWEN2VL @@ -2669,7 +2669,7 @@ class Qwen2VLModel(TextModel): return [(self.map_tensor_name(name), data_torch)] -@ModelBase.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration") +@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration") class Qwen2VLVisionModel(VisionModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index e4a6d946d..661bc5f39 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -534,14 +534,15 @@ extern "C" { GGML_UNARY_OP_STEP, GGML_UNARY_OP_TANH, GGML_UNARY_OP_ELU, - GGML_UNARY_OP_RELU, GGML_UNARY_OP_SIGMOID, GGML_UNARY_OP_GELU, + GGML_UNARY_OP_GELU_ERF, GGML_UNARY_OP_GELU_QUICK, GGML_UNARY_OP_SILU, GGML_UNARY_OP_HARDSWISH, GGML_UNARY_OP_HARDSIGMOID, GGML_UNARY_OP_EXP, + GGML_UNARY_OP_RELU, GGML_UNARY_OP_COUNT, }; @@ -1037,6 +1038,16 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // GELU using erf (error function) when possible + // some backends may fallback to approximation based on Abramowitz and Stegun formula + GGML_API struct ggml_tensor * ggml_gelu_erf( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_gelu_erf_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_gelu_quick( struct ggml_context * ctx, struct ggml_tensor * a); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 4498332f1..16833abc8 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2216,6 +2216,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 955fec59a..26501b711 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -2691,6 +2691,109 @@ static void ggml_compute_forward_gelu( } } +// ggml_compute_forward_gelu_erf + +static void ggml_compute_forward_gelu_erf_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_gelu_erf_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_gelu_erf_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_gelu_erf_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + GGML_UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + +static void ggml_compute_forward_gelu_erf( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_gelu_erf_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_gelu_erf_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_gelu_quick static void ggml_compute_forward_gelu_quick_f32( @@ -7749,6 +7852,10 @@ void ggml_compute_forward_unary( { ggml_compute_forward_gelu(params, dst); } break; + case GGML_UNARY_OP_GELU_ERF: + { + ggml_compute_forward_gelu_erf(params, dst); + } break; case GGML_UNARY_OP_GELU_QUICK: { ggml_compute_forward_gelu_quick(params, dst); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 23cbb3051..c77349ebe 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -428,6 +428,7 @@ inline static void ggml_vec_exp_f16 (const int n, ggml_fp16_t * y, const ggml_fp static const float GELU_COEF_A = 0.044715f; static const float GELU_QUICK_COEF = -1.702f; static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; +static const float SQRT_2_INV = 0.70710678118654752440084436210484f; inline static float ggml_gelu_f32(float x) { return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); @@ -440,6 +441,14 @@ inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp } } +inline static void ggml_vec_gelu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + float xi = GGML_FP16_TO_FP32(x[i]); + float res = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV)); + y[i] = GGML_FP32_TO_FP16(res); + } +} + #ifdef GGML_GELU_FP16 inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { uint16_t t; @@ -463,6 +472,13 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { } #endif +inline static void ggml_vec_gelu_erf_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + float xi = x[i]; + y[i] = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV)); + } +} + inline static float ggml_gelu_quick_f32(float x) { return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x))); } diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 7c3683745..c100787e2 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -149,6 +149,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_SIGMOID, GGML_METAL_KERNEL_TYPE_GELU, GGML_METAL_KERNEL_TYPE_GELU_4, + GGML_METAL_KERNEL_TYPE_GELU_ERF, + GGML_METAL_KERNEL_TYPE_GELU_ERF_4, GGML_METAL_KERNEL_TYPE_GELU_QUICK, GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, GGML_METAL_KERNEL_TYPE_SILU, @@ -1103,6 +1105,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF, gelu_erf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF_4, gelu_erf_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); @@ -1613,6 +1617,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_ELU: @@ -2251,6 +2256,25 @@ static bool ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_UNARY_OP_GELU_ERF: + { + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case GGML_UNARY_OP_GELU_QUICK: { int64_t n = ggml_nelements(dst); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f18473dcb..59899550e 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -856,6 +856,7 @@ kernel void kernel_tanh( constant float GELU_COEF_A = 0.044715f; constant float GELU_QUICK_COEF = -1.702f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; +constant float SQRT_2_INV = 0.70710678118654752440084436210484f; kernel void kernel_gelu( device const float * src0, @@ -897,6 +898,42 @@ kernel void kernel_gelu_quick_4( dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); } +// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation +// ref: https://www.johndcook.com/blog/python_erf/ +constant float p_erf = 0.3275911f; +constant float a1_erf = 0.254829592f; +constant float a2_erf = -0.284496736f; +constant float a3_erf = 1.421413741f; +constant float a4_erf = -1.453152027f; +constant float a5_erf = 1.061405429f; + +template +T erf_approx(T x) { + T sign_x = sign(x); + x = fabs(x); + T t = 1.0f / (1.0f + p_erf * x); + T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); + return sign_x * y; +} + +kernel void kernel_gelu_erf( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV)); +} + +kernel void kernel_gelu_erf_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV)); +} + kernel void kernel_silu( device const float * src0, device float * dst, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index ccc600c8f..028b5237a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1112,9 +1112,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "HARDSWISH", "HARDSIGMOID", "EXP", + "GELU_ERF", }; -static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14"); +static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); @@ -2514,6 +2515,20 @@ struct ggml_tensor * ggml_gelu_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU); } +// ggml_gelu_erf + +struct ggml_tensor * ggml_gelu_erf( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_ERF); +} + +struct ggml_tensor * ggml_gelu_erf_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_ERF); +} + // ggml_gelu_quick struct ggml_tensor * ggml_gelu_quick( diff --git a/gguf-py/gguf/gguf_reader.py b/gguf-py/gguf/gguf_reader.py index 5991cdb76..d87e8f723 100644 --- a/gguf-py/gguf/gguf_reader.py +++ b/gguf-py/gguf/gguf_reader.py @@ -251,7 +251,7 @@ class GGUFReader: offs += curr_size return offs - orig_offs, aparts, data_idxs, types # We can't deal with this one. - raise ValueError('Unknown/unhandled field type {gtype}') + raise ValueError(f'Unknown/unhandled field type {gtype}') def _get_tensor_info_field(self, orig_offs: int) -> ReaderField: offs = orig_offs diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index ff50d3de3..1dcce18f7 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -896,7 +896,7 @@ class GGUFWriter: def add_remove_extra_whitespaces(self, value: bool) -> None: self.add_bool(Keys.Tokenizer.REMOVE_EXTRA_WS, value) - def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None: + def add_precompiled_charsmap(self, charsmap: bytes) -> None: self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap) def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: diff --git a/include/llama.h b/include/llama.h index 8437a301c..a3512e539 100644 --- a/include/llama.h +++ b/include/llama.h @@ -612,10 +612,12 @@ extern "C" { // Returns the number of tokens in the KV cache (slow, use only for debug) // If a KV cell has multiple sequences assigned to it, it will be counted multiple times - LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx); + DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx), + "Use llama_kv_self_seq_pos_max() instead"); // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) - LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx); + DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx), + "Use llama_kv_self_seq_pos_max() instead"); // Clear the KV cache - both cell info is erased and KV data is zeroed LLAMA_API void llama_kv_self_clear( diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index a88b2fe30..b98e3256c 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -1,5 +1,6 @@ #include "llama-batch.h" +#include #include #include @@ -281,9 +282,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0 batch = in_batch; GGML_ASSERT(batch.n_tokens > 0); if (!batch.pos) { + assert(p0 >= 0); pos.resize(batch.n_tokens); for (int32_t i = 0; i < batch.n_tokens; i++) { - pos[i] = i + p0; + pos[i] = p0 + i; } batch.pos = pos.data(); } diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a1dec63aa..2cad3ff69 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -857,11 +857,17 @@ int llama_context::decode(llama_batch & inp_batch) { return -1; } + if (!inp_batch.pos) { + if (inp_batch.seq_id) { + LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__); + return -1; + } + } + llama_kv_cache * kv_self = static_cast(memory.get()); // temporary allocate memory for the input batch if needed - // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1); + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1); const llama_batch & batch = batch_allocr.batch; @@ -2292,22 +2298,47 @@ int32_t llama_apply_adapter_cvec( // kv cache // +// deprecated int32_t llama_kv_self_n_tokens(const llama_context * ctx) { const auto * kv = ctx->get_kv_self(); if (!kv) { return 0; } - return kv->get_n_tokens(); + int32_t res = 0; + + for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) { + const llama_pos p0 = kv->seq_pos_min(s); + const llama_pos p1 = kv->seq_pos_max(s); + + if (p0 >= 0) { + res += (p1 - p0) + 1; + } + } + + return res; } +// deprecated +// note: this is the same as above - will be removed anyway, so it's ok int32_t llama_kv_self_used_cells(const llama_context * ctx) { const auto * kv = ctx->get_kv_self(); if (!kv) { return 0; } - return kv->get_used_cells(); + int32_t res = 0; + + for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) { + const llama_pos p0 = kv->seq_pos_min(s); + const llama_pos p1 = kv->seq_pos_max(s); + + if (p0 >= 0) { + res += (p1 - p0) + 1; + } + } + + return res; } void llama_kv_self_clear(llama_context * ctx) { diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 8c9e7080e..42bc02855 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1236,8 +1236,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() auto inp = std::make_unique(hparams, cparams, kv_self); { - GGML_ASSERT(hparams.n_swa_pattern == 1 && "Use llama_kv_cache_unified_iswa for SWA"); - GGML_ASSERT(hparams.n_swa == 0 && "Use llama_kv_cache_unified_iswa for SWA"); + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); const auto n_kv = kv_self->get_n(); @@ -1312,8 +1311,8 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } - if (hparams.n_swa_pattern > 1) { - GGML_ASSERT(hparams.n_swa > 0 && "Use llama_kv_cache_unified for non-SWA"); + { + GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA"); const auto n_kv = kv_self->get_kv_swa()->get_n(); diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 90dfe7a7f..4f84e56b3 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -72,7 +72,7 @@ uint32_t llama_hparams::n_embd_v_s() const { bool llama_hparams::is_swa(uint32_t il) const { if (il < n_layer) { - return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1); + return n_swa_pattern == 0 || (il % n_swa_pattern < (n_swa_pattern - 1)); } GGML_ABORT("fatal error"); diff --git a/src/llama-hparams.h b/src/llama-hparams.h index f865cbaea..5222eedcf 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -104,7 +104,18 @@ struct llama_hparams { llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; uint32_t n_swa = 0; // the size of the sliding window (0 - no SWA) - uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention + uint32_t n_swa_pattern = 1; // this value n means that every nth layer is dense (i.e. non-SWA) + // by default n == 1, all layers are dense + // note that if n_swa_pattern == 0, all layers are SWA + // example: n_swa_pattern = 3 + // il == 0: swa + // il == 1: swa + // il == 2: dense + // il == 3: swa + // il == 4: swa + // il == 5: dense + // il == 6: swa + // etc ... // for State Space Models uint32_t ssm_d_conv = 0; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index ca7390280..81235a4af 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -30,13 +30,14 @@ llama_kv_cache_unified::llama_kv_cache_unified( bool v_trans, bool offload, uint32_t kv_size, - uint32_t padding, + uint32_t n_seq_max, + uint32_t n_pad, uint32_t n_swa, - llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding), n_swa(n_swa), swa_type(swa_type) { - GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding"); + llama_swa_type swa_type) : + model(model), hparams(model.hparams), v_trans(v_trans), + n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { - this->type_k = type_k; - this->type_v = type_v; + GGML_ASSERT(kv_size % n_pad == 0); // create a context for each buffer type std::map ctx_map; @@ -129,8 +130,8 @@ llama_kv_cache_unified::llama_kv_cache_unified( const size_t memory_size_k = size_k_bytes(); const size_t memory_size_v = size_v_bytes(); - LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6d cells, %3d layers), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } @@ -442,7 +443,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { void llama_kv_cache_unified::defrag_sched(float thold) { // - do not defrag small contexts (i.e. < 2048 tokens) // - count the padding towards the number of used tokens - const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + padding)/n)) : 0.0f; + const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + n_pad)/n)) : 0.0f; // queue defragmentation for next llama_kv_cache_update if (fragmentation > thold) { @@ -558,7 +559,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding))); + n = std::min(size, std::max(n_pad, GGML_PAD(cell_max(), n_pad))); #ifdef FIND_SLOT_DEBUG LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa); @@ -567,20 +568,6 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { return true; } -int32_t llama_kv_cache_unified::get_n_tokens() const { - int32_t result = 0; - - for (uint32_t i = 0; i < size; i++) { - result += cells[i].seq_id.size(); - } - - return result; -} - -int32_t llama_kv_cache_unified::get_used_cells() const { - return used; -} - bool llama_kv_cache_unified::get_can_shift() const { return true; } @@ -802,16 +789,6 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama } } -llama_pos llama_kv_cache_unified::get_pos_max() const { - llama_pos pos_max = -1; - - for (const auto & cell : cells) { - pos_max = std::max(pos_max, cell.pos); - } - - return pos_max; -} - size_t llama_kv_cache_unified::total_size() const { size_t size = 0; @@ -1501,11 +1478,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell llama_seq_id seq_id; io.read_to(&seq_id, sizeof(seq_id)); - // TODO: llama_kv_cache_unified should have a notion of max sequences - //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { - if (seq_id < 0) { - //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); - LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id); + if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max); return false; } @@ -1655,17 +1629,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( ggml_type type_v, bool v_trans, bool offload, - uint32_t kv_size, bool swa_full, + uint32_t kv_size, uint32_t n_seq_max, uint32_t n_batch, - uint32_t padding) : hparams(model.hparams) { + uint32_t n_pad) : hparams(model.hparams) { llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; const uint32_t size_base = kv_size; - uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, padding)); + uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad)); // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning if (swa_full) { @@ -1680,14 +1654,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( kv_base = std::make_unique( model, std::move(filter_base), type_k, type_v, - v_trans, offload, size_base, padding, + v_trans, offload, size_base, n_seq_max, n_pad, 0, LLAMA_SWA_TYPE_NONE); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); kv_swa = std::make_unique( model, std::move(filter_swa), type_k, type_v, - v_trans, offload, size_swa, padding, + v_trans, offload, size_swa, n_seq_max, n_pad, hparams.n_swa, hparams.swa_type); } @@ -1810,18 +1784,6 @@ bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) { return res; } -int32_t llama_kv_cache_unified_iswa::get_n_tokens() const { - return kv_base->get_n_tokens(); -} - -int32_t llama_kv_cache_unified_iswa::get_used_cells() const { - return kv_base->get_used_cells(); -} - -llama_pos llama_kv_cache_unified_iswa::get_pos_max() const { - return kv_base->get_pos_max(); -} - bool llama_kv_cache_unified_iswa::get_can_shift() const { return kv_base->get_size() == kv_swa->get_size(); } @@ -1853,19 +1815,17 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( ggml_type type_k, ggml_type type_v, bool offload, - uint32_t kv_size) : hparams(model.hparams) { + uint32_t kv_size, + uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { const int32_t n_layer = hparams.n_layer; - LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n", - __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer); + LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n", + __func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer); head = 0; size = kv_size; used = 0; - this->type_k = type_k; - this->type_v = type_v; - cells.clear(); cells.resize(kv_size); @@ -2203,8 +2163,8 @@ void llama_kv_cache_recurrent::commit() { pending.ranges.clear(); } -bool llama_kv_cache_recurrent::update(llama_context & lctx) { - GGML_UNUSED(lctx); +bool llama_kv_cache_recurrent::update(llama_context & ctx) { + GGML_UNUSED(ctx); return false; } @@ -2265,7 +2225,7 @@ bool llama_kv_cache_recurrent::find_slot( if (seq_id < 0 || (uint32_t) seq_id >= size) { // too big seq_id // TODO: would it be possible to resize the cache instead? - LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size); + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max); return false; } if (j > 0) { @@ -2408,29 +2368,6 @@ bool llama_kv_cache_recurrent::find_slot( return n >= n_seqs; } -int32_t llama_kv_cache_recurrent::get_n_tokens() const { - int32_t result = 0; - - for (uint32_t i = 0; i < size; i++) { - result += cells[i].seq_id.size(); - } - - return result; -} - -int32_t llama_kv_cache_recurrent::get_used_cells() const { - return used; -} - -llama_pos llama_kv_cache_recurrent::get_pos_max() const { - llama_pos pos_max = -1; - for (const auto & cell : cells) { - pos_max = std::max(pos_max, cell.pos); - } - - return pos_max; -} - bool llama_kv_cache_recurrent::get_can_shift() const { return false; } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index bd0485bc6..191a1090a 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -55,10 +55,7 @@ struct llama_kv_cache : public llama_memory_i { // ============================================================================================================= // getters - virtual int32_t get_n_tokens() const = 0; - virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache - virtual llama_pos get_pos_max() const = 0; - virtual bool get_can_shift() const = 0; + virtual bool get_can_shift() const = 0; bool get_can_edit() const override { return get_can_shift(); } @@ -108,7 +105,8 @@ public: bool v_trans, bool offload, uint32_t kv_size, - uint32_t padding, + uint32_t n_seq_max, + uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type); @@ -150,12 +148,6 @@ public: // to the first cell of the slot. bool find_slot(const llama_ubatch & batch) override; - int32_t get_n_tokens() const override; - int32_t get_used_cells() const override; - - // TODO: better data structures to reduce the cost of this operation - llama_pos get_pos_max() const override; - bool get_can_shift() const override; // state write/load @@ -228,16 +220,15 @@ private: // computed before each graph build uint32_t n = 0; - // required padding - uint32_t padding = 1; + const uint32_t n_seq_max = 1; - ggml_type type_k = GGML_TYPE_F16; - ggml_type type_v = GGML_TYPE_F16; + // required padding + const uint32_t n_pad = 1; // SWA - uint32_t n_swa = 0; + const uint32_t n_swa = 0; - llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; + const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; std::vector ctxs; std::vector bufs; @@ -317,11 +308,11 @@ public: ggml_type type_v, bool v_trans, bool offload, - uint32_t kv_size, bool swa_full, + uint32_t kv_size, uint32_t n_seq_max, uint32_t n_batch, - uint32_t padding); + uint32_t n_pad); ~llama_kv_cache_unified_iswa() = default; @@ -358,12 +349,6 @@ public: bool find_slot(const llama_ubatch & batch) override; - int32_t get_n_tokens() const override; - int32_t get_used_cells() const override; - - // TODO: better data structures to reduce the cost of this operation - llama_pos get_pos_max() const override; - bool get_can_shift() const override; // state write/load @@ -432,7 +417,8 @@ public: ggml_type type_k, ggml_type type_v, bool offload, - uint32_t kv_size); + uint32_t kv_size, + uint32_t n_seq_max); ~llama_kv_cache_recurrent() = default; @@ -444,7 +430,7 @@ public: bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; - void seq_keep(llama_seq_id seq_id) override; + void seq_keep(llama_seq_id seq_id) override; void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; @@ -458,7 +444,7 @@ public: void restore() override; void commit() override; - bool update(llama_context & lctx) override; + bool update(llama_context & ctx) override; void defrag_sched(float thold) override; @@ -469,12 +455,6 @@ public: bool find_slot(const llama_ubatch & batch) override; - int32_t get_n_tokens() const override; - int32_t get_used_cells() const override; - - // TODO: better data structures to reduce the cost of this operation - llama_pos get_pos_max() const override; - bool get_can_shift() const override; // TODO: temporary methods - they are not really const as they do const_cast<>, fix this @@ -514,8 +494,7 @@ private: std::vector ranges; } pending; - ggml_type type_k = GGML_TYPE_F16; - ggml_type type_v = GGML_TYPE_F16; + const uint32_t n_seq_max = 1; std::vector ctxs; std::vector bufs; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 295d4ca18..143073539 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -858,43 +858,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } - // for backward compatibility ; see: https://github.com/ggerganov/llama.cpp/pull/8931 - if ((hparams.n_layer == 32 || hparams.n_layer == 40) && hparams.n_ctx_train == 4096) { - // default value for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct - LLAMA_LOG_WARN("%s: assuming n_swa = 2047 for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct\n", __func__); + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - - hparams.n_swa = 2047; - } else if (hparams.n_layer == 32 && hparams.n_head_kv(0) == 32 && hparams.n_ctx_train == 131072) { - // default value for Phi-3-mini-128k-instruct - LLAMA_LOG_WARN("%s: assuming no SWA for Phi-3-mini-128k-instruct\n", __func__); + if (found_swa && hparams.n_swa > 0) { + LLAMA_LOG_WARN("%s: Phi SWA is currently disabled - results might be suboptimal for some models (see %s)\n", + __func__, "https://github.com/ggml-org/llama.cpp/pull/13676"); + // TODO: fix conversion scripts to correctly populate `n_swa` and `n_swa_pattern` hparams.swa_type = LLAMA_SWA_TYPE_NONE; - hparams.n_swa = hparams.n_ctx_train; - hparams.n_swa_pattern = 1; - } else if (hparams.n_layer == 40 && hparams.n_ctx_train == 131072) { - // default value for Phi-3-medium-128k-instruct - LLAMA_LOG_WARN("%s: assuming no SWA for Phi-3-medium-128k-instruct\n", __func__); - - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - - hparams.n_swa = hparams.n_ctx_train; - hparams.n_swa_pattern = 1; - } - - bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - if (!found_swa && hparams.n_swa == 0) { - throw std::runtime_error("invalid value for sliding_window"); - } - - if (hparams.n_swa > hparams.n_ctx_train) { - LLAMA_LOG_WARN("%s: unexpected n_swa: %d >= %d, disabling SWA\n", __func__, hparams.n_swa, hparams.n_ctx_train); - - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - - hparams.n_swa = hparams.n_ctx_train; + hparams.n_swa = 0; hparams.n_swa_pattern = 1; } } break; @@ -7468,8 +7441,9 @@ struct llm_build_phi2 : public llm_graph_context { } }; -struct llm_build_phi3_iswa : public llm_graph_context { - llm_build_phi3_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +template +struct llm_build_phi3 : public llm_graph_context { + llm_build_phi3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -7483,7 +7457,14 @@ struct llm_build_phi3_iswa : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified_iswa(); + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_unified_iswa(); + } else { + inp_attn = build_attn_inp_kv_unified(); + } for (int il = 0; il < n_layer; ++il) { auto * residual = inpL; @@ -13322,7 +13303,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, GGML_TYPE_F32, GGML_TYPE_F32, cparams.offload_kqv, - std::max((uint32_t) 1, cparams.n_seq_max)); + std::max((uint32_t) 1, cparams.n_seq_max), + cparams.n_seq_max); } break; default: { @@ -13332,19 +13314,23 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); - if (hparams.n_swa > 0) { + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + GGML_ASSERT(hparams.n_swa_pattern != 1); + res = new llama_kv_cache_unified_iswa( *this, params.type_k, params.type_v, !cparams.flash_attn, cparams.offload_kqv, - cparams.n_ctx, params.swa_full, + cparams.n_ctx, cparams.n_seq_max, cparams.n_batch, padding); } else { + GGML_ASSERT(hparams.n_swa_pattern == 1); + res = new llama_kv_cache_unified( *this, nullptr, @@ -13353,6 +13339,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, !cparams.flash_attn, cparams.offload_kqv, cparams.n_ctx, + cparams.n_seq_max, padding, hparams.n_swa, hparams.swa_type); @@ -13453,7 +13440,11 @@ llm_graph_result_ptr llama_model::build_graph( case LLM_ARCH_PHI3: case LLM_ARCH_PHIMOE: { - llm = std::make_unique(*this, params, gf); + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + llm = std::make_unique> (*this, params, gf); + } else { + llm = std::make_unique>(*this, params, gf); + } } break; case LLM_ARCH_PLAMO: { diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index e9e2d442e..6ed73a5f5 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -95,3 +95,5 @@ bool clip_is_pixtral(const struct clip_ctx * ctx); void set_clip_uses_gpu(bool usegpu); bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); + +bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype) ; \ No newline at end of file diff --git a/tools/server/server.cpp b/tools/server/server.cpp index f8b7ff062..1a08e30d2 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -951,7 +951,7 @@ struct server_task_result_cmpl_partial : server_task_result { } json to_json_oaicompat_chat() { - bool first = n_decoded == 0; + bool first = n_decoded == 1; std::time_t t = std::time(0); json choices; @@ -962,15 +962,18 @@ struct server_task_result_cmpl_partial : server_task_result { {"delta", json{{"role", "assistant"}}}}}); } else { // We have to send this as two updates to conform to openai behavior + // initial_ret is the role message for stream=True json initial_ret = json{{"choices", json::array({json{ {"finish_reason", nullptr}, {"index", 0}, {"delta", json{ - {"role", "assistant"} + {"role", "assistant"}, + {"content", ""} }}}})}, {"created", t}, {"id", oaicompat_cmpl_id}, {"model", oaicompat_model}, + {"system_fingerprint", build_info}, {"object", "chat.completion.chunk"}}; json second_ret = json{ @@ -982,8 +985,19 @@ struct server_task_result_cmpl_partial : server_task_result { {"created", t}, {"id", oaicompat_cmpl_id}, {"model", oaicompat_model}, + {"system_fingerprint", build_info}, {"object", "chat.completion.chunk"}}; + if (prob_output.probs.size() > 0) { + second_ret["choices"][0]["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + + if (timings.prompt_n >= 0) { + second_ret.push_back({"timings", timings.to_json()}); + } + return std::vector({initial_ret, second_ret}); } } else { @@ -1137,9 +1151,6 @@ struct server_task_result_metrics : server_task_result { int n_tasks_deferred; int64_t t_start; - int32_t kv_cache_tokens_count; - int32_t kv_cache_used_cells; - // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields uint64_t n_prompt_tokens_processed_total = 0; uint64_t t_prompt_processing_total = 0; @@ -1179,9 +1190,6 @@ struct server_task_result_metrics : server_task_result { { "n_decode_total", n_decode_total }, { "n_busy_slots_total", n_busy_slots_total }, - { "kv_cache_tokens_count", kv_cache_tokens_count }, - { "kv_cache_used_cells", kv_cache_used_cells }, - { "slots", slots_data }, }; } @@ -2771,9 +2779,6 @@ struct server_context { res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); res->t_start = metrics.t_start; - res->kv_cache_tokens_count = llama_kv_self_n_tokens(ctx); - res->kv_cache_used_cells = llama_kv_self_used_cells(ctx); - res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; res->t_prompt_processing_total = metrics.t_prompt_processing_total; res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; @@ -3336,6 +3341,37 @@ struct server_context { common_set_adapter_lora(ctx, slot_batched->lora); } + const bool do_encode = (params_base.embedding || params_base.reranking); + + // pad the batch so that batch.n_tokens >= n_slots + // TODO: temporary workaround for https://github.com/ggml-org/llama.cpp/issues/13689 + if (do_encode) { + const int n_slots = slots.size(); + + if (batch.n_tokens < n_slots) { + std::set seq_ids; + for (int j = 0; j < batch.n_tokens; ++j) { + seq_ids.insert(batch.seq_id[j][0]); + } + + // find unused sequence id + llama_seq_id seq_id = -1; + for (int i = 0; i < n_slots; ++i) { + if (seq_ids.find(i) == seq_ids.end()) { + seq_id = i; + } + } + + const int n_add = n_slots - batch.n_tokens; + + SRV_WRN("adding %d dummy tokens to the batch, seq_id = %d\n", n_add, seq_id); + + for (int j = 0; j < n_add; ++j) { + common_batch_add(batch, 0, j, { seq_id }, false); + } + } + } + // process the created batch of tokens for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); @@ -3352,7 +3388,7 @@ struct server_context { int ret = 0; - if (params_base.embedding || params_base.reranking) { + if (do_encode) { ret = llama_encode(ctx, batch_view); } else { ret = llama_decode(ctx, batch_view); @@ -3361,14 +3397,29 @@ struct server_context { metrics.on_decoded(slots); if (ret != 0) { - if (n_batch == 1 || ret < 0) { - // if you get here, it means the KV cache is full - try increasing it via the context size - SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); - for (auto & slot : slots) { - slot.release(); - send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); + { + std::string err; + + if (n_batch == 1 && ret == 1) { + err = "Context size has been exceeded."; + } + + if (ret == -1) { + err = "Invalid input batch."; + } + + if (ret < -1) { + err = "Compute error."; + } + + if (!err.empty()) { + SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); + for (auto & slot : slots) { + slot.release(); + send_error(slot, err); + } + break; } - break; // break loop of n_batch } // retry with half the batch size to try to find a free slot in the KV cache @@ -3702,6 +3753,7 @@ int main(int argc, char ** argv) { "/health", "/models", "/v1/models", + "/api/tags" }; // If API key is not set, skip validation @@ -3740,7 +3792,7 @@ int main(int argc, char ** argv) { if (req.path == "/" || tmp.back() == "html") { res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); res.status = 503; - } else if (req.path == "/models" || req.path == "/v1/models") { + } else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") { // allow the models endpoint to be accessed during loading return true; } else { @@ -3883,14 +3935,6 @@ int main(int argc, char ** argv) { {"name", "predicted_tokens_seconds"}, {"help", "Average generation throughput in tokens/s."}, {"value", res_metrics->n_tokens_predicted ? 1.e3 / res_metrics->t_tokens_generation * res_metrics->n_tokens_predicted : 0.} - },{ - {"name", "kv_cache_usage_ratio"}, - {"help", "KV-cache usage. 1 means 100 percent usage."}, - {"value", 1. * res_metrics->kv_cache_used_cells / params.n_ctx} - },{ - {"name", "kv_cache_tokens"}, - {"help", "KV-cache tokens."}, - {"value", (uint64_t) res_metrics->kv_cache_tokens_count} },{ {"name", "requests_processing"}, {"help", "Number of requests processing."}, @@ -4086,6 +4130,19 @@ int main(int argc, char ** argv) { { "llama.context_length", ctx_server.slots.back().n_ctx, }, } }, + {"modelfile", ""}, + {"parameters", ""}, + {"template", common_chat_templates_source(ctx_server.chat_templates.get())}, + {"details", { + {"parent_model", ""}, + {"format", "gguf"}, + {"family", ""}, + {"families", {""}}, + {"parameter_size", ""}, + {"quantization_level", ""} + }}, + {"model_info", ""}, + {"capabilities", {"completion"}} }; res_ok(res, data); @@ -4411,6 +4468,28 @@ int main(int argc, char ** argv) { } json models = { + {"models", { + { + {"name", params.model_alias.empty() ? params.model.path : params.model_alias}, + {"model", params.model_alias.empty() ? params.model.path : params.model_alias}, + {"modified_at", ""}, + {"size", ""}, + {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash + {"type", "model"}, + {"description", ""}, + {"tags", {""}}, + {"capabilities", {"completion"}}, + {"parameters", ""}, + {"details", { + {"parent_model", ""}, + {"format", "gguf"}, + {"family", ""}, + {"families", {""}}, + {"parameter_size", ""}, + {"quantization_level", ""} + }} + } + }}, {"object", "list"}, {"data", { { @@ -4420,7 +4499,7 @@ int main(int argc, char ** argv) { {"owned_by", "llamacpp"}, {"meta", model_meta}, }, - }} + }} }; res_ok(res, models); @@ -4748,11 +4827,13 @@ int main(int argc, char ** argv) { svr->Post("/api/show", handle_api_show); svr->Get ("/models", handle_models); // public endpoint (no API key check) svr->Get ("/v1/models", handle_models); // public endpoint (no API key check) + svr->Get ("/api/tags", handle_models); // ollama specific endpoint. public endpoint (no API key check) svr->Post("/completion", handle_completions); // legacy svr->Post("/completions", handle_completions); svr->Post("/v1/completions", handle_completions_oai); svr->Post("/chat/completions", handle_chat_completions); svr->Post("/v1/chat/completions", handle_chat_completions); + svr->Post("/api/chat", handle_chat_completions); // ollama specific endpoint svr->Post("/infill", handle_infill); svr->Post("/embedding", handle_embeddings); // legacy svr->Post("/embeddings", handle_embeddings); diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 491cb3a5d..bab5d005d 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -71,8 +71,14 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte }) content = "" last_cmpl_id = None - for data in res: + for i, data in enumerate(res): choice = data["choices"][0] + if i == 0: + # Check first role message for stream=True + assert choice["delta"]["content"] == "" + assert choice["delta"]["role"] == "assistant" + else: + assert "role" not in choice["delta"] assert data["system_fingerprint"].startswith("b") assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future if last_cmpl_id is None: @@ -242,12 +248,18 @@ def test_chat_completion_with_timings_per_token(): "stream": True, "timings_per_token": True, }) - for data in res: - assert "timings" in data - assert "prompt_per_second" in data["timings"] - assert "predicted_per_second" in data["timings"] - assert "predicted_n" in data["timings"] - assert data["timings"]["predicted_n"] <= 10 + for i, data in enumerate(res): + if i == 0: + # Check first role message for stream=True + assert data["choices"][0]["delta"]["content"] == "" + assert data["choices"][0]["delta"]["role"] == "assistant" + else: + assert "role" not in data["choices"][0]["delta"] + assert "timings" in data + assert "prompt_per_second" in data["timings"] + assert "predicted_per_second" in data["timings"] + assert "predicted_n" in data["timings"] + assert data["timings"]["predicted_n"] <= 10 def test_logprobs(): @@ -295,17 +307,23 @@ def test_logprobs_stream(): ) output_text = '' aggregated_text = '' - for data in res: + for i, data in enumerate(res): choice = data.choices[0] - if choice.finish_reason is None: - if choice.delta.content: - output_text += choice.delta.content - assert choice.logprobs is not None - assert choice.logprobs.content is not None - for token in choice.logprobs.content: - aggregated_text += token.token - assert token.logprob <= 0.0 - assert token.bytes is not None - assert token.top_logprobs is not None - assert len(token.top_logprobs) > 0 + if i == 0: + # Check first role message for stream=True + assert choice.delta.content == "" + assert choice.delta.role == "assistant" + else: + assert choice.delta.role is None + if choice.finish_reason is None: + if choice.delta.content: + output_text += choice.delta.content + assert choice.logprobs is not None + assert choice.logprobs.content is not None + for token in choice.logprobs.content: + aggregated_text += token.token + assert token.logprob <= 0.0 + assert token.bytes is not None + assert token.top_logprobs is not None + assert len(token.top_logprobs) > 0 assert aggregated_text == output_text