From 0754b7b6fe6109909786bdaa763111167b7410c8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 2 May 2026 18:03:25 +0300 Subject: [PATCH] server : avoid checkpoint data host copies (#22558) * server : avoid checkpoint data host copies * llama : refactor llama_io_read_i --- src/llama-context.cpp | 76 ++++++++++++++++++++++++++++----- src/llama-io.cpp | 9 +++- src/llama-io.h | 4 +- src/llama-kv-cache.cpp | 53 +++++++++++------------ src/llama-memory-recurrent.cpp | 36 ++++++++-------- tools/server/server-context.cpp | 26 +++++------ 6 files changed, 132 insertions(+), 72 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 8126249e1..d584415ee 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2253,6 +2253,28 @@ public: llama_io_write_buffer( uint8_t * p, size_t len) : ptr(p), buf_size(len) {} + ~llama_io_write_buffer() { +#if 1 + // TODO: add backend support to batch tensor_get? or some other way to speed this up + for (const auto & info : winfos) { + ggml_backend_tensor_get(info.tensor, info.ptr, info.offset, info.size); + } +#else + // flush the writes asynchronously + // this helps on Macs, but on other devices - it does not. just an example + std::vector> futures; + futures.reserve(winfos.size()); + for (const auto & info : winfos) { + futures.push_back(std::async(std::launch::async, [info]() { + ggml_backend_tensor_get(info.tensor, info.ptr, info.offset, info.size); + })); + } + for (auto & f : futures) { + f.wait(); + } +#endif + } + void write(const void * src, size_t size) override { if (size > buf_size) { throw std::runtime_error("unexpectedly reached end of buffer"); @@ -2267,7 +2289,10 @@ public: if (size > buf_size) { throw std::runtime_error("unexpectedly reached end of buffer"); } - ggml_backend_tensor_get(tensor, ptr, offset, size); + + // save the write for later during destruction + winfos.push_back({tensor, ptr, size, offset}); + ptr += size; size_written += size; buf_size -= size; @@ -2281,25 +2306,48 @@ private: uint8_t * ptr; size_t buf_size = 0; size_t size_written = 0; + + struct write_info { + const ggml_tensor * tensor; + uint8_t * ptr; + size_t size; + size_t offset; + }; + std::vector winfos; }; class llama_io_read_buffer : public llama_io_read_i { public: llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} - const uint8_t * read(size_t size) override { - const uint8_t * base_ptr = ptr; + ~llama_io_read_buffer() { + // flush the reads + for (const auto & info : rinfos) { + ggml_backend_tensor_set(info.tensor, info.ptr, info.offset, info.size); + } + } + + void read(void * dst, size_t size) override { if (size > buf_size) { throw std::runtime_error("unexpectedly reached end of buffer"); } + memcpy(dst, ptr, size); ptr += size; size_read += size; buf_size -= size; - return base_ptr; } - void read_to(void * dst, size_t size) override { - memcpy(dst, read(size), size); + void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + + // save for later during destruction + rinfos.push_back({tensor, ptr, size, offset}); + + ptr += size; + size_read += size; + buf_size -= size; } size_t n_bytes() override { @@ -2310,6 +2358,14 @@ private: const uint8_t * ptr; size_t buf_size = 0; size_t size_read = 0; + + struct read_info { + ggml_tensor * tensor; + const uint8_t * ptr; + size_t size; + size_t offset; + }; + std::vector rinfos; }; class llama_io_write_file : public llama_io_write_i { @@ -2341,15 +2397,15 @@ class llama_io_read_file : public llama_io_read_i { public: llama_io_read_file(llama_file * f) : file(f) {} - void read_to(void * dst, size_t size) override { + void read(void * dst, size_t size) override { file->read_raw(dst, size); size_read += size; } - const uint8_t * read(size_t size) override { + void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override { temp_buffer.resize(size); - read_to(temp_buffer.data(), size); - return temp_buffer.data(); + read(temp_buffer.data(), size); + ggml_backend_tensor_set(tensor, temp_buffer.data(), offset, size); } size_t n_bytes() override { diff --git a/src/llama-io.cpp b/src/llama-io.cpp index 7ad70d163..5ec463494 100644 --- a/src/llama-io.cpp +++ b/src/llama-io.cpp @@ -1,5 +1,7 @@ #include "llama-io.h" +#include + void llama_io_write_i::write_string(const std::string & str) { uint32_t str_size = str.size(); @@ -9,7 +11,10 @@ void llama_io_write_i::write_string(const std::string & str) { void llama_io_read_i::read_string(std::string & str) { uint32_t str_size; - read_to(&str_size, sizeof(str_size)); + read(&str_size, sizeof(str_size)); - str.assign((const char *) read(str_size), str_size); + std::vector buf(str_size); + read(buf.data(), str_size); + + str.assign(buf.data(), str_size); } diff --git a/src/llama-io.h b/src/llama-io.h index ce9216b83..1e77a2578 100644 --- a/src/llama-io.h +++ b/src/llama-io.h @@ -25,8 +25,8 @@ public: llama_io_read_i() = default; virtual ~llama_io_read_i() = default; - virtual const uint8_t * read(size_t size) = 0; - virtual void read_to(void * dst, size_t size) = 0; + virtual void read(void * dst, size_t size) = 0; + virtual void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) = 0; // bytes read so far virtual size_t n_bytes() = 0; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 09102f549..666cca12f 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1900,14 +1900,14 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); uint32_t n_stream_cur; - io.read_to(&n_stream_cur, sizeof(n_stream_cur)); + io.read(&n_stream_cur, sizeof(n_stream_cur)); if (n_stream_cur != n_stream) { throw std::runtime_error("n_stream mismatch"); } for (uint32_t s = 0; s < n_stream; ++s) { uint32_t cell_count; - io.read_to(&cell_count, sizeof(cell_count)); + io.read(&cell_count, sizeof(cell_count)); if (cell_count == 0) { continue; @@ -2082,8 +2082,8 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 llama_pos pos; uint32_t n_seq_id; - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); + io.read(&pos, sizeof(pos)); + io.read(&n_seq_id, sizeof(n_seq_id)); if (n_seq_id != 1) { LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); @@ -2092,7 +2092,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 if (hparams.n_pos_per_embd() > 1) { llama_kv_cell_ext ext; - io.read_to(&ext, sizeof(ext)); + io.read(&ext, sizeof(ext)); ubatch.pos[i + ubatch.n_tokens] = ext.y; ubatch.pos[i + ubatch.n_tokens*2] = ext.x; @@ -2101,7 +2101,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 // read the sequence id, but directly discard it - we will use dest_seq_id instead { llama_seq_id seq_id; - io.read_to(&seq_id, sizeof(seq_id)); + io.read(&seq_id, sizeof(seq_id)); } ubatch.pos[i] = pos; @@ -2143,20 +2143,20 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 llama_pos pos; uint32_t n_seq_id; - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); + io.read(&pos, sizeof(pos)); + io.read(&n_seq_id, sizeof(n_seq_id)); cells.pos_set(i, pos); if (hparams.n_pos_per_embd() > 1) { llama_kv_cell_ext ext; - io.read_to(&ext, sizeof(ext)); + io.read(&ext, sizeof(ext)); cells.ext_set(i, ext); } for (uint32_t j = 0; j < n_seq_id; ++j) { llama_seq_id seq_id; - io.read_to(&seq_id, sizeof(seq_id)); + io.read(&seq_id, sizeof(seq_id)); if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) { LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max); @@ -2189,8 +2189,8 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 uint32_t v_trans; uint32_t n_layer; - io.read_to(&v_trans, sizeof(v_trans)); - io.read_to(&n_layer, sizeof(n_layer)); + io.read(&v_trans, sizeof(v_trans)); + io.read(&n_layer, sizeof(n_layer)); if (n_layer != layers.size()) { LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size()); @@ -2217,7 +2217,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read type of key int32_t k_type_i_ref; - io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); + io.read(&k_type_i_ref, sizeof(k_type_i_ref)); const int32_t k_type_i = (int32_t) k->type; if (k_type_i != k_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); @@ -2226,7 +2226,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read row size of key uint64_t k_size_row_ref; - io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); + io.read(&k_size_row_ref, sizeof(k_size_row_ref)); const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa); if (k_size_row != k_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); @@ -2236,13 +2236,12 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 if (cell_count) { if (sinfo.is_contiguous()) { // Fast path: contiguous cells, single memcpy - ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row); + io.read_tensor(k, sinfo.head() * k_size_row, cell_count * k_size_row); } else { // Slow path: scatter to non-contiguous positions - const void * src = io.read(cell_count * k_size_row); for (uint32_t i = 0; i < cell_count; ++i) { const size_t dst_offset = sinfo.idxs[0][i] * k_size_row; - ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row); + io.read_tensor(k, dst_offset, k_size_row); } } } @@ -2261,7 +2260,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read type of value int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + io.read(&v_type_i_ref, sizeof(v_type_i_ref)); const int32_t v_type_i = (int32_t) v->type; if (v_type_i != v_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); @@ -2270,7 +2269,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read row size of value uint64_t v_size_row_ref; - io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); + io.read(&v_size_row_ref, sizeof(v_size_row_ref)); const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa); if (v_size_row != v_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); @@ -2280,13 +2279,12 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 if (cell_count) { if (sinfo.is_contiguous()) { // Fast path: contiguous cells, single memcpy - ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row); + io.read_tensor(v, sinfo.head() * v_size_row, cell_count * v_size_row); } else { // Slow path: scatter to non-contiguous positions - const void * src = io.read(cell_count * v_size_row); for (uint32_t i = 0; i < cell_count; ++i) { const size_t dst_offset = sinfo.idxs[0][i] * v_size_row; - ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row); + io.read_tensor(v, dst_offset, v_size_row); } } } @@ -2305,7 +2303,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read type of value int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + io.read(&v_type_i_ref, sizeof(v_type_i_ref)); const int32_t v_type_i = (int32_t) v->type; if (v_type_i != v_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); @@ -2314,7 +2312,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read element size of value uint32_t v_size_el_ref; - io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); + io.read(&v_size_el_ref, sizeof(v_size_el_ref)); const size_t v_size_el = ggml_type_size(v->type); if (v_size_el != v_size_el_ref) { LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); @@ -2323,7 +2321,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read GQA embedding size uint32_t n_embd_v_gqa_ref; - io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + io.read(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); if (n_embd_v_gqa != n_embd_v_gqa_ref) { LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); return false; @@ -2335,15 +2333,14 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 const uint32_t h = sinfo.head(); for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { const size_t dst_offset = (h + j * cells.size()) * v_size_el; - ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + io.read_tensor(v, dst_offset, cell_count * v_size_el); } } else { // Slow path: scatter to non-contiguous positions for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const void * src = io.read(cell_count * v_size_el); for (uint32_t i = 0; i < cell_count; ++i) { const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el; - ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el); + io.read_tensor(v, dst_offset, v_size_el); } } } diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 9287fe45e..4b4fdeb6d 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -743,7 +743,7 @@ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_i GGML_UNUSED(flags); uint32_t cell_count; - io.read_to(&cell_count, sizeof(cell_count)); + io.read(&cell_count, sizeof(cell_count)); bool res = true; @@ -879,8 +879,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell llama_pos pos; uint32_t n_seq_id; - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); + io.read(&pos, sizeof(pos)); + io.read(&n_seq_id, sizeof(n_seq_id)); if (n_seq_id != 0) { LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); @@ -920,14 +920,14 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell llama_pos pos; uint32_t n_seq_id; - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); + io.read(&pos, sizeof(pos)); + io.read(&n_seq_id, sizeof(n_seq_id)); cell.pos = pos; for (uint32_t j = 0; j < n_seq_id; ++j) { llama_seq_id seq_id; - io.read_to(&seq_id, sizeof(seq_id)); + io.read(&seq_id, sizeof(seq_id)); if (seq_id < 0 || (uint32_t) seq_id >= this->n_seq_max) { LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, this->n_seq_max); @@ -961,8 +961,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) { uint32_t s_trans; uint32_t n_layer; - io.read_to(&s_trans, sizeof(s_trans)); - io.read_to(&n_layer, sizeof(n_layer)); + io.read(&s_trans, sizeof(s_trans)); + io.read(&n_layer, sizeof(n_layer)); if (n_layer != hparams.n_layer) { LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); @@ -984,7 +984,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read type of key int32_t r_type_i_ref; - io.read_to(&r_type_i_ref, sizeof(r_type_i_ref)); + io.read(&r_type_i_ref, sizeof(r_type_i_ref)); const int32_t r_type_i = (int32_t) r_l[il]->type; if (r_type_i != r_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il); @@ -993,7 +993,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read row size of key uint64_t r_size_row_ref; - io.read_to(&r_size_row_ref, sizeof(r_size_row_ref)); + io.read(&r_size_row_ref, sizeof(r_size_row_ref)); const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); if (r_size_row != r_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il); @@ -1002,7 +1002,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell if (cell_count) { // Read and set the keys for the whole cell range - ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row); + io.read_tensor(r_l[il], head * r_size_row, cell_count * r_size_row); } } @@ -1013,7 +1013,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read type of value int32_t s_type_i_ref; - io.read_to(&s_type_i_ref, sizeof(s_type_i_ref)); + io.read(&s_type_i_ref, sizeof(s_type_i_ref)); const int32_t s_type_i = (int32_t)s_l[il]->type; if (s_type_i != s_type_i_ref) { @@ -1023,7 +1023,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read row size of value uint64_t s_size_row_ref; - io.read_to(&s_size_row_ref, sizeof(s_size_row_ref)); + io.read(&s_size_row_ref, sizeof(s_size_row_ref)); const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); if (s_size_row != s_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il); @@ -1032,7 +1032,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell if (cell_count) { // Read and set the values for the whole cell range - ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row); + io.read_tensor(s_l[il], head * s_size_row, cell_count * s_size_row); } } } else { @@ -1045,7 +1045,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read type of value int32_t s_type_i_ref; - io.read_to(&s_type_i_ref, sizeof(s_type_i_ref)); + io.read(&s_type_i_ref, sizeof(s_type_i_ref)); const int32_t s_type_i = (int32_t)s_l[il]->type; if (s_type_i != s_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il); @@ -1054,7 +1054,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read element size of value uint32_t s_size_el_ref; - io.read_to(&s_size_el_ref, sizeof(s_size_el_ref)); + io.read(&s_size_el_ref, sizeof(s_size_el_ref)); const size_t s_size_el = ggml_type_size(s_l[il]->type); if (s_size_el != s_size_el_ref) { LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il); @@ -1063,7 +1063,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read state embedding size uint32_t n_embd_s_ref; - io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref)); + io.read(&n_embd_s_ref, sizeof(n_embd_s_ref)); if (n_embd_s != n_embd_s_ref) { LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il); return false; @@ -1073,7 +1073,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // For each row in the transposed matrix, read the values for the whole cell range for (uint32_t j = 0; j < n_embd_s; ++j) { const size_t dst_offset = (head + j * size) * s_size_el; - ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el); + io.read_tensor(s_l[il], dst_offset, cell_count * s_size_el); } } } diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 2d3003f03..d21e9c2ee 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -36,7 +36,7 @@ using json = nlohmann::ordered_json; constexpr int HTTP_POLLING_SECONDS = 1; -static server_prompt_checkpoint server_get_checkpoint(llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1) { +static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1) { if (pos_min == -1) { pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id); } @@ -46,19 +46,15 @@ static server_prompt_checkpoint server_get_checkpoint(llama_context * ctx, int i const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - auto cur = server_prompt_checkpoint { - /*.pos_min = */ pos_min, - /*.pos_max = */ pos_max, - /*.n_tokens = */ n_tokens, - /*.data = */ std::vector(checkpoint_size), - }; + ckpt.pos_min = pos_min; + ckpt.pos_max = pos_max; + ckpt.n_tokens = n_tokens; + ckpt.data.resize(checkpoint_size); - const size_t n = llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + const size_t n = llama_state_seq_get_data_ext(ctx, ckpt.data.data(), checkpoint_size, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); if (n != checkpoint_size) { GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n); } - - return cur; } // state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 @@ -364,7 +360,12 @@ struct server_slot { if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { const auto n_tokens = prompt.tokens.size(); - spec_ckpt = server_get_checkpoint(ctx, this->id, n_tokens); + //const int64_t t_start = ggml_time_us(); + + server_prompt_checkpoint_update(spec_ckpt, ctx, this->id, n_tokens); + + //const int64_t t_total = ggml_time_us() - t_start; + //printf("checkpoint total: %f ms\n", t_total / 1000.0); SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n", spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024); @@ -1836,7 +1837,8 @@ private: slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); } - const auto & cur = slot.prompt.checkpoints.emplace_back(server_get_checkpoint(ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max)); + auto & cur = slot.prompt.checkpoints.emplace_back(); + server_prompt_checkpoint_update(cur, ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max); SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",