mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-07 21:19:51 +00:00
commit
d3fae09252
13 changed files with 465 additions and 55 deletions
3
.gitmodules
vendored
3
.gitmodules
vendored
|
@ -17,3 +17,6 @@
|
|||
[submodule "third_party/prometheus-cpp"]
|
||||
path = third_party/prometheus-cpp
|
||||
url = https://github.com/jupp0r/prometheus-cpp
|
||||
[submodule "third_party/PhotonLibOS"]
|
||||
path = third_party/PhotonLibOS
|
||||
url = https://github.com/alibaba/PhotonLibOS.git
|
||||
|
|
|
@ -13,9 +13,27 @@ set(CMAKE_CXX_STANDARD 20)
|
|||
# set(CMAKE_CXX_FLAGS "-Og -march=native -Wall -Wextra -g -fPIC")
|
||||
# set(CMAKE_BUILD_TYPE "Debug")
|
||||
set(CMAKE_CXX_FLAGS "-O3 -march=native -Wall -Wextra -fPIC")
|
||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI})
|
||||
set(CMAKE_BUILD_TYPE "Release")
|
||||
|
||||
|
||||
if(NOT DEFINED _GLIBCXX_USE_CXX11_ABI)
|
||||
find_package(Python3 REQUIRED COMPONENTS Interpreter)
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -c
|
||||
"import torch; print('1' if torch.compiled_with_cxx11_abi() else '0')"
|
||||
OUTPUT_VARIABLE ABI_FLAG
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
set(_GLIBCXX_USE_CXX11_ABI ${ABI_FLAG} CACHE STRING "C++11 ABI setting from PyTorch" FORCE)
|
||||
endif()
|
||||
|
||||
# 无论是否是自动检测,都传给编译器
|
||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI})
|
||||
|
||||
message(STATUS "_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI}")
|
||||
|
||||
file(GLOB_RECURSE FMT_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.h")
|
||||
|
||||
add_custom_target(
|
||||
|
|
|
@ -11,13 +11,30 @@ set(CMAKE_CXX_STANDARD 20)
|
|||
# set(CMAKE_CXX_FLAGS "-Og -march=native -Wall -Wextra -Wpedantic -g -fsanitize=address")
|
||||
# set(CMAKE_CXX_FLAGS "-march=native -Wall -Wextra -Wpedantic -g")
|
||||
# set(CMAKE_CXX_FLAGS "-fPIC -O3 -ffast-math -march=native -Wall -Wextra -g")
|
||||
set(CMAKE_BUILD_TYPE "Release")
|
||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI})
|
||||
# set(CMAKE_BUILD_TYPE "Debug")
|
||||
# set(CMAKE_BUILD_TYPE "Release")
|
||||
set(CMAKE_BUILD_TYPE "Debug")
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
set(BUILD_TEST OFF)
|
||||
set(BUILD_PYTHON_EXT OFF)
|
||||
|
||||
if(NOT DEFINED _GLIBCXX_USE_CXX11_ABI)
|
||||
find_package(Python3 REQUIRED COMPONENTS Interpreter)
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -c
|
||||
"import torch; print('1' if torch.compiled_with_cxx11_abi() else '0')"
|
||||
OUTPUT_VARIABLE ABI_FLAG
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
set(_GLIBCXX_USE_CXX11_ABI ${ABI_FLAG} CACHE STRING "C++11 ABI setting from PyTorch" FORCE)
|
||||
endif()
|
||||
|
||||
# 无论是否是自动检测,都传给编译器
|
||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI})
|
||||
|
||||
message(STATUS "_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI}")
|
||||
|
||||
# set(USE_IO_URING ON)
|
||||
if(USE_IO_URING)
|
||||
message(STATUS "Using io_uring")
|
||||
|
@ -54,6 +71,8 @@ find_package(Torch REQUIRED PATHS "${TORCH_INSTALL_PREFIX}/share/cmake/Torch" NO
|
|||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
|
||||
|
||||
# include_directories(/usr/include/tbb)
|
||||
# link_directories(/usr/lib64)
|
||||
find_package(TBB REQUIRED)
|
||||
find_package(CUDA REQUIRED)
|
||||
|
||||
|
@ -101,3 +120,13 @@ if(BUILD_PYTHON_EXT)
|
|||
DESTINATION ${CMAKE_BINARY_DIR}/output)
|
||||
endif()
|
||||
|
||||
if(USE_IO_URING)
|
||||
set(PHOTON_ENABLE_URING ON CACHE BOOL "Enable io_uring")
|
||||
endif()
|
||||
|
||||
set(PHOTON_CXX_STANDARD 14 CACHE INTERNAL "C++ standard")
|
||||
|
||||
set(CMAKE_CXX_FLAGS "-O3 -march=native")
|
||||
message(STATUS "CMAKE_CXX_FLAGS of PhotonLibOS: ${CMAKE_CXX_FLAGS}")
|
||||
|
||||
add_subdirectory(${THIRD_PARTY_DIR}/PhotonLibOS ${THIRD_PARTY_BUILD_DIR}/PhotonLibOS)
|
||||
|
|
|
@ -33,8 +33,9 @@ target_link_libraries(kvc2 PUBLIC TBB::tbb xxHash::xxhash cache_entry cuda_strea
|
|||
message(STATUS "CMAKE_SOURCE_DIR: " ${CMAKE_SOURCE_DIR})
|
||||
add_library(async_store async_store.cpp)
|
||||
target_include_directories(async_store PRIVATE ${THIRD_PARTY_DIR}/nlohmann/single_include)
|
||||
target_include_directories(async_store PRIVATE ${THIRD_PARTY_DIR}/PhotonLibOS/include)
|
||||
target_include_directories(async_store PRIVATE ${THIRD_PARTY_DIR}/spdlog/include)
|
||||
target_link_libraries(async_store PUBLIC pthread)
|
||||
target_link_libraries(async_store PUBLIC photon_static pthread)
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -13,13 +13,25 @@
|
|||
#include <thread>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <photon/common/alog.h>
|
||||
#include <photon/common/io-alloc.h>
|
||||
#include <photon/fs/localfs.h>
|
||||
#include <photon/photon.h>
|
||||
#include <photon/thread/thread11.h>
|
||||
#include "utils/lock_free_queue.hpp"
|
||||
|
||||
#include "async_store.hh"
|
||||
|
||||
namespace async_store {
|
||||
|
||||
#ifdef USE_IO_URING
|
||||
static int io_engine_type = photon::fs::ioengine_iouring;
|
||||
#else
|
||||
static int io_engine_type = photon::fs::ioengine_libaio;
|
||||
#endif
|
||||
|
||||
struct ArrayStore {
|
||||
photon::mutex lock;
|
||||
static const size_t DeviceBlockSize = 512;
|
||||
|
||||
const size_t element_size;
|
||||
|
@ -30,28 +42,62 @@ struct ArrayStore {
|
|||
size_t size_in_bytes() { return size * element_size_aligned; }
|
||||
|
||||
std::filesystem::path data_path;
|
||||
std::unique_ptr<photon::fs::IFile> file;
|
||||
|
||||
void extend(size_t to) {
|
||||
if (to <= size) {
|
||||
return;
|
||||
}
|
||||
// TODO: extend file
|
||||
file->ftruncate(to * element_size_aligned);
|
||||
size = to;
|
||||
// LOG_INFO("Extend file to `, size `", to, size_in_bytes());
|
||||
LOG_INFO("Extend file to `, size `", to, size_in_bytes());
|
||||
}
|
||||
|
||||
ArrayStore(size_t element_size, size_t size, std::filesystem::path data_path)
|
||||
: element_size(element_size),
|
||||
element_size_aligned((element_size + DeviceBlockSize - 1) / DeviceBlockSize),
|
||||
element_size_aligned(align_up(element_size, DeviceBlockSize)),
|
||||
data_path(data_path) {
|
||||
// TODO: prefix cache
|
||||
double write_amplification = element_size_aligned * 1.0 / element_size;
|
||||
if (write_amplification > 1.1) {
|
||||
LOG_WARN("Warning: write amplification is ` for `", write_amplification, data_path.c_str());
|
||||
}
|
||||
|
||||
if (std::filesystem::exists(data_path)) {
|
||||
LOG_INFO("Opening `", data_path.c_str());
|
||||
this->file = std::unique_ptr<photon::fs::IFile>(
|
||||
photon::fs::open_localfile_adaptor(data_path.c_str(), O_RDWR | O_DIRECT, 0664, io_engine_type));
|
||||
} else {
|
||||
LOG_INFO("Creating `", data_path.c_str());
|
||||
this->file = std::unique_ptr<photon::fs::IFile>(
|
||||
photon::fs::open_localfile_adaptor(data_path.c_str(), O_RDWR | O_CREAT | O_DIRECT, 0664, io_engine_type));
|
||||
}
|
||||
if (file.get() == nullptr) {
|
||||
LOG_ERROR("Error opening file");
|
||||
}
|
||||
struct stat buf;
|
||||
file->fstat(&buf);
|
||||
this->size = buf.st_size / element_size_aligned;
|
||||
|
||||
extend(size);
|
||||
}
|
||||
|
||||
void read(size_t index, void* buffer) {
|
||||
// TODO: read from file
|
||||
size_t ret = file->pread(buffer, element_size, index * element_size_aligned);
|
||||
if (ret != element_size) {
|
||||
perror("Error reading from file");
|
||||
LOG_ERROR("Error reading to file ` ` `, ret `", buffer, element_size, index * element_size_aligned, ret);
|
||||
}
|
||||
file->fdatasync();
|
||||
file->fsync();
|
||||
}
|
||||
void write(size_t index, void* buffer) {
|
||||
// TODO: write to file
|
||||
size_t ret = file->pwrite(buffer, element_size, index * element_size_aligned);
|
||||
if (ret != element_size) {
|
||||
perror("Error writing to file");
|
||||
|
||||
LOG_ERROR("Error writing to file ` ` ` `, ret `", file.get(), buffer, element_size, index * element_size_aligned,
|
||||
ret);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -98,15 +144,66 @@ struct IODealerImpl {
|
|||
IODealerImpl(bool use_io_uring, int IO_DEPTH) : use_io_uring(use_io_uring), IO_DEPTH(IO_DEPTH) {}
|
||||
|
||||
void queue_consumer() {
|
||||
// TODO:
|
||||
while (stop == false) {
|
||||
if (auto request = ioQueue.dequeue(); request) {
|
||||
if (request->write) {
|
||||
request->store->write(request->index, request->data);
|
||||
} else {
|
||||
request->store->read(request->index, request->data);
|
||||
}
|
||||
io_cnt += 1;
|
||||
io_amount += request->store->element_size_aligned;
|
||||
|
||||
if (request->need_promise) {
|
||||
// LOG_INFO("Set Promise `",request->promise);
|
||||
request->promise->set();
|
||||
}
|
||||
// photon::thread_yield();
|
||||
} else {
|
||||
// 队列为空,避免忙等
|
||||
photon::thread_usleep(10);
|
||||
// photon::thread_yield();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void io_perf() {
|
||||
// TODO:
|
||||
LOG_INFO("IO Depth `", IO_DEPTH);
|
||||
while (stop == false) {
|
||||
photon::thread_sleep(1);
|
||||
if (io_cnt == 0) {
|
||||
continue;
|
||||
}
|
||||
LOG_INFO("IO queue remaining: ` , processed ` M. IO count: ` Kops, ` M/s",
|
||||
(ioQueue.enqueue_count - ioQueue.dequeue_count), ioQueue.dequeue_count / 1e6, io_cnt / 1e3,
|
||||
io_amount / 1e6);
|
||||
io_cnt = 0;
|
||||
io_amount = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void io_dealer() {
|
||||
// TODO:
|
||||
int ev_engine = use_io_uring ? photon::INIT_EVENT_IOURING : photon::INIT_EVENT_EPOLL;
|
||||
int io_engine = use_io_uring ? photon::INIT_IO_NONE : photon::INIT_IO_LIBAIO;
|
||||
int fs_io_engine = use_io_uring ? photon::fs::ioengine_iouring : photon::fs::ioengine_libaio;
|
||||
io_engine_type = fs_io_engine;
|
||||
int ret = photon::init(ev_engine, io_engine, photon::PhotonOptions{.libaio_queue_depth = 512});
|
||||
if (ret != 0) {
|
||||
LOG_ERROR("PHOTON INIT FAILED");
|
||||
exit(1);
|
||||
}
|
||||
DEFER(photon::fini());
|
||||
std::vector<photon::join_handle*> handles;
|
||||
|
||||
handles.push_back(photon::thread_enable_join(photon::thread_create11([this]() { io_perf(); })));
|
||||
|
||||
LOG_INFO("Initializing IO Dealer");
|
||||
for (int i = 0; i < IO_DEPTH; i++) {
|
||||
handles.push_back(photon::thread_enable_join(photon::thread_create11([this]() { queue_consumer(); })));
|
||||
}
|
||||
for (auto& handle : handles) {
|
||||
photon::thread_join(handle);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -130,7 +227,7 @@ void IODealer::stop() {
|
|||
if (io_impl->stop) {
|
||||
return;
|
||||
}
|
||||
// LOG_INFO("Stopping IO Dealer");
|
||||
LOG_INFO("Stopping IO Dealer");
|
||||
io_impl->stop = true;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#ifndef __DEFS_H_
|
||||
#define __DEFS_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
@ -27,7 +28,7 @@ struct CacheInfo {
|
|||
size_t hidden_layer_count();
|
||||
std::filesystem::path path(std::optional<size_t> which_layer = std::nullopt);
|
||||
bool operator==(const CacheInfo& other) const;
|
||||
size_t element_size(size_t block_length);
|
||||
size_t element_size(size_t block_length, size_t head_dim);
|
||||
size_t hash_value() const;
|
||||
};
|
||||
|
||||
|
|
|
@ -62,10 +62,13 @@ bool CacheInfo::operator==(const CacheInfo& other) const {
|
|||
return model_name == other.model_name && is_key_cache == other.is_key_cache && quant_type == other.quant_type;
|
||||
}
|
||||
|
||||
size_t CacheInfo::element_size(size_t block_length) {
|
||||
size_t count = model_configs[model_name].hidden_size * block_length;
|
||||
size_t CacheInfo::element_size(size_t block_length, size_t head_dim) {
|
||||
size_t count = head_dim * block_length;
|
||||
auto& q = quant_configs[quant_type];
|
||||
return count / q.block_element_count * q.block_element_size;
|
||||
if (q.block_element_count == 0 || q.block_element_size == 0)
|
||||
return count * 4; // default to FP32
|
||||
else
|
||||
return count / q.block_element_count * q.block_element_size;
|
||||
}
|
||||
|
||||
size_t CacheInfo::hash_value() const {
|
||||
|
@ -220,7 +223,7 @@ struct DiskCacheAllocator {
|
|||
return re;
|
||||
}
|
||||
|
||||
DiskCacheAllocator(std::filesystem::path path, CacheInfo info) : path(path), info(info) {
|
||||
DiskCacheAllocator(std::filesystem::path path, CacheInfo info, size_t head_dim) : path(path), info(info) {
|
||||
// SPDLOG_DEBUG("Create DiskCacheAllocator {}", path.c_str());
|
||||
auto allocator_path = path / info.path();
|
||||
if (std::filesystem::exists(allocator_path) == false) {
|
||||
|
@ -231,7 +234,7 @@ struct DiskCacheAllocator {
|
|||
|
||||
for (size_t i = 0; i < info.hidden_layer_count(); i++) {
|
||||
// SPDLOG_DEBUG("Create store {} for {}", (path / info.path(i)).c_str(),i);
|
||||
auto store = async_store::create_or_open_store(info.element_size(NumTokenPerBlock), 1000, path / info.path(i));
|
||||
auto store = async_store::create_or_open_store(info.element_size(NumTokenPerBlock, head_dim), 1000, path / info.path(i));
|
||||
stores.push_back(store);
|
||||
}
|
||||
update_capacity();
|
||||
|
@ -263,7 +266,9 @@ struct DiskCacheManager {
|
|||
// SPDLOG_DEBUG("Make Allocator {}",allocator_json.dump());
|
||||
CacheInfo info;
|
||||
allocator_json.at("info").get_to(info);
|
||||
auto allocator = std::make_shared<DiskCacheAllocator>(nlohmann_json_t.config.path, info);
|
||||
assert(nlohmann_json_t.config.gpu_cache_config.has_value());
|
||||
size_t head_dim = nlohmann_json_t.config.gpu_cache_config.value().k_head_dim;
|
||||
auto allocator = std::make_shared<DiskCacheAllocator>(nlohmann_json_t.config.path, info, head_dim);
|
||||
allocator_json.at("allocator").get_to(*allocator);
|
||||
nlohmann_json_t.allocators[info] = allocator;
|
||||
}
|
||||
|
@ -280,7 +285,9 @@ struct DiskCacheManager {
|
|||
{
|
||||
std::lock_guard<std::mutex> lg(lock);
|
||||
if (allocators.count(info) == 0) {
|
||||
allocators.emplace(info, std::make_shared<DiskCacheAllocator>(config.path, info));
|
||||
assert(config.gpu_cache_config.has_value());
|
||||
size_t head_dim = config.gpu_cache_config.value().k_head_dim;
|
||||
allocators.emplace(info, std::make_shared<DiskCacheAllocator>(config.path, info, head_dim));
|
||||
}
|
||||
}
|
||||
return allocators.at(info);
|
||||
|
@ -561,7 +568,39 @@ struct PrefixTree {
|
|||
if (need_lock) {
|
||||
sl = std::shared_lock<std::shared_mutex>(rw_lock);
|
||||
}
|
||||
// TODO: prefix cache
|
||||
|
||||
// We disable full seq match because this is awful when we try to maintain the flush back.
|
||||
// auto full_seq_hash = TokensHasher::hash(data, length);
|
||||
// SPDLOG_DEBUG("Look up prefix with hash {:016x} length: {}", full_seq_hash, length);
|
||||
// // debug();
|
||||
// if (prefix_map.count(full_seq_hash)) {
|
||||
// SPDLOG_DEBUG("Full Match", full_seq_hash);
|
||||
// return {prefix_map.at(full_seq_hash), length};
|
||||
// }
|
||||
|
||||
TokenLength block_length = length / NumTokenPerBlock; // do not need the tail
|
||||
TokenLength l = 0, r = block_length + 1;
|
||||
while (l + 1 < r) {
|
||||
TokenLength mid = (l + r) / 2; // [1,block_length]
|
||||
auto hash = TokensHasher::hash(data, mid * NumTokenPerBlock);
|
||||
if (prefix_map.count(hash)) {
|
||||
SPDLOG_DEBUG("Binary Prefix Search: Found prefix with hash {:016x}", hash);
|
||||
l = mid;
|
||||
} else {
|
||||
SPDLOG_DEBUG("Binary Prefix Search: Not Found prefix with hash {:016x}", hash);
|
||||
r = mid;
|
||||
}
|
||||
}
|
||||
|
||||
met->lookup_prefixmatch_length->Observe(l * NumTokenPerBlock);
|
||||
met->matched_length_percentage->Observe(l * NumTokenPerBlock * 100.0 / length);
|
||||
|
||||
if (l == 0)
|
||||
return {nullptr, 0};
|
||||
|
||||
auto hash = TokensHasher::hash(data, l * NumTokenPerBlock);
|
||||
|
||||
return {prefix_map.at(hash).first.get(), l * NumTokenPerBlock};
|
||||
}
|
||||
|
||||
PrefixMatch look_up_or_insert(Token* data, TokenLength length) {
|
||||
|
@ -699,7 +738,18 @@ struct DoubleCacheHandle : public DoubleCacheHandleInterface {
|
|||
}
|
||||
}
|
||||
}
|
||||
std::vector<MatchStatus> matched_status() override { assert(false); }
|
||||
std::vector<MatchStatus> matched_status() override {
|
||||
// if (enable_alt == false) {
|
||||
// SPDLOG_ERROR("Matched Status is not available when enable_alt is false");
|
||||
// assert(0);
|
||||
// }
|
||||
// std::vector<MatchStatus> re;
|
||||
// for (auto& [p, idx, status] : match_by_blocks.matches) {
|
||||
// re.push_back(status);
|
||||
// }
|
||||
// return re;
|
||||
assert(false);
|
||||
}
|
||||
|
||||
bool any_match() {
|
||||
if (enable_alt) {
|
||||
|
@ -988,30 +1038,7 @@ struct DoubleCacheHandle : public DoubleCacheHandleInterface {
|
|||
// set_raw_handles(true, k);
|
||||
// set_raw_handles(false, v);
|
||||
// }
|
||||
void set_raw_handles(bool is_key_cache, const std::vector<layer_data>& layer_data) {
|
||||
auto single_set_raw_handles = [layer_data](CacheInfo info,
|
||||
std::vector<std::vector<std::shared_ptr<CacheBlockEntry>>>& handles) {
|
||||
handles.resize(layer_data.size());
|
||||
for (size_t i = 0; i < info.hidden_layer_count(); i++) {
|
||||
auto& layer = layer_data[i];
|
||||
handles[i].clear();
|
||||
for (auto& block_data : layer) {
|
||||
auto handle = std::make_shared<CacheBlockEntry>();
|
||||
handle->data = reinterpret_cast<void*>(block_data);
|
||||
handle->size = info.element_size(NumTokenPerBlock);
|
||||
handles[i].push_back(handle);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (is_key_cache) {
|
||||
is_k_cache_on = true;
|
||||
single_set_raw_handles(k_info(), k_cache_handles);
|
||||
} else {
|
||||
is_v_cache_on = true;
|
||||
single_set_raw_handles(v_info(), v_cache_handles);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
std::vector<layer_data> export_raw_pointers(bool is_key_cache) {
|
||||
std::vector<layer_data> re;
|
||||
|
@ -1042,6 +1069,7 @@ struct DoubleCacheHandle : public DoubleCacheHandleInterface {
|
|||
return re;
|
||||
}
|
||||
|
||||
void set_raw_handles(bool is_key_cache, const std::vector<layer_data>& layer_data);
|
||||
void get_handles();
|
||||
void get_empty_handles();
|
||||
|
||||
|
@ -1257,7 +1285,23 @@ struct KVC2 : KVC2Interface {
|
|||
re->kvc2_top = this;
|
||||
SPDLOG_DEBUG("Lookup TokenLength {}", length);
|
||||
if (config.gpu_only == false) {
|
||||
// TODO:
|
||||
re->match = tree->look_up(id, length);
|
||||
re->get_handles();
|
||||
if (re->alloc_on_cpu() == false) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
SPDLOG_DEBUG("Found {}, Prompt Length {}, Estimated Length {}", re->match.match_length, length, estimated_length);
|
||||
if (re->match.prefix) {
|
||||
re->collect_locations();
|
||||
auto disk_io_helper = re->segment_io(io_dealer.get(), disk_cache.get(), 0,
|
||||
div_up(re->match.match_length, NumTokenPerBlock), IO_Read);
|
||||
// TODO: async is break here, do something later
|
||||
disk_io_helper->wait();
|
||||
SPDLOG_INFO("Loaded to mem");
|
||||
} else {
|
||||
SPDLOG_INFO("No Match, No need to load");
|
||||
}
|
||||
}
|
||||
return re;
|
||||
};
|
||||
|
@ -1408,11 +1452,37 @@ DoubleCacheHandle::~DoubleCacheHandle() {
|
|||
}
|
||||
};
|
||||
|
||||
void DoubleCacheHandle::set_raw_handles(bool is_key_cache, const std::vector<layer_data>& layer_data) {
|
||||
auto single_set_raw_handles = [layer_data](CacheInfo info, size_t head_dim,
|
||||
std::vector<std::vector<std::shared_ptr<CacheBlockEntry>>>& handles) {
|
||||
handles.resize(layer_data.size());
|
||||
for (size_t i = 0; i < info.hidden_layer_count(); i++) {
|
||||
auto& layer = layer_data[i];
|
||||
handles[i].clear();
|
||||
for (auto& block_data : layer) {
|
||||
auto handle = std::make_shared<CacheBlockEntry>();
|
||||
handle->data = reinterpret_cast<void*>(block_data);
|
||||
handle->size = info.element_size(NumTokenPerBlock, head_dim);
|
||||
handles[i].push_back(handle);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (is_key_cache) {
|
||||
is_k_cache_on = true;
|
||||
single_set_raw_handles(k_info(), kvc2_top->config.gpu_cache_config.value().k_head_dim, k_cache_handles);
|
||||
} else {
|
||||
is_v_cache_on = true;
|
||||
single_set_raw_handles(v_info(), kvc2_top->config.gpu_cache_config.value().k_head_dim, v_cache_handles);
|
||||
}
|
||||
}
|
||||
|
||||
void DoubleCacheHandle::get_handles() {
|
||||
size_t new_count = 0, total_count = 0;
|
||||
auto get_info_handles = [this, &new_count, &total_count](
|
||||
CacheInfo info, std::vector<std::vector<std::shared_ptr<CacheBlockEntry>>>& layers) {
|
||||
auto total_block_count = div_up(estimated_length, NumTokenPerBlock);
|
||||
size_t head_dim = kvc2_top->config.gpu_cache_config.value().k_head_dim;
|
||||
for (size_t l = 0; l < info.hidden_layer_count(); l++) {
|
||||
auto hashes = match.matched_hashes(info, l);
|
||||
layers[l].resize(total_block_count, nullptr);
|
||||
|
@ -1422,7 +1492,7 @@ void DoubleCacheHandle::get_handles() {
|
|||
key = hashes[i];
|
||||
bool is_new;
|
||||
total_count += 1;
|
||||
layers[l][i] = this->kvc2_top->cache_manager->get(is_new, info.element_size(NumTokenPerBlock), key);
|
||||
layers[l][i] = this->kvc2_top->cache_manager->get(is_new, info.element_size(NumTokenPerBlock, head_dim), key);
|
||||
if (is_new)
|
||||
new_count += 1;
|
||||
layers[l][i]->cache_info = info;
|
||||
|
|
|
@ -56,6 +56,18 @@ class MPSCQueue {
|
|||
return std::nullopt;
|
||||
}
|
||||
|
||||
// 忙等待 dequeue
|
||||
std::optional<T> busy_wait_dequeue() {
|
||||
while (true) {
|
||||
std::optional<T> re = dequeue();
|
||||
if (re.has_value()) {
|
||||
return re;
|
||||
}
|
||||
// std::this_thread::yield();
|
||||
}
|
||||
throw std::runtime_error("Should not be here");
|
||||
}
|
||||
|
||||
size_t size() { return enqueue_count.load() - dequeue_count; }
|
||||
};
|
||||
|
||||
|
@ -83,7 +95,7 @@ class MPSCQueueConsumerLock {
|
|||
return re.value();
|
||||
}
|
||||
sema.acquire();
|
||||
return queue.dequeue().value();
|
||||
return queue.busy_wait_dequeue().value();
|
||||
}
|
||||
|
||||
size_t size() { return queue.size(); }
|
||||
|
|
|
@ -55,6 +55,18 @@ public:
|
|||
return std::nullopt;
|
||||
}
|
||||
|
||||
// 忙等待 dequeue
|
||||
std::optional<T> busy_wait_dequeue() {
|
||||
while (true) {
|
||||
std::optional<T> re = dequeue();
|
||||
if (re.has_value()) {
|
||||
return re;
|
||||
}
|
||||
// std::this_thread::yield();
|
||||
}
|
||||
throw std::runtime_error("Should not be here");
|
||||
}
|
||||
|
||||
size_t size() { return enqueue_count.load() - dequeue_count; }
|
||||
};
|
||||
|
||||
|
@ -84,7 +96,7 @@ public:
|
|||
return re.value();
|
||||
}
|
||||
sema.acquire();
|
||||
return queue.dequeue().value();
|
||||
return queue.busy_wait_dequeue().value();
|
||||
}
|
||||
|
||||
template <typename Rep, typename Period>
|
||||
|
@ -102,7 +114,7 @@ public:
|
|||
}
|
||||
|
||||
if (sema.try_acquire_for(dur)) {
|
||||
return queue.dequeue().value();
|
||||
return queue.busy_wait_dequeue().value();
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
|
34
doc/en/prefix_cache.md
Normal file
34
doc/en/prefix_cache.md
Normal file
|
@ -0,0 +1,34 @@
|
|||
## Enabling Prefix Cache Mode in KTransformers
|
||||
|
||||
To enable **Prefix Cache Mode** in KTransformers, you need to modify the configuration file and recompile the project.
|
||||
|
||||
### Step 1: Modify the Configuration File
|
||||
|
||||
Edit the `./ktransformers/configs/config.yaml` file with the following content (you can adjust the values according to your needs):
|
||||
|
||||
```yaml
|
||||
attn:
|
||||
page_size: 16 # Size of a page in KV Cache.
|
||||
chunk_size: 256
|
||||
kvc2:
|
||||
gpu_only: false # Set to false to enable prefix cache mode (Disk + CPU + GPU KV storage)
|
||||
utilization_percentage: 1.0
|
||||
cpu_memory_size_GB: 500 # Amount of CPU memory allocated for KV Cache
|
||||
```
|
||||
|
||||
### Step 2: Update Submodules and Recompile
|
||||
|
||||
If this is your first time using prefix cache mode, please update the submodules first:
|
||||
|
||||
```bash
|
||||
git submodule update --init --recursive # Update PhotonLibOS submodule
|
||||
```
|
||||
|
||||
Then recompile the project:
|
||||
|
||||
```bash
|
||||
# Install single NUMA dependencies
|
||||
USE_BALANCE_SERVE=1 bash ./install.sh
|
||||
# For those who have two cpu and 1T RAM(Dual NUMA):
|
||||
USE_BALANCE_SERVE=1 USE_NUMA=1 bash ./install.sh
|
||||
```
|
|
@ -67,6 +67,6 @@ attn:
|
|||
page_size: 256
|
||||
chunk_size: 256
|
||||
kvc2:
|
||||
gpu_only: true
|
||||
gpu_only: false
|
||||
utilization_percentage: 1.0
|
||||
cpu_memory_size_GB: 500
|
||||
|
|
132
ktransformers/tests/test_prefix.py
Normal file
132
ktransformers/tests/test_prefix.py
Normal file
|
@ -0,0 +1,132 @@
|
|||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import aiohttp
|
||||
import random
|
||||
import argparse
|
||||
import yaml
|
||||
import os
|
||||
import time
|
||||
from time import sleep
|
||||
|
||||
decodesz = 128
|
||||
# Server URL (replace with your server URL)
|
||||
decodesz_list = [128]
|
||||
prefill_speeds = []
|
||||
decode_speeds = []
|
||||
|
||||
async def fetch_message_once(session, request_id, messages, max_tokens, model):
|
||||
try:
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
"temperature": 0.3,
|
||||
"top_p": 1.0,
|
||||
"stream": True,
|
||||
"return_speed": True,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
|
||||
headers = {
|
||||
'accept': 'application/json',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
async with session.post(SERVER_URL, json=payload, headers=headers, timeout=500000) as response:
|
||||
if response.status != 200:
|
||||
print(f"[Request {request_id}] Error: Status {response.status}")
|
||||
return None, None, None
|
||||
|
||||
buffer = ""
|
||||
usage_info = None
|
||||
answer = ""
|
||||
|
||||
async for line in response.content:
|
||||
decoded_line = line.decode("utf-8").strip()
|
||||
if not decoded_line or not decoded_line.startswith("data: "):
|
||||
continue
|
||||
|
||||
decoded_line = decoded_line[6:].strip()
|
||||
if not decoded_line:
|
||||
continue
|
||||
|
||||
response_data = json.loads(decoded_line)
|
||||
|
||||
if "usage" in response_data:
|
||||
usage_info = response_data["usage"]
|
||||
|
||||
choices = response_data.get("choices", [])
|
||||
if not choices:
|
||||
continue
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
token = delta.get("content", "")
|
||||
if token:
|
||||
buffer += token
|
||||
answer += token
|
||||
|
||||
finish_reason = choices[0].get("finish_reason", None)
|
||||
if finish_reason:
|
||||
break
|
||||
|
||||
return answer.strip(), usage_info, buffer.strip()
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Request {request_id}] Exception: {e}")
|
||||
return None, None, None
|
||||
|
||||
|
||||
async def multi_turn_conversation(session, request_id, rounds, max_tokens, model):
|
||||
prompt = ["介绍一下秦始皇", "秦始皇的成就有哪些", "秦始皇的历史影响", "介绍一下秦始皇的陵墓", "秦始皇的统一措施", "秦始皇的政治制度", "秦始皇的文化政策", "秦始皇的军事行动"]
|
||||
|
||||
messages = [{"role": "system", "content": ""}]
|
||||
global prefill_speeds, decode_speeds
|
||||
|
||||
for i in range(rounds):
|
||||
user_msg = f"这是第{i + 1}轮对话,请回答以下问题:{prompt[i % len(prompt)]}"
|
||||
messages.append({"role": "user", "content": user_msg})
|
||||
print(f"\n[Request {request_id}] >> User: {user_msg}")
|
||||
|
||||
answer, usage_info, _ = await fetch_message_once(session, request_id, messages, max_tokens, model)
|
||||
if answer:
|
||||
messages.append({"role": "user", "content": answer})
|
||||
print(f"[Request {request_id}] << Assistant: {answer}")
|
||||
|
||||
if usage_info:
|
||||
prefill_speed = usage_info["prompt_tokens"] / usage_info["prefill_time"]
|
||||
decode_speed = usage_info["completion_tokens"] / usage_info["decode_time"]
|
||||
prefill_speeds.append(prefill_speed)
|
||||
decode_speeds.append(decode_speed)
|
||||
print(f'[Request {request_id}] prefill speed: {prefill_speed}')
|
||||
print(f'[Request {request_id}] decode speed: {decode_speed}')
|
||||
|
||||
|
||||
async def main(concurrent_requests, rounds, max_tokens, model):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tasks = [multi_turn_conversation(session, i, rounds, max_tokens, model) for i in range(concurrent_requests)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
if prefill_speeds:
|
||||
import numpy as np
|
||||
print(f"\n=== Summary ===")
|
||||
print(f"Total concurrency: {concurrent_requests}")
|
||||
print(f"Avg prefill speed: {np.mean(prefill_speeds)}")
|
||||
print(f"Avg decode speed: {np.mean(decode_speeds)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
|
||||
parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests")
|
||||
parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name")
|
||||
parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048")
|
||||
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
|
||||
parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens")
|
||||
parser.add_argument("--rounds", type=int, default=8, help="Number of multi-turn rounds (before final query)")
|
||||
|
||||
args = parser.parse_args()
|
||||
SERVER_URL = args.api_url
|
||||
max_tokens = args.max_tokens
|
||||
model = args.model
|
||||
|
||||
asyncio.run(main(args.concurrent, args.rounds, max_tokens, model))
|
||||
|
1
third_party/PhotonLibOS
vendored
Submodule
1
third_party/PhotonLibOS
vendored
Submodule
|
@ -0,0 +1 @@
|
|||
Subproject commit 92f56d4527c24aafcee75d87fd72fce25680266f
|
Loading…
Add table
Reference in a new issue