mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-12 09:49:40 +00:00
fix context shifting
This commit is contained in:
parent
07c4966a80
commit
c54a6a0132
8 changed files with 397 additions and 73 deletions
292
src/llama.cpp
292
src/llama.cpp
|
@ -121,6 +121,17 @@ struct Timer {
|
|||
// helpers
|
||||
//
|
||||
|
||||
template<typename LocalFn, typename RemoteFn>
|
||||
bool kv_cache_op(bool flag,
|
||||
LocalFn local_fn,
|
||||
RemoteFn remote_fn,
|
||||
bool is_last_dev) {
|
||||
if (!flag) return false;
|
||||
local_fn();
|
||||
if (!is_last_dev) remote_fn();
|
||||
return true;
|
||||
}
|
||||
|
||||
// trim whitespace from the beginning and end of a string
|
||||
static std::string trim(const std::string & str) {
|
||||
size_t start = 0;
|
||||
|
@ -4157,7 +4168,7 @@ static bool llama_kv_cache_find_slot(
|
|||
}
|
||||
|
||||
if (n_tested >= cache.size) {
|
||||
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
||||
LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -10629,7 +10640,7 @@ struct llm_build_context {
|
|||
cb(lctx.inp_K_shift, "K_shift", -1);
|
||||
ggml_set_input(lctx.inp_K_shift);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
for (int il = 0; il < (int)kv_self.k_l.size(); ++il) {
|
||||
const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
||||
|
@ -10642,13 +10653,19 @@ struct llm_build_context {
|
|||
|
||||
struct ggml_tensor * tmp;
|
||||
if (ggml_is_quantized(k->type)) {
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
GGML_ABORT("The option --cache-type-k is not supported on Metal\n");
|
||||
#endif
|
||||
|
||||
// dequantize to f32 -> RoPE -> quantize back
|
||||
tmp = ggml_cast(ctx0, k, GGML_TYPE_F32);
|
||||
cb(tmp, "K_f32", il);
|
||||
|
||||
for (auto * backend : lctx.backends) {
|
||||
// Figure out which backend KV cache belongs to
|
||||
if (ggml_backend_supports_buft(backend, lctx.model.buft_layer[il].buft)) {
|
||||
ggml_backend_sched_set_tensor_backend(lctx.sched.at(0), tmp, backend); // todo.
|
||||
ggml_backend_sched_set_tensor_backend(lctx.sched[0], tmp, backend);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -17769,7 +17786,39 @@ struct input_tensors {
|
|||
};
|
||||
|
||||
struct sync_meta {
|
||||
int32_t n_tokens = 0;
|
||||
int32_t n_tokens = 0;
|
||||
llama_pos * pos = nullptr;
|
||||
uint32_t n_ctx = 0;
|
||||
|
||||
// signal to clear the kv cache
|
||||
bool clear_kv_cache = false;
|
||||
|
||||
// signal to remove a kv cache sequence
|
||||
bool kv_seq_rm = false;
|
||||
llama_seq_id rm_seq_id = 0;
|
||||
llama_pos rm_p0 = 0;
|
||||
llama_pos rm_p1 = 0;
|
||||
|
||||
// signal to add a kv cache sequence
|
||||
bool kv_seq_add = false;
|
||||
llama_seq_id add_seq_id = 0;
|
||||
llama_pos add_p0 = 0;
|
||||
llama_pos add_p1 = 0;
|
||||
llama_pos add_delta = 0;
|
||||
|
||||
// signal to copy a kv cache sequence
|
||||
bool kv_seq_cp = false;
|
||||
llama_seq_id cp_src_seq_id = 0;
|
||||
llama_seq_id cp_dst_seq_id = 0;
|
||||
llama_pos cp_p0 = 0;
|
||||
llama_pos cp_p1 = 0;
|
||||
|
||||
// signal to divide the kv cache range
|
||||
bool kv_seq_div = false;
|
||||
llama_seq_id div_seq_id = 0;
|
||||
llama_pos div_p0 = 0;
|
||||
llama_pos div_p1 = 0;
|
||||
int div_factor = 1;
|
||||
};
|
||||
|
||||
static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
||||
|
@ -17781,6 +17830,11 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
|||
send_msgs.emplace_back("n_tokens", strlen("n_tokens"));
|
||||
send_msgs.emplace_back(&(meta->n_tokens), sizeof(meta->n_tokens));
|
||||
|
||||
if (meta->pos != nullptr) {
|
||||
send_msgs.emplace_back("pos", strlen("pos"));
|
||||
send_msgs.emplace_back(meta->pos, meta->n_ctx * sizeof(llama_pos));
|
||||
}
|
||||
|
||||
zmq::send_multipart(socket, send_msgs);
|
||||
} catch (const zmq::error_t& e) {
|
||||
LLAMA_LOG_INFO("Failed to send meta data: %s\n", e.what());
|
||||
|
@ -17797,6 +17851,49 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
|||
|
||||
socket.set(zmq::sockopt::rcvtimeo, -1);
|
||||
|
||||
const std::string cmd = recv_msgs[0].to_string();
|
||||
size_t idx = 1;
|
||||
|
||||
if (cmd == "clear_kv_cache" && recv_msgs.size() == 1) {
|
||||
meta->clear_kv_cache = true;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (cmd == "kv_seq_rm" && recv_msgs.size() == 4) {
|
||||
meta->kv_seq_rm = true;
|
||||
std::memcpy(&meta->rm_seq_id, recv_msgs[idx++].data(), sizeof(meta->rm_seq_id));
|
||||
std::memcpy(&meta->rm_p0, recv_msgs[idx++].data(), sizeof(meta->rm_p0));
|
||||
std::memcpy(&meta->rm_p1, recv_msgs[idx++].data(), sizeof(meta->rm_p1));
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (cmd == "kv_seq_add" && recv_msgs.size() == 5) {
|
||||
meta->kv_seq_add = true;
|
||||
std::memcpy(&meta->add_seq_id, recv_msgs[idx++].data(), sizeof(meta->add_seq_id));
|
||||
std::memcpy(&meta->add_p0, recv_msgs[idx++].data(), sizeof(meta->add_p0));
|
||||
std::memcpy(&meta->add_p1, recv_msgs[idx++].data(), sizeof(meta->add_p1));
|
||||
std::memcpy(&meta->add_delta, recv_msgs[idx++].data(), sizeof(meta->add_delta));
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (cmd == "kv_seq_cp" && recv_msgs.size() == 5) {
|
||||
meta->kv_seq_cp = true;
|
||||
std::memcpy(&meta->cp_src_seq_id, recv_msgs[idx++].data(), sizeof(meta->cp_src_seq_id));
|
||||
std::memcpy(&meta->cp_dst_seq_id, recv_msgs[idx++].data(), sizeof(meta->cp_dst_seq_id));
|
||||
std::memcpy(&meta->cp_p0, recv_msgs[idx++].data(), sizeof(meta->cp_p0));
|
||||
std::memcpy(&meta->cp_p1, recv_msgs[idx++].data(), sizeof(meta->cp_p1));
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (cmd == "kv_seq_div" && recv_msgs.size() == 5) {
|
||||
meta->kv_seq_div = true;
|
||||
std::memcpy(&meta->div_seq_id, recv_msgs[idx++].data(), sizeof(meta->div_seq_id));
|
||||
std::memcpy(&meta->div_p0, recv_msgs[idx++].data(), sizeof(meta->div_p0));
|
||||
std::memcpy(&meta->div_p1, recv_msgs[idx++].data(), sizeof(meta->div_p1));
|
||||
std::memcpy(&meta->div_factor, recv_msgs[idx++].data(), sizeof(meta->div_factor));
|
||||
return 0;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < recv_msgs.size(); i += 2) {
|
||||
std::string key = recv_msgs[i].to_string();
|
||||
zmq::message_t & data_msg = recv_msgs[i + 1];
|
||||
|
@ -17805,6 +17902,11 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
|||
GGML_ASSERT(data_msg.size() == sizeof(meta->n_tokens));
|
||||
std::memcpy(&(meta->n_tokens), data_msg.data(), sizeof(meta->n_tokens));
|
||||
}
|
||||
|
||||
if (key == "pos") {
|
||||
meta->pos = (llama_pos *) malloc(meta->n_ctx * sizeof(llama_pos));
|
||||
std::memcpy(meta->pos, data_msg.data(), meta->n_ctx * sizeof(llama_pos));
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
@ -18069,15 +18171,66 @@ static int llama_decode_internal(
|
|||
}
|
||||
|
||||
sync_meta meta;
|
||||
meta.n_ctx = cparams.n_ctx;
|
||||
bool is_last_dev = (my_rank == n_world - 1);
|
||||
|
||||
if (my_rank != 0) {
|
||||
if (llama_recv_meta(*lctx.recv_socket, &meta) == -1) {
|
||||
return -1;
|
||||
}
|
||||
batch_all.n_tokens = meta.n_tokens;
|
||||
|
||||
if (meta.n_tokens > 0) {
|
||||
batch_all.n_tokens = meta.n_tokens;
|
||||
if (meta.pos != nullptr) {
|
||||
batch_all.pos = (llama_pos *) malloc(cparams.n_ctx * sizeof(llama_pos));
|
||||
std::memcpy(batch_all.pos, meta.pos, cparams.n_ctx * sizeof(llama_pos));
|
||||
}
|
||||
}
|
||||
|
||||
if (kv_cache_op(meta.clear_kv_cache,
|
||||
[&]{ llama_kv_cache_clear (&lctx); },
|
||||
[&]{ llama_send_kv_cache_clear (&lctx); },
|
||||
is_last_dev)) {
|
||||
LLAMA_LOG_INFO("%s: received signal kv_cache_clear\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (kv_cache_op(meta.kv_seq_rm,
|
||||
[&]{ llama_kv_cache_seq_rm (&lctx, meta.rm_seq_id, meta.rm_p0, meta.rm_p1); },
|
||||
[&]{ llama_send_kv_cache_seq_rm (&lctx, meta.rm_seq_id, meta.rm_p0, meta.rm_p1); },
|
||||
is_last_dev)) {
|
||||
LLAMA_LOG_INFO("%s: received signal kv_cache_seq_rm\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (kv_cache_op(meta.kv_seq_add,
|
||||
[&]{ llama_kv_cache_seq_add (&lctx, meta.add_seq_id, meta.add_p0, meta.add_p1, meta.add_delta); },
|
||||
[&]{ llama_send_kv_cache_seq_add(&lctx, meta.add_seq_id, meta.add_p0, meta.add_p1, meta.add_delta); },
|
||||
is_last_dev)) {
|
||||
LLAMA_LOG_INFO("%s: received signal kv_cache_seq_add\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (kv_cache_op(meta.kv_seq_cp,
|
||||
[&]{ llama_kv_cache_seq_cp (&lctx, meta.cp_src_seq_id, meta.cp_dst_seq_id, meta.cp_p0, meta.cp_p1); },
|
||||
[&]{ llama_send_kv_cache_seq_cp (&lctx, meta.cp_src_seq_id, meta.cp_dst_seq_id, meta.cp_p0, meta.cp_p1); },
|
||||
is_last_dev)) {
|
||||
LLAMA_LOG_INFO("%s: received signal kv_cache_seq_cp\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (kv_cache_op(meta.kv_seq_div,
|
||||
[&]{ llama_kv_cache_seq_div (&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); },
|
||||
[&]{ llama_send_kv_cache_seq_div(&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); },
|
||||
is_last_dev)) {
|
||||
LLAMA_LOG_INFO("%s: received signal kv_cache_seq_div\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
if (my_rank != n_world - 1) {
|
||||
if (!is_last_dev) {
|
||||
meta.n_tokens = batch_all.n_tokens;
|
||||
meta.pos = batch_all.pos;
|
||||
llama_send_meta(*lctx.send_socket, &meta);
|
||||
}
|
||||
|
||||
|
@ -18803,22 +18956,20 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
|||
|
||||
// apply K-shift if needed
|
||||
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
|
||||
throw std::runtime_error("shift not supported\n");
|
||||
|
||||
if (lctx.model.arch == LLM_ARCH_DEEPSEEK2) { // not supported due to MLA
|
||||
GGML_ABORT("Deepseek2 does not support K-shift");
|
||||
}
|
||||
|
||||
{
|
||||
ggml_backend_sched_reset(lctx.sched.at(0)); // todo.
|
||||
for (size_t i = 0; i < lctx.sched.size(); ++i) {
|
||||
ggml_backend_sched_reset(lctx.sched[i]);
|
||||
|
||||
ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
|
||||
|
||||
ggml_backend_sched_alloc_graph(lctx.sched.at(0), gf); // todo.
|
||||
ggml_backend_sched_alloc_graph(lctx.sched[i], gf);
|
||||
|
||||
llama_set_k_shift(lctx);
|
||||
|
||||
llama_graph_compute(lctx, gf, lctx.sched.at(0), lctx.cparams.n_threads, lctx.threadpool); // todo.
|
||||
llama_graph_compute(lctx, gf, lctx.sched[i], lctx.cparams.n_threads, lctx.threadpool);
|
||||
|
||||
need_reserve = true;
|
||||
}
|
||||
|
@ -18845,8 +18996,6 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
|||
|
||||
// reserve a worst case graph again
|
||||
if (need_reserve) {
|
||||
throw std::runtime_error("reserve not supported\n");
|
||||
|
||||
// TODO: extract to a function
|
||||
// build worst-case graph
|
||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||
|
@ -18854,13 +19003,11 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
|||
llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
std::vector<ggml_cgraph *> gf = llama_build_graph(lctx, ubatch, true);
|
||||
|
||||
// initialize scheduler with the worst-case graph
|
||||
ggml_backend_sched_reset(lctx.sched[0]); // todo.
|
||||
GGML_ASSERT(lctx.sched.size() == gf.size());
|
||||
|
||||
bool ok = true;
|
||||
GGML_ASSERT(lctx.sched.size() == gf.size());
|
||||
for (size_t i = 0; i < gf.size(); ++i) {
|
||||
ggml_backend_sched_reset(lctx.sched[i]);
|
||||
ok = ok & ggml_backend_sched_reserve(lctx.sched[i], gf[i]);
|
||||
}
|
||||
if (!ok) {
|
||||
|
@ -20201,6 +20348,8 @@ void llama_init_sockets(struct llama_context * ctx, uint32_t n_world, uint32_t m
|
|||
LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what());
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
}
|
||||
|
||||
int llama_gather_device_info(struct llama_context * ctx, struct device_info * dev_info_set) {
|
||||
|
@ -20264,36 +20413,47 @@ int llama_send_device_info(struct llama_context * ctx, struct device_info * dev_
|
|||
|
||||
int llama_bcast_startup_args(llama_context * ctx, uint32_t rank, startup_args * args) {
|
||||
int32_t n_world = ctx->cparams.n_world;
|
||||
if (n_world == 1) {
|
||||
return 0;
|
||||
}
|
||||
GGML_ASSERT(n_world > 0);
|
||||
GGML_ASSERT(ctx != nullptr && ctx->send_socket != nullptr);
|
||||
|
||||
if (rank == 0){
|
||||
// send
|
||||
try {
|
||||
std::vector<zmq::message_t> send_msgs;
|
||||
|
||||
send_msgs.emplace_back("should_profile", strlen("should_profile"));
|
||||
send_msgs.emplace_back(&args->should_profile, sizeof(args->should_profile));
|
||||
|
||||
send_msgs.emplace_back("n_ctx", strlen("n_ctx"));
|
||||
send_msgs.emplace_back(&args->n_ctx, sizeof(args->n_ctx));
|
||||
|
||||
zmq::send_multipart(*ctx->send_socket, send_msgs);
|
||||
} catch (const zmq::error_t& e) {
|
||||
LLAMA_LOG_INFO("Failed to send data: %s\n", e.what());
|
||||
return -1;
|
||||
}
|
||||
}else {
|
||||
} else {
|
||||
// receive
|
||||
std::vector<zmq::message_t> recv_msgs;
|
||||
if (!zmq::recv_multipart(*ctx->recv_socket, std::back_inserter(recv_msgs))) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
GGML_ASSERT(recv_msgs[0].to_string() == "should_profile");
|
||||
GGML_ASSERT(recv_msgs[1].size() == sizeof(bool));
|
||||
bool should_profile = *static_cast<bool*>(recv_msgs[1].data());
|
||||
args->should_profile = should_profile;
|
||||
|
||||
GGML_ASSERT(recv_msgs[2].to_string() == "n_ctx");
|
||||
GGML_ASSERT(recv_msgs[3].size() == sizeof(uint32_t));
|
||||
uint32_t n_ctx = *static_cast<uint32_t*>(recv_msgs[3].data());
|
||||
args->n_ctx = n_ctx;
|
||||
|
||||
if ((int)rank != (int)n_world - 1){
|
||||
// send
|
||||
try {
|
||||
zmq::send_multipart(*ctx->send_socket, recv_msgs);
|
||||
} catch (const zmq::error_t& e) {
|
||||
} catch (const zmq::error_t & e) {
|
||||
LLAMA_LOG_INFO("Failed to send data: %s\n", e.what());
|
||||
return -1;
|
||||
}
|
||||
|
@ -21910,10 +22070,42 @@ void llama_kv_cache_clear(struct llama_context * ctx) {
|
|||
llama_kv_cache_clear(ctx->kv_self);
|
||||
}
|
||||
|
||||
void llama_send_kv_cache_clear(struct llama_context * ctx) {
|
||||
if (ctx->send_socket == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
std::vector<zmq::message_t> send_msgs;
|
||||
const char * cmd = "clear_kv_cache";
|
||||
send_msgs.emplace_back(cmd, strlen(cmd));
|
||||
zmq::send_multipart(*ctx->send_socket, send_msgs);
|
||||
} catch (const zmq::error_t & e) {
|
||||
LLAMA_LOG_INFO("Failed to send KV cache clear signal: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
|
||||
}
|
||||
|
||||
void llama_send_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
if (ctx->send_socket == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
std::vector<zmq::message_t> msgs;
|
||||
msgs.emplace_back("kv_seq_rm", strlen("kv_seq_rm"));
|
||||
msgs.emplace_back(&seq_id, sizeof(seq_id));
|
||||
msgs.emplace_back(&p0, sizeof(p0));
|
||||
msgs.emplace_back(&p1, sizeof(p1));
|
||||
zmq::send_multipart(*ctx->send_socket, msgs);
|
||||
} catch (const zmq::error_t & e) {
|
||||
LLAMA_LOG_WARN("Failed to send kv_seq_rm: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
if (seq_id_src == seq_id_dst) {
|
||||
return;
|
||||
|
@ -21921,6 +22113,24 @@ void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src,
|
|||
llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
void llama_send_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
if (ctx->send_socket == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
std::vector<zmq::message_t> msgs;
|
||||
msgs.emplace_back("kv_seq_cp", strlen("kv_seq_cp"));
|
||||
msgs.emplace_back(&seq_id_src, sizeof(seq_id_src));
|
||||
msgs.emplace_back(&seq_id_dst, sizeof(seq_id_dst));
|
||||
msgs.emplace_back(&p0, sizeof(p0));
|
||||
msgs.emplace_back(&p1, sizeof(p1));
|
||||
zmq::send_multipart(*ctx->send_socket, msgs);
|
||||
} catch (const zmq::error_t & e) {
|
||||
LLAMA_LOG_WARN("Failed to send kv_seq_cp: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||
llama_kv_cache_seq_keep(ctx->kv_self, seq_id);
|
||||
}
|
||||
|
@ -21933,6 +22143,24 @@ void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, lla
|
|||
llama_kv_cache_seq_add(ctx->kv_self, seq_id, p0, p1, delta);
|
||||
}
|
||||
|
||||
void llama_send_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
||||
if (ctx->send_socket == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
std::vector<zmq::message_t> msgs;
|
||||
msgs.emplace_back("kv_seq_add", strlen("kv_seq_add"));
|
||||
msgs.emplace_back(&seq_id, sizeof(seq_id));
|
||||
msgs.emplace_back(&p0, sizeof(p0));
|
||||
msgs.emplace_back(&p1, sizeof(p1));
|
||||
msgs.emplace_back(&delta, sizeof(delta));
|
||||
zmq::send_multipart(*ctx->send_socket, msgs);
|
||||
} catch (const zmq::error_t & e) {
|
||||
LLAMA_LOG_WARN("Failed to send kv_seq_add: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
if (d == 1) {
|
||||
return;
|
||||
|
@ -21941,6 +22169,24 @@ void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, lla
|
|||
llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
void llama_send_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
if (ctx->send_socket == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
std::vector<zmq::message_t> msgs;
|
||||
msgs.emplace_back("kv_seq_div", strlen("kv_seq_div"));
|
||||
msgs.emplace_back(&seq_id, sizeof(seq_id));
|
||||
msgs.emplace_back(&p0, sizeof(p0));
|
||||
msgs.emplace_back(&p1, sizeof(p1));
|
||||
msgs.emplace_back(&d, sizeof(d));
|
||||
zmq::send_multipart(*ctx->send_socket, msgs);
|
||||
} catch (const zmq::error_t & e) {
|
||||
LLAMA_LOG_WARN("Failed to send kv_seq_div: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue