mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 12:09:48 +00:00
format kvc2, delete quant_configs, move model_configs to ~/.ktransformers
This commit is contained in:
parent
9dd24ecd72
commit
64de784328
31 changed files with 853 additions and 878 deletions
|
@ -35,23 +35,23 @@ struct ArrayStore {
|
|||
if (to <= size) {
|
||||
return;
|
||||
}
|
||||
//TODO: extend file
|
||||
// TODO: extend file
|
||||
size = to;
|
||||
//LOG_INFO("Extend file to `, size `", to, size_in_bytes());
|
||||
// LOG_INFO("Extend file to `, size `", to, size_in_bytes());
|
||||
}
|
||||
|
||||
ArrayStore(size_t element_size, size_t size, std::filesystem::path data_path)
|
||||
: element_size(element_size),
|
||||
element_size_aligned((element_size + DeviceBlockSize - 1) / DeviceBlockSize),
|
||||
data_path(data_path) {
|
||||
//TODO: prefix cache
|
||||
// TODO: prefix cache
|
||||
}
|
||||
|
||||
void read(size_t index, void* buffer) {
|
||||
//TODO: read from file
|
||||
// TODO: read from file
|
||||
}
|
||||
void write(size_t index, void* buffer) {
|
||||
//TODO: write to file
|
||||
// TODO: write to file
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -98,15 +98,15 @@ struct IODealerImpl {
|
|||
IODealerImpl(bool use_io_uring, int IO_DEPTH) : use_io_uring(use_io_uring), IO_DEPTH(IO_DEPTH) {}
|
||||
|
||||
void queue_consumer() {
|
||||
//TODO:
|
||||
// TODO:
|
||||
}
|
||||
|
||||
void io_perf() {
|
||||
//TODO:
|
||||
// TODO:
|
||||
}
|
||||
|
||||
void io_dealer() {
|
||||
//TODO:
|
||||
// TODO:
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -130,7 +130,7 @@ void IODealer::stop() {
|
|||
if (io_impl->stop) {
|
||||
return;
|
||||
}
|
||||
//LOG_INFO("Stopping IO Dealer");
|
||||
// LOG_INFO("Stopping IO Dealer");
|
||||
io_impl->stop = true;
|
||||
}
|
||||
|
||||
|
|
|
@ -77,7 +77,6 @@ GPUPageCache::GPUPageCache(GPUPageCacheConfig& config) : config(config) {
|
|||
gpu_only_occupations.resize(config.total_kvcache_pages, false);
|
||||
}
|
||||
|
||||
|
||||
num_free_pages = config.total_kvcache_pages;
|
||||
for (size_t i = 0; i < config.layer_count; i++) {
|
||||
if (config.k_cache_on)
|
||||
|
@ -248,18 +247,19 @@ void GPUPageCache::append_col_to_request(std::vector<std::shared_ptr<CudaStreamM
|
|||
auto gpu_block_idx = k_handles[0][at]->gpu_block_idx.value();
|
||||
for (size_t layer = 0; layer < config.layer_count; layer++) {
|
||||
for (size_t which_gpu = 0; which_gpu < config.gpu_devices_id.size(); which_gpu++) {
|
||||
|
||||
if (config.k_cache_on) {
|
||||
assert(k_handles[layer][at]->data != nullptr);
|
||||
reqs[which_gpu]->sizes.push_back(tp_size[which_gpu]);
|
||||
reqs[which_gpu]->host_mem_addresses.push_back(offset_by_bytes(k_handles[layer][at]->data, tp_offset[which_gpu]));
|
||||
reqs[which_gpu]->host_mem_addresses.push_back(
|
||||
offset_by_bytes(k_handles[layer][at]->data, tp_offset[which_gpu]));
|
||||
reqs[which_gpu]->device_mem_addresses.push_back(k_cache[which_gpu][layer][gpu_block_idx].data_ptr());
|
||||
}
|
||||
|
||||
if (config.v_cache_on) {
|
||||
assert(v_handles[layer][at]->data != nullptr);
|
||||
reqs[which_gpu]->sizes.push_back(tp_size[which_gpu]);
|
||||
reqs[which_gpu]->host_mem_addresses.push_back(offset_by_bytes(v_handles[layer][at]->data, tp_offset[which_gpu]));
|
||||
reqs[which_gpu]->host_mem_addresses.push_back(
|
||||
offset_by_bytes(v_handles[layer][at]->data, tp_offset[which_gpu]));
|
||||
reqs[which_gpu]->device_mem_addresses.push_back(v_cache[which_gpu][layer][gpu_block_idx].data_ptr());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,16 +1,16 @@
|
|||
#pragma once
|
||||
|
||||
#include "prometheus/counter.h"
|
||||
#include "prometheus/exposer.h"
|
||||
#include "prometheus/gauge.h"
|
||||
#include "prometheus/histogram.h"
|
||||
#include "prometheus/registry.h"
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include "prometheus/counter.h"
|
||||
#include "prometheus/exposer.h"
|
||||
#include "prometheus/gauge.h"
|
||||
#include "prometheus/histogram.h"
|
||||
#include "prometheus/registry.h"
|
||||
|
||||
#include "utils/timer.hpp"
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
#ifndef __MODEL_CONFIG_HPP_
|
||||
#define __MODEL_CONFIG_HPP_
|
||||
|
||||
#include <iostream>
|
||||
#include "nlohmann/json.hpp"
|
||||
#include <iostream>
|
||||
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
|
@ -13,7 +13,7 @@ using ModelName = std::string;
|
|||
|
||||
// We must assure this can be load by config.json
|
||||
class ModelConfig {
|
||||
public:
|
||||
public:
|
||||
DimSize hidden_size;
|
||||
DimSize intermediate_size;
|
||||
size_t max_position_embeddings;
|
||||
|
@ -23,10 +23,13 @@ class ModelConfig {
|
|||
size_t num_key_value_heads;
|
||||
size_t vocab_size;
|
||||
|
||||
NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, intermediate_size, max_position_embeddings, model_type,
|
||||
num_attention_heads, num_hidden_layers, num_key_value_heads, vocab_size);
|
||||
NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, intermediate_size,
|
||||
max_position_embeddings, model_type,
|
||||
num_attention_heads, num_hidden_layers,
|
||||
num_key_value_heads, vocab_size);
|
||||
|
||||
void load_from(std::filesystem::path path) {
|
||||
std::cout << "Load from " << path << std::endl;
|
||||
std::ifstream i(path);
|
||||
nlohmann::json j;
|
||||
i >> j;
|
||||
|
@ -38,12 +41,14 @@ using QuantType = std::string;
|
|||
static const QuantType NoQuantType = "";
|
||||
|
||||
class QuantConfig {
|
||||
public:
|
||||
public:
|
||||
QuantType name;
|
||||
|
||||
// For GEMV
|
||||
QuantType type_of_dot_vector = NoQuantType;
|
||||
inline bool can_be_used_as_matrix() { return type_of_dot_vector != NoQuantType; }
|
||||
inline bool can_be_used_as_matrix() {
|
||||
return type_of_dot_vector != NoQuantType;
|
||||
}
|
||||
|
||||
bool can_be_used_as_vector;
|
||||
|
||||
|
@ -56,8 +61,11 @@ class QuantConfig {
|
|||
|
||||
URL reference = "";
|
||||
|
||||
NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(QuantConfig, name, type_of_dot_vector, can_be_used_as_vector,
|
||||
bytes_per_element, has_scale, has_min, block_element_count,
|
||||
NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(QuantConfig, name,
|
||||
type_of_dot_vector,
|
||||
can_be_used_as_vector,
|
||||
bytes_per_element, has_scale,
|
||||
has_min, block_element_count,
|
||||
block_element_size, reference);
|
||||
};
|
||||
|
||||
|
@ -65,14 +73,18 @@ inline std::map<QuantType, QuantConfig> quant_configs;
|
|||
inline std::map<ModelName, ModelConfig> model_configs;
|
||||
|
||||
inline void load_quant_configs(std::filesystem::path path) {
|
||||
std::cout << __FUNCTION__ << " from " << path << std::endl;
|
||||
std::ifstream i(path);
|
||||
nlohmann::json j;
|
||||
i >> j;
|
||||
quant_configs = j.get<std::map<QuantType, QuantConfig>>();
|
||||
std::cout << "Loaded Quant Configs" << std::endl;
|
||||
for (auto& [k, v] : quant_configs) {
|
||||
std::cout << " - " << k << std::endl;
|
||||
if (std::filesystem::exists(path)) {
|
||||
std::cout << __FUNCTION__ << " from " << path << std::endl;
|
||||
std::ifstream i(path);
|
||||
i >> j;
|
||||
quant_configs = j.get<std::map<QuantType, QuantConfig>>();
|
||||
std::cout << "Loaded Quant Configs" << std::endl;
|
||||
for (auto &[k, v] : quant_configs) {
|
||||
std::cout << " - " << k << std::endl;
|
||||
}
|
||||
} else {
|
||||
std::cout << __FUNCTION__ << " no file at " << path << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -83,14 +95,18 @@ inline void dump_quant_configs(std::filesystem::path path) {
|
|||
}
|
||||
|
||||
inline void load_model_configs(std::filesystem::path path) {
|
||||
std::cout << __FUNCTION__ << " from " << path << std::endl;
|
||||
std::ifstream i(path);
|
||||
nlohmann::json j;
|
||||
i >> j;
|
||||
model_configs = j.get<std::map<ModelName, ModelConfig>>();
|
||||
std::cout << "Loaded Model Configs" << std::endl;
|
||||
for (auto& [k, v] : model_configs) {
|
||||
std::cout << " - " << k << std::endl;
|
||||
if (std::filesystem::exists(path)) {
|
||||
std::cout << __FUNCTION__ << " from " << path << std::endl;
|
||||
std::ifstream i(path);
|
||||
i >> j;
|
||||
model_configs = j.get<std::map<ModelName, ModelConfig>>();
|
||||
std::cout << "Loaded Model Configs" << std::endl;
|
||||
for (auto &[k, v] : model_configs) {
|
||||
std::cout << " - " << k << std::endl;
|
||||
}
|
||||
} else {
|
||||
std::cout << __FUNCTION__ << " no file at " << path << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -17,13 +17,14 @@ PageAlignedMemoryPool::PageAlignedMemoryPool(size_t size_in_bytes) {
|
|||
assert(total_pages >= Blocks);
|
||||
page_per_block = total_pages / Blocks;
|
||||
|
||||
for (size_t block_index = 0; block_index < Blocks; block_index ++) {
|
||||
first_page[block_index] = reinterpret_cast<void*>(reinterpret_cast<intptr_t>(data) + static_cast<intptr_t>(block_index) * page_per_block * PageSize);
|
||||
for (size_t block_index = 0; block_index < Blocks; block_index++) {
|
||||
first_page[block_index] = reinterpret_cast<void*>(reinterpret_cast<intptr_t>(data) +
|
||||
static_cast<intptr_t>(block_index) * page_per_block * PageSize);
|
||||
count_page[block_index] =
|
||||
block_index == Blocks - 1 ? (total_pages - page_per_block * (Blocks - 1)) : page_per_block;
|
||||
SPDLOG_DEBUG("first_page[{}] = {}, count_page[{}] = {}",
|
||||
block_index, reinterpret_cast<intptr_t>(first_page[block_index]) - reinterpret_cast<intptr_t>(data),
|
||||
block_index, count_page[block_index]);
|
||||
block_index == Blocks - 1 ? (total_pages - page_per_block * (Blocks - 1)) : page_per_block;
|
||||
SPDLOG_DEBUG("first_page[{}] = {}, count_page[{}] = {}", block_index,
|
||||
reinterpret_cast<intptr_t>(first_page[block_index]) - reinterpret_cast<intptr_t>(data), block_index,
|
||||
count_page[block_index]);
|
||||
bitmap[block_index].resize(count_page[block_index], 0);
|
||||
}
|
||||
SPDLOG_INFO("PageAlignedMemoryPool with size {} Mbytes, {} pages", total_size / (1 << 20), page_count());
|
||||
|
@ -53,7 +54,7 @@ void* PageAlignedMemoryPool::alloc_in_block(size_t block_index, size_t alloc_siz
|
|||
size_t free_pages = 0;
|
||||
for (size_t i = 0; i < count_page[block_index]; i++) {
|
||||
if (bitmap[block_index][i] == 0) {
|
||||
free_pages ++;
|
||||
free_pages++;
|
||||
if (free_pages == alloc_size) {
|
||||
size_t page_index = i + 1 - free_pages;
|
||||
for (size_t page = page_index; page < page_index + alloc_size; page++) {
|
||||
|
@ -73,7 +74,7 @@ void* PageAlignedMemoryPool::alloc_in_block(size_t block_index, size_t alloc_siz
|
|||
void* PageAlignedMemoryPool::alloc(size_t size) {
|
||||
size_t alloc_size = div_up(size, PageSize);
|
||||
auto cnt = now_block.fetch_add(1, std::memory_order_relaxed);
|
||||
for (size_t i = 0; i < Blocks; i ++) {
|
||||
for (size_t i = 0; i < Blocks; i++) {
|
||||
auto result = alloc_in_block((i + cnt) % Blocks, alloc_size);
|
||||
if (result != nullptr) {
|
||||
allocated.fetch_add(alloc_size * PageSize, std::memory_order_relaxed);
|
||||
|
@ -119,5 +120,6 @@ void PageAlignedMemoryPool::defragment() {}
|
|||
/// 调试打印
|
||||
std::string PageAlignedMemoryPool::debug() {
|
||||
return fmt::format("PageAlignedMemoryPool: total_size: {}MB, allocated: {}, alloc/free count: {}/{}\n",
|
||||
readable_number(total_size), readable_number(size_t(allocated)), size_t(alloc_count), size_t(free_count));
|
||||
readable_number(total_size), readable_number(size_t(allocated)), size_t(alloc_count),
|
||||
size_t(free_count));
|
||||
}
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
#pragma once
|
||||
|
||||
#include <algorithm> // std::sort
|
||||
#include <cstddef> // size_t
|
||||
#include <mutex> // std::mutex
|
||||
#include <vector>
|
||||
#include <assert.h>
|
||||
#include <bitset>
|
||||
#include <algorithm> // std::sort
|
||||
#include <atomic>
|
||||
#include <bitset>
|
||||
#include <cstddef> // size_t
|
||||
#include <mutex> // std::mutex
|
||||
#include <vector>
|
||||
|
||||
constexpr size_t PageSize = 4096;
|
||||
|
||||
|
@ -26,10 +26,11 @@ struct PageAlignedMemoryPool {
|
|||
|
||||
std::mutex lock[Blocks];
|
||||
size_t page_per_block = 0;
|
||||
void *first_page[Blocks];
|
||||
void* first_page[Blocks];
|
||||
size_t count_page[Blocks];
|
||||
std::vector<int8_t> bitmap[Blocks];
|
||||
void* alloc_in_block(size_t block_index, size_t alloc_size);
|
||||
|
||||
public:
|
||||
/// 构造函数和析构函数
|
||||
explicit PageAlignedMemoryPool(size_t size_in_bytes);
|
||||
|
|
|
@ -339,7 +339,7 @@ struct Prefix {
|
|||
|
||||
void update_location(CacheInfo info, Location location) { locations.location_map[info] = location; }
|
||||
|
||||
Prefix* to_first_prefix_without_disk_locations(CacheInfo k_info/*, CacheInfo v_info*/) { // just k_info
|
||||
Prefix* to_first_prefix_without_disk_locations(CacheInfo k_info /*, CacheInfo v_info*/) { // just k_info
|
||||
auto now_prefix = this;
|
||||
while (now_prefix->prev != nullptr) {
|
||||
auto& prev = now_prefix->prev;
|
||||
|
@ -561,7 +561,7 @@ struct PrefixTree {
|
|||
if (need_lock) {
|
||||
sl = std::shared_lock<std::shared_mutex>(rw_lock);
|
||||
}
|
||||
//TODO: prefix cache
|
||||
// TODO: prefix cache
|
||||
}
|
||||
|
||||
PrefixMatch look_up_or_insert(Token* data, TokenLength length) {
|
||||
|
@ -579,7 +579,6 @@ struct PrefixTree {
|
|||
return re;
|
||||
}
|
||||
|
||||
|
||||
std::shared_ptr<Prefix> new_prefix_node(Prefix* prev, TokenLength prev_match_length, Token* data, TokenLength length,
|
||||
bool need_lock = true) {
|
||||
std::unique_lock<std::shared_mutex> ul;
|
||||
|
@ -700,9 +699,7 @@ struct DoubleCacheHandle : public DoubleCacheHandleInterface {
|
|||
}
|
||||
}
|
||||
}
|
||||
std::vector<MatchStatus> matched_status() override {
|
||||
assert(false);
|
||||
}
|
||||
std::vector<MatchStatus> matched_status() override { assert(false); }
|
||||
|
||||
bool any_match() {
|
||||
if (enable_alt) {
|
||||
|
@ -1066,7 +1063,6 @@ struct DoubleCacheHandle : public DoubleCacheHandleInterface {
|
|||
};
|
||||
|
||||
struct KVC2 : KVC2Interface {
|
||||
|
||||
KVC2Config config;
|
||||
std::shared_ptr<Metrics> met;
|
||||
|
||||
|
@ -1261,7 +1257,7 @@ struct KVC2 : KVC2Interface {
|
|||
re->kvc2_top = this;
|
||||
SPDLOG_DEBUG("Lookup TokenLength {}", length);
|
||||
if (config.gpu_only == false) {
|
||||
//TODO:
|
||||
// TODO:
|
||||
}
|
||||
return re;
|
||||
};
|
||||
|
@ -1694,9 +1690,11 @@ void GPUPageCache::gpu_background_flush() {
|
|||
if (col_uls.empty())
|
||||
continue;
|
||||
for (size_t l = 0; l < config.layer_count; l++) {
|
||||
if (config.k_cache_on && (occupations[l][i]->gpu_cc.dirty.load() == false || occupations[l][i]->cpu_cc.dirty.load()))
|
||||
if (config.k_cache_on &&
|
||||
(occupations[l][i]->gpu_cc.dirty.load() == false || occupations[l][i]->cpu_cc.dirty.load()))
|
||||
goto next_gpu_page;
|
||||
if (config.v_cache_on && (v_occupations[l][i]->gpu_cc.dirty.load() == false || v_occupations[l][i]->cpu_cc.dirty.load()))
|
||||
if (config.v_cache_on &&
|
||||
(v_occupations[l][i]->gpu_cc.dirty.load() == false || v_occupations[l][i]->cpu_cc.dirty.load()))
|
||||
goto next_gpu_page;
|
||||
}
|
||||
|
||||
|
|
|
@ -139,18 +139,18 @@ std::vector<Token> random_ids(size_t length, std::mt19937& gen) {
|
|||
return re;
|
||||
}
|
||||
|
||||
std::vector<layer_data> slice(std::vector<layer_data>& h1,size_t start,size_t end){
|
||||
std::vector<layer_data> slice(std::vector<layer_data>& h1, size_t start, size_t end) {
|
||||
std::vector<layer_data> re;
|
||||
for(auto&l:h1){
|
||||
for (auto& l : h1) {
|
||||
layer_data new_layer;
|
||||
new_layer.insert(new_layer.end(),l.begin()+start,l.begin()+end);
|
||||
new_layer.insert(new_layer.end(), l.begin() + start, l.begin() + end);
|
||||
re.push_back(new_layer);
|
||||
}
|
||||
return re;
|
||||
}
|
||||
|
||||
void cmp_handle_data(std::vector<layer_data> h1, std::vector<layer_data> h2,
|
||||
std::optional<size_t> blocks = std::nullopt) {
|
||||
std::optional<size_t> blocks = std::nullopt) {
|
||||
assert(h1.size() == h2.size());
|
||||
|
||||
for (size_t i = 0; i < h1.size(); i++) {
|
||||
|
|
|
@ -7,9 +7,9 @@ int main(int argc, char* argv[]) {
|
|||
config.gpu_cache_config->total_kvcache_pages = 12;
|
||||
auto kvc2 = kvc2::create_kvc2(config);
|
||||
|
||||
// #pragma omp parallel for
|
||||
// #pragma omp parallel for
|
||||
for (size_t ti = 0; ti < 2; ti++) {
|
||||
SPDLOG_WARN("Test {}",ti);
|
||||
SPDLOG_WARN("Test {}", ti);
|
||||
auto [kcache, vcache] = kvc2->get_kvcache();
|
||||
std::mt19937 gen(ti + 123);
|
||||
size_t total_page = 10;
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
#include "common.hpp"
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
|
||||
qw25_7B_gpu_config.v_cache_on = false;
|
||||
config.gpu_cache_config = qw25_7B_gpu_config;
|
||||
config.v_cache_on = false;
|
||||
|
|
|
@ -1,16 +1,15 @@
|
|||
|
||||
#include <unistd.h>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <random>
|
||||
#include <unistd.h>
|
||||
#include "page_aligned_memory_pool.cpp"
|
||||
|
||||
#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG
|
||||
#define FMT_HEADER_ONLY
|
||||
#include "spdlog/spdlog.h"
|
||||
|
||||
|
||||
// 每个线程执行的任务
|
||||
void thread_task(PageAlignedMemoryPool& pool) {
|
||||
std::mt19937 gen(123);
|
||||
|
@ -22,8 +21,8 @@ void thread_task(PageAlignedMemoryPool& pool) {
|
|||
void* ptr = pool.alloc(size);
|
||||
// SPDLOG_DEBUG(pool.debug());
|
||||
if (ptr) {
|
||||
pool.free(ptr, size);
|
||||
// allocated.push_back({ptr, size});
|
||||
pool.free(ptr, size);
|
||||
// allocated.push_back({ptr, size});
|
||||
}
|
||||
// sleep((int)(gen() % 1000) / 1000.0);
|
||||
}
|
||||
|
@ -36,20 +35,19 @@ void thread_task(PageAlignedMemoryPool& pool) {
|
|||
int main(int argc, char* argv[]) {
|
||||
spdlog::set_level(spdlog::level::debug);
|
||||
|
||||
|
||||
// 创建一个内存池
|
||||
PageAlignedMemoryPool pool(40ll * 1024 * 1024 * 1024); // 40 G
|
||||
PageAlignedMemoryPool pool(40ll * 1024 * 1024 * 1024); // 40 G
|
||||
|
||||
// 创建线程
|
||||
const int num_threads = 32;
|
||||
std::vector<std::thread> threads;
|
||||
for (int i = 0; i < num_threads; ++i) {
|
||||
threads.emplace_back(thread_task, std::ref(pool));
|
||||
threads.emplace_back(thread_task, std::ref(pool));
|
||||
}
|
||||
|
||||
// 等待所有线程完成
|
||||
for (auto& t : threads) {
|
||||
t.join();
|
||||
t.join();
|
||||
}
|
||||
|
||||
// 输出调试信息
|
||||
|
|
|
@ -1,171 +1,163 @@
|
|||
#include "utils/periodic_task.hpp"
|
||||
#include <chrono>
|
||||
#include <cstdio>
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
#include <future>
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <cstdio>
|
||||
#include <future>
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
#include "utils/periodic_task.hpp"
|
||||
|
||||
// 1. 任务是否按预期执行
|
||||
void testPeriodicTaskExecution() {
|
||||
std::atomic<int> execution_count{0};
|
||||
auto task = [&execution_count]() {
|
||||
execution_count++;
|
||||
};
|
||||
std::atomic<int> execution_count{0};
|
||||
auto task = [&execution_count]() { execution_count++; };
|
||||
|
||||
periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(50));
|
||||
periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(50));
|
||||
|
||||
std::this_thread::sleep_for(std::chrono::seconds(2));
|
||||
std::this_thread::sleep_for(std::chrono::seconds(2));
|
||||
|
||||
assert(execution_count >= 20); // 确保任务执行了至少 20 次
|
||||
std::cout << "Test 1 passed: Task executed periodically." << std::endl;
|
||||
std::cout << "Task executed " << execution_count.load() << " times." << std::endl;
|
||||
assert(execution_count >= 20); // 确保任务执行了至少 20 次
|
||||
std::cout << "Test 1 passed: Task executed periodically." << std::endl;
|
||||
std::cout << "Task executed " << execution_count.load() << " times." << std::endl;
|
||||
}
|
||||
|
||||
// 2. 提前唤醒任务的功能
|
||||
void testWakeUpImmediately() {
|
||||
std::atomic<int> execution_count{0};
|
||||
auto task = [&execution_count]() {
|
||||
execution_count++;
|
||||
};
|
||||
std::atomic<int> execution_count{0};
|
||||
auto task = [&execution_count]() { execution_count++; };
|
||||
|
||||
periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200));
|
||||
periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200));
|
||||
|
||||
// 提前唤醒任务
|
||||
periodic_task.wakeUp();
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(50)); // 等待任务执行
|
||||
// 提前唤醒任务
|
||||
periodic_task.wakeUp();
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(50)); // 等待任务执行
|
||||
|
||||
std::cout << "Execution count after wakeUp: " << execution_count.load() << std::endl;
|
||||
assert(execution_count == 1); // 确保任务立即执行
|
||||
std::cout << "Test 2 passed: Task woke up immediately." << std::endl;
|
||||
std::cout << "Execution count after wakeUp: " << execution_count.load() << std::endl;
|
||||
assert(execution_count == 1); // 确保任务立即执行
|
||||
std::cout << "Test 2 passed: Task woke up immediately." << std::endl;
|
||||
}
|
||||
|
||||
// 3. wakeUpWait() 的等待功能
|
||||
void testWakeUpWait() {
|
||||
std::promise<void> promise;
|
||||
std::future<void> future = promise.get_future();
|
||||
auto task = [&promise]() {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // 模拟任务执行
|
||||
promise.set_value(); // 任务完成时设置 promise
|
||||
};
|
||||
std::promise<void> promise;
|
||||
std::future<void> future = promise.get_future();
|
||||
auto task = [&promise]() {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // 模拟任务执行
|
||||
promise.set_value(); // 任务完成时设置 promise
|
||||
};
|
||||
|
||||
periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200));
|
||||
periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200));
|
||||
|
||||
// 调用 wakeUpWait 并等待任务完成
|
||||
std::future<void> wakeup_future = periodic_task.wakeUpWait();
|
||||
wakeup_future.wait(); // 等待任务完成
|
||||
// 调用 wakeUpWait 并等待任务完成
|
||||
std::future<void> wakeup_future = periodic_task.wakeUpWait();
|
||||
wakeup_future.wait(); // 等待任务完成
|
||||
|
||||
assert(wakeup_future.valid()); // 确保 future 是有效的
|
||||
std::cout << "Test 3 passed: wakeUpWait() works correctly." << std::endl;
|
||||
std::cout << "wakeUpWait() future is valid." << std::endl;
|
||||
assert(wakeup_future.valid()); // 确保 future 是有效的
|
||||
std::cout << "Test 3 passed: wakeUpWait() works correctly." << std::endl;
|
||||
std::cout << "wakeUpWait() future is valid." << std::endl;
|
||||
}
|
||||
|
||||
// 4. 任务抛出异常的处理
|
||||
void testTaskExceptionHandling() {
|
||||
auto task = []() {
|
||||
throw std::runtime_error("Test exception");
|
||||
};
|
||||
auto task = []() { throw std::runtime_error("Test exception"); };
|
||||
|
||||
periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200));
|
||||
periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200));
|
||||
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(300)); // 等待一段时间
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(300)); // 等待一段时间
|
||||
|
||||
std::cout << "Test 4 passed: Task exception is handled correctly." << std::endl;
|
||||
std::cout << "Exception handled and task did not crash." << std::endl;
|
||||
std::cout << "Test 4 passed: Task exception is handled correctly." << std::endl;
|
||||
std::cout << "Exception handled and task did not crash." << std::endl;
|
||||
}
|
||||
|
||||
// 5. 线程是否能正确停止
|
||||
void testTaskStop() {
|
||||
std::atomic<bool> stopped{false};
|
||||
auto task = [&stopped]() {
|
||||
while (!stopped) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
||||
}
|
||||
};
|
||||
std::atomic<bool> stopped{false};
|
||||
auto task = [&stopped]() {
|
||||
while (!stopped) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
||||
}
|
||||
};
|
||||
|
||||
periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(100));
|
||||
periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(100));
|
||||
|
||||
std::this_thread::sleep_for(std::chrono::seconds(1)); // 运行一段时间
|
||||
std::this_thread::sleep_for(std::chrono::seconds(1)); // 运行一段时间
|
||||
|
||||
stopped = true; // 请求停止
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(50)); // 等待线程停止
|
||||
stopped = true; // 请求停止
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(50)); // 等待线程停止
|
||||
|
||||
std::cout << "Test 5 passed: Task thread stops correctly." << std::endl;
|
||||
std::cout << "Task has been stopped successfully." << std::endl;
|
||||
std::cout << "Test 5 passed: Task thread stops correctly." << std::endl;
|
||||
std::cout << "Task has been stopped successfully." << std::endl;
|
||||
}
|
||||
|
||||
// 6. 高频唤醒的情况下任务执行是否正常
|
||||
void testHighFrequencyWakeUp() {
|
||||
std::atomic<int> execution_count{0};
|
||||
auto task = [&execution_count]() {
|
||||
execution_count++;
|
||||
};
|
||||
std::atomic<int> execution_count{0};
|
||||
auto task = [&execution_count]() { execution_count++; };
|
||||
|
||||
periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200));
|
||||
periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200));
|
||||
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
periodic_task.wakeUp();
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10)); // 每 10 毫秒唤醒一次
|
||||
}
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
periodic_task.wakeUp();
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10)); // 每 10 毫秒唤醒一次
|
||||
}
|
||||
|
||||
std::this_thread::sleep_for(std::chrono::seconds(1)); // 等待任务执行完成
|
||||
std::this_thread::sleep_for(std::chrono::seconds(1)); // 等待任务执行完成
|
||||
|
||||
assert(execution_count > 50); // 确保任务至少执行了 50 次
|
||||
std::cout << "Test 6 passed: Task handles frequent wake ups correctly." << std::endl;
|
||||
std::cout << "Task executed " << execution_count.load() << " times." << std::endl;
|
||||
assert(execution_count > 50); // 确保任务至少执行了 50 次
|
||||
std::cout << "Test 6 passed: Task handles frequent wake ups correctly." << std::endl;
|
||||
std::cout << "Task executed " << execution_count.load() << " times." << std::endl;
|
||||
}
|
||||
|
||||
// 7. 多个 wakeUpWait() 调用的处理
|
||||
void testMultipleWakeUpWait() {
|
||||
std::atomic<int> execution_count{0};
|
||||
auto task = [&execution_count]() {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // 模拟任务执行
|
||||
execution_count++;
|
||||
};
|
||||
std::atomic<int> execution_count{0};
|
||||
auto task = [&execution_count]() {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // 模拟任务执行
|
||||
execution_count++;
|
||||
};
|
||||
|
||||
periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200));
|
||||
periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200));
|
||||
|
||||
// 同时调用两个 wakeUpWait
|
||||
std::future<void> future1 = periodic_task.wakeUpWait();
|
||||
std::future<void> future2 = periodic_task.wakeUpWait();
|
||||
// 同时调用两个 wakeUpWait
|
||||
std::future<void> future1 = periodic_task.wakeUpWait();
|
||||
std::future<void> future2 = periodic_task.wakeUpWait();
|
||||
|
||||
future1.wait();
|
||||
future2.wait();
|
||||
future1.wait();
|
||||
future2.wait();
|
||||
|
||||
assert(execution_count == 1); // 确保任务只执行了一次
|
||||
std::cout << "Test 7 passed: Multiple wakeUpWait() calls are handled correctly." << std::endl;
|
||||
std::cout << "Task executed " << execution_count.load() << " times." << std::endl;
|
||||
assert(execution_count == 1); // 确保任务只执行了一次
|
||||
std::cout << "Test 7 passed: Multiple wakeUpWait() calls are handled correctly." << std::endl;
|
||||
std::cout << "Task executed " << execution_count.load() << " times." << std::endl;
|
||||
}
|
||||
|
||||
// 8. 任务函数为空的边界情况
|
||||
void testEmptyTaskFunction() {
|
||||
auto task = []() {
|
||||
// 空任务函数
|
||||
};
|
||||
auto task = []() {
|
||||
// 空任务函数
|
||||
};
|
||||
|
||||
periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(100));
|
||||
periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(100));
|
||||
|
||||
std::this_thread::sleep_for(std::chrono::seconds(1)); // 等待一段时间
|
||||
std::this_thread::sleep_for(std::chrono::seconds(1)); // 等待一段时间
|
||||
|
||||
std::cout << "Test 8 passed: Empty task function works correctly." << std::endl;
|
||||
std::cout << "Empty task function executed without issues." << std::endl;
|
||||
std::cout << "Test 8 passed: Empty task function works correctly." << std::endl;
|
||||
std::cout << "Empty task function executed without issues." << std::endl;
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "Starting tests..." << std::endl;
|
||||
std::cout << "Starting tests..." << std::endl;
|
||||
|
||||
// testWakeUpImmediately();
|
||||
testPeriodicTaskExecution();
|
||||
testWakeUpImmediately();
|
||||
testWakeUpWait();
|
||||
testTaskExceptionHandling();
|
||||
testTaskStop();
|
||||
testHighFrequencyWakeUp();
|
||||
testMultipleWakeUpWait();
|
||||
testEmptyTaskFunction();
|
||||
// testWakeUpImmediately();
|
||||
testPeriodicTaskExecution();
|
||||
testWakeUpImmediately();
|
||||
testWakeUpWait();
|
||||
testTaskExceptionHandling();
|
||||
testTaskStop();
|
||||
testHighFrequencyWakeUp();
|
||||
testMultipleWakeUpWait();
|
||||
testEmptyTaskFunction();
|
||||
|
||||
std::cout << "All tests passed!" << std::endl;
|
||||
std::cout << "All tests passed!" << std::endl;
|
||||
|
||||
return 0;
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
#include "scheduler.h"
|
||||
#include <memory>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <memory>
|
||||
#include "scheduler.h"
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
|
@ -16,19 +16,25 @@ PYBIND11_MODULE(sched_ext, m) {
|
|||
.def_readwrite("layer_count", &scheduler::ModelSettings::layer_count)
|
||||
.def_readwrite("num_k_heads", &scheduler::ModelSettings::num_k_heads)
|
||||
.def_readwrite("k_head_dim", &scheduler::ModelSettings::k_head_dim)
|
||||
.def_readwrite("bytes_per_params", &scheduler::ModelSettings::bytes_per_params)
|
||||
.def_readwrite("bytes_per_kv_cache_element", &scheduler::ModelSettings::bytes_per_kv_cache_element)
|
||||
.def_readwrite("bytes_per_params",
|
||||
&scheduler::ModelSettings::bytes_per_params)
|
||||
.def_readwrite("bytes_per_kv_cache_element",
|
||||
&scheduler::ModelSettings::bytes_per_kv_cache_element)
|
||||
.def("params_size", &scheduler::ModelSettings::params_nbytes)
|
||||
.def("bytes_per_token_kv_cache", &scheduler::ModelSettings::bytes_per_token_kv_cache)
|
||||
.def("bytes_per_token_kv_cache",
|
||||
&scheduler::ModelSettings::bytes_per_token_kv_cache)
|
||||
// 添加 pickle 支持
|
||||
.def(py::pickle(
|
||||
[](const scheduler::ModelSettings& self) { // __getstate__
|
||||
return py::make_tuple(self.params_count, self.layer_count, self.num_k_heads, self.k_head_dim,
|
||||
self.bytes_per_params, self.bytes_per_kv_cache_element);
|
||||
[](const scheduler::ModelSettings &self) { // __getstate__
|
||||
return py::make_tuple(self.params_count, self.layer_count,
|
||||
self.num_k_heads, self.k_head_dim,
|
||||
self.bytes_per_params,
|
||||
self.bytes_per_kv_cache_element);
|
||||
},
|
||||
[](py::tuple t) { // __setstate__
|
||||
[](py::tuple t) { // __setstate__
|
||||
if (t.size() != 6)
|
||||
throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size()));
|
||||
throw std::runtime_error("Invalid state! t.size() = " +
|
||||
std::to_string(t.size()));
|
||||
scheduler::ModelSettings ms;
|
||||
ms.params_count = t[0].cast<size_t>();
|
||||
ms.layer_count = t[1].cast<size_t>();
|
||||
|
@ -40,22 +46,24 @@ PYBIND11_MODULE(sched_ext, m) {
|
|||
}));
|
||||
|
||||
py::class_<scheduler::SampleOptions>(m, "SampleOptions")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("temperature", &scheduler::SampleOptions::temperature)
|
||||
.def_readwrite("top_p", &scheduler::SampleOptions::top_p) // 确保 top_p 也能被访问
|
||||
.def(py::pickle(
|
||||
[](const scheduler::SampleOptions& self) {
|
||||
return py::make_tuple(self.temperature, self.top_p); // 序列化 temperature 和 top_p
|
||||
},
|
||||
[](py::tuple t) {
|
||||
if (t.size() != 2) // 确保解包时参数数量匹配
|
||||
throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size()));
|
||||
.def(py::init<>())
|
||||
.def_readwrite("temperature", &scheduler::SampleOptions::temperature)
|
||||
.def_readwrite("top_p",
|
||||
&scheduler::SampleOptions::top_p) // 确保 top_p 也能被访问
|
||||
.def(py::pickle(
|
||||
[](const scheduler::SampleOptions &self) {
|
||||
return py::make_tuple(self.temperature,
|
||||
self.top_p); // 序列化 temperature 和 top_p
|
||||
},
|
||||
[](py::tuple t) {
|
||||
if (t.size() != 2) // 确保解包时参数数量匹配
|
||||
throw std::runtime_error("Invalid state! t.size() = " +
|
||||
std::to_string(t.size()));
|
||||
scheduler::SampleOptions so;
|
||||
so.temperature = t[0].cast<double>();
|
||||
so.top_p = t[1].cast<double>(); // 反序列化 top_p
|
||||
so.top_p = t[1].cast<double>(); // 反序列化 top_p
|
||||
return so;
|
||||
}
|
||||
));
|
||||
}));
|
||||
|
||||
py::class_<scheduler::Settings>(m, "Settings")
|
||||
.def(py::init<>())
|
||||
|
@ -65,33 +73,43 @@ PYBIND11_MODULE(sched_ext, m) {
|
|||
.def_readwrite("page_size", &scheduler::Settings::page_size)
|
||||
.def_readwrite("gpu_device_id", &scheduler::Settings::gpu_device_id)
|
||||
.def_readwrite("gpu_memory_size", &scheduler::Settings::gpu_memory_size)
|
||||
.def_readwrite("memory_utilization_percentage", &scheduler::Settings::memory_utilization_percentage)
|
||||
.def_readwrite("memory_utilization_percentage",
|
||||
&scheduler::Settings::memory_utilization_percentage)
|
||||
.def_readwrite("max_batch_size", &scheduler::Settings::max_batch_size)
|
||||
.def_readwrite("recommended_chunk_prefill_token_count",
|
||||
&scheduler::Settings::recommended_chunk_prefill_token_count)
|
||||
.def_readwrite(
|
||||
"recommended_chunk_prefill_token_count",
|
||||
&scheduler::Settings::recommended_chunk_prefill_token_count)
|
||||
.def_readwrite("sample_options", &scheduler::Settings::sample_options)
|
||||
.def_readwrite("sched_metrics_port", &scheduler::Settings::sched_metrics_port)
|
||||
.def_readwrite("sched_metrics_port",
|
||||
&scheduler::Settings::sched_metrics_port)
|
||||
.def_readwrite("gpu_only", &scheduler::Settings::gpu_only)
|
||||
.def_readwrite("use_self_defined_head_dim", &scheduler::Settings::use_self_defined_head_dim)
|
||||
.def_readwrite("self_defined_head_dim", &scheduler::Settings::self_defined_head_dim)
|
||||
.def_readwrite("full_kv_cache_on_each_gpu", &scheduler::Settings::full_kv_cache_on_each_gpu)
|
||||
.def_readwrite("use_self_defined_head_dim",
|
||||
&scheduler::Settings::use_self_defined_head_dim)
|
||||
.def_readwrite("self_defined_head_dim",
|
||||
&scheduler::Settings::self_defined_head_dim)
|
||||
.def_readwrite("full_kv_cache_on_each_gpu",
|
||||
&scheduler::Settings::full_kv_cache_on_each_gpu)
|
||||
.def_readwrite("k_cache_on", &scheduler::Settings::k_cache_on)
|
||||
.def_readwrite("v_cache_on", &scheduler::Settings::v_cache_on)
|
||||
.def_readwrite("kvc2_config_path", &scheduler::Settings::kvc2_config_path)
|
||||
.def_readwrite("kvc2_root_path", &scheduler::Settings::kvc2_root_path)
|
||||
.def_readwrite("memory_pool_size_GB", &scheduler::Settings::memory_pool_size_GB)
|
||||
.def_readwrite("memory_pool_size_GB",
|
||||
&scheduler::Settings::memory_pool_size_GB)
|
||||
.def_readwrite("evict_count", &scheduler::Settings::evict_count)
|
||||
.def_readwrite("strategy_name", &scheduler::Settings::strategy_name)
|
||||
.def_readwrite("kvc2_metrics_port", &scheduler::Settings::kvc2_metrics_port)
|
||||
.def_readwrite("kvc2_metrics_port",
|
||||
&scheduler::Settings::kvc2_metrics_port)
|
||||
.def_readwrite("load_from_disk", &scheduler::Settings::load_from_disk)
|
||||
.def_readwrite("save_to_disk", &scheduler::Settings::save_to_disk)
|
||||
// derived
|
||||
.def_readwrite("gpu_device_count", &scheduler::Settings::gpu_device_count)
|
||||
.def_readwrite("total_kvcache_pages", &scheduler::Settings::total_kvcache_pages)
|
||||
.def_readwrite("total_kvcache_pages",
|
||||
&scheduler::Settings::total_kvcache_pages)
|
||||
.def_readwrite("devices", &scheduler::Settings::devices)
|
||||
.def("auto_derive", &scheduler::Settings::auto_derive);
|
||||
|
||||
py::class_<scheduler::BatchQueryTodo, std::shared_ptr<scheduler::BatchQueryTodo>>(m, "BatchQueryTodo")
|
||||
py::class_<scheduler::BatchQueryTodo,
|
||||
std::shared_ptr<scheduler::BatchQueryTodo>>(m, "BatchQueryTodo")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("query_ids", &scheduler::BatchQueryTodo::query_ids)
|
||||
.def_readwrite("query_tokens", &scheduler::BatchQueryTodo::query_tokens)
|
||||
|
@ -99,31 +117,42 @@ PYBIND11_MODULE(sched_ext, m) {
|
|||
.def_readwrite("block_indexes", &scheduler::BatchQueryTodo::block_indexes)
|
||||
.def_readwrite("attn_masks", &scheduler::BatchQueryTodo::attn_masks)
|
||||
.def_readwrite("rope_ranges", &scheduler::BatchQueryTodo::rope_ranges)
|
||||
.def_readwrite("sample_options", &scheduler::BatchQueryTodo::sample_options)
|
||||
.def_readwrite("prefill_mini_batches", &scheduler::BatchQueryTodo::prefill_mini_batches)
|
||||
.def_readwrite("decode_mini_batches", &scheduler::BatchQueryTodo::decode_mini_batches)
|
||||
.def_readwrite("sample_options",
|
||||
&scheduler::BatchQueryTodo::sample_options)
|
||||
.def_readwrite("prefill_mini_batches",
|
||||
&scheduler::BatchQueryTodo::prefill_mini_batches)
|
||||
.def_readwrite("decode_mini_batches",
|
||||
&scheduler::BatchQueryTodo::decode_mini_batches)
|
||||
.def_readwrite("stop_criteria", &scheduler::BatchQueryTodo::stop_criteria)
|
||||
.def("debug", &scheduler::BatchQueryTodo::debug)
|
||||
.def(py::pickle(
|
||||
[](const scheduler::BatchQueryTodo& self) {
|
||||
return py::make_tuple(self.query_ids, self.query_tokens, self.query_lengths, self.block_indexes,
|
||||
self.attn_masks, self.rope_ranges, self.sample_options, self.prefill_mini_batches,
|
||||
self.decode_mini_batches, self.stop_criteria);
|
||||
[](const scheduler::BatchQueryTodo &self) {
|
||||
return py::make_tuple(
|
||||
self.query_ids, self.query_tokens, self.query_lengths,
|
||||
self.block_indexes, self.attn_masks, self.rope_ranges,
|
||||
self.sample_options, self.prefill_mini_batches,
|
||||
self.decode_mini_batches, self.stop_criteria);
|
||||
},
|
||||
[](py::tuple t) {
|
||||
if (t.size() != 10)
|
||||
throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size()));
|
||||
throw std::runtime_error("Invalid state! t.size() = " +
|
||||
std::to_string(t.size()));
|
||||
scheduler::BatchQueryTodo bqt;
|
||||
bqt.query_ids = t[0].cast<std::vector<scheduler::QueryID>>();
|
||||
bqt.query_tokens = t[1].cast<std::vector<torch::Tensor>>();
|
||||
bqt.query_lengths = t[2].cast<std::vector<scheduler::TokenLength>>();
|
||||
bqt.query_lengths =
|
||||
t[2].cast<std::vector<scheduler::TokenLength>>();
|
||||
bqt.block_indexes = t[3].cast<std::vector<torch::Tensor>>();
|
||||
bqt.attn_masks = t[4].cast<std::optional<torch::Tensor>>();
|
||||
bqt.rope_ranges = t[5].cast<std::optional<torch::Tensor>>();
|
||||
bqt.sample_options = t[6].cast<std::vector<scheduler::SampleOptions>>();
|
||||
bqt.prefill_mini_batches = t[7].cast<std::vector<scheduler::PrefillTask>>();
|
||||
bqt.decode_mini_batches = t[8].cast<std::vector<std::vector<scheduler::QueryID>>>();
|
||||
bqt.stop_criteria = t[9].cast<std::vector<std::vector<std::vector<int>>>>();
|
||||
bqt.sample_options =
|
||||
t[6].cast<std::vector<scheduler::SampleOptions>>();
|
||||
bqt.prefill_mini_batches =
|
||||
t[7].cast<std::vector<scheduler::PrefillTask>>();
|
||||
bqt.decode_mini_batches =
|
||||
t[8].cast<std::vector<std::vector<scheduler::QueryID>>>();
|
||||
bqt.stop_criteria =
|
||||
t[9].cast<std::vector<std::vector<std::vector<int>>>>();
|
||||
return bqt;
|
||||
}));
|
||||
|
||||
|
@ -133,16 +162,20 @@ PYBIND11_MODULE(sched_ext, m) {
|
|||
.def_readwrite("ok", &scheduler::QueryUpdate::ok)
|
||||
.def_readwrite("is_prefill", &scheduler::QueryUpdate::is_prefill)
|
||||
.def_readwrite("decode_done", &scheduler::QueryUpdate::decode_done)
|
||||
.def_readwrite("active_position", &scheduler::QueryUpdate::active_position)
|
||||
.def_readwrite("generated_token", &scheduler::QueryUpdate::generated_token)
|
||||
.def_readwrite("active_position",
|
||||
&scheduler::QueryUpdate::active_position)
|
||||
.def_readwrite("generated_token",
|
||||
&scheduler::QueryUpdate::generated_token)
|
||||
.def(py::pickle(
|
||||
[](const scheduler::QueryUpdate& self) {
|
||||
return py::make_tuple(self.id, self.ok, self.is_prefill, self.decode_done, self.active_position,
|
||||
[](const scheduler::QueryUpdate &self) {
|
||||
return py::make_tuple(self.id, self.ok, self.is_prefill,
|
||||
self.decode_done, self.active_position,
|
||||
self.generated_token);
|
||||
},
|
||||
[](py::tuple t) {
|
||||
if (t.size() != 6)
|
||||
throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size()));
|
||||
throw std::runtime_error("Invalid state! t.size() = " +
|
||||
std::to_string(t.size()));
|
||||
scheduler::QueryUpdate qu;
|
||||
qu.id = t[0].cast<scheduler::QueryID>();
|
||||
qu.ok = t[1].cast<bool>();
|
||||
|
@ -156,8 +189,7 @@ PYBIND11_MODULE(sched_ext, m) {
|
|||
py::class_<scheduler::InferenceContext>(m, "InferenceContext")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("k_cache", &scheduler::InferenceContext::k_cache)
|
||||
.def_readwrite("v_cache", &scheduler::InferenceContext::v_cache)
|
||||
;
|
||||
.def_readwrite("v_cache", &scheduler::InferenceContext::v_cache);
|
||||
|
||||
py::class_<scheduler::QueryAdd>(m, "QueryAdd")
|
||||
.def(py::init<>())
|
||||
|
@ -173,15 +205,18 @@ PYBIND11_MODULE(sched_ext, m) {
|
|||
.def("serialize", &scheduler::QueryAdd::serialize)
|
||||
.def_static("deserialize", &scheduler::QueryAdd::deserialize)
|
||||
.def(py::pickle(
|
||||
[](const scheduler::QueryAdd& self) {
|
||||
[](const scheduler::QueryAdd &self) {
|
||||
return py::make_tuple(self.query_token,
|
||||
// self.attn_mask,
|
||||
self.query_length, self.estimated_length, self.sample_options, self.user_id,
|
||||
self.SLO_TTFT_ms, self.SLO_TBT_ms, self.stop_criteria);
|
||||
self.query_length, self.estimated_length,
|
||||
self.sample_options, self.user_id,
|
||||
self.SLO_TTFT_ms, self.SLO_TBT_ms,
|
||||
self.stop_criteria);
|
||||
},
|
||||
[](py::tuple t) {
|
||||
if (t.size() != 8)
|
||||
throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size()));
|
||||
throw std::runtime_error("Invalid state! t.size() = " +
|
||||
std::to_string(t.size()));
|
||||
scheduler::QueryAdd qa;
|
||||
qa.query_token = t[0].cast<std::vector<scheduler::Token>>();
|
||||
// qa.attn_mask = t[1].cast<torch::Tensor>();
|
||||
|
@ -195,14 +230,20 @@ PYBIND11_MODULE(sched_ext, m) {
|
|||
return qa;
|
||||
}));
|
||||
|
||||
py::class_<scheduler::Scheduler, std::shared_ptr<scheduler::Scheduler>>(m, "Scheduler")
|
||||
py::class_<scheduler::Scheduler, std::shared_ptr<scheduler::Scheduler>>(
|
||||
m, "Scheduler")
|
||||
.def("init", &scheduler::Scheduler::init)
|
||||
.def("run", &scheduler::Scheduler::run)
|
||||
.def("stop", &scheduler::Scheduler::stop)
|
||||
.def("add_query", &scheduler::Scheduler::add_query, py::call_guard<py::gil_scoped_release>())
|
||||
.def("cancel_query", &scheduler::Scheduler::cancel_query, py::call_guard<py::gil_scoped_release>())
|
||||
.def("update_last_batch", &scheduler::Scheduler::update_last_batch, py::call_guard<py::gil_scoped_release>())
|
||||
.def("get_inference_context", &scheduler::Scheduler::get_inference_context);
|
||||
.def("add_query", &scheduler::Scheduler::add_query,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("cancel_query", &scheduler::Scheduler::cancel_query,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("update_last_batch", &scheduler::Scheduler::update_last_batch,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("get_inference_context",
|
||||
&scheduler::Scheduler::get_inference_context);
|
||||
|
||||
m.def("create_scheduler", &scheduler::create_scheduler, "Create a new Scheduler instance");
|
||||
m.def("create_scheduler", &scheduler::create_scheduler,
|
||||
"Create a new Scheduler instance");
|
||||
}
|
||||
|
|
|
@ -2,89 +2,101 @@
|
|||
#include <iostream>
|
||||
|
||||
// 构造函数
|
||||
Metrics::Metrics(const MetricsConfig& config)
|
||||
Metrics::Metrics(const MetricsConfig &config)
|
||||
: registry_(std::make_shared<prometheus::Registry>()),
|
||||
exposer_(config.endpoint),
|
||||
stop_uptime_thread_(false),
|
||||
exposer_(config.endpoint), stop_uptime_thread_(false),
|
||||
start_time_(std::chrono::steady_clock::now()) {
|
||||
// 定义统一的桶大小,最大为 10000 ms (10 s)
|
||||
std::vector<double> common_buckets = {0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0, 5.0,
|
||||
10.0, 50.0, 100.0, 500.0, 1000.0, 5000.0, 10000.0}; // 毫秒
|
||||
std::vector<double> common_buckets = {
|
||||
0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0, 5.0,
|
||||
10.0, 50.0, 100.0, 500.0, 1000.0, 5000.0, 10000.0}; // 毫秒
|
||||
|
||||
// 注册 TTFT_ms Histogram
|
||||
auto& TTFT_family = prometheus::BuildHistogram()
|
||||
auto &TTFT_family = prometheus::BuildHistogram()
|
||||
.Name(std::string(METRIC_PREFIX) + "_TTFT_ms")
|
||||
.Help("Time to first token in milliseconds")
|
||||
.Register(*registry_);
|
||||
TTFT_ms = &TTFT_family.Add({{"model", config.model_name}}, common_buckets);
|
||||
|
||||
// 注册 TBT_ms Histogram
|
||||
auto& TBT_family = prometheus::BuildHistogram()
|
||||
auto &TBT_family = prometheus::BuildHistogram()
|
||||
.Name(std::string(METRIC_PREFIX) + "_TBT_ms")
|
||||
.Help("Time between tokens in milliseconds")
|
||||
.Register(*registry_);
|
||||
TBT_ms = &TBT_family.Add({{"model", config.model_name}}, common_buckets);
|
||||
|
||||
// 注册 schedule_time Histogram
|
||||
auto& schedule_time_family = prometheus::BuildHistogram()
|
||||
.Name(std::string(METRIC_PREFIX) + "_schedule_time_ms")
|
||||
.Help("Time to generate schedule in milliseconds")
|
||||
.Register(*registry_);
|
||||
schedule_time = &schedule_time_family.Add({{"model", config.model_name}}, common_buckets);
|
||||
auto &schedule_time_family =
|
||||
prometheus::BuildHistogram()
|
||||
.Name(std::string(METRIC_PREFIX) + "_schedule_time_ms")
|
||||
.Help("Time to generate schedule in milliseconds")
|
||||
.Register(*registry_);
|
||||
schedule_time =
|
||||
&schedule_time_family.Add({{"model", config.model_name}}, common_buckets);
|
||||
|
||||
// 注册 generated_tokens Counter
|
||||
auto& generated_tokens_family = prometheus::BuildCounter()
|
||||
.Name(std::string(METRIC_PREFIX) + "_generated_tokens_total")
|
||||
.Help("Total generated tokens")
|
||||
.Register(*registry_);
|
||||
generated_tokens = &generated_tokens_family.Add({{"model", config.model_name}});
|
||||
auto &generated_tokens_family =
|
||||
prometheus::BuildCounter()
|
||||
.Name(std::string(METRIC_PREFIX) + "_generated_tokens_total")
|
||||
.Help("Total generated tokens")
|
||||
.Register(*registry_);
|
||||
generated_tokens =
|
||||
&generated_tokens_family.Add({{"model", config.model_name}});
|
||||
|
||||
// 注册 throughput_query Gauge
|
||||
auto& throughput_query_family = prometheus::BuildGauge()
|
||||
.Name(std::string(METRIC_PREFIX) + "_throughput_query")
|
||||
.Help("Throughput per second based on queries")
|
||||
.Register(*registry_);
|
||||
throughput_query = &throughput_query_family.Add({{"model", config.model_name}});
|
||||
auto &throughput_query_family =
|
||||
prometheus::BuildGauge()
|
||||
.Name(std::string(METRIC_PREFIX) + "_throughput_query")
|
||||
.Help("Throughput per second based on queries")
|
||||
.Register(*registry_);
|
||||
throughput_query =
|
||||
&throughput_query_family.Add({{"model", config.model_name}});
|
||||
|
||||
// 注册 throughput_generated_tokens Gauge
|
||||
auto& throughput_generated_tokens_family = prometheus::BuildGauge()
|
||||
.Name(std::string(METRIC_PREFIX) + "_throughput_generated_tokens")
|
||||
.Help("Throughput per second based on generated tokens")
|
||||
.Register(*registry_);
|
||||
throughput_generated_tokens = &throughput_generated_tokens_family.Add({{"model", config.model_name}});
|
||||
auto &throughput_generated_tokens_family =
|
||||
prometheus::BuildGauge()
|
||||
.Name(std::string(METRIC_PREFIX) + "_throughput_generated_tokens")
|
||||
.Help("Throughput per second based on generated tokens")
|
||||
.Register(*registry_);
|
||||
throughput_generated_tokens =
|
||||
&throughput_generated_tokens_family.Add({{"model", config.model_name}});
|
||||
|
||||
// 注册 event_count Counter family
|
||||
event_count_family_ = &prometheus::BuildCounter()
|
||||
.Name(std::string(METRIC_PREFIX) + "_event_count_total")
|
||||
.Help("Count of various events")
|
||||
.Register(*registry_);
|
||||
event_count_family_ =
|
||||
&prometheus::BuildCounter()
|
||||
.Name(std::string(METRIC_PREFIX) + "_event_count_total")
|
||||
.Help("Count of various events")
|
||||
.Register(*registry_);
|
||||
|
||||
batch_count_family_ = &prometheus::BuildCounter()
|
||||
.Name(std::string(METRIC_PREFIX) + "_batch_count_total")
|
||||
.Help("Count of various batch by status")
|
||||
.Register(*registry_);
|
||||
batch_count_family_ =
|
||||
&prometheus::BuildCounter()
|
||||
.Name(std::string(METRIC_PREFIX) + "_batch_count_total")
|
||||
.Help("Count of various batch by status")
|
||||
.Register(*registry_);
|
||||
|
||||
// 注册 query_count Counter family
|
||||
query_count_family_ = &prometheus::BuildCounter()
|
||||
.Name(std::string(METRIC_PREFIX) + "_query_count_total")
|
||||
.Help("Count of queries by status")
|
||||
.Register(*registry_);
|
||||
query_count_family_ =
|
||||
&prometheus::BuildCounter()
|
||||
.Name(std::string(METRIC_PREFIX) + "_query_count_total")
|
||||
.Help("Count of queries by status")
|
||||
.Register(*registry_);
|
||||
|
||||
// 注册 uptime_ms Gauge
|
||||
auto& uptime_family = prometheus::BuildGauge()
|
||||
auto &uptime_family = prometheus::BuildGauge()
|
||||
.Name(std::string(METRIC_PREFIX) + "_uptime_ms")
|
||||
.Help("Uptime of the scheduler in milliseconds")
|
||||
.Register(*registry_);
|
||||
uptime_ms = &uptime_family.Add({{"model", config.model_name}});
|
||||
|
||||
// 注册 GPU 利用率 Gauges
|
||||
auto& gpu_util_family = prometheus::BuildGauge()
|
||||
.Name(std::string(METRIC_PREFIX) + "_gpu_utilization_ratio")
|
||||
.Help("Current GPU utilization ratio (0 to 1)")
|
||||
.Register(*registry_);
|
||||
auto &gpu_util_family =
|
||||
prometheus::BuildGauge()
|
||||
.Name(std::string(METRIC_PREFIX) + "_gpu_utilization_ratio")
|
||||
.Help("Current GPU utilization ratio (0 to 1)")
|
||||
.Register(*registry_);
|
||||
for (size_t i = 0; i < config.gpu_count; ++i) {
|
||||
gpu_utilization_gauges.push_back(
|
||||
&gpu_util_family.Add({{"gpu_id", std::to_string(i)}, {"model", config.model_name}}));
|
||||
gpu_utilization_gauges.push_back(&gpu_util_family.Add(
|
||||
{{"gpu_id", std::to_string(i)}, {"model", config.model_name}}));
|
||||
}
|
||||
|
||||
// 将 Registry 注册到 Exposer 中
|
||||
|
@ -95,16 +107,15 @@ Metrics::Metrics(const MetricsConfig& config)
|
|||
}
|
||||
|
||||
// 析构函数
|
||||
Metrics::~Metrics() {
|
||||
StopUptimeUpdater();
|
||||
}
|
||||
Metrics::~Metrics() { StopUptimeUpdater(); }
|
||||
|
||||
// 启动 uptime 更新线程
|
||||
void Metrics::StartUptimeUpdater() {
|
||||
uptime_thread_ = std::thread([this]() {
|
||||
while (!stop_uptime_thread_) {
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
std::chrono::duration<double, std::milli> uptime_duration = now - start_time_;
|
||||
std::chrono::duration<double, std::milli> uptime_duration =
|
||||
now - start_time_;
|
||||
uptime_ms->Set(uptime_duration.count());
|
||||
// fn_every_sec(this);
|
||||
std::this_thread::sleep_for(std::chrono::seconds(1));
|
||||
|
@ -121,15 +132,16 @@ void Metrics::StopUptimeUpdater() {
|
|||
}
|
||||
|
||||
// 获取 event_count 指标
|
||||
prometheus::Counter* Metrics::event_count(const std::string& type) {
|
||||
return &event_count_family_->Add({{"type", type}}); // 可根据需要添加更多标签
|
||||
prometheus::Counter *Metrics::event_count(const std::string &type) {
|
||||
return &event_count_family_->Add({{"type", type}}); // 可根据需要添加更多标签
|
||||
}
|
||||
|
||||
// 获取 query_count 指标
|
||||
prometheus::Counter* Metrics::query_count(const std::string& status) {
|
||||
return &query_count_family_->Add({{"status", status}}); // 可根据需要添加更多标签
|
||||
prometheus::Counter *Metrics::query_count(const std::string &status) {
|
||||
return &query_count_family_->Add(
|
||||
{{"status", status}}); // 可根据需要添加更多标签
|
||||
}
|
||||
|
||||
prometheus::Counter* Metrics::batch_count(const std::string& type) {
|
||||
prometheus::Counter *Metrics::batch_count(const std::string &type) {
|
||||
return &batch_count_family_->Add({{"type", type}});
|
||||
}
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
#ifndef Metrics_H
|
||||
#define Metrics_H
|
||||
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <memory>
|
||||
#include <prometheus/counter.h>
|
||||
#include <prometheus/exposer.h>
|
||||
#include <prometheus/gauge.h>
|
||||
#include <prometheus/histogram.h>
|
||||
#include <prometheus/registry.h>
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
@ -21,46 +21,46 @@ class Metrics;
|
|||
// 配置结构体
|
||||
struct MetricsConfig {
|
||||
std::string endpoint;
|
||||
std::string model_name; // 模型名称,如 "gpt-4"
|
||||
size_t gpu_count; // GPU数量
|
||||
std::string model_name; // 模型名称,如 "gpt-4"
|
||||
size_t gpu_count; // GPU数量
|
||||
};
|
||||
|
||||
// Metrics 类,根据配置初始化 Prometheus 指标
|
||||
class Metrics {
|
||||
public:
|
||||
public:
|
||||
// 构造函数传入 MetricsConfig
|
||||
Metrics(const MetricsConfig& config);
|
||||
Metrics(const MetricsConfig &config);
|
||||
~Metrics();
|
||||
|
||||
// 禁止拷贝和赋值
|
||||
Metrics(const Metrics&) = delete;
|
||||
Metrics& operator=(const Metrics&) = delete;
|
||||
Metrics(const Metrics &) = delete;
|
||||
Metrics &operator=(const Metrics &) = delete;
|
||||
|
||||
std::function<void(Metrics*)> fn_every_sec;
|
||||
std::function<void(Metrics *)> fn_every_sec;
|
||||
|
||||
// 指标指针
|
||||
prometheus::Gauge* uptime_ms;
|
||||
prometheus::Histogram* TTFT_ms;
|
||||
prometheus::Histogram* TBT_ms;
|
||||
prometheus::Histogram* schedule_time;
|
||||
prometheus::Gauge* throughput_query;
|
||||
prometheus::Gauge* throughput_generated_tokens;
|
||||
prometheus::Counter* generated_tokens;
|
||||
std::vector<prometheus::Gauge*> gpu_utilization_gauges;
|
||||
prometheus::Gauge *uptime_ms;
|
||||
prometheus::Histogram *TTFT_ms;
|
||||
prometheus::Histogram *TBT_ms;
|
||||
prometheus::Histogram *schedule_time;
|
||||
prometheus::Gauge *throughput_query;
|
||||
prometheus::Gauge *throughput_generated_tokens;
|
||||
prometheus::Counter *generated_tokens;
|
||||
std::vector<prometheus::Gauge *> gpu_utilization_gauges;
|
||||
|
||||
// 计数器家族
|
||||
prometheus::Counter* event_count(const std::string& type);
|
||||
prometheus::Counter* query_count(const std::string& status);
|
||||
prometheus::Counter* batch_count(const std::string& type);
|
||||
prometheus::Counter *event_count(const std::string &type);
|
||||
prometheus::Counter *query_count(const std::string &status);
|
||||
prometheus::Counter *batch_count(const std::string &type);
|
||||
|
||||
private:
|
||||
private:
|
||||
std::shared_ptr<prometheus::Registry> registry_;
|
||||
prometheus::Exposer exposer_;
|
||||
|
||||
// 计数器家族
|
||||
prometheus::Family<prometheus::Counter>* event_count_family_;
|
||||
prometheus::Family<prometheus::Counter>* batch_count_family_;
|
||||
prometheus::Family<prometheus::Counter>* query_count_family_;
|
||||
prometheus::Family<prometheus::Counter> *event_count_family_;
|
||||
prometheus::Family<prometheus::Counter> *batch_count_family_;
|
||||
prometheus::Family<prometheus::Counter> *query_count_family_;
|
||||
|
||||
// 线程和控制变量用于更新 uptime_ms
|
||||
std::thread uptime_thread_;
|
||||
|
@ -76,10 +76,13 @@ class Metrics {
|
|||
};
|
||||
|
||||
struct HistogramTimerWrapper {
|
||||
prometheus::Histogram* histogram;
|
||||
prometheus::Histogram *histogram;
|
||||
Timer timer;
|
||||
inline HistogramTimerWrapper(prometheus::Histogram* histogram) : histogram(histogram), timer() { timer.start(); }
|
||||
inline HistogramTimerWrapper(prometheus::Histogram *histogram)
|
||||
: histogram(histogram), timer() {
|
||||
timer.start();
|
||||
}
|
||||
inline ~HistogramTimerWrapper() { histogram->Observe(timer.elapsedMs()); }
|
||||
};
|
||||
|
||||
#endif // Metrics_H
|
||||
#endif // Metrics_H
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
#ifndef __MODEL_CONFIG_HPP_
|
||||
#define __MODEL_CONFIG_HPP_
|
||||
|
||||
#include <iostream>
|
||||
#include "nlohmann/json.hpp"
|
||||
#include <iostream>
|
||||
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
|
@ -13,7 +13,7 @@ using ModelName = std::string;
|
|||
|
||||
// We must assure this can be load by config.json
|
||||
class ModelConfig {
|
||||
public:
|
||||
public:
|
||||
DimSize hidden_size;
|
||||
DimSize intermediate_size;
|
||||
size_t max_position_embeddings;
|
||||
|
@ -23,10 +23,13 @@ class ModelConfig {
|
|||
size_t num_key_value_heads;
|
||||
size_t vocab_size;
|
||||
|
||||
NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, intermediate_size, max_position_embeddings, model_type,
|
||||
num_attention_heads, num_hidden_layers, num_key_value_heads, vocab_size);
|
||||
NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, intermediate_size,
|
||||
max_position_embeddings, model_type,
|
||||
num_attention_heads, num_hidden_layers,
|
||||
num_key_value_heads, vocab_size);
|
||||
|
||||
void load_from(std::filesystem::path path) {
|
||||
std::cout << "Load from " << path << std::endl;
|
||||
std::ifstream i(path);
|
||||
nlohmann::json j;
|
||||
i >> j;
|
||||
|
@ -38,12 +41,14 @@ using QuantType = std::string;
|
|||
static const QuantType NoQuantType = "";
|
||||
|
||||
class QuantConfig {
|
||||
public:
|
||||
public:
|
||||
QuantType name;
|
||||
|
||||
// For GEMV
|
||||
QuantType type_of_dot_vector = NoQuantType;
|
||||
inline bool can_be_used_as_matrix() { return type_of_dot_vector != NoQuantType; }
|
||||
inline bool can_be_used_as_matrix() {
|
||||
return type_of_dot_vector != NoQuantType;
|
||||
}
|
||||
|
||||
bool can_be_used_as_vector;
|
||||
|
||||
|
@ -56,8 +61,11 @@ class QuantConfig {
|
|||
|
||||
URL reference = "";
|
||||
|
||||
NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(QuantConfig, name, type_of_dot_vector, can_be_used_as_vector,
|
||||
bytes_per_element, has_scale, has_min, block_element_count,
|
||||
NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(QuantConfig, name,
|
||||
type_of_dot_vector,
|
||||
can_be_used_as_vector,
|
||||
bytes_per_element, has_scale,
|
||||
has_min, block_element_count,
|
||||
block_element_size, reference);
|
||||
};
|
||||
|
||||
|
@ -70,14 +78,13 @@ inline void load_quant_configs(std::filesystem::path path) {
|
|||
std::cout << __FUNCTION__ << " from " << path << std::endl;
|
||||
std::ifstream i(path);
|
||||
i >> j;
|
||||
quant_configs = j.get<std::map<QuantType, QuantConfig>>();
|
||||
std::cout << "Loaded Quant Configs" << std::endl;
|
||||
for (auto &[k, v] : quant_configs) {
|
||||
std::cout << " - " << k << std::endl;
|
||||
}
|
||||
} else {
|
||||
std::cout << __FUNCTION__ << " create new at " << path << std::endl;
|
||||
}
|
||||
|
||||
quant_configs = j.get<std::map<QuantType, QuantConfig>>();
|
||||
std::cout << "Loaded Quant Configs" << std::endl;
|
||||
for (auto& [k, v] : quant_configs) {
|
||||
std::cout << " - " << k << std::endl;
|
||||
std::cout << __FUNCTION__ << " no file at " << path << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -93,14 +100,13 @@ inline void load_model_configs(std::filesystem::path path) {
|
|||
std::cout << __FUNCTION__ << " from " << path << std::endl;
|
||||
std::ifstream i(path);
|
||||
i >> j;
|
||||
model_configs = j.get<std::map<ModelName, ModelConfig>>();
|
||||
std::cout << "Loaded Model Configs" << std::endl;
|
||||
for (auto &[k, v] : model_configs) {
|
||||
std::cout << " - " << k << std::endl;
|
||||
}
|
||||
} else {
|
||||
std::cout << __FUNCTION__ << " create new at " << path << std::endl;
|
||||
}
|
||||
|
||||
model_configs = j.get<std::map<ModelName, ModelConfig>>();
|
||||
std::cout << "Loaded Model Configs" << std::endl;
|
||||
for (auto& [k, v] : model_configs) {
|
||||
std::cout << " - " << k << std::endl;
|
||||
std::cout << __FUNCTION__ << " no file at " << path << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,20 +3,20 @@
|
|||
#include "nlohmann/json.hpp"
|
||||
#include "spdlog/spdlog.h"
|
||||
|
||||
#include <optional>
|
||||
#include "scheduler.h"
|
||||
#include <optional>
|
||||
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include "arithmetic.hpp"
|
||||
#include "atomic_ptr_with_flags.hpp"
|
||||
#include "easy_format.hpp"
|
||||
#include "metrics.h"
|
||||
#include "mpsc.hpp"
|
||||
#include "timer.hpp"
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
|
||||
#include "kvc2.h"
|
||||
|
||||
|
@ -28,7 +28,8 @@ void Settings::auto_derive() {
|
|||
gpu_device_count = gpu_device_id.size();
|
||||
if (torch::cuda::is_available()) {
|
||||
size_t gpu_count = torch::cuda::device_count();
|
||||
SPDLOG_INFO("Number of available GPUs: {}, want {}", gpu_count, gpu_device_count);
|
||||
SPDLOG_INFO("Number of available GPUs: {}, want {}", gpu_count,
|
||||
gpu_device_count);
|
||||
if (gpu_count < gpu_device_count) {
|
||||
SPDLOG_ERROR("Not enough GPUs available.");
|
||||
exit(0);
|
||||
|
@ -42,37 +43,49 @@ void Settings::auto_derive() {
|
|||
}
|
||||
|
||||
if (model_settings.num_k_heads % gpu_device_count != 0) {
|
||||
SPDLOG_ERROR("num_k_heads {} is not divisible by gpu_device_count {}", model_settings.num_k_heads,
|
||||
gpu_device_count);
|
||||
SPDLOG_ERROR("num_k_heads {} is not divisible by gpu_device_count {}",
|
||||
model_settings.num_k_heads, gpu_device_count);
|
||||
assert(false);
|
||||
}
|
||||
|
||||
size_t gpu_memory_available = gpu_memory_size * memory_utilization_percentage;
|
||||
if (gpu_memory_available * gpu_device_count < model_settings.params_nbytes()) {
|
||||
SPDLOG_ERROR("GPU memory size {}G is smaller than {}G", gpu_memory_available * gpu_device_count / 1e9,
|
||||
if (gpu_memory_available * gpu_device_count <
|
||||
model_settings.params_nbytes()) {
|
||||
SPDLOG_ERROR("GPU memory size {}G is smaller than {}G",
|
||||
gpu_memory_available * gpu_device_count / 1e9,
|
||||
model_settings.params_nbytes() / 1e9);
|
||||
assert(false);
|
||||
}
|
||||
|
||||
assert(model_settings.k_head_dim % model_settings.num_k_heads == 0);
|
||||
size_t head_per_gpu = model_settings.num_k_heads / gpu_device_count;
|
||||
size_t gpu_memory_for_kv_cache = gpu_memory_available /*- model_settings.params_nbytes() / gpu_device_count*/;
|
||||
SPDLOG_INFO("Each GPU Total: {}MiB, Model Params: {}MiB, KVCache: {}MiB, Left: {}MiB", gpu_memory_size / (1 << 20),
|
||||
model_settings.params_nbytes() / gpu_device_count / (1 << 20), gpu_memory_for_kv_cache / (1 << 20),
|
||||
(gpu_memory_size - gpu_memory_available) / (1 << 20));
|
||||
size_t gpu_memory_for_kv_cache =
|
||||
gpu_memory_available /*- model_settings.params_nbytes() /
|
||||
gpu_device_count*/
|
||||
;
|
||||
SPDLOG_INFO(
|
||||
"Each GPU Total: {}MiB, Model Params: {}MiB, KVCache: {}MiB, Left: {}MiB",
|
||||
gpu_memory_size / (1 << 20),
|
||||
model_settings.params_nbytes() / gpu_device_count / (1 << 20),
|
||||
gpu_memory_for_kv_cache / (1 << 20),
|
||||
(gpu_memory_size - gpu_memory_available) / (1 << 20));
|
||||
size_t kv_cache_on_cnt = (size_t)(k_cache_on) + (size_t)(v_cache_on);
|
||||
size_t max_total_kvcache_pages =
|
||||
gpu_memory_for_kv_cache / (kv_cache_on_cnt * head_per_gpu * model_settings.k_head_dim *
|
||||
model_settings.bytes_per_kv_cache_element * page_size * model_settings.layer_count);
|
||||
gpu_memory_for_kv_cache /
|
||||
(kv_cache_on_cnt * head_per_gpu * model_settings.k_head_dim *
|
||||
model_settings.bytes_per_kv_cache_element * page_size *
|
||||
model_settings.layer_count);
|
||||
if (total_kvcache_pages.has_value()) {
|
||||
if (total_kvcache_pages.value() > max_total_kvcache_pages) {
|
||||
SPDLOG_ERROR("total_kvcache_pages {} is larger than max_total_kvcache_pages {}", total_kvcache_pages.value(),
|
||||
max_total_kvcache_pages);
|
||||
SPDLOG_ERROR(
|
||||
"total_kvcache_pages {} is larger than max_total_kvcache_pages {}",
|
||||
total_kvcache_pages.value(), max_total_kvcache_pages);
|
||||
assert(false);
|
||||
}
|
||||
} else {
|
||||
total_kvcache_pages = max_total_kvcache_pages;
|
||||
SPDLOG_INFO("total_kvcache_pages is auto derived as {}", max_total_kvcache_pages);
|
||||
SPDLOG_INFO("total_kvcache_pages is auto derived as {}",
|
||||
max_total_kvcache_pages);
|
||||
}
|
||||
|
||||
if (page_size % 256 != 0) {
|
||||
|
@ -88,7 +101,7 @@ void Settings::auto_derive() {
|
|||
std::string BatchQueryTodo::debug() {
|
||||
std::string re = "BatchQueryTodo: ";
|
||||
re += "QueryIDs: ";
|
||||
for (auto& id : query_ids) {
|
||||
for (auto &id : query_ids) {
|
||||
re += std::to_string(id) + " ";
|
||||
}
|
||||
return re;
|
||||
|
@ -119,59 +132,61 @@ struct Query {
|
|||
// Query status changed by this order
|
||||
enum Status { Received, Preparing, Ready, Prefill, Decode, Done };
|
||||
Status plan_status = Received;
|
||||
TokenLength active_position; // the position where no kvcache now
|
||||
TokenLength plan_position; // the position where no kvcache now, in plan
|
||||
TokenLength active_position; // the position where no kvcache now
|
||||
TokenLength plan_position; // the position where no kvcache now, in plan
|
||||
size_t prepare_try_count = 0;
|
||||
std::shared_ptr<kvc2::DoubleCacheHandleInterface> kvc2_handle = nullptr;
|
||||
|
||||
// derived from kvc2_handle
|
||||
torch::Tensor block_index; // block indexes
|
||||
torch::Tensor block_index; // block indexes
|
||||
|
||||
struct QueryContext {
|
||||
ModelName model_name;
|
||||
QuantType quant_type;
|
||||
kvc2::KVC2Interface* kvc2_interface;
|
||||
QueryMaintainer* query_maintainer;
|
||||
Metrics* met;
|
||||
kvc2::KVC2Interface *kvc2_interface;
|
||||
QueryMaintainer *query_maintainer;
|
||||
Metrics *met;
|
||||
} ctx;
|
||||
|
||||
void after_load(bool ok);
|
||||
|
||||
void to_status(Status to);
|
||||
|
||||
void export_metrics() { ctx.met->query_count(status_to_string(plan_status))->Increment(1); }
|
||||
void export_metrics() {
|
||||
ctx.met->query_count(status_to_string(plan_status))->Increment(1);
|
||||
}
|
||||
|
||||
Query(QueryID id, QueryAdd query_add, QueryContext context)
|
||||
: id(id),
|
||||
prompt_length(query_add.query_length),
|
||||
no_kvcache_from(0),
|
||||
: id(id), prompt_length(query_add.query_length), no_kvcache_from(0),
|
||||
estimated_length(query_add.estimated_length),
|
||||
sample_options(query_add.sample_options),
|
||||
user_id(query_add.user_id),
|
||||
SLO_TTFT_ms(query_add.SLO_TTFT_ms),
|
||||
SLO_TBT_ms(query_add.SLO_TBT_ms),
|
||||
stop_criteria(query_add.stop_criteria),
|
||||
ctx(context) {
|
||||
sample_options(query_add.sample_options), user_id(query_add.user_id),
|
||||
SLO_TTFT_ms(query_add.SLO_TTFT_ms), SLO_TBT_ms(query_add.SLO_TBT_ms),
|
||||
stop_criteria(query_add.stop_criteria), ctx(context) {
|
||||
std::vector<int64_t> shape = {int64_t(query_add.estimated_length)};
|
||||
query_token = torch::zeros(shape, torch::TensorOptions().dtype(torch::kInt32));
|
||||
query_token =
|
||||
torch::zeros(shape, torch::TensorOptions().dtype(torch::kInt32));
|
||||
assert(query_token.is_contiguous());
|
||||
if (query_token.is_contiguous() == false) {
|
||||
SPDLOG_ERROR("Query Token must be contiguous!");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
memcpy(query_token.data_ptr(), query_add.query_token.data(), query_add.query_length * sizeof(Token));
|
||||
memcpy(query_token.data_ptr(), query_add.query_token.data(),
|
||||
query_add.query_length * sizeof(Token));
|
||||
|
||||
no_kvcache_from = 0; // maybe match prefix later
|
||||
no_kvcache_from = 0; // maybe match prefix later
|
||||
export_metrics();
|
||||
}
|
||||
|
||||
Token& token_at(size_t idx) { return reinterpret_cast<Token*>(query_token.data_ptr())[idx]; }
|
||||
Token &token_at(size_t idx) {
|
||||
return reinterpret_cast<Token *>(query_token.data_ptr())[idx];
|
||||
}
|
||||
|
||||
void absorb_update(const QueryUpdate& update) {
|
||||
void absorb_update(const QueryUpdate &update) {
|
||||
SPDLOG_DEBUG("{}", update.debug());
|
||||
active_position = update.active_position;
|
||||
kvc2_handle->append_tokens(&token_at(0), active_position); // active_position is length -1
|
||||
kvc2_handle->append_tokens(&token_at(0),
|
||||
active_position); // active_position is length -1
|
||||
if (update.is_prefill) {
|
||||
if (active_position == prompt_length) {
|
||||
token_at(active_position) = update.generated_token;
|
||||
|
@ -187,15 +202,17 @@ struct Query {
|
|||
}
|
||||
}
|
||||
|
||||
void absorb_prefill_task(const PrefillTask& task) {
|
||||
auto& [id, start, length] = task;
|
||||
void absorb_prefill_task(const PrefillTask &task) {
|
||||
auto &[id, start, length] = task;
|
||||
this->plan_position = start + length;
|
||||
if (this->plan_position == prompt_length) {
|
||||
to_status(Decode);
|
||||
}
|
||||
}
|
||||
|
||||
void absorb_decode_task([[maybe_unused]] const QueryID& task) { this->plan_position += 1; }
|
||||
void absorb_decode_task([[maybe_unused]] const QueryID &task) {
|
||||
this->plan_position += 1;
|
||||
}
|
||||
|
||||
PrefillTask get_prefill_task(size_t prefill_length) {
|
||||
if (prefill_length + plan_position > prompt_length) {
|
||||
|
@ -206,18 +223,18 @@ struct Query {
|
|||
|
||||
static std::string status_to_string(Status status) {
|
||||
switch (status) {
|
||||
case Received:
|
||||
return "Received";
|
||||
case Preparing:
|
||||
return "Preparing";
|
||||
case Ready:
|
||||
return "Ready";
|
||||
case Prefill:
|
||||
return "Prefill";
|
||||
case Decode:
|
||||
return "Decode";
|
||||
case Done:
|
||||
return "Done";
|
||||
case Received:
|
||||
return "Received";
|
||||
case Preparing:
|
||||
return "Preparing";
|
||||
case Ready:
|
||||
return "Ready";
|
||||
case Prefill:
|
||||
return "Prefill";
|
||||
case Decode:
|
||||
return "Decode";
|
||||
case Done:
|
||||
return "Done";
|
||||
}
|
||||
assert(false);
|
||||
}
|
||||
|
@ -225,16 +242,19 @@ struct Query {
|
|||
void debug() {
|
||||
std::string status_string = status_to_string(plan_status);
|
||||
|
||||
SPDLOG_DEBUG(
|
||||
"Query {}, prompt_length {}, estimated_length {}, plan status {}, plan position {} "
|
||||
"active position {}",
|
||||
id, prompt_length, estimated_length, status_string, plan_position, active_position);
|
||||
SPDLOG_DEBUG("Query {}, prompt_length {}, estimated_length {}, plan status "
|
||||
"{}, plan position {} "
|
||||
"active position {}",
|
||||
id, prompt_length, estimated_length, status_string,
|
||||
plan_position, active_position);
|
||||
}
|
||||
};
|
||||
|
||||
std::string QueryUpdate::debug() const {
|
||||
return fmt::format("Query {}, ok {}, is_prefill {}, done {}, active_position {}, gen token {}", id, ok, is_prefill,
|
||||
decode_done, active_position, generated_token);
|
||||
return fmt::format("Query {}, ok {}, is_prefill {}, done {}, active_position "
|
||||
"{}, gen token {}",
|
||||
id, ok, is_prefill, decode_done, active_position,
|
||||
generated_token);
|
||||
}
|
||||
|
||||
using Q = std::shared_ptr<Query>;
|
||||
|
@ -258,18 +278,22 @@ struct KVC2_Maintainer {
|
|||
.total_kvcache_pages = settings.total_kvcache_pages.value(),
|
||||
.num_token_per_page = settings.page_size,
|
||||
.num_k_heads = settings.model_settings.num_k_heads,
|
||||
.k_head_dim =
|
||||
settings.use_self_defined_head_dim ? settings.self_defined_head_dim : settings.model_settings.k_head_dim,
|
||||
.k_head_dim = settings.use_self_defined_head_dim
|
||||
? settings.self_defined_head_dim
|
||||
: settings.model_settings.k_head_dim,
|
||||
.full_kv_cache_on_each_gpu = settings.full_kv_cache_on_each_gpu,
|
||||
.k_cache_on = settings.k_cache_on,
|
||||
.v_cache_on = settings.v_cache_on,
|
||||
.tensor_type = torch::kBFloat16,
|
||||
};
|
||||
|
||||
auto model_configs_path = std::filesystem::path(settings.kvc2_config_path) / "model_configs.json";
|
||||
auto model_configs_path =
|
||||
std::filesystem::path(settings.kvc2_config_path) / "model_configs.json";
|
||||
load_model_configs(model_configs_path);
|
||||
auto my_model_config = ModelConfig();
|
||||
my_model_config.load_from(std::filesystem::path(settings.model_settings.model_path) / "config.json");
|
||||
my_model_config.load_from(
|
||||
std::filesystem::path(settings.model_settings.model_path) /
|
||||
"config.json");
|
||||
model_configs[settings.model_name] = my_model_config;
|
||||
dump_model_configs(model_configs_path);
|
||||
|
||||
|
@ -299,7 +323,7 @@ struct KVC2_Maintainer {
|
|||
}
|
||||
};
|
||||
|
||||
using EventAddQuery = std::pair<QueryAdd, std::promise<QueryID>*>;
|
||||
using EventAddQuery = std::pair<QueryAdd, std::promise<QueryID> *>;
|
||||
using EventUpdateQuery = BatchQueryUpdate;
|
||||
using EventTakenBatch = std::shared_ptr<BatchQueryTodo>;
|
||||
struct EventPrepare {
|
||||
|
@ -311,55 +335,48 @@ struct EventPrepared {
|
|||
bool ok;
|
||||
};
|
||||
|
||||
struct EventQueryStatus{
|
||||
struct EventQueryStatus {
|
||||
QueryID query_id;
|
||||
Query::Status now_status;
|
||||
};
|
||||
struct EventSchedule {};
|
||||
|
||||
using Event = std::variant<EventAddQuery, EventUpdateQuery, EventTakenBatch, EventPrepare, EventPrepared,
|
||||
EventQueryStatus, EventSchedule>;
|
||||
using Event =
|
||||
std::variant<EventAddQuery, EventUpdateQuery, EventTakenBatch, EventPrepare,
|
||||
EventPrepared, EventQueryStatus, EventSchedule>;
|
||||
|
||||
template <typename T>
|
||||
std::string event_name(const T& event);
|
||||
template <typename T> std::string event_name(const T &event);
|
||||
|
||||
template <>
|
||||
std::string event_name(const EventAddQuery&) {
|
||||
template <> std::string event_name(const EventAddQuery &) {
|
||||
return "EventAddQuery";
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string event_name(const EventUpdateQuery&) {
|
||||
template <> std::string event_name(const EventUpdateQuery &) {
|
||||
return "EventUpdateQuery";
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string event_name(const EventTakenBatch&) {
|
||||
template <> std::string event_name(const EventTakenBatch &) {
|
||||
return "EventTakenBatch";
|
||||
}
|
||||
template <>
|
||||
std::string event_name(const EventPrepare&) {
|
||||
template <> std::string event_name(const EventPrepare &) {
|
||||
return "EventPrepare";
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string event_name(const EventPrepared&) {
|
||||
template <> std::string event_name(const EventPrepared &) {
|
||||
return "EventPrepared";
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string event_name(const EventQueryStatus&) {
|
||||
template <> std::string event_name(const EventQueryStatus &) {
|
||||
return "EventQueryStatus";
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string event_name(const EventSchedule&) {
|
||||
template <> std::string event_name(const EventSchedule &) {
|
||||
return "EventSchedule";
|
||||
}
|
||||
|
||||
// 用 std::visit 实现对 variant 的 event_name
|
||||
std::string event_name(const Event& event) {
|
||||
return std::visit([](const auto& e) { return event_name(e); }, event);
|
||||
std::string event_name(const Event &event) {
|
||||
return std::visit([](const auto &e) { return event_name(e); }, event);
|
||||
}
|
||||
|
||||
static_assert(std::is_copy_constructible<Event>::value);
|
||||
|
@ -383,13 +400,13 @@ struct QueryMaintainer : public Scheduler {
|
|||
|
||||
QueryMaintainer() = default;
|
||||
|
||||
void gen_batch_query_todo(BatchQueryTodo* re, const std::set<Q>& queries) {
|
||||
void gen_batch_query_todo(BatchQueryTodo *re, const std::set<Q> &queries) {
|
||||
std::vector<std::vector<QueryID>> d_batch(2);
|
||||
size_t last_decode_batch = 0;
|
||||
size_t prefill_num = 0;
|
||||
size_t decode_num = 0;
|
||||
size_t preill_length = 0;
|
||||
for (auto& q : queries) {
|
||||
for (auto &q : queries) {
|
||||
if (q->plan_status == Query::Prefill) {
|
||||
prefill_num += 1;
|
||||
}
|
||||
|
@ -397,13 +414,13 @@ struct QueryMaintainer : public Scheduler {
|
|||
decode_num += 1;
|
||||
}
|
||||
}
|
||||
if (prefill_num >= 2 || (prefill_num ==1 && settings.max_batch_size - 2 < decode_num)) {
|
||||
preill_length = settings.recommended_chunk_prefill_token_count;
|
||||
}
|
||||
else {
|
||||
if (prefill_num >= 2 ||
|
||||
(prefill_num == 1 && settings.max_batch_size - 2 < decode_num)) {
|
||||
preill_length = settings.recommended_chunk_prefill_token_count;
|
||||
} else {
|
||||
preill_length = settings.recommended_chunk_prefill_token_count * 2;
|
||||
}
|
||||
for (auto& q : queries) {
|
||||
for (auto &q : queries) {
|
||||
re->query_ids.push_back(q->id);
|
||||
re->query_tokens.push_back(q->query_token);
|
||||
re->query_lengths.push_back(q->prompt_length);
|
||||
|
@ -427,7 +444,7 @@ struct QueryMaintainer : public Scheduler {
|
|||
re->attn_masks = std::nullopt;
|
||||
re->rope_ranges = std::nullopt;
|
||||
|
||||
for (auto& b : d_batch) {
|
||||
for (auto &b : d_batch) {
|
||||
if (b.empty())
|
||||
continue;
|
||||
re->decode_mini_batches.push_back(b);
|
||||
|
@ -439,46 +456,54 @@ struct QueryMaintainer : public Scheduler {
|
|||
// Interface
|
||||
|
||||
void init(Settings settings) override {
|
||||
SPDLOG_INFO(
|
||||
"\nScheduler Settings:\n"
|
||||
" model_name: {}\n"
|
||||
" quant_type: {}\n"
|
||||
" model_path: {}\n"
|
||||
" params_count: {}\n"
|
||||
" layer_count: {}\n"
|
||||
" num_k_heads: {}\n"
|
||||
" k_head_dim: {}\n"
|
||||
" bytes_per_params: {}\n"
|
||||
" bytes_per_kv_cache_element: {}\n"
|
||||
" page_size: {}\n"
|
||||
" gpu_device_id: {}\n"
|
||||
" gpu_memory_size: {}\n"
|
||||
" memory_utilization_percentage: {}\n"
|
||||
" max_batch_size: {}\n"
|
||||
" recommended_chunk_prefill_token_count: {}\n"
|
||||
" sched_metrics_port: {}\n"
|
||||
" kvc2_config_path: {}\n"
|
||||
" kvc2_root_path: {}\n"
|
||||
" memory_pool_size_GB: {}\n"
|
||||
" evict_count: {}\n"
|
||||
" kvc2_metrics_port: {}\n"
|
||||
" load_from_disk: {}\n"
|
||||
" save_to_disk: {}\n"
|
||||
" strategy_name: {}\n"
|
||||
" gpu_device_count: {}\n",
|
||||
settings.model_name, settings.quant_type, settings.model_settings.model_path,
|
||||
settings.model_settings.params_count, settings.model_settings.layer_count, settings.model_settings.num_k_heads,
|
||||
settings.model_settings.k_head_dim, settings.model_settings.bytes_per_params,
|
||||
settings.model_settings.bytes_per_kv_cache_element,
|
||||
SPDLOG_INFO("\nScheduler Settings:\n"
|
||||
" model_name: {}\n"
|
||||
" quant_type: {}\n"
|
||||
" model_path: {}\n"
|
||||
" params_count: {}\n"
|
||||
" layer_count: {}\n"
|
||||
" num_k_heads: {}\n"
|
||||
" k_head_dim: {}\n"
|
||||
" bytes_per_params: {}\n"
|
||||
" bytes_per_kv_cache_element: {}\n"
|
||||
" page_size: {}\n"
|
||||
" gpu_device_id: {}\n"
|
||||
" gpu_memory_size: {}\n"
|
||||
" memory_utilization_percentage: {}\n"
|
||||
" max_batch_size: {}\n"
|
||||
" recommended_chunk_prefill_token_count: {}\n"
|
||||
" sched_metrics_port: {}\n"
|
||||
" kvc2_config_path: {}\n"
|
||||
" kvc2_root_path: {}\n"
|
||||
" memory_pool_size_GB: {}\n"
|
||||
" evict_count: {}\n"
|
||||
" kvc2_metrics_port: {}\n"
|
||||
" load_from_disk: {}\n"
|
||||
" save_to_disk: {}\n"
|
||||
" strategy_name: {}\n"
|
||||
" gpu_device_count: {}\n",
|
||||
settings.model_name, settings.quant_type,
|
||||
settings.model_settings.model_path,
|
||||
settings.model_settings.params_count,
|
||||
settings.model_settings.layer_count,
|
||||
settings.model_settings.num_k_heads,
|
||||
settings.model_settings.k_head_dim,
|
||||
settings.model_settings.bytes_per_params,
|
||||
settings.model_settings.bytes_per_kv_cache_element,
|
||||
|
||||
settings.page_size, format_vector(settings.gpu_device_id), readable_number(settings.gpu_memory_size),
|
||||
settings.memory_utilization_percentage, settings.max_batch_size, settings.recommended_chunk_prefill_token_count,
|
||||
settings.sched_metrics_port, settings.kvc2_config_path, settings.kvc2_root_path, settings.memory_pool_size_GB,
|
||||
settings.evict_count, settings.kvc2_metrics_port, settings.load_from_disk, settings.save_to_disk,
|
||||
settings.strategy_name, settings.gpu_device_count);
|
||||
settings.page_size, format_vector(settings.gpu_device_id),
|
||||
readable_number(settings.gpu_memory_size),
|
||||
settings.memory_utilization_percentage, settings.max_batch_size,
|
||||
settings.recommended_chunk_prefill_token_count,
|
||||
settings.sched_metrics_port, settings.kvc2_config_path,
|
||||
settings.kvc2_root_path, settings.memory_pool_size_GB,
|
||||
settings.evict_count, settings.kvc2_metrics_port,
|
||||
settings.load_from_disk, settings.save_to_disk,
|
||||
settings.strategy_name, settings.gpu_device_count);
|
||||
|
||||
this->settings = settings;
|
||||
kvc2_maintainer = std::shared_ptr<KVC2_Maintainer>(new KVC2_Maintainer(settings));
|
||||
kvc2_maintainer =
|
||||
std::shared_ptr<KVC2_Maintainer>(new KVC2_Maintainer(settings));
|
||||
MetricsConfig met_conf = {
|
||||
.endpoint = "0.0.0.0:" + std::to_string(settings.sched_metrics_port),
|
||||
.model_name = settings.model_name,
|
||||
|
@ -487,7 +512,7 @@ struct QueryMaintainer : public Scheduler {
|
|||
|
||||
SPDLOG_INFO("Creating scheduler metrics exporter on {}", met_conf.endpoint);
|
||||
met = std::make_shared<Metrics>(met_conf);
|
||||
met->fn_every_sec = [](Metrics* met) {
|
||||
met->fn_every_sec = [](Metrics *met) {
|
||||
auto generated_tokens = met->generated_tokens->Collect().counter.value;
|
||||
SPDLOG_INFO("Last Sec Generated Tokens {}", generated_tokens);
|
||||
};
|
||||
|
@ -522,7 +547,8 @@ struct QueryMaintainer : public Scheduler {
|
|||
// Here this function update last batch results and get the next batch
|
||||
// in most cases, the batch is ready,
|
||||
// if not, busy wait to get it
|
||||
std::shared_ptr<BatchQueryTodo> update_last_batch(BatchQueryUpdate updates) override {
|
||||
std::shared_ptr<BatchQueryTodo>
|
||||
update_last_batch(BatchQueryUpdate updates) override {
|
||||
event_loop_queue.enqueue(updates);
|
||||
|
||||
// Busy Wait
|
||||
|
@ -547,20 +573,22 @@ struct QueryMaintainer : public Scheduler {
|
|||
InferenceContext re;
|
||||
re.k_cache = kvc2_maintainer->k_cache;
|
||||
re.v_cache = kvc2_maintainer->v_cache;
|
||||
// kvc2_maintainer->k_cache[0][0][0][0][0][0] = 42; // test whether we pass this to inference loop
|
||||
// kvc2_maintainer->k_cache[0][0][0][0][0][0] = 42; // test whether we pass
|
||||
// this to inference loop
|
||||
return re;
|
||||
}
|
||||
|
||||
virtual void strategy_add_query(Q new_query) = 0;
|
||||
virtual void strategy_update_query(const EventUpdateQuery& update) = 0;
|
||||
virtual void strategy_taken_batch(const EventTakenBatch& batch) = 0;
|
||||
virtual void strategy_prepare(const EventPrepare& prepare) = 0;
|
||||
virtual void strategy_prepared(const EventPrepared& prepared) = 0;
|
||||
virtual void strategy_query_status(const EventQueryStatus& query_status) = 0;
|
||||
virtual void strategy_schedule(const EventSchedule& event, BatchQueryTodo* new_batch) = 0;
|
||||
virtual void strategy_update_query(const EventUpdateQuery &update) = 0;
|
||||
virtual void strategy_taken_batch(const EventTakenBatch &batch) = 0;
|
||||
virtual void strategy_prepare(const EventPrepare &prepare) = 0;
|
||||
virtual void strategy_prepared(const EventPrepared &prepared) = 0;
|
||||
virtual void strategy_query_status(const EventQueryStatus &query_status) = 0;
|
||||
virtual void strategy_schedule(const EventSchedule &event,
|
||||
BatchQueryTodo *new_batch) = 0;
|
||||
|
||||
void tackle_event(EventAddQuery& event) {
|
||||
auto& query_add = event.first;
|
||||
void tackle_event(EventAddQuery &event) {
|
||||
auto &query_add = event.first;
|
||||
QueryID id = query_id_counter;
|
||||
event.second->set_value(id);
|
||||
query_id_counter += 1;
|
||||
|
@ -570,33 +598,36 @@ struct QueryMaintainer : public Scheduler {
|
|||
strategy_add_query(new_query);
|
||||
}
|
||||
|
||||
void tackle_event(const EventUpdateQuery& update) {
|
||||
void tackle_event(const EventUpdateQuery &update) {
|
||||
// SPDLOG_INFO("Tackle Update Query");
|
||||
for (auto& u : update) {
|
||||
for (auto &u : update) {
|
||||
if (u.ok == false) {
|
||||
SPDLOG_ERROR("Query {} is not exectued OK", u.id);
|
||||
exit(1);
|
||||
}
|
||||
auto q = query_map[u.id];
|
||||
if (q->plan_status == Query::Status::Prefill || q->plan_status == Query::Status::Decode) {
|
||||
if (q->plan_status == Query::Status::Prefill ||
|
||||
q->plan_status == Query::Status::Decode) {
|
||||
q->absorb_update(u);
|
||||
} else {
|
||||
SPDLOG_DEBUG("Query {} is not in Prefill or Decode status, do not update it", u.id);
|
||||
SPDLOG_DEBUG(
|
||||
"Query {} is not in Prefill or Decode status, do not update it",
|
||||
u.id);
|
||||
}
|
||||
}
|
||||
strategy_update_query(update);
|
||||
}
|
||||
|
||||
void tackle_event(const EventTakenBatch& batch) {
|
||||
void tackle_event(const EventTakenBatch &batch) {
|
||||
met->batch_count("Taken")->Increment(1);
|
||||
for (auto& task : batch->prefill_mini_batches) {
|
||||
for (auto &task : batch->prefill_mini_batches) {
|
||||
auto [id, s, l] = task;
|
||||
if (l == 0)
|
||||
continue;
|
||||
query_map.at(id)->absorb_prefill_task(task);
|
||||
}
|
||||
for (auto& mini_batch : batch->decode_mini_batches) {
|
||||
for (auto& id : mini_batch) {
|
||||
for (auto &mini_batch : batch->decode_mini_batches) {
|
||||
for (auto &id : mini_batch) {
|
||||
query_map.at(id)->absorb_decode_task(id);
|
||||
}
|
||||
}
|
||||
|
@ -604,16 +635,18 @@ struct QueryMaintainer : public Scheduler {
|
|||
strategy_taken_batch(batch);
|
||||
}
|
||||
|
||||
void tackle_event(const EventPrepare& event) { strategy_prepare(event); }
|
||||
void tackle_event(const EventPrepared& event) { strategy_prepared(event); }
|
||||
void tackle_event(const EventQueryStatus& event) { strategy_query_status(event); }
|
||||
void tackle_event(const EventPrepare &event) { strategy_prepare(event); }
|
||||
void tackle_event(const EventPrepared &event) { strategy_prepared(event); }
|
||||
void tackle_event(const EventQueryStatus &event) {
|
||||
strategy_query_status(event);
|
||||
}
|
||||
|
||||
void tackle_event(const EventSchedule& event) {
|
||||
void tackle_event(const EventSchedule &event) {
|
||||
// SPDLOG_INFO("Tackle Schedule Event");
|
||||
|
||||
HistogramTimerWrapper t(met->schedule_time);
|
||||
|
||||
BatchQueryTodo* new_batch = new BatchQueryTodo;
|
||||
BatchQueryTodo *new_batch = new BatchQueryTodo;
|
||||
strategy_schedule(event, new_batch);
|
||||
// if (new_batch->query_ids.empty()) {
|
||||
// SPDLOG_INFO("Nothing todo");
|
||||
|
@ -660,7 +693,8 @@ struct QueryMaintainer : public Scheduler {
|
|||
}
|
||||
},
|
||||
event);
|
||||
if (event_loop_queue.size() == 0 && std::holds_alternative<EventSchedule>(event) == false) {
|
||||
if (event_loop_queue.size() == 0 &&
|
||||
std::holds_alternative<EventSchedule>(event) == false) {
|
||||
// if this is not a schedule event, we need to schedule one
|
||||
event_loop_queue.enqueue(EventSchedule());
|
||||
}
|
||||
|
@ -679,54 +713,58 @@ struct QueryMaintainer : public Scheduler {
|
|||
void Query::to_status(Status to) {
|
||||
SPDLOG_DEBUG("Calling to status query {}, to {}", id, status_to_string(to));
|
||||
switch (to) {
|
||||
case Received:
|
||||
assert(false);
|
||||
break;
|
||||
case Preparing:
|
||||
SPDLOG_INFO("Preparing Query {} {}", id,
|
||||
prepare_try_count > 0 ? (std::to_string(prepare_try_count) + " Try") : "");
|
||||
prepare_try_count += 1;
|
||||
case Received:
|
||||
assert(false);
|
||||
break;
|
||||
case Preparing:
|
||||
SPDLOG_INFO("Preparing Query {} {}", id,
|
||||
prepare_try_count > 0
|
||||
? (std::to_string(prepare_try_count) + " Try")
|
||||
: "");
|
||||
prepare_try_count += 1;
|
||||
|
||||
ctx.kvc2_interface->lookup_to_gpu_async(
|
||||
ctx.model_name, ctx.quant_type, static_cast<kvc2::Token*>(query_token.data_ptr()), prompt_length,
|
||||
estimated_length, [this](std::shared_ptr<kvc2::DoubleCacheHandleInterface> handle) {
|
||||
if (handle == nullptr) {
|
||||
SPDLOG_INFO("Get handle from kvc2 Failed.");
|
||||
this->after_load(false);
|
||||
} else {
|
||||
SPDLOG_INFO("Get handle from kvc2 Success.");
|
||||
this->kvc2_handle = handle;
|
||||
this->to_status(Ready);
|
||||
this->after_load(true);
|
||||
}
|
||||
});
|
||||
break;
|
||||
case Ready:
|
||||
SPDLOG_INFO("Ready Query {}", id);
|
||||
break;
|
||||
case Prefill:
|
||||
SPDLOG_INFO("Prefilling Query {}", id);
|
||||
// assert(plan_status == Received);
|
||||
plan_position = kvc2_handle->matched_length();
|
||||
ctx.kvc2_interface->lookup_to_gpu_async(
|
||||
ctx.model_name, ctx.quant_type,
|
||||
static_cast<kvc2::Token *>(query_token.data_ptr()), prompt_length,
|
||||
estimated_length,
|
||||
[this](std::shared_ptr<kvc2::DoubleCacheHandleInterface> handle) {
|
||||
if (handle == nullptr) {
|
||||
SPDLOG_INFO("Get handle from kvc2 Failed.");
|
||||
this->after_load(false);
|
||||
} else {
|
||||
SPDLOG_INFO("Get handle from kvc2 Success.");
|
||||
this->kvc2_handle = handle;
|
||||
this->to_status(Ready);
|
||||
this->after_load(true);
|
||||
}
|
||||
});
|
||||
break;
|
||||
case Ready:
|
||||
SPDLOG_INFO("Ready Query {}", id);
|
||||
break;
|
||||
case Prefill:
|
||||
SPDLOG_INFO("Prefilling Query {}", id);
|
||||
// assert(plan_status == Received);
|
||||
plan_position = kvc2_handle->matched_length();
|
||||
|
||||
if (prompt_length - plan_position == 0) {
|
||||
assert(prompt_length > 0);
|
||||
plan_position -= 1;
|
||||
}
|
||||
break;
|
||||
case Decode:
|
||||
SPDLOG_INFO("Decoding Query {}", id);
|
||||
// assert(plan_status == Prefill);
|
||||
break;
|
||||
case Done:
|
||||
SPDLOG_INFO("Finish Query {}", id);
|
||||
kvc2_handle = nullptr;
|
||||
ctx.query_maintainer->event_loop_queue.enqueue(EventQueryStatus{
|
||||
if (prompt_length - plan_position == 0) {
|
||||
assert(prompt_length > 0);
|
||||
plan_position -= 1;
|
||||
}
|
||||
break;
|
||||
case Decode:
|
||||
SPDLOG_INFO("Decoding Query {}", id);
|
||||
// assert(plan_status == Prefill);
|
||||
break;
|
||||
case Done:
|
||||
SPDLOG_INFO("Finish Query {}", id);
|
||||
kvc2_handle = nullptr;
|
||||
ctx.query_maintainer->event_loop_queue.enqueue(EventQueryStatus{
|
||||
.query_id = id,
|
||||
.now_status = to,
|
||||
});
|
||||
// assert(plan_status == Decode);
|
||||
break;
|
||||
});
|
||||
// assert(plan_status == Decode);
|
||||
break;
|
||||
}
|
||||
plan_status = to;
|
||||
export_metrics();
|
||||
|
@ -734,11 +772,14 @@ void Query::to_status(Status to) {
|
|||
|
||||
void Query::after_load(bool ok) {
|
||||
if (ok) {
|
||||
size_t page_count = div_up(estimated_length, ctx.query_maintainer->settings.page_size);
|
||||
size_t page_count =
|
||||
div_up(estimated_length, ctx.query_maintainer->settings.page_size);
|
||||
std::vector<int64_t> shape;
|
||||
shape.push_back(page_count);
|
||||
block_index = torch::zeros(shape, torch::TensorOptions().dtype(torch::kInt32)).contiguous();
|
||||
auto ptr = reinterpret_cast<int32_t*>(block_index.data_ptr());
|
||||
block_index =
|
||||
torch::zeros(shape, torch::TensorOptions().dtype(torch::kInt32))
|
||||
.contiguous();
|
||||
auto ptr = reinterpret_cast<int32_t *>(block_index.data_ptr());
|
||||
auto vec_idx = kvc2_handle->get_gpu_block_idx();
|
||||
for (size_t i = 0; i < vec_idx.size(); i++) {
|
||||
ptr[i] = vec_idx[i];
|
||||
|
@ -765,7 +806,7 @@ struct FCFS_single_prefill : public QueryMaintainer {
|
|||
bool has_query_preparing = false;
|
||||
std::optional<EventPrepare> wait_done_prepare = std::nullopt;
|
||||
|
||||
std::set<Q> active_query; // on going queries for LLMs
|
||||
std::set<Q> active_query; // on going queries for LLMs
|
||||
|
||||
// interface all these are executed in a single thread
|
||||
void strategy_add_query(Q new_query) override {
|
||||
|
@ -774,71 +815,72 @@ struct FCFS_single_prefill : public QueryMaintainer {
|
|||
has_query_preparing = true;
|
||||
auto next_q = queue.front();
|
||||
queue.pop();
|
||||
event_loop_queue.enqueue(EventPrepare{next_q->id,true});
|
||||
event_loop_queue.enqueue(EventPrepare{next_q->id, true});
|
||||
}
|
||||
}
|
||||
|
||||
void strategy_update_query(const EventUpdateQuery& update) override {
|
||||
void strategy_update_query(const EventUpdateQuery &update) override {
|
||||
for (auto u : update) {
|
||||
auto& q = query_map[u.id];
|
||||
auto &q = query_map[u.id];
|
||||
if (q->plan_status == Query::Done) {
|
||||
active_query.erase(q);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void strategy_taken_batch(const EventTakenBatch& batch) override {
|
||||
for (auto& q : batch->query_ids) {
|
||||
void strategy_taken_batch(const EventTakenBatch &batch) override {
|
||||
for (auto &q : batch->query_ids) {
|
||||
if (query_map[q]->plan_status != Query::Done) {
|
||||
active_query.insert(query_map[q]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void strategy_prepare(const EventPrepare& prepare) override {
|
||||
if(prepare.first_try){
|
||||
auto& q = query_map[prepare.query_id];
|
||||
void strategy_prepare(const EventPrepare &prepare) override {
|
||||
if (prepare.first_try) {
|
||||
auto &q = query_map[prepare.query_id];
|
||||
q->to_status(Query::Preparing);
|
||||
}else{
|
||||
assert(wait_done_prepare.has_value()==false);
|
||||
} else {
|
||||
assert(wait_done_prepare.has_value() == false);
|
||||
wait_done_prepare = prepare;
|
||||
wait_done_prepare->first_try = true;
|
||||
}
|
||||
}
|
||||
|
||||
void strategy_prepared(const EventPrepared& prepared) override {
|
||||
void strategy_prepared(const EventPrepared &prepared) override {
|
||||
assert(prepared.ok);
|
||||
ready_queue.push(query_map[prepared.query_id]);
|
||||
if (queue.empty() == false) {
|
||||
auto next_q_prepare = queue.front();
|
||||
queue.pop();
|
||||
event_loop_queue.enqueue(EventPrepare{next_q_prepare->id,true});
|
||||
event_loop_queue.enqueue(EventPrepare{next_q_prepare->id, true});
|
||||
|
||||
} else {
|
||||
has_query_preparing = false;
|
||||
}
|
||||
}
|
||||
|
||||
void strategy_query_status(const EventQueryStatus& query_status) override{
|
||||
if(query_status.now_status==Query::Done){
|
||||
if(wait_done_prepare.has_value()){
|
||||
void strategy_query_status(const EventQueryStatus &query_status) override {
|
||||
if (query_status.now_status == Query::Done) {
|
||||
if (wait_done_prepare.has_value()) {
|
||||
event_loop_queue.enqueue(wait_done_prepare.value());
|
||||
wait_done_prepare = std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void strategy_schedule([[maybe_unused]] const EventSchedule& event, BatchQueryTodo* new_batch) override {
|
||||
void strategy_schedule([[maybe_unused]] const EventSchedule &event,
|
||||
BatchQueryTodo *new_batch) override {
|
||||
bool have_prefill = false;
|
||||
for (auto& q : active_query) {
|
||||
for (auto &q : active_query) {
|
||||
if (q->plan_status == Query::Prefill) {
|
||||
have_prefill = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (have_prefill == false && ready_queue.empty() == false && active_query.size() < settings.max_batch_size) {
|
||||
auto& next_q = ready_queue.front();
|
||||
if (have_prefill == false && ready_queue.empty() == false &&
|
||||
active_query.size() < settings.max_batch_size) {
|
||||
auto &next_q = ready_queue.front();
|
||||
ready_queue.pop();
|
||||
|
||||
SPDLOG_INFO("Active query {}", next_q->id);
|
||||
|
@ -847,7 +889,7 @@ struct FCFS_single_prefill : public QueryMaintainer {
|
|||
}
|
||||
if (active_query.empty() == false)
|
||||
SPDLOG_INFO("Active Query Size {}", active_query.size());
|
||||
for (auto& q : active_query) {
|
||||
for (auto &q : active_query) {
|
||||
q->debug();
|
||||
}
|
||||
gen_batch_query_todo(new_batch, active_query);
|
||||
|
@ -855,10 +897,11 @@ struct FCFS_single_prefill : public QueryMaintainer {
|
|||
};
|
||||
|
||||
struct FCFS : public FCFS_single_prefill {
|
||||
void strategy_schedule([[maybe_unused]] const EventSchedule& event, BatchQueryTodo* new_batch) override {
|
||||
void strategy_schedule([[maybe_unused]] const EventSchedule &event,
|
||||
BatchQueryTodo *new_batch) override {
|
||||
int prefill_count = 0;
|
||||
const int max_prefill_count = 2;
|
||||
for (auto& q : active_query) {
|
||||
for (auto &q : active_query) {
|
||||
if (q->plan_status == Query::Prefill) {
|
||||
prefill_count += 1;
|
||||
}
|
||||
|
@ -877,7 +920,7 @@ struct FCFS : public FCFS_single_prefill {
|
|||
if (active_query.empty() == false) {
|
||||
SPDLOG_DEBUG("Active Query Size {}", active_query.size());
|
||||
}
|
||||
for (auto& q : active_query) {
|
||||
for (auto &q : active_query) {
|
||||
q->debug();
|
||||
}
|
||||
gen_batch_query_todo(new_batch, active_query);
|
||||
|
@ -900,7 +943,8 @@ std::shared_ptr<Scheduler> create_scheduler(Settings settings) {
|
|||
}
|
||||
|
||||
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(SampleOptions, temperature, top_p);
|
||||
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(QueryAdd, query_token, query_length, estimated_length, sample_options, user_id,
|
||||
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(QueryAdd, query_token, query_length,
|
||||
estimated_length, sample_options, user_id,
|
||||
SLO_TTFT_ms, SLO_TBT_ms);
|
||||
|
||||
std::string QueryAdd::serialize() {
|
||||
|
@ -908,9 +952,9 @@ std::string QueryAdd::serialize() {
|
|||
return j.dump();
|
||||
}
|
||||
|
||||
QueryAdd QueryAdd::deserialize(const std::string& input) {
|
||||
QueryAdd QueryAdd::deserialize(const std::string &input) {
|
||||
json j = json::parse(input);
|
||||
return j.get<QueryAdd>();
|
||||
}
|
||||
|
||||
}; // namespace scheduler
|
||||
}; // namespace scheduler
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
#pragma once
|
||||
#include <torch/torch.h>
|
||||
#include "model_config.h"
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
#include "model_config.h"
|
||||
|
||||
namespace scheduler {
|
||||
|
||||
|
@ -28,7 +28,9 @@ struct ModelSettings {
|
|||
double bytes_per_kv_cache_element;
|
||||
|
||||
inline size_t params_nbytes() { return params_count * bytes_per_params; }
|
||||
inline size_t bytes_per_token_kv_cache() { return bytes_per_kv_cache_element * num_k_heads * k_head_dim; }
|
||||
inline size_t bytes_per_token_kv_cache() {
|
||||
return bytes_per_kv_cache_element * num_k_heads * k_head_dim;
|
||||
}
|
||||
};
|
||||
|
||||
struct SampleOptions {
|
||||
|
@ -37,15 +39,16 @@ struct SampleOptions {
|
|||
};
|
||||
|
||||
struct Settings {
|
||||
// something is aukward here, kvc2 only use model_name and quant_type to get model infos.
|
||||
// something is aukward here, kvc2 only use model_name and quant_type to get
|
||||
// model infos.
|
||||
ModelName model_name;
|
||||
QuantType quant_type;
|
||||
// model_setting is ignore by kvc2
|
||||
ModelSettings model_settings;
|
||||
|
||||
size_t page_size = 256; // how many token in a page
|
||||
std::vector<size_t> gpu_device_id; //
|
||||
size_t gpu_memory_size; // memory size in bytes of each GPU, each
|
||||
size_t page_size = 256; // how many token in a page
|
||||
std::vector<size_t> gpu_device_id; //
|
||||
size_t gpu_memory_size; // memory size in bytes of each GPU, each
|
||||
double memory_utilization_percentage;
|
||||
|
||||
size_t max_batch_size = 256;
|
||||
|
@ -79,14 +82,16 @@ struct Settings {
|
|||
void auto_derive();
|
||||
};
|
||||
|
||||
using PrefillTask = std::tuple<QueryID, TokenLength, TokenLength>; // id, start, length
|
||||
using PrefillTask =
|
||||
std::tuple<QueryID, TokenLength, TokenLength>; // id, start, length
|
||||
|
||||
struct BatchQueryTodo {
|
||||
// query
|
||||
std::vector<QueryID> query_ids;
|
||||
std::vector<torch::Tensor> query_tokens;
|
||||
std::vector<TokenLength> query_lengths;
|
||||
std::vector<torch::Tensor> block_indexes; // (max_num_blocks_per_seq), dtype torch.int32.
|
||||
std::vector<torch::Tensor>
|
||||
block_indexes; // (max_num_blocks_per_seq), dtype torch.int32.
|
||||
std::optional<torch::Tensor> attn_masks;
|
||||
std::optional<torch::Tensor> rope_ranges;
|
||||
std::vector<SampleOptions> sample_options;
|
||||
|
@ -94,8 +99,10 @@ struct BatchQueryTodo {
|
|||
|
||||
// mini batches, adjacent two mini batches are executed together
|
||||
// tasks count must be <=2, because of flash infer attention
|
||||
std::vector<PrefillTask> prefill_mini_batches; // prefill minibatch only has 1 prefill
|
||||
std::vector<std::vector<QueryID>> decode_mini_batches; // decode minibatch has multiple decode
|
||||
std::vector<PrefillTask>
|
||||
prefill_mini_batches; // prefill minibatch only has 1 prefill
|
||||
std::vector<std::vector<QueryID>>
|
||||
decode_mini_batches; // decode minibatch has multiple decode
|
||||
|
||||
std::string debug();
|
||||
bool empty();
|
||||
|
@ -105,9 +112,9 @@ struct QueryUpdate {
|
|||
QueryID id;
|
||||
bool ok;
|
||||
bool is_prefill;
|
||||
bool decode_done; // no use for now
|
||||
TokenLength active_position; // the position where no kvcache now,
|
||||
// kvcache[active_position] == None
|
||||
bool decode_done; // no use for now
|
||||
TokenLength active_position; // the position where no kvcache now,
|
||||
// kvcache[active_position] == None
|
||||
|
||||
Token generated_token;
|
||||
|
||||
|
@ -117,8 +124,8 @@ struct QueryUpdate {
|
|||
using BatchQueryUpdate = std::vector<QueryUpdate>;
|
||||
|
||||
struct InferenceContext {
|
||||
std::vector<torch::Tensor> k_cache; // [gpu num] (layer_count, num blocks,
|
||||
// page size, kheadnum, head_dim)
|
||||
std::vector<torch::Tensor> k_cache; // [gpu num] (layer_count, num blocks,
|
||||
// page size, kheadnum, head_dim)
|
||||
std::vector<torch::Tensor> v_cache;
|
||||
};
|
||||
|
||||
|
@ -127,7 +134,7 @@ constexpr UserID NoUser = -1;
|
|||
const int MAX_SLO_TIME = 1e9;
|
||||
|
||||
struct QueryAdd {
|
||||
std::vector<Token> query_token; // int here
|
||||
std::vector<Token> query_token; // int here
|
||||
// torch::Tensor attn_mask;
|
||||
TokenLength query_length;
|
||||
TokenLength estimated_length;
|
||||
|
@ -141,11 +148,11 @@ struct QueryAdd {
|
|||
int SLO_TBT_ms = MAX_SLO_TIME;
|
||||
|
||||
std::string serialize();
|
||||
static QueryAdd deserialize(const std::string& input);
|
||||
static QueryAdd deserialize(const std::string &input);
|
||||
};
|
||||
|
||||
class Scheduler {
|
||||
public:
|
||||
public:
|
||||
virtual void init(Settings settings) = 0;
|
||||
|
||||
virtual void run() = 0;
|
||||
|
@ -156,7 +163,8 @@ class Scheduler {
|
|||
virtual void cancel_query(QueryID id) = 0;
|
||||
|
||||
// inference loop call this
|
||||
virtual std::shared_ptr<BatchQueryTodo> update_last_batch(BatchQueryUpdate updates) = 0;
|
||||
virtual std::shared_ptr<BatchQueryTodo>
|
||||
update_last_batch(BatchQueryUpdate updates) = 0;
|
||||
virtual InferenceContext get_inference_context() = 0;
|
||||
|
||||
virtual ~Scheduler() = default;
|
||||
|
@ -164,4 +172,4 @@ class Scheduler {
|
|||
|
||||
std::shared_ptr<Scheduler> create_scheduler(Settings settings);
|
||||
|
||||
}; // namespace scheduler
|
||||
}; // namespace scheduler
|
|
@ -1,7 +1,6 @@
|
|||
#include <type_traits>
|
||||
|
||||
template <typename T, typename U>
|
||||
T div_up(T x, U by) {
|
||||
template <typename T, typename U> T div_up(T x, U by) {
|
||||
static_assert(std::is_integral_v<T>);
|
||||
static_assert(std::is_integral_v<U>);
|
||||
return (x + by - 1) / by;
|
||||
|
|
|
@ -1,28 +1,35 @@
|
|||
#include <atomic>
|
||||
|
||||
template <typename T>
|
||||
struct AtomicPtrWithFlag {
|
||||
template <typename T> struct AtomicPtrWithFlag {
|
||||
constexpr static uint64_t mask = 1ull << 63;
|
||||
std::atomic_uint64_t ptr = 0;
|
||||
|
||||
std::pair<T*, bool> load(std::memory_order order = std::memory_order_seq_cst) {
|
||||
std::pair<T *, bool>
|
||||
load(std::memory_order order = std::memory_order_seq_cst) {
|
||||
uint64_t val = ptr.load(order);
|
||||
return {reinterpret_cast<T*>(val & (~mask)), val & mask};
|
||||
return {reinterpret_cast<T *>(val & (~mask)), val & mask};
|
||||
}
|
||||
|
||||
void store(T* p, bool flag, std::memory_order order = std::memory_order_seq_cst) {
|
||||
void store(T *p, bool flag,
|
||||
std::memory_order order = std::memory_order_seq_cst) {
|
||||
ptr.store(reinterpret_cast<uint64_t>(p) | (flag ? mask : 0), order);
|
||||
}
|
||||
|
||||
std::pair<T*, bool> exchange(T* p, bool flag, std::memory_order order = std::memory_order_seq_cst) {
|
||||
uint64_t val = ptr.exchange(reinterpret_cast<uint64_t>(p) | (flag ? mask : 0), order);
|
||||
return {reinterpret_cast<T*>(val & (~mask)), val & mask};
|
||||
std::pair<T *, bool>
|
||||
exchange(T *p, bool flag,
|
||||
std::memory_order order = std::memory_order_seq_cst) {
|
||||
uint64_t val =
|
||||
ptr.exchange(reinterpret_cast<uint64_t>(p) | (flag ? mask : 0), order);
|
||||
return {reinterpret_cast<T *>(val & (~mask)), val & mask};
|
||||
}
|
||||
|
||||
std::pair<T*, bool> touch_load(std::memory_order order = std::memory_order_seq_cst) {
|
||||
std::pair<T *, bool>
|
||||
touch_load(std::memory_order order = std::memory_order_seq_cst) {
|
||||
uint64_t val = ptr.fetch_and(~mask, order);
|
||||
return {reinterpret_cast<T*>(val & (~mask)), val & mask};
|
||||
return {reinterpret_cast<T *>(val & (~mask)), val & mask};
|
||||
}
|
||||
|
||||
bool check_flag(std::memory_order order = std::memory_order_seq_cst) { return ptr.load(order) & mask; }
|
||||
bool check_flag(std::memory_order order = std::memory_order_seq_cst) {
|
||||
return ptr.load(order) & mask;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -19,7 +19,7 @@ namespace csv {
|
|||
* @param line The CSV line to parse.
|
||||
* @return A vector of strings, each representing a field in the CSV line.
|
||||
*/
|
||||
inline std::vector<std::string> parse_csv_line(const std::string& line) {
|
||||
inline std::vector<std::string> parse_csv_line(const std::string &line) {
|
||||
std::vector<std::string> result;
|
||||
std::string field;
|
||||
bool in_quotes = false;
|
||||
|
@ -57,7 +57,8 @@ inline std::vector<std::string> parse_csv_line(const std::string& line) {
|
|||
* @return A vector of pairs, each containing a column name and a vector of data
|
||||
* for that column.
|
||||
*/
|
||||
inline std::vector<std::pair<std::string, std::vector<std::string>>> read_csv(const std::string& filename) {
|
||||
inline std::vector<std::pair<std::string, std::vector<std::string>>>
|
||||
read_csv(const std::string &filename) {
|
||||
std::cout << "Reading CSV file: " << filename << std::endl;
|
||||
// Open the file
|
||||
std::ifstream file(filename);
|
||||
|
@ -72,7 +73,7 @@ inline std::vector<std::pair<std::string, std::vector<std::string>>> read_csv(co
|
|||
|
||||
// Prepare the result vector with column names
|
||||
std::vector<std::pair<std::string, std::vector<std::string>>> result;
|
||||
for (const auto& name : column_names) {
|
||||
for (const auto &name : column_names) {
|
||||
result.emplace_back(name, std::vector<std::string>());
|
||||
}
|
||||
|
||||
|
@ -84,7 +85,7 @@ inline std::vector<std::pair<std::string, std::vector<std::string>>> read_csv(co
|
|||
// Determine the number of threads to use
|
||||
unsigned int num_threads = std::thread::hardware_concurrency();
|
||||
if (num_threads == 0)
|
||||
num_threads = 4; // Default to 4 threads if hardware_concurrency returns 0
|
||||
num_threads = 4; // Default to 4 threads if hardware_concurrency returns 0
|
||||
|
||||
// Calculate chunk start positions based on content size
|
||||
std::vector<size_t> chunk_starts;
|
||||
|
@ -100,14 +101,15 @@ inline std::vector<std::pair<std::string, std::vector<std::string>>> read_csv(co
|
|||
++pos;
|
||||
}
|
||||
if (pos < content_size) {
|
||||
++pos; // Skip the newline character
|
||||
++pos; // Skip the newline character
|
||||
}
|
||||
chunk_starts.push_back(pos);
|
||||
}
|
||||
chunk_starts.push_back(content_size);
|
||||
|
||||
// Create threads to parse each chunk
|
||||
std::vector<std::vector<std::vector<std::string>>> thread_results(num_threads);
|
||||
std::vector<std::vector<std::vector<std::string>>> thread_results(
|
||||
num_threads);
|
||||
std::vector<std::thread> threads;
|
||||
|
||||
for (unsigned int i = 0; i < num_threads; ++i) {
|
||||
|
@ -133,13 +135,13 @@ inline std::vector<std::pair<std::string, std::vector<std::string>>> read_csv(co
|
|||
}
|
||||
|
||||
// Wait for all threads to finish
|
||||
for (auto& t : threads) {
|
||||
for (auto &t : threads) {
|
||||
t.join();
|
||||
}
|
||||
|
||||
// Combine the results from all threads into the final result
|
||||
for (const auto& local_result : thread_results) {
|
||||
for (const auto& row : local_result) {
|
||||
for (const auto &local_result : thread_results) {
|
||||
for (const auto &row : local_result) {
|
||||
for (size_t i = 0; i < row.size(); ++i) {
|
||||
if (i < result.size()) {
|
||||
result[i].second.push_back(row[i]);
|
||||
|
@ -158,8 +160,9 @@ inline std::vector<std::pair<std::string, std::vector<std::string>>> read_csv(co
|
|||
* @param data A vector of pairs, each containing a column name and a vector of
|
||||
* data for that column.
|
||||
*/
|
||||
inline void write_csv(const std::string& filename,
|
||||
const std::vector<std::pair<std::string, std::vector<std::string>>>& data) {
|
||||
inline void write_csv(
|
||||
const std::string &filename,
|
||||
const std::vector<std::pair<std::string, std::vector<std::string>>> &data) {
|
||||
std::cout << "Writing CSV file: " << filename << std::endl;
|
||||
|
||||
// Open the file for writing
|
||||
|
@ -170,10 +173,10 @@ inline void write_csv(const std::string& filename,
|
|||
|
||||
// Check that all columns have the same number of rows
|
||||
if (data.empty()) {
|
||||
return; // Nothing to write
|
||||
return; // Nothing to write
|
||||
}
|
||||
size_t num_rows = data[0].second.size();
|
||||
for (const auto& column : data) {
|
||||
for (const auto &column : data) {
|
||||
if (column.second.size() != num_rows) {
|
||||
throw std::runtime_error("All columns must have the same number of rows");
|
||||
}
|
||||
|
@ -191,7 +194,7 @@ inline void write_csv(const std::string& filename,
|
|||
// Write the data rows
|
||||
for (size_t row = 0; row < num_rows; ++row) {
|
||||
for (size_t col = 0; col < data.size(); ++col) {
|
||||
const std::string& field = data[col].second[row];
|
||||
const std::string &field = data[col].second[row];
|
||||
// Handle CSV escaping
|
||||
std::string escaped_field = field;
|
||||
bool needs_quotes = false;
|
||||
|
@ -204,7 +207,8 @@ inline void write_csv(const std::string& filename,
|
|||
pos += 2;
|
||||
}
|
||||
}
|
||||
if (escaped_field.find(',') != std::string::npos || escaped_field.find('\n') != std::string::npos) {
|
||||
if (escaped_field.find(',') != std::string::npos ||
|
||||
escaped_field.find('\n') != std::string::npos) {
|
||||
needs_quotes = true;
|
||||
}
|
||||
if (needs_quotes) {
|
||||
|
@ -220,6 +224,6 @@ inline void write_csv(const std::string& filename,
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace csv
|
||||
} // namespace csv
|
||||
|
||||
#endif // CSV_READER_HPP
|
||||
#endif // CSV_READER_HPP
|
||||
|
|
|
@ -2,15 +2,14 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
template <typename T>
|
||||
std::string format_vector(const std::vector<T>& v) {
|
||||
template <typename T> std::string format_vector(const std::vector<T> &v) {
|
||||
std::ostringstream oss;
|
||||
if (v.empty())
|
||||
return "[]";
|
||||
for (size_t i = 0; i < v.size(); ++i) {
|
||||
oss << v[i];
|
||||
if (i < v.size() - 1)
|
||||
oss << ", "; // 逗号分隔
|
||||
oss << ", "; // 逗号分隔
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
|
|
@ -4,32 +4,31 @@
|
|||
#include <optional>
|
||||
#include <semaphore>
|
||||
|
||||
template <typename T>
|
||||
class MPSCQueue {
|
||||
template <typename T> class MPSCQueue {
|
||||
struct Node {
|
||||
T data;
|
||||
std::atomic<Node*> next;
|
||||
std::atomic<Node *> next;
|
||||
|
||||
Node() : next(nullptr) {}
|
||||
Node(T data_) : data(std::move(data_)), next(nullptr) {}
|
||||
};
|
||||
|
||||
std::atomic<Node*> head;
|
||||
Node* tail;
|
||||
std::atomic<Node *> head;
|
||||
Node *tail;
|
||||
|
||||
public:
|
||||
public:
|
||||
std::atomic_size_t enqueue_count = 0;
|
||||
size_t dequeue_count = 0;
|
||||
MPSCQueue() {
|
||||
Node* dummy = new Node();
|
||||
Node *dummy = new Node();
|
||||
head.store(dummy, std::memory_order_seq_cst);
|
||||
tail = dummy;
|
||||
}
|
||||
|
||||
~MPSCQueue() {
|
||||
Node* node = tail;
|
||||
Node *node = tail;
|
||||
while (node) {
|
||||
Node* next = node->next.load(std::memory_order_seq_cst);
|
||||
Node *next = node->next.load(std::memory_order_seq_cst);
|
||||
delete node;
|
||||
node = next;
|
||||
}
|
||||
|
@ -38,14 +37,14 @@ class MPSCQueue {
|
|||
// 生产者调用
|
||||
void enqueue(T data) {
|
||||
enqueue_count.fetch_add(1);
|
||||
Node* node = new Node(std::move(data));
|
||||
Node* prev_head = head.exchange(node, std::memory_order_seq_cst);
|
||||
Node *node = new Node(std::move(data));
|
||||
Node *prev_head = head.exchange(node, std::memory_order_seq_cst);
|
||||
prev_head->next.store(node, std::memory_order_seq_cst);
|
||||
}
|
||||
|
||||
// 消费者调用
|
||||
std::optional<T> dequeue() {
|
||||
Node* next = tail->next.load(std::memory_order_seq_cst);
|
||||
Node *next = tail->next.load(std::memory_order_seq_cst);
|
||||
if (next) {
|
||||
T res = std::move(next->data);
|
||||
delete tail;
|
||||
|
@ -59,16 +58,16 @@ class MPSCQueue {
|
|||
size_t size() { return enqueue_count.load() - dequeue_count; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class MPSCQueueConsumerLock {
|
||||
template <typename T> class MPSCQueueConsumerLock {
|
||||
MPSCQueue<T> queue;
|
||||
std::counting_semaphore<> sema{0};
|
||||
|
||||
public:
|
||||
public:
|
||||
void enqueue(T data) {
|
||||
queue.enqueue(std::move(data));
|
||||
// std::atomic_thread_fence(std::memory_order_seq_cst);// Inserting this because the memory order might be wrong, I
|
||||
// am also not that sure about this.
|
||||
// std::atomic_thread_fence(std::memory_order_seq_cst);// Inserting this
|
||||
// because the memory order might be wrong, I am also not that sure about
|
||||
// this.
|
||||
sema.release();
|
||||
}
|
||||
|
||||
|
@ -76,8 +75,10 @@ class MPSCQueueConsumerLock {
|
|||
auto re = queue.dequeue();
|
||||
if (re.has_value()) {
|
||||
while (sema.try_acquire() == false) {
|
||||
std::cerr << __FILE__ << ":" << __FUNCTION__ << " sema try acquire should be success, retrying, please check"
|
||||
<< std::endl;
|
||||
std::cerr
|
||||
<< __FILE__ << ":" << __FUNCTION__
|
||||
<< " sema try acquire should be success, retrying, please check"
|
||||
<< std::endl;
|
||||
// assert(false);
|
||||
}
|
||||
return re.value();
|
||||
|
@ -91,8 +92,10 @@ class MPSCQueueConsumerLock {
|
|||
auto re = queue.dequeue();
|
||||
if (re.has_value()) {
|
||||
while (sema.try_acquire() == false) {
|
||||
std::cerr << __FILE__ << ":" << __FUNCTION__ << " sema try acquire should be success, retrying, please check"
|
||||
<< std::endl;
|
||||
std::cerr
|
||||
<< __FILE__ << ":" << __FUNCTION__
|
||||
<< " sema try acquire should be success, retrying, please check"
|
||||
<< std::endl;
|
||||
// assert(false);
|
||||
}
|
||||
return re.value();
|
||||
|
|
|
@ -7,59 +7,71 @@
|
|||
#include <unordered_map>
|
||||
|
||||
class Statistics {
|
||||
public:
|
||||
public:
|
||||
// Increment the counter for a given key by a specified value (default is 1)
|
||||
void increment_counter(const std::string& key, int64_t value = 1) { counters_[key] += value; }
|
||||
void increment_counter(const std::string &key, int64_t value = 1) {
|
||||
counters_[key] += value;
|
||||
}
|
||||
|
||||
int64_t& get_counter(const std::string& key) { return counters_[key]; }
|
||||
int64_t &get_counter(const std::string &key) { return counters_[key]; }
|
||||
|
||||
// Start the timer for a given key
|
||||
void start_timer(const std::string& key) { active_timers_[key] = std::chrono::high_resolution_clock::now(); }
|
||||
void start_timer(const std::string &key) {
|
||||
active_timers_[key] = std::chrono::high_resolution_clock::now();
|
||||
}
|
||||
|
||||
// Stop the timer for a given key and update the total time and count
|
||||
void stop_timer(const std::string& key) {
|
||||
void stop_timer(const std::string &key) {
|
||||
auto start_it = active_timers_.find(key);
|
||||
if (start_it != active_timers_.end()) {
|
||||
auto duration = std::chrono::high_resolution_clock::now() - start_it->second;
|
||||
auto duration =
|
||||
std::chrono::high_resolution_clock::now() - start_it->second;
|
||||
timings_[key].total_time += duration;
|
||||
timings_[key].count += 1;
|
||||
active_timers_.erase(start_it);
|
||||
} else {
|
||||
// Handle error: stop_timer called without a matching start_timer
|
||||
std::cerr << "Warning: stop_timer called for key '" << key << "' without a matching start_timer.\n";
|
||||
std::cerr << "Warning: stop_timer called for key '" << key
|
||||
<< "' without a matching start_timer.\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Print out the collected statistical information
|
||||
void report() const {
|
||||
std::cout << "Counters:\n";
|
||||
for (const auto& kv : counters_) {
|
||||
for (const auto &kv : counters_) {
|
||||
std::cout << " " << kv.first << ": " << kv.second << "\n";
|
||||
}
|
||||
std::cout << "\nTimers:\n";
|
||||
for (const auto& kv : timings_) {
|
||||
for (const auto &kv : timings_) {
|
||||
std::cout << " " << kv.first << ": count = " << kv.second.count
|
||||
<< ", total_time = " << kv.second.total_time.count() << "s"
|
||||
<< ", average_time = " << (kv.second.count > 0 ? kv.second.total_time.count() / kv.second.count : 0)
|
||||
<< ", average_time = "
|
||||
<< (kv.second.count > 0
|
||||
? kv.second.total_time.count() / kv.second.count
|
||||
: 0)
|
||||
<< "s\n";
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
// Mapping from key to counter
|
||||
std::unordered_map<std::string, int64_t> counters_;
|
||||
|
||||
// Struct to hold timing information for a key
|
||||
struct TimingInfo {
|
||||
int64_t count = 0;
|
||||
std::chrono::duration<double> total_time = std::chrono::duration<double>::zero();
|
||||
std::chrono::duration<double> total_time =
|
||||
std::chrono::duration<double>::zero();
|
||||
};
|
||||
|
||||
// Mapping from key to timing information
|
||||
std::unordered_map<std::string, TimingInfo> timings_;
|
||||
|
||||
// Mapping from key to the start time of active timers
|
||||
std::unordered_map<std::string, std::chrono::high_resolution_clock::time_point> active_timers_;
|
||||
std::unordered_map<std::string,
|
||||
std::chrono::high_resolution_clock::time_point>
|
||||
active_timers_;
|
||||
};
|
||||
|
||||
#endif // STATISTICS_HPP
|
||||
#endif // STATISTICS_HPP
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#pragma once
|
||||
#include "readable_number.hpp"
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <iomanip>
|
||||
|
@ -6,7 +7,6 @@
|
|||
#include <map>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include "readable_number.hpp"
|
||||
|
||||
inline std::string doubleToStringR2(double value) {
|
||||
std::stringstream stream;
|
||||
|
@ -15,7 +15,7 @@ inline std::string doubleToStringR2(double value) {
|
|||
}
|
||||
|
||||
class Timer {
|
||||
public:
|
||||
public:
|
||||
std::string name;
|
||||
bool tmp_timer = false;
|
||||
|
||||
|
@ -49,10 +49,14 @@ class Timer {
|
|||
endTime = m_endTime;
|
||||
}
|
||||
|
||||
return std::chrono::duration_cast<std::chrono::nanoseconds>(endTime - m_startTime).count();
|
||||
return std::chrono::duration_cast<std::chrono::nanoseconds>(endTime -
|
||||
m_startTime)
|
||||
.count();
|
||||
}
|
||||
|
||||
void printElapsedMilliseconds() { std::cout << elapsedNs() / 1e6 << " ms" << std::endl; }
|
||||
void printElapsedMilliseconds() {
|
||||
std::cout << elapsedNs() / 1e6 << " ms" << std::endl;
|
||||
}
|
||||
|
||||
static std::string ns_to_string(double duration) {
|
||||
auto nano_sec = duration;
|
||||
|
@ -100,13 +104,13 @@ class Timer {
|
|||
return readable_number(ops) + "op/s";
|
||||
}
|
||||
|
||||
void merge(Timer& other) {
|
||||
void merge(Timer &other) {
|
||||
assert(m_isRunning == false);
|
||||
assert(other.m_isRunning == false);
|
||||
m_runningNs += other.runningTimeNs();
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
std::chrono::time_point<std::chrono::high_resolution_clock> m_startTime;
|
||||
std::chrono::time_point<std::chrono::high_resolution_clock> m_endTime;
|
||||
bool m_isRunning = false;
|
||||
|
@ -114,14 +118,14 @@ class Timer {
|
|||
};
|
||||
|
||||
class Counter {
|
||||
public:
|
||||
public:
|
||||
Counter() {}
|
||||
|
||||
std::map<std::string, size_t> counters;
|
||||
|
||||
void inc(const char* name, size_t num) { counters[name] += num; };
|
||||
void inc(const char *name, size_t num) { counters[name] += num; };
|
||||
void print() {
|
||||
for (auto& p : counters) {
|
||||
for (auto &p : counters) {
|
||||
std::cout << p.first << " : " << p.second << std::endl;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -1,122 +0,0 @@
|
|||
{
|
||||
"DeepSeek-Coder-V2-Instruct": {
|
||||
"hidden_size": 5120,
|
||||
"intermediate_size": 12288,
|
||||
"max_position_embeddings": 163840,
|
||||
"model_type": "deepseek_v2",
|
||||
"num_attention_heads": 128,
|
||||
"num_hidden_layers": 60,
|
||||
"num_key_value_heads": 128,
|
||||
"vocab_size": 102400
|
||||
},
|
||||
"DeepSeek-R1": {
|
||||
"hidden_size": 7168,
|
||||
"intermediate_size": 18432,
|
||||
"max_position_embeddings": 163840,
|
||||
"model_type": "deepseek_v3",
|
||||
"num_attention_heads": 128,
|
||||
"num_hidden_layers": 61,
|
||||
"num_key_value_heads": 128,
|
||||
"vocab_size": 129280
|
||||
},
|
||||
"DeepSeek-V2-Lite-Chat": {
|
||||
"hidden_size": 2048,
|
||||
"intermediate_size": 10944,
|
||||
"max_position_embeddings": 163840,
|
||||
"model_type": "deepseek_v2",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"num_key_value_heads": 16,
|
||||
"vocab_size": 102400
|
||||
},
|
||||
"DeepSeek-V3": {
|
||||
"hidden_size": 7168,
|
||||
"intermediate_size": 18432,
|
||||
"max_position_embeddings": 163840,
|
||||
"model_type": "deepseek_v3",
|
||||
"num_attention_heads": 128,
|
||||
"num_hidden_layers": 3,
|
||||
"num_key_value_heads": 128,
|
||||
"vocab_size": 129280
|
||||
},
|
||||
"DeepSeek-V3-bf16": {
|
||||
"hidden_size": 7168,
|
||||
"intermediate_size": 18432,
|
||||
"max_position_embeddings": 163840,
|
||||
"model_type": "deepseek_v3",
|
||||
"num_attention_heads": 128,
|
||||
"num_hidden_layers": 61,
|
||||
"num_key_value_heads": 128,
|
||||
"vocab_size": 129280
|
||||
},
|
||||
"LLaMA-2-7B-32K": {
|
||||
"hidden_size": 4096,
|
||||
"intermediate_size": 11008,
|
||||
"max_position_embeddings": 32768,
|
||||
"model_type": "llama",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 32,
|
||||
"num_key_value_heads": 32,
|
||||
"vocab_size": 32000
|
||||
},
|
||||
"Moonlight-16B-A3B-Instruct": {
|
||||
"hidden_size": 2048,
|
||||
"intermediate_size": 11264,
|
||||
"max_position_embeddings": 8192,
|
||||
"model_type": "deepseek_v3",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"num_key_value_heads": 16,
|
||||
"vocab_size": 163840
|
||||
},
|
||||
"Qwen2.5-32B-Instruct": {
|
||||
"hidden_size": 5120,
|
||||
"intermediate_size": 27648,
|
||||
"max_position_embeddings": 32768,
|
||||
"model_type": "qwen2",
|
||||
"num_attention_heads": 40,
|
||||
"num_hidden_layers": 64,
|
||||
"num_key_value_heads": 8,
|
||||
"vocab_size": 152064
|
||||
},
|
||||
"Qwen2.5-32B-Instruct-GPTQ-Int4": {
|
||||
"hidden_size": 5120,
|
||||
"intermediate_size": 27648,
|
||||
"max_position_embeddings": 32768,
|
||||
"model_type": "qwen2",
|
||||
"num_attention_heads": 40,
|
||||
"num_hidden_layers": 64,
|
||||
"num_key_value_heads": 8,
|
||||
"vocab_size": 152064
|
||||
},
|
||||
"Qwen2.5-7B-Instruct": {
|
||||
"hidden_size": 3584,
|
||||
"intermediate_size": 18944,
|
||||
"max_position_embeddings": 32768,
|
||||
"model_type": "qwen2",
|
||||
"num_attention_heads": 28,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 4,
|
||||
"vocab_size": 152064
|
||||
},
|
||||
"Qwen2.5-7B-Instruct-GPTQ-Int4": {
|
||||
"hidden_size": 3584,
|
||||
"intermediate_size": 18944,
|
||||
"max_position_embeddings": 32768,
|
||||
"model_type": "qwen2",
|
||||
"num_attention_heads": 28,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 4,
|
||||
"vocab_size": 152064
|
||||
},
|
||||
"qwen2-72b-instruct": {
|
||||
"hidden_size": 8192,
|
||||
"intermediate_size": 29568,
|
||||
"max_position_embeddings": 32768,
|
||||
"model_type": "qwen2",
|
||||
"num_attention_heads": 64,
|
||||
"num_hidden_layers": 80,
|
||||
"num_key_value_heads": 8,
|
||||
"vocab_size": 152064
|
||||
}
|
||||
}
|
|
@ -1,57 +0,0 @@
|
|||
{
|
||||
"BF16": {
|
||||
"block_element_count": 1,
|
||||
"block_element_size": 2,
|
||||
"bytes_per_element": 2.0,
|
||||
"can_be_used_as_vector": true,
|
||||
"has_min": false,
|
||||
"has_scale": false,
|
||||
"name": "BF16",
|
||||
"reference": "",
|
||||
"type_of_dot_vector": "BF16"
|
||||
},
|
||||
"FP16": {
|
||||
"block_element_count": 1,
|
||||
"block_element_size": 2,
|
||||
"bytes_per_element": 2.0,
|
||||
"can_be_used_as_vector": true,
|
||||
"has_min": false,
|
||||
"has_scale": false,
|
||||
"name": "FP16",
|
||||
"reference": "",
|
||||
"type_of_dot_vector": "FP16"
|
||||
},
|
||||
"FP32": {
|
||||
"block_element_count": 1,
|
||||
"block_element_size": 4,
|
||||
"bytes_per_element": 4.0,
|
||||
"can_be_used_as_vector": true,
|
||||
"has_min": false,
|
||||
"has_scale": false,
|
||||
"name": "FP32",
|
||||
"reference": "",
|
||||
"type_of_dot_vector": "FP32"
|
||||
},
|
||||
"Q4_0": {
|
||||
"block_element_count": 32,
|
||||
"block_element_size": 18,
|
||||
"bytes_per_element": 0.5625,
|
||||
"can_be_used_as_vector": false,
|
||||
"has_min": false,
|
||||
"has_scale": true,
|
||||
"name": "Q4_0",
|
||||
"reference": "https://huggingface.co/docs/hub/gguf",
|
||||
"type_of_dot_vector": "Q8_0"
|
||||
},
|
||||
"Q8_0": {
|
||||
"block_element_count": 32,
|
||||
"block_element_size": 34,
|
||||
"bytes_per_element": 1.0625,
|
||||
"can_be_used_as_vector": true,
|
||||
"has_min": false,
|
||||
"has_scale": true,
|
||||
"name": "Q8_0",
|
||||
"reference": "https://huggingface.co/docs/hub/gguf",
|
||||
"type_of_dot_vector": "Q8_0"
|
||||
}
|
||||
}
|
|
@ -70,6 +70,9 @@ class ArgumentParser:
|
|||
parser.add_argument("--batch_size", type=int, default=self.cfg.batch_size)
|
||||
parser.add_argument("--cache_lens", type=int, default=self.cfg.cache_lens)
|
||||
|
||||
# kvc2 config
|
||||
parser.add_argument("--kvc2_config_dir", type=str, default=self.cfg.kvc2_config_dir)
|
||||
|
||||
# log configs
|
||||
# log level: debug, info, warn, error, crit
|
||||
parser.add_argument("--log_dir", type=str, default=self.cfg.log_dir)
|
||||
|
|
|
@ -7,9 +7,7 @@ import sys, os
|
|||
import yaml, json
|
||||
from time import sleep
|
||||
|
||||
current_dir = os.path.dirname(__file__)
|
||||
# sched_path = os.path.abspath(os.path.join(current_dir, '../../../build/balance_serve/sched'))
|
||||
# sys.path.insert(0, sched_path)
|
||||
|
||||
import sched_ext
|
||||
from transformers import AutoConfig
|
||||
|
||||
|
@ -52,8 +50,7 @@ def create_sched_settings(args):
|
|||
settings.v_cache_on = False
|
||||
|
||||
settings.kvc2_root_path = '/mnt/data/persist-kvc'
|
||||
settings.kvc2_config_path = os.path.join(current_dir, "..", "..", "configs")
|
||||
print(os.path.join(current_dir, "..", "..", "configs"))
|
||||
settings.kvc2_config_path = args.kvc2_config_dir
|
||||
settings.memory_pool_size_GB = args.cpu_memory_size_GB
|
||||
settings.evict_count = 40
|
||||
settings.kvc2_metrics_port = args.kvc2_metrics_port
|
||||
|
|
|
@ -34,12 +34,15 @@ class Config(metaclass=Singleton):
|
|||
|
||||
user_path: str = os.path.expanduser("~")
|
||||
localstore_path: str = os.path.join(user_path, ".ktransformers")
|
||||
kvc2_config_dir = os.path.join(localstore_path, "kvc2")
|
||||
config_path: str = os.path.join(localstore_path, Config.CONFIG_FILE_NAME)
|
||||
if not os.path.exists(config_yaml):
|
||||
print(f"Can't find config file, {config_yaml}")
|
||||
exit(-1)
|
||||
if not os.path.exists(localstore_path):
|
||||
os.mkdir(localstore_path)
|
||||
if not os.path.exists(kvc2_config_dir):
|
||||
os.mkdir(kvc2_config_dir)
|
||||
if not os.path.exists(config_path):
|
||||
shutil.copyfile(config_yaml, config_path)
|
||||
with open(config_path, "r", encoding="utf-8") as fp:
|
||||
|
@ -62,10 +65,13 @@ class Config(metaclass=Singleton):
|
|||
self.localstore_path: str = os.path.join(self.user_path, ".ktransformers")
|
||||
# log configs
|
||||
self.log_dir = os.path.join(self.localstore_path, cfg["log"]["dir"])
|
||||
if not os.path.exists(self.log_dir):
|
||||
os.mkdir(self.log_dir)
|
||||
self.log_file = cfg["log"]["file"]
|
||||
self.log_level = cfg["log"]["level"]
|
||||
self.backup_count = cfg["log"]["backup_count"]
|
||||
|
||||
self.kvc2_config_dir = os.path.join(self.localstore_path, "kvc2")
|
||||
# server configs
|
||||
self.server: dict = cfg.get("server", {})
|
||||
self.server_ip = self.server.get("ip", "0.0.0.0")
|
||||
|
|
Loading…
Add table
Reference in a new issue