diff --git a/csrc/balance_serve/kvc2/src/async_store.cpp b/csrc/balance_serve/kvc2/src/async_store.cpp index b5400c9..033fa86 100644 --- a/csrc/balance_serve/kvc2/src/async_store.cpp +++ b/csrc/balance_serve/kvc2/src/async_store.cpp @@ -35,23 +35,23 @@ struct ArrayStore { if (to <= size) { return; } - //TODO: extend file + // TODO: extend file size = to; - //LOG_INFO("Extend file to `, size `", to, size_in_bytes()); + // LOG_INFO("Extend file to `, size `", to, size_in_bytes()); } ArrayStore(size_t element_size, size_t size, std::filesystem::path data_path) : element_size(element_size), element_size_aligned((element_size + DeviceBlockSize - 1) / DeviceBlockSize), data_path(data_path) { - //TODO: prefix cache + // TODO: prefix cache } void read(size_t index, void* buffer) { - //TODO: read from file + // TODO: read from file } void write(size_t index, void* buffer) { - //TODO: write to file + // TODO: write to file } }; @@ -98,15 +98,15 @@ struct IODealerImpl { IODealerImpl(bool use_io_uring, int IO_DEPTH) : use_io_uring(use_io_uring), IO_DEPTH(IO_DEPTH) {} void queue_consumer() { - //TODO: + // TODO: } void io_perf() { - //TODO: + // TODO: } void io_dealer() { - //TODO: + // TODO: } }; @@ -130,7 +130,7 @@ void IODealer::stop() { if (io_impl->stop) { return; } - //LOG_INFO("Stopping IO Dealer"); + // LOG_INFO("Stopping IO Dealer"); io_impl->stop = true; } diff --git a/csrc/balance_serve/kvc2/src/gpu_cache.cpp b/csrc/balance_serve/kvc2/src/gpu_cache.cpp index 2bfe945..76cf1a7 100644 --- a/csrc/balance_serve/kvc2/src/gpu_cache.cpp +++ b/csrc/balance_serve/kvc2/src/gpu_cache.cpp @@ -77,7 +77,6 @@ GPUPageCache::GPUPageCache(GPUPageCacheConfig& config) : config(config) { gpu_only_occupations.resize(config.total_kvcache_pages, false); } - num_free_pages = config.total_kvcache_pages; for (size_t i = 0; i < config.layer_count; i++) { if (config.k_cache_on) @@ -248,18 +247,19 @@ void GPUPageCache::append_col_to_request(std::vectorgpu_block_idx.value(); for (size_t layer = 0; layer < config.layer_count; layer++) { for (size_t which_gpu = 0; which_gpu < config.gpu_devices_id.size(); which_gpu++) { - if (config.k_cache_on) { assert(k_handles[layer][at]->data != nullptr); reqs[which_gpu]->sizes.push_back(tp_size[which_gpu]); - reqs[which_gpu]->host_mem_addresses.push_back(offset_by_bytes(k_handles[layer][at]->data, tp_offset[which_gpu])); + reqs[which_gpu]->host_mem_addresses.push_back( + offset_by_bytes(k_handles[layer][at]->data, tp_offset[which_gpu])); reqs[which_gpu]->device_mem_addresses.push_back(k_cache[which_gpu][layer][gpu_block_idx].data_ptr()); } - + if (config.v_cache_on) { assert(v_handles[layer][at]->data != nullptr); reqs[which_gpu]->sizes.push_back(tp_size[which_gpu]); - reqs[which_gpu]->host_mem_addresses.push_back(offset_by_bytes(v_handles[layer][at]->data, tp_offset[which_gpu])); + reqs[which_gpu]->host_mem_addresses.push_back( + offset_by_bytes(v_handles[layer][at]->data, tp_offset[which_gpu])); reqs[which_gpu]->device_mem_addresses.push_back(v_cache[which_gpu][layer][gpu_block_idx].data_ptr()); } } diff --git a/csrc/balance_serve/kvc2/src/metrics.h b/csrc/balance_serve/kvc2/src/metrics.h index fa88785..5910d8f 100644 --- a/csrc/balance_serve/kvc2/src/metrics.h +++ b/csrc/balance_serve/kvc2/src/metrics.h @@ -1,16 +1,16 @@ #pragma once -#include "prometheus/counter.h" -#include "prometheus/exposer.h" -#include "prometheus/gauge.h" -#include "prometheus/histogram.h" -#include "prometheus/registry.h" #include #include #include #include #include #include +#include "prometheus/counter.h" +#include "prometheus/exposer.h" +#include "prometheus/gauge.h" +#include "prometheus/histogram.h" +#include "prometheus/registry.h" #include "utils/timer.hpp" diff --git a/csrc/balance_serve/kvc2/src/model_config.h b/csrc/balance_serve/kvc2/src/model_config.h index 7ad1d90..e7512c4 100644 --- a/csrc/balance_serve/kvc2/src/model_config.h +++ b/csrc/balance_serve/kvc2/src/model_config.h @@ -1,8 +1,8 @@ #ifndef __MODEL_CONFIG_HPP_ #define __MODEL_CONFIG_HPP_ -#include #include "nlohmann/json.hpp" +#include #include #include @@ -13,7 +13,7 @@ using ModelName = std::string; // We must assure this can be load by config.json class ModelConfig { - public: +public: DimSize hidden_size; DimSize intermediate_size; size_t max_position_embeddings; @@ -23,10 +23,13 @@ class ModelConfig { size_t num_key_value_heads; size_t vocab_size; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, intermediate_size, max_position_embeddings, model_type, - num_attention_heads, num_hidden_layers, num_key_value_heads, vocab_size); + NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, intermediate_size, + max_position_embeddings, model_type, + num_attention_heads, num_hidden_layers, + num_key_value_heads, vocab_size); void load_from(std::filesystem::path path) { + std::cout << "Load from " << path << std::endl; std::ifstream i(path); nlohmann::json j; i >> j; @@ -38,12 +41,14 @@ using QuantType = std::string; static const QuantType NoQuantType = ""; class QuantConfig { - public: +public: QuantType name; // For GEMV QuantType type_of_dot_vector = NoQuantType; - inline bool can_be_used_as_matrix() { return type_of_dot_vector != NoQuantType; } + inline bool can_be_used_as_matrix() { + return type_of_dot_vector != NoQuantType; + } bool can_be_used_as_vector; @@ -56,8 +61,11 @@ class QuantConfig { URL reference = ""; - NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(QuantConfig, name, type_of_dot_vector, can_be_used_as_vector, - bytes_per_element, has_scale, has_min, block_element_count, + NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(QuantConfig, name, + type_of_dot_vector, + can_be_used_as_vector, + bytes_per_element, has_scale, + has_min, block_element_count, block_element_size, reference); }; @@ -65,14 +73,18 @@ inline std::map quant_configs; inline std::map model_configs; inline void load_quant_configs(std::filesystem::path path) { - std::cout << __FUNCTION__ << " from " << path << std::endl; - std::ifstream i(path); nlohmann::json j; - i >> j; - quant_configs = j.get>(); - std::cout << "Loaded Quant Configs" << std::endl; - for (auto& [k, v] : quant_configs) { - std::cout << " - " << k << std::endl; + if (std::filesystem::exists(path)) { + std::cout << __FUNCTION__ << " from " << path << std::endl; + std::ifstream i(path); + i >> j; + quant_configs = j.get>(); + std::cout << "Loaded Quant Configs" << std::endl; + for (auto &[k, v] : quant_configs) { + std::cout << " - " << k << std::endl; + } + } else { + std::cout << __FUNCTION__ << " no file at " << path << std::endl; } } @@ -83,14 +95,18 @@ inline void dump_quant_configs(std::filesystem::path path) { } inline void load_model_configs(std::filesystem::path path) { - std::cout << __FUNCTION__ << " from " << path << std::endl; - std::ifstream i(path); nlohmann::json j; - i >> j; - model_configs = j.get>(); - std::cout << "Loaded Model Configs" << std::endl; - for (auto& [k, v] : model_configs) { - std::cout << " - " << k << std::endl; + if (std::filesystem::exists(path)) { + std::cout << __FUNCTION__ << " from " << path << std::endl; + std::ifstream i(path); + i >> j; + model_configs = j.get>(); + std::cout << "Loaded Model Configs" << std::endl; + for (auto &[k, v] : model_configs) { + std::cout << " - " << k << std::endl; + } + } else { + std::cout << __FUNCTION__ << " no file at " << path << std::endl; } } diff --git a/csrc/balance_serve/kvc2/src/page_aligned_memory_pool.cpp b/csrc/balance_serve/kvc2/src/page_aligned_memory_pool.cpp index d70ed2e..bbb44db 100644 --- a/csrc/balance_serve/kvc2/src/page_aligned_memory_pool.cpp +++ b/csrc/balance_serve/kvc2/src/page_aligned_memory_pool.cpp @@ -17,13 +17,14 @@ PageAlignedMemoryPool::PageAlignedMemoryPool(size_t size_in_bytes) { assert(total_pages >= Blocks); page_per_block = total_pages / Blocks; - for (size_t block_index = 0; block_index < Blocks; block_index ++) { - first_page[block_index] = reinterpret_cast(reinterpret_cast(data) + static_cast(block_index) * page_per_block * PageSize); + for (size_t block_index = 0; block_index < Blocks; block_index++) { + first_page[block_index] = reinterpret_cast(reinterpret_cast(data) + + static_cast(block_index) * page_per_block * PageSize); count_page[block_index] = - block_index == Blocks - 1 ? (total_pages - page_per_block * (Blocks - 1)) : page_per_block; - SPDLOG_DEBUG("first_page[{}] = {}, count_page[{}] = {}", - block_index, reinterpret_cast(first_page[block_index]) - reinterpret_cast(data), - block_index, count_page[block_index]); + block_index == Blocks - 1 ? (total_pages - page_per_block * (Blocks - 1)) : page_per_block; + SPDLOG_DEBUG("first_page[{}] = {}, count_page[{}] = {}", block_index, + reinterpret_cast(first_page[block_index]) - reinterpret_cast(data), block_index, + count_page[block_index]); bitmap[block_index].resize(count_page[block_index], 0); } SPDLOG_INFO("PageAlignedMemoryPool with size {} Mbytes, {} pages", total_size / (1 << 20), page_count()); @@ -53,7 +54,7 @@ void* PageAlignedMemoryPool::alloc_in_block(size_t block_index, size_t alloc_siz size_t free_pages = 0; for (size_t i = 0; i < count_page[block_index]; i++) { if (bitmap[block_index][i] == 0) { - free_pages ++; + free_pages++; if (free_pages == alloc_size) { size_t page_index = i + 1 - free_pages; for (size_t page = page_index; page < page_index + alloc_size; page++) { @@ -73,7 +74,7 @@ void* PageAlignedMemoryPool::alloc_in_block(size_t block_index, size_t alloc_siz void* PageAlignedMemoryPool::alloc(size_t size) { size_t alloc_size = div_up(size, PageSize); auto cnt = now_block.fetch_add(1, std::memory_order_relaxed); - for (size_t i = 0; i < Blocks; i ++) { + for (size_t i = 0; i < Blocks; i++) { auto result = alloc_in_block((i + cnt) % Blocks, alloc_size); if (result != nullptr) { allocated.fetch_add(alloc_size * PageSize, std::memory_order_relaxed); @@ -119,5 +120,6 @@ void PageAlignedMemoryPool::defragment() {} /// 调试打印 std::string PageAlignedMemoryPool::debug() { return fmt::format("PageAlignedMemoryPool: total_size: {}MB, allocated: {}, alloc/free count: {}/{}\n", - readable_number(total_size), readable_number(size_t(allocated)), size_t(alloc_count), size_t(free_count)); + readable_number(total_size), readable_number(size_t(allocated)), size_t(alloc_count), + size_t(free_count)); } diff --git a/csrc/balance_serve/kvc2/src/page_aligned_memory_pool.h b/csrc/balance_serve/kvc2/src/page_aligned_memory_pool.h index c65a740..4daaac1 100644 --- a/csrc/balance_serve/kvc2/src/page_aligned_memory_pool.h +++ b/csrc/balance_serve/kvc2/src/page_aligned_memory_pool.h @@ -1,12 +1,12 @@ #pragma once -#include // std::sort -#include // size_t -#include // std::mutex -#include #include -#include +#include // std::sort #include +#include +#include // size_t +#include // std::mutex +#include constexpr size_t PageSize = 4096; @@ -18,7 +18,7 @@ struct PageAlignedMemoryPool { void* data = nullptr; size_t total_size = 0, total_pages = 0; - + std::atomic_size_t now_block = 0; std::atomic_size_t allocated = 0; // allocated_size std::atomic_size_t alloc_count = 0; @@ -26,10 +26,11 @@ struct PageAlignedMemoryPool { std::mutex lock[Blocks]; size_t page_per_block = 0; - void *first_page[Blocks]; + void* first_page[Blocks]; size_t count_page[Blocks]; std::vector bitmap[Blocks]; void* alloc_in_block(size_t block_index, size_t alloc_size); + public: /// 构造函数和析构函数 explicit PageAlignedMemoryPool(size_t size_in_bytes); diff --git a/csrc/balance_serve/kvc2/src/prefix.cpp b/csrc/balance_serve/kvc2/src/prefix.cpp index add1cd4..9518866 100644 --- a/csrc/balance_serve/kvc2/src/prefix.cpp +++ b/csrc/balance_serve/kvc2/src/prefix.cpp @@ -339,7 +339,7 @@ struct Prefix { void update_location(CacheInfo info, Location location) { locations.location_map[info] = location; } - Prefix* to_first_prefix_without_disk_locations(CacheInfo k_info/*, CacheInfo v_info*/) { // just k_info + Prefix* to_first_prefix_without_disk_locations(CacheInfo k_info /*, CacheInfo v_info*/) { // just k_info auto now_prefix = this; while (now_prefix->prev != nullptr) { auto& prev = now_prefix->prev; @@ -561,7 +561,7 @@ struct PrefixTree { if (need_lock) { sl = std::shared_lock(rw_lock); } - //TODO: prefix cache + // TODO: prefix cache } PrefixMatch look_up_or_insert(Token* data, TokenLength length) { @@ -579,7 +579,6 @@ struct PrefixTree { return re; } - std::shared_ptr new_prefix_node(Prefix* prev, TokenLength prev_match_length, Token* data, TokenLength length, bool need_lock = true) { std::unique_lock ul; @@ -700,9 +699,7 @@ struct DoubleCacheHandle : public DoubleCacheHandleInterface { } } } - std::vector matched_status() override { - assert(false); - } + std::vector matched_status() override { assert(false); } bool any_match() { if (enable_alt) { @@ -1066,7 +1063,6 @@ struct DoubleCacheHandle : public DoubleCacheHandleInterface { }; struct KVC2 : KVC2Interface { - KVC2Config config; std::shared_ptr met; @@ -1194,7 +1190,7 @@ struct KVC2 : KVC2Interface { auto v_loc = disk_cache->allocate(h->v_info(), div_up(new_length, NumTokenPerBlock)); h->k_seg_locs.add_location(now_prefix->start_length / NumTokenPerBlock, k_loc); h->v_seg_locs.add_location(now_prefix->start_length / NumTokenPerBlock, v_loc); - + // split it to prefix trees for (auto tail = h->match.prefix; tail != now_prefix->prev; tail = tail->prev) { TokenLength local_ids_length = tail->local_length(); @@ -1207,7 +1203,7 @@ struct KVC2 : KVC2Interface { // allocate a big space on disk auto k_loc = disk_cache->allocate(h->k_info(), div_up(new_length, NumTokenPerBlock)); h->k_seg_locs.add_location(now_prefix->start_length / NumTokenPerBlock, k_loc); - + // split it to prefix trees for (auto tail = h->match.prefix; tail != now_prefix->prev; tail = tail->prev) { TokenLength local_ids_length = tail->local_length(); @@ -1231,7 +1227,7 @@ struct KVC2 : KVC2Interface { h->kvc2_top = this; h->set_cache_info(model_name, quant_type, config.k_cache_on, config.v_cache_on); h->ids = Tokens(id, id + length); - + if (config.k_cache_on) h->set_raw_handles(true, k_cache); if (config.v_cache_on) @@ -1261,7 +1257,7 @@ struct KVC2 : KVC2Interface { re->kvc2_top = this; SPDLOG_DEBUG("Lookup TokenLength {}", length); if (config.gpu_only == false) { - //TODO: + // TODO: } return re; }; @@ -1694,9 +1690,11 @@ void GPUPageCache::gpu_background_flush() { if (col_uls.empty()) continue; for (size_t l = 0; l < config.layer_count; l++) { - if (config.k_cache_on && (occupations[l][i]->gpu_cc.dirty.load() == false || occupations[l][i]->cpu_cc.dirty.load())) + if (config.k_cache_on && + (occupations[l][i]->gpu_cc.dirty.load() == false || occupations[l][i]->cpu_cc.dirty.load())) goto next_gpu_page; - if (config.v_cache_on && (v_occupations[l][i]->gpu_cc.dirty.load() == false || v_occupations[l][i]->cpu_cc.dirty.load())) + if (config.v_cache_on && + (v_occupations[l][i]->gpu_cc.dirty.load() == false || v_occupations[l][i]->cpu_cc.dirty.load())) goto next_gpu_page; } diff --git a/csrc/balance_serve/kvc2/test/kvc2test/common.hpp b/csrc/balance_serve/kvc2/test/kvc2test/common.hpp index 29e37dd..5d01250 100644 --- a/csrc/balance_serve/kvc2/test/kvc2test/common.hpp +++ b/csrc/balance_serve/kvc2/test/kvc2test/common.hpp @@ -139,18 +139,18 @@ std::vector random_ids(size_t length, std::mt19937& gen) { return re; } -std::vector slice(std::vector& h1,size_t start,size_t end){ +std::vector slice(std::vector& h1, size_t start, size_t end) { std::vector re; - for(auto&l:h1){ + for (auto& l : h1) { layer_data new_layer; - new_layer.insert(new_layer.end(),l.begin()+start,l.begin()+end); + new_layer.insert(new_layer.end(), l.begin() + start, l.begin() + end); re.push_back(new_layer); } return re; } void cmp_handle_data(std::vector h1, std::vector h2, - std::optional blocks = std::nullopt) { + std::optional blocks = std::nullopt) { assert(h1.size() == h2.size()); for (size_t i = 0; i < h1.size(); i++) { diff --git a/csrc/balance_serve/kvc2/test/kvc2test/flush-back.cpp b/csrc/balance_serve/kvc2/test/kvc2test/flush-back.cpp index cd94b11..c601d7c 100644 --- a/csrc/balance_serve/kvc2/test/kvc2test/flush-back.cpp +++ b/csrc/balance_serve/kvc2/test/kvc2test/flush-back.cpp @@ -7,9 +7,9 @@ int main(int argc, char* argv[]) { config.gpu_cache_config->total_kvcache_pages = 12; auto kvc2 = kvc2::create_kvc2(config); -// #pragma omp parallel for + // #pragma omp parallel for for (size_t ti = 0; ti < 2; ti++) { - SPDLOG_WARN("Test {}",ti); + SPDLOG_WARN("Test {}", ti); auto [kcache, vcache] = kvc2->get_kvcache(); std::mt19937 gen(ti + 123); size_t total_page = 10; diff --git a/csrc/balance_serve/kvc2/test/kvc2test/lookup-gpu-mt-without-vcache.cpp b/csrc/balance_serve/kvc2/test/kvc2test/lookup-gpu-mt-without-vcache.cpp index a60de53..f8ea430 100644 --- a/csrc/balance_serve/kvc2/test/kvc2test/lookup-gpu-mt-without-vcache.cpp +++ b/csrc/balance_serve/kvc2/test/kvc2test/lookup-gpu-mt-without-vcache.cpp @@ -14,7 +14,7 @@ int main(int argc, char* argv[]) { qw25_7B_gpu_config.v_cache_on = false; config.gpu_cache_config = qw25_7B_gpu_config; config.v_cache_on = false; - + init(argc, argv); spdlog::set_level(spdlog::level::debug); auto kvc2 = kvc2::create_kvc2(config); diff --git a/csrc/balance_serve/kvc2/test/kvc2test/lookup-without-vcache.cpp b/csrc/balance_serve/kvc2/test/kvc2test/lookup-without-vcache.cpp index febf5a8..285580a 100644 --- a/csrc/balance_serve/kvc2/test/kvc2test/lookup-without-vcache.cpp +++ b/csrc/balance_serve/kvc2/test/kvc2test/lookup-without-vcache.cpp @@ -11,11 +11,10 @@ #include "common.hpp" int main(int argc, char* argv[]) { - qw25_7B_gpu_config.v_cache_on = false; config.gpu_cache_config = qw25_7B_gpu_config; config.v_cache_on = false; - + init(argc, argv); spdlog::set_level(spdlog::level::debug); auto kvc2 = kvc2::create_kvc2(config); diff --git a/csrc/balance_serve/kvc2/test/page_pool_test.cpp b/csrc/balance_serve/kvc2/test/page_pool_test.cpp index 5bccbca..bc337f4 100644 --- a/csrc/balance_serve/kvc2/test/page_pool_test.cpp +++ b/csrc/balance_serve/kvc2/test/page_pool_test.cpp @@ -1,16 +1,15 @@ +#include #include +#include #include #include -#include -#include #include "page_aligned_memory_pool.cpp" #define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG #define FMT_HEADER_ONLY #include "spdlog/spdlog.h" - // 每个线程执行的任务 void thread_task(PageAlignedMemoryPool& pool) { std::mt19937 gen(123); @@ -22,8 +21,8 @@ void thread_task(PageAlignedMemoryPool& pool) { void* ptr = pool.alloc(size); // SPDLOG_DEBUG(pool.debug()); if (ptr) { - pool.free(ptr, size); - // allocated.push_back({ptr, size}); + pool.free(ptr, size); + // allocated.push_back({ptr, size}); } // sleep((int)(gen() % 1000) / 1000.0); } @@ -35,21 +34,20 @@ void thread_task(PageAlignedMemoryPool& pool) { int main(int argc, char* argv[]) { spdlog::set_level(spdlog::level::debug); - // 创建一个内存池 - PageAlignedMemoryPool pool(40ll * 1024 * 1024 * 1024); // 40 G + PageAlignedMemoryPool pool(40ll * 1024 * 1024 * 1024); // 40 G // 创建线程 const int num_threads = 32; std::vector threads; for (int i = 0; i < num_threads; ++i) { - threads.emplace_back(thread_task, std::ref(pool)); + threads.emplace_back(thread_task, std::ref(pool)); } // 等待所有线程完成 for (auto& t : threads) { - t.join(); + t.join(); } // 输出调试信息 diff --git a/csrc/balance_serve/kvc2/test/test_periodic_task.cpp b/csrc/balance_serve/kvc2/test/test_periodic_task.cpp index a2f8f89..e11457a 100644 --- a/csrc/balance_serve/kvc2/test/test_periodic_task.cpp +++ b/csrc/balance_serve/kvc2/test/test_periodic_task.cpp @@ -1,171 +1,163 @@ -#include "utils/periodic_task.hpp" -#include -#include -#include -#include -#include #include #include +#include +#include +#include +#include +#include +#include "utils/periodic_task.hpp" // 1. 任务是否按预期执行 void testPeriodicTaskExecution() { - std::atomic execution_count{0}; - auto task = [&execution_count]() { - execution_count++; - }; + std::atomic execution_count{0}; + auto task = [&execution_count]() { execution_count++; }; - periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(50)); + periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(50)); - std::this_thread::sleep_for(std::chrono::seconds(2)); + std::this_thread::sleep_for(std::chrono::seconds(2)); - assert(execution_count >= 20); // 确保任务执行了至少 20 次 - std::cout << "Test 1 passed: Task executed periodically." << std::endl; - std::cout << "Task executed " << execution_count.load() << " times." << std::endl; + assert(execution_count >= 20); // 确保任务执行了至少 20 次 + std::cout << "Test 1 passed: Task executed periodically." << std::endl; + std::cout << "Task executed " << execution_count.load() << " times." << std::endl; } // 2. 提前唤醒任务的功能 void testWakeUpImmediately() { - std::atomic execution_count{0}; - auto task = [&execution_count]() { - execution_count++; - }; + std::atomic execution_count{0}; + auto task = [&execution_count]() { execution_count++; }; - periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200)); + periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200)); - // 提前唤醒任务 - periodic_task.wakeUp(); - std::this_thread::sleep_for(std::chrono::milliseconds(50)); // 等待任务执行 + // 提前唤醒任务 + periodic_task.wakeUp(); + std::this_thread::sleep_for(std::chrono::milliseconds(50)); // 等待任务执行 - std::cout << "Execution count after wakeUp: " << execution_count.load() << std::endl; - assert(execution_count == 1); // 确保任务立即执行 - std::cout << "Test 2 passed: Task woke up immediately." << std::endl; + std::cout << "Execution count after wakeUp: " << execution_count.load() << std::endl; + assert(execution_count == 1); // 确保任务立即执行 + std::cout << "Test 2 passed: Task woke up immediately." << std::endl; } // 3. wakeUpWait() 的等待功能 void testWakeUpWait() { - std::promise promise; - std::future future = promise.get_future(); - auto task = [&promise]() { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); // 模拟任务执行 - promise.set_value(); // 任务完成时设置 promise - }; + std::promise promise; + std::future future = promise.get_future(); + auto task = [&promise]() { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // 模拟任务执行 + promise.set_value(); // 任务完成时设置 promise + }; - periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200)); + periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200)); - // 调用 wakeUpWait 并等待任务完成 - std::future wakeup_future = periodic_task.wakeUpWait(); - wakeup_future.wait(); // 等待任务完成 + // 调用 wakeUpWait 并等待任务完成 + std::future wakeup_future = periodic_task.wakeUpWait(); + wakeup_future.wait(); // 等待任务完成 - assert(wakeup_future.valid()); // 确保 future 是有效的 - std::cout << "Test 3 passed: wakeUpWait() works correctly." << std::endl; - std::cout << "wakeUpWait() future is valid." << std::endl; + assert(wakeup_future.valid()); // 确保 future 是有效的 + std::cout << "Test 3 passed: wakeUpWait() works correctly." << std::endl; + std::cout << "wakeUpWait() future is valid." << std::endl; } // 4. 任务抛出异常的处理 void testTaskExceptionHandling() { - auto task = []() { - throw std::runtime_error("Test exception"); - }; + auto task = []() { throw std::runtime_error("Test exception"); }; - periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200)); + periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200)); - std::this_thread::sleep_for(std::chrono::milliseconds(300)); // 等待一段时间 + std::this_thread::sleep_for(std::chrono::milliseconds(300)); // 等待一段时间 - std::cout << "Test 4 passed: Task exception is handled correctly." << std::endl; - std::cout << "Exception handled and task did not crash." << std::endl; + std::cout << "Test 4 passed: Task exception is handled correctly." << std::endl; + std::cout << "Exception handled and task did not crash." << std::endl; } // 5. 线程是否能正确停止 void testTaskStop() { - std::atomic stopped{false}; - auto task = [&stopped]() { - while (!stopped) { - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - } - }; + std::atomic stopped{false}; + auto task = [&stopped]() { + while (!stopped) { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + }; - periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(100)); + periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(100)); - std::this_thread::sleep_for(std::chrono::seconds(1)); // 运行一段时间 + std::this_thread::sleep_for(std::chrono::seconds(1)); // 运行一段时间 - stopped = true; // 请求停止 - std::this_thread::sleep_for(std::chrono::milliseconds(50)); // 等待线程停止 + stopped = true; // 请求停止 + std::this_thread::sleep_for(std::chrono::milliseconds(50)); // 等待线程停止 - std::cout << "Test 5 passed: Task thread stops correctly." << std::endl; - std::cout << "Task has been stopped successfully." << std::endl; + std::cout << "Test 5 passed: Task thread stops correctly." << std::endl; + std::cout << "Task has been stopped successfully." << std::endl; } // 6. 高频唤醒的情况下任务执行是否正常 void testHighFrequencyWakeUp() { - std::atomic execution_count{0}; - auto task = [&execution_count]() { - execution_count++; - }; + std::atomic execution_count{0}; + auto task = [&execution_count]() { execution_count++; }; - periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200)); + periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200)); - for (int i = 0; i < 100; ++i) { - periodic_task.wakeUp(); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); // 每 10 毫秒唤醒一次 - } + for (int i = 0; i < 100; ++i) { + periodic_task.wakeUp(); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); // 每 10 毫秒唤醒一次 + } - std::this_thread::sleep_for(std::chrono::seconds(1)); // 等待任务执行完成 + std::this_thread::sleep_for(std::chrono::seconds(1)); // 等待任务执行完成 - assert(execution_count > 50); // 确保任务至少执行了 50 次 - std::cout << "Test 6 passed: Task handles frequent wake ups correctly." << std::endl; - std::cout << "Task executed " << execution_count.load() << " times." << std::endl; + assert(execution_count > 50); // 确保任务至少执行了 50 次 + std::cout << "Test 6 passed: Task handles frequent wake ups correctly." << std::endl; + std::cout << "Task executed " << execution_count.load() << " times." << std::endl; } // 7. 多个 wakeUpWait() 调用的处理 void testMultipleWakeUpWait() { - std::atomic execution_count{0}; - auto task = [&execution_count]() { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); // 模拟任务执行 - execution_count++; - }; + std::atomic execution_count{0}; + auto task = [&execution_count]() { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // 模拟任务执行 + execution_count++; + }; - periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200)); + periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200)); - // 同时调用两个 wakeUpWait - std::future future1 = periodic_task.wakeUpWait(); - std::future future2 = periodic_task.wakeUpWait(); + // 同时调用两个 wakeUpWait + std::future future1 = periodic_task.wakeUpWait(); + std::future future2 = periodic_task.wakeUpWait(); - future1.wait(); - future2.wait(); + future1.wait(); + future2.wait(); - assert(execution_count == 1); // 确保任务只执行了一次 - std::cout << "Test 7 passed: Multiple wakeUpWait() calls are handled correctly." << std::endl; - std::cout << "Task executed " << execution_count.load() << " times." << std::endl; + assert(execution_count == 1); // 确保任务只执行了一次 + std::cout << "Test 7 passed: Multiple wakeUpWait() calls are handled correctly." << std::endl; + std::cout << "Task executed " << execution_count.load() << " times." << std::endl; } // 8. 任务函数为空的边界情况 void testEmptyTaskFunction() { - auto task = []() { - // 空任务函数 - }; + auto task = []() { + // 空任务函数 + }; - periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(100)); + periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(100)); - std::this_thread::sleep_for(std::chrono::seconds(1)); // 等待一段时间 + std::this_thread::sleep_for(std::chrono::seconds(1)); // 等待一段时间 - std::cout << "Test 8 passed: Empty task function works correctly." << std::endl; - std::cout << "Empty task function executed without issues." << std::endl; + std::cout << "Test 8 passed: Empty task function works correctly." << std::endl; + std::cout << "Empty task function executed without issues." << std::endl; } int main() { - std::cout << "Starting tests..." << std::endl; + std::cout << "Starting tests..." << std::endl; - // testWakeUpImmediately(); - testPeriodicTaskExecution(); - testWakeUpImmediately(); - testWakeUpWait(); - testTaskExceptionHandling(); - testTaskStop(); - testHighFrequencyWakeUp(); - testMultipleWakeUpWait(); - testEmptyTaskFunction(); + // testWakeUpImmediately(); + testPeriodicTaskExecution(); + testWakeUpImmediately(); + testWakeUpWait(); + testTaskExceptionHandling(); + testTaskStop(); + testHighFrequencyWakeUp(); + testMultipleWakeUpWait(); + testEmptyTaskFunction(); - std::cout << "All tests passed!" << std::endl; + std::cout << "All tests passed!" << std::endl; - return 0; + return 0; } diff --git a/csrc/balance_serve/sched/bind.cpp b/csrc/balance_serve/sched/bind.cpp index d280e44..13733ff 100644 --- a/csrc/balance_serve/sched/bind.cpp +++ b/csrc/balance_serve/sched/bind.cpp @@ -1,8 +1,8 @@ +#include "scheduler.h" +#include #include #include #include -#include -#include "scheduler.h" #include @@ -16,19 +16,25 @@ PYBIND11_MODULE(sched_ext, m) { .def_readwrite("layer_count", &scheduler::ModelSettings::layer_count) .def_readwrite("num_k_heads", &scheduler::ModelSettings::num_k_heads) .def_readwrite("k_head_dim", &scheduler::ModelSettings::k_head_dim) - .def_readwrite("bytes_per_params", &scheduler::ModelSettings::bytes_per_params) - .def_readwrite("bytes_per_kv_cache_element", &scheduler::ModelSettings::bytes_per_kv_cache_element) + .def_readwrite("bytes_per_params", + &scheduler::ModelSettings::bytes_per_params) + .def_readwrite("bytes_per_kv_cache_element", + &scheduler::ModelSettings::bytes_per_kv_cache_element) .def("params_size", &scheduler::ModelSettings::params_nbytes) - .def("bytes_per_token_kv_cache", &scheduler::ModelSettings::bytes_per_token_kv_cache) + .def("bytes_per_token_kv_cache", + &scheduler::ModelSettings::bytes_per_token_kv_cache) // 添加 pickle 支持 .def(py::pickle( - [](const scheduler::ModelSettings& self) { // __getstate__ - return py::make_tuple(self.params_count, self.layer_count, self.num_k_heads, self.k_head_dim, - self.bytes_per_params, self.bytes_per_kv_cache_element); + [](const scheduler::ModelSettings &self) { // __getstate__ + return py::make_tuple(self.params_count, self.layer_count, + self.num_k_heads, self.k_head_dim, + self.bytes_per_params, + self.bytes_per_kv_cache_element); }, - [](py::tuple t) { // __setstate__ + [](py::tuple t) { // __setstate__ if (t.size() != 6) - throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size())); + throw std::runtime_error("Invalid state! t.size() = " + + std::to_string(t.size())); scheduler::ModelSettings ms; ms.params_count = t[0].cast(); ms.layer_count = t[1].cast(); @@ -40,22 +46,24 @@ PYBIND11_MODULE(sched_ext, m) { })); py::class_(m, "SampleOptions") - .def(py::init<>()) - .def_readwrite("temperature", &scheduler::SampleOptions::temperature) - .def_readwrite("top_p", &scheduler::SampleOptions::top_p) // 确保 top_p 也能被访问 - .def(py::pickle( - [](const scheduler::SampleOptions& self) { - return py::make_tuple(self.temperature, self.top_p); // 序列化 temperature 和 top_p - }, - [](py::tuple t) { - if (t.size() != 2) // 确保解包时参数数量匹配 - throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size())); + .def(py::init<>()) + .def_readwrite("temperature", &scheduler::SampleOptions::temperature) + .def_readwrite("top_p", + &scheduler::SampleOptions::top_p) // 确保 top_p 也能被访问 + .def(py::pickle( + [](const scheduler::SampleOptions &self) { + return py::make_tuple(self.temperature, + self.top_p); // 序列化 temperature 和 top_p + }, + [](py::tuple t) { + if (t.size() != 2) // 确保解包时参数数量匹配 + throw std::runtime_error("Invalid state! t.size() = " + + std::to_string(t.size())); scheduler::SampleOptions so; so.temperature = t[0].cast(); - so.top_p = t[1].cast(); // 反序列化 top_p + so.top_p = t[1].cast(); // 反序列化 top_p return so; - } - )); + })); py::class_(m, "Settings") .def(py::init<>()) @@ -65,33 +73,43 @@ PYBIND11_MODULE(sched_ext, m) { .def_readwrite("page_size", &scheduler::Settings::page_size) .def_readwrite("gpu_device_id", &scheduler::Settings::gpu_device_id) .def_readwrite("gpu_memory_size", &scheduler::Settings::gpu_memory_size) - .def_readwrite("memory_utilization_percentage", &scheduler::Settings::memory_utilization_percentage) + .def_readwrite("memory_utilization_percentage", + &scheduler::Settings::memory_utilization_percentage) .def_readwrite("max_batch_size", &scheduler::Settings::max_batch_size) - .def_readwrite("recommended_chunk_prefill_token_count", - &scheduler::Settings::recommended_chunk_prefill_token_count) + .def_readwrite( + "recommended_chunk_prefill_token_count", + &scheduler::Settings::recommended_chunk_prefill_token_count) .def_readwrite("sample_options", &scheduler::Settings::sample_options) - .def_readwrite("sched_metrics_port", &scheduler::Settings::sched_metrics_port) + .def_readwrite("sched_metrics_port", + &scheduler::Settings::sched_metrics_port) .def_readwrite("gpu_only", &scheduler::Settings::gpu_only) - .def_readwrite("use_self_defined_head_dim", &scheduler::Settings::use_self_defined_head_dim) - .def_readwrite("self_defined_head_dim", &scheduler::Settings::self_defined_head_dim) - .def_readwrite("full_kv_cache_on_each_gpu", &scheduler::Settings::full_kv_cache_on_each_gpu) + .def_readwrite("use_self_defined_head_dim", + &scheduler::Settings::use_self_defined_head_dim) + .def_readwrite("self_defined_head_dim", + &scheduler::Settings::self_defined_head_dim) + .def_readwrite("full_kv_cache_on_each_gpu", + &scheduler::Settings::full_kv_cache_on_each_gpu) .def_readwrite("k_cache_on", &scheduler::Settings::k_cache_on) .def_readwrite("v_cache_on", &scheduler::Settings::v_cache_on) .def_readwrite("kvc2_config_path", &scheduler::Settings::kvc2_config_path) .def_readwrite("kvc2_root_path", &scheduler::Settings::kvc2_root_path) - .def_readwrite("memory_pool_size_GB", &scheduler::Settings::memory_pool_size_GB) + .def_readwrite("memory_pool_size_GB", + &scheduler::Settings::memory_pool_size_GB) .def_readwrite("evict_count", &scheduler::Settings::evict_count) .def_readwrite("strategy_name", &scheduler::Settings::strategy_name) - .def_readwrite("kvc2_metrics_port", &scheduler::Settings::kvc2_metrics_port) + .def_readwrite("kvc2_metrics_port", + &scheduler::Settings::kvc2_metrics_port) .def_readwrite("load_from_disk", &scheduler::Settings::load_from_disk) .def_readwrite("save_to_disk", &scheduler::Settings::save_to_disk) // derived .def_readwrite("gpu_device_count", &scheduler::Settings::gpu_device_count) - .def_readwrite("total_kvcache_pages", &scheduler::Settings::total_kvcache_pages) + .def_readwrite("total_kvcache_pages", + &scheduler::Settings::total_kvcache_pages) .def_readwrite("devices", &scheduler::Settings::devices) .def("auto_derive", &scheduler::Settings::auto_derive); - py::class_>(m, "BatchQueryTodo") + py::class_>(m, "BatchQueryTodo") .def(py::init<>()) .def_readwrite("query_ids", &scheduler::BatchQueryTodo::query_ids) .def_readwrite("query_tokens", &scheduler::BatchQueryTodo::query_tokens) @@ -99,31 +117,42 @@ PYBIND11_MODULE(sched_ext, m) { .def_readwrite("block_indexes", &scheduler::BatchQueryTodo::block_indexes) .def_readwrite("attn_masks", &scheduler::BatchQueryTodo::attn_masks) .def_readwrite("rope_ranges", &scheduler::BatchQueryTodo::rope_ranges) - .def_readwrite("sample_options", &scheduler::BatchQueryTodo::sample_options) - .def_readwrite("prefill_mini_batches", &scheduler::BatchQueryTodo::prefill_mini_batches) - .def_readwrite("decode_mini_batches", &scheduler::BatchQueryTodo::decode_mini_batches) + .def_readwrite("sample_options", + &scheduler::BatchQueryTodo::sample_options) + .def_readwrite("prefill_mini_batches", + &scheduler::BatchQueryTodo::prefill_mini_batches) + .def_readwrite("decode_mini_batches", + &scheduler::BatchQueryTodo::decode_mini_batches) .def_readwrite("stop_criteria", &scheduler::BatchQueryTodo::stop_criteria) .def("debug", &scheduler::BatchQueryTodo::debug) .def(py::pickle( - [](const scheduler::BatchQueryTodo& self) { - return py::make_tuple(self.query_ids, self.query_tokens, self.query_lengths, self.block_indexes, - self.attn_masks, self.rope_ranges, self.sample_options, self.prefill_mini_batches, - self.decode_mini_batches, self.stop_criteria); + [](const scheduler::BatchQueryTodo &self) { + return py::make_tuple( + self.query_ids, self.query_tokens, self.query_lengths, + self.block_indexes, self.attn_masks, self.rope_ranges, + self.sample_options, self.prefill_mini_batches, + self.decode_mini_batches, self.stop_criteria); }, [](py::tuple t) { if (t.size() != 10) - throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size())); + throw std::runtime_error("Invalid state! t.size() = " + + std::to_string(t.size())); scheduler::BatchQueryTodo bqt; bqt.query_ids = t[0].cast>(); bqt.query_tokens = t[1].cast>(); - bqt.query_lengths = t[2].cast>(); + bqt.query_lengths = + t[2].cast>(); bqt.block_indexes = t[3].cast>(); bqt.attn_masks = t[4].cast>(); bqt.rope_ranges = t[5].cast>(); - bqt.sample_options = t[6].cast>(); - bqt.prefill_mini_batches = t[7].cast>(); - bqt.decode_mini_batches = t[8].cast>>(); - bqt.stop_criteria = t[9].cast>>>(); + bqt.sample_options = + t[6].cast>(); + bqt.prefill_mini_batches = + t[7].cast>(); + bqt.decode_mini_batches = + t[8].cast>>(); + bqt.stop_criteria = + t[9].cast>>>(); return bqt; })); @@ -133,16 +162,20 @@ PYBIND11_MODULE(sched_ext, m) { .def_readwrite("ok", &scheduler::QueryUpdate::ok) .def_readwrite("is_prefill", &scheduler::QueryUpdate::is_prefill) .def_readwrite("decode_done", &scheduler::QueryUpdate::decode_done) - .def_readwrite("active_position", &scheduler::QueryUpdate::active_position) - .def_readwrite("generated_token", &scheduler::QueryUpdate::generated_token) + .def_readwrite("active_position", + &scheduler::QueryUpdate::active_position) + .def_readwrite("generated_token", + &scheduler::QueryUpdate::generated_token) .def(py::pickle( - [](const scheduler::QueryUpdate& self) { - return py::make_tuple(self.id, self.ok, self.is_prefill, self.decode_done, self.active_position, + [](const scheduler::QueryUpdate &self) { + return py::make_tuple(self.id, self.ok, self.is_prefill, + self.decode_done, self.active_position, self.generated_token); }, [](py::tuple t) { if (t.size() != 6) - throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size())); + throw std::runtime_error("Invalid state! t.size() = " + + std::to_string(t.size())); scheduler::QueryUpdate qu; qu.id = t[0].cast(); qu.ok = t[1].cast(); @@ -156,8 +189,7 @@ PYBIND11_MODULE(sched_ext, m) { py::class_(m, "InferenceContext") .def(py::init<>()) .def_readwrite("k_cache", &scheduler::InferenceContext::k_cache) - .def_readwrite("v_cache", &scheduler::InferenceContext::v_cache) - ; + .def_readwrite("v_cache", &scheduler::InferenceContext::v_cache); py::class_(m, "QueryAdd") .def(py::init<>()) @@ -173,15 +205,18 @@ PYBIND11_MODULE(sched_ext, m) { .def("serialize", &scheduler::QueryAdd::serialize) .def_static("deserialize", &scheduler::QueryAdd::deserialize) .def(py::pickle( - [](const scheduler::QueryAdd& self) { + [](const scheduler::QueryAdd &self) { return py::make_tuple(self.query_token, // self.attn_mask, - self.query_length, self.estimated_length, self.sample_options, self.user_id, - self.SLO_TTFT_ms, self.SLO_TBT_ms, self.stop_criteria); + self.query_length, self.estimated_length, + self.sample_options, self.user_id, + self.SLO_TTFT_ms, self.SLO_TBT_ms, + self.stop_criteria); }, [](py::tuple t) { if (t.size() != 8) - throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size())); + throw std::runtime_error("Invalid state! t.size() = " + + std::to_string(t.size())); scheduler::QueryAdd qa; qa.query_token = t[0].cast>(); // qa.attn_mask = t[1].cast(); @@ -195,14 +230,20 @@ PYBIND11_MODULE(sched_ext, m) { return qa; })); - py::class_>(m, "Scheduler") + py::class_>( + m, "Scheduler") .def("init", &scheduler::Scheduler::init) .def("run", &scheduler::Scheduler::run) .def("stop", &scheduler::Scheduler::stop) - .def("add_query", &scheduler::Scheduler::add_query, py::call_guard()) - .def("cancel_query", &scheduler::Scheduler::cancel_query, py::call_guard()) - .def("update_last_batch", &scheduler::Scheduler::update_last_batch, py::call_guard()) - .def("get_inference_context", &scheduler::Scheduler::get_inference_context); + .def("add_query", &scheduler::Scheduler::add_query, + py::call_guard()) + .def("cancel_query", &scheduler::Scheduler::cancel_query, + py::call_guard()) + .def("update_last_batch", &scheduler::Scheduler::update_last_batch, + py::call_guard()) + .def("get_inference_context", + &scheduler::Scheduler::get_inference_context); - m.def("create_scheduler", &scheduler::create_scheduler, "Create a new Scheduler instance"); + m.def("create_scheduler", &scheduler::create_scheduler, + "Create a new Scheduler instance"); } diff --git a/csrc/balance_serve/sched/metrics.cpp b/csrc/balance_serve/sched/metrics.cpp index 8a6f259..eafbbde 100644 --- a/csrc/balance_serve/sched/metrics.cpp +++ b/csrc/balance_serve/sched/metrics.cpp @@ -2,89 +2,101 @@ #include // 构造函数 -Metrics::Metrics(const MetricsConfig& config) +Metrics::Metrics(const MetricsConfig &config) : registry_(std::make_shared()), - exposer_(config.endpoint), - stop_uptime_thread_(false), + exposer_(config.endpoint), stop_uptime_thread_(false), start_time_(std::chrono::steady_clock::now()) { // 定义统一的桶大小,最大为 10000 ms (10 s) - std::vector common_buckets = {0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0, 5.0, - 10.0, 50.0, 100.0, 500.0, 1000.0, 5000.0, 10000.0}; // 毫秒 + std::vector common_buckets = { + 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0, 5.0, + 10.0, 50.0, 100.0, 500.0, 1000.0, 5000.0, 10000.0}; // 毫秒 // 注册 TTFT_ms Histogram - auto& TTFT_family = prometheus::BuildHistogram() + auto &TTFT_family = prometheus::BuildHistogram() .Name(std::string(METRIC_PREFIX) + "_TTFT_ms") .Help("Time to first token in milliseconds") .Register(*registry_); TTFT_ms = &TTFT_family.Add({{"model", config.model_name}}, common_buckets); // 注册 TBT_ms Histogram - auto& TBT_family = prometheus::BuildHistogram() + auto &TBT_family = prometheus::BuildHistogram() .Name(std::string(METRIC_PREFIX) + "_TBT_ms") .Help("Time between tokens in milliseconds") .Register(*registry_); TBT_ms = &TBT_family.Add({{"model", config.model_name}}, common_buckets); // 注册 schedule_time Histogram - auto& schedule_time_family = prometheus::BuildHistogram() - .Name(std::string(METRIC_PREFIX) + "_schedule_time_ms") - .Help("Time to generate schedule in milliseconds") - .Register(*registry_); - schedule_time = &schedule_time_family.Add({{"model", config.model_name}}, common_buckets); + auto &schedule_time_family = + prometheus::BuildHistogram() + .Name(std::string(METRIC_PREFIX) + "_schedule_time_ms") + .Help("Time to generate schedule in milliseconds") + .Register(*registry_); + schedule_time = + &schedule_time_family.Add({{"model", config.model_name}}, common_buckets); // 注册 generated_tokens Counter - auto& generated_tokens_family = prometheus::BuildCounter() - .Name(std::string(METRIC_PREFIX) + "_generated_tokens_total") - .Help("Total generated tokens") - .Register(*registry_); - generated_tokens = &generated_tokens_family.Add({{"model", config.model_name}}); + auto &generated_tokens_family = + prometheus::BuildCounter() + .Name(std::string(METRIC_PREFIX) + "_generated_tokens_total") + .Help("Total generated tokens") + .Register(*registry_); + generated_tokens = + &generated_tokens_family.Add({{"model", config.model_name}}); // 注册 throughput_query Gauge - auto& throughput_query_family = prometheus::BuildGauge() - .Name(std::string(METRIC_PREFIX) + "_throughput_query") - .Help("Throughput per second based on queries") - .Register(*registry_); - throughput_query = &throughput_query_family.Add({{"model", config.model_name}}); + auto &throughput_query_family = + prometheus::BuildGauge() + .Name(std::string(METRIC_PREFIX) + "_throughput_query") + .Help("Throughput per second based on queries") + .Register(*registry_); + throughput_query = + &throughput_query_family.Add({{"model", config.model_name}}); // 注册 throughput_generated_tokens Gauge - auto& throughput_generated_tokens_family = prometheus::BuildGauge() - .Name(std::string(METRIC_PREFIX) + "_throughput_generated_tokens") - .Help("Throughput per second based on generated tokens") - .Register(*registry_); - throughput_generated_tokens = &throughput_generated_tokens_family.Add({{"model", config.model_name}}); + auto &throughput_generated_tokens_family = + prometheus::BuildGauge() + .Name(std::string(METRIC_PREFIX) + "_throughput_generated_tokens") + .Help("Throughput per second based on generated tokens") + .Register(*registry_); + throughput_generated_tokens = + &throughput_generated_tokens_family.Add({{"model", config.model_name}}); // 注册 event_count Counter family - event_count_family_ = &prometheus::BuildCounter() - .Name(std::string(METRIC_PREFIX) + "_event_count_total") - .Help("Count of various events") - .Register(*registry_); + event_count_family_ = + &prometheus::BuildCounter() + .Name(std::string(METRIC_PREFIX) + "_event_count_total") + .Help("Count of various events") + .Register(*registry_); - batch_count_family_ = &prometheus::BuildCounter() - .Name(std::string(METRIC_PREFIX) + "_batch_count_total") - .Help("Count of various batch by status") - .Register(*registry_); + batch_count_family_ = + &prometheus::BuildCounter() + .Name(std::string(METRIC_PREFIX) + "_batch_count_total") + .Help("Count of various batch by status") + .Register(*registry_); // 注册 query_count Counter family - query_count_family_ = &prometheus::BuildCounter() - .Name(std::string(METRIC_PREFIX) + "_query_count_total") - .Help("Count of queries by status") - .Register(*registry_); + query_count_family_ = + &prometheus::BuildCounter() + .Name(std::string(METRIC_PREFIX) + "_query_count_total") + .Help("Count of queries by status") + .Register(*registry_); // 注册 uptime_ms Gauge - auto& uptime_family = prometheus::BuildGauge() + auto &uptime_family = prometheus::BuildGauge() .Name(std::string(METRIC_PREFIX) + "_uptime_ms") .Help("Uptime of the scheduler in milliseconds") .Register(*registry_); uptime_ms = &uptime_family.Add({{"model", config.model_name}}); // 注册 GPU 利用率 Gauges - auto& gpu_util_family = prometheus::BuildGauge() - .Name(std::string(METRIC_PREFIX) + "_gpu_utilization_ratio") - .Help("Current GPU utilization ratio (0 to 1)") - .Register(*registry_); + auto &gpu_util_family = + prometheus::BuildGauge() + .Name(std::string(METRIC_PREFIX) + "_gpu_utilization_ratio") + .Help("Current GPU utilization ratio (0 to 1)") + .Register(*registry_); for (size_t i = 0; i < config.gpu_count; ++i) { - gpu_utilization_gauges.push_back( - &gpu_util_family.Add({{"gpu_id", std::to_string(i)}, {"model", config.model_name}})); + gpu_utilization_gauges.push_back(&gpu_util_family.Add( + {{"gpu_id", std::to_string(i)}, {"model", config.model_name}})); } // 将 Registry 注册到 Exposer 中 @@ -95,16 +107,15 @@ Metrics::Metrics(const MetricsConfig& config) } // 析构函数 -Metrics::~Metrics() { - StopUptimeUpdater(); -} +Metrics::~Metrics() { StopUptimeUpdater(); } // 启动 uptime 更新线程 void Metrics::StartUptimeUpdater() { uptime_thread_ = std::thread([this]() { while (!stop_uptime_thread_) { auto now = std::chrono::steady_clock::now(); - std::chrono::duration uptime_duration = now - start_time_; + std::chrono::duration uptime_duration = + now - start_time_; uptime_ms->Set(uptime_duration.count()); // fn_every_sec(this); std::this_thread::sleep_for(std::chrono::seconds(1)); @@ -121,15 +132,16 @@ void Metrics::StopUptimeUpdater() { } // 获取 event_count 指标 -prometheus::Counter* Metrics::event_count(const std::string& type) { - return &event_count_family_->Add({{"type", type}}); // 可根据需要添加更多标签 +prometheus::Counter *Metrics::event_count(const std::string &type) { + return &event_count_family_->Add({{"type", type}}); // 可根据需要添加更多标签 } // 获取 query_count 指标 -prometheus::Counter* Metrics::query_count(const std::string& status) { - return &query_count_family_->Add({{"status", status}}); // 可根据需要添加更多标签 +prometheus::Counter *Metrics::query_count(const std::string &status) { + return &query_count_family_->Add( + {{"status", status}}); // 可根据需要添加更多标签 } -prometheus::Counter* Metrics::batch_count(const std::string& type) { +prometheus::Counter *Metrics::batch_count(const std::string &type) { return &batch_count_family_->Add({{"type", type}}); } diff --git a/csrc/balance_serve/sched/metrics.h b/csrc/balance_serve/sched/metrics.h index 226684c..bfd18a3 100644 --- a/csrc/balance_serve/sched/metrics.h +++ b/csrc/balance_serve/sched/metrics.h @@ -1,14 +1,14 @@ #ifndef Metrics_H #define Metrics_H +#include +#include +#include #include #include #include #include #include -#include -#include -#include #include #include #include @@ -21,46 +21,46 @@ class Metrics; // 配置结构体 struct MetricsConfig { std::string endpoint; - std::string model_name; // 模型名称,如 "gpt-4" - size_t gpu_count; // GPU数量 + std::string model_name; // 模型名称,如 "gpt-4" + size_t gpu_count; // GPU数量 }; // Metrics 类,根据配置初始化 Prometheus 指标 class Metrics { - public: +public: // 构造函数传入 MetricsConfig - Metrics(const MetricsConfig& config); + Metrics(const MetricsConfig &config); ~Metrics(); // 禁止拷贝和赋值 - Metrics(const Metrics&) = delete; - Metrics& operator=(const Metrics&) = delete; + Metrics(const Metrics &) = delete; + Metrics &operator=(const Metrics &) = delete; - std::function fn_every_sec; + std::function fn_every_sec; // 指标指针 - prometheus::Gauge* uptime_ms; - prometheus::Histogram* TTFT_ms; - prometheus::Histogram* TBT_ms; - prometheus::Histogram* schedule_time; - prometheus::Gauge* throughput_query; - prometheus::Gauge* throughput_generated_tokens; - prometheus::Counter* generated_tokens; - std::vector gpu_utilization_gauges; + prometheus::Gauge *uptime_ms; + prometheus::Histogram *TTFT_ms; + prometheus::Histogram *TBT_ms; + prometheus::Histogram *schedule_time; + prometheus::Gauge *throughput_query; + prometheus::Gauge *throughput_generated_tokens; + prometheus::Counter *generated_tokens; + std::vector gpu_utilization_gauges; // 计数器家族 - prometheus::Counter* event_count(const std::string& type); - prometheus::Counter* query_count(const std::string& status); - prometheus::Counter* batch_count(const std::string& type); + prometheus::Counter *event_count(const std::string &type); + prometheus::Counter *query_count(const std::string &status); + prometheus::Counter *batch_count(const std::string &type); - private: +private: std::shared_ptr registry_; prometheus::Exposer exposer_; // 计数器家族 - prometheus::Family* event_count_family_; - prometheus::Family* batch_count_family_; - prometheus::Family* query_count_family_; + prometheus::Family *event_count_family_; + prometheus::Family *batch_count_family_; + prometheus::Family *query_count_family_; // 线程和控制变量用于更新 uptime_ms std::thread uptime_thread_; @@ -76,10 +76,13 @@ class Metrics { }; struct HistogramTimerWrapper { - prometheus::Histogram* histogram; + prometheus::Histogram *histogram; Timer timer; - inline HistogramTimerWrapper(prometheus::Histogram* histogram) : histogram(histogram), timer() { timer.start(); } + inline HistogramTimerWrapper(prometheus::Histogram *histogram) + : histogram(histogram), timer() { + timer.start(); + } inline ~HistogramTimerWrapper() { histogram->Observe(timer.elapsedMs()); } }; -#endif // Metrics_H +#endif // Metrics_H diff --git a/csrc/balance_serve/sched/model_config.h b/csrc/balance_serve/sched/model_config.h index ff6e915..e7512c4 100644 --- a/csrc/balance_serve/sched/model_config.h +++ b/csrc/balance_serve/sched/model_config.h @@ -1,8 +1,8 @@ #ifndef __MODEL_CONFIG_HPP_ #define __MODEL_CONFIG_HPP_ -#include #include "nlohmann/json.hpp" +#include #include #include @@ -13,7 +13,7 @@ using ModelName = std::string; // We must assure this can be load by config.json class ModelConfig { - public: +public: DimSize hidden_size; DimSize intermediate_size; size_t max_position_embeddings; @@ -23,10 +23,13 @@ class ModelConfig { size_t num_key_value_heads; size_t vocab_size; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, intermediate_size, max_position_embeddings, model_type, - num_attention_heads, num_hidden_layers, num_key_value_heads, vocab_size); + NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, intermediate_size, + max_position_embeddings, model_type, + num_attention_heads, num_hidden_layers, + num_key_value_heads, vocab_size); void load_from(std::filesystem::path path) { + std::cout << "Load from " << path << std::endl; std::ifstream i(path); nlohmann::json j; i >> j; @@ -38,12 +41,14 @@ using QuantType = std::string; static const QuantType NoQuantType = ""; class QuantConfig { - public: +public: QuantType name; // For GEMV QuantType type_of_dot_vector = NoQuantType; - inline bool can_be_used_as_matrix() { return type_of_dot_vector != NoQuantType; } + inline bool can_be_used_as_matrix() { + return type_of_dot_vector != NoQuantType; + } bool can_be_used_as_vector; @@ -56,8 +61,11 @@ class QuantConfig { URL reference = ""; - NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(QuantConfig, name, type_of_dot_vector, can_be_used_as_vector, - bytes_per_element, has_scale, has_min, block_element_count, + NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(QuantConfig, name, + type_of_dot_vector, + can_be_used_as_vector, + bytes_per_element, has_scale, + has_min, block_element_count, block_element_size, reference); }; @@ -70,14 +78,13 @@ inline void load_quant_configs(std::filesystem::path path) { std::cout << __FUNCTION__ << " from " << path << std::endl; std::ifstream i(path); i >> j; + quant_configs = j.get>(); + std::cout << "Loaded Quant Configs" << std::endl; + for (auto &[k, v] : quant_configs) { + std::cout << " - " << k << std::endl; + } } else { - std::cout << __FUNCTION__ << " create new at " << path << std::endl; - } - - quant_configs = j.get>(); - std::cout << "Loaded Quant Configs" << std::endl; - for (auto& [k, v] : quant_configs) { - std::cout << " - " << k << std::endl; + std::cout << __FUNCTION__ << " no file at " << path << std::endl; } } @@ -93,14 +100,13 @@ inline void load_model_configs(std::filesystem::path path) { std::cout << __FUNCTION__ << " from " << path << std::endl; std::ifstream i(path); i >> j; + model_configs = j.get>(); + std::cout << "Loaded Model Configs" << std::endl; + for (auto &[k, v] : model_configs) { + std::cout << " - " << k << std::endl; + } } else { - std::cout << __FUNCTION__ << " create new at " << path << std::endl; - } - - model_configs = j.get>(); - std::cout << "Loaded Model Configs" << std::endl; - for (auto& [k, v] : model_configs) { - std::cout << " - " << k << std::endl; + std::cout << __FUNCTION__ << " no file at " << path << std::endl; } } diff --git a/csrc/balance_serve/sched/scheduler.cpp b/csrc/balance_serve/sched/scheduler.cpp index 266157e..1410c14 100644 --- a/csrc/balance_serve/sched/scheduler.cpp +++ b/csrc/balance_serve/sched/scheduler.cpp @@ -3,20 +3,20 @@ #include "nlohmann/json.hpp" #include "spdlog/spdlog.h" -#include #include "scheduler.h" +#include -#include -#include -#include -#include -#include #include "arithmetic.hpp" #include "atomic_ptr_with_flags.hpp" #include "easy_format.hpp" #include "metrics.h" #include "mpsc.hpp" #include "timer.hpp" +#include +#include +#include +#include +#include #include "kvc2.h" @@ -28,7 +28,8 @@ void Settings::auto_derive() { gpu_device_count = gpu_device_id.size(); if (torch::cuda::is_available()) { size_t gpu_count = torch::cuda::device_count(); - SPDLOG_INFO("Number of available GPUs: {}, want {}", gpu_count, gpu_device_count); + SPDLOG_INFO("Number of available GPUs: {}, want {}", gpu_count, + gpu_device_count); if (gpu_count < gpu_device_count) { SPDLOG_ERROR("Not enough GPUs available."); exit(0); @@ -42,37 +43,49 @@ void Settings::auto_derive() { } if (model_settings.num_k_heads % gpu_device_count != 0) { - SPDLOG_ERROR("num_k_heads {} is not divisible by gpu_device_count {}", model_settings.num_k_heads, - gpu_device_count); + SPDLOG_ERROR("num_k_heads {} is not divisible by gpu_device_count {}", + model_settings.num_k_heads, gpu_device_count); assert(false); } size_t gpu_memory_available = gpu_memory_size * memory_utilization_percentage; - if (gpu_memory_available * gpu_device_count < model_settings.params_nbytes()) { - SPDLOG_ERROR("GPU memory size {}G is smaller than {}G", gpu_memory_available * gpu_device_count / 1e9, + if (gpu_memory_available * gpu_device_count < + model_settings.params_nbytes()) { + SPDLOG_ERROR("GPU memory size {}G is smaller than {}G", + gpu_memory_available * gpu_device_count / 1e9, model_settings.params_nbytes() / 1e9); assert(false); } assert(model_settings.k_head_dim % model_settings.num_k_heads == 0); size_t head_per_gpu = model_settings.num_k_heads / gpu_device_count; - size_t gpu_memory_for_kv_cache = gpu_memory_available /*- model_settings.params_nbytes() / gpu_device_count*/; - SPDLOG_INFO("Each GPU Total: {}MiB, Model Params: {}MiB, KVCache: {}MiB, Left: {}MiB", gpu_memory_size / (1 << 20), - model_settings.params_nbytes() / gpu_device_count / (1 << 20), gpu_memory_for_kv_cache / (1 << 20), - (gpu_memory_size - gpu_memory_available) / (1 << 20)); + size_t gpu_memory_for_kv_cache = + gpu_memory_available /*- model_settings.params_nbytes() / + gpu_device_count*/ + ; + SPDLOG_INFO( + "Each GPU Total: {}MiB, Model Params: {}MiB, KVCache: {}MiB, Left: {}MiB", + gpu_memory_size / (1 << 20), + model_settings.params_nbytes() / gpu_device_count / (1 << 20), + gpu_memory_for_kv_cache / (1 << 20), + (gpu_memory_size - gpu_memory_available) / (1 << 20)); size_t kv_cache_on_cnt = (size_t)(k_cache_on) + (size_t)(v_cache_on); size_t max_total_kvcache_pages = - gpu_memory_for_kv_cache / (kv_cache_on_cnt * head_per_gpu * model_settings.k_head_dim * - model_settings.bytes_per_kv_cache_element * page_size * model_settings.layer_count); + gpu_memory_for_kv_cache / + (kv_cache_on_cnt * head_per_gpu * model_settings.k_head_dim * + model_settings.bytes_per_kv_cache_element * page_size * + model_settings.layer_count); if (total_kvcache_pages.has_value()) { if (total_kvcache_pages.value() > max_total_kvcache_pages) { - SPDLOG_ERROR("total_kvcache_pages {} is larger than max_total_kvcache_pages {}", total_kvcache_pages.value(), - max_total_kvcache_pages); + SPDLOG_ERROR( + "total_kvcache_pages {} is larger than max_total_kvcache_pages {}", + total_kvcache_pages.value(), max_total_kvcache_pages); assert(false); } } else { total_kvcache_pages = max_total_kvcache_pages; - SPDLOG_INFO("total_kvcache_pages is auto derived as {}", max_total_kvcache_pages); + SPDLOG_INFO("total_kvcache_pages is auto derived as {}", + max_total_kvcache_pages); } if (page_size % 256 != 0) { @@ -88,7 +101,7 @@ void Settings::auto_derive() { std::string BatchQueryTodo::debug() { std::string re = "BatchQueryTodo: "; re += "QueryIDs: "; - for (auto& id : query_ids) { + for (auto &id : query_ids) { re += std::to_string(id) + " "; } return re; @@ -119,59 +132,61 @@ struct Query { // Query status changed by this order enum Status { Received, Preparing, Ready, Prefill, Decode, Done }; Status plan_status = Received; - TokenLength active_position; // the position where no kvcache now - TokenLength plan_position; // the position where no kvcache now, in plan + TokenLength active_position; // the position where no kvcache now + TokenLength plan_position; // the position where no kvcache now, in plan size_t prepare_try_count = 0; std::shared_ptr kvc2_handle = nullptr; // derived from kvc2_handle - torch::Tensor block_index; // block indexes + torch::Tensor block_index; // block indexes struct QueryContext { ModelName model_name; QuantType quant_type; - kvc2::KVC2Interface* kvc2_interface; - QueryMaintainer* query_maintainer; - Metrics* met; + kvc2::KVC2Interface *kvc2_interface; + QueryMaintainer *query_maintainer; + Metrics *met; } ctx; void after_load(bool ok); void to_status(Status to); - void export_metrics() { ctx.met->query_count(status_to_string(plan_status))->Increment(1); } + void export_metrics() { + ctx.met->query_count(status_to_string(plan_status))->Increment(1); + } Query(QueryID id, QueryAdd query_add, QueryContext context) - : id(id), - prompt_length(query_add.query_length), - no_kvcache_from(0), + : id(id), prompt_length(query_add.query_length), no_kvcache_from(0), estimated_length(query_add.estimated_length), - sample_options(query_add.sample_options), - user_id(query_add.user_id), - SLO_TTFT_ms(query_add.SLO_TTFT_ms), - SLO_TBT_ms(query_add.SLO_TBT_ms), - stop_criteria(query_add.stop_criteria), - ctx(context) { + sample_options(query_add.sample_options), user_id(query_add.user_id), + SLO_TTFT_ms(query_add.SLO_TTFT_ms), SLO_TBT_ms(query_add.SLO_TBT_ms), + stop_criteria(query_add.stop_criteria), ctx(context) { std::vector shape = {int64_t(query_add.estimated_length)}; - query_token = torch::zeros(shape, torch::TensorOptions().dtype(torch::kInt32)); + query_token = + torch::zeros(shape, torch::TensorOptions().dtype(torch::kInt32)); assert(query_token.is_contiguous()); if (query_token.is_contiguous() == false) { SPDLOG_ERROR("Query Token must be contiguous!"); exit(1); } - memcpy(query_token.data_ptr(), query_add.query_token.data(), query_add.query_length * sizeof(Token)); + memcpy(query_token.data_ptr(), query_add.query_token.data(), + query_add.query_length * sizeof(Token)); - no_kvcache_from = 0; // maybe match prefix later + no_kvcache_from = 0; // maybe match prefix later export_metrics(); } - Token& token_at(size_t idx) { return reinterpret_cast(query_token.data_ptr())[idx]; } + Token &token_at(size_t idx) { + return reinterpret_cast(query_token.data_ptr())[idx]; + } - void absorb_update(const QueryUpdate& update) { + void absorb_update(const QueryUpdate &update) { SPDLOG_DEBUG("{}", update.debug()); active_position = update.active_position; - kvc2_handle->append_tokens(&token_at(0), active_position); // active_position is length -1 + kvc2_handle->append_tokens(&token_at(0), + active_position); // active_position is length -1 if (update.is_prefill) { if (active_position == prompt_length) { token_at(active_position) = update.generated_token; @@ -187,15 +202,17 @@ struct Query { } } - void absorb_prefill_task(const PrefillTask& task) { - auto& [id, start, length] = task; + void absorb_prefill_task(const PrefillTask &task) { + auto &[id, start, length] = task; this->plan_position = start + length; if (this->plan_position == prompt_length) { to_status(Decode); } } - void absorb_decode_task([[maybe_unused]] const QueryID& task) { this->plan_position += 1; } + void absorb_decode_task([[maybe_unused]] const QueryID &task) { + this->plan_position += 1; + } PrefillTask get_prefill_task(size_t prefill_length) { if (prefill_length + plan_position > prompt_length) { @@ -206,18 +223,18 @@ struct Query { static std::string status_to_string(Status status) { switch (status) { - case Received: - return "Received"; - case Preparing: - return "Preparing"; - case Ready: - return "Ready"; - case Prefill: - return "Prefill"; - case Decode: - return "Decode"; - case Done: - return "Done"; + case Received: + return "Received"; + case Preparing: + return "Preparing"; + case Ready: + return "Ready"; + case Prefill: + return "Prefill"; + case Decode: + return "Decode"; + case Done: + return "Done"; } assert(false); } @@ -225,16 +242,19 @@ struct Query { void debug() { std::string status_string = status_to_string(plan_status); - SPDLOG_DEBUG( - "Query {}, prompt_length {}, estimated_length {}, plan status {}, plan position {} " - "active position {}", - id, prompt_length, estimated_length, status_string, plan_position, active_position); + SPDLOG_DEBUG("Query {}, prompt_length {}, estimated_length {}, plan status " + "{}, plan position {} " + "active position {}", + id, prompt_length, estimated_length, status_string, + plan_position, active_position); } }; std::string QueryUpdate::debug() const { - return fmt::format("Query {}, ok {}, is_prefill {}, done {}, active_position {}, gen token {}", id, ok, is_prefill, - decode_done, active_position, generated_token); + return fmt::format("Query {}, ok {}, is_prefill {}, done {}, active_position " + "{}, gen token {}", + id, ok, is_prefill, decode_done, active_position, + generated_token); } using Q = std::shared_ptr; @@ -258,18 +278,22 @@ struct KVC2_Maintainer { .total_kvcache_pages = settings.total_kvcache_pages.value(), .num_token_per_page = settings.page_size, .num_k_heads = settings.model_settings.num_k_heads, - .k_head_dim = - settings.use_self_defined_head_dim ? settings.self_defined_head_dim : settings.model_settings.k_head_dim, + .k_head_dim = settings.use_self_defined_head_dim + ? settings.self_defined_head_dim + : settings.model_settings.k_head_dim, .full_kv_cache_on_each_gpu = settings.full_kv_cache_on_each_gpu, .k_cache_on = settings.k_cache_on, .v_cache_on = settings.v_cache_on, .tensor_type = torch::kBFloat16, }; - auto model_configs_path = std::filesystem::path(settings.kvc2_config_path) / "model_configs.json"; + auto model_configs_path = + std::filesystem::path(settings.kvc2_config_path) / "model_configs.json"; load_model_configs(model_configs_path); auto my_model_config = ModelConfig(); - my_model_config.load_from(std::filesystem::path(settings.model_settings.model_path) / "config.json"); + my_model_config.load_from( + std::filesystem::path(settings.model_settings.model_path) / + "config.json"); model_configs[settings.model_name] = my_model_config; dump_model_configs(model_configs_path); @@ -299,7 +323,7 @@ struct KVC2_Maintainer { } }; -using EventAddQuery = std::pair*>; +using EventAddQuery = std::pair *>; using EventUpdateQuery = BatchQueryUpdate; using EventTakenBatch = std::shared_ptr; struct EventPrepare { @@ -311,55 +335,48 @@ struct EventPrepared { bool ok; }; -struct EventQueryStatus{ +struct EventQueryStatus { QueryID query_id; Query::Status now_status; }; struct EventSchedule {}; -using Event = std::variant; +using Event = + std::variant; -template -std::string event_name(const T& event); +template std::string event_name(const T &event); -template <> -std::string event_name(const EventAddQuery&) { +template <> std::string event_name(const EventAddQuery &) { return "EventAddQuery"; } -template <> -std::string event_name(const EventUpdateQuery&) { +template <> std::string event_name(const EventUpdateQuery &) { return "EventUpdateQuery"; } -template <> -std::string event_name(const EventTakenBatch&) { +template <> std::string event_name(const EventTakenBatch &) { return "EventTakenBatch"; } -template <> -std::string event_name(const EventPrepare&) { +template <> std::string event_name(const EventPrepare &) { return "EventPrepare"; } -template <> -std::string event_name(const EventPrepared&) { +template <> std::string event_name(const EventPrepared &) { return "EventPrepared"; } -template <> -std::string event_name(const EventQueryStatus&) { +template <> std::string event_name(const EventQueryStatus &) { return "EventQueryStatus"; } -template <> -std::string event_name(const EventSchedule&) { +template <> std::string event_name(const EventSchedule &) { return "EventSchedule"; } // 用 std::visit 实现对 variant 的 event_name -std::string event_name(const Event& event) { - return std::visit([](const auto& e) { return event_name(e); }, event); +std::string event_name(const Event &event) { + return std::visit([](const auto &e) { return event_name(e); }, event); } static_assert(std::is_copy_constructible::value); @@ -383,13 +400,13 @@ struct QueryMaintainer : public Scheduler { QueryMaintainer() = default; - void gen_batch_query_todo(BatchQueryTodo* re, const std::set& queries) { + void gen_batch_query_todo(BatchQueryTodo *re, const std::set &queries) { std::vector> d_batch(2); size_t last_decode_batch = 0; size_t prefill_num = 0; size_t decode_num = 0; size_t preill_length = 0; - for (auto& q : queries) { + for (auto &q : queries) { if (q->plan_status == Query::Prefill) { prefill_num += 1; } @@ -397,13 +414,13 @@ struct QueryMaintainer : public Scheduler { decode_num += 1; } } - if (prefill_num >= 2 || (prefill_num ==1 && settings.max_batch_size - 2 < decode_num)) { - preill_length = settings.recommended_chunk_prefill_token_count; - } - else { + if (prefill_num >= 2 || + (prefill_num == 1 && settings.max_batch_size - 2 < decode_num)) { + preill_length = settings.recommended_chunk_prefill_token_count; + } else { preill_length = settings.recommended_chunk_prefill_token_count * 2; } - for (auto& q : queries) { + for (auto &q : queries) { re->query_ids.push_back(q->id); re->query_tokens.push_back(q->query_token); re->query_lengths.push_back(q->prompt_length); @@ -427,7 +444,7 @@ struct QueryMaintainer : public Scheduler { re->attn_masks = std::nullopt; re->rope_ranges = std::nullopt; - for (auto& b : d_batch) { + for (auto &b : d_batch) { if (b.empty()) continue; re->decode_mini_batches.push_back(b); @@ -439,46 +456,54 @@ struct QueryMaintainer : public Scheduler { // Interface void init(Settings settings) override { - SPDLOG_INFO( - "\nScheduler Settings:\n" - " model_name: {}\n" - " quant_type: {}\n" - " model_path: {}\n" - " params_count: {}\n" - " layer_count: {}\n" - " num_k_heads: {}\n" - " k_head_dim: {}\n" - " bytes_per_params: {}\n" - " bytes_per_kv_cache_element: {}\n" - " page_size: {}\n" - " gpu_device_id: {}\n" - " gpu_memory_size: {}\n" - " memory_utilization_percentage: {}\n" - " max_batch_size: {}\n" - " recommended_chunk_prefill_token_count: {}\n" - " sched_metrics_port: {}\n" - " kvc2_config_path: {}\n" - " kvc2_root_path: {}\n" - " memory_pool_size_GB: {}\n" - " evict_count: {}\n" - " kvc2_metrics_port: {}\n" - " load_from_disk: {}\n" - " save_to_disk: {}\n" - " strategy_name: {}\n" - " gpu_device_count: {}\n", - settings.model_name, settings.quant_type, settings.model_settings.model_path, - settings.model_settings.params_count, settings.model_settings.layer_count, settings.model_settings.num_k_heads, - settings.model_settings.k_head_dim, settings.model_settings.bytes_per_params, - settings.model_settings.bytes_per_kv_cache_element, + SPDLOG_INFO("\nScheduler Settings:\n" + " model_name: {}\n" + " quant_type: {}\n" + " model_path: {}\n" + " params_count: {}\n" + " layer_count: {}\n" + " num_k_heads: {}\n" + " k_head_dim: {}\n" + " bytes_per_params: {}\n" + " bytes_per_kv_cache_element: {}\n" + " page_size: {}\n" + " gpu_device_id: {}\n" + " gpu_memory_size: {}\n" + " memory_utilization_percentage: {}\n" + " max_batch_size: {}\n" + " recommended_chunk_prefill_token_count: {}\n" + " sched_metrics_port: {}\n" + " kvc2_config_path: {}\n" + " kvc2_root_path: {}\n" + " memory_pool_size_GB: {}\n" + " evict_count: {}\n" + " kvc2_metrics_port: {}\n" + " load_from_disk: {}\n" + " save_to_disk: {}\n" + " strategy_name: {}\n" + " gpu_device_count: {}\n", + settings.model_name, settings.quant_type, + settings.model_settings.model_path, + settings.model_settings.params_count, + settings.model_settings.layer_count, + settings.model_settings.num_k_heads, + settings.model_settings.k_head_dim, + settings.model_settings.bytes_per_params, + settings.model_settings.bytes_per_kv_cache_element, - settings.page_size, format_vector(settings.gpu_device_id), readable_number(settings.gpu_memory_size), - settings.memory_utilization_percentage, settings.max_batch_size, settings.recommended_chunk_prefill_token_count, - settings.sched_metrics_port, settings.kvc2_config_path, settings.kvc2_root_path, settings.memory_pool_size_GB, - settings.evict_count, settings.kvc2_metrics_port, settings.load_from_disk, settings.save_to_disk, - settings.strategy_name, settings.gpu_device_count); + settings.page_size, format_vector(settings.gpu_device_id), + readable_number(settings.gpu_memory_size), + settings.memory_utilization_percentage, settings.max_batch_size, + settings.recommended_chunk_prefill_token_count, + settings.sched_metrics_port, settings.kvc2_config_path, + settings.kvc2_root_path, settings.memory_pool_size_GB, + settings.evict_count, settings.kvc2_metrics_port, + settings.load_from_disk, settings.save_to_disk, + settings.strategy_name, settings.gpu_device_count); this->settings = settings; - kvc2_maintainer = std::shared_ptr(new KVC2_Maintainer(settings)); + kvc2_maintainer = + std::shared_ptr(new KVC2_Maintainer(settings)); MetricsConfig met_conf = { .endpoint = "0.0.0.0:" + std::to_string(settings.sched_metrics_port), .model_name = settings.model_name, @@ -487,7 +512,7 @@ struct QueryMaintainer : public Scheduler { SPDLOG_INFO("Creating scheduler metrics exporter on {}", met_conf.endpoint); met = std::make_shared(met_conf); - met->fn_every_sec = [](Metrics* met) { + met->fn_every_sec = [](Metrics *met) { auto generated_tokens = met->generated_tokens->Collect().counter.value; SPDLOG_INFO("Last Sec Generated Tokens {}", generated_tokens); }; @@ -522,7 +547,8 @@ struct QueryMaintainer : public Scheduler { // Here this function update last batch results and get the next batch // in most cases, the batch is ready, // if not, busy wait to get it - std::shared_ptr update_last_batch(BatchQueryUpdate updates) override { + std::shared_ptr + update_last_batch(BatchQueryUpdate updates) override { event_loop_queue.enqueue(updates); // Busy Wait @@ -547,20 +573,22 @@ struct QueryMaintainer : public Scheduler { InferenceContext re; re.k_cache = kvc2_maintainer->k_cache; re.v_cache = kvc2_maintainer->v_cache; - // kvc2_maintainer->k_cache[0][0][0][0][0][0] = 42; // test whether we pass this to inference loop + // kvc2_maintainer->k_cache[0][0][0][0][0][0] = 42; // test whether we pass + // this to inference loop return re; } virtual void strategy_add_query(Q new_query) = 0; - virtual void strategy_update_query(const EventUpdateQuery& update) = 0; - virtual void strategy_taken_batch(const EventTakenBatch& batch) = 0; - virtual void strategy_prepare(const EventPrepare& prepare) = 0; - virtual void strategy_prepared(const EventPrepared& prepared) = 0; - virtual void strategy_query_status(const EventQueryStatus& query_status) = 0; - virtual void strategy_schedule(const EventSchedule& event, BatchQueryTodo* new_batch) = 0; + virtual void strategy_update_query(const EventUpdateQuery &update) = 0; + virtual void strategy_taken_batch(const EventTakenBatch &batch) = 0; + virtual void strategy_prepare(const EventPrepare &prepare) = 0; + virtual void strategy_prepared(const EventPrepared &prepared) = 0; + virtual void strategy_query_status(const EventQueryStatus &query_status) = 0; + virtual void strategy_schedule(const EventSchedule &event, + BatchQueryTodo *new_batch) = 0; - void tackle_event(EventAddQuery& event) { - auto& query_add = event.first; + void tackle_event(EventAddQuery &event) { + auto &query_add = event.first; QueryID id = query_id_counter; event.second->set_value(id); query_id_counter += 1; @@ -570,33 +598,36 @@ struct QueryMaintainer : public Scheduler { strategy_add_query(new_query); } - void tackle_event(const EventUpdateQuery& update) { + void tackle_event(const EventUpdateQuery &update) { // SPDLOG_INFO("Tackle Update Query"); - for (auto& u : update) { + for (auto &u : update) { if (u.ok == false) { SPDLOG_ERROR("Query {} is not exectued OK", u.id); exit(1); } auto q = query_map[u.id]; - if (q->plan_status == Query::Status::Prefill || q->plan_status == Query::Status::Decode) { + if (q->plan_status == Query::Status::Prefill || + q->plan_status == Query::Status::Decode) { q->absorb_update(u); } else { - SPDLOG_DEBUG("Query {} is not in Prefill or Decode status, do not update it", u.id); + SPDLOG_DEBUG( + "Query {} is not in Prefill or Decode status, do not update it", + u.id); } } strategy_update_query(update); } - void tackle_event(const EventTakenBatch& batch) { + void tackle_event(const EventTakenBatch &batch) { met->batch_count("Taken")->Increment(1); - for (auto& task : batch->prefill_mini_batches) { + for (auto &task : batch->prefill_mini_batches) { auto [id, s, l] = task; if (l == 0) continue; query_map.at(id)->absorb_prefill_task(task); } - for (auto& mini_batch : batch->decode_mini_batches) { - for (auto& id : mini_batch) { + for (auto &mini_batch : batch->decode_mini_batches) { + for (auto &id : mini_batch) { query_map.at(id)->absorb_decode_task(id); } } @@ -604,16 +635,18 @@ struct QueryMaintainer : public Scheduler { strategy_taken_batch(batch); } - void tackle_event(const EventPrepare& event) { strategy_prepare(event); } - void tackle_event(const EventPrepared& event) { strategy_prepared(event); } - void tackle_event(const EventQueryStatus& event) { strategy_query_status(event); } + void tackle_event(const EventPrepare &event) { strategy_prepare(event); } + void tackle_event(const EventPrepared &event) { strategy_prepared(event); } + void tackle_event(const EventQueryStatus &event) { + strategy_query_status(event); + } - void tackle_event(const EventSchedule& event) { + void tackle_event(const EventSchedule &event) { // SPDLOG_INFO("Tackle Schedule Event"); HistogramTimerWrapper t(met->schedule_time); - BatchQueryTodo* new_batch = new BatchQueryTodo; + BatchQueryTodo *new_batch = new BatchQueryTodo; strategy_schedule(event, new_batch); // if (new_batch->query_ids.empty()) { // SPDLOG_INFO("Nothing todo"); @@ -660,7 +693,8 @@ struct QueryMaintainer : public Scheduler { } }, event); - if (event_loop_queue.size() == 0 && std::holds_alternative(event) == false) { + if (event_loop_queue.size() == 0 && + std::holds_alternative(event) == false) { // if this is not a schedule event, we need to schedule one event_loop_queue.enqueue(EventSchedule()); } @@ -679,54 +713,58 @@ struct QueryMaintainer : public Scheduler { void Query::to_status(Status to) { SPDLOG_DEBUG("Calling to status query {}, to {}", id, status_to_string(to)); switch (to) { - case Received: - assert(false); - break; - case Preparing: - SPDLOG_INFO("Preparing Query {} {}", id, - prepare_try_count > 0 ? (std::to_string(prepare_try_count) + " Try") : ""); - prepare_try_count += 1; + case Received: + assert(false); + break; + case Preparing: + SPDLOG_INFO("Preparing Query {} {}", id, + prepare_try_count > 0 + ? (std::to_string(prepare_try_count) + " Try") + : ""); + prepare_try_count += 1; - ctx.kvc2_interface->lookup_to_gpu_async( - ctx.model_name, ctx.quant_type, static_cast(query_token.data_ptr()), prompt_length, - estimated_length, [this](std::shared_ptr handle) { - if (handle == nullptr) { - SPDLOG_INFO("Get handle from kvc2 Failed."); - this->after_load(false); - } else { - SPDLOG_INFO("Get handle from kvc2 Success."); - this->kvc2_handle = handle; - this->to_status(Ready); - this->after_load(true); - } - }); - break; - case Ready: - SPDLOG_INFO("Ready Query {}", id); - break; - case Prefill: - SPDLOG_INFO("Prefilling Query {}", id); - // assert(plan_status == Received); - plan_position = kvc2_handle->matched_length(); + ctx.kvc2_interface->lookup_to_gpu_async( + ctx.model_name, ctx.quant_type, + static_cast(query_token.data_ptr()), prompt_length, + estimated_length, + [this](std::shared_ptr handle) { + if (handle == nullptr) { + SPDLOG_INFO("Get handle from kvc2 Failed."); + this->after_load(false); + } else { + SPDLOG_INFO("Get handle from kvc2 Success."); + this->kvc2_handle = handle; + this->to_status(Ready); + this->after_load(true); + } + }); + break; + case Ready: + SPDLOG_INFO("Ready Query {}", id); + break; + case Prefill: + SPDLOG_INFO("Prefilling Query {}", id); + // assert(plan_status == Received); + plan_position = kvc2_handle->matched_length(); - if (prompt_length - plan_position == 0) { - assert(prompt_length > 0); - plan_position -= 1; - } - break; - case Decode: - SPDLOG_INFO("Decoding Query {}", id); - // assert(plan_status == Prefill); - break; - case Done: - SPDLOG_INFO("Finish Query {}", id); - kvc2_handle = nullptr; - ctx.query_maintainer->event_loop_queue.enqueue(EventQueryStatus{ + if (prompt_length - plan_position == 0) { + assert(prompt_length > 0); + plan_position -= 1; + } + break; + case Decode: + SPDLOG_INFO("Decoding Query {}", id); + // assert(plan_status == Prefill); + break; + case Done: + SPDLOG_INFO("Finish Query {}", id); + kvc2_handle = nullptr; + ctx.query_maintainer->event_loop_queue.enqueue(EventQueryStatus{ .query_id = id, .now_status = to, - }); - // assert(plan_status == Decode); - break; + }); + // assert(plan_status == Decode); + break; } plan_status = to; export_metrics(); @@ -734,11 +772,14 @@ void Query::to_status(Status to) { void Query::after_load(bool ok) { if (ok) { - size_t page_count = div_up(estimated_length, ctx.query_maintainer->settings.page_size); + size_t page_count = + div_up(estimated_length, ctx.query_maintainer->settings.page_size); std::vector shape; shape.push_back(page_count); - block_index = torch::zeros(shape, torch::TensorOptions().dtype(torch::kInt32)).contiguous(); - auto ptr = reinterpret_cast(block_index.data_ptr()); + block_index = + torch::zeros(shape, torch::TensorOptions().dtype(torch::kInt32)) + .contiguous(); + auto ptr = reinterpret_cast(block_index.data_ptr()); auto vec_idx = kvc2_handle->get_gpu_block_idx(); for (size_t i = 0; i < vec_idx.size(); i++) { ptr[i] = vec_idx[i]; @@ -765,7 +806,7 @@ struct FCFS_single_prefill : public QueryMaintainer { bool has_query_preparing = false; std::optional wait_done_prepare = std::nullopt; - std::set active_query; // on going queries for LLMs + std::set active_query; // on going queries for LLMs // interface all these are executed in a single thread void strategy_add_query(Q new_query) override { @@ -774,71 +815,72 @@ struct FCFS_single_prefill : public QueryMaintainer { has_query_preparing = true; auto next_q = queue.front(); queue.pop(); - event_loop_queue.enqueue(EventPrepare{next_q->id,true}); + event_loop_queue.enqueue(EventPrepare{next_q->id, true}); } } - void strategy_update_query(const EventUpdateQuery& update) override { + void strategy_update_query(const EventUpdateQuery &update) override { for (auto u : update) { - auto& q = query_map[u.id]; + auto &q = query_map[u.id]; if (q->plan_status == Query::Done) { active_query.erase(q); } } } - void strategy_taken_batch(const EventTakenBatch& batch) override { - for (auto& q : batch->query_ids) { + void strategy_taken_batch(const EventTakenBatch &batch) override { + for (auto &q : batch->query_ids) { if (query_map[q]->plan_status != Query::Done) { active_query.insert(query_map[q]); } } } - void strategy_prepare(const EventPrepare& prepare) override { - if(prepare.first_try){ - auto& q = query_map[prepare.query_id]; + void strategy_prepare(const EventPrepare &prepare) override { + if (prepare.first_try) { + auto &q = query_map[prepare.query_id]; q->to_status(Query::Preparing); - }else{ - assert(wait_done_prepare.has_value()==false); + } else { + assert(wait_done_prepare.has_value() == false); wait_done_prepare = prepare; wait_done_prepare->first_try = true; } } - void strategy_prepared(const EventPrepared& prepared) override { + void strategy_prepared(const EventPrepared &prepared) override { assert(prepared.ok); ready_queue.push(query_map[prepared.query_id]); if (queue.empty() == false) { auto next_q_prepare = queue.front(); queue.pop(); - event_loop_queue.enqueue(EventPrepare{next_q_prepare->id,true}); + event_loop_queue.enqueue(EventPrepare{next_q_prepare->id, true}); } else { has_query_preparing = false; } } - void strategy_query_status(const EventQueryStatus& query_status) override{ - if(query_status.now_status==Query::Done){ - if(wait_done_prepare.has_value()){ + void strategy_query_status(const EventQueryStatus &query_status) override { + if (query_status.now_status == Query::Done) { + if (wait_done_prepare.has_value()) { event_loop_queue.enqueue(wait_done_prepare.value()); wait_done_prepare = std::nullopt; } } - } - void strategy_schedule([[maybe_unused]] const EventSchedule& event, BatchQueryTodo* new_batch) override { + void strategy_schedule([[maybe_unused]] const EventSchedule &event, + BatchQueryTodo *new_batch) override { bool have_prefill = false; - for (auto& q : active_query) { + for (auto &q : active_query) { if (q->plan_status == Query::Prefill) { have_prefill = true; } } - if (have_prefill == false && ready_queue.empty() == false && active_query.size() < settings.max_batch_size) { - auto& next_q = ready_queue.front(); + if (have_prefill == false && ready_queue.empty() == false && + active_query.size() < settings.max_batch_size) { + auto &next_q = ready_queue.front(); ready_queue.pop(); SPDLOG_INFO("Active query {}", next_q->id); @@ -847,7 +889,7 @@ struct FCFS_single_prefill : public QueryMaintainer { } if (active_query.empty() == false) SPDLOG_INFO("Active Query Size {}", active_query.size()); - for (auto& q : active_query) { + for (auto &q : active_query) { q->debug(); } gen_batch_query_todo(new_batch, active_query); @@ -855,10 +897,11 @@ struct FCFS_single_prefill : public QueryMaintainer { }; struct FCFS : public FCFS_single_prefill { - void strategy_schedule([[maybe_unused]] const EventSchedule& event, BatchQueryTodo* new_batch) override { + void strategy_schedule([[maybe_unused]] const EventSchedule &event, + BatchQueryTodo *new_batch) override { int prefill_count = 0; const int max_prefill_count = 2; - for (auto& q : active_query) { + for (auto &q : active_query) { if (q->plan_status == Query::Prefill) { prefill_count += 1; } @@ -877,7 +920,7 @@ struct FCFS : public FCFS_single_prefill { if (active_query.empty() == false) { SPDLOG_DEBUG("Active Query Size {}", active_query.size()); } - for (auto& q : active_query) { + for (auto &q : active_query) { q->debug(); } gen_batch_query_todo(new_batch, active_query); @@ -900,7 +943,8 @@ std::shared_ptr create_scheduler(Settings settings) { } NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(SampleOptions, temperature, top_p); -NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(QueryAdd, query_token, query_length, estimated_length, sample_options, user_id, +NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(QueryAdd, query_token, query_length, + estimated_length, sample_options, user_id, SLO_TTFT_ms, SLO_TBT_ms); std::string QueryAdd::serialize() { @@ -908,9 +952,9 @@ std::string QueryAdd::serialize() { return j.dump(); } -QueryAdd QueryAdd::deserialize(const std::string& input) { +QueryAdd QueryAdd::deserialize(const std::string &input) { json j = json::parse(input); return j.get(); } -}; // namespace scheduler +}; // namespace scheduler diff --git a/csrc/balance_serve/sched/scheduler.h b/csrc/balance_serve/sched/scheduler.h index 0889c08..6e8662c 100644 --- a/csrc/balance_serve/sched/scheduler.h +++ b/csrc/balance_serve/sched/scheduler.h @@ -1,10 +1,10 @@ #pragma once -#include +#include "model_config.h" #include #include #include +#include #include -#include "model_config.h" namespace scheduler { @@ -28,7 +28,9 @@ struct ModelSettings { double bytes_per_kv_cache_element; inline size_t params_nbytes() { return params_count * bytes_per_params; } - inline size_t bytes_per_token_kv_cache() { return bytes_per_kv_cache_element * num_k_heads * k_head_dim; } + inline size_t bytes_per_token_kv_cache() { + return bytes_per_kv_cache_element * num_k_heads * k_head_dim; + } }; struct SampleOptions { @@ -37,15 +39,16 @@ struct SampleOptions { }; struct Settings { - // something is aukward here, kvc2 only use model_name and quant_type to get model infos. + // something is aukward here, kvc2 only use model_name and quant_type to get + // model infos. ModelName model_name; QuantType quant_type; // model_setting is ignore by kvc2 ModelSettings model_settings; - size_t page_size = 256; // how many token in a page - std::vector gpu_device_id; // - size_t gpu_memory_size; // memory size in bytes of each GPU, each + size_t page_size = 256; // how many token in a page + std::vector gpu_device_id; // + size_t gpu_memory_size; // memory size in bytes of each GPU, each double memory_utilization_percentage; size_t max_batch_size = 256; @@ -79,14 +82,16 @@ struct Settings { void auto_derive(); }; -using PrefillTask = std::tuple; // id, start, length +using PrefillTask = + std::tuple; // id, start, length struct BatchQueryTodo { // query std::vector query_ids; std::vector query_tokens; std::vector query_lengths; - std::vector block_indexes; // (max_num_blocks_per_seq), dtype torch.int32. + std::vector + block_indexes; // (max_num_blocks_per_seq), dtype torch.int32. std::optional attn_masks; std::optional rope_ranges; std::vector sample_options; @@ -94,8 +99,10 @@ struct BatchQueryTodo { // mini batches, adjacent two mini batches are executed together // tasks count must be <=2, because of flash infer attention - std::vector prefill_mini_batches; // prefill minibatch only has 1 prefill - std::vector> decode_mini_batches; // decode minibatch has multiple decode + std::vector + prefill_mini_batches; // prefill minibatch only has 1 prefill + std::vector> + decode_mini_batches; // decode minibatch has multiple decode std::string debug(); bool empty(); @@ -105,9 +112,9 @@ struct QueryUpdate { QueryID id; bool ok; bool is_prefill; - bool decode_done; // no use for now - TokenLength active_position; // the position where no kvcache now, - // kvcache[active_position] == None + bool decode_done; // no use for now + TokenLength active_position; // the position where no kvcache now, + // kvcache[active_position] == None Token generated_token; @@ -117,8 +124,8 @@ struct QueryUpdate { using BatchQueryUpdate = std::vector; struct InferenceContext { - std::vector k_cache; // [gpu num] (layer_count, num blocks, - // page size, kheadnum, head_dim) + std::vector k_cache; // [gpu num] (layer_count, num blocks, + // page size, kheadnum, head_dim) std::vector v_cache; }; @@ -127,7 +134,7 @@ constexpr UserID NoUser = -1; const int MAX_SLO_TIME = 1e9; struct QueryAdd { - std::vector query_token; // int here + std::vector query_token; // int here // torch::Tensor attn_mask; TokenLength query_length; TokenLength estimated_length; @@ -141,11 +148,11 @@ struct QueryAdd { int SLO_TBT_ms = MAX_SLO_TIME; std::string serialize(); - static QueryAdd deserialize(const std::string& input); + static QueryAdd deserialize(const std::string &input); }; class Scheduler { - public: +public: virtual void init(Settings settings) = 0; virtual void run() = 0; @@ -156,7 +163,8 @@ class Scheduler { virtual void cancel_query(QueryID id) = 0; // inference loop call this - virtual std::shared_ptr update_last_batch(BatchQueryUpdate updates) = 0; + virtual std::shared_ptr + update_last_batch(BatchQueryUpdate updates) = 0; virtual InferenceContext get_inference_context() = 0; virtual ~Scheduler() = default; @@ -164,4 +172,4 @@ class Scheduler { std::shared_ptr create_scheduler(Settings settings); -}; // namespace scheduler \ No newline at end of file +}; // namespace scheduler \ No newline at end of file diff --git a/csrc/balance_serve/sched/utils/arithmetic.hpp b/csrc/balance_serve/sched/utils/arithmetic.hpp index 7562f56..1eda61e 100644 --- a/csrc/balance_serve/sched/utils/arithmetic.hpp +++ b/csrc/balance_serve/sched/utils/arithmetic.hpp @@ -1,7 +1,6 @@ #include -template -T div_up(T x, U by) { +template T div_up(T x, U by) { static_assert(std::is_integral_v); static_assert(std::is_integral_v); return (x + by - 1) / by; diff --git a/csrc/balance_serve/sched/utils/atomic_ptr_with_flags.hpp b/csrc/balance_serve/sched/utils/atomic_ptr_with_flags.hpp index f0c98bf..95aeea2 100644 --- a/csrc/balance_serve/sched/utils/atomic_ptr_with_flags.hpp +++ b/csrc/balance_serve/sched/utils/atomic_ptr_with_flags.hpp @@ -1,28 +1,35 @@ #include -template -struct AtomicPtrWithFlag { +template struct AtomicPtrWithFlag { constexpr static uint64_t mask = 1ull << 63; std::atomic_uint64_t ptr = 0; - std::pair load(std::memory_order order = std::memory_order_seq_cst) { + std::pair + load(std::memory_order order = std::memory_order_seq_cst) { uint64_t val = ptr.load(order); - return {reinterpret_cast(val & (~mask)), val & mask}; + return {reinterpret_cast(val & (~mask)), val & mask}; } - void store(T* p, bool flag, std::memory_order order = std::memory_order_seq_cst) { + void store(T *p, bool flag, + std::memory_order order = std::memory_order_seq_cst) { ptr.store(reinterpret_cast(p) | (flag ? mask : 0), order); } - std::pair exchange(T* p, bool flag, std::memory_order order = std::memory_order_seq_cst) { - uint64_t val = ptr.exchange(reinterpret_cast(p) | (flag ? mask : 0), order); - return {reinterpret_cast(val & (~mask)), val & mask}; + std::pair + exchange(T *p, bool flag, + std::memory_order order = std::memory_order_seq_cst) { + uint64_t val = + ptr.exchange(reinterpret_cast(p) | (flag ? mask : 0), order); + return {reinterpret_cast(val & (~mask)), val & mask}; } - std::pair touch_load(std::memory_order order = std::memory_order_seq_cst) { + std::pair + touch_load(std::memory_order order = std::memory_order_seq_cst) { uint64_t val = ptr.fetch_and(~mask, order); - return {reinterpret_cast(val & (~mask)), val & mask}; + return {reinterpret_cast(val & (~mask)), val & mask}; } - bool check_flag(std::memory_order order = std::memory_order_seq_cst) { return ptr.load(order) & mask; } + bool check_flag(std::memory_order order = std::memory_order_seq_cst) { + return ptr.load(order) & mask; + } }; diff --git a/csrc/balance_serve/sched/utils/csv.hpp b/csrc/balance_serve/sched/utils/csv.hpp index d9dd558..6a7cc98 100644 --- a/csrc/balance_serve/sched/utils/csv.hpp +++ b/csrc/balance_serve/sched/utils/csv.hpp @@ -19,7 +19,7 @@ namespace csv { * @param line The CSV line to parse. * @return A vector of strings, each representing a field in the CSV line. */ -inline std::vector parse_csv_line(const std::string& line) { +inline std::vector parse_csv_line(const std::string &line) { std::vector result; std::string field; bool in_quotes = false; @@ -57,7 +57,8 @@ inline std::vector parse_csv_line(const std::string& line) { * @return A vector of pairs, each containing a column name and a vector of data * for that column. */ -inline std::vector>> read_csv(const std::string& filename) { +inline std::vector>> +read_csv(const std::string &filename) { std::cout << "Reading CSV file: " << filename << std::endl; // Open the file std::ifstream file(filename); @@ -72,7 +73,7 @@ inline std::vector>> read_csv(co // Prepare the result vector with column names std::vector>> result; - for (const auto& name : column_names) { + for (const auto &name : column_names) { result.emplace_back(name, std::vector()); } @@ -84,7 +85,7 @@ inline std::vector>> read_csv(co // Determine the number of threads to use unsigned int num_threads = std::thread::hardware_concurrency(); if (num_threads == 0) - num_threads = 4; // Default to 4 threads if hardware_concurrency returns 0 + num_threads = 4; // Default to 4 threads if hardware_concurrency returns 0 // Calculate chunk start positions based on content size std::vector chunk_starts; @@ -100,14 +101,15 @@ inline std::vector>> read_csv(co ++pos; } if (pos < content_size) { - ++pos; // Skip the newline character + ++pos; // Skip the newline character } chunk_starts.push_back(pos); } chunk_starts.push_back(content_size); // Create threads to parse each chunk - std::vector>> thread_results(num_threads); + std::vector>> thread_results( + num_threads); std::vector threads; for (unsigned int i = 0; i < num_threads; ++i) { @@ -133,13 +135,13 @@ inline std::vector>> read_csv(co } // Wait for all threads to finish - for (auto& t : threads) { + for (auto &t : threads) { t.join(); } // Combine the results from all threads into the final result - for (const auto& local_result : thread_results) { - for (const auto& row : local_result) { + for (const auto &local_result : thread_results) { + for (const auto &row : local_result) { for (size_t i = 0; i < row.size(); ++i) { if (i < result.size()) { result[i].second.push_back(row[i]); @@ -158,8 +160,9 @@ inline std::vector>> read_csv(co * @param data A vector of pairs, each containing a column name and a vector of * data for that column. */ -inline void write_csv(const std::string& filename, - const std::vector>>& data) { +inline void write_csv( + const std::string &filename, + const std::vector>> &data) { std::cout << "Writing CSV file: " << filename << std::endl; // Open the file for writing @@ -170,10 +173,10 @@ inline void write_csv(const std::string& filename, // Check that all columns have the same number of rows if (data.empty()) { - return; // Nothing to write + return; // Nothing to write } size_t num_rows = data[0].second.size(); - for (const auto& column : data) { + for (const auto &column : data) { if (column.second.size() != num_rows) { throw std::runtime_error("All columns must have the same number of rows"); } @@ -191,7 +194,7 @@ inline void write_csv(const std::string& filename, // Write the data rows for (size_t row = 0; row < num_rows; ++row) { for (size_t col = 0; col < data.size(); ++col) { - const std::string& field = data[col].second[row]; + const std::string &field = data[col].second[row]; // Handle CSV escaping std::string escaped_field = field; bool needs_quotes = false; @@ -204,7 +207,8 @@ inline void write_csv(const std::string& filename, pos += 2; } } - if (escaped_field.find(',') != std::string::npos || escaped_field.find('\n') != std::string::npos) { + if (escaped_field.find(',') != std::string::npos || + escaped_field.find('\n') != std::string::npos) { needs_quotes = true; } if (needs_quotes) { @@ -220,6 +224,6 @@ inline void write_csv(const std::string& filename, } } -} // namespace csv +} // namespace csv -#endif // CSV_READER_HPP +#endif // CSV_READER_HPP diff --git a/csrc/balance_serve/sched/utils/easy_format.hpp b/csrc/balance_serve/sched/utils/easy_format.hpp index d541410..ae60624 100644 --- a/csrc/balance_serve/sched/utils/easy_format.hpp +++ b/csrc/balance_serve/sched/utils/easy_format.hpp @@ -2,15 +2,14 @@ #include #include -template -std::string format_vector(const std::vector& v) { +template std::string format_vector(const std::vector &v) { std::ostringstream oss; if (v.empty()) return "[]"; for (size_t i = 0; i < v.size(); ++i) { oss << v[i]; if (i < v.size() - 1) - oss << ", "; // 逗号分隔 + oss << ", "; // 逗号分隔 } return oss.str(); } diff --git a/csrc/balance_serve/sched/utils/mpsc.hpp b/csrc/balance_serve/sched/utils/mpsc.hpp index 4a3b476..10aac0d 100644 --- a/csrc/balance_serve/sched/utils/mpsc.hpp +++ b/csrc/balance_serve/sched/utils/mpsc.hpp @@ -4,32 +4,31 @@ #include #include -template -class MPSCQueue { +template class MPSCQueue { struct Node { T data; - std::atomic next; + std::atomic next; Node() : next(nullptr) {} Node(T data_) : data(std::move(data_)), next(nullptr) {} }; - std::atomic head; - Node* tail; + std::atomic head; + Node *tail; - public: +public: std::atomic_size_t enqueue_count = 0; size_t dequeue_count = 0; MPSCQueue() { - Node* dummy = new Node(); + Node *dummy = new Node(); head.store(dummy, std::memory_order_seq_cst); tail = dummy; } ~MPSCQueue() { - Node* node = tail; + Node *node = tail; while (node) { - Node* next = node->next.load(std::memory_order_seq_cst); + Node *next = node->next.load(std::memory_order_seq_cst); delete node; node = next; } @@ -38,14 +37,14 @@ class MPSCQueue { // 生产者调用 void enqueue(T data) { enqueue_count.fetch_add(1); - Node* node = new Node(std::move(data)); - Node* prev_head = head.exchange(node, std::memory_order_seq_cst); + Node *node = new Node(std::move(data)); + Node *prev_head = head.exchange(node, std::memory_order_seq_cst); prev_head->next.store(node, std::memory_order_seq_cst); } // 消费者调用 std::optional dequeue() { - Node* next = tail->next.load(std::memory_order_seq_cst); + Node *next = tail->next.load(std::memory_order_seq_cst); if (next) { T res = std::move(next->data); delete tail; @@ -59,16 +58,16 @@ class MPSCQueue { size_t size() { return enqueue_count.load() - dequeue_count; } }; -template -class MPSCQueueConsumerLock { +template class MPSCQueueConsumerLock { MPSCQueue queue; std::counting_semaphore<> sema{0}; - public: +public: void enqueue(T data) { queue.enqueue(std::move(data)); - // std::atomic_thread_fence(std::memory_order_seq_cst);// Inserting this because the memory order might be wrong, I - // am also not that sure about this. + // std::atomic_thread_fence(std::memory_order_seq_cst);// Inserting this + // because the memory order might be wrong, I am also not that sure about + // this. sema.release(); } @@ -76,8 +75,10 @@ class MPSCQueueConsumerLock { auto re = queue.dequeue(); if (re.has_value()) { while (sema.try_acquire() == false) { - std::cerr << __FILE__ << ":" << __FUNCTION__ << " sema try acquire should be success, retrying, please check" - << std::endl; + std::cerr + << __FILE__ << ":" << __FUNCTION__ + << " sema try acquire should be success, retrying, please check" + << std::endl; // assert(false); } return re.value(); @@ -91,8 +92,10 @@ class MPSCQueueConsumerLock { auto re = queue.dequeue(); if (re.has_value()) { while (sema.try_acquire() == false) { - std::cerr << __FILE__ << ":" << __FUNCTION__ << " sema try acquire should be success, retrying, please check" - << std::endl; + std::cerr + << __FILE__ << ":" << __FUNCTION__ + << " sema try acquire should be success, retrying, please check" + << std::endl; // assert(false); } return re.value(); diff --git a/csrc/balance_serve/sched/utils/statistics.hpp b/csrc/balance_serve/sched/utils/statistics.hpp index 98e82a7..bcf5532 100644 --- a/csrc/balance_serve/sched/utils/statistics.hpp +++ b/csrc/balance_serve/sched/utils/statistics.hpp @@ -7,59 +7,71 @@ #include class Statistics { - public: +public: // Increment the counter for a given key by a specified value (default is 1) - void increment_counter(const std::string& key, int64_t value = 1) { counters_[key] += value; } + void increment_counter(const std::string &key, int64_t value = 1) { + counters_[key] += value; + } - int64_t& get_counter(const std::string& key) { return counters_[key]; } + int64_t &get_counter(const std::string &key) { return counters_[key]; } // Start the timer for a given key - void start_timer(const std::string& key) { active_timers_[key] = std::chrono::high_resolution_clock::now(); } + void start_timer(const std::string &key) { + active_timers_[key] = std::chrono::high_resolution_clock::now(); + } // Stop the timer for a given key and update the total time and count - void stop_timer(const std::string& key) { + void stop_timer(const std::string &key) { auto start_it = active_timers_.find(key); if (start_it != active_timers_.end()) { - auto duration = std::chrono::high_resolution_clock::now() - start_it->second; + auto duration = + std::chrono::high_resolution_clock::now() - start_it->second; timings_[key].total_time += duration; timings_[key].count += 1; active_timers_.erase(start_it); } else { // Handle error: stop_timer called without a matching start_timer - std::cerr << "Warning: stop_timer called for key '" << key << "' without a matching start_timer.\n"; + std::cerr << "Warning: stop_timer called for key '" << key + << "' without a matching start_timer.\n"; } } // Print out the collected statistical information void report() const { std::cout << "Counters:\n"; - for (const auto& kv : counters_) { + for (const auto &kv : counters_) { std::cout << " " << kv.first << ": " << kv.second << "\n"; } std::cout << "\nTimers:\n"; - for (const auto& kv : timings_) { + for (const auto &kv : timings_) { std::cout << " " << kv.first << ": count = " << kv.second.count << ", total_time = " << kv.second.total_time.count() << "s" - << ", average_time = " << (kv.second.count > 0 ? kv.second.total_time.count() / kv.second.count : 0) + << ", average_time = " + << (kv.second.count > 0 + ? kv.second.total_time.count() / kv.second.count + : 0) << "s\n"; } } - private: +private: // Mapping from key to counter std::unordered_map counters_; // Struct to hold timing information for a key struct TimingInfo { int64_t count = 0; - std::chrono::duration total_time = std::chrono::duration::zero(); + std::chrono::duration total_time = + std::chrono::duration::zero(); }; // Mapping from key to timing information std::unordered_map timings_; // Mapping from key to the start time of active timers - std::unordered_map active_timers_; + std::unordered_map + active_timers_; }; -#endif // STATISTICS_HPP +#endif // STATISTICS_HPP diff --git a/csrc/balance_serve/sched/utils/timer.hpp b/csrc/balance_serve/sched/utils/timer.hpp index 7ec4fc5..9cfb205 100644 --- a/csrc/balance_serve/sched/utils/timer.hpp +++ b/csrc/balance_serve/sched/utils/timer.hpp @@ -1,4 +1,5 @@ #pragma once +#include "readable_number.hpp" #include #include #include @@ -6,7 +7,6 @@ #include #include #include -#include "readable_number.hpp" inline std::string doubleToStringR2(double value) { std::stringstream stream; @@ -15,7 +15,7 @@ inline std::string doubleToStringR2(double value) { } class Timer { - public: +public: std::string name; bool tmp_timer = false; @@ -49,10 +49,14 @@ class Timer { endTime = m_endTime; } - return std::chrono::duration_cast(endTime - m_startTime).count(); + return std::chrono::duration_cast(endTime - + m_startTime) + .count(); } - void printElapsedMilliseconds() { std::cout << elapsedNs() / 1e6 << " ms" << std::endl; } + void printElapsedMilliseconds() { + std::cout << elapsedNs() / 1e6 << " ms" << std::endl; + } static std::string ns_to_string(double duration) { auto nano_sec = duration; @@ -100,13 +104,13 @@ class Timer { return readable_number(ops) + "op/s"; } - void merge(Timer& other) { + void merge(Timer &other) { assert(m_isRunning == false); assert(other.m_isRunning == false); m_runningNs += other.runningTimeNs(); } - private: +private: std::chrono::time_point m_startTime; std::chrono::time_point m_endTime; bool m_isRunning = false; @@ -114,14 +118,14 @@ class Timer { }; class Counter { - public: +public: Counter() {} std::map counters; - void inc(const char* name, size_t num) { counters[name] += num; }; + void inc(const char *name, size_t num) { counters[name] += num; }; void print() { - for (auto& p : counters) { + for (auto &p : counters) { std::cout << p.first << " : " << p.second << std::endl; } }; diff --git a/ktransformers/configs/model_configs.json b/ktransformers/configs/model_configs.json deleted file mode 100644 index 6ce80b0..0000000 --- a/ktransformers/configs/model_configs.json +++ /dev/null @@ -1,122 +0,0 @@ -{ - "DeepSeek-Coder-V2-Instruct": { - "hidden_size": 5120, - "intermediate_size": 12288, - "max_position_embeddings": 163840, - "model_type": "deepseek_v2", - "num_attention_heads": 128, - "num_hidden_layers": 60, - "num_key_value_heads": 128, - "vocab_size": 102400 - }, - "DeepSeek-R1": { - "hidden_size": 7168, - "intermediate_size": 18432, - "max_position_embeddings": 163840, - "model_type": "deepseek_v3", - "num_attention_heads": 128, - "num_hidden_layers": 61, - "num_key_value_heads": 128, - "vocab_size": 129280 - }, - "DeepSeek-V2-Lite-Chat": { - "hidden_size": 2048, - "intermediate_size": 10944, - "max_position_embeddings": 163840, - "model_type": "deepseek_v2", - "num_attention_heads": 16, - "num_hidden_layers": 27, - "num_key_value_heads": 16, - "vocab_size": 102400 - }, - "DeepSeek-V3": { - "hidden_size": 7168, - "intermediate_size": 18432, - "max_position_embeddings": 163840, - "model_type": "deepseek_v3", - "num_attention_heads": 128, - "num_hidden_layers": 3, - "num_key_value_heads": 128, - "vocab_size": 129280 - }, - "DeepSeek-V3-bf16": { - "hidden_size": 7168, - "intermediate_size": 18432, - "max_position_embeddings": 163840, - "model_type": "deepseek_v3", - "num_attention_heads": 128, - "num_hidden_layers": 61, - "num_key_value_heads": 128, - "vocab_size": 129280 - }, - "LLaMA-2-7B-32K": { - "hidden_size": 4096, - "intermediate_size": 11008, - "max_position_embeddings": 32768, - "model_type": "llama", - "num_attention_heads": 32, - "num_hidden_layers": 32, - "num_key_value_heads": 32, - "vocab_size": 32000 - }, - "Moonlight-16B-A3B-Instruct": { - "hidden_size": 2048, - "intermediate_size": 11264, - "max_position_embeddings": 8192, - "model_type": "deepseek_v3", - "num_attention_heads": 16, - "num_hidden_layers": 27, - "num_key_value_heads": 16, - "vocab_size": 163840 - }, - "Qwen2.5-32B-Instruct": { - "hidden_size": 5120, - "intermediate_size": 27648, - "max_position_embeddings": 32768, - "model_type": "qwen2", - "num_attention_heads": 40, - "num_hidden_layers": 64, - "num_key_value_heads": 8, - "vocab_size": 152064 - }, - "Qwen2.5-32B-Instruct-GPTQ-Int4": { - "hidden_size": 5120, - "intermediate_size": 27648, - "max_position_embeddings": 32768, - "model_type": "qwen2", - "num_attention_heads": 40, - "num_hidden_layers": 64, - "num_key_value_heads": 8, - "vocab_size": 152064 - }, - "Qwen2.5-7B-Instruct": { - "hidden_size": 3584, - "intermediate_size": 18944, - "max_position_embeddings": 32768, - "model_type": "qwen2", - "num_attention_heads": 28, - "num_hidden_layers": 28, - "num_key_value_heads": 4, - "vocab_size": 152064 - }, - "Qwen2.5-7B-Instruct-GPTQ-Int4": { - "hidden_size": 3584, - "intermediate_size": 18944, - "max_position_embeddings": 32768, - "model_type": "qwen2", - "num_attention_heads": 28, - "num_hidden_layers": 28, - "num_key_value_heads": 4, - "vocab_size": 152064 - }, - "qwen2-72b-instruct": { - "hidden_size": 8192, - "intermediate_size": 29568, - "max_position_embeddings": 32768, - "model_type": "qwen2", - "num_attention_heads": 64, - "num_hidden_layers": 80, - "num_key_value_heads": 8, - "vocab_size": 152064 - } -} \ No newline at end of file diff --git a/ktransformers/configs/quant_configs.json b/ktransformers/configs/quant_configs.json deleted file mode 100644 index 191df5a..0000000 --- a/ktransformers/configs/quant_configs.json +++ /dev/null @@ -1,57 +0,0 @@ -{ - "BF16": { - "block_element_count": 1, - "block_element_size": 2, - "bytes_per_element": 2.0, - "can_be_used_as_vector": true, - "has_min": false, - "has_scale": false, - "name": "BF16", - "reference": "", - "type_of_dot_vector": "BF16" - }, - "FP16": { - "block_element_count": 1, - "block_element_size": 2, - "bytes_per_element": 2.0, - "can_be_used_as_vector": true, - "has_min": false, - "has_scale": false, - "name": "FP16", - "reference": "", - "type_of_dot_vector": "FP16" - }, - "FP32": { - "block_element_count": 1, - "block_element_size": 4, - "bytes_per_element": 4.0, - "can_be_used_as_vector": true, - "has_min": false, - "has_scale": false, - "name": "FP32", - "reference": "", - "type_of_dot_vector": "FP32" - }, - "Q4_0": { - "block_element_count": 32, - "block_element_size": 18, - "bytes_per_element": 0.5625, - "can_be_used_as_vector": false, - "has_min": false, - "has_scale": true, - "name": "Q4_0", - "reference": "https://huggingface.co/docs/hub/gguf", - "type_of_dot_vector": "Q8_0" - }, - "Q8_0": { - "block_element_count": 32, - "block_element_size": 34, - "bytes_per_element": 1.0625, - "can_be_used_as_vector": true, - "has_min": false, - "has_scale": true, - "name": "Q8_0", - "reference": "https://huggingface.co/docs/hub/gguf", - "type_of_dot_vector": "Q8_0" - } -} \ No newline at end of file diff --git a/ktransformers/server/args.py b/ktransformers/server/args.py index e60da10..0536ec9 100644 --- a/ktransformers/server/args.py +++ b/ktransformers/server/args.py @@ -70,6 +70,9 @@ class ArgumentParser: parser.add_argument("--batch_size", type=int, default=self.cfg.batch_size) parser.add_argument("--cache_lens", type=int, default=self.cfg.cache_lens) + # kvc2 config + parser.add_argument("--kvc2_config_dir", type=str, default=self.cfg.kvc2_config_dir) + # log configs # log level: debug, info, warn, error, crit parser.add_argument("--log_dir", type=str, default=self.cfg.log_dir) diff --git a/ktransformers/server/balance_serve/settings.py b/ktransformers/server/balance_serve/settings.py index 0a90a86..a79cdac 100644 --- a/ktransformers/server/balance_serve/settings.py +++ b/ktransformers/server/balance_serve/settings.py @@ -7,9 +7,7 @@ import sys, os import yaml, json from time import sleep -current_dir = os.path.dirname(__file__) -# sched_path = os.path.abspath(os.path.join(current_dir, '../../../build/balance_serve/sched')) -# sys.path.insert(0, sched_path) + import sched_ext from transformers import AutoConfig @@ -52,8 +50,7 @@ def create_sched_settings(args): settings.v_cache_on = False settings.kvc2_root_path = '/mnt/data/persist-kvc' - settings.kvc2_config_path = os.path.join(current_dir, "..", "..", "configs") - print(os.path.join(current_dir, "..", "..", "configs")) + settings.kvc2_config_path = args.kvc2_config_dir settings.memory_pool_size_GB = args.cpu_memory_size_GB settings.evict_count = 40 settings.kvc2_metrics_port = args.kvc2_metrics_port diff --git a/ktransformers/server/config/config.py b/ktransformers/server/config/config.py index f3d82a5..055be06 100644 --- a/ktransformers/server/config/config.py +++ b/ktransformers/server/config/config.py @@ -34,12 +34,15 @@ class Config(metaclass=Singleton): user_path: str = os.path.expanduser("~") localstore_path: str = os.path.join(user_path, ".ktransformers") + kvc2_config_dir = os.path.join(localstore_path, "kvc2") config_path: str = os.path.join(localstore_path, Config.CONFIG_FILE_NAME) if not os.path.exists(config_yaml): print(f"Can't find config file, {config_yaml}") exit(-1) if not os.path.exists(localstore_path): os.mkdir(localstore_path) + if not os.path.exists(kvc2_config_dir): + os.mkdir(kvc2_config_dir) if not os.path.exists(config_path): shutil.copyfile(config_yaml, config_path) with open(config_path, "r", encoding="utf-8") as fp: @@ -62,10 +65,13 @@ class Config(metaclass=Singleton): self.localstore_path: str = os.path.join(self.user_path, ".ktransformers") # log configs self.log_dir = os.path.join(self.localstore_path, cfg["log"]["dir"]) + if not os.path.exists(self.log_dir): + os.mkdir(self.log_dir) self.log_file = cfg["log"]["file"] self.log_level = cfg["log"]["level"] self.backup_count = cfg["log"]["backup_count"] + self.kvc2_config_dir = os.path.join(self.localstore_path, "kvc2") # server configs self.server: dict = cfg.get("server", {}) self.server_ip = self.server.get("ip", "0.0.0.0")