mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 09:34:37 +00:00
server : add SWA checkpoints (#15293)
* server : add SWA checkpoints ggml-ci * cont : server clean-up * server : handle state restore fails * llama : add extended llama_state_seq_ API * server : do not make checkpoints if --swa-full ggml-ci * llama : remove flags value for NONE * server : configure number of SWA checkpoints with CLI arg ggml-ci * args : fix scope of new argument
This commit is contained in:
parent
3973163bff
commit
d32e03f449
15 changed files with 206 additions and 54 deletions
|
@ -1507,6 +1507,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
params.swa_full = true;
|
params.swa_full = true;
|
||||||
}
|
}
|
||||||
).set_env("LLAMA_ARG_SWA_FULL"));
|
).set_env("LLAMA_ARG_SWA_FULL"));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--swa-checkpoints"}, "N",
|
||||||
|
string_format("max number of SWA checkpoints per slot to create (default: %d)\n"
|
||||||
|
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints),
|
||||||
|
[](common_params & params, int value) {
|
||||||
|
params.n_swa_checkpoints = value;
|
||||||
|
}
|
||||||
|
).set_env("LLAMA_ARG_SWA_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--kv-unified", "-kvu"},
|
{"--kv-unified", "-kvu"},
|
||||||
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
|
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
|
||||||
|
|
|
@ -413,11 +413,12 @@ struct common_params {
|
||||||
std::string cls_sep = "\t"; // separator of classification sequences
|
std::string cls_sep = "\t"; // separator of classification sequences
|
||||||
|
|
||||||
// server params
|
// server params
|
||||||
int32_t port = 8080; // server listens on this network port
|
int32_t port = 8080; // server listens on this network port
|
||||||
int32_t timeout_read = 600; // http read timeout in seconds
|
int32_t timeout_read = 600; // http read timeout in seconds
|
||||||
int32_t timeout_write = timeout_read; // http write timeout in seconds
|
int32_t timeout_write = timeout_read; // http write timeout in seconds
|
||||||
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
||||||
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
|
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
|
||||||
|
int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot
|
||||||
|
|
||||||
std::string hostname = "127.0.0.1";
|
std::string hostname = "127.0.0.1";
|
||||||
std::string public_path = ""; // NOLINT
|
std::string public_path = ""; // NOLINT
|
||||||
|
|
|
@ -870,6 +870,29 @@ extern "C" {
|
||||||
size_t n_token_capacity,
|
size_t n_token_capacity,
|
||||||
size_t * n_token_count_out);
|
size_t * n_token_count_out);
|
||||||
|
|
||||||
|
#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1
|
||||||
|
|
||||||
|
typedef uint32_t llama_state_seq_flags;
|
||||||
|
|
||||||
|
LLAMA_API size_t llama_state_seq_get_size_ext(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_state_seq_flags flags);
|
||||||
|
|
||||||
|
LLAMA_API size_t llama_state_seq_get_data_ext(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
uint8_t * dst,
|
||||||
|
size_t size,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_state_seq_flags flags);
|
||||||
|
|
||||||
|
LLAMA_API size_t llama_state_seq_set_data_ext(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
const uint8_t * src,
|
||||||
|
size_t size,
|
||||||
|
llama_seq_id dest_seq_id,
|
||||||
|
llama_state_seq_flags flags);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Decoding
|
// Decoding
|
||||||
//
|
//
|
||||||
|
|
|
@ -1657,30 +1657,30 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
|
size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
llama_io_write_dummy io;
|
llama_io_write_dummy io;
|
||||||
try {
|
try {
|
||||||
return state_seq_write_data(io, seq_id);
|
return state_seq_write_data(io, seq_id, flags);
|
||||||
} catch (const std::exception & err) {
|
} catch (const std::exception & err) {
|
||||||
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
|
size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) {
|
||||||
llama_io_write_buffer io(dst, size);
|
llama_io_write_buffer io(dst, size);
|
||||||
try {
|
try {
|
||||||
return state_seq_write_data(io, seq_id);
|
return state_seq_write_data(io, seq_id, flags);
|
||||||
} catch (const std::exception & err) {
|
} catch (const std::exception & err) {
|
||||||
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
|
size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) {
|
||||||
llama_io_read_buffer io(src, size);
|
llama_io_read_buffer io(src, size);
|
||||||
try {
|
try {
|
||||||
return state_seq_read_data(io, seq_id);
|
return state_seq_read_data(io, seq_id, flags);
|
||||||
} catch (const std::exception & err) {
|
} catch (const std::exception & err) {
|
||||||
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -1778,7 +1778,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file
|
||||||
{
|
{
|
||||||
const size_t state_size = file.size() - file.tell();
|
const size_t state_size = file.size() - file.tell();
|
||||||
llama_io_read_file io(&file);
|
llama_io_read_file io(&file);
|
||||||
const size_t nread = state_seq_read_data(io, seq_id);
|
const size_t nread = state_seq_read_data(io, seq_id, 0);
|
||||||
if (!nread) {
|
if (!nread) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -1802,7 +1802,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file
|
||||||
|
|
||||||
// save the context state using stream saving
|
// save the context state using stream saving
|
||||||
llama_io_write_file io(&file);
|
llama_io_write_file io(&file);
|
||||||
state_seq_write_data(io, seq_id);
|
state_seq_write_data(io, seq_id, 0);
|
||||||
|
|
||||||
const size_t res = file.tell();
|
const size_t res = file.tell();
|
||||||
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
|
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
|
||||||
|
@ -1971,21 +1971,21 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
||||||
return io.n_bytes();
|
return io.n_bytes();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
GGML_UNUSED(seq_id);
|
GGML_UNUSED(seq_id);
|
||||||
|
|
||||||
if (memory) {
|
if (memory) {
|
||||||
memory->state_write(io, seq_id);
|
memory->state_write(io, seq_id, flags);
|
||||||
}
|
}
|
||||||
|
|
||||||
return io.n_bytes();
|
return io.n_bytes();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
GGML_UNUSED(seq_id);
|
GGML_UNUSED(seq_id);
|
||||||
|
|
||||||
if (memory) {
|
if (memory) {
|
||||||
memory->state_read(io, seq_id);
|
memory->state_read(io, seq_id, flags);
|
||||||
}
|
}
|
||||||
|
|
||||||
return io.n_bytes();
|
return io.n_bytes();
|
||||||
|
@ -2801,19 +2801,31 @@ bool llama_state_save_file(llama_context * ctx, const char * path_session, const
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
|
size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
|
||||||
return ctx->state_seq_get_size(seq_id);
|
return llama_state_seq_get_size_ext(ctx, seq_id, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
|
size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
|
||||||
ctx->synchronize();
|
return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0);
|
||||||
|
|
||||||
return ctx->state_seq_get_data(seq_id, dst, size);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
|
size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
|
||||||
|
return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
|
return ctx->state_seq_get_size(seq_id, flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
ctx->synchronize();
|
ctx->synchronize();
|
||||||
|
|
||||||
return ctx->state_seq_set_data(seq_id, src, size);
|
return ctx->state_seq_get_data(seq_id, dst, size, flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
|
ctx->synchronize();
|
||||||
|
|
||||||
|
return ctx->state_seq_set_data(seq_id, src, size, flags);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
|
size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
|
||||||
|
|
|
@ -111,9 +111,9 @@ struct llama_context {
|
||||||
size_t state_get_data( uint8_t * dst, size_t size);
|
size_t state_get_data( uint8_t * dst, size_t size);
|
||||||
size_t state_set_data(const uint8_t * src, size_t size);
|
size_t state_set_data(const uint8_t * src, size_t size);
|
||||||
|
|
||||||
size_t state_seq_get_size(llama_seq_id seq_id);
|
size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags);
|
||||||
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
|
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags);
|
||||||
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
|
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags);
|
||||||
|
|
||||||
bool state_load_file(
|
bool state_load_file(
|
||||||
const char * filepath,
|
const char * filepath,
|
||||||
|
@ -213,8 +213,8 @@ private:
|
||||||
size_t state_write_data(llama_io_write_i & io);
|
size_t state_write_data(llama_io_write_i & io);
|
||||||
size_t state_read_data (llama_io_read_i & io);
|
size_t state_read_data (llama_io_read_i & io);
|
||||||
|
|
||||||
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
|
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
|
||||||
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
|
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
|
||||||
|
|
||||||
//
|
//
|
||||||
// members
|
// members
|
||||||
|
|
|
@ -194,14 +194,20 @@ bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
||||||
return kv_base->get_size() == kv_swa->get_size();
|
return kv_base->get_size() == kv_swa->get_size();
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||||
kv_base->state_write(io, seq_id);
|
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
|
||||||
kv_swa ->state_write(io, seq_id);
|
kv_base->state_write(io, seq_id, flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
kv_swa->state_write(io, seq_id, flags);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
kv_base->state_read(io, seq_id);
|
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
|
||||||
kv_swa ->state_read(io, seq_id);
|
kv_base->state_read(io, seq_id, flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
kv_swa->state_read(io, seq_id, flags);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
|
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
|
||||||
|
|
|
@ -56,8 +56,8 @@ public:
|
||||||
|
|
||||||
// state write/load
|
// state write/load
|
||||||
|
|
||||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
|
||||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_kv_cache_unified_iswa specific API
|
// llama_kv_cache_unified_iswa specific API
|
||||||
|
|
|
@ -1828,7 +1828,9 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||||
|
GGML_UNUSED(flags);
|
||||||
|
|
||||||
io.write(&n_stream, sizeof(n_stream));
|
io.write(&n_stream, sizeof(n_stream));
|
||||||
|
|
||||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||||
|
@ -1879,7 +1881,9 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
|
GGML_UNUSED(flags);
|
||||||
|
|
||||||
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
|
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
|
||||||
|
|
||||||
uint32_t n_stream_cur;
|
uint32_t n_stream_cur;
|
||||||
|
|
|
@ -136,8 +136,8 @@ public:
|
||||||
|
|
||||||
// state write/load
|
// state write/load
|
||||||
|
|
||||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
|
||||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_kv_cache_unified specific API
|
// llama_kv_cache_unified specific API
|
||||||
|
|
|
@ -165,12 +165,16 @@ llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
|
||||||
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
|
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||||
|
GGML_UNUSED(flags);
|
||||||
|
|
||||||
mem_attn->state_write(io, seq_id);
|
mem_attn->state_write(io, seq_id);
|
||||||
mem_recr->state_write(io, seq_id);
|
mem_recr->state_write(io, seq_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
|
GGML_UNUSED(flags);
|
||||||
|
|
||||||
mem_attn->state_read(io, seq_id);
|
mem_attn->state_read(io, seq_id);
|
||||||
mem_recr->state_read(io, seq_id);
|
mem_recr->state_read(io, seq_id);
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,8 +74,8 @@ public:
|
||||||
|
|
||||||
// state write/load
|
// state write/load
|
||||||
|
|
||||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
|
||||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_memory_hybrid specific API
|
// llama_memory_hybrid specific API
|
||||||
|
|
|
@ -680,7 +680,9 @@ size_t llama_memory_recurrent::size_s_bytes() const {
|
||||||
return size_s_bytes;
|
return size_s_bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||||
|
GGML_UNUSED(flags);
|
||||||
|
|
||||||
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
||||||
uint32_t cell_count = 0;
|
uint32_t cell_count = 0;
|
||||||
|
|
||||||
|
@ -718,7 +720,9 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
|
||||||
state_write_data(io, cell_ranges);
|
state_write_data(io, cell_ranges);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
|
GGML_UNUSED(flags);
|
||||||
|
|
||||||
uint32_t cell_count;
|
uint32_t cell_count;
|
||||||
io.read_to(&cell_count, sizeof(cell_count));
|
io.read_to(&cell_count, sizeof(cell_count));
|
||||||
|
|
||||||
|
|
|
@ -63,8 +63,8 @@ public:
|
||||||
|
|
||||||
// state write/load
|
// state write/load
|
||||||
|
|
||||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
|
||||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
||||||
|
|
||||||
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
||||||
uint32_t size = 0; // total number of cells, shared across all sequences
|
uint32_t size = 0; // total number of cells, shared across all sequences
|
||||||
|
|
|
@ -104,8 +104,8 @@ struct llama_memory_i {
|
||||||
// state write/read
|
// state write/read
|
||||||
//
|
//
|
||||||
|
|
||||||
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
|
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const = 0;
|
||||||
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
|
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
|
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
|
||||||
|
|
|
@ -692,6 +692,13 @@ struct completion_token_output {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct swa_checkpoint {
|
||||||
|
llama_pos pos_min;
|
||||||
|
llama_pos pos_max;
|
||||||
|
|
||||||
|
std::vector<uint8_t> data;
|
||||||
|
};
|
||||||
|
|
||||||
struct server_task_result_cmpl_final : server_task_result {
|
struct server_task_result_cmpl_final : server_task_result {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
|
|
||||||
|
@ -1336,6 +1343,8 @@ struct server_slot {
|
||||||
|
|
||||||
std::vector<completion_token_output> generated_token_probs;
|
std::vector<completion_token_output> generated_token_probs;
|
||||||
|
|
||||||
|
std::vector<swa_checkpoint> swa_checkpoints;
|
||||||
|
|
||||||
bool has_next_token = true;
|
bool has_next_token = true;
|
||||||
bool has_new_line = false;
|
bool has_new_line = false;
|
||||||
bool truncated = false;
|
bool truncated = false;
|
||||||
|
@ -3293,6 +3302,8 @@ struct server_context {
|
||||||
slot.n_past = 0;
|
slot.n_past = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const auto n_swa = llama_model_n_swa(model);
|
||||||
|
|
||||||
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
|
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
|
||||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||||
if (pos_min == -1) {
|
if (pos_min == -1) {
|
||||||
|
@ -3300,12 +3311,58 @@ struct server_context {
|
||||||
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
|
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto n_swa = llama_model_n_swa(model);
|
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
|
||||||
if (pos_min > std::max(0, slot.n_past - n_swa)) {
|
|
||||||
|
if (pos_min > pos_min_thold) {
|
||||||
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
|
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
|
||||||
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
|
|
||||||
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
// search for a SWA checkpoint
|
||||||
slot.n_past = 0;
|
const auto it = std::find_if(
|
||||||
|
slot.swa_checkpoints.rbegin(),
|
||||||
|
slot.swa_checkpoints.rend(),
|
||||||
|
[&](const auto & cur) {
|
||||||
|
return cur.pos_min <= pos_min_thold;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
bool do_reset = it == slot.swa_checkpoints.rend();
|
||||||
|
|
||||||
|
if (!do_reset) {
|
||||||
|
// restore the checkpoint
|
||||||
|
const size_t swa_size = it->data.size();
|
||||||
|
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
|
||||||
|
|
||||||
|
if (n != swa_size) {
|
||||||
|
SLT_ERR(slot, "failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
|
||||||
|
do_reset = true;
|
||||||
|
} else {
|
||||||
|
slot.n_past = std::min(slot.n_past, it->pos_max);
|
||||||
|
|
||||||
|
SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (do_reset) {
|
||||||
|
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
|
||||||
|
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||||
|
|
||||||
|
slot.n_past = 0;
|
||||||
|
slot.swa_checkpoints.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n_swa > 0) {
|
||||||
|
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
|
||||||
|
|
||||||
|
// erase any checkpoints with pos_min > pos_min_thold
|
||||||
|
for (int i = (int) slot.swa_checkpoints.size() - 1; i >= 0; i--) {
|
||||||
|
const auto & cur = slot.swa_checkpoints[i];
|
||||||
|
if (cur.pos_min > pos_min_thold) {
|
||||||
|
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i);
|
||||||
|
|
||||||
|
SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3519,6 +3576,39 @@ struct server_context {
|
||||||
|
|
||||||
// prompt evaluated for next-token prediction
|
// prompt evaluated for next-token prediction
|
||||||
slot.state = SLOT_STATE_GENERATING;
|
slot.state = SLOT_STATE_GENERATING;
|
||||||
|
|
||||||
|
// make a checkpoint with the SWA memory
|
||||||
|
// checkpoints are needed only if we are not using "--swa-full"
|
||||||
|
if (llama_model_n_swa(model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0) {
|
||||||
|
if (slot.swa_checkpoints.size() >= (size_t) params_base.n_swa_checkpoints) {
|
||||||
|
{
|
||||||
|
const auto & cur = slot.swa_checkpoints.back();
|
||||||
|
|
||||||
|
SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n",
|
||||||
|
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
||||||
|
}
|
||||||
|
|
||||||
|
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t swa_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
|
||||||
|
|
||||||
|
auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{
|
||||||
|
/*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
|
||||||
|
/*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
|
||||||
|
/*.data = */ std::vector<uint8_t>(swa_size),
|
||||||
|
});
|
||||||
|
|
||||||
|
llama_state_seq_get_data_ext(ctx, cur.data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
|
||||||
|
|
||||||
|
float size_total = 0.0f;
|
||||||
|
for (const auto & checkpoint : slot.swa_checkpoints) {
|
||||||
|
size_total += (float) checkpoint.data.size() / 1024 / 1024;
|
||||||
|
}
|
||||||
|
|
||||||
|
SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d (%.3f MiB)\n",
|
||||||
|
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024, (int) slot.swa_checkpoints.size(), params_base.n_swa_checkpoints, size_total);
|
||||||
|
}
|
||||||
} else if (slot.state != SLOT_STATE_GENERATING) {
|
} else if (slot.state != SLOT_STATE_GENERATING) {
|
||||||
continue; // continue loop of slots
|
continue; // continue loop of slots
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue