diff --git a/common/arg.cpp b/common/arg.cpp index 1227aeb2a..81c4005c5 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2254,9 +2254,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); add_opt(common_arg( {"-dt", "--defrag-thold"}, "N", - string_format("KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold), + string_format("KV cache defragmentation threshold (DEPRECATED)"), [](common_params & params, const std::string & value) { - params.defrag_thold = std::stof(value); + GGML_UNUSED(params); + GGML_UNUSED(value); + LOG_WRN("DEPRECATED: --defrag-thold is deprecated and no longer necessary to specify\n"); } ).set_env("LLAMA_ARG_DEFRAG_THOLD")); add_opt(common_arg( diff --git a/common/common.cpp b/common/common.cpp index decabcc2e..fdce1dcde 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1152,7 +1152,6 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.yarn_orig_ctx = params.yarn_orig_ctx; cparams.pooling_type = params.pooling_type; cparams.attention_type = params.attention_type; - cparams.defrag_thold = params.defrag_thold; cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; diff --git a/common/common.h b/common/common.h index 614e41a24..390dda5e5 100644 --- a/common/common.h +++ b/common/common.h @@ -288,7 +288,6 @@ struct common_params { float yarn_beta_fast = 32.0f; // YaRN low correction dim float yarn_beta_slow = 1.0f; // YaRN high correction dim int32_t yarn_orig_ctx = 0; // YaRN original context length - float defrag_thold = 0.1f; // KV cache defragmentation threshold // offload params std::vector devices; // devices to use for offloading diff --git a/examples/llama.vim b/examples/llama.vim index af3fd3935..736802d36 100644 --- a/examples/llama.vim +++ b/examples/llama.vim @@ -17,7 +17,7 @@ " " start the llama.cpp server with a FIM-compatible model. for example: " -" $ llama-server -m {model.gguf} --port 8012 -ngl 99 -fa -dt 0.1 --ubatch-size 512 --batch-size 1024 --cache-reuse 256 +" $ llama-server -m {model.gguf} --port 8012 -ngl 99 -fa --ubatch-size 512 --batch-size 1024 --cache-reuse 256 " " --batch-size [512, model max context] " diff --git a/include/llama.h b/include/llama.h index 662e0971d..c5622cc16 100644 --- a/include/llama.h +++ b/include/llama.h @@ -312,7 +312,7 @@ extern "C" { float yarn_beta_fast; // YaRN low correction dim float yarn_beta_slow; // YaRN high correction dim uint32_t yarn_orig_ctx; // YaRN original context size - float defrag_thold; // defragment the KV cache if holes/size > thold, <= 0 disabled (default) + float defrag_thold; // [DEPRECATED] defragment the KV cache if holes/size > thold, <= 0 disabled (default) ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index 8366f89a0..0141e0a35 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -28,7 +28,6 @@ LLAMA_BENCH_DB_FIELDS = [ "model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads", "cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides", - "defrag_thold", "use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen", "n_depth", "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", ] diff --git a/src/llama-context.cpp b/src/llama-context.cpp index e8e8b3450..18cf25079 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -39,7 +39,6 @@ llama_context::llama_context( cparams.yarn_attn_factor = params.yarn_attn_factor; cparams.yarn_beta_fast = params.yarn_beta_fast; cparams.yarn_beta_slow = params.yarn_beta_slow; - cparams.defrag_thold = params.defrag_thold; cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; cparams.flash_attn = params.flash_attn; @@ -978,7 +977,7 @@ int llama_context::decode(const llama_batch & batch_inp) { bool did_optimize = false; - // handle any pending defrags/shifts + // handle any pending shifts/copies memory_update(false); llama_memory_context_ptr mctx; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 38750affc..dbbaba9f6 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -24,7 +24,6 @@ struct llama_cparams { float yarn_attn_factor; float yarn_beta_fast; float yarn_beta_slow; - float defrag_thold; bool embeddings; bool causal_attn; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index bb490cf9e..70ddd5f4b 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -525,39 +525,11 @@ llama_memory_context_ptr llama_kv_cache::init_full() { } llama_memory_context_ptr llama_kv_cache::init_update(llama_context * lctx, bool optimize) { + GGML_UNUSED(optimize); + bool do_shift = get_has_shift(); - defrag_info dinfo; - - // see if we need to defrag - if (n_stream == 1) { - // note : for now do not consider defrag for n_stream > 1 - const auto & cells = v_cells[seq_to_stream[0]]; - - bool do_defrag = optimize; - - const auto thold = lctx->get_cparams().defrag_thold; - - if (!do_defrag && thold > 0.0f) { - const auto n_kv = cells.used_max_p1(); - - // - do not defrag small contexts (i.e. < 2048 tokens) - // - count the padding towards the number of used tokens - const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f; - - if (fragmentation > thold) { - LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); - - do_defrag = true; - } - } - - if (do_defrag) { - dinfo = defrag_prepare(lctx->graph_max_nodes()); - } - } - - return std::make_unique(this, lctx, do_shift, std::move(dinfo), std::move(sc_info)); + return std::make_unique(this, lctx, do_shift, std::move(sc_info)); } llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector & ubatches) { @@ -629,7 +601,7 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vectorget_sched(); @@ -699,53 +671,6 @@ bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const defrag_in } } - if (!dinfo.empty()) { - LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); - - // note: for now do not consider defrag for n_stream > 1 - auto & cells = v_cells[seq_to_stream[0]]; - auto & head = v_heads[seq_to_stream[0]]; - - // apply moves: - { - const auto n_kv = dinfo.ids.size(); - - for (uint32_t i = 0; i < n_kv; ++i) { - assert(dinfo.ids[i] <= n_kv); - - if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) { - continue; - } - - cells.mv(i, dinfo.ids[i]); - } - - // reset the head so we can find the first free slot during the next ubatch - head = 0; - } - - ggml_backend_sched_reset(sched); - - auto * res = lctx->get_gf_res_reserve(); - - res->reset(); - - auto * gf = build_graph_defrag(res, lctx, dinfo); - if (!ggml_backend_sched_alloc_graph(sched, gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__); - return updated; - } - - res->set_inputs(nullptr); - - if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) { - LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__); - return updated; - } - - updated = true; - } - return updated; } @@ -1525,283 +1450,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co return gf; } -ggml_cgraph * llama_kv_cache::build_graph_defrag( - llm_graph_result * res, - llama_context * lctx, - const defrag_info & dinfo) const { - auto * ctx = res->get_ctx(); - auto * gf = res->get_gf(); - - GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag"); - - const auto & cells = v_cells[0]; - - const auto & ids = dinfo.ids; - - const auto & cparams = lctx->get_cparams(); - -#if 0 - // CPU defrag - // - // TODO: optimizations are possible: - // - multiple threads - // - avoid copying to the host memory when already there - // - // likely not worth the effort, as we have ggml_graph based defrag - // - - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - - const uint32_t kv_size = size; - - std::vector buf_k; - std::vector buf_v; - - for (uint32_t il = 0; il < n_layer; ++il) { - const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); - const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size); - - const size_t v_size_el = ggml_type_size(v_l[il]->type); - const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size); - - buf_k.resize(k_size); - buf_v.resize(v_size); - - ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size()); - ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size()); - - // batch move [i, i+nm) to [id, id+nm) - // note: cells can move only to a lower index - for (uint32_t i = 0; i < n_kv; ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == n_kv) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < n_kv && ids[i + nm] == id + nm) { - nm++; - } - - // move keys - { - const int64_t os = i*k_size_row; - const int64_t od = id*k_size_row; - - memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row); - } - - // move values (note: they are transposed) - { - const int64_t os = i; - const int64_t od = id; - - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el); - } - } - - i += nm - 1; - } - - ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size()); - ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size()); - } -#else - for (uint32_t i = 0; i < ids.size(); ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == ids.size()) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < ids.size() && ids[i + nm] == id + nm) { - nm++; - } - - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); - - ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k, - n_embd_k_gqa, nm, - ggml_row_size(layer.k->type, n_embd_k_gqa), - ggml_row_size(layer.k->type, n_embd_k_gqa*i)); - - ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k, - n_embd_k_gqa, nm, - ggml_row_size(layer.k->type, n_embd_k_gqa), - ggml_row_size(layer.k->type, n_embd_k_gqa*id)); - - ggml_tensor * view_v_src; - ggml_tensor * view_v_dst; - - if (cparams.flash_attn) { - // NOTE: the V cache is not transposed when using flash attention - view_v_src = ggml_view_2d(ctx, layer.v, - n_embd_v_gqa, nm, - ggml_row_size(layer.v->type, n_embd_v_gqa), - ggml_row_size(layer.v->type, n_embd_v_gqa*i)); - - view_v_dst = ggml_view_2d(ctx, layer.v, - n_embd_v_gqa, nm, - ggml_row_size(layer.v->type, n_embd_v_gqa), - ggml_row_size(layer.v->type, n_embd_v_gqa*id)); - } else { - view_v_src = ggml_view_2d(ctx, layer.v, - nm, n_embd_v_gqa, - ggml_row_size(layer.v->type, cells.size()), - ggml_row_size(layer.v->type, i)); - - view_v_dst = ggml_view_2d(ctx, layer.v, - nm, n_embd_v_gqa, - ggml_row_size(layer.v->type, cells.size()), - ggml_row_size(layer.v->type, id)); - } - - ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst)); - ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst)); - } - - i += nm - 1; - } - - //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); -#endif - - return gf; -} - -llama_kv_cache::defrag_info llama_kv_cache::defrag_prepare(int32_t n_max_nodes) const { - GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag"); - - const auto & cells = v_cells[0]; - - const uint32_t n_layer = layers.size(); - - const uint32_t n_kv = cells.used_max_p1(); - const uint32_t n_used = cells.get_used(); - - assert(n_used <= n_kv); - - //const int64_t t_start = ggml_time_us(); - - // number of cells moved - uint32_t n_moves = 0; - - // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag) - // - source view, destination view, copy operation - // - x2 for keys and values - //const uint32_t max_moves = max_nodes()/(6*n_layer); - // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 - const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); - - // determine which KV cells to move where - defrag_info res; - auto & ids = res.ids; - - ids.resize(n_kv, n_kv); - - for (uint32_t i0 = 0; i0 < n_used; ++i0) { - if (!cells.is_empty(i0)) { - ids[i0] = i0; - - continue; - } - - // found a hole - fill it with data from the end of the cache - - uint32_t nh = 1; - - // determine the size of the hole - while (i0 + nh < n_used && cells.is_empty(i0 + nh)) { - nh++; - } - - uint32_t nf = 0; - uint32_t is = n_kv - 1; - - // starting from the end, find nh non-empty cells - for (; is > i0; --is) { - if (cells.is_empty(is) || ids[is] != n_kv) { - continue; - } - - // non-empty cell which is not yet moved - nf++; - - if (nf == nh) { - break; - } - } - - // this can only happen if `n_used` is not accurate, which would be a bug - GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh"); - - nf = 0; - - uint32_t i1 = is; - - // are we moving a continuous block of memory? - bool cont = false; - - // should we stop searching for the next move? - bool stop = false; - - // go back and move the nf cells to the hole - for (; i1 < n_kv; ++i1) { - if (cells.is_empty(i1) || ids[i1] != n_kv) { - if (n_moves == max_moves) { - stop = true; - break; - } - - cont = false; - continue; - } - - // this cell goes to (i0 + nf) - ids[i1] = i0 + nf; - - if (!cont) { - n_moves++; - cont = true; - } - - nf++; - - if (nf == nh) { - break; - } - } - - if (stop || n_moves == max_moves) { - break; - } - - //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh); - - i0 += nh - 1; - } - - if (n_moves == 0) { - return {}; - } - - LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves); - - LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer); - - return res; -} - bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const { assert(p0 >= 0 && p1 >= 0); @@ -2300,9 +1948,8 @@ llama_kv_cache_context::llama_kv_cache_context( llama_kv_cache * kv, llama_context * lctx, bool do_shift, - defrag_info dinfo, - stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)), sc_info(std::move(sc_info)) { - if (!do_shift && this->dinfo.empty() && this->sc_info.empty()) { + stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), sc_info(std::move(sc_info)) { + if (!do_shift && this->sc_info.empty()) { status = LLAMA_MEMORY_STATUS_NO_UPDATE; } } @@ -2330,7 +1977,7 @@ bool llama_kv_cache_context::apply() { // no ubatches -> this is a KV cache update if (ubatches.empty()) { - kv->update(lctx, do_shift, dinfo, sc_info); + kv->update(lctx, do_shift, sc_info); return true; } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 5ca618e1b..297a0973d 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -24,17 +24,6 @@ public: // this callback is used to filter out layers that should not be included in the cache using layer_filter_cb = std::function; - struct defrag_info { - bool empty() const { - return ids.empty(); - } - - // contains information about which cell moves where: - // - cell i moves to ids[i] - // - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved - std::vector ids; - }; - struct stream_copy_info { bool empty() const { assert(ssrc.size() == sdst.size()); @@ -173,7 +162,7 @@ public: // return empty vector on failure slot_info_vec_t prepare(const std::vector & ubatches); - bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info); + bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info); // find a slot of kv cells that can hold the ubatch // if cont == true, then the slot must be continuous @@ -254,9 +243,6 @@ private: // model layer id -> KV cache layer id std::unordered_map map_layer_ids; - // return non-empty vector if cells have been moved - defrag_info defrag_prepare(int32_t n_max_nodes) const; - size_t total_size() const; size_t size_k_bytes() const; @@ -277,11 +263,6 @@ private: llm_graph_result * res, llama_context * lctx) const; - ggml_cgraph * build_graph_defrag( - llm_graph_result * res, - llama_context * lctx, - const defrag_info & dinfo) const; - struct cell_ranges_t { uint32_t strm; @@ -299,7 +280,6 @@ class llama_kv_cache_context : public llama_memory_context_i { public: // some shorthands using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; - using defrag_info = llama_kv_cache::defrag_info; using stream_copy_info = llama_kv_cache::stream_copy_info; // used for errors @@ -314,7 +294,6 @@ public: llama_kv_cache * kv, llama_context * lctx, bool do_shift, - defrag_info dinfo, stream_copy_info sc_info); // used to create a batch procesing context from a batch @@ -374,8 +353,6 @@ private: bool do_shift = false; - defrag_info dinfo; - stream_copy_info sc_info; // diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index 2651e3033..8f6bf0145 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -77,24 +77,24 @@ public: } // move cell isrc to idst (used during defrag) - void mv(uint32_t isrc, uint32_t idst) { - assert(isrc < pos.size()); - assert(idst < pos.size()); + //void mv(uint32_t isrc, uint32_t idst) { + // assert(isrc < pos.size()); + // assert(idst < pos.size()); - assert(pos[idst] == -1); - assert(pos[isrc] != -1); + // assert(pos[idst] == -1); + // assert(pos[isrc] != -1); - pos [idst] = pos [isrc]; - shift[idst] = shift[isrc]; - seq [idst] = seq [isrc]; + // pos [idst] = pos [isrc]; + // shift[idst] = shift[isrc]; + // seq [idst] = seq [isrc]; - pos [isrc] = -1; - shift[isrc] = 0; - seq [isrc].reset(); + // pos [isrc] = -1; + // shift[isrc] = 0; + // seq [isrc].reset(); - used.erase (isrc); - used.insert(idst); - } + // used.erase (isrc); + // used.insert(idst); + //} // copy the state of cells [i, i + n) (used for save/restore the state of the cells) llama_kv_cells cp(uint32_t i, uint32_t n) const { diff --git a/src/llama-memory.h b/src/llama-memory.h index 42a7145c2..94d858bcc 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -77,7 +77,7 @@ struct llama_memory_i { // simulate full cache, used for allocating worst-case compute buffers virtual llama_memory_context_ptr init_full() = 0; - // prepare for any pending memory updates, such as shifts, defrags, etc. + // prepare for any pending memory updates, such as shifts, copies, etc. // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0; diff --git a/tools/llama-bench/README.md b/tools/llama-bench/README.md index 31a273087..bf7fd29c8 100644 --- a/tools/llama-bench/README.md +++ b/tools/llama-bench/README.md @@ -43,7 +43,6 @@ test parameters: -ub, --ubatch-size (default: 512) -ctk, --cache-type-k (default: f16) -ctv, --cache-type-v (default: f16) - -dt, --defrag-thold (default: -1) -t, --threads (default: system dependent) -C, --cpu-mask (default: 0x0) --cpu-strict <0|1> (default: 0) diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 10b48c556..9378706a1 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -245,7 +245,6 @@ struct cmd_params { std::vector n_ubatch; std::vector type_k; std::vector type_v; - std::vector defrag_thold; std::vector n_threads; std::vector cpu_mask; std::vector cpu_strict; @@ -282,7 +281,6 @@ static const cmd_params cmd_params_defaults = { /* n_ubatch */ { 512 }, /* type_k */ { GGML_TYPE_F16 }, /* type_v */ { GGML_TYPE_F16 }, - /* defrag_thold */ { -1.0f }, /* n_threads */ { cpu_get_num_math() }, /* cpu_mask */ { "0x0" }, /* cpu_strict */ { false }, @@ -346,8 +344,6 @@ static void print_usage(int /* argc */, char ** argv) { join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str()); printf(" -ctv, --cache-type-v (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str()); - printf(" -dt, --defrag-thold (default: %s)\n", - join(cmd_params_defaults.defrag_thold, ",").c_str()); printf(" -t, --threads (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str()); printf(" -C, --cpu-mask (default: %s)\n", @@ -533,13 +529,6 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } params.type_v.insert(params.type_v.end(), types.begin(), types.end()); - } else if (arg == "-dt" || arg == "--defrag-thold") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = string_split(argv[i], split_delim); - params.defrag_thold.insert(params.defrag_thold.end(), p.begin(), p.end()); } else if (arg == "-t" || arg == "--threads") { if (++i >= argc) { invalid_param = true; @@ -849,9 +838,6 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.type_v.empty()) { params.type_v = cmd_params_defaults.type_v; } - if (params.defrag_thold.empty()) { - params.defrag_thold = cmd_params_defaults.defrag_thold; - } if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; } @@ -910,7 +896,6 @@ struct cmd_params_instance { int n_ubatch; ggml_type type_k; ggml_type type_v; - float defrag_thold; int n_threads; std::string cpu_mask; bool cpu_strict; @@ -1007,7 +992,6 @@ struct cmd_params_instance { cparams.n_ubatch = n_ubatch; cparams.type_k = type_k; cparams.type_v = type_v; - cparams.defrag_thold = defrag_thold; cparams.offload_kqv = !no_kv_offload; cparams.flash_attn = flash_attn; cparams.embeddings = embeddings; @@ -1037,7 +1021,6 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & nub : params.n_ubatch) for (const auto & tk : params.type_k) for (const auto & tv : params.type_v) - for (const auto & defrag_thold : params.defrag_thold) for (const auto & nkvo : params.no_kv_offload) for (const auto & fa : params.flash_attn) for (const auto & nt : params.n_threads) @@ -1058,7 +1041,6 @@ static std::vector get_cmd_params_instances(const cmd_param /* .n_ubatch = */ nub, /* .type_k = */ tk, /* .type_v = */ tv, - /* .defrag_thold = */ defrag_thold, /* .n_threads = */ nt, /* .cpu_mask = */ cm, /* .cpu_strict = */ cs, @@ -1091,7 +1073,6 @@ static std::vector get_cmd_params_instances(const cmd_param /* .n_ubatch = */ nub, /* .type_k = */ tk, /* .type_v = */ tv, - /* .defrag_thold = */ defrag_thold, /* .n_threads = */ nt, /* .cpu_mask = */ cm, /* .cpu_strict = */ cs, @@ -1124,7 +1105,6 @@ static std::vector get_cmd_params_instances(const cmd_param /* .n_ubatch = */ nub, /* .type_k = */ tk, /* .type_v = */ tv, - /* .defrag_thold = */ defrag_thold, /* .n_threads = */ nt, /* .cpu_mask = */ cm, /* .cpu_strict = */ cs, @@ -1166,7 +1146,6 @@ struct test { int poll; ggml_type type_k; ggml_type type_v; - float defrag_thold; int n_gpu_layers; llama_split_mode split_mode; int main_gpu; @@ -1201,7 +1180,6 @@ struct test { poll = inst.poll; type_k = inst.type_k; type_v = inst.type_v; - defrag_thold = inst.defrag_thold; n_gpu_layers = inst.n_gpu_layers; split_mode = inst.split_mode; main_gpu = inst.main_gpu; @@ -1257,7 +1235,6 @@ struct test { "model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads", "cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides", - "defrag_thold", "use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen", "n_depth", "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", }; @@ -1277,7 +1254,7 @@ struct test { field == "use_mmap" || field == "embeddings") { return BOOL; } - if (field == "avg_ts" || field == "stddev_ts" || field == "defrag_thold") { + if (field == "avg_ts" || field == "stddev_ts") { return FLOAT; } return STRING; @@ -1344,7 +1321,6 @@ struct test { std::to_string(flash_attn), tensor_split_str, tensor_buft_overrides_str, - std::to_string(defrag_thold), std::to_string(use_mmap), std::to_string(embeddings), std::to_string(no_op_offload), @@ -1611,9 +1587,6 @@ struct markdown_printer : public printer { if (params.type_v.size() > 1 || params.type_v != cmd_params_defaults.type_v) { fields.emplace_back("type_v"); } - if (params.defrag_thold.size() > 1 || params.defrag_thold != cmd_params_defaults.defrag_thold) { - fields.emplace_back("defrag_thold"); - } if (params.main_gpu.size() > 1 || params.main_gpu != cmd_params_defaults.main_gpu) { fields.emplace_back("main_gpu"); } diff --git a/tools/server/README.md b/tools/server/README.md index 86844225f..baf3730ad 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -66,7 +66,7 @@ The project is under active development, and we are [looking for feedback and co | `-nkvo, --no-kv-offload` | disable KV offload
(env: LLAMA_ARG_NO_KV_OFFLOAD) | | `-ctk, --cache-type-k TYPE` | KV cache data type for K
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K) | | `-ctv, --cache-type-v TYPE` | KV cache data type for V
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V) | -| `-dt, --defrag-thold N` | KV cache defragmentation threshold (default: 0.1, < 0 - disabled)
(env: LLAMA_ARG_DEFRAG_THOLD) | +| `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)
(env: LLAMA_ARG_DEFRAG_THOLD) | | `-np, --parallel N` | number of parallel sequences to decode (default: 1)
(env: LLAMA_ARG_N_PARALLEL) | | `--mlock` | force system to keep model in RAM rather than swapping or compressing
(env: LLAMA_ARG_MLOCK) | | `--no-mmap` | do not memory-map model (slower load but may reduce pageouts if not using mlock)
(env: LLAMA_ARG_NO_MMAP) | diff --git a/tools/server/bench/bench.py b/tools/server/bench/bench.py index 5cc6f92ab..0c57a2df0 100644 --- a/tools/server/bench/bench.py +++ b/tools/server/bench/bench.py @@ -274,7 +274,6 @@ def start_server_background(args): server_args.extend(['--batch-size', args.batch_size]) server_args.extend(['--ubatch-size', args.ubatch_size]) server_args.extend(['--n-predict', args.max_tokens * 2]) - server_args.extend(['--defrag-thold', "0.1"]) server_args.append('--cont-batching') server_args.append('--metrics') server_args.append('--flash-attn')