kvcache-ai-ktransformers/csrc/balance_serve/sched/scheduler.cpp

960 lines
31 KiB
C++

#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO
#define FMT_HEADER_ONLY
#include "nlohmann/json.hpp"
#include "spdlog/spdlog.h"
#include "scheduler.h"
#include <optional>
#include "arithmetic.hpp"
#include "atomic_ptr_with_flags.hpp"
#include "easy_format.hpp"
#include "metrics.h"
#include "mpsc.hpp"
#include "timer.hpp"
#include <atomic>
#include <cassert>
#include <future>
#include <memory>
#include <queue>
#include "kvc2.h"
using json = nlohmann::json;
namespace scheduler {
void Settings::auto_derive() {
gpu_device_count = gpu_device_id.size();
if (torch::cuda::is_available()) {
size_t gpu_count = torch::cuda::device_count();
SPDLOG_INFO("Number of available GPUs: {}, want {}", gpu_count,
gpu_device_count);
if (gpu_count < gpu_device_count) {
SPDLOG_ERROR("Not enough GPUs available.");
exit(0);
}
for (size_t i = 0; i < gpu_device_count; i++) {
devices.push_back(torch::Device(torch::kCUDA, gpu_device_id[i]));
}
} else {
SPDLOG_ERROR("CUDA is not available on this system.");
exit(0);
}
if (model_settings.num_k_heads % gpu_device_count != 0) {
SPDLOG_ERROR("num_k_heads {} is not divisible by gpu_device_count {}",
model_settings.num_k_heads, gpu_device_count);
assert(false);
}
size_t gpu_memory_available = gpu_memory_size * memory_utilization_percentage;
if (gpu_memory_available * gpu_device_count <
model_settings.params_nbytes()) {
SPDLOG_ERROR("GPU memory size {}G is smaller than {}G",
gpu_memory_available * gpu_device_count / 1e9,
model_settings.params_nbytes() / 1e9);
assert(false);
}
assert(model_settings.k_head_dim % model_settings.num_k_heads == 0);
size_t head_per_gpu = model_settings.num_k_heads / gpu_device_count;
size_t gpu_memory_for_kv_cache =
gpu_memory_available /*- model_settings.params_nbytes() /
gpu_device_count*/
;
SPDLOG_INFO(
"Each GPU Total: {}MiB, Model Params: {}MiB, KVCache: {}MiB, Left: {}MiB",
gpu_memory_size / (1 << 20),
model_settings.params_nbytes() / gpu_device_count / (1 << 20),
gpu_memory_for_kv_cache / (1 << 20),
(gpu_memory_size - gpu_memory_available) / (1 << 20));
size_t kv_cache_on_cnt = (size_t)(k_cache_on) + (size_t)(v_cache_on);
size_t max_total_kvcache_pages =
gpu_memory_for_kv_cache /
(kv_cache_on_cnt * head_per_gpu * model_settings.k_head_dim *
model_settings.bytes_per_kv_cache_element * page_size *
model_settings.layer_count);
if (total_kvcache_pages.has_value()) {
if (total_kvcache_pages.value() > max_total_kvcache_pages) {
SPDLOG_ERROR(
"total_kvcache_pages {} is larger than max_total_kvcache_pages {}",
total_kvcache_pages.value(), max_total_kvcache_pages);
assert(false);
}
} else {
total_kvcache_pages = max_total_kvcache_pages;
SPDLOG_INFO("total_kvcache_pages is auto derived as {}",
max_total_kvcache_pages);
}
if (page_size % 256 != 0) {
SPDLOG_ERROR("page_size {} is not divisible by 256", page_size);
assert(false);
}
if (page_size < 256) {
SPDLOG_ERROR("page_size {} is smaller than 256", page_size);
assert(false);
}
}
std::string BatchQueryTodo::debug() {
std::string re = "BatchQueryTodo: ";
re += "QueryIDs: ";
for (auto &id : query_ids) {
re += std::to_string(id) + " ";
}
return re;
}
bool BatchQueryTodo::empty() {
return prefill_mini_batches.empty() && decode_mini_batches.empty();
}
struct QueryMaintainer;
struct Query {
QueryID id;
torch::Tensor query_token;
TokenLength prompt_length;
TokenLength no_kvcache_from;
TokenLength estimated_length;
SampleOptions sample_options;
UserID user_id;
std::optional<int> SLO_TTFT_ms;
std::optional<int> SLO_TBT_ms;
std::vector<std::vector<int>> stop_criteria;
// status
// Query status changed by this order
enum Status { Received, Preparing, Ready, Prefill, Decode, Done };
Status plan_status = Received;
TokenLength active_position; // the position where no kvcache now
TokenLength plan_position; // the position where no kvcache now, in plan
size_t prepare_try_count = 0;
std::shared_ptr<kvc2::DoubleCacheHandleInterface> kvc2_handle = nullptr;
// derived from kvc2_handle
torch::Tensor block_index; // block indexes
struct QueryContext {
ModelName model_name;
QuantType quant_type;
kvc2::KVC2Interface *kvc2_interface;
QueryMaintainer *query_maintainer;
Metrics *met;
} ctx;
void after_load(bool ok);
void to_status(Status to);
void export_metrics() {
ctx.met->query_count(status_to_string(plan_status))->Increment(1);
}
Query(QueryID id, QueryAdd query_add, QueryContext context)
: id(id), prompt_length(query_add.query_length), no_kvcache_from(0),
estimated_length(query_add.estimated_length),
sample_options(query_add.sample_options), user_id(query_add.user_id),
SLO_TTFT_ms(query_add.SLO_TTFT_ms), SLO_TBT_ms(query_add.SLO_TBT_ms),
stop_criteria(query_add.stop_criteria), ctx(context) {
std::vector<int64_t> shape = {int64_t(query_add.estimated_length)};
query_token =
torch::zeros(shape, torch::TensorOptions().dtype(torch::kInt32));
assert(query_token.is_contiguous());
if (query_token.is_contiguous() == false) {
SPDLOG_ERROR("Query Token must be contiguous!");
exit(1);
}
memcpy(query_token.data_ptr(), query_add.query_token.data(),
query_add.query_length * sizeof(Token));
no_kvcache_from = 0; // maybe match prefix later
export_metrics();
}
Token &token_at(size_t idx) {
return reinterpret_cast<Token *>(query_token.data_ptr())[idx];
}
void absorb_update(const QueryUpdate &update) {
SPDLOG_DEBUG("{}", update.debug());
active_position = update.active_position;
kvc2_handle->append_tokens(&token_at(0),
active_position); // active_position is length -1
if (update.is_prefill) {
if (active_position == prompt_length) {
token_at(active_position) = update.generated_token;
ctx.met->generated_tokens->Increment(1);
}
} else {
token_at(active_position) = update.generated_token;
ctx.met->generated_tokens->Increment(1);
}
if (update.decode_done || active_position == estimated_length - 1) {
to_status(Done);
}
}
void absorb_prefill_task(const PrefillTask &task) {
auto &[id, start, length] = task;
this->plan_position = start + length;
if (this->plan_position == prompt_length) {
to_status(Decode);
}
}
void absorb_decode_task([[maybe_unused]] const QueryID &task) {
this->plan_position += 1;
}
PrefillTask get_prefill_task(size_t prefill_length) {
if (prefill_length + plan_position > prompt_length) {
prefill_length = prompt_length - plan_position;
}
return {id, plan_position, prefill_length};
}
static std::string status_to_string(Status status) {
switch (status) {
case Received:
return "Received";
case Preparing:
return "Preparing";
case Ready:
return "Ready";
case Prefill:
return "Prefill";
case Decode:
return "Decode";
case Done:
return "Done";
}
assert(false);
}
void debug() {
std::string status_string = status_to_string(plan_status);
SPDLOG_DEBUG("Query {}, prompt_length {}, estimated_length {}, plan status "
"{}, plan position {} "
"active position {}",
id, prompt_length, estimated_length, status_string,
plan_position, active_position);
}
};
std::string QueryUpdate::debug() const {
return fmt::format("Query {}, ok {}, is_prefill {}, done {}, active_position "
"{}, gen token {}",
id, ok, is_prefill, decode_done, active_position,
generated_token);
}
using Q = std::shared_ptr<Query>;
struct KVC2_Maintainer {
Settings settings;
std::vector<torch::Tensor> k_cache;
std::vector<torch::Tensor> v_cache;
std::shared_ptr<kvc2::KVC2Interface> kvc2_interface;
KVC2_Maintainer(Settings settings) : settings(settings) {
// SPDLOG_WARN("Creating KVC2 Instance {}", settings.kvc2_root_path);
assert(settings.kvc2_root_path.size() > 0);
// SPDLOG_WARN("Sizeof KVC2Config {} upper", sizeof(kvc2::KVC2Config));
kvc2::GPUPageCacheConfig gpu_cache_config{
.gpu_only = settings.gpu_only,
.gpu_devices_id = settings.gpu_device_id,
.layer_count = settings.model_settings.layer_count,
.total_kvcache_pages = settings.total_kvcache_pages.value(),
.num_token_per_page = settings.page_size,
.num_k_heads = settings.model_settings.num_k_heads,
.k_head_dim = settings.use_self_defined_head_dim
? settings.self_defined_head_dim
: settings.model_settings.k_head_dim,
.full_kv_cache_on_each_gpu = settings.full_kv_cache_on_each_gpu,
.k_cache_on = settings.k_cache_on,
.v_cache_on = settings.v_cache_on,
.tensor_type = torch::kBFloat16,
};
auto model_configs_path =
std::filesystem::path(settings.kvc2_config_path) / "model_configs.json";
load_model_configs(model_configs_path);
auto my_model_config = ModelConfig();
my_model_config.load_from(
std::filesystem::path(settings.model_settings.model_path) /
"config.json");
model_configs[settings.model_name] = my_model_config;
dump_model_configs(model_configs_path);
kvc2::KVC2Config kvc2_config = {
.k_cache_on = settings.k_cache_on,
.v_cache_on = settings.v_cache_on,
.gpu_only = settings.gpu_only,
.load_from_disk = settings.load_from_disk,
.save_to_disk = settings.save_to_disk,
.path = settings.kvc2_root_path,
.config_path = settings.kvc2_config_path,
.num_token_per_page = settings.page_size,
.memory_pool_size = size_t(settings.memory_pool_size_GB * 1e9),
.evict_count = settings.evict_count,
.gpu_cache_config = gpu_cache_config,
.metrics_port = settings.kvc2_metrics_port,
};
kvc2_interface = kvc2::create_kvc2(kvc2_config);
if (settings.load_from_disk)
kvc2_interface->load();
SPDLOG_DEBUG("KVC2 created ok");
auto [k_cache, v_cache] = kvc2_interface->get_kvcache();
this->k_cache = k_cache;
this->v_cache = v_cache;
}
};
using EventAddQuery = std::pair<QueryAdd, std::promise<QueryID> *>;
using EventUpdateQuery = BatchQueryUpdate;
using EventTakenBatch = std::shared_ptr<BatchQueryTodo>;
struct EventPrepare {
QueryID query_id;
bool first_try;
};
struct EventPrepared {
QueryID query_id;
bool ok;
};
struct EventQueryStatus {
QueryID query_id;
Query::Status now_status;
};
struct EventSchedule {};
using Event =
std::variant<EventAddQuery, EventUpdateQuery, EventTakenBatch, EventPrepare,
EventPrepared, EventQueryStatus, EventSchedule>;
template <typename T> std::string event_name(const T &event);
template <> std::string event_name(const EventAddQuery &) {
return "EventAddQuery";
}
template <> std::string event_name(const EventUpdateQuery &) {
return "EventUpdateQuery";
}
template <> std::string event_name(const EventTakenBatch &) {
return "EventTakenBatch";
}
template <> std::string event_name(const EventPrepare &) {
return "EventPrepare";
}
template <> std::string event_name(const EventPrepared &) {
return "EventPrepared";
}
template <> std::string event_name(const EventQueryStatus &) {
return "EventQueryStatus";
}
template <> std::string event_name(const EventSchedule &) {
return "EventSchedule";
}
// 用 std::visit 实现对 variant 的 event_name
std::string event_name(const Event &event) {
return std::visit([](const auto &e) { return event_name(e); }, event);
}
static_assert(std::is_copy_constructible<Event>::value);
static_assert(std::is_move_constructible<Event>::value);
struct QueryMaintainer : public Scheduler {
// only get access by event loop
Settings settings;
QueryID query_id_counter = NoQueryID + 1;
std::map<QueryID, Q> query_map;
std::shared_ptr<KVC2_Maintainer> kvc2_maintainer;
std::shared_ptr<Metrics> met;
// multi-thread visit
std::atomic_bool stop_flag = false;
// TODO consider correctness of event loop
MPSCQueueConsumerLock<Event> event_loop_queue;
// std::binary_semaphore batch_ready{0};
AtomicPtrWithFlag<BatchQueryTodo> next_batch;
QueryMaintainer() = default;
void gen_batch_query_todo(BatchQueryTodo *re, const std::set<Q> &queries) {
std::vector<std::vector<QueryID>> d_batch(2);
size_t last_decode_batch = 0;
size_t prefill_num = 0;
size_t decode_num = 0;
size_t preill_length = 0;
for (auto &q : queries) {
if (q->plan_status == Query::Prefill) {
prefill_num += 1;
}
if (q->plan_status == Query::Decode) {
decode_num += 1;
}
}
if (prefill_num >= 2 ||
(prefill_num == 1 && settings.max_batch_size - 2 < decode_num)) {
preill_length = settings.recommended_chunk_prefill_token_count;
} else {
preill_length = settings.recommended_chunk_prefill_token_count * 2;
}
for (auto &q : queries) {
re->query_ids.push_back(q->id);
re->query_tokens.push_back(q->query_token);
re->query_lengths.push_back(q->prompt_length);
if (q->plan_status == Query::Prefill) {
re->prefill_mini_batches.push_back(q->get_prefill_task(preill_length));
assert(re->prefill_mini_batches.size() <= 2);
}
if (q->plan_status == Query::Decode) {
d_batch[last_decode_batch].push_back(q->id);
// last_decode_batch = 1 - last_decode_batch;
if (d_batch[last_decode_batch].size() == settings.max_batch_size - 1) {
last_decode_batch += 1;
assert(last_decode_batch < 2);
}
}
re->block_indexes.push_back(q->block_index);
re->sample_options.push_back(q->sample_options);
re->stop_criteria.push_back(q->stop_criteria);
}
re->attn_masks = std::nullopt;
re->rope_ranges = std::nullopt;
for (auto &b : d_batch) {
if (b.empty())
continue;
re->decode_mini_batches.push_back(b);
}
met->batch_count("Generated")->Increment(1);
}
// Interface
void init(Settings settings) override {
SPDLOG_INFO("\nScheduler Settings:\n"
" model_name: {}\n"
" quant_type: {}\n"
" model_path: {}\n"
" params_count: {}\n"
" layer_count: {}\n"
" num_k_heads: {}\n"
" k_head_dim: {}\n"
" bytes_per_params: {}\n"
" bytes_per_kv_cache_element: {}\n"
" page_size: {}\n"
" gpu_device_id: {}\n"
" gpu_memory_size: {}\n"
" memory_utilization_percentage: {}\n"
" max_batch_size: {}\n"
" recommended_chunk_prefill_token_count: {}\n"
" sched_metrics_port: {}\n"
" kvc2_config_path: {}\n"
" kvc2_root_path: {}\n"
" memory_pool_size_GB: {}\n"
" evict_count: {}\n"
" kvc2_metrics_port: {}\n"
" load_from_disk: {}\n"
" save_to_disk: {}\n"
" strategy_name: {}\n"
" gpu_device_count: {}\n",
settings.model_name, settings.quant_type,
settings.model_settings.model_path,
settings.model_settings.params_count,
settings.model_settings.layer_count,
settings.model_settings.num_k_heads,
settings.model_settings.k_head_dim,
settings.model_settings.bytes_per_params,
settings.model_settings.bytes_per_kv_cache_element,
settings.page_size, format_vector(settings.gpu_device_id),
readable_number(settings.gpu_memory_size),
settings.memory_utilization_percentage, settings.max_batch_size,
settings.recommended_chunk_prefill_token_count,
settings.sched_metrics_port, settings.kvc2_config_path,
settings.kvc2_root_path, settings.memory_pool_size_GB,
settings.evict_count, settings.kvc2_metrics_port,
settings.load_from_disk, settings.save_to_disk,
settings.strategy_name, settings.gpu_device_count);
this->settings = settings;
kvc2_maintainer =
std::shared_ptr<KVC2_Maintainer>(new KVC2_Maintainer(settings));
MetricsConfig met_conf = {
.endpoint = "0.0.0.0:" + std::to_string(settings.sched_metrics_port),
.model_name = settings.model_name,
.gpu_count = settings.gpu_device_count,
};
SPDLOG_INFO("Creating scheduler metrics exporter on {}", met_conf.endpoint);
met = std::make_shared<Metrics>(met_conf);
met->fn_every_sec = [](Metrics *met) {
auto generated_tokens = met->generated_tokens->Collect().counter.value;
SPDLOG_INFO("Last Sec Generated Tokens {}", generated_tokens);
};
}
Query::QueryContext get_query_context() {
return Query::QueryContext{
.model_name = settings.model_name,
.quant_type = settings.quant_type,
.kvc2_interface = kvc2_maintainer->kvc2_interface.get(),
.query_maintainer = this,
.met = met.get(),
};
}
QueryID add_query(QueryAdd query_add) override {
std::promise<QueryID> p;
event_loop_queue.enqueue(EventAddQuery(query_add, &p));
return p.get_future().get();
}
void cancel_query(QueryID id) override {
SPDLOG_INFO("Cancel Query");
SPDLOG_INFO("sched:{} Cancel Query", fmt::ptr(this));
auto it = query_map.find(id);
if (it == query_map.end()) {
SPDLOG_ERROR("Query {} is not found", id);
return;
}
query_map.erase(it);
}
// Here this function update last batch results and get the next batch
// in most cases, the batch is ready,
// if not, busy wait to get it
std::shared_ptr<BatchQueryTodo>
update_last_batch(BatchQueryUpdate updates) override {
event_loop_queue.enqueue(updates);
// Busy Wait
while (true) {
auto [ptr, is_new] = next_batch.touch_load();
// SPDLOG_INFO("ptr {} is_new {}", fmt::ptr(ptr), is_new);
if (is_new) {
// SPDLOG_DEBUG("New Batch {}", fmt::ptr(ptr));
auto re = std::shared_ptr<BatchQueryTodo>(ptr);
event_loop_queue.enqueue(re);
return re;
} else {
// // here to busy wait
// SPDLOG_INFO("Not New");
// using namespace std::chrono_literals;
// std::this_thread::sleep_for(1s);
}
}
}
InferenceContext get_inference_context() override {
InferenceContext re;
re.k_cache = kvc2_maintainer->k_cache;
re.v_cache = kvc2_maintainer->v_cache;
// kvc2_maintainer->k_cache[0][0][0][0][0][0] = 42; // test whether we pass
// this to inference loop
return re;
}
virtual void strategy_add_query(Q new_query) = 0;
virtual void strategy_update_query(const EventUpdateQuery &update) = 0;
virtual void strategy_taken_batch(const EventTakenBatch &batch) = 0;
virtual void strategy_prepare(const EventPrepare &prepare) = 0;
virtual void strategy_prepared(const EventPrepared &prepared) = 0;
virtual void strategy_query_status(const EventQueryStatus &query_status) = 0;
virtual void strategy_schedule(const EventSchedule &event,
BatchQueryTodo *new_batch) = 0;
void tackle_event(EventAddQuery &event) {
auto &query_add = event.first;
QueryID id = query_id_counter;
event.second->set_value(id);
query_id_counter += 1;
Q new_query(new Query(id, query_add, get_query_context()));
query_map[id] = new_query;
SPDLOG_INFO("New Query {} is added", id);
strategy_add_query(new_query);
}
void tackle_event(const EventUpdateQuery &update) {
// SPDLOG_INFO("Tackle Update Query");
for (auto &u : update) {
if (u.ok == false) {
SPDLOG_ERROR("Query {} is not exectued OK", u.id);
exit(1);
}
auto q = query_map[u.id];
if (q->plan_status == Query::Status::Prefill ||
q->plan_status == Query::Status::Decode) {
q->absorb_update(u);
} else {
SPDLOG_DEBUG(
"Query {} is not in Prefill or Decode status, do not update it",
u.id);
}
}
strategy_update_query(update);
}
void tackle_event(const EventTakenBatch &batch) {
met->batch_count("Taken")->Increment(1);
for (auto &task : batch->prefill_mini_batches) {
auto [id, s, l] = task;
if (l == 0)
continue;
query_map.at(id)->absorb_prefill_task(task);
}
for (auto &mini_batch : batch->decode_mini_batches) {
for (auto &id : mini_batch) {
query_map.at(id)->absorb_decode_task(id);
}
}
strategy_taken_batch(batch);
}
void tackle_event(const EventPrepare &event) { strategy_prepare(event); }
void tackle_event(const EventPrepared &event) { strategy_prepared(event); }
void tackle_event(const EventQueryStatus &event) {
strategy_query_status(event);
}
void tackle_event(const EventSchedule &event) {
// SPDLOG_INFO("Tackle Schedule Event");
HistogramTimerWrapper t(met->schedule_time);
BatchQueryTodo *new_batch = new BatchQueryTodo;
strategy_schedule(event, new_batch);
// if (new_batch->query_ids.empty()) {
// SPDLOG_INFO("Nothing todo");
// delete new_batch;
// return;
// }
auto [old_batch, flag] = next_batch.exchange(new_batch, true);
if (new_batch->empty() == false) {
SPDLOG_DEBUG("set new batch {}", fmt::ptr(new_batch));
}
if (flag) {
SPDLOG_INFO("Batch {} is not consumed", fmt::ptr(old_batch));
delete old_batch;
}
}
void run() override {
std::thread([this]() {
SPDLOG_WARN("Starting Scheduler Event Loop");
while (stop_flag.load() == false) {
auto event = event_loop_queue.dequeue();
met->event_count(event_name(event))->Increment(1);
std::visit(
[this](auto event) {
using T = std::decay_t<decltype(event)>;
// SPDLOG_INFO("Event Loop: {}", typeid(T).name());
if constexpr (std::is_same_v<T, EventAddQuery>) {
tackle_event(event);
} else if constexpr (std::is_same_v<T, EventUpdateQuery>) {
tackle_event(event);
} else if constexpr (std::is_same_v<T, EventTakenBatch>) {
tackle_event(event);
} else if constexpr (std::is_same_v<T, EventPrepare>) {
tackle_event(event);
} else if constexpr (std::is_same_v<T, EventPrepared>) {
tackle_event(event);
} else if constexpr (std::is_same_v<T, EventQueryStatus>) {
tackle_event(event);
} else if constexpr (std::is_same_v<T, EventSchedule>) {
tackle_event(event);
} else {
SPDLOG_ERROR("Should not be here");
assert(false);
}
},
event);
if (event_loop_queue.size() == 0 &&
std::holds_alternative<EventSchedule>(event) == false) {
// if this is not a schedule event, we need to schedule one
event_loop_queue.enqueue(EventSchedule());
}
}
}).detach();
}
void stop() override { stop_flag.store(true); }
~QueryMaintainer() {
kvc2_maintainer->kvc2_interface->save();
stop();
}
};
void Query::to_status(Status to) {
SPDLOG_DEBUG("Calling to status query {}, to {}", id, status_to_string(to));
switch (to) {
case Received:
assert(false);
break;
case Preparing:
SPDLOG_INFO("Preparing Query {} {}", id,
prepare_try_count > 0
? (std::to_string(prepare_try_count) + " Try")
: "");
prepare_try_count += 1;
ctx.kvc2_interface->lookup_to_gpu_async(
ctx.model_name, ctx.quant_type,
static_cast<kvc2::Token *>(query_token.data_ptr()), prompt_length,
estimated_length,
[this](std::shared_ptr<kvc2::DoubleCacheHandleInterface> handle) {
if (handle == nullptr) {
SPDLOG_INFO("Get handle from kvc2 Failed.");
this->after_load(false);
} else {
SPDLOG_INFO("Get handle from kvc2 Success.");
this->kvc2_handle = handle;
this->to_status(Ready);
this->after_load(true);
}
});
break;
case Ready:
SPDLOG_INFO("Ready Query {}", id);
break;
case Prefill:
SPDLOG_INFO("Prefilling Query {}", id);
// assert(plan_status == Received);
plan_position = kvc2_handle->matched_length();
if (prompt_length - plan_position == 0) {
assert(prompt_length > 0);
plan_position -= 1;
}
break;
case Decode:
SPDLOG_INFO("Decoding Query {}", id);
// assert(plan_status == Prefill);
break;
case Done:
SPDLOG_INFO("Finish Query {}", id);
kvc2_handle = nullptr;
ctx.query_maintainer->event_loop_queue.enqueue(EventQueryStatus{
.query_id = id,
.now_status = to,
});
// assert(plan_status == Decode);
break;
}
plan_status = to;
export_metrics();
}
void Query::after_load(bool ok) {
if (ok) {
size_t page_count =
div_up(estimated_length, ctx.query_maintainer->settings.page_size);
std::vector<int64_t> shape;
shape.push_back(page_count);
block_index =
torch::zeros(shape, torch::TensorOptions().dtype(torch::kInt32))
.contiguous();
auto ptr = reinterpret_cast<int32_t *>(block_index.data_ptr());
auto vec_idx = kvc2_handle->get_gpu_block_idx();
for (size_t i = 0; i < vec_idx.size(); i++) {
ptr[i] = vec_idx[i];
}
no_kvcache_from = kvc2_handle->matched_length();
}
if (ok) {
ctx.query_maintainer->event_loop_queue.enqueue(EventPrepared{
.query_id = id,
.ok = ok,
});
} else {
ctx.query_maintainer->event_loop_queue.enqueue(EventPrepare{
.query_id = id,
.first_try = false,
});
}
}
struct FCFS_single_prefill : public QueryMaintainer {
std::queue<Q> queue;
std::queue<Q> ready_queue;
bool has_query_preparing = false;
std::optional<EventPrepare> wait_done_prepare = std::nullopt;
std::set<Q> active_query; // on going queries for LLMs
// interface all these are executed in a single thread
void strategy_add_query(Q new_query) override {
queue.push(new_query);
if (has_query_preparing == false) {
has_query_preparing = true;
auto next_q = queue.front();
queue.pop();
event_loop_queue.enqueue(EventPrepare{next_q->id, true});
}
}
void strategy_update_query(const EventUpdateQuery &update) override {
for (auto u : update) {
auto &q = query_map[u.id];
if (q->plan_status == Query::Done) {
active_query.erase(q);
}
}
}
void strategy_taken_batch(const EventTakenBatch &batch) override {
for (auto &q : batch->query_ids) {
if (query_map[q]->plan_status != Query::Done) {
active_query.insert(query_map[q]);
}
}
}
void strategy_prepare(const EventPrepare &prepare) override {
if (prepare.first_try) {
auto &q = query_map[prepare.query_id];
q->to_status(Query::Preparing);
} else {
assert(wait_done_prepare.has_value() == false);
wait_done_prepare = prepare;
wait_done_prepare->first_try = true;
}
}
void strategy_prepared(const EventPrepared &prepared) override {
assert(prepared.ok);
ready_queue.push(query_map[prepared.query_id]);
if (queue.empty() == false) {
auto next_q_prepare = queue.front();
queue.pop();
event_loop_queue.enqueue(EventPrepare{next_q_prepare->id, true});
} else {
has_query_preparing = false;
}
}
void strategy_query_status(const EventQueryStatus &query_status) override {
if (query_status.now_status == Query::Done) {
if (wait_done_prepare.has_value()) {
event_loop_queue.enqueue(wait_done_prepare.value());
wait_done_prepare = std::nullopt;
}
}
}
void strategy_schedule([[maybe_unused]] const EventSchedule &event,
BatchQueryTodo *new_batch) override {
bool have_prefill = false;
for (auto &q : active_query) {
if (q->plan_status == Query::Prefill) {
have_prefill = true;
}
}
if (have_prefill == false && ready_queue.empty() == false &&
active_query.size() < settings.max_batch_size) {
auto &next_q = ready_queue.front();
ready_queue.pop();
SPDLOG_INFO("Active query {}", next_q->id);
active_query.insert(next_q);
next_q->to_status(Query::Prefill);
}
if (active_query.empty() == false)
SPDLOG_INFO("Active Query Size {}", active_query.size());
for (auto &q : active_query) {
q->debug();
}
gen_batch_query_todo(new_batch, active_query);
}
};
struct FCFS : public FCFS_single_prefill {
void strategy_schedule([[maybe_unused]] const EventSchedule &event,
BatchQueryTodo *new_batch) override {
int prefill_count = 0;
const int max_prefill_count = 2;
for (auto &q : active_query) {
if (q->plan_status == Query::Prefill) {
prefill_count += 1;
}
}
while (prefill_count < max_prefill_count && ready_queue.empty() == false &&
active_query.size() < settings.max_batch_size) {
auto next_q = ready_queue.front();
ready_queue.pop();
SPDLOG_INFO("Active query {}", next_q->id);
active_query.insert(next_q);
next_q->to_status(Query::Prefill);
prefill_count += 1;
}
if (active_query.empty() == false) {
SPDLOG_DEBUG("Active Query Size {}", active_query.size());
}
for (auto &q : active_query) {
q->debug();
}
gen_batch_query_todo(new_batch, active_query);
}
};
std::shared_ptr<Scheduler> create_scheduler(Settings settings) {
spdlog::set_level(spdlog::level::debug);
std::shared_ptr<Scheduler> re;
SPDLOG_INFO("Using Strategy {}", settings.strategy_name);
if (settings.strategy_name == "FCFS-single-prefill") {
re = std::shared_ptr<Scheduler>(new FCFS_single_prefill());
} else if (settings.strategy_name == "FCFS") {
re = std::shared_ptr<Scheduler>(new FCFS());
} else {
SPDLOG_ERROR("Unknown strategy {}", settings.strategy_name);
}
re->init(settings);
return re;
}
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(SampleOptions, temperature, top_p);
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(QueryAdd, query_token, query_length,
estimated_length, sample_options, user_id,
SLO_TTFT_ms, SLO_TBT_ms);
std::string QueryAdd::serialize() {
json j = *this;
return j.dump();
}
QueryAdd QueryAdd::deserialize(const std::string &input) {
json j = json::parse(input);
return j.get<QueryAdd>();
}
}; // namespace scheduler