server : add SWA checkpoints (#15293)

* server : add SWA checkpoints

ggml-ci

* cont : server clean-up

* server : handle state restore fails

* llama : add extended llama_state_seq_ API

* server : do not make checkpoints if --swa-full

ggml-ci

* llama : remove flags value for NONE

* server : configure number of SWA checkpoints with CLI arg

ggml-ci

* args : fix scope of new argument
This commit is contained in:
Georgi Gerganov 2025-08-14 14:59:50 +03:00 committed by GitHub
parent 3973163bff
commit d32e03f449
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 206 additions and 54 deletions

View file

@ -1507,6 +1507,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.swa_full = true;
}
).set_env("LLAMA_ARG_SWA_FULL"));
add_opt(common_arg(
{"--swa-checkpoints"}, "N",
string_format("max number of SWA checkpoints per slot to create (default: %d)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints),
[](common_params & params, int value) {
params.n_swa_checkpoints = value;
}
).set_env("LLAMA_ARG_SWA_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--kv-unified", "-kvu"},
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"

View file

@ -418,6 +418,7 @@ struct common_params {
int32_t timeout_write = timeout_read; // http write timeout in seconds
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot
std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT

View file

@ -870,6 +870,29 @@ extern "C" {
size_t n_token_capacity,
size_t * n_token_count_out);
#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1
typedef uint32_t llama_state_seq_flags;
LLAMA_API size_t llama_state_seq_get_size_ext(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags);
LLAMA_API size_t llama_state_seq_get_data_ext(
struct llama_context * ctx,
uint8_t * dst,
size_t size,
llama_seq_id seq_id,
llama_state_seq_flags flags);
LLAMA_API size_t llama_state_seq_set_data_ext(
struct llama_context * ctx,
const uint8_t * src,
size_t size,
llama_seq_id dest_seq_id,
llama_state_seq_flags flags);
//
// Decoding
//

View file

@ -1657,30 +1657,30 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
}
}
size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) {
llama_io_write_dummy io;
try {
return state_seq_write_data(io, seq_id);
return state_seq_write_data(io, seq_id, flags);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
return 0;
}
}
size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) {
llama_io_write_buffer io(dst, size);
try {
return state_seq_write_data(io, seq_id);
return state_seq_write_data(io, seq_id, flags);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
return 0;
}
}
size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) {
llama_io_read_buffer io(src, size);
try {
return state_seq_read_data(io, seq_id);
return state_seq_read_data(io, seq_id, flags);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
return 0;
@ -1778,7 +1778,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file
{
const size_t state_size = file.size() - file.tell();
llama_io_read_file io(&file);
const size_t nread = state_seq_read_data(io, seq_id);
const size_t nread = state_seq_read_data(io, seq_id, 0);
if (!nread) {
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
return 0;
@ -1802,7 +1802,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file
// save the context state using stream saving
llama_io_write_file io(&file);
state_seq_write_data(io, seq_id);
state_seq_write_data(io, seq_id, 0);
const size_t res = file.tell();
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
@ -1971,21 +1971,21 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
return io.n_bytes();
}
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(seq_id);
if (memory) {
memory->state_write(io, seq_id);
memory->state_write(io, seq_id, flags);
}
return io.n_bytes();
}
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(seq_id);
if (memory) {
memory->state_read(io, seq_id);
memory->state_read(io, seq_id, flags);
}
return io.n_bytes();
@ -2801,19 +2801,31 @@ bool llama_state_save_file(llama_context * ctx, const char * path_session, const
}
size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
return ctx->state_seq_get_size(seq_id);
return llama_state_seq_get_size_ext(ctx, seq_id, 0);
}
size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
ctx->synchronize();
return ctx->state_seq_get_data(seq_id, dst, size);
return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0);
}
size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0);
}
size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
return ctx->state_seq_get_size(seq_id, flags);
}
size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
ctx->synchronize();
return ctx->state_seq_set_data(seq_id, src, size);
return ctx->state_seq_get_data(seq_id, dst, size, flags);
}
size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
ctx->synchronize();
return ctx->state_seq_set_data(seq_id, src, size, flags);
}
size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {

View file

@ -111,9 +111,9 @@ struct llama_context {
size_t state_get_data( uint8_t * dst, size_t size);
size_t state_set_data(const uint8_t * src, size_t size);
size_t state_seq_get_size(llama_seq_id seq_id);
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags);
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags);
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags);
bool state_load_file(
const char * filepath,
@ -213,8 +213,8 @@ private:
size_t state_write_data(llama_io_write_i & io);
size_t state_read_data (llama_io_read_i & io);
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
//
// members

View file

@ -194,14 +194,20 @@ bool llama_kv_cache_unified_iswa::get_can_shift() const {
return kv_base->get_size() == kv_swa->get_size();
}
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
kv_base->state_write(io, seq_id);
kv_swa ->state_write(io, seq_id);
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
kv_base->state_write(io, seq_id, flags);
}
kv_swa->state_write(io, seq_id, flags);
}
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
kv_base->state_read(io, seq_id);
kv_swa ->state_read(io, seq_id);
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
kv_base->state_read(io, seq_id, flags);
}
kv_swa->state_read(io, seq_id, flags);
}
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {

View file

@ -56,8 +56,8 @@ public:
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
//
// llama_kv_cache_unified_iswa specific API

View file

@ -1828,7 +1828,9 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
return false;
}
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);
io.write(&n_stream, sizeof(n_stream));
for (uint32_t s = 0; s < n_stream; ++s) {
@ -1879,7 +1881,9 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
}
}
void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(flags);
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
uint32_t n_stream_cur;

View file

@ -136,8 +136,8 @@ public:
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
//
// llama_kv_cache_unified specific API

View file

@ -165,12 +165,16 @@ llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
}
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);
mem_attn->state_write(io, seq_id);
mem_recr->state_write(io, seq_id);
}
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(flags);
mem_attn->state_read(io, seq_id);
mem_recr->state_read(io, seq_id);
}

View file

@ -74,8 +74,8 @@ public:
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
//
// llama_memory_hybrid specific API

View file

@ -680,7 +680,9 @@ size_t llama_memory_recurrent::size_s_bytes() const {
return size_s_bytes;
}
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
uint32_t cell_count = 0;
@ -718,7 +720,9 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
state_write_data(io, cell_ranges);
}
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(flags);
uint32_t cell_count;
io.read_to(&cell_count, sizeof(cell_count));

View file

@ -63,8 +63,8 @@ public:
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
uint32_t size = 0; // total number of cells, shared across all sequences

View file

@ -104,8 +104,8 @@ struct llama_memory_i {
// state write/read
//
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0;
};
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;

View file

@ -692,6 +692,13 @@ struct completion_token_output {
}
};
struct swa_checkpoint {
llama_pos pos_min;
llama_pos pos_max;
std::vector<uint8_t> data;
};
struct server_task_result_cmpl_final : server_task_result {
int index = 0;
@ -1336,6 +1343,8 @@ struct server_slot {
std::vector<completion_token_output> generated_token_probs;
std::vector<swa_checkpoint> swa_checkpoints;
bool has_next_token = true;
bool has_new_line = false;
bool truncated = false;
@ -3293,6 +3302,8 @@ struct server_context {
slot.n_past = 0;
}
const auto n_swa = llama_model_n_swa(model);
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
if (pos_min == -1) {
@ -3300,12 +3311,58 @@ struct server_context {
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
}
const auto n_swa = llama_model_n_swa(model);
if (pos_min > std::max(0, slot.n_past - n_swa)) {
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
if (pos_min > pos_min_thold) {
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
// search for a SWA checkpoint
const auto it = std::find_if(
slot.swa_checkpoints.rbegin(),
slot.swa_checkpoints.rend(),
[&](const auto & cur) {
return cur.pos_min <= pos_min_thold;
}
);
bool do_reset = it == slot.swa_checkpoints.rend();
if (!do_reset) {
// restore the checkpoint
const size_t swa_size = it->data.size();
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
if (n != swa_size) {
SLT_ERR(slot, "failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
do_reset = true;
} else {
slot.n_past = std::min(slot.n_past, it->pos_max);
SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
}
}
if (do_reset) {
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
slot.n_past = 0;
slot.swa_checkpoints.clear();
}
}
}
if (n_swa > 0) {
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
// erase any checkpoints with pos_min > pos_min_thold
for (int i = (int) slot.swa_checkpoints.size() - 1; i >= 0; i--) {
const auto & cur = slot.swa_checkpoints[i];
if (cur.pos_min > pos_min_thold) {
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i);
SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
}
}
}
}
@ -3519,6 +3576,39 @@ struct server_context {
// prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING;
// make a checkpoint with the SWA memory
// checkpoints are needed only if we are not using "--swa-full"
if (llama_model_n_swa(model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0) {
if (slot.swa_checkpoints.size() >= (size_t) params_base.n_swa_checkpoints) {
{
const auto & cur = slot.swa_checkpoints.back();
SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
}
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin());
}
const size_t swa_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{
/*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
/*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
/*.data = */ std::vector<uint8_t>(swa_size),
});
llama_state_seq_get_data_ext(ctx, cur.data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
float size_total = 0.0f;
for (const auto & checkpoint : slot.swa_checkpoints) {
size_total += (float) checkpoint.data.size() / 1024 / 1024;
}
SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d (%.3f MiB)\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024, (int) slot.swa_checkpoints.size(), params_base.n_swa_checkpoints, size_total);
}
} else if (slot.state != SLOT_STATE_GENERATING) {
continue; // continue loop of slots
}