diff --git a/common/common.cpp b/common/common.cpp index 7fc695545..7bcc20c76 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -942,7 +942,7 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } - if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) { + if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) { LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__); params.ctx_shift = false; } @@ -1049,7 +1049,7 @@ struct common_init_result common_init_from_params(common_params & params) { if (llama_model_has_decoder(model)) { llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch))); } - llama_kv_self_clear(lctx); + llama_memory_clear(llama_get_memory(lctx), true); llama_synchronize(lctx); llama_perf_context_reset(lctx); llama_set_warmup(lctx, false); diff --git a/common/speculative.cpp b/common/speculative.cpp index ccad70fa9..843bd1ddb 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -144,6 +144,8 @@ llama_tokens common_speculative_gen_draft( auto & smpl = spec->smpl; auto & prompt = spec->prompt; + auto * mem = llama_get_memory(ctx); + int reuse_i = 0; int reuse_n = 0; @@ -173,7 +175,7 @@ llama_tokens common_speculative_gen_draft( result.reserve(params.n_draft); if (reuse_n == 0) { - llama_kv_self_clear(ctx); + llama_memory_clear(mem, false); prompt.clear(); } else { @@ -192,14 +194,14 @@ llama_tokens common_speculative_gen_draft( } if (reuse_i > 0) { - llama_kv_self_seq_rm (ctx, 0, 0, reuse_i); - llama_kv_self_seq_add(ctx, 0, reuse_i, -1, -reuse_i); + llama_memory_seq_rm (mem, 0, 0, reuse_i); + llama_memory_seq_add(mem, 0, reuse_i, -1, -reuse_i); prompt.erase(prompt.begin(), prompt.begin() + reuse_i); } if (reuse_n < (int) prompt.size()) { - llama_kv_self_seq_rm (ctx, 0, reuse_n, -1); + llama_memory_seq_rm (mem, 0, reuse_n, -1); prompt.erase(prompt.begin() + reuse_n, prompt.end()); } diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index ea6da69ba..681929d27 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -37,7 +37,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); // clear previous kv_cache values (irrelevant for embeddings) - llama_memory_clear(llama_get_memory(ctx)); + llama_memory_clear(llama_get_memory(ctx), true); // run model LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); @@ -335,4 +335,4 @@ int main(int argc, char ** argv) { llama_backend_free(); return 0; -} \ No newline at end of file +} diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index e76acd6de..b583556a3 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -616,9 +616,8 @@ static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; ggml_cuda_set_device(ctx->device); - CUDA_CHECK(cudaDeviceSynchronize()); - CUDA_CHECK(cudaMemset(ctx->dev_ptr, value, buffer->size)); - CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaMemsetAsync(ctx->dev_ptr, value, buffer->size, cudaStreamPerThread)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); } static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = { @@ -1145,7 +1144,6 @@ typedef void (*ggml_cuda_op_mul_mat_t)( static cudaError_t ggml_cuda_cpy_tensor_2d( void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) { - GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer)); const char * src_ptr = (const char *) src->data; char * dst_ptr = (char *) dst; @@ -1428,8 +1426,6 @@ static void ggml_cuda_op_mul_mat( const int64_t nb2 = dst->nb[2]; const int64_t nb3 = dst->nb[3]; - GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer)); - GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer)); ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context; ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context; @@ -1751,7 +1747,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); - GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); + GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft)); GGML_ASSERT(src0->type == GGML_TYPE_F16); // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst. diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp index 88ec13ea2..8b952db43 100644 --- a/ggml/src/ggml-sycl/quants.hpp +++ b/ggml/src/ggml-sycl/quants.hpp @@ -14,12 +14,13 @@ #ifndef GGML_SYCL_QUANTS_HPP #define GGML_SYCL_QUANTS_HPP +#include + #include "ggml-common.h" #include "ggml.h" namespace ggml_sycl_reordered { - // The reordered block moves quants (qs) and scales(d) to two // uniform regions of memory that is contiguous in the same tensor. // What this means is that instead of having: @@ -32,7 +33,6 @@ namespace ggml_sycl_reordered { template struct block_q_t; - // qk number of weights / quants in a block // qr number of weights in a byte (described as 'before dequantization') // for quantization types that has low and high bits split, qr is calculated with @@ -47,10 +47,12 @@ template <> struct block_q_t { static constexpr uint32_t vdr_mmvq = 2; }; - static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); } + static constexpr std::pair get_block_offset(const int block_index, const int /* nblocks */) { + return { block_index * (traits::qk / traits::qr), 0 }; + } - static constexpr int get_d_offset(int nrows, int ncols, const int block_index) { - return (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half); + static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { + return { (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half), 0 }; } static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } @@ -64,20 +66,46 @@ template <> struct block_q_t { static constexpr uint32_t vdr_mmvq = 2; }; - static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); } + static constexpr std::pair get_block_offset(const int block_index, const int /* nblocks */) { + return { block_index * (traits::qk / traits::qr), 0 }; + } - static constexpr int get_d_offset(int nrows, int ncols, const int block_index) { + static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { auto nblocks = (nrows * (ncols / traits::qk)); - return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)); + return { nblocks * (QK_K / 2), + (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)) }; } static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; } - - constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; } }; +template <> struct block_q_t { + struct traits { + static constexpr uint32_t qk = QK_K; + static constexpr uint32_t qi = QI6_K; + static constexpr uint32_t qr = QR6_K; + static constexpr uint32_t vdr_mmvq = 1; + }; + + static constexpr std::pair get_block_offset(const int block_index, const int n_blocks) { + auto low_bits_index = block_index * (traits::qk / traits::qr); + // the index of high bits it's after all low bits + auto high_bits_index = n_blocks * (QK_K / 2) + (block_index * (QK_K / 4)); + return { low_bits_index, high_bits_index }; + } + + static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { + auto nblocks = (nrows * (ncols / traits::qk)); + auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 4); + auto block_scales = total_qs_bytes + block_index * (QK_K / 16); + auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16); + return { block_scales, sb_scale }; + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } +}; } // namespace ggml_sycl_reordered #endif // GGML_SYCL_QUANTS_HPP diff --git a/include/llama.h b/include/llama.h index 3c4d189ce..499eaac1f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -628,7 +628,10 @@ extern "C" { // // Clear the memory contents - LLAMA_API void llama_memory_clear(llama_memory_t mem); + // If data == true, the data buffers will also be cleared together with the metadata + LLAMA_API void llama_memory_clear( + llama_memory_t mem, + bool data); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails @@ -708,74 +711,82 @@ extern "C" { "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); // Clear the KV cache - both cell info is erased and KV data is zeroed - LLAMA_API void llama_kv_self_clear( - struct llama_context * ctx); + DEPRECATED(LLAMA_API void llama_kv_self_clear( + struct llama_context * ctx), + "Use llama_memory_clear() instead"); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails // seq_id < 0 : match any sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API bool llama_kv_self_seq_rm( + DEPRECATED(LLAMA_API bool llama_kv_self_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, - llama_pos p1); + llama_pos p1), + "Use llama_memory_seq_rm() instead"); // Copy all tokens that belong to the specified sequence to another sequence // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_self_seq_cp( + DEPRECATED(LLAMA_API void llama_kv_self_seq_cp( struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, - llama_pos p1); + llama_pos p1), + "Use llama_memory_seq_cp() instead"); // Removes all tokens that do not belong to the specified sequence - LLAMA_API void llama_kv_self_seq_keep( + DEPRECATED(LLAMA_API void llama_kv_self_seq_keep( struct llama_context * ctx, - llama_seq_id seq_id); + llama_seq_id seq_id), + "Use llama_memory_seq_keep() instead"); // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly: // - lazily on next llama_decode() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_self_seq_add( + DEPRECATED(LLAMA_API void llama_kv_self_seq_add( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, - llama_pos delta); + llama_pos delta), + "Use llama_memory_seq_add() instead"); // Integer division of the positions by factor of `d > 1` // If the KV cache is RoPEd, the KV data is updated accordingly: // - lazily on next llama_decode() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_self_seq_div( + DEPRECATED(void llama_kv_self_seq_div( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, - int d); + int d), + "Use llama_memory_seq_div() instead"); // Returns the smallest position present in the KV cache for the specified sequence // This is typically non-zero only for SWA caches // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache // Return -1 if the sequence is empty - LLAMA_API llama_pos llama_kv_self_seq_pos_min( + DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_min( struct llama_context * ctx, - llama_seq_id seq_id); + llama_seq_id seq_id), + "Use llama_memory_seq_pos_min() instead"); // Returns the largest position present in the KV cache for the specified sequence // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache // Return -1 if the sequence is empty - LLAMA_API llama_pos llama_kv_self_seq_pos_max( + DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_max( struct llama_context * ctx, - llama_seq_id seq_id); + llama_seq_id seq_id), + "Use llama_memory_seq_pos_max() instead"); // Defragment the KV cache // This will be applied: @@ -784,7 +795,8 @@ extern "C" { "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'"); // Check if the context supports KV cache shifting - LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx); + DEPRECATED(LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx), + "use llama_memory_can_shift() instead"); // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx), diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index c0590e105..43fa60a80 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -200,7 +200,6 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, - { LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" }, { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, @@ -1707,8 +1706,14 @@ static const std::map LLM_TENSOR_INFOS = { LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} std::string LLM_KV::operator()(llm_kv kv) const { - return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix) - : ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); + std::string name = ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); + + if (suffix != nullptr) { + name += "."; + name += suffix; + } + + return name; } std::string LLM_TN_IMPL::str() const { diff --git a/src/llama-arch.h b/src/llama-arch.h index 930cb4eca..f3825528a 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -196,7 +196,6 @@ enum llm_kv { LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_CHAT_TEMPLATE, - LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, LLM_KV_TOKENIZER_FIM_PRE_ID, LLM_KV_TOKENIZER_FIM_SUF_ID, LLM_KV_TOKENIZER_FIM_MID_ID, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ea9064204..ed4b1cecc 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -123,7 +123,7 @@ llama_context::llama_context( __func__, n_ctx_per_seq, hparams.n_ctx_train); } - if (!params.swa_full && cparams.n_seq_max > 1) { + if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) { LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n", __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573"); } @@ -422,6 +422,7 @@ llama_memory_t llama_context::get_memory() const { return memory.get(); } +// deprecated void llama_context::kv_self_defrag_sched() { if (!memory) { return; @@ -430,6 +431,7 @@ void llama_context::kv_self_defrag_sched() { memory_force_optimize = true; } +// deprecated bool llama_context::kv_self_update(bool optimize) { if (!memory) { return false; @@ -2053,7 +2055,7 @@ void llama_context::opt_epoch_iter( const uint32_t n_batch = std::min(this->n_batch(), n_ctx); const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); - memory->clear(); + memory->clear(true); for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) { batch.n_tokens = n_batch; @@ -2426,8 +2428,12 @@ llama_memory_t llama_get_memory(const struct llama_context * ctx) { return ctx->get_memory(); } -void llama_memory_clear(llama_memory_t mem) { - mem->clear(); +void llama_memory_clear(llama_memory_t mem, bool data) { + if (!mem) { + return; + } + + mem->clear(data); } bool llama_memory_seq_rm( @@ -2435,6 +2441,10 @@ bool llama_memory_seq_rm( llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + if (!mem) { + return true; + } + return mem->seq_rm(seq_id, p0, p1); } @@ -2444,12 +2454,20 @@ void llama_memory_seq_cp( llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + if (!mem) { + return; + } + mem->seq_cp(seq_id_src, seq_id_dst, p0, p1); } void llama_memory_seq_keep( llama_memory_t mem, llama_seq_id seq_id) { + if (!mem) { + return; + } + mem->seq_keep(seq_id); } @@ -2459,6 +2477,10 @@ void llama_memory_seq_add( llama_pos p0, llama_pos p1, llama_pos delta) { + if (!mem) { + return; + } + mem->seq_add(seq_id, p0, p1, delta); } @@ -2468,22 +2490,38 @@ void llama_memory_seq_div( llama_pos p0, llama_pos p1, int d) { + if (!mem) { + return; + } + mem->seq_div(seq_id, p0, p1, d); } llama_pos llama_memory_seq_pos_min( llama_memory_t mem, llama_seq_id seq_id) { + if (!mem) { + return -1; + } + return mem->seq_pos_min(seq_id); } llama_pos llama_memory_seq_pos_max( llama_memory_t mem, llama_seq_id seq_id) { + if (!mem) { + return -1; + } + return mem->seq_pos_max(seq_id); } bool llama_memory_can_shift(llama_memory_t mem) { + if (!mem) { + return false; + } + return mem->get_can_shift(); } @@ -2534,15 +2572,17 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) { return res; } +// deprecated void llama_kv_self_clear(llama_context * ctx) { auto * kv = llama_get_memory(ctx); if (!kv) { return; } - llama_memory_clear(kv); + llama_memory_clear(kv, true); } +// deprecated bool llama_kv_self_seq_rm( llama_context * ctx, llama_seq_id seq_id, @@ -2556,6 +2596,7 @@ bool llama_kv_self_seq_rm( return llama_memory_seq_rm(kv, seq_id, p0, p1); } +// deprecated void llama_kv_self_seq_cp( llama_context * ctx, llama_seq_id seq_id_src, @@ -2570,6 +2611,7 @@ void llama_kv_self_seq_cp( llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1); } +// deprecated void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { auto * kv = llama_get_memory(ctx); if (!kv) { @@ -2579,6 +2621,7 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { llama_memory_seq_keep(kv, seq_id); } +// deprecated void llama_kv_self_seq_add( llama_context * ctx, llama_seq_id seq_id, @@ -2593,6 +2636,7 @@ void llama_kv_self_seq_add( llama_memory_seq_add(kv, seq_id, p0, p1, delta); } +// deprecated void llama_kv_self_seq_div( llama_context * ctx, llama_seq_id seq_id, @@ -2607,6 +2651,7 @@ void llama_kv_self_seq_div( llama_memory_seq_div(kv, seq_id, p0, p1, d); } +// deprecated llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) { auto * kv = llama_get_memory(ctx); if (!kv) { @@ -2616,6 +2661,7 @@ llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) { return llama_memory_seq_pos_min(kv, seq_id); } +// deprecated llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { auto * kv = llama_get_memory(ctx); if (!kv) { @@ -2631,6 +2677,7 @@ void llama_kv_self_defrag(llama_context * ctx) { ctx->kv_self_defrag_sched(); } +// deprecated bool llama_kv_self_can_shift(const llama_context * ctx) { auto * kv = llama_get_memory(ctx); if (!kv) { diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 50fe0c44d..b7148e054 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -659,6 +659,20 @@ ggml_tensor * llm_graph_context::build_ffn( cur = ggml_mul(ctx0, x0, x1); cb(cur, "ffn_mul", il); } break; + case LLM_FFN_GEGLU: + { + // Split into two equal parts + int64_t split_point = cur->ne[0] / 2; + // TODO: these conts should not be needed + ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); + ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); + + x0 = ggml_gelu(ctx0, x0); + cb(x0, "ffn_gelu", il); + + cur = ggml_mul(ctx0, x0, x1); + cb(cur, "ffn_geglu", il); + } break; } if (gate && type_gate == LLM_FFN_PAR) { diff --git a/src/llama-graph.h b/src/llama-graph.h index 2b1cfa5b7..28da6a522 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -36,6 +36,7 @@ enum llm_ffn_op_type { LLM_FFN_RELU, LLM_FFN_RELU_SQR, LLM_FFN_SWIGLU, + LLM_FFN_GEGLU, }; enum llm_ffn_gate_type { diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp index 7cfb3ea05..87fed8ded 100644 --- a/src/llama-kv-cache-recurrent.cpp +++ b/src/llama-kv-cache-recurrent.cpp @@ -117,18 +117,21 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( } } -void llama_kv_cache_recurrent::clear() { +void llama_kv_cache_recurrent::clear(bool data) { for (int32_t i = 0; i < (int32_t) size; ++i) { cells[i].pos = -1; cells[i].seq_id.clear(); cells[i].src = -1; cells[i].tail = -1; } + head = 0; used = 0; - for (auto & buf : bufs) { - ggml_backend_buffer_clear(buf.get(), 0); + if (data) { + for (auto & buf : bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } } } @@ -723,7 +726,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq if (!res) { if (seq_id == -1) { - clear(); + clear(true); } else { seq_rm(seq_id, -1, -1); } @@ -880,7 +883,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce return false; } - clear(); + clear(true); for (uint32_t i = 0; i < cell_count; ++i) { kv_cell & cell = cells[i]; diff --git a/src/llama-kv-cache-recurrent.h b/src/llama-kv-cache-recurrent.h index cb813dfe8..d1da12256 100644 --- a/src/llama-kv-cache-recurrent.h +++ b/src/llama-kv-cache-recurrent.h @@ -39,7 +39,7 @@ public: llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; - void clear() override; + void clear(bool data) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index 3aa606c84..28d182654 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -52,9 +52,9 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( hparams.n_swa, hparams.swa_type); } -void llama_kv_cache_unified_iswa::clear() { - kv_base->clear(); - kv_swa ->clear(); +void llama_kv_cache_unified_iswa::clear(bool data) { + kv_base->clear(data); + kv_swa ->clear(data); } bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h index 3fabcd6b8..3dbf33ed7 100644 --- a/src/llama-kv-cache-unified-iswa.h +++ b/src/llama-kv-cache-unified-iswa.h @@ -43,7 +43,7 @@ public: bool get_can_shift() const override; - void clear() override; + void clear(bool data) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index d26c0e1d7..4f1fb3ab5 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -129,13 +129,15 @@ llama_kv_cache_unified::llama_kv_cache_unified( } } -void llama_kv_cache_unified::clear() { +void llama_kv_cache_unified::clear(bool data) { cells.reset(); head = 0; - for (auto & buf : bufs) { - ggml_backend_buffer_clear(buf.get(), 0); + if (data) { + for (auto & buf : bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } } } @@ -1319,7 +1321,7 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i if (!res) { if (seq_id == -1) { - clear(); + clear(true); } else { seq_rm(seq_id, -1, -1); } @@ -1500,7 +1502,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell return false; } - clear(); + clear(true); for (uint32_t i = 0; i < cell_count; ++i) { llama_pos pos; diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index d01a9abd7..49f410ef6 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -68,7 +68,7 @@ public: bool get_can_shift() const override; - void clear() override; + void clear(bool data) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; diff --git a/src/llama-memory.h b/src/llama-memory.h index 5993b59be..991aae781 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -90,7 +90,8 @@ struct llama_memory_i { // ops // - virtual void clear() = 0; + // if data == true, the data buffers will also be cleared together with the metadata + virtual void clear(bool data) = 0; virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0; virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index aefc4f24f..730ecd7bd 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13888,7 +13888,7 @@ uint64_t llama_model_size(const llama_model * model) { } const char * llama_model_chat_template(const llama_model * model, const char * name) { - const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N) + const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE) : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE); const auto & it = model->gguf_kv.find(key); if (it == model->gguf_kv.end()) { diff --git a/tools/main/main.cpp b/tools/main/main.cpp index 0dd5a0751..64b157b90 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -148,6 +148,8 @@ int main(int argc, char ** argv) { return 1; } + auto * mem = llama_get_memory(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); auto chat_templates = common_chat_templates_init(model, params.chat_template); @@ -352,7 +354,7 @@ int main(int argc, char ** argv) { } // remove any "future" tokens that we might have inherited from the previous session - llama_kv_self_seq_rm(ctx, -1, n_matching_session_tokens, -1); + llama_memory_seq_rm(mem, -1, n_matching_session_tokens, -1); } LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n", @@ -600,8 +602,8 @@ int main(int argc, char ** argv) { LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", n_past, n_left, n_ctx, params.n_keep, n_discard); - llama_kv_self_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); - llama_kv_self_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); + llama_memory_seq_rm (mem, 0, params.n_keep , params.n_keep + n_discard); + llama_memory_seq_add(mem, 0, params.n_keep + n_discard, n_past, -n_discard); n_past -= n_discard; @@ -624,9 +626,9 @@ int main(int argc, char ** argv) { LOG_DBG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n); LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd); - llama_kv_self_seq_add(ctx, 0, ga_i, n_past, ib*bd); - llama_kv_self_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); - llama_kv_self_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); + llama_memory_seq_add(mem, 0, ga_i, n_past, ib*bd); + llama_memory_seq_div(mem, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); + llama_memory_seq_add(mem, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); n_past -= bd; diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index 40deab5ab..599e682e0 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -342,7 +342,7 @@ int main(int argc, char ** argv) { } if (line == "/clear") { ctx.n_past = 0; - llama_kv_self_seq_rm(ctx.lctx, 0, 1, -1); // keep BOS + llama_memory_seq_rm(llama_get_memory(ctx.lctx), 0, 1, -1); // keep BOS LOG("Chat history cleared\n\n"); continue; } diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index f8e304342..35b9e702f 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 9038df4c3..77dcbc11b 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2006,7 +2006,7 @@ struct server_context { } } - if (!llama_kv_self_can_shift(ctx)) { + if (!llama_memory_can_shift(llama_get_memory(ctx))) { if (params_base.ctx_shift) { params_base.ctx_shift = false; SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled"); @@ -2142,7 +2142,8 @@ struct server_context { // find the slot that has been least recently used if (ret == nullptr) { - int64_t t_last = ggml_time_us(); + int64_t t_last = -1; + for (server_slot & slot : slots) { // skip the slot if it is not available if (slot.is_processing()) { @@ -2150,7 +2151,7 @@ struct server_context { } // select the current slot if the criteria match - if (slot.t_last_used < t_last) { + if (!ret || slot.t_last_used <= t_last) { t_last = slot.t_last_used; ret = &slot; } @@ -2224,7 +2225,7 @@ struct server_context { SRV_DBG("%s", "clearing KV cache\n"); // clear the entire KV cache - llama_kv_self_clear(ctx); + llama_memory_clear(llama_get_memory(ctx), true); clean_kv_cache = false; } @@ -2910,7 +2911,7 @@ struct server_context { // Erase token cache const size_t n_erased = slot->cache_tokens.size(); - llama_kv_self_seq_rm(ctx, slot->id, -1, -1); + llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1); slot->cache_tokens.clear(); auto res = std::make_unique(); @@ -2985,8 +2986,8 @@ struct server_context { SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); - llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); - llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard); + llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard); + llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.n_past, -n_discard); // add generated tokens to cache { @@ -3189,8 +3190,8 @@ struct server_context { const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; - llama_kv_self_seq_rm (ctx, slot.id, head_p, head_c); - llama_kv_self_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift); + llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c); + llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift); for (size_t i = 0; i < n_match; i++) { slot.cache_tokens.set_token(head_p + i, slot.cache_tokens[head_c + i]); @@ -3212,7 +3213,7 @@ struct server_context { } if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) { - const auto pos_min = llama_kv_self_seq_pos_min(ctx, slot.id); + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); if (pos_min == -1) { SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min); GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); @@ -3247,9 +3248,9 @@ struct server_context { } // keep only the common part - if (!llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1)) { + if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) { // could not partially delete (likely using a non-Transformer model) - llama_kv_self_seq_rm(ctx, slot.id, -1, -1); + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); // there is no common part left slot.n_past = 0; @@ -3589,7 +3590,7 @@ struct server_context { slot.cache_tokens.push_back(id); slot.cache_tokens.insert({ids.begin(), ids.end() - 1}); - llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1); + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1); for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; diff --git a/tools/server/webui/src/App.tsx b/tools/server/webui/src/App.tsx index 02f1719d3..8dfcf4907 100644 --- a/tools/server/webui/src/App.tsx +++ b/tools/server/webui/src/App.tsx @@ -32,7 +32,7 @@ function AppLayout() { <>