fix context shifting

This commit is contained in:
Lizonghang 2025-05-19 16:58:35 +04:00
parent 07c4966a80
commit c54a6a0132
8 changed files with 397 additions and 73 deletions

View file

@ -121,6 +121,17 @@ struct Timer {
// helpers
//
template<typename LocalFn, typename RemoteFn>
bool kv_cache_op(bool flag,
LocalFn local_fn,
RemoteFn remote_fn,
bool is_last_dev) {
if (!flag) return false;
local_fn();
if (!is_last_dev) remote_fn();
return true;
}
// trim whitespace from the beginning and end of a string
static std::string trim(const std::string & str) {
size_t start = 0;
@ -4157,7 +4168,7 @@ static bool llama_kv_cache_find_slot(
}
if (n_tested >= cache.size) {
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
return false;
}
}
@ -10629,7 +10640,7 @@ struct llm_build_context {
cb(lctx.inp_K_shift, "K_shift", -1);
ggml_set_input(lctx.inp_K_shift);
for (int il = 0; il < n_layer; ++il) {
for (int il = 0; il < (int)kv_self.k_l.size(); ++il) {
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
struct ggml_tensor * rope_factors = build_rope_factors(il);
@ -10642,13 +10653,19 @@ struct llm_build_context {
struct ggml_tensor * tmp;
if (ggml_is_quantized(k->type)) {
#ifdef GGML_USE_METAL
GGML_ABORT("The option --cache-type-k is not supported on Metal\n");
#endif
// dequantize to f32 -> RoPE -> quantize back
tmp = ggml_cast(ctx0, k, GGML_TYPE_F32);
cb(tmp, "K_f32", il);
for (auto * backend : lctx.backends) {
// Figure out which backend KV cache belongs to
if (ggml_backend_supports_buft(backend, lctx.model.buft_layer[il].buft)) {
ggml_backend_sched_set_tensor_backend(lctx.sched.at(0), tmp, backend); // todo.
ggml_backend_sched_set_tensor_backend(lctx.sched[0], tmp, backend);
break;
}
}
@ -17769,7 +17786,39 @@ struct input_tensors {
};
struct sync_meta {
int32_t n_tokens = 0;
int32_t n_tokens = 0;
llama_pos * pos = nullptr;
uint32_t n_ctx = 0;
// signal to clear the kv cache
bool clear_kv_cache = false;
// signal to remove a kv cache sequence
bool kv_seq_rm = false;
llama_seq_id rm_seq_id = 0;
llama_pos rm_p0 = 0;
llama_pos rm_p1 = 0;
// signal to add a kv cache sequence
bool kv_seq_add = false;
llama_seq_id add_seq_id = 0;
llama_pos add_p0 = 0;
llama_pos add_p1 = 0;
llama_pos add_delta = 0;
// signal to copy a kv cache sequence
bool kv_seq_cp = false;
llama_seq_id cp_src_seq_id = 0;
llama_seq_id cp_dst_seq_id = 0;
llama_pos cp_p0 = 0;
llama_pos cp_p1 = 0;
// signal to divide the kv cache range
bool kv_seq_div = false;
llama_seq_id div_seq_id = 0;
llama_pos div_p0 = 0;
llama_pos div_p1 = 0;
int div_factor = 1;
};
static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
@ -17781,6 +17830,11 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
send_msgs.emplace_back("n_tokens", strlen("n_tokens"));
send_msgs.emplace_back(&(meta->n_tokens), sizeof(meta->n_tokens));
if (meta->pos != nullptr) {
send_msgs.emplace_back("pos", strlen("pos"));
send_msgs.emplace_back(meta->pos, meta->n_ctx * sizeof(llama_pos));
}
zmq::send_multipart(socket, send_msgs);
} catch (const zmq::error_t& e) {
LLAMA_LOG_INFO("Failed to send meta data: %s\n", e.what());
@ -17797,6 +17851,49 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
socket.set(zmq::sockopt::rcvtimeo, -1);
const std::string cmd = recv_msgs[0].to_string();
size_t idx = 1;
if (cmd == "clear_kv_cache" && recv_msgs.size() == 1) {
meta->clear_kv_cache = true;
return 0;
}
if (cmd == "kv_seq_rm" && recv_msgs.size() == 4) {
meta->kv_seq_rm = true;
std::memcpy(&meta->rm_seq_id, recv_msgs[idx++].data(), sizeof(meta->rm_seq_id));
std::memcpy(&meta->rm_p0, recv_msgs[idx++].data(), sizeof(meta->rm_p0));
std::memcpy(&meta->rm_p1, recv_msgs[idx++].data(), sizeof(meta->rm_p1));
return 0;
}
if (cmd == "kv_seq_add" && recv_msgs.size() == 5) {
meta->kv_seq_add = true;
std::memcpy(&meta->add_seq_id, recv_msgs[idx++].data(), sizeof(meta->add_seq_id));
std::memcpy(&meta->add_p0, recv_msgs[idx++].data(), sizeof(meta->add_p0));
std::memcpy(&meta->add_p1, recv_msgs[idx++].data(), sizeof(meta->add_p1));
std::memcpy(&meta->add_delta, recv_msgs[idx++].data(), sizeof(meta->add_delta));
return 0;
}
if (cmd == "kv_seq_cp" && recv_msgs.size() == 5) {
meta->kv_seq_cp = true;
std::memcpy(&meta->cp_src_seq_id, recv_msgs[idx++].data(), sizeof(meta->cp_src_seq_id));
std::memcpy(&meta->cp_dst_seq_id, recv_msgs[idx++].data(), sizeof(meta->cp_dst_seq_id));
std::memcpy(&meta->cp_p0, recv_msgs[idx++].data(), sizeof(meta->cp_p0));
std::memcpy(&meta->cp_p1, recv_msgs[idx++].data(), sizeof(meta->cp_p1));
return 0;
}
if (cmd == "kv_seq_div" && recv_msgs.size() == 5) {
meta->kv_seq_div = true;
std::memcpy(&meta->div_seq_id, recv_msgs[idx++].data(), sizeof(meta->div_seq_id));
std::memcpy(&meta->div_p0, recv_msgs[idx++].data(), sizeof(meta->div_p0));
std::memcpy(&meta->div_p1, recv_msgs[idx++].data(), sizeof(meta->div_p1));
std::memcpy(&meta->div_factor, recv_msgs[idx++].data(), sizeof(meta->div_factor));
return 0;
}
for (size_t i = 0; i < recv_msgs.size(); i += 2) {
std::string key = recv_msgs[i].to_string();
zmq::message_t & data_msg = recv_msgs[i + 1];
@ -17805,6 +17902,11 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
GGML_ASSERT(data_msg.size() == sizeof(meta->n_tokens));
std::memcpy(&(meta->n_tokens), data_msg.data(), sizeof(meta->n_tokens));
}
if (key == "pos") {
meta->pos = (llama_pos *) malloc(meta->n_ctx * sizeof(llama_pos));
std::memcpy(meta->pos, data_msg.data(), meta->n_ctx * sizeof(llama_pos));
}
}
return 0;
}
@ -18069,15 +18171,66 @@ static int llama_decode_internal(
}
sync_meta meta;
meta.n_ctx = cparams.n_ctx;
bool is_last_dev = (my_rank == n_world - 1);
if (my_rank != 0) {
if (llama_recv_meta(*lctx.recv_socket, &meta) == -1) {
return -1;
}
batch_all.n_tokens = meta.n_tokens;
if (meta.n_tokens > 0) {
batch_all.n_tokens = meta.n_tokens;
if (meta.pos != nullptr) {
batch_all.pos = (llama_pos *) malloc(cparams.n_ctx * sizeof(llama_pos));
std::memcpy(batch_all.pos, meta.pos, cparams.n_ctx * sizeof(llama_pos));
}
}
if (kv_cache_op(meta.clear_kv_cache,
[&]{ llama_kv_cache_clear (&lctx); },
[&]{ llama_send_kv_cache_clear (&lctx); },
is_last_dev)) {
LLAMA_LOG_INFO("%s: received signal kv_cache_clear\n", __func__);
return -1;
}
if (kv_cache_op(meta.kv_seq_rm,
[&]{ llama_kv_cache_seq_rm (&lctx, meta.rm_seq_id, meta.rm_p0, meta.rm_p1); },
[&]{ llama_send_kv_cache_seq_rm (&lctx, meta.rm_seq_id, meta.rm_p0, meta.rm_p1); },
is_last_dev)) {
LLAMA_LOG_INFO("%s: received signal kv_cache_seq_rm\n", __func__);
return -1;
}
if (kv_cache_op(meta.kv_seq_add,
[&]{ llama_kv_cache_seq_add (&lctx, meta.add_seq_id, meta.add_p0, meta.add_p1, meta.add_delta); },
[&]{ llama_send_kv_cache_seq_add(&lctx, meta.add_seq_id, meta.add_p0, meta.add_p1, meta.add_delta); },
is_last_dev)) {
LLAMA_LOG_INFO("%s: received signal kv_cache_seq_add\n", __func__);
return -1;
}
if (kv_cache_op(meta.kv_seq_cp,
[&]{ llama_kv_cache_seq_cp (&lctx, meta.cp_src_seq_id, meta.cp_dst_seq_id, meta.cp_p0, meta.cp_p1); },
[&]{ llama_send_kv_cache_seq_cp (&lctx, meta.cp_src_seq_id, meta.cp_dst_seq_id, meta.cp_p0, meta.cp_p1); },
is_last_dev)) {
LLAMA_LOG_INFO("%s: received signal kv_cache_seq_cp\n", __func__);
return -1;
}
if (kv_cache_op(meta.kv_seq_div,
[&]{ llama_kv_cache_seq_div (&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); },
[&]{ llama_send_kv_cache_seq_div(&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); },
is_last_dev)) {
LLAMA_LOG_INFO("%s: received signal kv_cache_seq_div\n", __func__);
return -1;
}
}
if (my_rank != n_world - 1) {
if (!is_last_dev) {
meta.n_tokens = batch_all.n_tokens;
meta.pos = batch_all.pos;
llama_send_meta(*lctx.send_socket, &meta);
}
@ -18803,22 +18956,20 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
// apply K-shift if needed
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
throw std::runtime_error("shift not supported\n");
if (lctx.model.arch == LLM_ARCH_DEEPSEEK2) { // not supported due to MLA
GGML_ABORT("Deepseek2 does not support K-shift");
}
{
ggml_backend_sched_reset(lctx.sched.at(0)); // todo.
for (size_t i = 0; i < lctx.sched.size(); ++i) {
ggml_backend_sched_reset(lctx.sched[i]);
ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
ggml_backend_sched_alloc_graph(lctx.sched.at(0), gf); // todo.
ggml_backend_sched_alloc_graph(lctx.sched[i], gf);
llama_set_k_shift(lctx);
llama_graph_compute(lctx, gf, lctx.sched.at(0), lctx.cparams.n_threads, lctx.threadpool); // todo.
llama_graph_compute(lctx, gf, lctx.sched[i], lctx.cparams.n_threads, lctx.threadpool);
need_reserve = true;
}
@ -18845,8 +18996,6 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
// reserve a worst case graph again
if (need_reserve) {
throw std::runtime_error("reserve not supported\n");
// TODO: extract to a function
// build worst-case graph
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
@ -18854,13 +19003,11 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
std::vector<ggml_cgraph *> gf = llama_build_graph(lctx, ubatch, true);
// initialize scheduler with the worst-case graph
ggml_backend_sched_reset(lctx.sched[0]); // todo.
GGML_ASSERT(lctx.sched.size() == gf.size());
bool ok = true;
GGML_ASSERT(lctx.sched.size() == gf.size());
for (size_t i = 0; i < gf.size(); ++i) {
ggml_backend_sched_reset(lctx.sched[i]);
ok = ok & ggml_backend_sched_reserve(lctx.sched[i], gf[i]);
}
if (!ok) {
@ -20201,6 +20348,8 @@ void llama_init_sockets(struct llama_context * ctx, uint32_t n_world, uint32_t m
LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what());
exit(1);
}
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
int llama_gather_device_info(struct llama_context * ctx, struct device_info * dev_info_set) {
@ -20264,36 +20413,47 @@ int llama_send_device_info(struct llama_context * ctx, struct device_info * dev_
int llama_bcast_startup_args(llama_context * ctx, uint32_t rank, startup_args * args) {
int32_t n_world = ctx->cparams.n_world;
if (n_world == 1) {
return 0;
}
GGML_ASSERT(n_world > 0);
GGML_ASSERT(ctx != nullptr && ctx->send_socket != nullptr);
if (rank == 0){
// send
try {
std::vector<zmq::message_t> send_msgs;
send_msgs.emplace_back("should_profile", strlen("should_profile"));
send_msgs.emplace_back(&args->should_profile, sizeof(args->should_profile));
send_msgs.emplace_back("n_ctx", strlen("n_ctx"));
send_msgs.emplace_back(&args->n_ctx, sizeof(args->n_ctx));
zmq::send_multipart(*ctx->send_socket, send_msgs);
} catch (const zmq::error_t& e) {
LLAMA_LOG_INFO("Failed to send data: %s\n", e.what());
return -1;
}
}else {
} else {
// receive
std::vector<zmq::message_t> recv_msgs;
if (!zmq::recv_multipart(*ctx->recv_socket, std::back_inserter(recv_msgs))) {
return -1;
}
GGML_ASSERT(recv_msgs[0].to_string() == "should_profile");
GGML_ASSERT(recv_msgs[1].size() == sizeof(bool));
bool should_profile = *static_cast<bool*>(recv_msgs[1].data());
args->should_profile = should_profile;
GGML_ASSERT(recv_msgs[2].to_string() == "n_ctx");
GGML_ASSERT(recv_msgs[3].size() == sizeof(uint32_t));
uint32_t n_ctx = *static_cast<uint32_t*>(recv_msgs[3].data());
args->n_ctx = n_ctx;
if ((int)rank != (int)n_world - 1){
// send
try {
zmq::send_multipart(*ctx->send_socket, recv_msgs);
} catch (const zmq::error_t& e) {
} catch (const zmq::error_t & e) {
LLAMA_LOG_INFO("Failed to send data: %s\n", e.what());
return -1;
}
@ -21910,10 +22070,42 @@ void llama_kv_cache_clear(struct llama_context * ctx) {
llama_kv_cache_clear(ctx->kv_self);
}
void llama_send_kv_cache_clear(struct llama_context * ctx) {
if (ctx->send_socket == nullptr) {
return;
}
try {
std::vector<zmq::message_t> send_msgs;
const char * cmd = "clear_kv_cache";
send_msgs.emplace_back(cmd, strlen(cmd));
zmq::send_multipart(*ctx->send_socket, send_msgs);
} catch (const zmq::error_t & e) {
LLAMA_LOG_INFO("Failed to send KV cache clear signal: %s\n", e.what());
}
}
bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
}
void llama_send_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
if (ctx->send_socket == nullptr) {
return;
}
try {
std::vector<zmq::message_t> msgs;
msgs.emplace_back("kv_seq_rm", strlen("kv_seq_rm"));
msgs.emplace_back(&seq_id, sizeof(seq_id));
msgs.emplace_back(&p0, sizeof(p0));
msgs.emplace_back(&p1, sizeof(p1));
zmq::send_multipart(*ctx->send_socket, msgs);
} catch (const zmq::error_t & e) {
LLAMA_LOG_WARN("Failed to send kv_seq_rm: %s\n", e.what());
}
}
void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
if (seq_id_src == seq_id_dst) {
return;
@ -21921,6 +22113,24 @@ void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src,
llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
}
void llama_send_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
if (ctx->send_socket == nullptr) {
return;
}
try {
std::vector<zmq::message_t> msgs;
msgs.emplace_back("kv_seq_cp", strlen("kv_seq_cp"));
msgs.emplace_back(&seq_id_src, sizeof(seq_id_src));
msgs.emplace_back(&seq_id_dst, sizeof(seq_id_dst));
msgs.emplace_back(&p0, sizeof(p0));
msgs.emplace_back(&p1, sizeof(p1));
zmq::send_multipart(*ctx->send_socket, msgs);
} catch (const zmq::error_t & e) {
LLAMA_LOG_WARN("Failed to send kv_seq_cp: %s\n", e.what());
}
}
void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
llama_kv_cache_seq_keep(ctx->kv_self, seq_id);
}
@ -21933,6 +22143,24 @@ void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, lla
llama_kv_cache_seq_add(ctx->kv_self, seq_id, p0, p1, delta);
}
void llama_send_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
if (ctx->send_socket == nullptr) {
return;
}
try {
std::vector<zmq::message_t> msgs;
msgs.emplace_back("kv_seq_add", strlen("kv_seq_add"));
msgs.emplace_back(&seq_id, sizeof(seq_id));
msgs.emplace_back(&p0, sizeof(p0));
msgs.emplace_back(&p1, sizeof(p1));
msgs.emplace_back(&delta, sizeof(delta));
zmq::send_multipart(*ctx->send_socket, msgs);
} catch (const zmq::error_t & e) {
LLAMA_LOG_WARN("Failed to send kv_seq_add: %s\n", e.what());
}
}
void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
if (d == 1) {
return;
@ -21941,6 +22169,24 @@ void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, lla
llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d);
}
void llama_send_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
if (ctx->send_socket == nullptr) {
return;
}
try {
std::vector<zmq::message_t> msgs;
msgs.emplace_back("kv_seq_div", strlen("kv_seq_div"));
msgs.emplace_back(&seq_id, sizeof(seq_id));
msgs.emplace_back(&p0, sizeof(p0));
msgs.emplace_back(&p1, sizeof(p1));
msgs.emplace_back(&d, sizeof(d));
zmq::send_multipart(*ctx->send_socket, msgs);
} catch (const zmq::error_t & e) {
LLAMA_LOG_WARN("Failed to send kv_seq_div: %s\n", e.what());
}
}
llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
}