mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-17 04:09:19 +00:00
server : avoid checkpoint data host copies (#22558)
* server : avoid checkpoint data host copies * llama : refactor llama_io_read_i
This commit is contained in:
parent
09294365a9
commit
0754b7b6fe
6 changed files with 132 additions and 72 deletions
|
|
@ -2253,6 +2253,28 @@ public:
|
|||
llama_io_write_buffer(
|
||||
uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
||||
|
||||
~llama_io_write_buffer() {
|
||||
#if 1
|
||||
// TODO: add backend support to batch tensor_get? or some other way to speed this up
|
||||
for (const auto & info : winfos) {
|
||||
ggml_backend_tensor_get(info.tensor, info.ptr, info.offset, info.size);
|
||||
}
|
||||
#else
|
||||
// flush the writes asynchronously
|
||||
// this helps on Macs, but on other devices - it does not. just an example
|
||||
std::vector<std::future<void>> futures;
|
||||
futures.reserve(winfos.size());
|
||||
for (const auto & info : winfos) {
|
||||
futures.push_back(std::async(std::launch::async, [info]() {
|
||||
ggml_backend_tensor_get(info.tensor, info.ptr, info.offset, info.size);
|
||||
}));
|
||||
}
|
||||
for (auto & f : futures) {
|
||||
f.wait();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void write(const void * src, size_t size) override {
|
||||
if (size > buf_size) {
|
||||
throw std::runtime_error("unexpectedly reached end of buffer");
|
||||
|
|
@ -2267,7 +2289,10 @@ public:
|
|||
if (size > buf_size) {
|
||||
throw std::runtime_error("unexpectedly reached end of buffer");
|
||||
}
|
||||
ggml_backend_tensor_get(tensor, ptr, offset, size);
|
||||
|
||||
// save the write for later during destruction
|
||||
winfos.push_back({tensor, ptr, size, offset});
|
||||
|
||||
ptr += size;
|
||||
size_written += size;
|
||||
buf_size -= size;
|
||||
|
|
@ -2281,25 +2306,48 @@ private:
|
|||
uint8_t * ptr;
|
||||
size_t buf_size = 0;
|
||||
size_t size_written = 0;
|
||||
|
||||
struct write_info {
|
||||
const ggml_tensor * tensor;
|
||||
uint8_t * ptr;
|
||||
size_t size;
|
||||
size_t offset;
|
||||
};
|
||||
std::vector<write_info> winfos;
|
||||
};
|
||||
|
||||
class llama_io_read_buffer : public llama_io_read_i {
|
||||
public:
|
||||
llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
||||
|
||||
const uint8_t * read(size_t size) override {
|
||||
const uint8_t * base_ptr = ptr;
|
||||
~llama_io_read_buffer() {
|
||||
// flush the reads
|
||||
for (const auto & info : rinfos) {
|
||||
ggml_backend_tensor_set(info.tensor, info.ptr, info.offset, info.size);
|
||||
}
|
||||
}
|
||||
|
||||
void read(void * dst, size_t size) override {
|
||||
if (size > buf_size) {
|
||||
throw std::runtime_error("unexpectedly reached end of buffer");
|
||||
}
|
||||
memcpy(dst, ptr, size);
|
||||
ptr += size;
|
||||
size_read += size;
|
||||
buf_size -= size;
|
||||
return base_ptr;
|
||||
}
|
||||
|
||||
void read_to(void * dst, size_t size) override {
|
||||
memcpy(dst, read(size), size);
|
||||
void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override {
|
||||
if (size > buf_size) {
|
||||
throw std::runtime_error("unexpectedly reached end of buffer");
|
||||
}
|
||||
|
||||
// save for later during destruction
|
||||
rinfos.push_back({tensor, ptr, size, offset});
|
||||
|
||||
ptr += size;
|
||||
size_read += size;
|
||||
buf_size -= size;
|
||||
}
|
||||
|
||||
size_t n_bytes() override {
|
||||
|
|
@ -2310,6 +2358,14 @@ private:
|
|||
const uint8_t * ptr;
|
||||
size_t buf_size = 0;
|
||||
size_t size_read = 0;
|
||||
|
||||
struct read_info {
|
||||
ggml_tensor * tensor;
|
||||
const uint8_t * ptr;
|
||||
size_t size;
|
||||
size_t offset;
|
||||
};
|
||||
std::vector<read_info> rinfos;
|
||||
};
|
||||
|
||||
class llama_io_write_file : public llama_io_write_i {
|
||||
|
|
@ -2341,15 +2397,15 @@ class llama_io_read_file : public llama_io_read_i {
|
|||
public:
|
||||
llama_io_read_file(llama_file * f) : file(f) {}
|
||||
|
||||
void read_to(void * dst, size_t size) override {
|
||||
void read(void * dst, size_t size) override {
|
||||
file->read_raw(dst, size);
|
||||
size_read += size;
|
||||
}
|
||||
|
||||
const uint8_t * read(size_t size) override {
|
||||
void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override {
|
||||
temp_buffer.resize(size);
|
||||
read_to(temp_buffer.data(), size);
|
||||
return temp_buffer.data();
|
||||
read(temp_buffer.data(), size);
|
||||
ggml_backend_tensor_set(tensor, temp_buffer.data(), offset, size);
|
||||
}
|
||||
|
||||
size_t n_bytes() override {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
#include "llama-io.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
void llama_io_write_i::write_string(const std::string & str) {
|
||||
uint32_t str_size = str.size();
|
||||
|
||||
|
|
@ -9,7 +11,10 @@ void llama_io_write_i::write_string(const std::string & str) {
|
|||
|
||||
void llama_io_read_i::read_string(std::string & str) {
|
||||
uint32_t str_size;
|
||||
read_to(&str_size, sizeof(str_size));
|
||||
read(&str_size, sizeof(str_size));
|
||||
|
||||
str.assign((const char *) read(str_size), str_size);
|
||||
std::vector<char> buf(str_size);
|
||||
read(buf.data(), str_size);
|
||||
|
||||
str.assign(buf.data(), str_size);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,8 +25,8 @@ public:
|
|||
llama_io_read_i() = default;
|
||||
virtual ~llama_io_read_i() = default;
|
||||
|
||||
virtual const uint8_t * read(size_t size) = 0;
|
||||
virtual void read_to(void * dst, size_t size) = 0;
|
||||
virtual void read(void * dst, size_t size) = 0;
|
||||
virtual void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) = 0;
|
||||
|
||||
// bytes read so far
|
||||
virtual size_t n_bytes() = 0;
|
||||
|
|
|
|||
|
|
@ -1900,14 +1900,14 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama
|
|||
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
|
||||
|
||||
uint32_t n_stream_cur;
|
||||
io.read_to(&n_stream_cur, sizeof(n_stream_cur));
|
||||
io.read(&n_stream_cur, sizeof(n_stream_cur));
|
||||
if (n_stream_cur != n_stream) {
|
||||
throw std::runtime_error("n_stream mismatch");
|
||||
}
|
||||
|
||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||
uint32_t cell_count;
|
||||
io.read_to(&cell_count, sizeof(cell_count));
|
||||
io.read(&cell_count, sizeof(cell_count));
|
||||
|
||||
if (cell_count == 0) {
|
||||
continue;
|
||||
|
|
@ -2082,8 +2082,8 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
|||
llama_pos pos;
|
||||
uint32_t n_seq_id;
|
||||
|
||||
io.read_to(&pos, sizeof(pos));
|
||||
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
io.read(&pos, sizeof(pos));
|
||||
io.read(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
if (n_seq_id != 1) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
||||
|
|
@ -2092,7 +2092,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
|||
|
||||
if (hparams.n_pos_per_embd() > 1) {
|
||||
llama_kv_cell_ext ext;
|
||||
io.read_to(&ext, sizeof(ext));
|
||||
io.read(&ext, sizeof(ext));
|
||||
|
||||
ubatch.pos[i + ubatch.n_tokens] = ext.y;
|
||||
ubatch.pos[i + ubatch.n_tokens*2] = ext.x;
|
||||
|
|
@ -2101,7 +2101,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
|||
// read the sequence id, but directly discard it - we will use dest_seq_id instead
|
||||
{
|
||||
llama_seq_id seq_id;
|
||||
io.read_to(&seq_id, sizeof(seq_id));
|
||||
io.read(&seq_id, sizeof(seq_id));
|
||||
}
|
||||
|
||||
ubatch.pos[i] = pos;
|
||||
|
|
@ -2143,20 +2143,20 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
|||
llama_pos pos;
|
||||
uint32_t n_seq_id;
|
||||
|
||||
io.read_to(&pos, sizeof(pos));
|
||||
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
io.read(&pos, sizeof(pos));
|
||||
io.read(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
cells.pos_set(i, pos);
|
||||
|
||||
if (hparams.n_pos_per_embd() > 1) {
|
||||
llama_kv_cell_ext ext;
|
||||
io.read_to(&ext, sizeof(ext));
|
||||
io.read(&ext, sizeof(ext));
|
||||
cells.ext_set(i, ext);
|
||||
}
|
||||
|
||||
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
||||
llama_seq_id seq_id;
|
||||
io.read_to(&seq_id, sizeof(seq_id));
|
||||
io.read(&seq_id, sizeof(seq_id));
|
||||
|
||||
if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
|
||||
|
|
@ -2189,8 +2189,8 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|||
uint32_t v_trans;
|
||||
uint32_t n_layer;
|
||||
|
||||
io.read_to(&v_trans, sizeof(v_trans));
|
||||
io.read_to(&n_layer, sizeof(n_layer));
|
||||
io.read(&v_trans, sizeof(v_trans));
|
||||
io.read(&n_layer, sizeof(n_layer));
|
||||
|
||||
if (n_layer != layers.size()) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
|
||||
|
|
@ -2217,7 +2217,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|||
|
||||
// Read type of key
|
||||
int32_t k_type_i_ref;
|
||||
io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
|
||||
io.read(&k_type_i_ref, sizeof(k_type_i_ref));
|
||||
const int32_t k_type_i = (int32_t) k->type;
|
||||
if (k_type_i != k_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
|
||||
|
|
@ -2226,7 +2226,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|||
|
||||
// Read row size of key
|
||||
uint64_t k_size_row_ref;
|
||||
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
|
||||
io.read(&k_size_row_ref, sizeof(k_size_row_ref));
|
||||
const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
|
||||
if (k_size_row != k_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
|
||||
|
|
@ -2236,13 +2236,12 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|||
if (cell_count) {
|
||||
if (sinfo.is_contiguous()) {
|
||||
// Fast path: contiguous cells, single memcpy
|
||||
ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row);
|
||||
io.read_tensor(k, sinfo.head() * k_size_row, cell_count * k_size_row);
|
||||
} else {
|
||||
// Slow path: scatter to non-contiguous positions
|
||||
const void * src = io.read(cell_count * k_size_row);
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
const size_t dst_offset = sinfo.idxs[0][i] * k_size_row;
|
||||
ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row);
|
||||
io.read_tensor(k, dst_offset, k_size_row);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2261,7 +2260,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
||||
io.read(&v_type_i_ref, sizeof(v_type_i_ref));
|
||||
const int32_t v_type_i = (int32_t) v->type;
|
||||
if (v_type_i != v_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
||||
|
|
@ -2270,7 +2269,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|||
|
||||
// Read row size of value
|
||||
uint64_t v_size_row_ref;
|
||||
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
|
||||
io.read(&v_size_row_ref, sizeof(v_size_row_ref));
|
||||
const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
|
||||
if (v_size_row != v_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
|
||||
|
|
@ -2280,13 +2279,12 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|||
if (cell_count) {
|
||||
if (sinfo.is_contiguous()) {
|
||||
// Fast path: contiguous cells, single memcpy
|
||||
ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row);
|
||||
io.read_tensor(v, sinfo.head() * v_size_row, cell_count * v_size_row);
|
||||
} else {
|
||||
// Slow path: scatter to non-contiguous positions
|
||||
const void * src = io.read(cell_count * v_size_row);
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
const size_t dst_offset = sinfo.idxs[0][i] * v_size_row;
|
||||
ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row);
|
||||
io.read_tensor(v, dst_offset, v_size_row);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2305,7 +2303,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
||||
io.read(&v_type_i_ref, sizeof(v_type_i_ref));
|
||||
const int32_t v_type_i = (int32_t) v->type;
|
||||
if (v_type_i != v_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
||||
|
|
@ -2314,7 +2312,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|||
|
||||
// Read element size of value
|
||||
uint32_t v_size_el_ref;
|
||||
io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
|
||||
io.read(&v_size_el_ref, sizeof(v_size_el_ref));
|
||||
const size_t v_size_el = ggml_type_size(v->type);
|
||||
if (v_size_el != v_size_el_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
|
||||
|
|
@ -2323,7 +2321,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|||
|
||||
// Read GQA embedding size
|
||||
uint32_t n_embd_v_gqa_ref;
|
||||
io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
|
||||
io.read(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
|
||||
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
|
||||
return false;
|
||||
|
|
@ -2335,15 +2333,14 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|||
const uint32_t h = sinfo.head();
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
const size_t dst_offset = (h + j * cells.size()) * v_size_el;
|
||||
ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
||||
io.read_tensor(v, dst_offset, cell_count * v_size_el);
|
||||
}
|
||||
} else {
|
||||
// Slow path: scatter to non-contiguous positions
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
const void * src = io.read(cell_count * v_size_el);
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el;
|
||||
ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el);
|
||||
io.read_tensor(v, dst_offset, v_size_el);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -743,7 +743,7 @@ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_i
|
|||
GGML_UNUSED(flags);
|
||||
|
||||
uint32_t cell_count;
|
||||
io.read_to(&cell_count, sizeof(cell_count));
|
||||
io.read(&cell_count, sizeof(cell_count));
|
||||
|
||||
bool res = true;
|
||||
|
||||
|
|
@ -879,8 +879,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|||
llama_pos pos;
|
||||
uint32_t n_seq_id;
|
||||
|
||||
io.read_to(&pos, sizeof(pos));
|
||||
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
io.read(&pos, sizeof(pos));
|
||||
io.read(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
if (n_seq_id != 0) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
||||
|
|
@ -920,14 +920,14 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|||
llama_pos pos;
|
||||
uint32_t n_seq_id;
|
||||
|
||||
io.read_to(&pos, sizeof(pos));
|
||||
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
io.read(&pos, sizeof(pos));
|
||||
io.read(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
cell.pos = pos;
|
||||
|
||||
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
||||
llama_seq_id seq_id;
|
||||
io.read_to(&seq_id, sizeof(seq_id));
|
||||
io.read(&seq_id, sizeof(seq_id));
|
||||
|
||||
if (seq_id < 0 || (uint32_t) seq_id >= this->n_seq_max) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, this->n_seq_max);
|
||||
|
|
@ -961,8 +961,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|||
bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
||||
uint32_t s_trans;
|
||||
uint32_t n_layer;
|
||||
io.read_to(&s_trans, sizeof(s_trans));
|
||||
io.read_to(&n_layer, sizeof(n_layer));
|
||||
io.read(&s_trans, sizeof(s_trans));
|
||||
io.read(&n_layer, sizeof(n_layer));
|
||||
|
||||
if (n_layer != hparams.n_layer) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
|
||||
|
|
@ -984,7 +984,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
|||
|
||||
// Read type of key
|
||||
int32_t r_type_i_ref;
|
||||
io.read_to(&r_type_i_ref, sizeof(r_type_i_ref));
|
||||
io.read(&r_type_i_ref, sizeof(r_type_i_ref));
|
||||
const int32_t r_type_i = (int32_t) r_l[il]->type;
|
||||
if (r_type_i != r_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il);
|
||||
|
|
@ -993,7 +993,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
|||
|
||||
// Read row size of key
|
||||
uint64_t r_size_row_ref;
|
||||
io.read_to(&r_size_row_ref, sizeof(r_size_row_ref));
|
||||
io.read(&r_size_row_ref, sizeof(r_size_row_ref));
|
||||
const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
|
||||
if (r_size_row != r_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
|
||||
|
|
@ -1002,7 +1002,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
|||
|
||||
if (cell_count) {
|
||||
// Read and set the keys for the whole cell range
|
||||
ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row);
|
||||
io.read_tensor(r_l[il], head * r_size_row, cell_count * r_size_row);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1013,7 +1013,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
|||
|
||||
// Read type of value
|
||||
int32_t s_type_i_ref;
|
||||
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||
io.read(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
||||
|
||||
if (s_type_i != s_type_i_ref) {
|
||||
|
|
@ -1023,7 +1023,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
|||
|
||||
// Read row size of value
|
||||
uint64_t s_size_row_ref;
|
||||
io.read_to(&s_size_row_ref, sizeof(s_size_row_ref));
|
||||
io.read(&s_size_row_ref, sizeof(s_size_row_ref));
|
||||
const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
||||
if (s_size_row != s_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
|
||||
|
|
@ -1032,7 +1032,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
|||
|
||||
if (cell_count) {
|
||||
// Read and set the values for the whole cell range
|
||||
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row);
|
||||
io.read_tensor(s_l[il], head * s_size_row, cell_count * s_size_row);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
@ -1045,7 +1045,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
|||
|
||||
// Read type of value
|
||||
int32_t s_type_i_ref;
|
||||
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||
io.read(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
||||
if (s_type_i != s_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
|
||||
|
|
@ -1054,7 +1054,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
|||
|
||||
// Read element size of value
|
||||
uint32_t s_size_el_ref;
|
||||
io.read_to(&s_size_el_ref, sizeof(s_size_el_ref));
|
||||
io.read(&s_size_el_ref, sizeof(s_size_el_ref));
|
||||
const size_t s_size_el = ggml_type_size(s_l[il]->type);
|
||||
if (s_size_el != s_size_el_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il);
|
||||
|
|
@ -1063,7 +1063,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
|||
|
||||
// Read state embedding size
|
||||
uint32_t n_embd_s_ref;
|
||||
io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref));
|
||||
io.read(&n_embd_s_ref, sizeof(n_embd_s_ref));
|
||||
if (n_embd_s != n_embd_s_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il);
|
||||
return false;
|
||||
|
|
@ -1073,7 +1073,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
|||
// For each row in the transposed matrix, read the values for the whole cell range
|
||||
for (uint32_t j = 0; j < n_embd_s; ++j) {
|
||||
const size_t dst_offset = (head + j * size) * s_size_el;
|
||||
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el);
|
||||
io.read_tensor(s_l[il], dst_offset, cell_count * s_size_el);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ using json = nlohmann::ordered_json;
|
|||
|
||||
constexpr int HTTP_POLLING_SECONDS = 1;
|
||||
|
||||
static server_prompt_checkpoint server_get_checkpoint(llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1) {
|
||||
static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1) {
|
||||
if (pos_min == -1) {
|
||||
pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id);
|
||||
}
|
||||
|
|
@ -46,19 +46,15 @@ static server_prompt_checkpoint server_get_checkpoint(llama_context * ctx, int i
|
|||
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
auto cur = server_prompt_checkpoint {
|
||||
/*.pos_min = */ pos_min,
|
||||
/*.pos_max = */ pos_max,
|
||||
/*.n_tokens = */ n_tokens,
|
||||
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||
};
|
||||
ckpt.pos_min = pos_min;
|
||||
ckpt.pos_max = pos_max;
|
||||
ckpt.n_tokens = n_tokens;
|
||||
ckpt.data.resize(checkpoint_size);
|
||||
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx, ckpt.data.data(), checkpoint_size, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
if (n != checkpoint_size) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
|
||||
}
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
|
||||
|
|
@ -364,7 +360,12 @@ struct server_slot {
|
|||
if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
||||
const auto n_tokens = prompt.tokens.size();
|
||||
|
||||
spec_ckpt = server_get_checkpoint(ctx, this->id, n_tokens);
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
server_prompt_checkpoint_update(spec_ckpt, ctx, this->id, n_tokens);
|
||||
|
||||
//const int64_t t_total = ggml_time_us() - t_start;
|
||||
//printf("checkpoint total: %f ms\n", t_total / 1000.0);
|
||||
|
||||
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
|
||||
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
|
||||
|
|
@ -1836,7 +1837,8 @@ private:
|
|||
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
|
||||
}
|
||||
|
||||
const auto & cur = slot.prompt.checkpoints.emplace_back(server_get_checkpoint(ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max));
|
||||
auto & cur = slot.prompt.checkpoints.emplace_back();
|
||||
server_prompt_checkpoint_update(cur, ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max);
|
||||
|
||||
SLT_WRN(slot,
|
||||
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue