mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 14:51:06 +00:00
208 lines
12 KiB
C++
208 lines
12 KiB
C++
#include <pybind11/numpy.h>
|
|
#include <pybind11/pybind11.h>
|
|
#include <pybind11/stl.h>
|
|
#include <memory>
|
|
#include "scheduler.h"
|
|
|
|
#include <torch/extension.h>
|
|
|
|
namespace py = pybind11;
|
|
|
|
PYBIND11_MODULE(sched_ext, m) {
|
|
py::class_<scheduler::ModelSettings>(m, "ModelSettings")
|
|
.def(py::init<>())
|
|
.def_readwrite("model_path", &scheduler::ModelSettings::model_path)
|
|
.def_readwrite("params_count", &scheduler::ModelSettings::params_count)
|
|
.def_readwrite("layer_count", &scheduler::ModelSettings::layer_count)
|
|
.def_readwrite("num_k_heads", &scheduler::ModelSettings::num_k_heads)
|
|
.def_readwrite("k_head_dim", &scheduler::ModelSettings::k_head_dim)
|
|
.def_readwrite("bytes_per_params", &scheduler::ModelSettings::bytes_per_params)
|
|
.def_readwrite("bytes_per_kv_cache_element", &scheduler::ModelSettings::bytes_per_kv_cache_element)
|
|
.def("params_size", &scheduler::ModelSettings::params_nbytes)
|
|
.def("bytes_per_token_kv_cache", &scheduler::ModelSettings::bytes_per_token_kv_cache)
|
|
// 添加 pickle 支持
|
|
.def(py::pickle(
|
|
[](const scheduler::ModelSettings& self) { // __getstate__
|
|
return py::make_tuple(self.params_count, self.layer_count, self.num_k_heads, self.k_head_dim,
|
|
self.bytes_per_params, self.bytes_per_kv_cache_element);
|
|
},
|
|
[](py::tuple t) { // __setstate__
|
|
if (t.size() != 6)
|
|
throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size()));
|
|
scheduler::ModelSettings ms;
|
|
ms.params_count = t[0].cast<size_t>();
|
|
ms.layer_count = t[1].cast<size_t>();
|
|
ms.num_k_heads = t[2].cast<size_t>();
|
|
ms.k_head_dim = t[3].cast<size_t>();
|
|
ms.bytes_per_params = t[4].cast<double>();
|
|
ms.bytes_per_kv_cache_element = t[5].cast<double>();
|
|
return ms;
|
|
}));
|
|
|
|
py::class_<scheduler::SampleOptions>(m, "SampleOptions")
|
|
.def(py::init<>())
|
|
.def_readwrite("temperature", &scheduler::SampleOptions::temperature)
|
|
.def_readwrite("top_p", &scheduler::SampleOptions::top_p) // 确保 top_p 也能被访问
|
|
.def(py::pickle(
|
|
[](const scheduler::SampleOptions& self) {
|
|
return py::make_tuple(self.temperature, self.top_p); // 序列化 temperature 和 top_p
|
|
},
|
|
[](py::tuple t) {
|
|
if (t.size() != 2) // 确保解包时参数数量匹配
|
|
throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size()));
|
|
scheduler::SampleOptions so;
|
|
so.temperature = t[0].cast<double>();
|
|
so.top_p = t[1].cast<double>(); // 反序列化 top_p
|
|
return so;
|
|
}
|
|
));
|
|
|
|
py::class_<scheduler::Settings>(m, "Settings")
|
|
.def(py::init<>())
|
|
.def_readwrite("model_name", &scheduler::Settings::model_name)
|
|
.def_readwrite("quant_type", &scheduler::Settings::quant_type)
|
|
.def_readwrite("model_settings", &scheduler::Settings::model_settings)
|
|
.def_readwrite("page_size", &scheduler::Settings::page_size)
|
|
.def_readwrite("gpu_device_id", &scheduler::Settings::gpu_device_id)
|
|
.def_readwrite("gpu_memory_size", &scheduler::Settings::gpu_memory_size)
|
|
.def_readwrite("memory_utilization_percentage", &scheduler::Settings::memory_utilization_percentage)
|
|
.def_readwrite("max_batch_size", &scheduler::Settings::max_batch_size)
|
|
.def_readwrite("recommended_chunk_prefill_token_count",
|
|
&scheduler::Settings::recommended_chunk_prefill_token_count)
|
|
.def_readwrite("sample_options", &scheduler::Settings::sample_options)
|
|
.def_readwrite("sched_metrics_port", &scheduler::Settings::sched_metrics_port)
|
|
.def_readwrite("gpu_only", &scheduler::Settings::gpu_only)
|
|
.def_readwrite("use_self_defined_head_dim", &scheduler::Settings::use_self_defined_head_dim)
|
|
.def_readwrite("self_defined_head_dim", &scheduler::Settings::self_defined_head_dim)
|
|
.def_readwrite("full_kv_cache_on_each_gpu", &scheduler::Settings::full_kv_cache_on_each_gpu)
|
|
.def_readwrite("k_cache_on", &scheduler::Settings::k_cache_on)
|
|
.def_readwrite("v_cache_on", &scheduler::Settings::v_cache_on)
|
|
.def_readwrite("kvc2_config_path", &scheduler::Settings::kvc2_config_path)
|
|
.def_readwrite("kvc2_root_path", &scheduler::Settings::kvc2_root_path)
|
|
.def_readwrite("memory_pool_size_GB", &scheduler::Settings::memory_pool_size_GB)
|
|
.def_readwrite("evict_count", &scheduler::Settings::evict_count)
|
|
.def_readwrite("strategy_name", &scheduler::Settings::strategy_name)
|
|
.def_readwrite("kvc2_metrics_port", &scheduler::Settings::kvc2_metrics_port)
|
|
.def_readwrite("load_from_disk", &scheduler::Settings::load_from_disk)
|
|
.def_readwrite("save_to_disk", &scheduler::Settings::save_to_disk)
|
|
// derived
|
|
.def_readwrite("gpu_device_count", &scheduler::Settings::gpu_device_count)
|
|
.def_readwrite("total_kvcache_pages", &scheduler::Settings::total_kvcache_pages)
|
|
.def_readwrite("devices", &scheduler::Settings::devices)
|
|
.def("auto_derive", &scheduler::Settings::auto_derive);
|
|
|
|
py::class_<scheduler::BatchQueryTodo, std::shared_ptr<scheduler::BatchQueryTodo>>(m, "BatchQueryTodo")
|
|
.def(py::init<>())
|
|
.def_readwrite("query_ids", &scheduler::BatchQueryTodo::query_ids)
|
|
.def_readwrite("query_tokens", &scheduler::BatchQueryTodo::query_tokens)
|
|
.def_readwrite("query_lengths", &scheduler::BatchQueryTodo::query_lengths)
|
|
.def_readwrite("block_indexes", &scheduler::BatchQueryTodo::block_indexes)
|
|
.def_readwrite("attn_masks", &scheduler::BatchQueryTodo::attn_masks)
|
|
.def_readwrite("rope_ranges", &scheduler::BatchQueryTodo::rope_ranges)
|
|
.def_readwrite("sample_options", &scheduler::BatchQueryTodo::sample_options)
|
|
.def_readwrite("prefill_mini_batches", &scheduler::BatchQueryTodo::prefill_mini_batches)
|
|
.def_readwrite("decode_mini_batches", &scheduler::BatchQueryTodo::decode_mini_batches)
|
|
.def_readwrite("stop_criteria", &scheduler::BatchQueryTodo::stop_criteria)
|
|
.def("debug", &scheduler::BatchQueryTodo::debug)
|
|
.def(py::pickle(
|
|
[](const scheduler::BatchQueryTodo& self) {
|
|
return py::make_tuple(self.query_ids, self.query_tokens, self.query_lengths, self.block_indexes,
|
|
self.attn_masks, self.rope_ranges, self.sample_options, self.prefill_mini_batches,
|
|
self.decode_mini_batches, self.stop_criteria);
|
|
},
|
|
[](py::tuple t) {
|
|
if (t.size() != 10)
|
|
throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size()));
|
|
scheduler::BatchQueryTodo bqt;
|
|
bqt.query_ids = t[0].cast<std::vector<scheduler::QueryID>>();
|
|
bqt.query_tokens = t[1].cast<std::vector<torch::Tensor>>();
|
|
bqt.query_lengths = t[2].cast<std::vector<scheduler::TokenLength>>();
|
|
bqt.block_indexes = t[3].cast<std::vector<torch::Tensor>>();
|
|
bqt.attn_masks = t[4].cast<std::optional<torch::Tensor>>();
|
|
bqt.rope_ranges = t[5].cast<std::optional<torch::Tensor>>();
|
|
bqt.sample_options = t[6].cast<std::vector<scheduler::SampleOptions>>();
|
|
bqt.prefill_mini_batches = t[7].cast<std::vector<scheduler::PrefillTask>>();
|
|
bqt.decode_mini_batches = t[8].cast<std::vector<std::vector<scheduler::QueryID>>>();
|
|
bqt.stop_criteria = t[9].cast<std::vector<std::vector<std::vector<int>>>>();
|
|
return bqt;
|
|
}));
|
|
|
|
py::class_<scheduler::QueryUpdate>(m, "QueryUpdate")
|
|
.def(py::init<>())
|
|
.def_readwrite("id", &scheduler::QueryUpdate::id)
|
|
.def_readwrite("ok", &scheduler::QueryUpdate::ok)
|
|
.def_readwrite("is_prefill", &scheduler::QueryUpdate::is_prefill)
|
|
.def_readwrite("decode_done", &scheduler::QueryUpdate::decode_done)
|
|
.def_readwrite("active_position", &scheduler::QueryUpdate::active_position)
|
|
.def_readwrite("generated_token", &scheduler::QueryUpdate::generated_token)
|
|
.def(py::pickle(
|
|
[](const scheduler::QueryUpdate& self) {
|
|
return py::make_tuple(self.id, self.ok, self.is_prefill, self.decode_done, self.active_position,
|
|
self.generated_token);
|
|
},
|
|
[](py::tuple t) {
|
|
if (t.size() != 6)
|
|
throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size()));
|
|
scheduler::QueryUpdate qu;
|
|
qu.id = t[0].cast<scheduler::QueryID>();
|
|
qu.ok = t[1].cast<bool>();
|
|
qu.is_prefill = t[2].cast<bool>();
|
|
qu.decode_done = t[3].cast<bool>();
|
|
qu.active_position = t[4].cast<scheduler::TokenLength>();
|
|
qu.generated_token = t[5].cast<scheduler::Token>();
|
|
return qu;
|
|
}));
|
|
|
|
py::class_<scheduler::InferenceContext>(m, "InferenceContext")
|
|
.def(py::init<>())
|
|
.def_readwrite("k_cache", &scheduler::InferenceContext::k_cache)
|
|
.def_readwrite("v_cache", &scheduler::InferenceContext::v_cache)
|
|
;
|
|
|
|
py::class_<scheduler::QueryAdd>(m, "QueryAdd")
|
|
.def(py::init<>())
|
|
.def_readwrite("query_token", &scheduler::QueryAdd::query_token)
|
|
// .def_readwrite("attn_mask", &scheduler::QueryAdd::attn_mask)
|
|
.def_readwrite("query_length", &scheduler::QueryAdd::query_length)
|
|
.def_readwrite("estimated_length", &scheduler::QueryAdd::estimated_length)
|
|
.def_readwrite("sample_options", &scheduler::QueryAdd::sample_options)
|
|
.def_readwrite("user_id", &scheduler::QueryAdd::user_id)
|
|
.def_readwrite("SLO_TTFT_ms", &scheduler::QueryAdd::SLO_TTFT_ms)
|
|
.def_readwrite("SLO_TBT_ms", &scheduler::QueryAdd::SLO_TBT_ms)
|
|
.def_readwrite("stop_criteria", &scheduler::QueryAdd::stop_criteria)
|
|
.def("serialize", &scheduler::QueryAdd::serialize)
|
|
.def_static("deserialize", &scheduler::QueryAdd::deserialize)
|
|
.def(py::pickle(
|
|
[](const scheduler::QueryAdd& self) {
|
|
return py::make_tuple(self.query_token,
|
|
// self.attn_mask,
|
|
self.query_length, self.estimated_length, self.sample_options, self.user_id,
|
|
self.SLO_TTFT_ms, self.SLO_TBT_ms, self.stop_criteria);
|
|
},
|
|
[](py::tuple t) {
|
|
if (t.size() != 8)
|
|
throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size()));
|
|
scheduler::QueryAdd qa;
|
|
qa.query_token = t[0].cast<std::vector<scheduler::Token>>();
|
|
// qa.attn_mask = t[1].cast<torch::Tensor>();
|
|
qa.query_length = t[1].cast<scheduler::TokenLength>();
|
|
qa.estimated_length = t[2].cast<scheduler::TokenLength>();
|
|
qa.sample_options = t[3].cast<scheduler::SampleOptions>();
|
|
qa.user_id = t[4].cast<scheduler::UserID>();
|
|
qa.SLO_TTFT_ms = t[5].cast<int>();
|
|
qa.SLO_TBT_ms = t[6].cast<int>();
|
|
qa.stop_criteria = t[7].cast<std::vector<std::vector<int>>>();
|
|
return qa;
|
|
}));
|
|
|
|
py::class_<scheduler::Scheduler, std::shared_ptr<scheduler::Scheduler>>(m, "Scheduler")
|
|
.def("init", &scheduler::Scheduler::init)
|
|
.def("run", &scheduler::Scheduler::run)
|
|
.def("stop", &scheduler::Scheduler::stop)
|
|
.def("add_query", &scheduler::Scheduler::add_query, py::call_guard<py::gil_scoped_release>())
|
|
.def("cancel_query", &scheduler::Scheduler::cancel_query, py::call_guard<py::gil_scoped_release>())
|
|
.def("update_last_batch", &scheduler::Scheduler::update_last_batch, py::call_guard<py::gil_scoped_release>())
|
|
.def("get_inference_context", &scheduler::Scheduler::get_inference_context);
|
|
|
|
m.def("create_scheduler", &scheduler::create_scheduler, "Create a new Scheduler instance");
|
|
}
|