format kvc2, delete quant_configs, move model_configs to ~/.ktransformers

This commit is contained in:
qiyuxinlin 2025-04-08 10:06:07 +00:00
parent 9dd24ecd72
commit 64de784328
31 changed files with 853 additions and 878 deletions

View file

@ -77,7 +77,6 @@ GPUPageCache::GPUPageCache(GPUPageCacheConfig& config) : config(config) {
gpu_only_occupations.resize(config.total_kvcache_pages, false);
}
num_free_pages = config.total_kvcache_pages;
for (size_t i = 0; i < config.layer_count; i++) {
if (config.k_cache_on)
@ -248,18 +247,19 @@ void GPUPageCache::append_col_to_request(std::vector<std::shared_ptr<CudaStreamM
auto gpu_block_idx = k_handles[0][at]->gpu_block_idx.value();
for (size_t layer = 0; layer < config.layer_count; layer++) {
for (size_t which_gpu = 0; which_gpu < config.gpu_devices_id.size(); which_gpu++) {
if (config.k_cache_on) {
assert(k_handles[layer][at]->data != nullptr);
reqs[which_gpu]->sizes.push_back(tp_size[which_gpu]);
reqs[which_gpu]->host_mem_addresses.push_back(offset_by_bytes(k_handles[layer][at]->data, tp_offset[which_gpu]));
reqs[which_gpu]->host_mem_addresses.push_back(
offset_by_bytes(k_handles[layer][at]->data, tp_offset[which_gpu]));
reqs[which_gpu]->device_mem_addresses.push_back(k_cache[which_gpu][layer][gpu_block_idx].data_ptr());
}
if (config.v_cache_on) {
assert(v_handles[layer][at]->data != nullptr);
reqs[which_gpu]->sizes.push_back(tp_size[which_gpu]);
reqs[which_gpu]->host_mem_addresses.push_back(offset_by_bytes(v_handles[layer][at]->data, tp_offset[which_gpu]));
reqs[which_gpu]->host_mem_addresses.push_back(
offset_by_bytes(v_handles[layer][at]->data, tp_offset[which_gpu]));
reqs[which_gpu]->device_mem_addresses.push_back(v_cache[which_gpu][layer][gpu_block_idx].data_ptr());
}
}

View file

@ -1,16 +1,16 @@
#pragma once
#include "prometheus/counter.h"
#include "prometheus/exposer.h"
#include "prometheus/gauge.h"
#include "prometheus/histogram.h"
#include "prometheus/registry.h"
#include <atomic>
#include <chrono>
#include <memory>
#include <string>
#include <thread>
#include <vector>
#include "prometheus/counter.h"
#include "prometheus/exposer.h"
#include "prometheus/gauge.h"
#include "prometheus/histogram.h"
#include "prometheus/registry.h"
#include "utils/timer.hpp"

View file

@ -1,8 +1,8 @@
#ifndef __MODEL_CONFIG_HPP_
#define __MODEL_CONFIG_HPP_
#include <iostream>
#include "nlohmann/json.hpp"
#include <iostream>
#include <filesystem>
#include <fstream>
@ -23,10 +23,13 @@ class ModelConfig {
size_t num_key_value_heads;
size_t vocab_size;
NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, intermediate_size, max_position_embeddings, model_type,
num_attention_heads, num_hidden_layers, num_key_value_heads, vocab_size);
NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, intermediate_size,
max_position_embeddings, model_type,
num_attention_heads, num_hidden_layers,
num_key_value_heads, vocab_size);
void load_from(std::filesystem::path path) {
std::cout << "Load from " << path << std::endl;
std::ifstream i(path);
nlohmann::json j;
i >> j;
@ -43,7 +46,9 @@ class QuantConfig {
// For GEMV
QuantType type_of_dot_vector = NoQuantType;
inline bool can_be_used_as_matrix() { return type_of_dot_vector != NoQuantType; }
inline bool can_be_used_as_matrix() {
return type_of_dot_vector != NoQuantType;
}
bool can_be_used_as_vector;
@ -56,8 +61,11 @@ class QuantConfig {
URL reference = "";
NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(QuantConfig, name, type_of_dot_vector, can_be_used_as_vector,
bytes_per_element, has_scale, has_min, block_element_count,
NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(QuantConfig, name,
type_of_dot_vector,
can_be_used_as_vector,
bytes_per_element, has_scale,
has_min, block_element_count,
block_element_size, reference);
};
@ -65,15 +73,19 @@ inline std::map<QuantType, QuantConfig> quant_configs;
inline std::map<ModelName, ModelConfig> model_configs;
inline void load_quant_configs(std::filesystem::path path) {
nlohmann::json j;
if (std::filesystem::exists(path)) {
std::cout << __FUNCTION__ << " from " << path << std::endl;
std::ifstream i(path);
nlohmann::json j;
i >> j;
quant_configs = j.get<std::map<QuantType, QuantConfig>>();
std::cout << "Loaded Quant Configs" << std::endl;
for (auto &[k, v] : quant_configs) {
std::cout << " - " << k << std::endl;
}
} else {
std::cout << __FUNCTION__ << " no file at " << path << std::endl;
}
}
inline void dump_quant_configs(std::filesystem::path path) {
@ -83,15 +95,19 @@ inline void dump_quant_configs(std::filesystem::path path) {
}
inline void load_model_configs(std::filesystem::path path) {
nlohmann::json j;
if (std::filesystem::exists(path)) {
std::cout << __FUNCTION__ << " from " << path << std::endl;
std::ifstream i(path);
nlohmann::json j;
i >> j;
model_configs = j.get<std::map<ModelName, ModelConfig>>();
std::cout << "Loaded Model Configs" << std::endl;
for (auto &[k, v] : model_configs) {
std::cout << " - " << k << std::endl;
}
} else {
std::cout << __FUNCTION__ << " no file at " << path << std::endl;
}
}
inline void dump_model_configs(std::filesystem::path path) {

View file

@ -18,12 +18,13 @@ PageAlignedMemoryPool::PageAlignedMemoryPool(size_t size_in_bytes) {
page_per_block = total_pages / Blocks;
for (size_t block_index = 0; block_index < Blocks; block_index++) {
first_page[block_index] = reinterpret_cast<void*>(reinterpret_cast<intptr_t>(data) + static_cast<intptr_t>(block_index) * page_per_block * PageSize);
first_page[block_index] = reinterpret_cast<void*>(reinterpret_cast<intptr_t>(data) +
static_cast<intptr_t>(block_index) * page_per_block * PageSize);
count_page[block_index] =
block_index == Blocks - 1 ? (total_pages - page_per_block * (Blocks - 1)) : page_per_block;
SPDLOG_DEBUG("first_page[{}] = {}, count_page[{}] = {}",
block_index, reinterpret_cast<intptr_t>(first_page[block_index]) - reinterpret_cast<intptr_t>(data),
block_index, count_page[block_index]);
SPDLOG_DEBUG("first_page[{}] = {}, count_page[{}] = {}", block_index,
reinterpret_cast<intptr_t>(first_page[block_index]) - reinterpret_cast<intptr_t>(data), block_index,
count_page[block_index]);
bitmap[block_index].resize(count_page[block_index], 0);
}
SPDLOG_INFO("PageAlignedMemoryPool with size {} Mbytes, {} pages", total_size / (1 << 20), page_count());
@ -119,5 +120,6 @@ void PageAlignedMemoryPool::defragment() {}
/// 调试打印
std::string PageAlignedMemoryPool::debug() {
return fmt::format("PageAlignedMemoryPool: total_size: {}MB, allocated: {}, alloc/free count: {}/{}\n",
readable_number(total_size), readable_number(size_t(allocated)), size_t(alloc_count), size_t(free_count));
readable_number(total_size), readable_number(size_t(allocated)), size_t(alloc_count),
size_t(free_count));
}

View file

@ -1,12 +1,12 @@
#pragma once
#include <assert.h>
#include <algorithm> // std::sort
#include <atomic>
#include <bitset>
#include <cstddef> // size_t
#include <mutex> // std::mutex
#include <vector>
#include <assert.h>
#include <bitset>
#include <atomic>
constexpr size_t PageSize = 4096;
@ -30,6 +30,7 @@ struct PageAlignedMemoryPool {
size_t count_page[Blocks];
std::vector<int8_t> bitmap[Blocks];
void* alloc_in_block(size_t block_index, size_t alloc_size);
public:
/// 构造函数和析构函数
explicit PageAlignedMemoryPool(size_t size_in_bytes);

View file

@ -579,7 +579,6 @@ struct PrefixTree {
return re;
}
std::shared_ptr<Prefix> new_prefix_node(Prefix* prev, TokenLength prev_match_length, Token* data, TokenLength length,
bool need_lock = true) {
std::unique_lock<std::shared_mutex> ul;
@ -700,9 +699,7 @@ struct DoubleCacheHandle : public DoubleCacheHandleInterface {
}
}
}
std::vector<MatchStatus> matched_status() override {
assert(false);
}
std::vector<MatchStatus> matched_status() override { assert(false); }
bool any_match() {
if (enable_alt) {
@ -1066,7 +1063,6 @@ struct DoubleCacheHandle : public DoubleCacheHandleInterface {
};
struct KVC2 : KVC2Interface {
KVC2Config config;
std::shared_ptr<Metrics> met;
@ -1694,9 +1690,11 @@ void GPUPageCache::gpu_background_flush() {
if (col_uls.empty())
continue;
for (size_t l = 0; l < config.layer_count; l++) {
if (config.k_cache_on && (occupations[l][i]->gpu_cc.dirty.load() == false || occupations[l][i]->cpu_cc.dirty.load()))
if (config.k_cache_on &&
(occupations[l][i]->gpu_cc.dirty.load() == false || occupations[l][i]->cpu_cc.dirty.load()))
goto next_gpu_page;
if (config.v_cache_on && (v_occupations[l][i]->gpu_cc.dirty.load() == false || v_occupations[l][i]->cpu_cc.dirty.load()))
if (config.v_cache_on &&
(v_occupations[l][i]->gpu_cc.dirty.load() == false || v_occupations[l][i]->cpu_cc.dirty.load()))
goto next_gpu_page;
}

View file

@ -11,7 +11,6 @@
#include "common.hpp"
int main(int argc, char* argv[]) {
qw25_7B_gpu_config.v_cache_on = false;
config.gpu_cache_config = qw25_7B_gpu_config;
config.v_cache_on = false;

View file

@ -1,16 +1,15 @@
#include <unistd.h>
#include <iostream>
#include <random>
#include <thread>
#include <vector>
#include <random>
#include <unistd.h>
#include "page_aligned_memory_pool.cpp"
#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG
#define FMT_HEADER_ONLY
#include "spdlog/spdlog.h"
// 每个线程执行的任务
void thread_task(PageAlignedMemoryPool& pool) {
std::mt19937 gen(123);
@ -36,7 +35,6 @@ void thread_task(PageAlignedMemoryPool& pool) {
int main(int argc, char* argv[]) {
spdlog::set_level(spdlog::level::debug);
// 创建一个内存池
PageAlignedMemoryPool pool(40ll * 1024 * 1024 * 1024); // 40 G

View file

@ -1,18 +1,16 @@
#include "utils/periodic_task.hpp"
#include <chrono>
#include <cstdio>
#include <iostream>
#include <thread>
#include <future>
#include <atomic>
#include <cassert>
#include <chrono>
#include <cstdio>
#include <future>
#include <iostream>
#include <thread>
#include "utils/periodic_task.hpp"
// 1. 任务是否按预期执行
void testPeriodicTaskExecution() {
std::atomic<int> execution_count{0};
auto task = [&execution_count]() {
execution_count++;
};
auto task = [&execution_count]() { execution_count++; };
periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(50));
@ -26,9 +24,7 @@ void testPeriodicTaskExecution() {
// 2. 提前唤醒任务的功能
void testWakeUpImmediately() {
std::atomic<int> execution_count{0};
auto task = [&execution_count]() {
execution_count++;
};
auto task = [&execution_count]() { execution_count++; };
periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200));
@ -63,9 +59,7 @@ void testWakeUpWait() {
// 4. 任务抛出异常的处理
void testTaskExceptionHandling() {
auto task = []() {
throw std::runtime_error("Test exception");
};
auto task = []() { throw std::runtime_error("Test exception"); };
periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200));
@ -98,9 +92,7 @@ void testTaskStop() {
// 6. 高频唤醒的情况下任务执行是否正常
void testHighFrequencyWakeUp() {
std::atomic<int> execution_count{0};
auto task = [&execution_count]() {
execution_count++;
};
auto task = [&execution_count]() { execution_count++; };
periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200));

View file

@ -1,8 +1,8 @@
#include "scheduler.h"
#include <memory>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <memory>
#include "scheduler.h"
#include <torch/extension.h>
@ -16,19 +16,25 @@ PYBIND11_MODULE(sched_ext, m) {
.def_readwrite("layer_count", &scheduler::ModelSettings::layer_count)
.def_readwrite("num_k_heads", &scheduler::ModelSettings::num_k_heads)
.def_readwrite("k_head_dim", &scheduler::ModelSettings::k_head_dim)
.def_readwrite("bytes_per_params", &scheduler::ModelSettings::bytes_per_params)
.def_readwrite("bytes_per_kv_cache_element", &scheduler::ModelSettings::bytes_per_kv_cache_element)
.def_readwrite("bytes_per_params",
&scheduler::ModelSettings::bytes_per_params)
.def_readwrite("bytes_per_kv_cache_element",
&scheduler::ModelSettings::bytes_per_kv_cache_element)
.def("params_size", &scheduler::ModelSettings::params_nbytes)
.def("bytes_per_token_kv_cache", &scheduler::ModelSettings::bytes_per_token_kv_cache)
.def("bytes_per_token_kv_cache",
&scheduler::ModelSettings::bytes_per_token_kv_cache)
// 添加 pickle 支持
.def(py::pickle(
[](const scheduler::ModelSettings &self) { // __getstate__
return py::make_tuple(self.params_count, self.layer_count, self.num_k_heads, self.k_head_dim,
self.bytes_per_params, self.bytes_per_kv_cache_element);
return py::make_tuple(self.params_count, self.layer_count,
self.num_k_heads, self.k_head_dim,
self.bytes_per_params,
self.bytes_per_kv_cache_element);
},
[](py::tuple t) { // __setstate__
if (t.size() != 6)
throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size()));
throw std::runtime_error("Invalid state! t.size() = " +
std::to_string(t.size()));
scheduler::ModelSettings ms;
ms.params_count = t[0].cast<size_t>();
ms.layer_count = t[1].cast<size_t>();
@ -42,20 +48,22 @@ PYBIND11_MODULE(sched_ext, m) {
py::class_<scheduler::SampleOptions>(m, "SampleOptions")
.def(py::init<>())
.def_readwrite("temperature", &scheduler::SampleOptions::temperature)
.def_readwrite("top_p", &scheduler::SampleOptions::top_p) // 确保 top_p 也能被访问
.def_readwrite("top_p",
&scheduler::SampleOptions::top_p) // 确保 top_p 也能被访问
.def(py::pickle(
[](const scheduler::SampleOptions &self) {
return py::make_tuple(self.temperature, self.top_p); // 序列化 temperature 和 top_p
return py::make_tuple(self.temperature,
self.top_p); // 序列化 temperature 和 top_p
},
[](py::tuple t) {
if (t.size() != 2) // 确保解包时参数数量匹配
throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size()));
throw std::runtime_error("Invalid state! t.size() = " +
std::to_string(t.size()));
scheduler::SampleOptions so;
so.temperature = t[0].cast<double>();
so.top_p = t[1].cast<double>(); // 反序列化 top_p
return so;
}
));
}));
py::class_<scheduler::Settings>(m, "Settings")
.def(py::init<>())
@ -65,33 +73,43 @@ PYBIND11_MODULE(sched_ext, m) {
.def_readwrite("page_size", &scheduler::Settings::page_size)
.def_readwrite("gpu_device_id", &scheduler::Settings::gpu_device_id)
.def_readwrite("gpu_memory_size", &scheduler::Settings::gpu_memory_size)
.def_readwrite("memory_utilization_percentage", &scheduler::Settings::memory_utilization_percentage)
.def_readwrite("memory_utilization_percentage",
&scheduler::Settings::memory_utilization_percentage)
.def_readwrite("max_batch_size", &scheduler::Settings::max_batch_size)
.def_readwrite("recommended_chunk_prefill_token_count",
.def_readwrite(
"recommended_chunk_prefill_token_count",
&scheduler::Settings::recommended_chunk_prefill_token_count)
.def_readwrite("sample_options", &scheduler::Settings::sample_options)
.def_readwrite("sched_metrics_port", &scheduler::Settings::sched_metrics_port)
.def_readwrite("sched_metrics_port",
&scheduler::Settings::sched_metrics_port)
.def_readwrite("gpu_only", &scheduler::Settings::gpu_only)
.def_readwrite("use_self_defined_head_dim", &scheduler::Settings::use_self_defined_head_dim)
.def_readwrite("self_defined_head_dim", &scheduler::Settings::self_defined_head_dim)
.def_readwrite("full_kv_cache_on_each_gpu", &scheduler::Settings::full_kv_cache_on_each_gpu)
.def_readwrite("use_self_defined_head_dim",
&scheduler::Settings::use_self_defined_head_dim)
.def_readwrite("self_defined_head_dim",
&scheduler::Settings::self_defined_head_dim)
.def_readwrite("full_kv_cache_on_each_gpu",
&scheduler::Settings::full_kv_cache_on_each_gpu)
.def_readwrite("k_cache_on", &scheduler::Settings::k_cache_on)
.def_readwrite("v_cache_on", &scheduler::Settings::v_cache_on)
.def_readwrite("kvc2_config_path", &scheduler::Settings::kvc2_config_path)
.def_readwrite("kvc2_root_path", &scheduler::Settings::kvc2_root_path)
.def_readwrite("memory_pool_size_GB", &scheduler::Settings::memory_pool_size_GB)
.def_readwrite("memory_pool_size_GB",
&scheduler::Settings::memory_pool_size_GB)
.def_readwrite("evict_count", &scheduler::Settings::evict_count)
.def_readwrite("strategy_name", &scheduler::Settings::strategy_name)
.def_readwrite("kvc2_metrics_port", &scheduler::Settings::kvc2_metrics_port)
.def_readwrite("kvc2_metrics_port",
&scheduler::Settings::kvc2_metrics_port)
.def_readwrite("load_from_disk", &scheduler::Settings::load_from_disk)
.def_readwrite("save_to_disk", &scheduler::Settings::save_to_disk)
// derived
.def_readwrite("gpu_device_count", &scheduler::Settings::gpu_device_count)
.def_readwrite("total_kvcache_pages", &scheduler::Settings::total_kvcache_pages)
.def_readwrite("total_kvcache_pages",
&scheduler::Settings::total_kvcache_pages)
.def_readwrite("devices", &scheduler::Settings::devices)
.def("auto_derive", &scheduler::Settings::auto_derive);
py::class_<scheduler::BatchQueryTodo, std::shared_ptr<scheduler::BatchQueryTodo>>(m, "BatchQueryTodo")
py::class_<scheduler::BatchQueryTodo,
std::shared_ptr<scheduler::BatchQueryTodo>>(m, "BatchQueryTodo")
.def(py::init<>())
.def_readwrite("query_ids", &scheduler::BatchQueryTodo::query_ids)
.def_readwrite("query_tokens", &scheduler::BatchQueryTodo::query_tokens)
@ -99,31 +117,42 @@ PYBIND11_MODULE(sched_ext, m) {
.def_readwrite("block_indexes", &scheduler::BatchQueryTodo::block_indexes)
.def_readwrite("attn_masks", &scheduler::BatchQueryTodo::attn_masks)
.def_readwrite("rope_ranges", &scheduler::BatchQueryTodo::rope_ranges)
.def_readwrite("sample_options", &scheduler::BatchQueryTodo::sample_options)
.def_readwrite("prefill_mini_batches", &scheduler::BatchQueryTodo::prefill_mini_batches)
.def_readwrite("decode_mini_batches", &scheduler::BatchQueryTodo::decode_mini_batches)
.def_readwrite("sample_options",
&scheduler::BatchQueryTodo::sample_options)
.def_readwrite("prefill_mini_batches",
&scheduler::BatchQueryTodo::prefill_mini_batches)
.def_readwrite("decode_mini_batches",
&scheduler::BatchQueryTodo::decode_mini_batches)
.def_readwrite("stop_criteria", &scheduler::BatchQueryTodo::stop_criteria)
.def("debug", &scheduler::BatchQueryTodo::debug)
.def(py::pickle(
[](const scheduler::BatchQueryTodo &self) {
return py::make_tuple(self.query_ids, self.query_tokens, self.query_lengths, self.block_indexes,
self.attn_masks, self.rope_ranges, self.sample_options, self.prefill_mini_batches,
return py::make_tuple(
self.query_ids, self.query_tokens, self.query_lengths,
self.block_indexes, self.attn_masks, self.rope_ranges,
self.sample_options, self.prefill_mini_batches,
self.decode_mini_batches, self.stop_criteria);
},
[](py::tuple t) {
if (t.size() != 10)
throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size()));
throw std::runtime_error("Invalid state! t.size() = " +
std::to_string(t.size()));
scheduler::BatchQueryTodo bqt;
bqt.query_ids = t[0].cast<std::vector<scheduler::QueryID>>();
bqt.query_tokens = t[1].cast<std::vector<torch::Tensor>>();
bqt.query_lengths = t[2].cast<std::vector<scheduler::TokenLength>>();
bqt.query_lengths =
t[2].cast<std::vector<scheduler::TokenLength>>();
bqt.block_indexes = t[3].cast<std::vector<torch::Tensor>>();
bqt.attn_masks = t[4].cast<std::optional<torch::Tensor>>();
bqt.rope_ranges = t[5].cast<std::optional<torch::Tensor>>();
bqt.sample_options = t[6].cast<std::vector<scheduler::SampleOptions>>();
bqt.prefill_mini_batches = t[7].cast<std::vector<scheduler::PrefillTask>>();
bqt.decode_mini_batches = t[8].cast<std::vector<std::vector<scheduler::QueryID>>>();
bqt.stop_criteria = t[9].cast<std::vector<std::vector<std::vector<int>>>>();
bqt.sample_options =
t[6].cast<std::vector<scheduler::SampleOptions>>();
bqt.prefill_mini_batches =
t[7].cast<std::vector<scheduler::PrefillTask>>();
bqt.decode_mini_batches =
t[8].cast<std::vector<std::vector<scheduler::QueryID>>>();
bqt.stop_criteria =
t[9].cast<std::vector<std::vector<std::vector<int>>>>();
return bqt;
}));
@ -133,16 +162,20 @@ PYBIND11_MODULE(sched_ext, m) {
.def_readwrite("ok", &scheduler::QueryUpdate::ok)
.def_readwrite("is_prefill", &scheduler::QueryUpdate::is_prefill)
.def_readwrite("decode_done", &scheduler::QueryUpdate::decode_done)
.def_readwrite("active_position", &scheduler::QueryUpdate::active_position)
.def_readwrite("generated_token", &scheduler::QueryUpdate::generated_token)
.def_readwrite("active_position",
&scheduler::QueryUpdate::active_position)
.def_readwrite("generated_token",
&scheduler::QueryUpdate::generated_token)
.def(py::pickle(
[](const scheduler::QueryUpdate &self) {
return py::make_tuple(self.id, self.ok, self.is_prefill, self.decode_done, self.active_position,
return py::make_tuple(self.id, self.ok, self.is_prefill,
self.decode_done, self.active_position,
self.generated_token);
},
[](py::tuple t) {
if (t.size() != 6)
throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size()));
throw std::runtime_error("Invalid state! t.size() = " +
std::to_string(t.size()));
scheduler::QueryUpdate qu;
qu.id = t[0].cast<scheduler::QueryID>();
qu.ok = t[1].cast<bool>();
@ -156,8 +189,7 @@ PYBIND11_MODULE(sched_ext, m) {
py::class_<scheduler::InferenceContext>(m, "InferenceContext")
.def(py::init<>())
.def_readwrite("k_cache", &scheduler::InferenceContext::k_cache)
.def_readwrite("v_cache", &scheduler::InferenceContext::v_cache)
;
.def_readwrite("v_cache", &scheduler::InferenceContext::v_cache);
py::class_<scheduler::QueryAdd>(m, "QueryAdd")
.def(py::init<>())
@ -176,12 +208,15 @@ PYBIND11_MODULE(sched_ext, m) {
[](const scheduler::QueryAdd &self) {
return py::make_tuple(self.query_token,
// self.attn_mask,
self.query_length, self.estimated_length, self.sample_options, self.user_id,
self.SLO_TTFT_ms, self.SLO_TBT_ms, self.stop_criteria);
self.query_length, self.estimated_length,
self.sample_options, self.user_id,
self.SLO_TTFT_ms, self.SLO_TBT_ms,
self.stop_criteria);
},
[](py::tuple t) {
if (t.size() != 8)
throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size()));
throw std::runtime_error("Invalid state! t.size() = " +
std::to_string(t.size()));
scheduler::QueryAdd qa;
qa.query_token = t[0].cast<std::vector<scheduler::Token>>();
// qa.attn_mask = t[1].cast<torch::Tensor>();
@ -195,14 +230,20 @@ PYBIND11_MODULE(sched_ext, m) {
return qa;
}));
py::class_<scheduler::Scheduler, std::shared_ptr<scheduler::Scheduler>>(m, "Scheduler")
py::class_<scheduler::Scheduler, std::shared_ptr<scheduler::Scheduler>>(
m, "Scheduler")
.def("init", &scheduler::Scheduler::init)
.def("run", &scheduler::Scheduler::run)
.def("stop", &scheduler::Scheduler::stop)
.def("add_query", &scheduler::Scheduler::add_query, py::call_guard<py::gil_scoped_release>())
.def("cancel_query", &scheduler::Scheduler::cancel_query, py::call_guard<py::gil_scoped_release>())
.def("update_last_batch", &scheduler::Scheduler::update_last_batch, py::call_guard<py::gil_scoped_release>())
.def("get_inference_context", &scheduler::Scheduler::get_inference_context);
.def("add_query", &scheduler::Scheduler::add_query,
py::call_guard<py::gil_scoped_release>())
.def("cancel_query", &scheduler::Scheduler::cancel_query,
py::call_guard<py::gil_scoped_release>())
.def("update_last_batch", &scheduler::Scheduler::update_last_batch,
py::call_guard<py::gil_scoped_release>())
.def("get_inference_context",
&scheduler::Scheduler::get_inference_context);
m.def("create_scheduler", &scheduler::create_scheduler, "Create a new Scheduler instance");
m.def("create_scheduler", &scheduler::create_scheduler,
"Create a new Scheduler instance");
}

View file

@ -4,11 +4,11 @@
// 构造函数
Metrics::Metrics(const MetricsConfig &config)
: registry_(std::make_shared<prometheus::Registry>()),
exposer_(config.endpoint),
stop_uptime_thread_(false),
exposer_(config.endpoint), stop_uptime_thread_(false),
start_time_(std::chrono::steady_clock::now()) {
// 定义统一的桶大小,最大为 10000 ms (10 s)
std::vector<double> common_buckets = {0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0, 5.0,
std::vector<double> common_buckets = {
0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0, 5.0,
10.0, 50.0, 100.0, 500.0, 1000.0, 5000.0, 10000.0}; // 毫秒
// 注册 TTFT_ms Histogram
@ -26,46 +26,57 @@ Metrics::Metrics(const MetricsConfig& config)
TBT_ms = &TBT_family.Add({{"model", config.model_name}}, common_buckets);
// 注册 schedule_time Histogram
auto& schedule_time_family = prometheus::BuildHistogram()
auto &schedule_time_family =
prometheus::BuildHistogram()
.Name(std::string(METRIC_PREFIX) + "_schedule_time_ms")
.Help("Time to generate schedule in milliseconds")
.Register(*registry_);
schedule_time = &schedule_time_family.Add({{"model", config.model_name}}, common_buckets);
schedule_time =
&schedule_time_family.Add({{"model", config.model_name}}, common_buckets);
// 注册 generated_tokens Counter
auto& generated_tokens_family = prometheus::BuildCounter()
auto &generated_tokens_family =
prometheus::BuildCounter()
.Name(std::string(METRIC_PREFIX) + "_generated_tokens_total")
.Help("Total generated tokens")
.Register(*registry_);
generated_tokens = &generated_tokens_family.Add({{"model", config.model_name}});
generated_tokens =
&generated_tokens_family.Add({{"model", config.model_name}});
// 注册 throughput_query Gauge
auto& throughput_query_family = prometheus::BuildGauge()
auto &throughput_query_family =
prometheus::BuildGauge()
.Name(std::string(METRIC_PREFIX) + "_throughput_query")
.Help("Throughput per second based on queries")
.Register(*registry_);
throughput_query = &throughput_query_family.Add({{"model", config.model_name}});
throughput_query =
&throughput_query_family.Add({{"model", config.model_name}});
// 注册 throughput_generated_tokens Gauge
auto& throughput_generated_tokens_family = prometheus::BuildGauge()
auto &throughput_generated_tokens_family =
prometheus::BuildGauge()
.Name(std::string(METRIC_PREFIX) + "_throughput_generated_tokens")
.Help("Throughput per second based on generated tokens")
.Register(*registry_);
throughput_generated_tokens = &throughput_generated_tokens_family.Add({{"model", config.model_name}});
throughput_generated_tokens =
&throughput_generated_tokens_family.Add({{"model", config.model_name}});
// 注册 event_count Counter family
event_count_family_ = &prometheus::BuildCounter()
event_count_family_ =
&prometheus::BuildCounter()
.Name(std::string(METRIC_PREFIX) + "_event_count_total")
.Help("Count of various events")
.Register(*registry_);
batch_count_family_ = &prometheus::BuildCounter()
batch_count_family_ =
&prometheus::BuildCounter()
.Name(std::string(METRIC_PREFIX) + "_batch_count_total")
.Help("Count of various batch by status")
.Register(*registry_);
// 注册 query_count Counter family
query_count_family_ = &prometheus::BuildCounter()
query_count_family_ =
&prometheus::BuildCounter()
.Name(std::string(METRIC_PREFIX) + "_query_count_total")
.Help("Count of queries by status")
.Register(*registry_);
@ -78,13 +89,14 @@ Metrics::Metrics(const MetricsConfig& config)
uptime_ms = &uptime_family.Add({{"model", config.model_name}});
// 注册 GPU 利用率 Gauges
auto& gpu_util_family = prometheus::BuildGauge()
auto &gpu_util_family =
prometheus::BuildGauge()
.Name(std::string(METRIC_PREFIX) + "_gpu_utilization_ratio")
.Help("Current GPU utilization ratio (0 to 1)")
.Register(*registry_);
for (size_t i = 0; i < config.gpu_count; ++i) {
gpu_utilization_gauges.push_back(
&gpu_util_family.Add({{"gpu_id", std::to_string(i)}, {"model", config.model_name}}));
gpu_utilization_gauges.push_back(&gpu_util_family.Add(
{{"gpu_id", std::to_string(i)}, {"model", config.model_name}}));
}
// 将 Registry 注册到 Exposer 中
@ -95,16 +107,15 @@ Metrics::Metrics(const MetricsConfig& config)
}
// 析构函数
Metrics::~Metrics() {
StopUptimeUpdater();
}
Metrics::~Metrics() { StopUptimeUpdater(); }
// 启动 uptime 更新线程
void Metrics::StartUptimeUpdater() {
uptime_thread_ = std::thread([this]() {
while (!stop_uptime_thread_) {
auto now = std::chrono::steady_clock::now();
std::chrono::duration<double, std::milli> uptime_duration = now - start_time_;
std::chrono::duration<double, std::milli> uptime_duration =
now - start_time_;
uptime_ms->Set(uptime_duration.count());
// fn_every_sec(this);
std::this_thread::sleep_for(std::chrono::seconds(1));
@ -127,7 +138,8 @@ prometheus::Counter* Metrics::event_count(const std::string& type) {
// 获取 query_count 指标
prometheus::Counter *Metrics::query_count(const std::string &status) {
return &query_count_family_->Add({{"status", status}}); // 可根据需要添加更多标签
return &query_count_family_->Add(
{{"status", status}}); // 可根据需要添加更多标签
}
prometheus::Counter *Metrics::batch_count(const std::string &type) {

View file

@ -1,14 +1,14 @@
#ifndef Metrics_H
#define Metrics_H
#include <atomic>
#include <chrono>
#include <memory>
#include <prometheus/counter.h>
#include <prometheus/exposer.h>
#include <prometheus/gauge.h>
#include <prometheus/histogram.h>
#include <prometheus/registry.h>
#include <atomic>
#include <chrono>
#include <memory>
#include <string>
#include <thread>
#include <vector>
@ -78,7 +78,10 @@ class Metrics {
struct HistogramTimerWrapper {
prometheus::Histogram *histogram;
Timer timer;
inline HistogramTimerWrapper(prometheus::Histogram* histogram) : histogram(histogram), timer() { timer.start(); }
inline HistogramTimerWrapper(prometheus::Histogram *histogram)
: histogram(histogram), timer() {
timer.start();
}
inline ~HistogramTimerWrapper() { histogram->Observe(timer.elapsedMs()); }
};

View file

@ -1,8 +1,8 @@
#ifndef __MODEL_CONFIG_HPP_
#define __MODEL_CONFIG_HPP_
#include <iostream>
#include "nlohmann/json.hpp"
#include <iostream>
#include <filesystem>
#include <fstream>
@ -23,10 +23,13 @@ class ModelConfig {
size_t num_key_value_heads;
size_t vocab_size;
NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, intermediate_size, max_position_embeddings, model_type,
num_attention_heads, num_hidden_layers, num_key_value_heads, vocab_size);
NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, intermediate_size,
max_position_embeddings, model_type,
num_attention_heads, num_hidden_layers,
num_key_value_heads, vocab_size);
void load_from(std::filesystem::path path) {
std::cout << "Load from " << path << std::endl;
std::ifstream i(path);
nlohmann::json j;
i >> j;
@ -43,7 +46,9 @@ class QuantConfig {
// For GEMV
QuantType type_of_dot_vector = NoQuantType;
inline bool can_be_used_as_matrix() { return type_of_dot_vector != NoQuantType; }
inline bool can_be_used_as_matrix() {
return type_of_dot_vector != NoQuantType;
}
bool can_be_used_as_vector;
@ -56,8 +61,11 @@ class QuantConfig {
URL reference = "";
NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(QuantConfig, name, type_of_dot_vector, can_be_used_as_vector,
bytes_per_element, has_scale, has_min, block_element_count,
NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(QuantConfig, name,
type_of_dot_vector,
can_be_used_as_vector,
bytes_per_element, has_scale,
has_min, block_element_count,
block_element_size, reference);
};
@ -70,15 +78,14 @@ inline void load_quant_configs(std::filesystem::path path) {
std::cout << __FUNCTION__ << " from " << path << std::endl;
std::ifstream i(path);
i >> j;
} else {
std::cout << __FUNCTION__ << " create new at " << path << std::endl;
}
quant_configs = j.get<std::map<QuantType, QuantConfig>>();
std::cout << "Loaded Quant Configs" << std::endl;
for (auto &[k, v] : quant_configs) {
std::cout << " - " << k << std::endl;
}
} else {
std::cout << __FUNCTION__ << " no file at " << path << std::endl;
}
}
inline void dump_quant_configs(std::filesystem::path path) {
@ -93,15 +100,14 @@ inline void load_model_configs(std::filesystem::path path) {
std::cout << __FUNCTION__ << " from " << path << std::endl;
std::ifstream i(path);
i >> j;
} else {
std::cout << __FUNCTION__ << " create new at " << path << std::endl;
}
model_configs = j.get<std::map<ModelName, ModelConfig>>();
std::cout << "Loaded Model Configs" << std::endl;
for (auto &[k, v] : model_configs) {
std::cout << " - " << k << std::endl;
}
} else {
std::cout << __FUNCTION__ << " no file at " << path << std::endl;
}
}
inline void dump_model_configs(std::filesystem::path path) {

View file

@ -3,20 +3,20 @@
#include "nlohmann/json.hpp"
#include "spdlog/spdlog.h"
#include <optional>
#include "scheduler.h"
#include <optional>
#include <atomic>
#include <cassert>
#include <future>
#include <memory>
#include <queue>
#include "arithmetic.hpp"
#include "atomic_ptr_with_flags.hpp"
#include "easy_format.hpp"
#include "metrics.h"
#include "mpsc.hpp"
#include "timer.hpp"
#include <atomic>
#include <cassert>
#include <future>
#include <memory>
#include <queue>
#include "kvc2.h"
@ -28,7 +28,8 @@ void Settings::auto_derive() {
gpu_device_count = gpu_device_id.size();
if (torch::cuda::is_available()) {
size_t gpu_count = torch::cuda::device_count();
SPDLOG_INFO("Number of available GPUs: {}, want {}", gpu_count, gpu_device_count);
SPDLOG_INFO("Number of available GPUs: {}, want {}", gpu_count,
gpu_device_count);
if (gpu_count < gpu_device_count) {
SPDLOG_ERROR("Not enough GPUs available.");
exit(0);
@ -42,37 +43,49 @@ void Settings::auto_derive() {
}
if (model_settings.num_k_heads % gpu_device_count != 0) {
SPDLOG_ERROR("num_k_heads {} is not divisible by gpu_device_count {}", model_settings.num_k_heads,
gpu_device_count);
SPDLOG_ERROR("num_k_heads {} is not divisible by gpu_device_count {}",
model_settings.num_k_heads, gpu_device_count);
assert(false);
}
size_t gpu_memory_available = gpu_memory_size * memory_utilization_percentage;
if (gpu_memory_available * gpu_device_count < model_settings.params_nbytes()) {
SPDLOG_ERROR("GPU memory size {}G is smaller than {}G", gpu_memory_available * gpu_device_count / 1e9,
if (gpu_memory_available * gpu_device_count <
model_settings.params_nbytes()) {
SPDLOG_ERROR("GPU memory size {}G is smaller than {}G",
gpu_memory_available * gpu_device_count / 1e9,
model_settings.params_nbytes() / 1e9);
assert(false);
}
assert(model_settings.k_head_dim % model_settings.num_k_heads == 0);
size_t head_per_gpu = model_settings.num_k_heads / gpu_device_count;
size_t gpu_memory_for_kv_cache = gpu_memory_available /*- model_settings.params_nbytes() / gpu_device_count*/;
SPDLOG_INFO("Each GPU Total: {}MiB, Model Params: {}MiB, KVCache: {}MiB, Left: {}MiB", gpu_memory_size / (1 << 20),
model_settings.params_nbytes() / gpu_device_count / (1 << 20), gpu_memory_for_kv_cache / (1 << 20),
size_t gpu_memory_for_kv_cache =
gpu_memory_available /*- model_settings.params_nbytes() /
gpu_device_count*/
;
SPDLOG_INFO(
"Each GPU Total: {}MiB, Model Params: {}MiB, KVCache: {}MiB, Left: {}MiB",
gpu_memory_size / (1 << 20),
model_settings.params_nbytes() / gpu_device_count / (1 << 20),
gpu_memory_for_kv_cache / (1 << 20),
(gpu_memory_size - gpu_memory_available) / (1 << 20));
size_t kv_cache_on_cnt = (size_t)(k_cache_on) + (size_t)(v_cache_on);
size_t max_total_kvcache_pages =
gpu_memory_for_kv_cache / (kv_cache_on_cnt * head_per_gpu * model_settings.k_head_dim *
model_settings.bytes_per_kv_cache_element * page_size * model_settings.layer_count);
gpu_memory_for_kv_cache /
(kv_cache_on_cnt * head_per_gpu * model_settings.k_head_dim *
model_settings.bytes_per_kv_cache_element * page_size *
model_settings.layer_count);
if (total_kvcache_pages.has_value()) {
if (total_kvcache_pages.value() > max_total_kvcache_pages) {
SPDLOG_ERROR("total_kvcache_pages {} is larger than max_total_kvcache_pages {}", total_kvcache_pages.value(),
max_total_kvcache_pages);
SPDLOG_ERROR(
"total_kvcache_pages {} is larger than max_total_kvcache_pages {}",
total_kvcache_pages.value(), max_total_kvcache_pages);
assert(false);
}
} else {
total_kvcache_pages = max_total_kvcache_pages;
SPDLOG_INFO("total_kvcache_pages is auto derived as {}", max_total_kvcache_pages);
SPDLOG_INFO("total_kvcache_pages is auto derived as {}",
max_total_kvcache_pages);
}
if (page_size % 256 != 0) {
@ -139,39 +152,41 @@ struct Query {
void to_status(Status to);
void export_metrics() { ctx.met->query_count(status_to_string(plan_status))->Increment(1); }
void export_metrics() {
ctx.met->query_count(status_to_string(plan_status))->Increment(1);
}
Query(QueryID id, QueryAdd query_add, QueryContext context)
: id(id),
prompt_length(query_add.query_length),
no_kvcache_from(0),
: id(id), prompt_length(query_add.query_length), no_kvcache_from(0),
estimated_length(query_add.estimated_length),
sample_options(query_add.sample_options),
user_id(query_add.user_id),
SLO_TTFT_ms(query_add.SLO_TTFT_ms),
SLO_TBT_ms(query_add.SLO_TBT_ms),
stop_criteria(query_add.stop_criteria),
ctx(context) {
sample_options(query_add.sample_options), user_id(query_add.user_id),
SLO_TTFT_ms(query_add.SLO_TTFT_ms), SLO_TBT_ms(query_add.SLO_TBT_ms),
stop_criteria(query_add.stop_criteria), ctx(context) {
std::vector<int64_t> shape = {int64_t(query_add.estimated_length)};
query_token = torch::zeros(shape, torch::TensorOptions().dtype(torch::kInt32));
query_token =
torch::zeros(shape, torch::TensorOptions().dtype(torch::kInt32));
assert(query_token.is_contiguous());
if (query_token.is_contiguous() == false) {
SPDLOG_ERROR("Query Token must be contiguous!");
exit(1);
}
memcpy(query_token.data_ptr(), query_add.query_token.data(), query_add.query_length * sizeof(Token));
memcpy(query_token.data_ptr(), query_add.query_token.data(),
query_add.query_length * sizeof(Token));
no_kvcache_from = 0; // maybe match prefix later
export_metrics();
}
Token& token_at(size_t idx) { return reinterpret_cast<Token*>(query_token.data_ptr())[idx]; }
Token &token_at(size_t idx) {
return reinterpret_cast<Token *>(query_token.data_ptr())[idx];
}
void absorb_update(const QueryUpdate &update) {
SPDLOG_DEBUG("{}", update.debug());
active_position = update.active_position;
kvc2_handle->append_tokens(&token_at(0), active_position); // active_position is length -1
kvc2_handle->append_tokens(&token_at(0),
active_position); // active_position is length -1
if (update.is_prefill) {
if (active_position == prompt_length) {
token_at(active_position) = update.generated_token;
@ -195,7 +210,9 @@ struct Query {
}
}
void absorb_decode_task([[maybe_unused]] const QueryID& task) { this->plan_position += 1; }
void absorb_decode_task([[maybe_unused]] const QueryID &task) {
this->plan_position += 1;
}
PrefillTask get_prefill_task(size_t prefill_length) {
if (prefill_length + plan_position > prompt_length) {
@ -225,16 +242,19 @@ struct Query {
void debug() {
std::string status_string = status_to_string(plan_status);
SPDLOG_DEBUG(
"Query {}, prompt_length {}, estimated_length {}, plan status {}, plan position {} "
SPDLOG_DEBUG("Query {}, prompt_length {}, estimated_length {}, plan status "
"{}, plan position {} "
"active position {}",
id, prompt_length, estimated_length, status_string, plan_position, active_position);
id, prompt_length, estimated_length, status_string,
plan_position, active_position);
}
};
std::string QueryUpdate::debug() const {
return fmt::format("Query {}, ok {}, is_prefill {}, done {}, active_position {}, gen token {}", id, ok, is_prefill,
decode_done, active_position, generated_token);
return fmt::format("Query {}, ok {}, is_prefill {}, done {}, active_position "
"{}, gen token {}",
id, ok, is_prefill, decode_done, active_position,
generated_token);
}
using Q = std::shared_ptr<Query>;
@ -258,18 +278,22 @@ struct KVC2_Maintainer {
.total_kvcache_pages = settings.total_kvcache_pages.value(),
.num_token_per_page = settings.page_size,
.num_k_heads = settings.model_settings.num_k_heads,
.k_head_dim =
settings.use_self_defined_head_dim ? settings.self_defined_head_dim : settings.model_settings.k_head_dim,
.k_head_dim = settings.use_self_defined_head_dim
? settings.self_defined_head_dim
: settings.model_settings.k_head_dim,
.full_kv_cache_on_each_gpu = settings.full_kv_cache_on_each_gpu,
.k_cache_on = settings.k_cache_on,
.v_cache_on = settings.v_cache_on,
.tensor_type = torch::kBFloat16,
};
auto model_configs_path = std::filesystem::path(settings.kvc2_config_path) / "model_configs.json";
auto model_configs_path =
std::filesystem::path(settings.kvc2_config_path) / "model_configs.json";
load_model_configs(model_configs_path);
auto my_model_config = ModelConfig();
my_model_config.load_from(std::filesystem::path(settings.model_settings.model_path) / "config.json");
my_model_config.load_from(
std::filesystem::path(settings.model_settings.model_path) /
"config.json");
model_configs[settings.model_name] = my_model_config;
dump_model_configs(model_configs_path);
@ -317,43 +341,36 @@ struct EventQueryStatus{
};
struct EventSchedule {};
using Event = std::variant<EventAddQuery, EventUpdateQuery, EventTakenBatch, EventPrepare, EventPrepared,
EventQueryStatus, EventSchedule>;
using Event =
std::variant<EventAddQuery, EventUpdateQuery, EventTakenBatch, EventPrepare,
EventPrepared, EventQueryStatus, EventSchedule>;
template <typename T>
std::string event_name(const T& event);
template <typename T> std::string event_name(const T &event);
template <>
std::string event_name(const EventAddQuery&) {
template <> std::string event_name(const EventAddQuery &) {
return "EventAddQuery";
}
template <>
std::string event_name(const EventUpdateQuery&) {
template <> std::string event_name(const EventUpdateQuery &) {
return "EventUpdateQuery";
}
template <>
std::string event_name(const EventTakenBatch&) {
template <> std::string event_name(const EventTakenBatch &) {
return "EventTakenBatch";
}
template <>
std::string event_name(const EventPrepare&) {
template <> std::string event_name(const EventPrepare &) {
return "EventPrepare";
}
template <>
std::string event_name(const EventPrepared&) {
template <> std::string event_name(const EventPrepared &) {
return "EventPrepared";
}
template <>
std::string event_name(const EventQueryStatus&) {
template <> std::string event_name(const EventQueryStatus &) {
return "EventQueryStatus";
}
template <>
std::string event_name(const EventSchedule&) {
template <> std::string event_name(const EventSchedule &) {
return "EventSchedule";
}
@ -397,10 +414,10 @@ struct QueryMaintainer : public Scheduler {
decode_num += 1;
}
}
if (prefill_num >= 2 || (prefill_num ==1 && settings.max_batch_size - 2 < decode_num)) {
if (prefill_num >= 2 ||
(prefill_num == 1 && settings.max_batch_size - 2 < decode_num)) {
preill_length = settings.recommended_chunk_prefill_token_count;
}
else {
} else {
preill_length = settings.recommended_chunk_prefill_token_count * 2;
}
for (auto &q : queries) {
@ -439,8 +456,7 @@ struct QueryMaintainer : public Scheduler {
// Interface
void init(Settings settings) override {
SPDLOG_INFO(
"\nScheduler Settings:\n"
SPDLOG_INFO("\nScheduler Settings:\n"
" model_name: {}\n"
" quant_type: {}\n"
" model_path: {}\n"
@ -466,19 +482,28 @@ struct QueryMaintainer : public Scheduler {
" save_to_disk: {}\n"
" strategy_name: {}\n"
" gpu_device_count: {}\n",
settings.model_name, settings.quant_type, settings.model_settings.model_path,
settings.model_settings.params_count, settings.model_settings.layer_count, settings.model_settings.num_k_heads,
settings.model_settings.k_head_dim, settings.model_settings.bytes_per_params,
settings.model_name, settings.quant_type,
settings.model_settings.model_path,
settings.model_settings.params_count,
settings.model_settings.layer_count,
settings.model_settings.num_k_heads,
settings.model_settings.k_head_dim,
settings.model_settings.bytes_per_params,
settings.model_settings.bytes_per_kv_cache_element,
settings.page_size, format_vector(settings.gpu_device_id), readable_number(settings.gpu_memory_size),
settings.memory_utilization_percentage, settings.max_batch_size, settings.recommended_chunk_prefill_token_count,
settings.sched_metrics_port, settings.kvc2_config_path, settings.kvc2_root_path, settings.memory_pool_size_GB,
settings.evict_count, settings.kvc2_metrics_port, settings.load_from_disk, settings.save_to_disk,
settings.page_size, format_vector(settings.gpu_device_id),
readable_number(settings.gpu_memory_size),
settings.memory_utilization_percentage, settings.max_batch_size,
settings.recommended_chunk_prefill_token_count,
settings.sched_metrics_port, settings.kvc2_config_path,
settings.kvc2_root_path, settings.memory_pool_size_GB,
settings.evict_count, settings.kvc2_metrics_port,
settings.load_from_disk, settings.save_to_disk,
settings.strategy_name, settings.gpu_device_count);
this->settings = settings;
kvc2_maintainer = std::shared_ptr<KVC2_Maintainer>(new KVC2_Maintainer(settings));
kvc2_maintainer =
std::shared_ptr<KVC2_Maintainer>(new KVC2_Maintainer(settings));
MetricsConfig met_conf = {
.endpoint = "0.0.0.0:" + std::to_string(settings.sched_metrics_port),
.model_name = settings.model_name,
@ -522,7 +547,8 @@ struct QueryMaintainer : public Scheduler {
// Here this function update last batch results and get the next batch
// in most cases, the batch is ready,
// if not, busy wait to get it
std::shared_ptr<BatchQueryTodo> update_last_batch(BatchQueryUpdate updates) override {
std::shared_ptr<BatchQueryTodo>
update_last_batch(BatchQueryUpdate updates) override {
event_loop_queue.enqueue(updates);
// Busy Wait
@ -547,7 +573,8 @@ struct QueryMaintainer : public Scheduler {
InferenceContext re;
re.k_cache = kvc2_maintainer->k_cache;
re.v_cache = kvc2_maintainer->v_cache;
// kvc2_maintainer->k_cache[0][0][0][0][0][0] = 42; // test whether we pass this to inference loop
// kvc2_maintainer->k_cache[0][0][0][0][0][0] = 42; // test whether we pass
// this to inference loop
return re;
}
@ -557,7 +584,8 @@ struct QueryMaintainer : public Scheduler {
virtual void strategy_prepare(const EventPrepare &prepare) = 0;
virtual void strategy_prepared(const EventPrepared &prepared) = 0;
virtual void strategy_query_status(const EventQueryStatus &query_status) = 0;
virtual void strategy_schedule(const EventSchedule& event, BatchQueryTodo* new_batch) = 0;
virtual void strategy_schedule(const EventSchedule &event,
BatchQueryTodo *new_batch) = 0;
void tackle_event(EventAddQuery &event) {
auto &query_add = event.first;
@ -578,10 +606,13 @@ struct QueryMaintainer : public Scheduler {
exit(1);
}
auto q = query_map[u.id];
if (q->plan_status == Query::Status::Prefill || q->plan_status == Query::Status::Decode) {
if (q->plan_status == Query::Status::Prefill ||
q->plan_status == Query::Status::Decode) {
q->absorb_update(u);
} else {
SPDLOG_DEBUG("Query {} is not in Prefill or Decode status, do not update it", u.id);
SPDLOG_DEBUG(
"Query {} is not in Prefill or Decode status, do not update it",
u.id);
}
}
strategy_update_query(update);
@ -606,7 +637,9 @@ struct QueryMaintainer : public Scheduler {
void tackle_event(const EventPrepare &event) { strategy_prepare(event); }
void tackle_event(const EventPrepared &event) { strategy_prepared(event); }
void tackle_event(const EventQueryStatus& event) { strategy_query_status(event); }
void tackle_event(const EventQueryStatus &event) {
strategy_query_status(event);
}
void tackle_event(const EventSchedule &event) {
// SPDLOG_INFO("Tackle Schedule Event");
@ -660,7 +693,8 @@ struct QueryMaintainer : public Scheduler {
}
},
event);
if (event_loop_queue.size() == 0 && std::holds_alternative<EventSchedule>(event) == false) {
if (event_loop_queue.size() == 0 &&
std::holds_alternative<EventSchedule>(event) == false) {
// if this is not a schedule event, we need to schedule one
event_loop_queue.enqueue(EventSchedule());
}
@ -684,12 +718,16 @@ void Query::to_status(Status to) {
break;
case Preparing:
SPDLOG_INFO("Preparing Query {} {}", id,
prepare_try_count > 0 ? (std::to_string(prepare_try_count) + " Try") : "");
prepare_try_count > 0
? (std::to_string(prepare_try_count) + " Try")
: "");
prepare_try_count += 1;
ctx.kvc2_interface->lookup_to_gpu_async(
ctx.model_name, ctx.quant_type, static_cast<kvc2::Token*>(query_token.data_ptr()), prompt_length,
estimated_length, [this](std::shared_ptr<kvc2::DoubleCacheHandleInterface> handle) {
ctx.model_name, ctx.quant_type,
static_cast<kvc2::Token *>(query_token.data_ptr()), prompt_length,
estimated_length,
[this](std::shared_ptr<kvc2::DoubleCacheHandleInterface> handle) {
if (handle == nullptr) {
SPDLOG_INFO("Get handle from kvc2 Failed.");
this->after_load(false);
@ -734,10 +772,13 @@ void Query::to_status(Status to) {
void Query::after_load(bool ok) {
if (ok) {
size_t page_count = div_up(estimated_length, ctx.query_maintainer->settings.page_size);
size_t page_count =
div_up(estimated_length, ctx.query_maintainer->settings.page_size);
std::vector<int64_t> shape;
shape.push_back(page_count);
block_index = torch::zeros(shape, torch::TensorOptions().dtype(torch::kInt32)).contiguous();
block_index =
torch::zeros(shape, torch::TensorOptions().dtype(torch::kInt32))
.contiguous();
auto ptr = reinterpret_cast<int32_t *>(block_index.data_ptr());
auto vec_idx = kvc2_handle->get_gpu_block_idx();
for (size_t i = 0; i < vec_idx.size(); i++) {
@ -826,10 +867,10 @@ struct FCFS_single_prefill : public QueryMaintainer {
wait_done_prepare = std::nullopt;
}
}
}
void strategy_schedule([[maybe_unused]] const EventSchedule& event, BatchQueryTodo* new_batch) override {
void strategy_schedule([[maybe_unused]] const EventSchedule &event,
BatchQueryTodo *new_batch) override {
bool have_prefill = false;
for (auto &q : active_query) {
if (q->plan_status == Query::Prefill) {
@ -837,7 +878,8 @@ struct FCFS_single_prefill : public QueryMaintainer {
}
}
if (have_prefill == false && ready_queue.empty() == false && active_query.size() < settings.max_batch_size) {
if (have_prefill == false && ready_queue.empty() == false &&
active_query.size() < settings.max_batch_size) {
auto &next_q = ready_queue.front();
ready_queue.pop();
@ -855,7 +897,8 @@ struct FCFS_single_prefill : public QueryMaintainer {
};
struct FCFS : public FCFS_single_prefill {
void strategy_schedule([[maybe_unused]] const EventSchedule& event, BatchQueryTodo* new_batch) override {
void strategy_schedule([[maybe_unused]] const EventSchedule &event,
BatchQueryTodo *new_batch) override {
int prefill_count = 0;
const int max_prefill_count = 2;
for (auto &q : active_query) {
@ -900,7 +943,8 @@ std::shared_ptr<Scheduler> create_scheduler(Settings settings) {
}
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(SampleOptions, temperature, top_p);
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(QueryAdd, query_token, query_length, estimated_length, sample_options, user_id,
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(QueryAdd, query_token, query_length,
estimated_length, sample_options, user_id,
SLO_TTFT_ms, SLO_TBT_ms);
std::string QueryAdd::serialize() {

View file

@ -1,10 +1,10 @@
#pragma once
#include <torch/torch.h>
#include "model_config.h"
#include <cstdint>
#include <memory>
#include <optional>
#include <torch/torch.h>
#include <vector>
#include "model_config.h"
namespace scheduler {
@ -28,7 +28,9 @@ struct ModelSettings {
double bytes_per_kv_cache_element;
inline size_t params_nbytes() { return params_count * bytes_per_params; }
inline size_t bytes_per_token_kv_cache() { return bytes_per_kv_cache_element * num_k_heads * k_head_dim; }
inline size_t bytes_per_token_kv_cache() {
return bytes_per_kv_cache_element * num_k_heads * k_head_dim;
}
};
struct SampleOptions {
@ -37,7 +39,8 @@ struct SampleOptions {
};
struct Settings {
// something is aukward here, kvc2 only use model_name and quant_type to get model infos.
// something is aukward here, kvc2 only use model_name and quant_type to get
// model infos.
ModelName model_name;
QuantType quant_type;
// model_setting is ignore by kvc2
@ -79,14 +82,16 @@ struct Settings {
void auto_derive();
};
using PrefillTask = std::tuple<QueryID, TokenLength, TokenLength>; // id, start, length
using PrefillTask =
std::tuple<QueryID, TokenLength, TokenLength>; // id, start, length
struct BatchQueryTodo {
// query
std::vector<QueryID> query_ids;
std::vector<torch::Tensor> query_tokens;
std::vector<TokenLength> query_lengths;
std::vector<torch::Tensor> block_indexes; // (max_num_blocks_per_seq), dtype torch.int32.
std::vector<torch::Tensor>
block_indexes; // (max_num_blocks_per_seq), dtype torch.int32.
std::optional<torch::Tensor> attn_masks;
std::optional<torch::Tensor> rope_ranges;
std::vector<SampleOptions> sample_options;
@ -94,8 +99,10 @@ struct BatchQueryTodo {
// mini batches, adjacent two mini batches are executed together
// tasks count must be <=2, because of flash infer attention
std::vector<PrefillTask> prefill_mini_batches; // prefill minibatch only has 1 prefill
std::vector<std::vector<QueryID>> decode_mini_batches; // decode minibatch has multiple decode
std::vector<PrefillTask>
prefill_mini_batches; // prefill minibatch only has 1 prefill
std::vector<std::vector<QueryID>>
decode_mini_batches; // decode minibatch has multiple decode
std::string debug();
bool empty();
@ -156,7 +163,8 @@ class Scheduler {
virtual void cancel_query(QueryID id) = 0;
// inference loop call this
virtual std::shared_ptr<BatchQueryTodo> update_last_batch(BatchQueryUpdate updates) = 0;
virtual std::shared_ptr<BatchQueryTodo>
update_last_batch(BatchQueryUpdate updates) = 0;
virtual InferenceContext get_inference_context() = 0;
virtual ~Scheduler() = default;

View file

@ -1,7 +1,6 @@
#include <type_traits>
template <typename T, typename U>
T div_up(T x, U by) {
template <typename T, typename U> T div_up(T x, U by) {
static_assert(std::is_integral_v<T>);
static_assert(std::is_integral_v<U>);
return (x + by - 1) / by;

View file

@ -1,28 +1,35 @@
#include <atomic>
template <typename T>
struct AtomicPtrWithFlag {
template <typename T> struct AtomicPtrWithFlag {
constexpr static uint64_t mask = 1ull << 63;
std::atomic_uint64_t ptr = 0;
std::pair<T*, bool> load(std::memory_order order = std::memory_order_seq_cst) {
std::pair<T *, bool>
load(std::memory_order order = std::memory_order_seq_cst) {
uint64_t val = ptr.load(order);
return {reinterpret_cast<T *>(val & (~mask)), val & mask};
}
void store(T* p, bool flag, std::memory_order order = std::memory_order_seq_cst) {
void store(T *p, bool flag,
std::memory_order order = std::memory_order_seq_cst) {
ptr.store(reinterpret_cast<uint64_t>(p) | (flag ? mask : 0), order);
}
std::pair<T*, bool> exchange(T* p, bool flag, std::memory_order order = std::memory_order_seq_cst) {
uint64_t val = ptr.exchange(reinterpret_cast<uint64_t>(p) | (flag ? mask : 0), order);
std::pair<T *, bool>
exchange(T *p, bool flag,
std::memory_order order = std::memory_order_seq_cst) {
uint64_t val =
ptr.exchange(reinterpret_cast<uint64_t>(p) | (flag ? mask : 0), order);
return {reinterpret_cast<T *>(val & (~mask)), val & mask};
}
std::pair<T*, bool> touch_load(std::memory_order order = std::memory_order_seq_cst) {
std::pair<T *, bool>
touch_load(std::memory_order order = std::memory_order_seq_cst) {
uint64_t val = ptr.fetch_and(~mask, order);
return {reinterpret_cast<T *>(val & (~mask)), val & mask};
}
bool check_flag(std::memory_order order = std::memory_order_seq_cst) { return ptr.load(order) & mask; }
bool check_flag(std::memory_order order = std::memory_order_seq_cst) {
return ptr.load(order) & mask;
}
};

View file

@ -57,7 +57,8 @@ inline std::vector<std::string> parse_csv_line(const std::string& line) {
* @return A vector of pairs, each containing a column name and a vector of data
* for that column.
*/
inline std::vector<std::pair<std::string, std::vector<std::string>>> read_csv(const std::string& filename) {
inline std::vector<std::pair<std::string, std::vector<std::string>>>
read_csv(const std::string &filename) {
std::cout << "Reading CSV file: " << filename << std::endl;
// Open the file
std::ifstream file(filename);
@ -107,7 +108,8 @@ inline std::vector<std::pair<std::string, std::vector<std::string>>> read_csv(co
chunk_starts.push_back(content_size);
// Create threads to parse each chunk
std::vector<std::vector<std::vector<std::string>>> thread_results(num_threads);
std::vector<std::vector<std::vector<std::string>>> thread_results(
num_threads);
std::vector<std::thread> threads;
for (unsigned int i = 0; i < num_threads; ++i) {
@ -158,7 +160,8 @@ inline std::vector<std::pair<std::string, std::vector<std::string>>> read_csv(co
* @param data A vector of pairs, each containing a column name and a vector of
* data for that column.
*/
inline void write_csv(const std::string& filename,
inline void write_csv(
const std::string &filename,
const std::vector<std::pair<std::string, std::vector<std::string>>> &data) {
std::cout << "Writing CSV file: " << filename << std::endl;
@ -204,7 +207,8 @@ inline void write_csv(const std::string& filename,
pos += 2;
}
}
if (escaped_field.find(',') != std::string::npos || escaped_field.find('\n') != std::string::npos) {
if (escaped_field.find(',') != std::string::npos ||
escaped_field.find('\n') != std::string::npos) {
needs_quotes = true;
}
if (needs_quotes) {

View file

@ -2,8 +2,7 @@
#include <string>
#include <vector>
template <typename T>
std::string format_vector(const std::vector<T>& v) {
template <typename T> std::string format_vector(const std::vector<T> &v) {
std::ostringstream oss;
if (v.empty())
return "[]";

View file

@ -4,8 +4,7 @@
#include <optional>
#include <semaphore>
template <typename T>
class MPSCQueue {
template <typename T> class MPSCQueue {
struct Node {
T data;
std::atomic<Node *> next;
@ -59,16 +58,16 @@ class MPSCQueue {
size_t size() { return enqueue_count.load() - dequeue_count; }
};
template <typename T>
class MPSCQueueConsumerLock {
template <typename T> class MPSCQueueConsumerLock {
MPSCQueue<T> queue;
std::counting_semaphore<> sema{0};
public:
void enqueue(T data) {
queue.enqueue(std::move(data));
// std::atomic_thread_fence(std::memory_order_seq_cst);// Inserting this because the memory order might be wrong, I
// am also not that sure about this.
// std::atomic_thread_fence(std::memory_order_seq_cst);// Inserting this
// because the memory order might be wrong, I am also not that sure about
// this.
sema.release();
}
@ -76,7 +75,9 @@ class MPSCQueueConsumerLock {
auto re = queue.dequeue();
if (re.has_value()) {
while (sema.try_acquire() == false) {
std::cerr << __FILE__ << ":" << __FUNCTION__ << " sema try acquire should be success, retrying, please check"
std::cerr
<< __FILE__ << ":" << __FUNCTION__
<< " sema try acquire should be success, retrying, please check"
<< std::endl;
// assert(false);
}
@ -91,7 +92,9 @@ class MPSCQueueConsumerLock {
auto re = queue.dequeue();
if (re.has_value()) {
while (sema.try_acquire() == false) {
std::cerr << __FILE__ << ":" << __FUNCTION__ << " sema try acquire should be success, retrying, please check"
std::cerr
<< __FILE__ << ":" << __FUNCTION__
<< " sema try acquire should be success, retrying, please check"
<< std::endl;
// assert(false);
}

View file

@ -9,24 +9,30 @@
class Statistics {
public:
// Increment the counter for a given key by a specified value (default is 1)
void increment_counter(const std::string& key, int64_t value = 1) { counters_[key] += value; }
void increment_counter(const std::string &key, int64_t value = 1) {
counters_[key] += value;
}
int64_t &get_counter(const std::string &key) { return counters_[key]; }
// Start the timer for a given key
void start_timer(const std::string& key) { active_timers_[key] = std::chrono::high_resolution_clock::now(); }
void start_timer(const std::string &key) {
active_timers_[key] = std::chrono::high_resolution_clock::now();
}
// Stop the timer for a given key and update the total time and count
void stop_timer(const std::string &key) {
auto start_it = active_timers_.find(key);
if (start_it != active_timers_.end()) {
auto duration = std::chrono::high_resolution_clock::now() - start_it->second;
auto duration =
std::chrono::high_resolution_clock::now() - start_it->second;
timings_[key].total_time += duration;
timings_[key].count += 1;
active_timers_.erase(start_it);
} else {
// Handle error: stop_timer called without a matching start_timer
std::cerr << "Warning: stop_timer called for key '" << key << "' without a matching start_timer.\n";
std::cerr << "Warning: stop_timer called for key '" << key
<< "' without a matching start_timer.\n";
}
}
@ -40,7 +46,10 @@ class Statistics {
for (const auto &kv : timings_) {
std::cout << " " << kv.first << ": count = " << kv.second.count
<< ", total_time = " << kv.second.total_time.count() << "s"
<< ", average_time = " << (kv.second.count > 0 ? kv.second.total_time.count() / kv.second.count : 0)
<< ", average_time = "
<< (kv.second.count > 0
? kv.second.total_time.count() / kv.second.count
: 0)
<< "s\n";
}
}
@ -52,14 +61,17 @@ class Statistics {
// Struct to hold timing information for a key
struct TimingInfo {
int64_t count = 0;
std::chrono::duration<double> total_time = std::chrono::duration<double>::zero();
std::chrono::duration<double> total_time =
std::chrono::duration<double>::zero();
};
// Mapping from key to timing information
std::unordered_map<std::string, TimingInfo> timings_;
// Mapping from key to the start time of active timers
std::unordered_map<std::string, std::chrono::high_resolution_clock::time_point> active_timers_;
std::unordered_map<std::string,
std::chrono::high_resolution_clock::time_point>
active_timers_;
};
#endif // STATISTICS_HPP

View file

@ -1,4 +1,5 @@
#pragma once
#include "readable_number.hpp"
#include <cassert>
#include <chrono>
#include <iomanip>
@ -6,7 +7,6 @@
#include <map>
#include <sstream>
#include <string>
#include "readable_number.hpp"
inline std::string doubleToStringR2(double value) {
std::stringstream stream;
@ -49,10 +49,14 @@ class Timer {
endTime = m_endTime;
}
return std::chrono::duration_cast<std::chrono::nanoseconds>(endTime - m_startTime).count();
return std::chrono::duration_cast<std::chrono::nanoseconds>(endTime -
m_startTime)
.count();
}
void printElapsedMilliseconds() { std::cout << elapsedNs() / 1e6 << " ms" << std::endl; }
void printElapsedMilliseconds() {
std::cout << elapsedNs() / 1e6 << " ms" << std::endl;
}
static std::string ns_to_string(double duration) {
auto nano_sec = duration;

View file

@ -1,122 +0,0 @@
{
"DeepSeek-Coder-V2-Instruct": {
"hidden_size": 5120,
"intermediate_size": 12288,
"max_position_embeddings": 163840,
"model_type": "deepseek_v2",
"num_attention_heads": 128,
"num_hidden_layers": 60,
"num_key_value_heads": 128,
"vocab_size": 102400
},
"DeepSeek-R1": {
"hidden_size": 7168,
"intermediate_size": 18432,
"max_position_embeddings": 163840,
"model_type": "deepseek_v3",
"num_attention_heads": 128,
"num_hidden_layers": 61,
"num_key_value_heads": 128,
"vocab_size": 129280
},
"DeepSeek-V2-Lite-Chat": {
"hidden_size": 2048,
"intermediate_size": 10944,
"max_position_embeddings": 163840,
"model_type": "deepseek_v2",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"num_key_value_heads": 16,
"vocab_size": 102400
},
"DeepSeek-V3": {
"hidden_size": 7168,
"intermediate_size": 18432,
"max_position_embeddings": 163840,
"model_type": "deepseek_v3",
"num_attention_heads": 128,
"num_hidden_layers": 3,
"num_key_value_heads": 128,
"vocab_size": 129280
},
"DeepSeek-V3-bf16": {
"hidden_size": 7168,
"intermediate_size": 18432,
"max_position_embeddings": 163840,
"model_type": "deepseek_v3",
"num_attention_heads": 128,
"num_hidden_layers": 61,
"num_key_value_heads": 128,
"vocab_size": 129280
},
"LLaMA-2-7B-32K": {
"hidden_size": 4096,
"intermediate_size": 11008,
"max_position_embeddings": 32768,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"vocab_size": 32000
},
"Moonlight-16B-A3B-Instruct": {
"hidden_size": 2048,
"intermediate_size": 11264,
"max_position_embeddings": 8192,
"model_type": "deepseek_v3",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"num_key_value_heads": 16,
"vocab_size": 163840
},
"Qwen2.5-32B-Instruct": {
"hidden_size": 5120,
"intermediate_size": 27648,
"max_position_embeddings": 32768,
"model_type": "qwen2",
"num_attention_heads": 40,
"num_hidden_layers": 64,
"num_key_value_heads": 8,
"vocab_size": 152064
},
"Qwen2.5-32B-Instruct-GPTQ-Int4": {
"hidden_size": 5120,
"intermediate_size": 27648,
"max_position_embeddings": 32768,
"model_type": "qwen2",
"num_attention_heads": 40,
"num_hidden_layers": 64,
"num_key_value_heads": 8,
"vocab_size": 152064
},
"Qwen2.5-7B-Instruct": {
"hidden_size": 3584,
"intermediate_size": 18944,
"max_position_embeddings": 32768,
"model_type": "qwen2",
"num_attention_heads": 28,
"num_hidden_layers": 28,
"num_key_value_heads": 4,
"vocab_size": 152064
},
"Qwen2.5-7B-Instruct-GPTQ-Int4": {
"hidden_size": 3584,
"intermediate_size": 18944,
"max_position_embeddings": 32768,
"model_type": "qwen2",
"num_attention_heads": 28,
"num_hidden_layers": 28,
"num_key_value_heads": 4,
"vocab_size": 152064
},
"qwen2-72b-instruct": {
"hidden_size": 8192,
"intermediate_size": 29568,
"max_position_embeddings": 32768,
"model_type": "qwen2",
"num_attention_heads": 64,
"num_hidden_layers": 80,
"num_key_value_heads": 8,
"vocab_size": 152064
}
}

View file

@ -1,57 +0,0 @@
{
"BF16": {
"block_element_count": 1,
"block_element_size": 2,
"bytes_per_element": 2.0,
"can_be_used_as_vector": true,
"has_min": false,
"has_scale": false,
"name": "BF16",
"reference": "",
"type_of_dot_vector": "BF16"
},
"FP16": {
"block_element_count": 1,
"block_element_size": 2,
"bytes_per_element": 2.0,
"can_be_used_as_vector": true,
"has_min": false,
"has_scale": false,
"name": "FP16",
"reference": "",
"type_of_dot_vector": "FP16"
},
"FP32": {
"block_element_count": 1,
"block_element_size": 4,
"bytes_per_element": 4.0,
"can_be_used_as_vector": true,
"has_min": false,
"has_scale": false,
"name": "FP32",
"reference": "",
"type_of_dot_vector": "FP32"
},
"Q4_0": {
"block_element_count": 32,
"block_element_size": 18,
"bytes_per_element": 0.5625,
"can_be_used_as_vector": false,
"has_min": false,
"has_scale": true,
"name": "Q4_0",
"reference": "https://huggingface.co/docs/hub/gguf",
"type_of_dot_vector": "Q8_0"
},
"Q8_0": {
"block_element_count": 32,
"block_element_size": 34,
"bytes_per_element": 1.0625,
"can_be_used_as_vector": true,
"has_min": false,
"has_scale": true,
"name": "Q8_0",
"reference": "https://huggingface.co/docs/hub/gguf",
"type_of_dot_vector": "Q8_0"
}
}

View file

@ -70,6 +70,9 @@ class ArgumentParser:
parser.add_argument("--batch_size", type=int, default=self.cfg.batch_size)
parser.add_argument("--cache_lens", type=int, default=self.cfg.cache_lens)
# kvc2 config
parser.add_argument("--kvc2_config_dir", type=str, default=self.cfg.kvc2_config_dir)
# log configs
# log level: debug, info, warn, error, crit
parser.add_argument("--log_dir", type=str, default=self.cfg.log_dir)

View file

@ -7,9 +7,7 @@ import sys, os
import yaml, json
from time import sleep
current_dir = os.path.dirname(__file__)
# sched_path = os.path.abspath(os.path.join(current_dir, '../../../build/balance_serve/sched'))
# sys.path.insert(0, sched_path)
import sched_ext
from transformers import AutoConfig
@ -52,8 +50,7 @@ def create_sched_settings(args):
settings.v_cache_on = False
settings.kvc2_root_path = '/mnt/data/persist-kvc'
settings.kvc2_config_path = os.path.join(current_dir, "..", "..", "configs")
print(os.path.join(current_dir, "..", "..", "configs"))
settings.kvc2_config_path = args.kvc2_config_dir
settings.memory_pool_size_GB = args.cpu_memory_size_GB
settings.evict_count = 40
settings.kvc2_metrics_port = args.kvc2_metrics_port

View file

@ -34,12 +34,15 @@ class Config(metaclass=Singleton):
user_path: str = os.path.expanduser("~")
localstore_path: str = os.path.join(user_path, ".ktransformers")
kvc2_config_dir = os.path.join(localstore_path, "kvc2")
config_path: str = os.path.join(localstore_path, Config.CONFIG_FILE_NAME)
if not os.path.exists(config_yaml):
print(f"Can't find config file, {config_yaml}")
exit(-1)
if not os.path.exists(localstore_path):
os.mkdir(localstore_path)
if not os.path.exists(kvc2_config_dir):
os.mkdir(kvc2_config_dir)
if not os.path.exists(config_path):
shutil.copyfile(config_yaml, config_path)
with open(config_path, "r", encoding="utf-8") as fp:
@ -62,10 +65,13 @@ class Config(metaclass=Singleton):
self.localstore_path: str = os.path.join(self.user_path, ".ktransformers")
# log configs
self.log_dir = os.path.join(self.localstore_path, cfg["log"]["dir"])
if not os.path.exists(self.log_dir):
os.mkdir(self.log_dir)
self.log_file = cfg["log"]["file"]
self.log_level = cfg["log"]["level"]
self.backup_count = cfg["log"]["backup_count"]
self.kvc2_config_dir = os.path.join(self.localstore_path, "kvc2")
# server configs
self.server: dict = cfg.get("server", {})
self.server_ip = self.server.get("ip", "0.0.0.0")