#include "scheduler.h" #include #include #include #include #include namespace py = pybind11; PYBIND11_MODULE(sched_ext, m) { py::class_(m, "ModelSettings") .def(py::init<>()) .def_readwrite("model_path", &scheduler::ModelSettings::model_path) .def_readwrite("params_count", &scheduler::ModelSettings::params_count) .def_readwrite("layer_count", &scheduler::ModelSettings::layer_count) .def_readwrite("num_k_heads", &scheduler::ModelSettings::num_k_heads) .def_readwrite("k_head_dim", &scheduler::ModelSettings::k_head_dim) .def_readwrite("bytes_per_params", &scheduler::ModelSettings::bytes_per_params) .def_readwrite("bytes_per_kv_cache_element", &scheduler::ModelSettings::bytes_per_kv_cache_element) .def("params_size", &scheduler::ModelSettings::params_nbytes) .def("bytes_per_token_kv_cache", &scheduler::ModelSettings::bytes_per_token_kv_cache) // 添加 pickle 支持 .def(py::pickle( [](const scheduler::ModelSettings &self) { // __getstate__ return py::make_tuple(self.params_count, self.layer_count, self.num_k_heads, self.k_head_dim, self.bytes_per_params, self.bytes_per_kv_cache_element); }, [](py::tuple t) { // __setstate__ if (t.size() != 6) throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size())); scheduler::ModelSettings ms; ms.params_count = t[0].cast(); ms.layer_count = t[1].cast(); ms.num_k_heads = t[2].cast(); ms.k_head_dim = t[3].cast(); ms.bytes_per_params = t[4].cast(); ms.bytes_per_kv_cache_element = t[5].cast(); return ms; })); py::class_(m, "SampleOptions") .def(py::init<>()) .def_readwrite("temperature", &scheduler::SampleOptions::temperature) .def_readwrite("top_p", &scheduler::SampleOptions::top_p) // 确保 top_p 也能被访问 .def(py::pickle( [](const scheduler::SampleOptions &self) { return py::make_tuple(self.temperature, self.top_p); // 序列化 temperature 和 top_p }, [](py::tuple t) { if (t.size() != 2) // 确保解包时参数数量匹配 throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size())); scheduler::SampleOptions so; so.temperature = t[0].cast(); so.top_p = t[1].cast(); // 反序列化 top_p return so; })); py::class_(m, "Settings") .def(py::init<>()) .def_readwrite("model_name", &scheduler::Settings::model_name) .def_readwrite("quant_type", &scheduler::Settings::quant_type) .def_readwrite("model_settings", &scheduler::Settings::model_settings) .def_readwrite("page_size", &scheduler::Settings::page_size) .def_readwrite("gpu_device_id", &scheduler::Settings::gpu_device_id) .def_readwrite("gpu_memory_size", &scheduler::Settings::gpu_memory_size) .def_readwrite("memory_utilization_percentage", &scheduler::Settings::memory_utilization_percentage) .def_readwrite("max_batch_size", &scheduler::Settings::max_batch_size) .def_readwrite( "recommended_chunk_prefill_token_count", &scheduler::Settings::recommended_chunk_prefill_token_count) .def_readwrite("sample_options", &scheduler::Settings::sample_options) .def_readwrite("sched_metrics_port", &scheduler::Settings::sched_metrics_port) .def_readwrite("gpu_only", &scheduler::Settings::gpu_only) .def_readwrite("use_self_defined_head_dim", &scheduler::Settings::use_self_defined_head_dim) .def_readwrite("self_defined_head_dim", &scheduler::Settings::self_defined_head_dim) .def_readwrite("full_kv_cache_on_each_gpu", &scheduler::Settings::full_kv_cache_on_each_gpu) .def_readwrite("k_cache_on", &scheduler::Settings::k_cache_on) .def_readwrite("v_cache_on", &scheduler::Settings::v_cache_on) .def_readwrite("kvc2_config_path", &scheduler::Settings::kvc2_config_path) .def_readwrite("kvc2_root_path", &scheduler::Settings::kvc2_root_path) .def_readwrite("memory_pool_size_GB", &scheduler::Settings::memory_pool_size_GB) .def_readwrite("evict_count", &scheduler::Settings::evict_count) .def_readwrite("strategy_name", &scheduler::Settings::strategy_name) .def_readwrite("kvc2_metrics_port", &scheduler::Settings::kvc2_metrics_port) .def_readwrite("load_from_disk", &scheduler::Settings::load_from_disk) .def_readwrite("save_to_disk", &scheduler::Settings::save_to_disk) // derived .def_readwrite("gpu_device_count", &scheduler::Settings::gpu_device_count) .def_readwrite("total_kvcache_pages", &scheduler::Settings::total_kvcache_pages) .def_readwrite("devices", &scheduler::Settings::devices) .def("auto_derive", &scheduler::Settings::auto_derive); py::class_>(m, "BatchQueryTodo") .def(py::init<>()) .def_readwrite("query_ids", &scheduler::BatchQueryTodo::query_ids) .def_readwrite("query_tokens", &scheduler::BatchQueryTodo::query_tokens) .def_readwrite("query_lengths", &scheduler::BatchQueryTodo::query_lengths) .def_readwrite("block_indexes", &scheduler::BatchQueryTodo::block_indexes) .def_readwrite("attn_masks", &scheduler::BatchQueryTodo::attn_masks) .def_readwrite("rope_ranges", &scheduler::BatchQueryTodo::rope_ranges) .def_readwrite("sample_options", &scheduler::BatchQueryTodo::sample_options) .def_readwrite("prefill_mini_batches", &scheduler::BatchQueryTodo::prefill_mini_batches) .def_readwrite("decode_mini_batches", &scheduler::BatchQueryTodo::decode_mini_batches) .def_readwrite("stop_criteria", &scheduler::BatchQueryTodo::stop_criteria) .def("debug", &scheduler::BatchQueryTodo::debug) .def(py::pickle( [](const scheduler::BatchQueryTodo &self) { return py::make_tuple( self.query_ids, self.query_tokens, self.query_lengths, self.block_indexes, self.attn_masks, self.rope_ranges, self.sample_options, self.prefill_mini_batches, self.decode_mini_batches, self.stop_criteria); }, [](py::tuple t) { if (t.size() != 10) throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size())); scheduler::BatchQueryTodo bqt; bqt.query_ids = t[0].cast>(); bqt.query_tokens = t[1].cast>(); bqt.query_lengths = t[2].cast>(); bqt.block_indexes = t[3].cast>(); bqt.attn_masks = t[4].cast>(); bqt.rope_ranges = t[5].cast>(); bqt.sample_options = t[6].cast>(); bqt.prefill_mini_batches = t[7].cast>(); bqt.decode_mini_batches = t[8].cast>>(); bqt.stop_criteria = t[9].cast>>>(); return bqt; })); py::class_(m, "QueryUpdate") .def(py::init<>()) .def_readwrite("id", &scheduler::QueryUpdate::id) .def_readwrite("ok", &scheduler::QueryUpdate::ok) .def_readwrite("is_prefill", &scheduler::QueryUpdate::is_prefill) .def_readwrite("decode_done", &scheduler::QueryUpdate::decode_done) .def_readwrite("active_position", &scheduler::QueryUpdate::active_position) .def_readwrite("generated_token", &scheduler::QueryUpdate::generated_token) .def(py::pickle( [](const scheduler::QueryUpdate &self) { return py::make_tuple(self.id, self.ok, self.is_prefill, self.decode_done, self.active_position, self.generated_token); }, [](py::tuple t) { if (t.size() != 6) throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size())); scheduler::QueryUpdate qu; qu.id = t[0].cast(); qu.ok = t[1].cast(); qu.is_prefill = t[2].cast(); qu.decode_done = t[3].cast(); qu.active_position = t[4].cast(); qu.generated_token = t[5].cast(); return qu; })); py::class_(m, "InferenceContext") .def(py::init<>()) .def_readwrite("k_cache", &scheduler::InferenceContext::k_cache) .def_readwrite("v_cache", &scheduler::InferenceContext::v_cache); py::class_(m, "QueryAdd") .def(py::init<>()) .def_readwrite("query_token", &scheduler::QueryAdd::query_token) // .def_readwrite("attn_mask", &scheduler::QueryAdd::attn_mask) .def_readwrite("query_length", &scheduler::QueryAdd::query_length) .def_readwrite("estimated_length", &scheduler::QueryAdd::estimated_length) .def_readwrite("sample_options", &scheduler::QueryAdd::sample_options) .def_readwrite("user_id", &scheduler::QueryAdd::user_id) .def_readwrite("SLO_TTFT_ms", &scheduler::QueryAdd::SLO_TTFT_ms) .def_readwrite("SLO_TBT_ms", &scheduler::QueryAdd::SLO_TBT_ms) .def_readwrite("stop_criteria", &scheduler::QueryAdd::stop_criteria) .def("serialize", &scheduler::QueryAdd::serialize) .def_static("deserialize", &scheduler::QueryAdd::deserialize) .def(py::pickle( [](const scheduler::QueryAdd &self) { return py::make_tuple(self.query_token, // self.attn_mask, self.query_length, self.estimated_length, self.sample_options, self.user_id, self.SLO_TTFT_ms, self.SLO_TBT_ms, self.stop_criteria); }, [](py::tuple t) { if (t.size() != 8) throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size())); scheduler::QueryAdd qa; qa.query_token = t[0].cast>(); // qa.attn_mask = t[1].cast(); qa.query_length = t[1].cast(); qa.estimated_length = t[2].cast(); qa.sample_options = t[3].cast(); qa.user_id = t[4].cast(); qa.SLO_TTFT_ms = t[5].cast(); qa.SLO_TBT_ms = t[6].cast(); qa.stop_criteria = t[7].cast>>(); return qa; })); py::class_>( m, "Scheduler") .def("init", &scheduler::Scheduler::init) .def("run", &scheduler::Scheduler::run) .def("stop", &scheduler::Scheduler::stop) .def("add_query", &scheduler::Scheduler::add_query, py::call_guard()) .def("cancel_query", &scheduler::Scheduler::cancel_query, py::call_guard()) .def("update_last_batch", &scheduler::Scheduler::update_last_batch, py::call_guard()) .def("get_inference_context", &scheduler::Scheduler::get_inference_context); m.def("create_scheduler", &scheduler::create_scheduler, "Create a new Scheduler instance"); }