diff --git a/.gitmodules b/.gitmodules index 68242f3..0716ec2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,19 @@ [submodule "third_party/pybind11"] path = third_party/pybind11 url = https://github.com/pybind/pybind11.git + +[submodule "third_party/spdlog"] + path = third_party/spdlog + url = https://github.com/gabime/spdlog.git +[submodule "third_party/custom_flashinfer"] + path = third_party/custom_flashinfer + url = https://github.com/kvcache-ai/custom_flashinfer.git + branch = fix-precision-mla-merge-main + +[submodule "third_party/xxHash"] + path = third_party/xxHash + url = https://github.com/Cyan4973/xxHash.git + +[submodule "third_party/prometheus-cpp"] + path = third_party/prometheus-cpp + url = https://github.com/jupp0r/prometheus-cpp diff --git a/MANIFEST.in b/MANIFEST.in index dac9b32..4097ce6 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,7 @@ graft third_party graft ktransformers graft local_chat.py +graft csrc include LICENSE README.md prune ktransformers/website prune ktransformers/logs @@ -9,3 +10,4 @@ prune third_party/llama.cpp/models graft ktransformers/website/dist global-exclude __pycache__ include KTransformersOps.*.so +include cpuinfer_ext.*.so diff --git a/Makefile b/Makefile index 8349809..74cb3c9 100644 --- a/Makefile +++ b/Makefile @@ -29,4 +29,4 @@ clean: install_numa: USE_NUMA=1 make dev_install install_no_numa: - env -u USE_NUMA make dev_install + env -u USE_NUMA make dev_install \ No newline at end of file diff --git a/README.md b/README.md index 63728d2..d90caf2 100644 --- a/README.md +++ b/README.md @@ -23,17 +23,20 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin

🔥 Updates

+* **Mar 27, 2025**: Support Multi-concurrency. * **Mar 15, 2025**: Support ROCm on AMD GPU ([Tutorial](./doc/en/ROCm.md)). * **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./doc/en/fp8_kernel.md) weights. Support 139K [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context) for DeepSeek-V3 and R1 in 24GB VRAM. * **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context). * **Feb 15, 2025**: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%, up to 16 Tokens/s), update [docs](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/). * **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. For detailed show case and reproduction tutorial, see [here](./doc/en/DeepseekR1_V3_tutorial.md). * **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G. -* **Aug 15, 2024**: Update detailed [tutorial](doc/en/injection_tutorial.md) for injection and multi-GPU. -* **Aug 14, 2024**: Support llamfile as linear backend. +* **Aug 15, 2024**: Update detailed [tutorial](doc/en/injection_tutorial.md) for injection and multi-GPU. +* **Aug 14, 2024**: Support llamfile as linear backend. * **Aug 12, 2024**: Support multiple GPU; Support new model: mixtral 8\*7B and 8\*22B; Support q2k, q3k, q5k dequant on gpu. * **Aug 9, 2024**: Support windows native. + +

🌟 Show Cases

@@ -45,16 +48,16 @@ https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285

- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 14GB VRAM and 382GB DRAM([Tutorial](./doc/en/DeepseekR1_V3_tutorial.md)). - - Prefill Speed (tokens/s): - - KTransformers: 54.21 (32 cores) → 74.362 (dual-socket, 2×32 cores) → 255.26 (optimized AMX-based MoE kernel, V0.3 only) → 286.55 (selectively using 6 experts, V0.3 only) - - Compared to 10.31 tokens/s in llama.cpp with 2×32 cores, achieving up to **27.79× speedup**. - - Decode Speed (tokens/s): - - KTransformers: 8.73 (32 cores) → 11.26 (dual-socket, 2×32 cores) → 13.69 (selectively using 6 experts, V0.3 only) - - Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **3.03× speedup**. - - Upcoming Open Source Release: - - AMX optimizations and selective expert activation will be open-sourced in V0.3. - - Currently available only in preview binary distribution, which can be downloaded [here](./doc/en/DeepseekR1_V3_tutorial.md). + - Prefill Speed (tokens/s): + - KTransformers: 54.21 (32 cores) → 74.362 (dual-socket, 2×32 cores) → 255.26 (optimized AMX-based MoE kernel, V0.3 only) → 286.55 (selectively using 6 experts, V0.3 only) + - Compared to 10.31 tokens/s in llama.cpp with 2×32 cores, achieving up to **27.79× speedup**. + - Decode Speed (tokens/s): + - KTransformers: 8.73 (32 cores) → 11.26 (dual-socket, 2×32 cores) → 13.69 (selectively using 6 experts, V0.3 only) + - Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **3.03× speedup**. + - Upcoming Open Source Release: + - AMX optimizations and selective expert activation will be open-sourced in V0.3. + - Currently available only in preview binary distribution, which can be downloaded [here](./doc/en/DeepseekR1_V3_tutorial.md). - **Local 236B DeepSeek-Coder-V2:** Running its Q4_K_M version using only 21GB VRAM and 136GB DRAM, attainable on a local desktop machine, which scores even better than GPT4-0613 in [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench).

@@ -96,19 +99,16 @@ https://github.com/user-attachments/assets/a865e5e4-bca3-401e-94b8-af3c080e6c12 * **Flexible Sparse Attention Framework**: Offers a flexible block sparse attention framework for CPU offloaded decoding. Compatible with SnapKV, Quest, and InfLLm. Further information is available [here](./doc/en/long_context_introduction.md). --> - More advanced features will coming soon, so stay tuned!

🚀 Quick Start

- Getting started with KTransformers is simple! Follow the steps below to set up and start using it. ### 📥 Installation To install KTransformers, follow the official [Installation Guide](https://kvcache-ai.github.io/ktransformers/en/install.html). -

📃 Brief Injection Tutorial

At the heart of KTransformers is a user-friendly, template-based injection framework. This allows researchers to easily replace original torch modules with optimized variants. It also simplifies the process of combining multiple optimizations, allowing the exploration of their synergistic effects. @@ -167,7 +167,6 @@ The development of KTransformer is based on the flexible and versatile framework KTransformer is actively maintained and developed by contributors from the MADSys group at Tsinghua University and members from Approaching.AI. We welcome new contributors to join us in making KTransformer faster and easier to use. -

Discussion

If you have any questions, feel free to open an issue. Alternatively, you can join our WeChat group for further discussion. QR Code: [WeChat Group](WeChatGroup.png) diff --git a/csrc/balance_serve/CMakeLists.txt b/csrc/balance_serve/CMakeLists.txt new file mode 100644 index 0000000..e9b6072 --- /dev/null +++ b/csrc/balance_serve/CMakeLists.txt @@ -0,0 +1,69 @@ + +cmake_minimum_required(VERSION 3.21) +find_program(GCC_COMPILER NAMES g++-13 g++-12 g++-11 REQUIRED) +set(CMAKE_CXX_COMPILER ${GCC_COMPILER}) + +# 显示选定的编译器 +message(STATUS "Using compiler: ${CMAKE_CXX_COMPILER}") + + +project(balance_serve VERSION 0.1.0) + +set(CMAKE_CXX_STANDARD 20) +# set(CMAKE_CXX_FLAGS "-Og -march=native -Wall -Wextra -g -fPIC") +# set(CMAKE_BUILD_TYPE "Debug") +set(CMAKE_CXX_FLAGS "-O3 -march=native -Wall -Wextra -fPIC") +set(CMAKE_BUILD_TYPE "Release") + +file(GLOB_RECURSE FMT_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.h") + +add_custom_target( + format + COMMAND clang-format + -i + -style=file + ${FMT_SOURCES} + COMMENT "Running clang-format on all source files" +) + + + +add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) +set(BUILD_SHARED_LIBS ON) +set(ENABLE_PUSH OFF) +set(ENABLE_COMPRESSION OFF) + +# set(CMAKE_BUILD_TYPE "Release") +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +set(THIRD_PARTY_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../third_party) +set(THIRD_PARTY_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/third_party) +add_subdirectory(${THIRD_PARTY_DIR}/prometheus-cpp ${THIRD_PARTY_BUILD_DIR}/prometheus-cpp EXCLUDE_FROM_ALL) +add_subdirectory(${THIRD_PARTY_DIR}/xxHash/cmake_unofficial ${THIRD_PARTY_BUILD_DIR}/xxHash EXCLUDE_FROM_ALL) + +# add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third_party/prometheus-cpp ${CMAKE_CURRENT_BINARY_DIR}/third_party/prometheus-cpp) +set(SPDLOG_DIR ${THIRD_PARTY_DIR}/spdlog) +set(FMT_DIR ${THIRD_PARTY_DIR}/fmt) + +set(KVC2_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/kvc2/src) + +include_directories(${THIRD_PARTY_DIR}) + +add_subdirectory(${THIRD_PARTY_DIR}/pybind11 ${THIRD_PARTY_BUILD_DIR}/pybind11) + +execute_process( + COMMAND python3 -c "import torch; print(torch.__path__[0])" + OUTPUT_VARIABLE TORCH_INSTALL_PREFIX + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +message(STATUS "Found PyTorch at: ${TORCH_INSTALL_PREFIX}") + +# set(TORCH_INSTALL_PREFIX "/home/xwy/.conda/envs/kvc/lib/python3.12/site-packages/torch") +find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib") +find_package(Torch REQUIRED PATHS "${TORCH_INSTALL_PREFIX}/share/cmake/Torch" NO_DEFAULT_PATH) + +add_subdirectory(kvc2) +add_subdirectory(sched) + +# add_subdirectory(test) diff --git a/csrc/balance_serve/kvc2/.clang-format b/csrc/balance_serve/kvc2/.clang-format new file mode 100644 index 0000000..752070f --- /dev/null +++ b/csrc/balance_serve/kvc2/.clang-format @@ -0,0 +1,25 @@ +Language: Cpp +# 格式化风格,可以是LLVM, Google, Chromium, Mozilla, WebKit等,或者自定义 +BasedOnStyle: Google + +# 缩进设置 +IndentWidth: 2 +TabWidth: 2 +UseTab: Never + +# 换行相关设置 +BreakBeforeBraces: Attach +AllowShortIfStatementsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Inline +AllowShortLoopsOnASingleLine: false + +# 类与结构体 +DerivePointerAlignment: false +PointerAlignment: Left + +# 包含文件的排序和分组 +IncludeBlocks: Preserve +SortIncludes: true + +# 控制最大行宽 +ColumnLimit: 120 diff --git a/csrc/balance_serve/kvc2/CMakeLists.txt b/csrc/balance_serve/kvc2/CMakeLists.txt new file mode 100644 index 0000000..4238f15 --- /dev/null +++ b/csrc/balance_serve/kvc2/CMakeLists.txt @@ -0,0 +1,104 @@ +cmake_minimum_required(VERSION 3.21) + +find_program(GCC_COMPILER NAMES g++-13 g++-12 g++-11 REQUIRED) +set(CMAKE_CXX_COMPILER ${GCC_COMPILER}) + +project(kvcache-manager VERSION 0.1.0) + +set(CMAKE_CXX_STANDARD 20) + +# set(CMAKE_CXX_FLAGS "-fPIC -O3 -ffast-math -march=native -Wall -Wextra -Wpedantic -fvisibility=hidden -s") +# set(CMAKE_CXX_FLAGS "-Og -march=native -Wall -Wextra -Wpedantic -g -fsanitize=address") +# set(CMAKE_CXX_FLAGS "-march=native -Wall -Wextra -Wpedantic -g") +# set(CMAKE_CXX_FLAGS "-fPIC -O3 -ffast-math -march=native -Wall -Wextra -g") +# set(CMAKE_BUILD_TYPE "Release") +set(CMAKE_BUILD_TYPE "Debug") +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(BUILD_TEST OFF) +set(BUILD_PYTHON_EXT OFF) + +add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) + +# set(USE_IO_URING ON) +if(USE_IO_URING) + message(STATUS "Using io_uring") + add_compile_definitions(USE_IO_URING) +else() + message(STATUS "Using aio") +endif() + +file(GLOB_RECURSE ALL_SOURCE_FILES src/*.cpp src/*.h test/*.cpp test/*.h test/*.hpp) + +# 添加一个自定义目标来格式化所有代码 +if(NOT TARGET format) + add_custom_target( + format + COMMAND clang-format + -i + -style=file + ${ALL_SOURCE_FILES} + COMMENT "Running clang-format on all source files" + ) +endif() + +execute_process( + COMMAND python3 -c "import torch; print(torch.__path__[0])" + OUTPUT_VARIABLE TORCH_INSTALL_PREFIX + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +message(STATUS "Found PyTorch at: ${TORCH_INSTALL_PREFIX}") + +# set(TORCH_INSTALL_PREFIX "/home/xwy/.conda/envs/kvc/lib/python3.12/site-packages/torch") +find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib") +find_package(Torch REQUIRED PATHS "${TORCH_INSTALL_PREFIX}/share/cmake/Torch" NO_DEFAULT_PATH) + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party) + +find_package(TBB REQUIRED) +find_package(CUDA REQUIRED) + +# find_package(prometheus-cpp CONFIG REQUIRED) +if(NOT TARGET prometheus-cpp::pull) + message(FATAL_ERROR "prometheus-cpp::pull not found") +else() + message(STATUS "prometheus Found!") +endif() + +if(CUDA_FOUND) + message(STATUS "CUDA Found!") + message(STATUS "CUDA Version: ${CUDA_VERSION_STRING}") + message(STATUS "CUDA Toolkit Root: ${CUDA_TOOLKIT_ROOT_DIR}") +else() + message(FATAL_ERROR "CUDA not found!") +endif() + +add_subdirectory(src) + +if(BUILD_TEST) + add_subdirectory(test) +endif() + +message(STATUS "BUILD_PYTHON_EXT: ${BUILD_PYTHON_EXT}") + +if(BUILD_PYTHON_EXT) + if(NOT TARGET pybind11::pybind11) + add_subdirectory(${THIRD_PARTY_DIR}/pybind11 ${THIRD_PARTY_BUILD_DIR}/pybind11) + endif() + + pybind11_add_module(kvc2_ext src/bind.cpp) + + # EXAMPLE_VERSION_INFO is defined by setup.py and passed into the C++ code as a + # define (VERSION_INFO) here. + target_compile_definitions(kvc2_ext PRIVATE VERSION_INFO=${EXAMPLE_VERSION_INFO}) + message(STATUS "CMAKE_CURRENT_SOURCE_DIR: ${CMAKE_CURRENT_SOURCE_DIR}") + target_include_directories(kvc2_ext PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/spdlog/include) + + target_link_libraries(kvc2_ext PUBLIC kvc2 async_store) + + install(TARGETS kvc2_ext LIBRARY + DESTINATION ${CMAKE_BINARY_DIR}/output) + install(FILES src/kvc2_utils.py + DESTINATION ${CMAKE_BINARY_DIR}/output) +endif() + diff --git a/csrc/balance_serve/kvc2/README.md b/csrc/balance_serve/kvc2/README.md new file mode 100644 index 0000000..e4c4745 --- /dev/null +++ b/csrc/balance_serve/kvc2/README.md @@ -0,0 +1,38 @@ +# KVC2 + +# Build +运行以下命令编译kvc2,注意可能需要 sudo 权限安装一些依赖 +```shell +git clone https://github.com/kvcache-ai/kvc2 +cd kvc2 +./install_deps.sh +mkdir build +cd build +cmake .. +make -j && make install +``` +编译完成后会生成`build/output`,包含`kvc2_ext.cpython-312-x86_64-linux-gnu.so`和`kvc2_utils.py`方便调用。 + + + +# Troubleshooting +在 Python 环境运行时,可以需要在 conda 中安装相关的依赖。 +```shell +conda install -c conda-forge gcc_linux-64 gxx_linux-64 +``` + +也可以尝试设置一下环境变量,然后再运行。 +```shell +export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH +export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libffi.so.7 +``` + + diff --git a/csrc/balance_serve/kvc2/config/model_configs.json b/csrc/balance_serve/kvc2/config/model_configs.json new file mode 100644 index 0000000..4d8195a --- /dev/null +++ b/csrc/balance_serve/kvc2/config/model_configs.json @@ -0,0 +1,42 @@ +{ + "DeepSeek-Coder-V2-Instruct": { + "hidden_size": 5120, + "intermediate_size": 12288, + "max_position_embeddings": 163840, + "model_type": "deepseek_v2", + "num_attention_heads": 128, + "num_hidden_layers": 60, + "num_key_value_heads": 128, + "vocab_size": 102400 + }, + "LLaMA-2-7B-32K": { + "hidden_size": 4096, + "intermediate_size": 11008, + "max_position_embeddings": 32768, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "vocab_size": 32000 + }, + "Qwen2.5-7B-Instruct": { + "hidden_size": 3584, + "intermediate_size": 18944, + "max_position_embeddings": 32768, + "model_type": "qwen2", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "vocab_size": 152064 + }, + "qwen2-72b-instruct": { + "hidden_size": 8192, + "intermediate_size": 29568, + "max_position_embeddings": 32768, + "model_type": "qwen2", + "num_attention_heads": 64, + "num_hidden_layers": 80, + "num_key_value_heads": 8, + "vocab_size": 152064 + } +} \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/config/quant_configs.json b/csrc/balance_serve/kvc2/config/quant_configs.json new file mode 100644 index 0000000..191df5a --- /dev/null +++ b/csrc/balance_serve/kvc2/config/quant_configs.json @@ -0,0 +1,57 @@ +{ + "BF16": { + "block_element_count": 1, + "block_element_size": 2, + "bytes_per_element": 2.0, + "can_be_used_as_vector": true, + "has_min": false, + "has_scale": false, + "name": "BF16", + "reference": "", + "type_of_dot_vector": "BF16" + }, + "FP16": { + "block_element_count": 1, + "block_element_size": 2, + "bytes_per_element": 2.0, + "can_be_used_as_vector": true, + "has_min": false, + "has_scale": false, + "name": "FP16", + "reference": "", + "type_of_dot_vector": "FP16" + }, + "FP32": { + "block_element_count": 1, + "block_element_size": 4, + "bytes_per_element": 4.0, + "can_be_used_as_vector": true, + "has_min": false, + "has_scale": false, + "name": "FP32", + "reference": "", + "type_of_dot_vector": "FP32" + }, + "Q4_0": { + "block_element_count": 32, + "block_element_size": 18, + "bytes_per_element": 0.5625, + "can_be_used_as_vector": false, + "has_min": false, + "has_scale": true, + "name": "Q4_0", + "reference": "https://huggingface.co/docs/hub/gguf", + "type_of_dot_vector": "Q8_0" + }, + "Q8_0": { + "block_element_count": 32, + "block_element_size": 34, + "bytes_per_element": 1.0625, + "can_be_used_as_vector": true, + "has_min": false, + "has_scale": true, + "name": "Q8_0", + "reference": "https://huggingface.co/docs/hub/gguf", + "type_of_dot_vector": "Q8_0" + } +} \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/export_envs_before_run.sh b/csrc/balance_serve/kvc2/export_envs_before_run.sh new file mode 100755 index 0000000..b2fcd9b --- /dev/null +++ b/csrc/balance_serve/kvc2/export_envs_before_run.sh @@ -0,0 +1,2 @@ +export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH +export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libffi.so.7 diff --git a/csrc/balance_serve/kvc2/install_deps.sh b/csrc/balance_serve/kvc2/install_deps.sh new file mode 100755 index 0000000..336a32a --- /dev/null +++ b/csrc/balance_serve/kvc2/install_deps.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +cd "${0%/*}" +git submodule update --init --recursive + +sudo apt update +sudo apt install libtbb-dev +sudo apt install libcurl4-openssl-dev +sudo apt install libaio-dev + +cd third_party/xxHash/ +make -j +sudo make install +cd ../.. + diff --git a/csrc/balance_serve/kvc2/mkfs.sh b/csrc/balance_serve/kvc2/mkfs.sh new file mode 100755 index 0000000..aadcb02 --- /dev/null +++ b/csrc/balance_serve/kvc2/mkfs.sh @@ -0,0 +1,4 @@ +sudo umount /mnt/xwy +sudo mkfs.xfs /dev/nvme0n1 -f +sudo mount /dev/nvme0n1 /mnt/xwy +sudo chown -R xwy /mnt/xwy/ \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/src/CMakeLists.txt b/csrc/balance_serve/kvc2/src/CMakeLists.txt new file mode 100644 index 0000000..98ea626 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/CMakeLists.txt @@ -0,0 +1,45 @@ +include_directories(${THIRD_PARTY_DIR}/asyncio/include) + +add_library(kvc2_metrics STATIC metrics.cpp) +target_link_libraries(kvc2_metrics PUBLIC prometheus-cpp::pull) + +add_library(page_aligned_memory_pool page_aligned_memory_pool.cpp) +target_include_directories(page_aligned_memory_pool PRIVATE ${THIRD_PARTY_DIR}/spdlog/include) + +function(add_third_party_includes TARGET_NAME) + target_include_directories(${TARGET_NAME} PRIVATE + ${THIRD_PARTY_BUILD_DIR}/prometheus-cpp/core/include + ${THIRD_PARTY_BUILD_DIR}/prometheus-cpp/pull/include + ${THIRD_PARTY_DIR}/prometheus-cpp/core/include + ${THIRD_PARTY_DIR}/prometheus-cpp/pull/include + ${THIRD_PARTY_DIR}/spdlog/include + ) +endfunction() + + +add_library(cache_entry cache_entry.cpp) +add_third_party_includes(cache_entry) +target_link_libraries(cache_entry PUBLIC gpu_cache) + +add_library(gpu_cache gpu_cache.cpp) +add_third_party_includes(gpu_cache) +target_link_libraries(gpu_cache PUBLIC xxHash::xxhash ${TORCH_LIBRARIES} cuda_stream_manager) + +add_library(kvc2 prefix.cpp) +target_include_directories(kvc2 PRIVATE ${THIRD_PARTY_DIR}/nlohmann/single_include) +add_third_party_includes(kvc2) +target_link_libraries(kvc2 PUBLIC TBB::tbb xxHash::xxhash cache_entry cuda_stream_manager page_aligned_memory_pool ${TORCH_LIBRARIES} prometheus-cpp::pull kvc2_metrics) + +message(STATUS "CMAKE_SOURCE_DIR: " ${CMAKE_SOURCE_DIR}) +add_library(async_store async_store.cpp) +target_include_directories(async_store PRIVATE ${THIRD_PARTY_DIR}/nlohmann/single_include) +target_include_directories(async_store PRIVATE ${THIRD_PARTY_DIR}/spdlog/include) +target_link_libraries(async_store PUBLIC pthread) + + + +add_library(cuda_stream_manager cuda_stream_manager.cpp) +target_include_directories(cuda_stream_manager PUBLIC ${THIRD_PARTY_DIR}/nlohmann/single_include) +target_include_directories(cuda_stream_manager PUBLIC ${THIRD_PARTY_DIR}/spdlog/include) +target_include_directories(cuda_stream_manager PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) +target_link_libraries(cuda_stream_manager PUBLIC CUDA::cudart) diff --git a/csrc/balance_serve/kvc2/src/async_store.cpp b/csrc/balance_serve/kvc2/src/async_store.cpp new file mode 100644 index 0000000..b5400c9 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/async_store.cpp @@ -0,0 +1,137 @@ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils/lock_free_queue.hpp" + +#include "async_store.hh" + +namespace async_store { + +struct ArrayStore { + static const size_t DeviceBlockSize = 512; + + const size_t element_size; + const size_t element_size_aligned; + + size_t size; + + size_t size_in_bytes() { return size * element_size_aligned; } + + std::filesystem::path data_path; + + void extend(size_t to) { + if (to <= size) { + return; + } + //TODO: extend file + size = to; + //LOG_INFO("Extend file to `, size `", to, size_in_bytes()); + } + + ArrayStore(size_t element_size, size_t size, std::filesystem::path data_path) + : element_size(element_size), + element_size_aligned((element_size + DeviceBlockSize - 1) / DeviceBlockSize), + data_path(data_path) { + //TODO: prefix cache + } + + void read(size_t index, void* buffer) { + //TODO: read from file + } + void write(size_t index, void* buffer) { + //TODO: write to file + } +}; + +ArrayStore* create_or_open_store(size_t element_size, size_t size, std::filesystem::path data_path) { + return new ArrayStore(element_size, size, data_path); +} + +void close_store(ArrayStore* store) { + delete store; +} + +size_t capacity(ArrayStore* store) { + return store->size; +} + +void extend(ArrayStore* store, size_t to) { + store->extend(to); +} + +template +struct ArrayStoreT { + ArrayStore store; + ArrayStoreT(size_t element_count, std::filesystem::path data_path) : store(sizeof(T), element_count, data_path) {} + + void read(size_t index, void* output) { store.read(index, output); } + + void write(size_t index, T& value) { store.write(index, &value); } + void write(size_t index, void* value) { store.write(index, value); } +}; + +std::string request_to_string(IORequest* req) { + return fmt::format("IOReqeust {} {} to {}[{}]", req->write ? "Write" : "Read ", req->data, + req->store->data_path.c_str(), req->index); +} + +struct IODealerImpl { + MPSCQueue ioQueue; + uint64_t io_cnt = 0; + size_t io_amount = 0; + bool use_io_uring; + int IO_DEPTH; + + bool stop = false; + IODealerImpl(bool use_io_uring, int IO_DEPTH) : use_io_uring(use_io_uring), IO_DEPTH(IO_DEPTH) {} + + void queue_consumer() { + //TODO: + } + + void io_perf() { + //TODO: + } + + void io_dealer() { + //TODO: + } +}; + +IODealer::IODealer(bool use_io_uring, int IO_DEPTH) { + io_impl = new IODealerImpl(use_io_uring, IO_DEPTH); +} + +IODealer::~IODealer() { + stop(); + delete io_impl; +} + +void IODealer::enqueue(std::shared_ptr req) { + io_impl->ioQueue.enqueue(req); +} + +std::thread IODealer::start_io_thread() { + return std::thread([this]() { io_impl->io_dealer(); }); +} +void IODealer::stop() { + if (io_impl->stop) { + return; + } + //LOG_INFO("Stopping IO Dealer"); + io_impl->stop = true; +} + +} // namespace async_store diff --git a/csrc/balance_serve/kvc2/src/async_store.hh b/csrc/balance_serve/kvc2/src/async_store.hh new file mode 100644 index 0000000..046e990 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/async_store.hh @@ -0,0 +1,51 @@ +#pragma once +#include +#include + +#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG +#define FMT_HEADER_ONLY +#include "spdlog/spdlog.h" + +#include "io_helper.hpp" + +namespace async_store { + +struct ArrayStore; + +ArrayStore* create_or_open_store(size_t element_size, size_t size, std::filesystem::path data_path); +void close_store(ArrayStore* store); +size_t capacity(ArrayStore* store); +void extend(ArrayStore* store, size_t to); + + + +struct IORequest { + ArrayStore* store; + bool write; + void* data; + size_t index; + + // for sync + bool need_promise = false; + BatchPromise* promise; +}; + +std::string request_to_string(IORequest* req); + +struct IODealerImpl; +struct IODealer { + IODealerImpl* io_impl; + + IODealer(bool use_io_uring = false, int IO_DEPTH = 128); + ~IODealer(); + IODealer(const IODealer&) = delete; + IODealer& operator=(const IODealer&) = delete; + IODealer(IODealer&&) = default; + IODealer& operator=(IODealer&&) = default; + + void enqueue(std::shared_ptr req); + std::thread start_io_thread(); + void stop(); +}; + +} // namespace async_store diff --git a/csrc/balance_serve/kvc2/src/bind.cpp b/csrc/balance_serve/kvc2/src/bind.cpp new file mode 100644 index 0000000..c76bb26 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/bind.cpp @@ -0,0 +1,53 @@ +// #include +// #include +// #include +// #include +// #include +// #include +// #include "kvc2.h" +// #define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG +// #define FMT_HEADER_ONLY +// #include "spdlog/spdlog.h" +// #include "utils/arithmetic.hpp" + +// namespace py = pybind11; + +// PYBIND11_MODULE(kvc2_ext, m) { +// // Bind KVC2Config struct +// py::class_(m, "KVC2Config") +// .def(py::init<>()) +// .def_readwrite("path", &kvc2::KVC2Config::path) +// .def_readwrite("block_length", &kvc2::KVC2Config::num_token_per_page) +// .def_readwrite("memory_pool_size", &kvc2::KVC2Config::memory_pool_size) +// .def_readwrite("evict_count", &kvc2::KVC2Config::evict_count); + +// // Bind CacheInfo struct +// py::class_(m, "CacheInfo") +// .def(py::init<>()) +// .def_readwrite("model_name", &kvc2::CacheInfo::model_name) +// .def_readwrite("is_key_cache", &kvc2::CacheInfo::is_key_cache) +// .def_readwrite("quant_type", &kvc2::CacheInfo::quant_type) +// .def("hidden_layer_count", &kvc2::CacheInfo::hidden_layer_count) +// .def("path", &kvc2::CacheInfo::path, py::arg("which_layer") = std::nullopt) +// .def("__eq__", &kvc2::CacheInfo::operator==) +// .def("element_size", &kvc2::CacheInfo::element_size) +// .def("hash_value", &kvc2::CacheInfo::hash_value); + +// // Bind KVC2HandleInterface class +// py::class_>(m, "KVC2HandleInterface") +// .def("matched_length", &kvc2::SingleCacheHandleInterface::matched_length) +// .def("handle_data", &kvc2::KVC2HandleInterface::handle_data); + +// // Bind KVC2Interface class +// py::class_>(m, "KVC2Interface") +// .def("start_io_thread", [](kvc2::KVC2Interface& self) { self.start_io_thread(); }) +// .def("stop_io_thread", &kvc2::KVC2Interface::stop_io_thread) +// .def("load", &kvc2::KVC2Interface::load) +// .def("save", &kvc2::KVC2Interface::save) +// .def("raw_insert", &kvc2::KVC2Interface::raw_insert) +// .def("raw_read", &kvc2::KVC2Interface::raw_read) +// .def("lookup", &kvc2::KVC2Interface::lookup); + +// // Bind create_kvc2 function +// m.def("create_kvc2", &kvc2::create_kvc2, py::arg("config")); +// } \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/src/cache_entry.cpp b/csrc/balance_serve/kvc2/src/cache_entry.cpp new file mode 100644 index 0000000..3fe6b0a --- /dev/null +++ b/csrc/balance_serve/kvc2/src/cache_entry.cpp @@ -0,0 +1,263 @@ +#include "cache_entry.hh" +#include + +#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG +#define FMT_HEADER_ONLY +#include "spdlog/spdlog.h" + +#include "gpu_cache.hh" + +namespace kvc2 { + +bool ConcurrentControlUnit::can_desert() { + if (ref_count.load() == 0 && dirty.load() == false) { + tc.reset(); + return true; + } else { + return false; + } +} +void ConcurrentControlUnit::debug() { + SPDLOG_DEBUG("ref count {}, dirty {}, {}", ref_count.load(), dirty.load(), tc.debug()); +} + +CacheBlockEntry::~CacheBlockEntry() { + if (data != nullptr && manager && manager->pool) { + SPDLOG_WARN("Free {} when destruct", data); + free_on_cpu(); + } +} + +bool CacheBlockEntry::alloc_on_cpu() { + assert(data == nullptr); + data = manager->pool->alloc(size); + if (data == nullptr) { + manager->evict_for_cpu_cache(); + data = manager->pool->alloc(size); + if (data == nullptr) { + SPDLOG_ERROR("Not enough memory for Block Cache"); + return false; + } + } + return true; +} + +void CacheBlockEntry::free_on_cpu() { + manager->pool->free(data, size); + data = nullptr; +} + +bool CacheBlockEntry::alloc_on_cpu_no_lock() { + if (data == nullptr) { + if (alloc_on_cpu() == false) { + return false; + } + } + return true; +} + +bool CacheBlockEntry::inc_ref_or_alloc_on_cpu() { + std::lock_guard lg(lock); + if (data == nullptr) { + if (alloc_on_cpu()) { + cpu_cc.ref_count.fetch_add(1); + return true; + } else { + return false; + } + } else { + cpu_cc.ref_count.fetch_add(1); + return true; + } +} + +std::unique_lock CacheBlockEntry::try_lock() { + return std::unique_lock(lock, std::try_to_lock); +} + +std::lock_guard CacheBlockEntry::lock_guard() { + return std::lock_guard(lock); +} + +void CacheBlockEntry::debug() { + SPDLOG_DEBUG( + "CacheBlockEntry: disk[{:4},{:7}], with key {}, hash {:016x}, data: {}, ref_count: {}, size: {}, cpu tc: {}, " + "in page cache: {}, gpu ref count:{}, gpu tc: {}", + layer, idx, with_key, hash, data, cpu_cc.ref_count.load(), size, cpu_cc.tc.debug(), manager != nullptr, + gpu_cc.ref_count.load(), gpu_cc.tc.debug()); +} + +CacheBlockEntryCollector::CacheBlockEntryCollector(std::function exit_fn) : exit_fn(exit_fn) {} + +CacheBlockEntryCollector::~CacheBlockEntryCollector() { + // SPDLOG_DEBUG("Collector Destruct"); + for (auto& e : entries) { + exit_fn(e); + } +} + +void CacheBlockEntry::io_with(async_store::IODealer* dealer, IO_Helper& io_helper, + async_store::ArrayStore* store, size_t layer, size_t index, IOOption option) { + bool write; + + auto& batch_promise = io_helper.batch_promise; + + switch (option) { + case IO_Read: { + write = false; + if (io_helper.absorb_tc(this, cpu_cc.tc)) { + // need read + } else { + return; + } + break; + } + case IO_ForceRead: { + // Not change + write = false; + break; + } + case IO_ForceWrite: { + // Not change + write = true; + break; + } + case IO_Write: { + write = true; + break; + } + default: { + assert(0); + } + } + io_helper.new_task(); + this->layer = layer; + this->idx = index; + + auto req = std::make_shared(); + req->store = store; + req->data = data; + req->index = index; + req->write = write; + req->need_promise = true; + req->promise = &batch_promise; + + SPDLOG_TRACE("Submitting {}", async_store::request_to_string(req.get())); + dealer->enqueue(std::move(req)); +} + +CacheEntryManager::CacheEntryManager(CacheEntryManagerConfig config) : config(config) {} + +void CacheEntryManager::evict_for_cpu_cache() { + size_t count = 0; + evict( + [&count](const BlockPtr& block) { + // here we assume each with gpu must resides on cpu + if (block->data != nullptr && block->cpu_cc.can_desert() && + block->gpu_cc.can_desert() /*For now If A Cache Entry Block is on GPU, it must on cpu. */) { + block->free_on_cpu(); + count += 1; + return true; + } else { + return false; + } + }, + [&count, this]() { + return false; + // return count == this->config.evict_count; + }); +} + +void CacheEntryManager::insert(BlockPtr entry) { + assert(entry->with_key); + assert(key_entry_map.count(entry->hash) == 0); + usage_list.push_front(entry); + key_entry_map[entry->hash] = usage_list.begin(); +} + +CacheEntryManager::BlockPtr CacheEntryManager::access(const Key& key) { + auto it = key_entry_map.at(key); + auto entry = *it; + usage_list.erase(it); + usage_list.push_front(entry); + key_entry_map[key] = usage_list.begin(); + return entry; +} + +// void CacheEntryManager::remove(const Key& key) { +// auto it = key_entry_map[key]; +// usage_list.erase(it); +// key_entry_map.erase(key); +// } + +void CacheEntryManager::evict(std::function filter, std::function stop_condition) { + auto evict_count = 0; + auto inspect_count = 0; + + std::lock_guard lg(lock); + for (auto it = usage_list.rbegin(); it != usage_list.rend();) { + inspect_count += 1; + // SPDLOG_DEBUG("Map Size {}, List Size {}, Evicted {} blocks, Inspected {}, {}", key_entry_map.size(), + // usage_list.size(), evict_count, inspect_count, pool->debug()); + // (*it)->debug(); + if (stop_condition()) + break; + auto entry_ul = (*it)->try_lock(); + if (entry_ul.owns_lock() == false) { + ++it; // Ensure iterator advances when locking fails + continue; + } + if (filter(*it)) { + // SPDLOG_DEBUG("Evicting {}", fmt::ptr(it->get())); + evict_count++; + if ((*it)->with_key) + key_entry_map.erase((*it)->hash); + it = decltype(it)(usage_list.erase(std::next(it).base())); // Use base() to adjust for reverse iterator + } else { + ++it; // Ensure iterator advances when filter fails + } + } + + if (evict_count > 0) { + SPDLOG_DEBUG("Map Size {}, List Size {}, Evicted {} blocks, Inspected {}, {}", key_entry_map.size(), + usage_list.size(), evict_count, inspect_count, pool->debug()); + } +} + +CacheEntryManager::BlockPtr CacheEntryManager::get(bool& is_new, size_t size, std::optional key) { + std::unique_lock ul(lock); + if (key.has_value()) { + if (key_entry_map.count(key.value())) { + is_new = false; + return access(key.value()); + } else { + auto entry = std::make_shared(); + entry->with_key = true; + entry->hash = key.value(); + entry->size = size; + entry->manager = this; + insert(entry); + is_new = true; + return entry; + } + } else { + auto entry = std::make_shared(); + entry->with_key = false; + entry->size = size; + entry->manager = this; + is_new = true; + return entry; + } +} + +void CacheEntryManager::debug() { + fmt::print("Cache Manager: {} entries\n", key_entry_map.size()); + pool->debug(); + fmt::print("Layer 0 Entries in Order\n", key_entry_map.size()); + for (auto& it : usage_list) { + if (it->layer == 0) + it->debug(); + } +} + +}; // namespace kvc2 diff --git a/csrc/balance_serve/kvc2/src/cache_entry.hh b/csrc/balance_serve/kvc2/src/cache_entry.hh new file mode 100644 index 0000000..16f9b84 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/cache_entry.hh @@ -0,0 +1,182 @@ +#ifndef __CACHE_ENTRY_HH_ +#define __CACHE_ENTRY_HH_ +#include "async_store.hh" +#include "cuda_stream_manager.hh" +#include "defs.h" +#include "hasher.hpp" +#include "io_helper.hpp" +#include "page_aligned_memory_pool.h" +#include "utils/periodic_task.hpp" + +#include +#include +#include +#include "utils/mutex_extend.hpp" + +namespace kvc2 { +using CacheBlockKey = TokensHash; + +class CacheEntryManager; +struct DoubleVerticalBlocksHandle; +class GPUPageCache; + +struct ConcurrentControlUnit { + std::atomic_size_t ref_count = 0; + std::atomic_bool dirty = false; + TransferControl tc; + + bool can_desert(); + void debug(); +}; + +enum IOOption { + IO_ForceRead, + IO_ForceWrite, + IO_Read, + IO_Write, +}; + +inline std::string to_string(IOOption op) { + switch (op) { + case IO_ForceRead: + return "IO_ForceRead"; + case IO_ForceWrite: + return "IO_ForceWrite"; + case IO_Read: + return "IO_Read"; + case IO_Write: + return "IO_Write"; + default: + return "Unknown"; + } +} + +struct CacheBlockEntry { + friend CacheEntryManager; + using MutexT = non_recursive_mutex; + // using MutexT = std::mutex; + MutexT lock; + + // for cache + bool with_key = true; + CacheBlockKey hash = 0; + CacheBlockKey hash_check = 0; + + CacheInfo cache_info; + CacheEntryManager* manager = nullptr; + + // for memory pool + void* data = nullptr; + size_t size = 0; + + ConcurrentControlUnit cpu_cc; + + // for disk + size_t layer = -1; + size_t idx = -1; + + // for gpu + + std::optional gpu_block_idx = std::nullopt; + ConcurrentControlUnit gpu_cc; + + CacheBlockEntry() =default; + CacheBlockEntry(const CacheBlockEntry& other) = delete; + CacheBlockEntry& operator=(const CacheBlockEntry& other) = delete; + CacheBlockEntry(CacheBlockEntry&& other) = delete; + CacheBlockEntry& operator=(CacheBlockEntry&& other) = delete; + ~CacheBlockEntry(); + + private: + bool alloc_on_cpu(); + + + public: + void free_on_cpu(); + bool alloc_on_cpu_no_lock(); + + bool inc_ref_or_alloc_on_cpu(); + void set_key(TokensHash key, std::shared_ptr me); + + std::unique_lock try_lock(); + std::lock_guard lock_guard(); + + // will not get lock + void io_with(async_store::IODealer* dealer, IO_Helper& io_helper, async_store::ArrayStore* store, + size_t layer, size_t index, IOOption option); + void flush_back_async(IO_Helper& helper, std::vector& dirty_flags); + + void debug(); +}; + +struct CacheBlockEntryCollector{ + + std::vector entries; + std::function exit_fn; + + CacheBlockEntryCollector(std::function exit_fn); + ~CacheBlockEntryCollector(); + + CacheBlockEntryCollector(const CacheBlockEntryCollector& other) = delete; + CacheBlockEntryCollector(CacheBlockEntryCollector&& other) = delete; + CacheBlockEntryCollector& operator=(const CacheBlockEntryCollector& other) = delete; + CacheBlockEntryCollector& operator=(CacheBlockEntryCollector&& other) = delete; + + + +}; + + +struct KVC2; +struct CacheEntryManagerConfig { + size_t evict_count = 100; + KVC2* kvc2_top = nullptr; +}; + +class CacheEntryManager { + public: + using Key = CacheBlockKey; + using BlockPtr = std::shared_ptr; + + private: + friend CacheBlockEntry; + + CacheEntryManagerConfig config; + + std::mutex lock; + std::list usage_list; + std::unordered_map::iterator> key_entry_map; + + void insert(BlockPtr entry); + BlockPtr access(const Key& key); + + // void remove(const Key& key); + void evict(std::function filter, std::function stop_condition); + + + public: + std::unique_ptr background_flush_back=nullptr; + std::shared_ptr pool; + std::shared_ptr gpu_cache; + + CacheEntryManager(CacheEntryManagerConfig config); + + // disable all move and copy + CacheEntryManager(const CacheEntryManager& other) = delete; + CacheEntryManager& operator=(const CacheEntryManager& other) = delete; + CacheEntryManager(CacheEntryManager&& other) = delete; + CacheEntryManager& operator=(CacheEntryManager&& other) = delete; + + void cpu_background_flush(); + + void evict_for_cpu_cache(); + + // just get block pointers, not allocate them, will not return nullptr + BlockPtr get(bool& is_new,size_t size, std::optional key = std::nullopt); + + void debug(); +}; + +} // namespace kvc2 + +#endif \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/src/common.h b/csrc/balance_serve/kvc2/src/common.h new file mode 100644 index 0000000..e69de29 diff --git a/csrc/balance_serve/kvc2/src/cuda_stream_manager.cpp b/csrc/balance_serve/kvc2/src/cuda_stream_manager.cpp new file mode 100644 index 0000000..c696bf4 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/cuda_stream_manager.cpp @@ -0,0 +1,135 @@ +#include "cuda_stream_manager.hh" +#include +#include +#include +#include +#include +#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO +// #define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG +#define FMT_HEADER_ONLY +#include "spdlog/spdlog.h" + +CudaStreamManager::CudaStreamManager(const std::vector& device_ids, int num_streams_per_device) { + for (int device_id : device_ids) { + auto x = std::unique_ptr(new DeviceInfo); + DeviceInfo& device_info = *x; + device_info.device_id = device_id; + device_info.next_stream_index = 0; + device_info.stop_flag = false; + + // 设置设备 + cudaError_t err = cudaSetDevice(device_id); + if (err != cudaSuccess) { + SPDLOG_WARN("cudaSetDevice failed on device {}: {}", device_id, cudaGetErrorString(err)); + throw std::runtime_error("cudaSetDevice failed"); + } + + // 创建 CUDA 流 + device_info.streams.resize(num_streams_per_device); + for (int i = 0; i < num_streams_per_device; ++i) { + err = cudaStreamCreate(&device_info.streams[i]); + if (err != cudaSuccess) { + SPDLOG_WARN("Failed to create CUDA stream on device {}: {}", device_id, cudaGetErrorString(err)); + throw std::runtime_error("Failed to create CUDA stream"); + } + } + + // 启动设备工作线程 + device_info.worker_thread = std::thread(&CudaStreamManager::deviceWorker, this, std::ref(device_info)); + + devices_.push_back(std::move(x)); + } +} + +CudaStreamManager::~CudaStreamManager() { + // 通知所有设备线程停止 + for (auto& device_info : devices_) { + device_info->stop_flag.store(true); + auto request = std::shared_ptr(new Request); + request->should_exit = true; + device_info->request_queue.enqueue(std::move(request)); + } + + // 等待所有线程结束 + for (auto& device_info : devices_) { + if (device_info->worker_thread.joinable()) { + device_info->worker_thread.join(); + } + + // 销毁 CUDA 流 + cudaSetDevice(device_info->device_id); + for (auto& stream : device_info->streams) { + cudaStreamDestroy(stream); + } + } +} + +void CudaStreamManager::submitRequest(std::shared_ptr request) { + // 找到对应的设备 + for (auto& device_info : devices_) { + if (device_info->device_id == request->device_id) { + device_info->request_queue.enqueue(request); + return; + } + } + throw std::runtime_error("Invalid device ID in request"); +} + +void CudaStreamManager::deviceWorker(DeviceInfo& device_info) { + // 设置设备 + cudaError_t err = cudaSetDevice(device_info.device_id); + if (err != cudaSuccess) { + SPDLOG_WARN("cudaSetDevice failed in worker thread for device {}: {}", device_info.device_id, + cudaGetErrorString(err)); + return; + } + + while (device_info.stop_flag.load() == false) { + auto request = device_info.request_queue.dequeue(); + if (request->should_exit) { + return; + } + // 处理请求 + SPDLOG_DEBUG("Getting request on device {}, count {}", device_info.device_id, request->host_mem_addresses.size()); + int stream_index = device_info.next_stream_index; + cudaStream_t stream = device_info.streams[stream_index]; + device_info.next_stream_index = (device_info.next_stream_index + 1) % device_info.streams.size(); + + size_t num_transfers = request->host_mem_addresses.size(); + for (size_t i = 0; i < num_transfers; ++i) { + void* dst = request->device_mem_addresses[i]; + void* src = request->host_mem_addresses[i]; + if (request->direction == cudaMemcpyDeviceToHost) { + std::swap(dst, src); + } + + cudaError_t err = cudaMemcpyAsync(dst, src, request->sizes[i], request->direction, stream); + if (err != cudaSuccess) { + SPDLOG_WARN("cudaMemcpyAsync failed on device {}: {}", device_info.device_id, cudaGetErrorString(err)); + // 可以根据需要处理错误,这里简单地继续 + continue; + } + } + + // 添加回调函数,因为是异步,所以需要包起来 + struct CallbackData { + std::function callback; + }; + CallbackData* cb_data = new CallbackData{request->callback}; + + err = cudaLaunchHostFunc( + stream, + [](void* data) { + // SPDLOG_DEBUG("Callback function called"); + CallbackData* cb_data = static_cast(data); + cb_data->callback(); + delete cb_data; + }, + cb_data); + + if (err != cudaSuccess) { + SPDLOG_WARN("cudaLaunchHostFunc failed on device {}: {}", device_info.device_id, cudaGetErrorString(err)); + // 根据需要处理错误 + } + } +} diff --git a/csrc/balance_serve/kvc2/src/cuda_stream_manager.hh b/csrc/balance_serve/kvc2/src/cuda_stream_manager.hh new file mode 100644 index 0000000..d4fe215 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/cuda_stream_manager.hh @@ -0,0 +1,54 @@ +/* + * @Author: Xie Weiyu ervinxie@qq.com + * @Date: 2024-11-19 09:24:47 + * @LastEditors: Xie Weiyu ervinxie@qq.com + * @LastEditTime: 2024-11-20 02:55:49 + * @FilePath: /kvc2/src/cuda_stream_manager.hh + * @Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include "utils/mpsc.hpp" + +class CudaStreamManager { + public: + // 构造函数,接受要使用的设备 ID 列表和每个设备的流数量 + CudaStreamManager(const std::vector& device_ids, int num_streams_per_device); + ~CudaStreamManager(); + + // 请求结构体 + struct Request { + bool should_exit = false; + int device_id; + std::vector host_mem_addresses; + std::vector device_mem_addresses; + std::vector sizes; + cudaMemcpyKind direction; + std::function callback; + }; + + void submitRequest(std::shared_ptr request); + + private: + // 每个设备的信息 + struct DeviceInfo { + int device_id; + std::thread worker_thread; + std::vector streams; + int next_stream_index; + MPSCQueueConsumerLock> request_queue; + std::atomic_bool stop_flag; + }; + + // 设备 ID 到 DeviceInfo 的映射 + std::vector> devices_; + + // 私有方法 + void deviceWorker(DeviceInfo& device_info); +}; diff --git a/csrc/balance_serve/kvc2/src/defs.h b/csrc/balance_serve/kvc2/src/defs.h new file mode 100644 index 0000000..b21f4e2 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/defs.h @@ -0,0 +1,35 @@ +#ifndef __DEFS_H_ +#define __DEFS_H_ + +#include +#include +#include +#include "model_config.h" + +namespace kvc2 { +using kvc2_ptr = void*; +// using data_block_ptr = std::intptr_t; +using data_block_ptr = void*; +using layer_data = std::vector; +using kvc2_handle = void*; + +using Token = uint32_t; +using Tokens = std::vector; +using TokenPtr = std::intptr_t; +using TokenLength = size_t; +using BlockLength = size_t; + +struct CacheInfo { + ModelName model_name; + bool is_key_cache; + QuantType quant_type; + + size_t hidden_layer_count(); + std::filesystem::path path(std::optional which_layer = std::nullopt); + bool operator==(const CacheInfo& other) const; + size_t element_size(size_t block_length); + size_t hash_value() const; +}; + +}; // namespace kvc2 +#endif diff --git a/csrc/balance_serve/kvc2/src/gpu_cache.cpp b/csrc/balance_serve/kvc2/src/gpu_cache.cpp new file mode 100644 index 0000000..2bfe945 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/gpu_cache.cpp @@ -0,0 +1,282 @@ +#include "gpu_cache.hh" + +#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG +#define FMT_HEADER_ONLY +#include "spdlog/spdlog.h" + +#include "cache_entry.hh" +#include "utils/arithmetic.hpp" + +namespace kvc2 { + +GPUPageCache::GPUPageCache(GPUPageCacheConfig& config) : config(config) { + if (torch::cuda::is_available()) { + size_t gpu_count = torch::cuda::device_count(); + SPDLOG_INFO("Number of available GPUs: {}, want {}", gpu_count, config.gpu_devices_id.size()); + if (gpu_count < config.gpu_devices_id.size()) { + SPDLOG_ERROR("Not enough GPUs available."); + exit(0); + } + for (auto x : config.gpu_devices_id) { + gpu_devices.push_back(torch::Device(torch::kCUDA, x)); + } + } else { + SPDLOG_ERROR("CUDA is not available on this system."); + exit(0); + } + + SPDLOG_WARN("Creating GPU Cache"); + shape.push_back(config.layer_count); + shape.push_back(config.total_kvcache_pages); + shape.push_back(config.num_token_per_page); + if (config.full_kv_cache_on_each_gpu) { + if (config.gpu_devices_id.size() > 1) { + SPDLOG_WARN("Replicated KVCache on multiple gpu"); + } + shape.push_back(config.num_k_heads); + } else { + shape.push_back(config.num_k_heads / config.gpu_devices_id.size()); + } + shape.push_back(config.k_head_dim); + tensor_size = torch::elementSize(config.tensor_type); + for (auto& s : shape) { + tensor_size *= s; + } + SPDLOG_INFO("Creating KV Page Cache, Shape ({},{},{},{},{}), Size {} MiB", shape[0], shape[1], shape[2], shape[3], + shape[4], tensor_size / (1 << 20)); + if (config.k_cache_on) { + for (size_t i = 0; i < config.gpu_devices_id.size(); i++) { + auto k = torch::zeros(shape, torch::TensorOptions().dtype(config.tensor_type)); + k = k.to(gpu_devices[i]); + + k_cache.push_back(k); + + SPDLOG_INFO("K Page Cache of GPU {} is created", config.gpu_devices_id[i]); + } + occupations.resize(config.layer_count); + } else { + SPDLOG_WARN("Disalbe K Cache"); + assert(config.gpu_only); + } + + if (config.v_cache_on) { + for (size_t i = 0; i < config.gpu_devices_id.size(); i++) { + auto v = torch::zeros(shape, torch::TensorOptions().dtype(config.tensor_type)); + v = v.to(gpu_devices[i]); + v_cache.push_back(v); + + SPDLOG_INFO("V Page Cache of GPU {} is created", config.gpu_devices_id[i]); + } + v_occupations.resize(config.layer_count); + } else { + SPDLOG_WARN("Disalbe V Cache"); + // assert(config.gpu_only); // should not assert + } + + if (config.gpu_only) { + gpu_only_occupations.resize(config.total_kvcache_pages, false); + } + + + num_free_pages = config.total_kvcache_pages; + for (size_t i = 0; i < config.layer_count; i++) { + if (config.k_cache_on) + occupations[i].resize(config.total_kvcache_pages, nullptr); + if (config.v_cache_on) + v_occupations[i].resize(config.total_kvcache_pages, nullptr); + } + + tp_size.resize(config.gpu_devices_id.size(), shape[2] * shape[3] * shape[4] * c10::elementSize(config.tensor_type)); + tp_offset.resize(config.gpu_devices_id.size(), 0); + for (size_t i = 1; i < tp_offset.size(); i++) { + tp_offset[i] = tp_offset[i - 1] + tp_size[i - 1]; + } + + stream_manager = + std::unique_ptr(new CudaStreamManager(config.gpu_devices_id, config.num_streams_per_device)); +} + +bool GPUPageCache::alloc_col(std::vector>>& k_entries, + std::vector>>& v_entries, size_t at) { + std::lock_guard lg(lock); + auto idx = next_empty_col(); + if (idx.has_value()) { + // must have entry lock + auto& k0_entry = k_entries[0][at]; + k0_entry->gpu_block_idx = idx; + + for (size_t l = 0; l < config.layer_count; l++) { + if (config.k_cache_on) { + assert(k_entries[l][at]->data != nullptr); + occupations[l][idx.value()] = k_entries[l][at]; + } + if (config.v_cache_on) { + assert(v_entries[l][at]->data != nullptr); + v_occupations[l][idx.value()] = v_entries[l][at]; + } + } + return true; + } else { + return false; + } +} + +std::vector GPUPageCache::gpu_only_alloc_col(size_t count) { + assert(config.gpu_only); + std::lock_guard lg(lock); + std::vector re; + + for (size_t i = 0; i < config.total_kvcache_pages; i++) { + if (gpu_only_occupations[i] == false) { + re.push_back(i); + if (re.size() == count) { + break; + } + } + } + + if (re.size() == count) { + for (auto at : re) { + gpu_only_occupations[at] = true; + } + } else { + SPDLOG_WARN("GPU ONLY: Cannot allocate {} cols", count); + re.clear(); + } + return re; +} + +void GPUPageCache::gpu_only_free_cols(std::vector cols) { + assert(config.gpu_only); + std::lock_guard lg(lock); + for (auto at : cols) { + assert(gpu_only_occupations[at]); + gpu_only_occupations[at] = false; + } +} + +std::optional GPUPageCache::next_empty_col() { + if (num_free_pages == 0) { + evict_cols(); + if (num_free_pages == 0) { + return std::nullopt; + } + } + while (occupations[0][_col_idx] != nullptr) { + _col_idx = (_col_idx + 1) % config.total_kvcache_pages; + } + num_free_pages -= 1; + return _col_idx; +} + +void GPUPageCache::evict_cols() { + auto evicted_count = 0; + for (size_t i = 0; i < config.total_kvcache_pages; i++) { + auto& h = occupations[0][i]; + if (h == nullptr) { + continue; + } + auto lg = h->lock_guard(); + if (h->gpu_cc.can_desert()) { + h->gpu_cc.tc.reset(); + h = nullptr; + num_free_pages += 1; + evicted_count += 1; + } + } + if (evicted_count > 0) + SPDLOG_INFO("GPU: Evicted {} GPU pages", evicted_count); +} + +std::vector> GPUPageCache::try_lock_col(size_t at) { + std::vector> re; + if (config.k_cache_on) { + for (size_t l = 0; l < config.layer_count; l++) { + if (occupations[l][at] == nullptr) { + return {}; + } + auto ul = occupations[l][at]->try_lock(); + if (ul.owns_lock()) { + re.push_back(std::move(ul)); + } else { + return {}; + } + } + } + if (config.v_cache_on) { + for (size_t l = 0; l < config.layer_count; l++) { + if (v_occupations[l][at] == nullptr) { + return {}; + } + auto ul = v_occupations[l][at]->try_lock(); + if (ul.owns_lock()) { + re.push_back(std::move(ul)); + } else { + return {}; + } + } + } + return re; +} + +std::vector> GPUPageCache::basic_request(cudaMemcpyKind direction, + std::function callback) { + std::vector> re; + re.resize(config.gpu_devices_id.size(), nullptr); + for (size_t i = 0; i < re.size(); i++) { + re[i] = std::shared_ptr(new CudaStreamManager::Request); + re[i]->direction = direction; + re[i]->device_id = config.gpu_devices_id[i]; + re[i]->callback = callback; + } + return re; +} + +void GPUPageCache::submit_requests(std::vector> reqs) { + for (auto& r : reqs) { + stream_manager->submitRequest(r); + } +} + +void GPUPageCache::append_col_to_request(std::vector>& reqs, + std::vector>>& k_handles, + std::vector>>& v_handles, + size_t at) { + if (config.k_cache_on == false && config.v_cache_on == false) { + return; + } + auto gpu_block_idx = k_handles[0][at]->gpu_block_idx.value(); + for (size_t layer = 0; layer < config.layer_count; layer++) { + for (size_t which_gpu = 0; which_gpu < config.gpu_devices_id.size(); which_gpu++) { + + if (config.k_cache_on) { + assert(k_handles[layer][at]->data != nullptr); + reqs[which_gpu]->sizes.push_back(tp_size[which_gpu]); + reqs[which_gpu]->host_mem_addresses.push_back(offset_by_bytes(k_handles[layer][at]->data, tp_offset[which_gpu])); + reqs[which_gpu]->device_mem_addresses.push_back(k_cache[which_gpu][layer][gpu_block_idx].data_ptr()); + } + + if (config.v_cache_on) { + assert(v_handles[layer][at]->data != nullptr); + reqs[which_gpu]->sizes.push_back(tp_size[which_gpu]); + reqs[which_gpu]->host_mem_addresses.push_back(offset_by_bytes(v_handles[layer][at]->data, tp_offset[which_gpu])); + reqs[which_gpu]->device_mem_addresses.push_back(v_cache[which_gpu][layer][gpu_block_idx].data_ptr()); + } + } + } + // SPDLOG_DEBUG("GPU: Appended Vertical Handle to Request, count {}", reqs[0]->sizes.size()); +} + +void GPUPageCache::debug() { + size_t count = 0; + for (size_t i = 0; i < config.total_kvcache_pages; i++) { + if (occupations[0][i] == nullptr) { + count += 1; + } else { + // occupations[0][i]->gpu_cc.debug(); + } + } + SPDLOG_DEBUG("Free Page: {}/{}", count, config.total_kvcache_pages); +} + +} // namespace kvc2 diff --git a/csrc/balance_serve/kvc2/src/gpu_cache.hh b/csrc/balance_serve/kvc2/src/gpu_cache.hh new file mode 100644 index 0000000..0621056 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/gpu_cache.hh @@ -0,0 +1,74 @@ +#ifndef __GPU_CACHE_HH_ +#define __GPU_CACHE_HH_ + +#include +#include "cache_entry.hh" +#include "cuda_stream_manager.hh" +#include "defs.h" +#include "kvc2.h" +#include "metrics.h" +#include "utils/periodic_task.hpp" + +namespace kvc2 { + +class GPUPageCache { + std::vector gpu_devices; + + std::vector shape; + size_t tensor_size; + std::vector tp_offset; + std::vector tp_size; + + + + // met + std::shared_ptr met; + + // states + std::mutex lock; + size_t num_free_pages; + std::vector gpu_only_occupations; + std::vector>> occupations,v_occupations; + size_t _col_idx = 0; + + + // cuda stream manager + std::optional next_empty_col(); + + public: + GPUPageCacheConfig config; + std::unique_ptr stream_manager; + std::vector k_cache; + std::vector v_cache; + std::unique_ptr background_flush_back =nullptr; + + GPUPageCache(GPUPageCacheConfig& config); + + std::vector gpu_only_alloc_col(size_t count); + void gpu_only_free_cols(std::vector cols); + + + void gpu_background_flush(); + + + bool alloc_col(std::vector>>& k_entries, + std::vector>>& v_entries, size_t at); + void evict_cols(); + void flush_col(size_t at); + std::vector> try_lock_col(size_t at); + + void free_col(size_t at); + + std::vector> basic_request(cudaMemcpyKind direction, + std::function callback); + + void submit_requests(std::vector> reqs); + + void append_col_to_request(std::vector>& reqs, + std::vector>>& k_handles, + std::vector>>& v_handles, size_t at); + + void debug(); +}; +} // namespace kvc2 +#endif \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/src/hasher.hpp b/csrc/balance_serve/kvc2/src/hasher.hpp new file mode 100644 index 0000000..7b328ae --- /dev/null +++ b/csrc/balance_serve/kvc2/src/hasher.hpp @@ -0,0 +1,40 @@ +#ifndef __HASHER_HPP_ +#define __HASHER_HPP_ + +#include "defs.h" +#include "xxhash.h" + +namespace kvc2 { + +const uint64_t hash_seed = 4123512; +const uint64_t check_hash_seed = 1025753; + +using TokensHash = XXH64_hash_t; +struct TokensHasher { + XXH64_state_t* state; + TokensHasher() { + state = XXH64_createState(); + reset(); + } + ~TokensHasher() { XXH64_freeState(state); } + + TokensHasher(TokensHasher& other) = delete; + TokensHasher& operator=(TokensHasher& other) = delete; + TokensHasher(TokensHasher&& other) = delete; + TokensHasher& operator=(TokensHasher&& other) = delete; + TokensHash get() { return XXH64_digest(state); } + void reset(size_t seed = hash_seed) { XXH64_reset(state, seed); } + TokensHash update(Token* data, TokenLength length) { + XXH64_update(state, data, length * sizeof(Token)); + return get(); + } + + TokensHash update_raw(void* data, size_t size) { + XXH64_update(state, data, size); + return get(); + } + + static TokensHash hash(Token* data, TokenLength length) { return XXH64(data, length * sizeof(Token), hash_seed); } +}; +} // namespace kvc2 +#endif \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/src/io_helper.hpp b/csrc/balance_serve/kvc2/src/io_helper.hpp new file mode 100644 index 0000000..b9f3021 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/io_helper.hpp @@ -0,0 +1,155 @@ +/** + * @Description : + * @Author : Xie Weiyu + * @Date : 2024-12-11 06:35:31 + * @Version : 1.0.0 + * @LastEditors : Xie Weiyu + * @LastEditTime : 2024-12-11 06:50:55 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + **/ + +#pragma once +#include +#include +#include +#include +#include +#include +#include + +struct BatchPromise { + std::promise promise; + std::shared_future fut; + std::atomic_size_t count; + + inline BatchPromise(size_t count) : count(count) { fut = promise.get_future().share(); } + + inline void inc(size_t count = 1) { this->count.fetch_add(count, std::memory_order_seq_cst); } + + inline void set() { + if (count.fetch_sub(1, std::memory_order_seq_cst) == 1) { + promise.set_value(); + } + } + inline std::shared_future get_shared_fut() { return fut; } +}; + +template +struct TransferControl { + Lock lock; + + std::optional> transfer_ok = std::nullopt; + bool has_data = false; + + TransferControl() {} + + /* + true, std::nullopt : Already has data + false, shared_future : Transfer already started, should wait for the future + false, std::nullopt : should transfer by you + true, shared_future: Should not appear + */ + std::pair>> has_data_or_transfer(std::shared_future shared_fut) { + std::lock_guard lg(lock); + if (has_data) { + return {true, std::nullopt}; + } else { + if (transfer_ok.has_value()) { + return {false, transfer_ok}; + } else { + transfer_ok = shared_fut; + return {false, std::nullopt}; + } + } + } + + void set_has_data() { + std::lock_guard lg(lock); + has_data = true; + transfer_ok = std::nullopt; + } + + bool get_has_data() { + std::lock_guard lg(lock); + if (has_data) { + return true; + } else { + return false; + } + } + + void reset() { + std::lock_guard lg(lock); + transfer_ok = std::nullopt; + has_data = false; + } + + std::string debug() { + std::lock_guard lg(lock); + return std::string("") + (has_data ? "has data" : "no data") + " " + + (transfer_ok.has_value() ? "transfer " : "no transfer"); + } +}; + +struct ConcurrentController { + std::atomic_bool dirty = false; + std::atomic_size_t ref_count = 0; + TransferControl tc; +}; + +template +struct IO_Helper { + BatchPromise batch_promise; + std::function call_back_on_unit = nullptr; + std::function call_back = nullptr; + + std::vector> futs; + std::vector units_by_myself; + + IO_Helper(std::function call_back_on_unit, std::function call_back = nullptr) + : batch_promise(1), call_back_on_unit(call_back_on_unit), call_back(call_back) {} + + IO_Helper(const IO_Helper& other) = delete; + IO_Helper& operator=(const IO_Helper& other) = delete; + IO_Helper(IO_Helper&& other) = delete; + IO_Helper& operator=(IO_Helper&& other) = delete; + ~IO_Helper() { + // std::cout<<"Destory IO helper"<& tc) { + auto [ok, fut] = tc.has_data_or_transfer(batch_promise.get_shared_fut()); + if (ok) { + return false; + } else { + if (fut.has_value()) { + futs.push_back(fut.value()); + // printf("Transfer started\n"); + return false; + } else { + units_by_myself.push_back(unit); + // printf("Not Transfer\n"); + return true; + } + } + } + + void wait() { + for (auto& fut : futs) { + fut.wait(); + } + batch_promise.get_shared_fut().wait(); + for (auto& b : units_by_myself) { + call_back_on_unit(b); + } + if (call_back) + call_back(); + } +}; diff --git a/csrc/balance_serve/kvc2/src/kvc2.h b/csrc/balance_serve/kvc2/src/kvc2.h new file mode 100644 index 0000000..e93a6cf --- /dev/null +++ b/csrc/balance_serve/kvc2/src/kvc2.h @@ -0,0 +1,138 @@ +#pragma once +#include +#include +#include +#include +#include "defs.h" +#include "model_config.h" + +namespace kvc2 { +struct GPUPageCacheConfig { + bool gpu_only; + std::vector gpu_devices_id; + + size_t layer_count; + size_t total_kvcache_pages; + size_t num_token_per_page; + size_t num_k_heads; + size_t k_head_dim; + + bool full_kv_cache_on_each_gpu = false; + bool k_cache_on = true; + bool v_cache_on = true; + torch::ScalarType tensor_type; + + // for cuda stream manager + size_t num_streams_per_device = 4; +}; + +struct KVC2Config { + bool k_cache_on = true; + bool v_cache_on = true; + bool gpu_only = false; + bool load_from_disk = true; + bool save_to_disk = true; + std::string path; + std::string config_path; + TokenLength num_token_per_page = 256; + size_t memory_pool_size = 10e9; + size_t evict_count = 20; + std::optional gpu_cache_config = std::nullopt; + size_t metrics_port; + double recompute_ratio = 0.2; +}; + +class DoubleCacheHandleInterface; +class KVC2Interface { + public: + virtual ~KVC2Interface() = default; + + virtual void load() = 0; + virtual void save() = 0; + /* +Raw Insert +Insert kvcache from kvcache_data to disk. + +info: cache info +id: start pointer of token array +length: length of token array +kvcache_data: data of kvcache + +This will firstly match the ID array with the existing kvcache, and then insert the unmatched kvcache to disk. +*/ + virtual void raw_insert(ModelName model_name, QuantType quant_type, Token* id, TokenLength length, + const std::vector& k_cache, const std::vector& v_cache) = 0; + + /* +Raw Read +Read kvcache from disk to user specified pointers. + +info: cache info +id: start pointer of token array +length: length of token array +kvcache_data: data of kvcache +Return: matched length of prefix, in tokens + +This will not read from memory pool, it directly read from disk. +*/ + virtual TokenLength raw_read(ModelName model_name, QuantType quant_type, Token* id, TokenLength length, + const std::vector& k_cache, const std::vector& v_cache) = 0; + + /* + Lookup + Lookup kvcache and load it from disk to memory pool if needed. + + info: cache info + id: start pointer of token array + length: length of token array + + Return: kvc2_handle, holds kvcache until being released. + if not found, matched_length will return 0. + if memory pool is full, return nullptr + */ + virtual std::shared_ptr lookup(ModelName model_name, QuantType quant_type, Token* id, + TokenLength length, TokenLength estimated_length) = 0; + + /* + Lookup and allocate to gpu + info.is_k_cache does not matter here + */ + virtual std::shared_ptr lookup_to_gpu(ModelName model_name, QuantType quant_type, + Token* id, TokenLength length, + TokenLength estimated_length) = 0; + + virtual void lookup_to_gpu_async(ModelName model_name, QuantType quant_type, Token* id, TokenLength length, + TokenLength estimated_length, + std::function)> call_back) = 0; + + virtual std::pair, std::vector> get_kvcache() = 0; + + virtual void debug() = 0; +}; + +std::shared_ptr create_kvc2(KVC2Config config); + +enum MatchStatus { + Exact, + Partial, + NotMatchExact, + NotMatchPartial, +}; + +class DoubleCacheHandleInterface { + public: + virtual ~DoubleCacheHandleInterface() = default; + virtual TokenLength matched_length() = 0; + virtual std::vector matched_status() = 0; + virtual std::vector handle_data(bool is_key_cache) = 0; + virtual bool to_gpu() = 0; + virtual void to_gpu_async(std::function call_back) = 0; + virtual std::vector get_gpu_block_idx() = 0; + virtual std::vector get_gpu_attached_block_idx() = 0; + + virtual void append_tokens(Token* tokens, TokenLength length) = 0; // update generated tokens + + virtual void debug() = 0; +}; + +}; // namespace kvc2 diff --git a/csrc/balance_serve/kvc2/src/kvc2_utils.py b/csrc/balance_serve/kvc2/src/kvc2_utils.py new file mode 100644 index 0000000..954f63c --- /dev/null +++ b/csrc/balance_serve/kvc2/src/kvc2_utils.py @@ -0,0 +1,64 @@ +import torch +import ctypes + +def aligned_tensor(size, alignment=4096): + num_bytes = size + mem = ctypes.c_void_p() + error_code = ctypes.CDLL(None).posix_memalign( + ctypes.byref(mem), ctypes.c_size_t(alignment), ctypes.c_size_t(num_bytes) + ) + + if error_code != 0: + raise MemoryError(f"posix_memalign failed with error code {error_code}") + + array_type = (ctypes.c_int8 * size) + raw_array = array_type.from_address(mem.value) + + tensor = torch.frombuffer(raw_array, dtype=torch.int8) + + if tensor.data_ptr() % alignment != 0: + raise ValueError(f"Tensor data_ptr {tensor.data_ptr()} is not aligned to {alignment} bytes") + + return tensor, mem + +def alloc_aligned_cache(layer_count,block_count,element_size): + cache = [] + cache_mem = [] + for i in range(layer_count): + layer_data = [] + layer_mem = [] + for j in range(block_count): + tensor, mem_ptr = aligned_tensor(element_size, alignment=4096) + layer_data.append(tensor) + layer_mem.append(mem_ptr) + cache.append(layer_data) + cache_mem.append(layer_mem) + return cache,cache_mem + +def dealloc_aligned_cache(cache_mem): + for layer_mem in cache_mem: + for mem_ptr in layer_mem: + ctypes.CDLL(None).free(mem_ptr) + +def get_tensor_ptr(tensors): + tensor_ptr = [] + for layer in tensors: + layer_ptr = [] + for data in layer: + layer_ptr.append(data.data_ptr()) + tensor_ptr.append(layer_ptr) + return tensor_ptr + +def get_tensor_from_data_ptr(matched_data,element_size): + re = [] + for layer in matched_data: + re_layer = [] + for data_ptr in layer: + array_type = (ctypes.c_int8 * element_size) + raw_array = array_type.from_address(data_ptr) + tensor = torch.frombuffer(raw_array, dtype=torch.int8) + re_layer.append(tensor) + re.append(re_layer) + return re +if __name__ == "__main__": + pass \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/src/metrics.cpp b/csrc/balance_serve/kvc2/src/metrics.cpp new file mode 100644 index 0000000..9dd2c9e --- /dev/null +++ b/csrc/balance_serve/kvc2/src/metrics.cpp @@ -0,0 +1,141 @@ +#include "metrics.h" + +namespace kvc2 { + +Metrics::Metrics(const MetricsConfig& config) + : registry_(std::make_shared()), exposer_(config.endpoint) { + // 注册 prefix_nodes Counter + auto& prefix_nodes_family = prometheus::BuildCounter() + .Name(std::string(METRIC_PREFIX) + "_prefix_nodes") + .Help("Number of prefix nodes") + .Register(*registry_); + prefix_nodes = &prefix_nodes_family.Add({}); + + // 注册 prefix_block_count Counter + auto& prefix_block_count_family = prometheus::BuildCounter() + .Name(std::string(METRIC_PREFIX) + "_prefix_block_count") + .Help("Number of prefix blocks") + .Register(*registry_); + prefix_block_count = &prefix_block_count_family.Add({}); + + // 定义统一的桶大小,最大为 10000 ms (10 s) + std::vector common_buckets = {1.0, 5.0, 10.0, 50.0, 100.0, 500.0, 1000.0, 5000.0, 10000.0}; + + // 注册 raw_insert_time_ms Histogram + auto& raw_insert_time_ms_family = prometheus::BuildHistogram() + .Name(std::string(METRIC_PREFIX) + "_raw_insert_time_ms") + .Help("function raw insert's time in milliseconds") + .Register(*registry_); + raw_insert_time_ms = &raw_insert_time_ms_family.Add({}, common_buckets); + + // 注册 lookup_time_ms Histogram + auto& lookup_time_ms_family = prometheus::BuildHistogram() + .Name(std::string(METRIC_PREFIX) + "_lookup_time_ms") + .Help("function lookup's time in milliseconds") + .Register(*registry_); + lookup_time_ms = &lookup_time_ms_family.Add({}, common_buckets); + + // 注册 lookup_prefixmatch_length Histogram + auto& lookup_prefixmatch_length_family = prometheus::BuildHistogram() + .Name(std::string(METRIC_PREFIX) + "_lookup_prefixmatch_length") + .Help("function lookup's prefix match length") + .Register(*registry_); + lookup_prefixmatch_length = &lookup_prefixmatch_length_family.Add({}, common_buckets); + + // 注册 matched_length_percentage Histogram + auto& matched_length_percentage_family = prometheus::BuildHistogram() + .Name(std::string(METRIC_PREFIX) + "_matched_length_percentage") + .Help("function matched length percentage") + .Register(*registry_); + matched_length_percentage = &matched_length_percentage_family.Add({}, common_buckets); + + // 注册 disk_usage Gauge + auto& disk_usage_family = + prometheus::BuildGauge().Name(std::string(METRIC_PREFIX) + "_disk_usage").Help("disk usage").Register(*registry_); + disk_usage = &disk_usage_family.Add({}); + + // 注册 memory_pool_size Gauge + memory_pool_size_family_ = &prometheus::BuildGauge() + .Name(std::string(METRIC_PREFIX) + "_memory_pool_size") + .Help("memory pool size") + .Register(*registry_); + + // 注册 memory_pool_node_count Gauge + memory_pool_node_count_family_ = &prometheus::BuildGauge() + .Name(std::string(METRIC_PREFIX) + "_memory_pool_node_count") + .Help("memory pool node count") + .Register(*registry_); + + // 注册 lru_entry_count Gauge + lru_entry_count_family_ = &prometheus::BuildGauge() + .Name(std::string(METRIC_PREFIX) + "_lru_entry_count") + .Help("lru entry count") + .Register(*registry_); + + // 注册 gpu_page_count Gauge + gpu_page_count_family_ = &prometheus::BuildGauge() + .Name(std::string(METRIC_PREFIX) + "_gpu_page_count") + .Help("gpu page count") + .Register(*registry_); + + // 注册 append_tokens_time_ms Histogram + auto& append_tokens_time_ms_family = prometheus::BuildHistogram() + .Name(std::string(METRIC_PREFIX) + "_append_tokens_time_ms") + .Help("append tokens time in milliseconds") + .Register(*registry_); + append_tokens_time_ms = &append_tokens_time_ms_family.Add({}, common_buckets); + + // 注册 gpu_flush_back_time_ms Histogram + auto& gpu_flush_back_time_ms_family = prometheus::BuildHistogram() + .Name(std::string(METRIC_PREFIX) + "_gpu_flush_back_time_ms") + .Help("gpu flush back time in milliseconds") + .Register(*registry_); + gpu_flush_back_time_ms = &gpu_flush_back_time_ms_family.Add({}, common_buckets); + + // 注册 cpu_flush_back_time_ms Histogram + auto& cpu_flush_back_time_ms_family = prometheus::BuildHistogram() + .Name(std::string(METRIC_PREFIX) + "_cpu_flush_back_time_ms") + .Help("cpu flush back time in milliseconds") + .Register(*registry_); + cpu_flush_back_time_ms = &cpu_flush_back_time_ms_family.Add({}, common_buckets); + + exposer_.RegisterCollectable(registry_); +} + +// 析构函数 +Metrics::~Metrics() { + // 停止指标暴露 + // exposer_.Stop(); +} + +// 获取 memory_pool_size 指标 +prometheus::Gauge* Metrics::memory_pool_size(const std::string& type) { + return &memory_pool_size_family_->Add({{"type", type}}); +} + +// 获取 memory_pool_node_count 指标 +prometheus::Gauge* Metrics::memory_pool_node_count(const std::string& type) { + return &memory_pool_node_count_family_->Add({{"type", type}}); +} + +// 获取 lru_entry_count 指标 +prometheus::Gauge* Metrics::lru_entry_count(const std::string& type) { + return &lru_entry_count_family_->Add({{"type", type}}); +} + +// 获取 gpu_page_count 指标 +prometheus::Gauge* Metrics::gpu_page_count(std::string type) { + return &gpu_page_count_family_->Add({{"type", type}}); +} + +TimeObserver::TimeObserver(prometheus::Histogram* h) { + histogram_ = h; + timer_.start(); +} + +TimeObserver::~TimeObserver() { + timer_.stop(); + histogram_->Observe(timer_.elapsedNs() / 1e6); // ns -> ms +} + +} // namespace kvc2 \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/src/metrics.h b/csrc/balance_serve/kvc2/src/metrics.h new file mode 100644 index 0000000..fa88785 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/metrics.h @@ -0,0 +1,77 @@ +#pragma once + +#include "prometheus/counter.h" +#include "prometheus/exposer.h" +#include "prometheus/gauge.h" +#include "prometheus/histogram.h" +#include "prometheus/registry.h" +#include +#include +#include +#include +#include +#include + +#include "utils/timer.hpp" + +namespace kvc2 { + +// 指标前缀宏定义 +#define METRIC_PREFIX "kvc2" + +struct MetricsConfig { + std::string endpoint; // 监听端点,如 "0.0.0.0:8080" +}; + +class Metrics { + public: + // 构造函数传入 MetricsConfig + Metrics(const MetricsConfig& config); + ~Metrics(); + + // 禁止拷贝和赋值 + Metrics(const Metrics&) = delete; + Metrics& operator=(const Metrics&) = delete; + + // 指标指针 + prometheus::Counter* prefix_nodes; + prometheus::Counter* prefix_block_count; + + prometheus::Histogram* raw_insert_time_ms; + prometheus::Histogram* lookup_time_ms; + prometheus::Histogram* lookup_prefixmatch_length; + prometheus::Histogram* matched_length_percentage; + + prometheus::Gauge* disk_usage; + + prometheus::Gauge* memory_pool_size(const std::string& type); + prometheus::Gauge* memory_pool_node_count(const std::string& type); + + prometheus::Gauge* lru_entry_count(const std::string& type); + prometheus::Gauge* gpu_page_count(std::string type); + + prometheus::Histogram* append_tokens_time_ms; + prometheus::Histogram* gpu_flush_back_time_ms; + prometheus::Histogram* cpu_flush_back_time_ms; + + private: + std::shared_ptr registry_; + prometheus::Exposer exposer_; + + prometheus::Family* memory_pool_size_family_; + prometheus::Family* memory_pool_node_count_family_; + prometheus::Family* lru_entry_count_family_; + prometheus::Family* gpu_page_count_family_; +}; + +class TimeObserver { + public: + TimeObserver(prometheus::Histogram* h); + ~TimeObserver(); + + private: + Timer timer_; + prometheus::Histogram* histogram_; +}; + +} // namespace kvc2 \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/src/model_config.h b/csrc/balance_serve/kvc2/src/model_config.h new file mode 100644 index 0000000..7ad1d90 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/model_config.h @@ -0,0 +1,103 @@ +#ifndef __MODEL_CONFIG_HPP_ +#define __MODEL_CONFIG_HPP_ + +#include +#include "nlohmann/json.hpp" + +#include +#include + +using DimSize = size_t; +using URL = std::string; +using ModelName = std::string; + +// We must assure this can be load by config.json +class ModelConfig { + public: + DimSize hidden_size; + DimSize intermediate_size; + size_t max_position_embeddings; + std::string model_type; + size_t num_attention_heads; + size_t num_hidden_layers; + size_t num_key_value_heads; + size_t vocab_size; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, intermediate_size, max_position_embeddings, model_type, + num_attention_heads, num_hidden_layers, num_key_value_heads, vocab_size); + + void load_from(std::filesystem::path path) { + std::ifstream i(path); + nlohmann::json j; + i >> j; + *this = j.get(); + } +}; + +using QuantType = std::string; +static const QuantType NoQuantType = ""; + +class QuantConfig { + public: + QuantType name; + + // For GEMV + QuantType type_of_dot_vector = NoQuantType; + inline bool can_be_used_as_matrix() { return type_of_dot_vector != NoQuantType; } + + bool can_be_used_as_vector; + + double bytes_per_element; + bool has_scale; + bool has_min; + + size_t block_element_count; + size_t block_element_size; + + URL reference = ""; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(QuantConfig, name, type_of_dot_vector, can_be_used_as_vector, + bytes_per_element, has_scale, has_min, block_element_count, + block_element_size, reference); +}; + +inline std::map quant_configs; +inline std::map model_configs; + +inline void load_quant_configs(std::filesystem::path path) { + std::cout << __FUNCTION__ << " from " << path << std::endl; + std::ifstream i(path); + nlohmann::json j; + i >> j; + quant_configs = j.get>(); + std::cout << "Loaded Quant Configs" << std::endl; + for (auto& [k, v] : quant_configs) { + std::cout << " - " << k << std::endl; + } +} + +inline void dump_quant_configs(std::filesystem::path path) { + std::ofstream o(path); + nlohmann::json j = quant_configs; + o << j.dump(4); +} + +inline void load_model_configs(std::filesystem::path path) { + std::cout << __FUNCTION__ << " from " << path << std::endl; + std::ifstream i(path); + nlohmann::json j; + i >> j; + model_configs = j.get>(); + std::cout << "Loaded Model Configs" << std::endl; + for (auto& [k, v] : model_configs) { + std::cout << " - " << k << std::endl; + } +} + +inline void dump_model_configs(std::filesystem::path path) { + std::ofstream o(path); + nlohmann::json j = model_configs; + o << j.dump(4); +} + +#endif \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/src/page_aligned_memory_pool.cpp b/csrc/balance_serve/kvc2/src/page_aligned_memory_pool.cpp new file mode 100644 index 0000000..d70ed2e --- /dev/null +++ b/csrc/balance_serve/kvc2/src/page_aligned_memory_pool.cpp @@ -0,0 +1,123 @@ +#include "page_aligned_memory_pool.h" + +#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG +#define FMT_HEADER_ONLY +#include "spdlog/spdlog.h" + +#include "utils/arithmetic.hpp" +#include "utils/easy_format.hpp" + +/// 构造函数 +PageAlignedMemoryPool::PageAlignedMemoryPool(size_t size_in_bytes) { + total_size = (size_in_bytes / PageSize) * PageSize; + // 对齐分配。C++17 对齐方式写法,如果编译器不支持可以改用其它方法 + data = ::operator new[](total_size, std::align_val_t(PageSize)); + total_pages = total_size / PageSize; + + assert(total_pages >= Blocks); + page_per_block = total_pages / Blocks; + + for (size_t block_index = 0; block_index < Blocks; block_index ++) { + first_page[block_index] = reinterpret_cast(reinterpret_cast(data) + static_cast(block_index) * page_per_block * PageSize); + count_page[block_index] = + block_index == Blocks - 1 ? (total_pages - page_per_block * (Blocks - 1)) : page_per_block; + SPDLOG_DEBUG("first_page[{}] = {}, count_page[{}] = {}", + block_index, reinterpret_cast(first_page[block_index]) - reinterpret_cast(data), + block_index, count_page[block_index]); + bitmap[block_index].resize(count_page[block_index], 0); + } + SPDLOG_INFO("PageAlignedMemoryPool with size {} Mbytes, {} pages", total_size / (1 << 20), page_count()); +} + +/// 析构函数 +PageAlignedMemoryPool::~PageAlignedMemoryPool() { + if (data) { + // 注意:需要与分配时的对齐方式对应 + ::operator delete[](data, std::align_val_t(PageSize)); + data = nullptr; + } +} + +/// 返回总页数 +size_t PageAlignedMemoryPool::page_count() { + return total_size / PageSize; +} + +/// 返回按整页对齐后的字节数 +size_t PageAlignedMemoryPool::page_padded_size(size_t size) { + return div_up(size, PageSize) * PageSize; +} + +void* PageAlignedMemoryPool::alloc_in_block(size_t block_index, size_t alloc_size) { + std::lock_guard guard(lock[block_index]); + size_t free_pages = 0; + for (size_t i = 0; i < count_page[block_index]; i++) { + if (bitmap[block_index][i] == 0) { + free_pages ++; + if (free_pages == alloc_size) { + size_t page_index = i + 1 - free_pages; + for (size_t page = page_index; page < page_index + alloc_size; page++) { + bitmap[block_index][page] = 1; + // SPDLOG_DEBUG("alloc page {} in block {}", page, block_index); + } + return reinterpret_cast(reinterpret_cast(first_page[block_index]) + page_index * PageSize); + } + } else { + free_pages = 0; + } + } + return nullptr; +} + +/// 分配函数 +void* PageAlignedMemoryPool::alloc(size_t size) { + size_t alloc_size = div_up(size, PageSize); + auto cnt = now_block.fetch_add(1, std::memory_order_relaxed); + for (size_t i = 0; i < Blocks; i ++) { + auto result = alloc_in_block((i + cnt) % Blocks, alloc_size); + if (result != nullptr) { + allocated.fetch_add(alloc_size * PageSize, std::memory_order_relaxed); + alloc_count.fetch_add(1, std::memory_order_relaxed); + return result; + } + } + return nullptr; +} + +/// 释放函数 +void PageAlignedMemoryPool::free(void* p, size_t size) { + auto alloc_size = div_up(size, PageSize); + size_t block_index = (reinterpret_cast(p) - reinterpret_cast(data)) / page_per_block / PageSize; + size_t page_index = (reinterpret_cast(p) - reinterpret_cast(first_page[block_index])) / PageSize; + + std::lock_guard guard(lock[block_index]); + + for (size_t page = page_index; page < page_index + alloc_size; page++) + bitmap[block_index][page] = 0; + + allocated.fetch_sub(alloc_size * PageSize, std::memory_order_relaxed); + free_count.fetch_add(1, std::memory_order_relaxed); +} +// TODO: too slow +std::vector PageAlignedMemoryPool::alloc_multiple(size_t size, size_t count) { + std::vector result; + for (size_t i = 0; i < count; i++) { + auto p = alloc(size); + if (p == nullptr) { + for (auto ptr : result) { + free(ptr, size); + } + return {}; + } + result.push_back(p); + } + return result; +} + +void PageAlignedMemoryPool::defragment() {} + +/// 调试打印 +std::string PageAlignedMemoryPool::debug() { + return fmt::format("PageAlignedMemoryPool: total_size: {}MB, allocated: {}, alloc/free count: {}/{}\n", + readable_number(total_size), readable_number(size_t(allocated)), size_t(alloc_count), size_t(free_count)); +} diff --git a/csrc/balance_serve/kvc2/src/page_aligned_memory_pool.h b/csrc/balance_serve/kvc2/src/page_aligned_memory_pool.h new file mode 100644 index 0000000..c65a740 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/page_aligned_memory_pool.h @@ -0,0 +1,53 @@ +#pragma once + +#include // std::sort +#include // size_t +#include // std::mutex +#include +#include +#include +#include + +constexpr size_t PageSize = 4096; + +/// PageAlignedMemoryPool 类的声明 +struct PageAlignedMemoryPool { + private: + constexpr static size_t Blocks = 16; + + void* data = nullptr; + + size_t total_size = 0, total_pages = 0; + + std::atomic_size_t now_block = 0; + std::atomic_size_t allocated = 0; // allocated_size + std::atomic_size_t alloc_count = 0; + std::atomic_size_t free_count = 0; + + std::mutex lock[Blocks]; + size_t page_per_block = 0; + void *first_page[Blocks]; + size_t count_page[Blocks]; + std::vector bitmap[Blocks]; + void* alloc_in_block(size_t block_index, size_t alloc_size); + public: + /// 构造函数和析构函数 + explicit PageAlignedMemoryPool(size_t size_in_bytes); + ~PageAlignedMemoryPool(); + + /// 禁用拷贝和移动 + PageAlignedMemoryPool(PageAlignedMemoryPool&& other) = delete; + PageAlignedMemoryPool& operator=(PageAlignedMemoryPool&& other) = delete; + PageAlignedMemoryPool(const PageAlignedMemoryPool& other) = delete; + PageAlignedMemoryPool& operator=(const PageAlignedMemoryPool& other) = delete; + + /// 成员函数 + size_t page_count(); + size_t page_padded_size(size_t size); + + void* alloc(size_t size); + std::vector alloc_multiple(size_t size, size_t count); + void free(void* data, size_t size); + void defragment(); + std::string debug(); +}; diff --git a/csrc/balance_serve/kvc2/src/prefix.cpp b/csrc/balance_serve/kvc2/src/prefix.cpp new file mode 100644 index 0000000..add1cd4 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/prefix.cpp @@ -0,0 +1,1746 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG +#define FMT_HEADER_ONLY +#include "spdlog/spdlog.h" + +#include "async_store.hh" +#include "cuda_stream_manager.hh" +#include "kvc2.h" +#include "metrics.h" + +#include "cache_entry.hh" +#include "gpu_cache.hh" +#include "hasher.hpp" +#include "io_helper.hpp" +#include "page_aligned_memory_pool.h" + +#include "utils/arithmetic.hpp" +#include "utils/easy_format.hpp" +#include "utils/periodic_task.hpp" +namespace kvc2 { +struct KVC2; + +// will be set when init +TokenLength NumTokenPerBlock; +int EvictCount; + +using Layer = size_t; + +NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(CacheInfo, model_name, is_key_cache, quant_type); +NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(KVC2Config, gpu_only, load_from_disk, save_to_disk, path, config_path, + num_token_per_page, memory_pool_size, evict_count, metrics_port, recompute_ratio); + +size_t CacheInfo::hidden_layer_count() { + return model_configs.at(model_name).num_hidden_layers; +} + +std::filesystem::path CacheInfo::path(std::optional which_layer) { + auto folder = std::filesystem::path(model_name) / quant_type / (is_key_cache ? "key" : "value"); + if (which_layer.has_value()) { + folder /= fmt::format("layer-{}.kvc", which_layer.value()); + } + return folder; +} + +bool CacheInfo::operator==(const CacheInfo& other) const { + return model_name == other.model_name && is_key_cache == other.is_key_cache && quant_type == other.quant_type; +} + +size_t CacheInfo::element_size(size_t block_length) { + size_t count = model_configs[model_name].hidden_size * block_length; + auto& q = quant_configs[quant_type]; + return count / q.block_element_count * q.block_element_size; +} + +size_t CacheInfo::hash_value() const { + size_t x = hash_seed; + x = XXH64(model_name.data(), model_name.size(), x); + x = XXH64("quant_type", 10, x); + x = XXH64(quant_type.data(), quant_type.size(), x); + if (is_key_cache) { + x = XXH64("key", 3, x); + } else { + x = XXH64("value", 5, x); + } + return x; +} + +} // namespace kvc2 + +template <> +struct std::hash { + std::size_t operator()(const kvc2::CacheInfo& s) const noexcept { return s.hash_value(); } +}; +namespace kvc2 { +struct Location { + size_t start_idx; // start block index + size_t length; // length of blocks + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Location, start_idx, length); + + Location cut_tail(size_t offset_from_tail) { + Location re; + size_t offset = length - offset_from_tail; + re.start_idx = start_idx + offset; + re.length = offset_from_tail; + length = offset; + return re; + } +}; + +struct SegmentLocations { + std::vector> offsets; + + void add_location(size_t start_block, Location location) { + if (location.length + start_block > offsets.size()) { + offsets.resize(location.length + start_block, std::nullopt); + } + + for (size_t i = start_block; i < start_block + location.length; i++) { + offsets[i] = location.start_idx + i - start_block; + } + } + + void set_location(size_t start_block, size_t disk_location) { + if (start_block >= offsets.size()) { + offsets.resize(start_block + 1, std::nullopt); + } + offsets[start_block] = disk_location; + } + + std::optional get_idx(size_t block_idx) const { + if (block_idx >= offsets.size()) { + return std::nullopt; + } else { + return offsets[block_idx]; + } + } + + bool has_location(size_t block_idx, size_t length) { + for (size_t i = block_idx; i < block_idx + length; i++) { + if (get_idx(i).has_value() == false) { + return false; + } + } + return true; + } + + void debug() { + for (size_t i = 0; i < offsets.size(); ++i) { + if (offsets[i].has_value()) { + SPDLOG_DEBUG("Block {} -> Disk Location {}", i, offsets[i].value()); + } else { + SPDLOG_DEBUG("Block {} -> No Disk Location", i); + } + } + } +}; + +struct CacheDiskLocations { + std::unordered_map location_map; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(CacheDiskLocations, location_map); + + std::optional get_location(CacheInfo cache_info, TokenLength local_ids_length) { + size_t blocks_length = div_up(local_ids_length, NumTokenPerBlock); + if (location_map.count(cache_info) == 0) { + return std::nullopt; + } + Location re = location_map[cache_info]; + re.length = blocks_length; + return re; + } + + std::optional get_location_of_a_block(CacheInfo info, size_t local_at) { + if (location_map.count(info) == 0) { + return std::nullopt; + } + auto loc = location_map[info]; + if (local_at >= loc.length) { + return std::nullopt; + } + return loc.start_idx + local_at; + } +}; + +struct DiskCacheAllocator { + private: + // metadata + std::filesystem::path path; + CacheInfo info; + std::mutex lock; + size_t now_idx; + + // store + size_t capacity; + std::vector stores; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(DiskCacheAllocator, now_idx); + + void update_capacity() { + capacity = std::numeric_limits::max(); + for (auto& store : stores) { + capacity = std::min(capacity, async_store::capacity(store)); + } + } + + void extend(size_t to) { + for (size_t i = 0; i < info.hidden_layer_count(); i++) { + async_store::extend(stores[i], to); + } + update_capacity(); + } + + public: + async_store::ArrayStore* get_store(int i) { return stores[i]; } + Location alloc(size_t block_count) { + std::lock_guard lg(lock); + Location re; + re.start_idx = now_idx; + re.length = block_count; + now_idx += block_count; + if (now_idx >= capacity) { + extend(capacity * 2); + } + return re; + } + + DiskCacheAllocator(std::filesystem::path path, CacheInfo info) : path(path), info(info) { + // SPDLOG_DEBUG("Create DiskCacheAllocator {}", path.c_str()); + auto allocator_path = path / info.path(); + if (std::filesystem::exists(allocator_path) == false) { + std::filesystem::create_directories(allocator_path); + } + // restore metadata later in json load + now_idx = 0; + + for (size_t i = 0; i < info.hidden_layer_count(); i++) { + // SPDLOG_DEBUG("Create store {} for {}", (path / info.path(i)).c_str(),i); + auto store = async_store::create_or_open_store(info.element_size(NumTokenPerBlock), 1000, path / info.path(i)); + stores.push_back(store); + } + update_capacity(); + } + + ~DiskCacheAllocator() { + for (auto store : stores) { + async_store::close_store(store); + } + } +}; + +struct DiskCacheManager { + KVC2Config config; + std::mutex lock; + std::unordered_map> allocators; + + friend void to_json(nlohmann ::json& nlohmann_json_j, const DiskCacheManager& nlohmann_json_t) { + nlohmann_json_j["config"] = nlohmann_json_t.config; + nlohmann_json_j["allocators"] = nlohmann::json::array(); + for (auto& [info, allocator] : nlohmann_json_t.allocators) { + nlohmann_json_j["allocators"].push_back({{"info", info}, {"allocator", *allocator}}); + } + } + friend void from_json(const nlohmann ::json& nlohmann_json_j, DiskCacheManager& nlohmann_json_t) { + // SPDLOG_DEBUG("Load DiskCacheManager Json"); + nlohmann_json_j.at("config").get_to(nlohmann_json_t.config); + for (const auto& allocator_json : nlohmann_json_j.at("allocators")) { + // SPDLOG_DEBUG("Make Allocator {}",allocator_json.dump()); + CacheInfo info; + allocator_json.at("info").get_to(info); + auto allocator = std::make_shared(nlohmann_json_t.config.path, info); + allocator_json.at("allocator").get_to(*allocator); + nlohmann_json_t.allocators[info] = allocator; + } + }; + + DiskCacheManager(KVC2Config config) : config(config) { + SPDLOG_INFO("DiskCacheManager root path: {}", config.path.c_str()); + if (!std::filesystem::exists(config.path)) { + std::filesystem::create_directories(config.path); + } + } + + std::shared_ptr get_allocator(CacheInfo info) { + { + std::lock_guard lg(lock); + if (allocators.count(info) == 0) { + allocators.emplace(info, std::make_shared(config.path, info)); + } + } + return allocators.at(info); + } + + Location allocate(CacheInfo info, size_t cache_block_count) { + auto allocator = get_allocator(info); + return allocator->alloc(cache_block_count); + } +}; + +struct Prefix { + uint64_t prefix_id; // 0 for nullptr, started from 1 + TokenLength start_length; + Tokens ids; + CacheDiskLocations locations; + Prefix* prev = nullptr; + + // No serialization + bool prev_set = false; + + friend void to_json(nlohmann ::json& nlohmann_json_j, const Prefix& nlohmann_json_t) { + nlohmann_json_j["prefix_id"] = nlohmann_json_t.prefix_id; + nlohmann_json_j["start_length"] = nlohmann_json_t.start_length; + nlohmann_json_j["ids"] = nlohmann_json_t.ids; + if (nlohmann_json_t.prev) { + nlohmann_json_j["prev"] = nlohmann_json_t.prev->prefix_id; + } else { + nlohmann_json_j["prev"] = 0; + } + nlohmann_json_j["locations"] = nlohmann_json_t.locations; + } + friend void from_json(const nlohmann ::json& nlohmann_json_j, Prefix& nlohmann_json_t) { + nlohmann_json_j.at("prefix_id").get_to(nlohmann_json_t.prefix_id); + nlohmann_json_j.at("start_length").get_to(nlohmann_json_t.start_length); + nlohmann_json_j.at("ids").get_to(nlohmann_json_t.ids); + nlohmann_json_j.at("locations").get_to(nlohmann_json_t.locations); + + auto prev_id = nlohmann_json_j.at("prev").get(); + nlohmann_json_t.prev = reinterpret_cast(prev_id); + nlohmann_json_t.prev_set = false; + }; + + TokenLength local_length() { return ids.size(); } + TokenLength length() { return start_length + local_length(); } + Tokens prefix_to(TokenLength length) { + TokenLength local_length = length - start_length; + Tokens re; + if (prev) { + re = prev->prefix_to(start_length); + } + re.insert(re.end(), ids.begin(), ids.begin() + local_length); + return re; + } + Tokens full() { return prefix_to(length()); } + + void update_location(CacheInfo info, Location location) { locations.location_map[info] = location; } + + Prefix* to_first_prefix_without_disk_locations(CacheInfo k_info/*, CacheInfo v_info*/) { // just k_info + auto now_prefix = this; + while (now_prefix->prev != nullptr) { + auto& prev = now_prefix->prev; + auto k_location = prev->locations.get_location(k_info, prev->local_length()); + // auto v_location = prev->locations.get_location(v_info, prev->local_length()); + if (k_location.has_value()) { + // assert(v_location.has_value()); + // after now_prefix, we need to insert new kv cache. + break; + } + now_prefix = prev; + } + return now_prefix; + } + + void hash_to_with(TokenLength length, TokensHasher& hasher) { + TokenLength local_length = length - start_length; + if (prev) { + prev->hash_to_with(start_length, hasher); + } + hasher.update(ids.data(), local_length); + } + + void debug() { + fmt::print("Prefix {}, start_length: {}, local_length: {}, prev: {}, \n", prefix_id, start_length, local_length(), + (void*)prev); + } +}; +struct PrefixMatch { + Prefix* prefix; + TokenLength match_length; + + std::vector matched_hashes(CacheInfo info, Layer layer) { + std::vector re; + if (prefix == nullptr) + return re; + TokensHasher hasher; + hasher.reset(info.hash_value()); + hasher.update_raw(&layer, sizeof(layer)); + auto ids = prefix->prefix_to(match_length); + for (TokenLength i = 0; i < ids.size(); i += NumTokenPerBlock) { + TokenLength len = std::min(NumTokenPerBlock, ids.size() - i); + re.push_back(hasher.update(ids.data() + i, len)); + } + return re; + } + + void collect_locations(CacheInfo info, SegmentLocations& seg_locs) { + auto now_prefix = prefix; + size_t length = match_length; + while (now_prefix != nullptr) { + TokenLength local_length = length - now_prefix->start_length; + auto loc = now_prefix->locations.get_location(info, local_length); + if (loc.has_value()) { + seg_locs.add_location(now_prefix->start_length / NumTokenPerBlock, loc.value()); + } + length = now_prefix->start_length; + now_prefix = now_prefix->prev; + } + } +}; + +std::string to_string(const MatchStatus& status) { + switch (status) { + case Exact: + return "Exact"; + case Partial: + return "Partial"; + case NotMatchExact: + return "NotMatchExact"; + case NotMatchPartial: + return "NotMatchPartial"; + default: + return "Unknown"; + } +} + +struct MatchByBlock { + // prefix, block idx at prefix, status + std::vector> matches; + + bool any_match() { + for (auto& [p, l, m] : matches) { + if (p) { + return true; + } + } + return false; + } + + size_t partial_count() { + size_t re = 0; + for (auto& [p, l, m] : matches) { + if (m == Partial) { + re++; + } + } + return re; + } + + bool has_partial() { return partial_count() > 0; } + + std::vector> matched_hashes(CacheInfo info, Layer layer) { + // TODO: This function might be slow + std::vector> re(matches.size(), std::nullopt); + + for (size_t i = 0; i < matches.size(); i++) { + TokensHasher hasher; + hasher.reset(info.hash_value()); + hasher.update_raw(&layer, sizeof(layer)); + auto& [p, idx, status] = matches[i]; + if (p) { + p->hash_to_with((idx + 1) * NumTokenPerBlock, hasher); + re[i] = hasher.get(); + } + } + return re; + } + + void collect_locations(CacheInfo info, SegmentLocations& seg_locs) { + for (size_t i = 0; i < matches.size(); i++) { + auto& [p, idx, status] = matches[i]; + if (p) { + auto local_at = idx - p->start_length / NumTokenPerBlock; + seg_locs.set_location(i, p->locations.get_location_of_a_block(info, local_at).value()); + } + } + } + + std::string debug_string() { + std::string re = fmt::format("{} Match: ", matches.size()); + for (auto& [p, idx, status] : matches) { + switch (status) { + case Exact: + re += "E"; + break; + case Partial: + re += "P"; + break; + case NotMatchExact: + re += "N"; + break; + case NotMatchPartial: + re += "n"; + break; + default: + assert(0); + } + } + return re; + } +}; + +struct PrefixTree { + std::shared_mutex rw_lock; + + std::atomic_uint64_t prefix_id_counter = 1; + using MapT = + std::unordered_map, BlockLength>>; // Prefix, start_block_idx + MapT prefix_map; + + std::shared_ptr met; + + std::vector> prefix_refs = {nullptr}; // 0 is nullptr + + friend void to_json(nlohmann ::json& nlohmann_json_j, const PrefixTree& nlohmann_json_t) { + nlohmann_json_j["prefix_id_counter"] = nlohmann_json_t.prefix_id_counter.load(); + nlohmann_json_j["prefix_refs"] = nlohmann::json::array(); + for (auto prefix : nlohmann_json_t.prefix_refs) { + if (prefix == nullptr) + continue; + nlohmann_json_j["prefix_refs"].push_back(*prefix); + } + } + friend void from_json(const nlohmann ::json& nlohmann_json_j, PrefixTree& nlohmann_json_t) { + nlohmann_json_t.prefix_id_counter = nlohmann_json_j.at("prefix_id_counter").get(); + + nlohmann_json_t.prefix_refs.resize(nlohmann_json_t.prefix_id_counter); + for (size_t i = 1; i < nlohmann_json_t.prefix_id_counter; ++i) { + auto prefix = std::make_shared(); + nlohmann_json_j.at("prefix_refs")[i - 1].get_to(*prefix); + nlohmann_json_t.prefix_refs[i] = prefix; + } + nlohmann_json_t.init_prevs(); + nlohmann_json_t.init_map(); + }; + + void init_prevs() { + for (auto p : prefix_refs) { + if (p) { + if (p->prev_set == false) { + p->prev = prefix_refs[reinterpret_cast(p->prev)].get(); + p->prev_set = true; + } + } + } + } + + void init_map() { + assert(prefix_map.empty()); + for (auto p : prefix_refs) { + if (p == nullptr) + continue; + + auto ids = p->full(); + for (TokenLength i = p->start_length; i < p->length(); i += NumTokenPerBlock) { + TokenLength end = std::min(i + NumTokenPerBlock, p->length()); + assert(end % NumTokenPerBlock == 0); + auto hash = TokensHasher::hash(ids.data(), end); + prefix_map[hash] = {p, end / NumTokenPerBlock - 1}; + } + } + } + + // Look up prefix from the map, return the matched prefix and length. + // If the prefix is not found, match contains nullptr and 0. + PrefixMatch look_up(Token* data, TokenLength length, bool need_lock = true) { + std::shared_lock sl; + if (need_lock) { + sl = std::shared_lock(rw_lock); + } + //TODO: prefix cache + } + + PrefixMatch look_up_or_insert(Token* data, TokenLength length) { + std::unique_lock ul(rw_lock); + + auto match = look_up(data, length, false); + if (match.match_length == length) { + return match; + } + auto new_prefix = new_prefix_node(match.prefix, match.match_length, data, length, false); + + PrefixMatch re; + re.prefix = new_prefix.get(); + re.match_length = length; + return re; + } + + + std::shared_ptr new_prefix_node(Prefix* prev, TokenLength prev_match_length, Token* data, TokenLength length, + bool need_lock = true) { + std::unique_lock ul; + if (need_lock) + ul = std::unique_lock(rw_lock); + auto new_prefix = std::make_shared(); + new_prefix->prefix_id = prefix_id_counter.fetch_add(1); + new_prefix->start_length = prev_match_length; + new_prefix->ids = Tokens(data + prev_match_length, data + length); + new_prefix->prev = prev; + new_prefix->prev_set = true; + prefix_refs.push_back(new_prefix); + met->prefix_nodes->Increment(); + met->prefix_block_count->Increment(div_up(length - prev_match_length, NumTokenPerBlock)); + + assert(prefix_refs.size() == prefix_id_counter.load()); + + TokensHasher hasher; + hasher.update(data, prev_match_length); + + for (TokenLength i = prev_match_length; i < length; i += NumTokenPerBlock) { + TokenLength len = std::min(NumTokenPerBlock, length - i); + auto hash = hasher.update(data + i, len); + prefix_map[hash] = {new_prefix, i / NumTokenPerBlock}; + } + + return new_prefix; + } + + void debug() { + fmt::print("PrefixTree with {} prefixes, prefix counter: {}\n", prefix_map.size(), prefix_id_counter.load()); + for (auto& [hash, prefix] : prefix_map) { + fmt::print("Hash: {:016x}, start block {}\n", hash, prefix.second); + prefix.first->debug(); + } + } +}; + +size_t locations_blocks_count(const std::vector& locations) { + auto re = 0; + for (auto& loc : locations) { + re += loc.length; + } + return re; +} + +struct DoubleCacheHandle : public DoubleCacheHandleInterface { + ModelName model_name; + QuantType quant_type; + bool is_k_cache_on; + bool is_v_cache_on; + CacheInfo k_info() { + if (is_k_cache_on == false) { + SPDLOG_WARN("Get K CacheInfo, but K Cache is off"); + } + return CacheInfo{ + .model_name = model_name, + .is_key_cache = true, + .quant_type = quant_type, + }; + }; + + CacheInfo v_info() { + if (is_v_cache_on == false) { + SPDLOG_WARN("Get V CacheInfo, but K Cache is off"); + } + return CacheInfo{ + .model_name = model_name, + .is_key_cache = false, + .quant_type = quant_type, + }; + }; + + Tokens ids; + TokenLength estimated_length; + + bool enable_alt = false; + PrefixMatch match; + // MatchByBlock match_by_blocks; + + std::vector>> k_cache_handles; + std::vector>> v_cache_handles; + + SegmentLocations k_seg_locs; + SegmentLocations v_seg_locs; + + KVC2* kvc2_top; + + // for Cache Fusion + std::vector>> attatched_cache_handles; + + std::unique_ptr cpu_releaser = nullptr, gpu_releaser = nullptr; + + std::vector gpu_only_block_idx; + + virtual ~DoubleCacheHandle(); + // interface + TokenLength matched_length() override { + if (enable_alt) { + assert(0); + } else { + return match.match_length; + } + } + MatchStatus status_at(BlockLength i) { + assert(i < div_up(estimated_length, NumTokenPerBlock)); + if (enable_alt) { + assert(false); + // if (i >= match_by_blocks.matches.size()) { + // return match_by_blocks.has_partial() ? MatchStatus::NotMatchPartial : MatchStatus::NotMatchExact; + // } + // return std::get<2>(match_by_blocks.matches[i]); + } else { + if (i < match.match_length / NumTokenPerBlock) { + return MatchStatus::Exact; + } else { + return MatchStatus::NotMatchExact; + } + } + } + std::vector matched_status() override { + assert(false); + } + + bool any_match() { + if (enable_alt) { + assert(false); + // return match_by_blocks.any_match(); + } else { + return match.prefix != nullptr; + } + } + + BlockLength match_range_length() { + if (enable_alt) { + assert(false); + // return match_by_blocks.matches.size(); + } else { + return div_up(match.match_length, NumTokenPerBlock); + } + } + + std::vector handle_data(bool is_key_cache) override { return export_raw_pointers(is_key_cache); } + bool to_gpu() override; + void to_gpu_async(std::function call_back) override; + + std::vector get_gpu_block_idx() override; + + bool alloc_attached_blocks(BlockLength count); + std::vector get_gpu_attached_block_idx() override; + + void append_tokens(Token* tokens, TokenLength length) override; + + void debug() override {} + + void set_cache_info(ModelName model_name, QuantType quant_type, bool turn_on_k_cache, bool turn_on_v_cache) { + this->model_name = model_name; + this->quant_type = quant_type; + if (turn_on_k_cache) { + is_k_cache_on = true; + k_cache_handles.resize(k_info().hidden_layer_count()); + } else { + is_k_cache_on = false; + k_cache_handles.clear(); + } + if (turn_on_v_cache) { + is_v_cache_on = true; + v_cache_handles.resize(v_info().hidden_layer_count()); + } else { + is_v_cache_on = false; + v_cache_handles.clear(); + } + } + + void check_before_insert() { + std::optional blocks_count = std::nullopt; + + auto check_single_cache = [&blocks_count](CacheInfo cache_info, + std::vector>>& layers, + Tokens& ids) { + for (size_t i = 0; i < cache_info.hidden_layer_count(); i++) { + auto& layer = layers[i]; + if (blocks_count.has_value() == false) { + blocks_count = layer.size(); + } else { + if (blocks_count.value() != layer.size()) { + SPDLOG_ERROR("Layer {} has different block count", i); + throw std::runtime_error("Layer has different block count"); + } + } + } + if (blocks_count.has_value()) { + if (blocks_count.value() != div_up(ids.size(), NumTokenPerBlock)) { + SPDLOG_ERROR("Block count not match, ids: {}, blocks: {}", ids.size(), blocks_count.value()); + throw std::runtime_error("Block count not match"); + } + } + }; + + if (is_k_cache_on) + check_single_cache(k_info(), k_cache_handles, ids); + if (is_v_cache_on) + check_single_cache(v_info(), v_cache_handles, ids); + } + + template + void for_all_cache_block_entry(Fn f) { + if (is_k_cache_on) { + for (auto& layer : k_cache_handles) { + for (auto& block : layer) { + if (f(block) == false) + return; + } + } + } + if (is_v_cache_on) { + for (auto& layer : v_cache_handles) { + for (auto& block : layer) { + if (f(block) == false) + return; + } + } + } + } + + // concurrent check ok + bool alloc_on_cpu() { + assert(cpu_releaser == nullptr); + std::unique_ptr releaser = + std::make_unique([](CacheBlockEntry* entry) { + auto lg = entry->lock_guard(); + entry->cpu_cc.ref_count.fetch_sub(1); + }); + bool ok = true; + + for_all_cache_block_entry([&ok, &releaser](std::shared_ptr& block_entry) { + if (block_entry->inc_ref_or_alloc_on_cpu() == false) { + ok = false; + return false; + } else { + releaser->entries.push_back(block_entry.get()); + } + return true; + }); + + if (ok) { + cpu_releaser = std::move(releaser); + } + return ok; + } + + bool alloc_on_gpu_cols() { + assert(is_k_cache_on); + assert(gpu_releaser == nullptr); + std::unique_ptr releaser = + std::make_unique([](CacheBlockEntry* entry) { + auto lg = entry->lock_guard(); + entry->gpu_cc.ref_count.fetch_sub(1); + }); + + GPUPageCache* gpu_cache = k_cache_handles[0][0]->manager->gpu_cache.get(); + gpu_cache->background_flush_back->wakeUpWait(); + + bool ok = true; + size_t want_count = 0; + for (size_t i = 0; i < k_cache_handles[0].size(); i++) { + auto lg = k_cache_handles[0][i]->lock_guard(); + if (k_cache_handles[0][i]->gpu_block_idx.has_value() == false) { + want_count += 1; + if (gpu_cache->alloc_col(k_cache_handles, v_cache_handles, i) == false) { + ok = false; + break; + } + } + k_cache_handles[0][i]->gpu_cc.ref_count.fetch_add(1); + releaser->entries.push_back(k_cache_handles[0][i].get()); + } + if (ok == false) { + SPDLOG_WARN("Handle cannot allocate {} gpu pages", want_count); + } else { + gpu_releaser = std::move(releaser); + } + return ok; + } + + static void segment_io_layer(async_store::IODealer* dealer, IO_Helper& io_helper, + async_store::ArrayStore* store, + std::vector>& layer_entries, size_t block_start, + size_t length, Layer layer, const SegmentLocations& locations, IOOption option) { + SPDLOG_TRACE("{} [{}:{}) blocks to/from disk", to_string(option), block_start, block_start + length); + for (size_t i = block_start; i < block_start + length; i++) { + if (locations.get_idx(i).has_value()) { + SPDLOG_TRACE("Location for block {}, {}", i, locations.get_idx(i).value()); + layer_entries[i]->io_with(dealer, io_helper, store, layer, locations.get_idx(i).value(), option); + } + } + } + + std::shared_ptr> segment_io(async_store::IODealer* dealer, DiskCacheManager* manager, + BlockLength block_start, BlockLength length, IOOption option) { + auto io_helper = std::make_shared>([option](CacheBlockEntry* b) { + switch (option) { + case IO_ForceRead: + break; + case IO_ForceWrite: + break; + case IO_Read: { + b->cpu_cc.tc.set_has_data(); + break; + } + case IO_Write: + break; + default: + assert(0); + } + }); + + auto single_segment_io = [dealer, manager, block_start, length, option, io_helper]( + CacheInfo info, SegmentLocations& seg_locs, + std::vector>>& layers) { + assert(layers[0].size() >= block_start + length); + + auto allocator = manager->get_allocator(info); + + for (size_t l = 0; l < info.hidden_layer_count(); l++) { + segment_io_layer(dealer, *io_helper, allocator->get_store(l), layers[l], block_start, length, l, seg_locs, + option); + } + }; + + if (is_k_cache_on) + single_segment_io(k_info(), k_seg_locs, k_cache_handles); + if (is_v_cache_on) + single_segment_io(v_info(), v_seg_locs, v_cache_handles); + + io_helper->finish_add_taks(); + SPDLOG_DEBUG("Segment IO Submitted, total task count {}", io_helper->total_task_count); + return io_helper; + } + + std::shared_ptr> gpu_io(GPUPageCache* gpu_cache, BlockLength block_start, + BlockLength length, IOOption option) { + auto io_helper = std::make_shared>([option](CacheBlockEntry* b) { + switch (option) { + case IO_ForceRead: + break; + case IO_ForceWrite: + break; + case IO_Read: { + b->gpu_cc.tc.set_has_data(); + break; + } + case IO_Write: + break; + default: + assert(0); + } + }); + + cudaMemcpyKind direction; + if (option == IO_Read || option == IO_ForceRead) { + direction = cudaMemcpyHostToDevice; + } + if (option == IO_Write || option == IO_ForceWrite) { + direction = cudaMemcpyDeviceToHost; + } + + auto reqs = gpu_cache->basic_request(direction, [io_helper]() { io_helper->batch_promise.set(); }); + + for (size_t i = block_start; i < length; i++) { + auto status = status_at(i); + if (status == NotMatchExact || status == NotMatchPartial) { + SPDLOG_DEBUG("GPU: Col Handle not match (Skipped by Alt Match)"); + continue; + } + auto ptr = k_cache_handles[0][i].get(); + + switch (option) { + case IO_Read: { + if (io_helper->absorb_tc(ptr, ptr->gpu_cc.tc) == false) { + // SPDLOG_DEBUG("GPU: Col Handle need me to wait"); + continue; + } + break; + } + case IO_ForceRead: { + break; + } + case IO_ForceWrite: { + break; + } + case IO_Write: { + break; + } + default: { + assert(0); + } + } + SPDLOG_DEBUG("GPU: Col Handle needs me to transfer"); + gpu_cache->append_col_to_request(reqs, k_cache_handles, v_cache_handles, i); + } + io_helper->new_task(reqs.size()); + gpu_cache->submit_requests(reqs); + io_helper->finish_add_taks(); + return io_helper; + } + + // void set_raw_handles(const std::vector& k, const std::vector& v) { + // set_raw_handles(true, k); + // set_raw_handles(false, v); + // } + void set_raw_handles(bool is_key_cache, const std::vector& layer_data) { + auto single_set_raw_handles = [layer_data](CacheInfo info, + std::vector>>& handles) { + handles.resize(layer_data.size()); + for (size_t i = 0; i < info.hidden_layer_count(); i++) { + auto& layer = layer_data[i]; + handles[i].clear(); + for (auto& block_data : layer) { + auto handle = std::make_shared(); + handle->data = reinterpret_cast(block_data); + handle->size = info.element_size(NumTokenPerBlock); + handles[i].push_back(handle); + } + } + }; + + if (is_key_cache) { + is_k_cache_on = true; + single_set_raw_handles(k_info(), k_cache_handles); + } else { + is_v_cache_on = true; + single_set_raw_handles(v_info(), v_cache_handles); + } + } + + std::vector export_raw_pointers(bool is_key_cache) { + std::vector re; + + auto single_export_raw_pointers = [&re](std::vector>>& layers) { + for (auto& layer_handle : layers) { + layer_data layer; + for (size_t i = 0; i < layer_handle.size(); i++) { + auto block = layer_handle.at(i); + layer.push_back(reinterpret_cast(block->data)); + } + re.push_back(layer); + } + }; + + if (is_key_cache) { + if (is_k_cache_on == false) { + SPDLOG_WARN("Export K Cache, but K Cache is off"); + } + single_export_raw_pointers(k_cache_handles); + } else { + if (is_v_cache_on == false) { + SPDLOG_WARN("Export V Cache, but V Cache is off"); + } + single_export_raw_pointers(v_cache_handles); + } + + return re; + } + + void get_handles(); + void get_empty_handles(); + + void collect_locations() { + if (enable_alt) { + assert(false); + // match_by_blocks.collect_locations(k_info(), k_seg_locs); + // match_by_blocks.collect_locations(v_info(), v_seg_locs); + } else { + if (is_k_cache_on) + match.collect_locations(k_info(), k_seg_locs); + if (is_v_cache_on) + match.collect_locations(v_info(), v_seg_locs); + } + if (is_k_cache_on) + k_seg_locs.debug(); + // v_seg_locs.debug(); + } +}; + +struct KVC2 : KVC2Interface { + + KVC2Config config; + std::shared_ptr met; + + std::filesystem::path root; + std::unique_ptr tree; + std::unique_ptr disk_cache; + std::shared_ptr memory_pool; + std::unique_ptr cache_manager; + std::unique_ptr io_dealer; + + std::shared_ptr gpu_cache; + + public: + void load() override { + load_quant_configs(root / "quant_configs.json"); + load_model_configs(root / "model_configs.json"); + { + auto where = root / "tree.json"; + if (std::filesystem::exists(where)) { + nlohmann::json j; + std::ifstream i(where); + i >> j; + j.get_to(*tree); + SPDLOG_WARN("Loaded from {}", where.c_str()); + } + } + { + auto where = root / "disk_cache.json"; + if (std::filesystem::exists(where)) { + nlohmann::json j; + std::ifstream i(where); + i >> j; + j.get_to(*disk_cache); + SPDLOG_WARN("Loaded from {}", where.c_str()); + } + } + { + auto where = root / "config.json"; + if (std::filesystem::exists(where)) { + nlohmann::json j; + std::ifstream i(where); + i >> j; + j.get_to(config); + SPDLOG_WARN("Loaded from {}", where.c_str()); + } + } + } + + void save() override { + if (config.save_to_disk == false) { + return; + } + flush_back(); + { + nlohmann::json j; + j = *tree; + auto where = root / "tree.json"; + std::ofstream o(where); + o << j; + SPDLOG_WARN("Serialized to {}", where.c_str()); + } + { + nlohmann::json j; + j = *disk_cache; + auto where = root / "disk_cache.json"; + std::ofstream o(where); + o << j; + SPDLOG_WARN("Serialized to {}", where.c_str()); + } + { + nlohmann::json j; + j = config; + auto where = root / "config.json"; + std::ofstream o(where); + o << j; + SPDLOG_WARN("Serialized to {}", where.c_str()); + } + dump_quant_configs(root / "quant_configs.json"); + dump_model_configs(root / "model_configs.json"); + } + + void raw_insert(ModelName model_name, QuantType quant_type, Token* id, TokenLength length, + const std::vector& k_cache, const std::vector& v_cache) override { + TimeObserver time_observer(met->raw_insert_time_ms); + + SPDLOG_INFO("Raw Insert"); + if (length % NumTokenPerBlock != 0) { + SPDLOG_WARN("Try to insert tokens with length {}, which is not a multiple of NumTokenPerBlock({}), getting floor", + length, NumTokenPerBlock); + length = length / NumTokenPerBlock * NumTokenPerBlock; + } + + auto h = std::make_shared(); + h->kvc2_top = this; + h->set_cache_info(model_name, quant_type, config.k_cache_on, config.v_cache_on); + h->ids = Tokens(id, id + length); + + if (config.k_cache_on) + h->set_raw_handles(true, k_cache); + if (config.v_cache_on) + h->set_raw_handles(false, v_cache); + + h->check_before_insert(); + + h->match = tree->look_up_or_insert(id, length); + + auto now_prefix = h->match.prefix; + assert(config.k_cache_on); + + if (now_prefix->locations.get_location(h->k_info(), length - now_prefix->start_length).has_value()) { + assert(now_prefix->locations.get_location(h->v_info(), length - now_prefix->start_length).has_value()); + SPDLOG_INFO("KV Cache Already on disk"); + // already on disk + } else { + now_prefix = now_prefix->to_first_prefix_without_disk_locations(h->k_info()); + + // insert new kv cache locations + TokenLength new_length = length - now_prefix->start_length; + SPDLOG_DEBUG("Inserting new kv cache, length: {}", new_length); + assert(new_length > 0); + + if (config.v_cache_on) { + // allocate a big space on disk + auto k_loc = disk_cache->allocate(h->k_info(), div_up(new_length, NumTokenPerBlock)); + auto v_loc = disk_cache->allocate(h->v_info(), div_up(new_length, NumTokenPerBlock)); + h->k_seg_locs.add_location(now_prefix->start_length / NumTokenPerBlock, k_loc); + h->v_seg_locs.add_location(now_prefix->start_length / NumTokenPerBlock, v_loc); + + // split it to prefix trees + for (auto tail = h->match.prefix; tail != now_prefix->prev; tail = tail->prev) { + TokenLength local_ids_length = tail->local_length(); + tail->update_location(h->k_info(), k_loc.cut_tail(div_up(local_ids_length, NumTokenPerBlock))); + tail->update_location(h->v_info(), v_loc.cut_tail(div_up(local_ids_length, NumTokenPerBlock))); + } + assert(k_loc.length == 0); + assert(v_loc.length == 0); + } else { + // allocate a big space on disk + auto k_loc = disk_cache->allocate(h->k_info(), div_up(new_length, NumTokenPerBlock)); + h->k_seg_locs.add_location(now_prefix->start_length / NumTokenPerBlock, k_loc); + + // split it to prefix trees + for (auto tail = h->match.prefix; tail != now_prefix->prev; tail = tail->prev) { + TokenLength local_ids_length = tail->local_length(); + tail->update_location(h->k_info(), k_loc.cut_tail(div_up(local_ids_length, NumTokenPerBlock))); + } + assert(k_loc.length == 0); + } + + // write new kv cache + auto disk_io_helper = + h->segment_io(io_dealer.get(), disk_cache.get(), now_prefix->start_length / NumTokenPerBlock, + div_up(new_length, NumTokenPerBlock), IO_ForceWrite); + disk_io_helper->wait(); + } + } + + TokenLength raw_read(ModelName model_name, QuantType quant_type, Token* id, TokenLength length, + const std::vector& k_cache, const std::vector& v_cache) override { + SPDLOG_INFO("Raw Read"); + auto h = std::make_shared(); + h->kvc2_top = this; + h->set_cache_info(model_name, quant_type, config.k_cache_on, config.v_cache_on); + h->ids = Tokens(id, id + length); + + if (config.k_cache_on) + h->set_raw_handles(true, k_cache); + if (config.v_cache_on) + h->set_raw_handles(false, v_cache); + + h->match = tree->look_up(id, length); + if (h->match.prefix == nullptr) { + SPDLOG_INFO("Not Found"); + return 0; + } + SPDLOG_DEBUG("Found {}", h->match.match_length); + h->collect_locations(); + auto disk_io_helper = h->segment_io(io_dealer.get(), disk_cache.get(), 0, + div_up(h->match.match_length, NumTokenPerBlock), IO_ForceRead); + + disk_io_helper->wait(); + return h->match.match_length; + } + + std::shared_ptr lookup(ModelName model_name, QuantType quant_type, Token* id, + TokenLength length, TokenLength estimated_length) override { + TimeObserver time_observer(met->lookup_time_ms); + auto re = std::make_shared(); + re->set_cache_info(model_name, quant_type, config.k_cache_on, config.v_cache_on); + re->ids = Tokens(id, id + length); + re->estimated_length = estimated_length; + re->kvc2_top = this; + SPDLOG_DEBUG("Lookup TokenLength {}", length); + if (config.gpu_only == false) { + //TODO: + } + return re; + }; + + std::shared_ptr lookup_to_gpu(ModelName model_name, QuantType quant_type, Token* id, + size_t length, size_t estimated_length) override { + std::promise> p; + lookup_to_gpu_async(model_name, quant_type, id, length, estimated_length, [&p](auto re) { p.set_value(re); }); + return p.get_future().get(); + } + + void lookup_to_gpu_async(ModelName model_name, QuantType quant_type, Token* id, TokenLength length, + TokenLength estimated_length, + std::function)> call_back) override { + auto re = lookup(model_name, quant_type, id, length, estimated_length); + if (re == nullptr) { + call_back(nullptr); + return; + } + auto h = static_cast(re.get()); + if (config.gpu_only) { + auto total_block_count = div_up(estimated_length, NumTokenPerBlock); + h->gpu_only_block_idx = gpu_cache->gpu_only_alloc_col(total_block_count); + if (h->gpu_only_block_idx.empty()) { + call_back(nullptr); + } else { + call_back(re); + } + + } else { + if (h->k_info().hidden_layer_count() != gpu_cache->config.layer_count) { + SPDLOG_ERROR("GPU Cache Layer Count not match"); + assert(false); + } + + if (h->alloc_on_gpu_cols() == false) { + call_back(nullptr); + return; + } + + h->to_gpu_async([call_back, re](bool ok) { + if (ok) { + call_back(re); + } else { + call_back(nullptr); + } + }); + } + } + + std::pair, std::vector> get_kvcache() override { + return {gpu_cache->k_cache, gpu_cache->v_cache}; + } + + void flush_back() { + gpu_cache->background_flush_back->wakeUpWait(); + cache_manager->background_flush_back->wakeUpWait(); + } + + void debug() override { + cache_manager->debug(); + tree->debug(); + } + + virtual ~KVC2() { flush_back(); }; + + KVC2(KVC2Config config) : config(config) { + SPDLOG_INFO("Creating KVC2 using these config"); + SPDLOG_INFO(" GPU Only: {}", config.gpu_only); + SPDLOG_INFO(" Load: {}, Save: {}", config.load_from_disk, config.save_to_disk); + SPDLOG_INFO(" Path: {}", config.path); + SPDLOG_INFO(" Config Path: {}", config.config_path); + SPDLOG_INFO(" Num Token/Page: {}, Memory Pool Size: {}", config.num_token_per_page, + readable_number(config.memory_pool_size)); + SPDLOG_INFO(" Evict Count: {}, Metrics Port: {}", config.evict_count, config.metrics_port); + SPDLOG_INFO(" Recompute Ratio: {:.2f}", config.recompute_ratio); + + if (config.gpu_cache_config) { + const auto& gpu_config = *config.gpu_cache_config; + SPDLOG_INFO(" GPU Devices: {}", format_vector(gpu_config.gpu_devices_id)); + SPDLOG_INFO(" Layer Count: {}, Total KVCache Pages: {}", gpu_config.layer_count, + gpu_config.total_kvcache_pages); + SPDLOG_INFO(" Num Token/Page: {}, Num K Heads: {}", gpu_config.num_token_per_page, gpu_config.num_k_heads); + SPDLOG_INFO(" K Head Dim: {}, Tensor Type: {}", gpu_config.k_head_dim, + static_cast(gpu_config.tensor_type)); + SPDLOG_INFO(" MemcpyCudaStreams/Device: {}", gpu_config.num_streams_per_device); + } else { + SPDLOG_INFO(" GPU Cache Config: None"); + } + + load_model_configs(config.config_path + "/model_configs.json"); + load_quant_configs(config.config_path + "/quant_configs.json"); + + // met + MetricsConfig met_conf; + met_conf.endpoint = "0.0.0.0:" + std::to_string(config.metrics_port); + SPDLOG_INFO("Creating kvc2 metrics exporter on {}", met_conf.endpoint); + met = std::make_shared(met_conf); + + if (config.gpu_only == false) { + if (config.k_cache_on == false) { + SPDLOG_ERROR("if k_cache_on is false, gpu_only must be true"); + assert(false); + } + root = config.path; + tree = std::make_unique(); + disk_cache = std::make_unique(config); + memory_pool = std::make_shared(config.memory_pool_size); + cache_manager = std::unique_ptr( + new CacheEntryManager(CacheEntryManagerConfig{.evict_count = config.evict_count, .kvc2_top = this})); + cache_manager->pool = memory_pool; + + io_dealer = std::make_unique(); + io_dealer->start_io_thread().detach(); + + tree->met = met; + if (config.gpu_cache_config.has_value()) { + gpu_cache = std::make_shared(config.gpu_cache_config.value()); + cache_manager->gpu_cache = gpu_cache; + } + cache_manager->cpu_background_flush(); + gpu_cache->gpu_background_flush(); + } else { + SPDLOG_CRITICAL("GPU ONLY MODE, NO PREFIX CACHE"); + gpu_cache = std::make_shared(config.gpu_cache_config.value()); + } + } +}; + +std::shared_ptr create_kvc2(KVC2Config config) { + NumTokenPerBlock = config.num_token_per_page; + EvictCount = config.evict_count; + // SPDLOG_WARN("Sizeof KVC2Config {} here", sizeof(KVC2Config)); + return std::make_shared(config); +} + +DoubleCacheHandle::~DoubleCacheHandle() { + if (kvc2_top->config.gpu_only) { + kvc2_top->gpu_cache->gpu_only_free_cols(gpu_only_block_idx); + } else { + for_all_cache_block_entry([](std::shared_ptr& block_entry) { + block_entry->lock_guard(); + if (block_entry->with_key == false && block_entry->data != nullptr) { + block_entry->free_on_cpu(); + } + return true; + }); + } +}; + +void DoubleCacheHandle::get_handles() { + size_t new_count = 0, total_count = 0; + auto get_info_handles = [this, &new_count, &total_count]( + CacheInfo info, std::vector>>& layers) { + auto total_block_count = div_up(estimated_length, NumTokenPerBlock); + for (size_t l = 0; l < info.hidden_layer_count(); l++) { + auto hashes = match.matched_hashes(info, l); + layers[l].resize(total_block_count, nullptr); + for (size_t i = 0; i < total_block_count; i++) { + std::optional key = std::nullopt; + if (i < hashes.size()) + key = hashes[i]; + bool is_new; + total_count += 1; + layers[l][i] = this->kvc2_top->cache_manager->get(is_new, info.element_size(NumTokenPerBlock), key); + if (is_new) + new_count += 1; + layers[l][i]->cache_info = info; + layers[l][i]->layer = l; + } + } + }; + + if (kvc2_top->config.k_cache_on) + get_info_handles(k_info(), k_cache_handles); + if (kvc2_top->config.v_cache_on) + get_info_handles(v_info(), v_cache_handles); + SPDLOG_INFO("New Handles: {}/{}", new_count, total_count); +} + +bool DoubleCacheHandle::to_gpu() { + std::promise p; + to_gpu_async([&p](bool ok) { p.set_value(ok); }); + return p.get_future().get(); +} + +void DoubleCacheHandle::to_gpu_async(std::function call_back) { + if (enable_alt) { + assert(false); + // size_t page_size = kvc2_top->config.num_token_per_page; + // BlockLength count = + // div_up(TokenLength(std::ceil(match_by_blocks.partial_count() * page_size * + // kvc2_top->config.recompute_ratio)), + // page_size); + // if (alloc_attached_blocks(count) == false) { + // SPDLOG_WARN("Cannot allocate attached GPU block"); + // call_back(false); + // return; + // } else { + // SPDLOG_INFO("Allocated {} attached GPU blocks", count); + // } + } + + // don't wait here + if (any_match() == false) { + SPDLOG_INFO("No match, No need to load to gpu"); + call_back(true); + return; + } + + auto gpu_io_helper = gpu_io(kvc2_top->gpu_cache.get(), 0, match_range_length(), IO_Read); + gpu_io_helper->call_back = [call_back]() { call_back(true); }; + + // Ok this is very stupid, but I have to do this for now + std::thread([gpu_io_helper]() { gpu_io_helper->wait(); }).detach(); +} + +bool DoubleCacheHandle::alloc_attached_blocks(BlockLength count) { + // attached_vertical_handles.resize(count); + // for (size_t i = 0; i < count; i++) { + // attached_vertical_handles[i] = std::shared_ptr(new DoubleVerticalBlocksHandle); + // attached_vertical_handles[i]->gpu_only = true; + // } + // return kvc2_top->gpu_cache->alloc_pages(attached_vertical_handles); + return true; +} + +std::vector DoubleCacheHandle::get_gpu_attached_block_idx() { + std::vector re; + // for (auto& h : attached_vertical_handles) { + // re.push_back(h->gpu_block_idx.value()); + // } + return re; +} + +void CacheBlockEntry::set_key(TokensHash key, std::shared_ptr me) { + assert(with_key == false); + with_key = true; + hash = key; + // SPDLOG_DEBUG("Insert New Gen KVCache, key {}", key); + std::lock_guard manager_lg(manager->lock); + if (manager->key_entry_map.contains(me->hash)) { + SPDLOG_WARN("Duplicate key {}", me->hash); + } else { + manager->insert(me); + } +} + +std::vector DoubleCacheHandle::get_gpu_block_idx() { + if (kvc2_top->config.gpu_only) { + return gpu_only_block_idx; + } else { + std::vector re; + for (auto& handle : k_cache_handles[0]) { + re.push_back(handle->gpu_block_idx.value()); + } + return re; + } +} + +/* +length : total length of tokens (including matched tokens) + 1. update key, insert CacheBlock hash to lru + 2. set dirty flag + 3. update prefix tree, allocate new disk location +*/ +void DoubleCacheHandle::append_tokens(Token* all_tokens, TokenLength length) { + if (kvc2_top->config.gpu_only) { + return; + } + TimeObserver time_observer(kvc2_top->met->append_tokens_time_ms); + if (enable_alt) { + SPDLOG_WARN("Append Tokens Not Implemented for Alternative Path"); + return; + } + if (length > estimated_length) { + SPDLOG_ERROR("Length {} exceed estimated length {}", length, estimated_length); + assert(false); + } + size_t match_length = matched_length(); + + if (length < match_length) { + SPDLOG_WARN("Length {} less than match length {}", length, match_length); + assert(false); + } + + if (length > ids.size()) { + ids.insert(ids.end(), all_tokens + ids.size(), all_tokens + length); + } + + static const auto num_token_per_page = kvc2_top->config.num_token_per_page; + + if (match_length % num_token_per_page != 0) { + SPDLOG_ERROR("Match length {} is not multiple of num_token_per_page {}", match_length, num_token_per_page); + assert(false); + } + + if (match_length + num_token_per_page > length) { + // SPDLOG_DEBUG("append_tokens No need to update"); + return; + } + SPDLOG_DEBUG("Append Tokens to {}", length); + auto pre_match_length = match_length; + // set gpu dirty flag + size_t new_added_block_count = 0; + while (match_length + num_token_per_page <= length) { + match_length += num_token_per_page; + new_added_block_count += 1; + } + + // update prefix tree + match.prefix = kvc2_top->tree->new_prefix_node(match.prefix, pre_match_length, ids.data(), match_length).get(); + match.match_length = match_length; + + // alloc disk location for new added prefix + auto disk_cache = kvc2_top->disk_cache.get(); + Location k_loc{0, 0}, v_loc{0, 0}; + if (is_k_cache_on) { + k_loc = disk_cache->allocate(k_info(), new_added_block_count); + k_seg_locs.add_location(match.prefix->start_length / NumTokenPerBlock, k_loc); + match.prefix->update_location(k_info(), k_loc); + } + if (is_v_cache_on) { + v_loc = disk_cache->allocate(v_info(), new_added_block_count); + v_seg_locs.add_location(match.prefix->start_length / NumTokenPerBlock, v_loc); + match.prefix->update_location(v_info(), v_loc); + } + + // update cache handles + auto update_cache_handles = [this, pre_match_length, length]( + CacheInfo info, std::vector>>& layers, + Location loc) { + TokensHasher hasher; + for (Layer l = 0; l < info.hidden_layer_count(); l++) { + hasher.reset(info.hash_value()); + hasher.update_raw(&l, sizeof(l)); + hasher.update(ids.data(), pre_match_length); + auto page_count_start = pre_match_length / num_token_per_page; + for (size_t i = pre_match_length; i + num_token_per_page <= length; i += num_token_per_page) { + auto page_count = i / num_token_per_page; + hasher.update(ids.data() + i, num_token_per_page); + auto block = layers[l][page_count]; + { + auto lg = block->lock_guard(); + block->idx = loc.start_idx + page_count - page_count_start; + block->set_key(hasher.get(), block); + if (l == 0 && info.is_key_cache) { + block->gpu_cc.tc.set_has_data(); + } + block->gpu_cc.dirty.store(true); + } + } + } + }; + + if (is_k_cache_on) { + update_cache_handles(k_info(), k_cache_handles, k_loc); + } + if (is_v_cache_on) { + update_cache_handles(v_info(), v_cache_handles, v_loc); + } + + // kvc2_top->block_cache->debug(); +} + +void CacheBlockEntry::flush_back_async(IO_Helper& helper, + std::vector& dirty_flags) { + auto kvc2_top = manager->config.kvc2_top; + auto allocator = kvc2_top->disk_cache->get_allocator(cache_info); + // if (layer == 0) { + // SPDLOG_DEBUG("Flush {} to {}", fmt::ptr(this), idx); + // } + io_with(kvc2_top->io_dealer.get(), helper, allocator->get_store(layer), layer, idx, IOOption::IO_Write); + dirty_flags.push_back(&cpu_cc.dirty); +} + +void CacheEntryManager::cpu_background_flush() { + if (background_flush_back.get() == nullptr) { + SPDLOG_INFO("Starting CPU Background flush"); + background_flush_back = std::unique_ptr(new periodic::PeriodicTask([this]() { + // Timer t("CPU Flush"); + std::vector dirty_cpus; + std::vector> entry_uls; + IO_Helper io_helper(nullptr, [&dirty_cpus]() { + for (auto& flag : dirty_cpus) { + flag->store(false); + } + if (dirty_cpus.size() > 0) + SPDLOG_DEBUG("{} dirty CPU pages flushed.", dirty_cpus.size()); + }); + { + std::lock_guard ul(lock); + for (auto& e : usage_list) { + auto ul = e->try_lock(); + if (ul.owns_lock()) { + if (e->cpu_cc.dirty.load()) { + entry_uls.push_back(std::move(ul)); + e->flush_back_async(io_helper, dirty_cpus); + } + } + // if (dirty_cpus.size() == 100) { + // break; + // } + } + } + + io_helper.finish_add_taks(); + io_helper.wait(); + })); + } else { + SPDLOG_ERROR("Flush Thread Already Started"); + } +} + +void GPUPageCache::gpu_background_flush() { + if (background_flush_back.get() == nullptr) { + SPDLOG_INFO("Starting GPU Background flush"); + background_flush_back = std::unique_ptr(new periodic::PeriodicTask([this]() { + // Timer t("GPU Flush"); + + std::vector dirty_cols; + std::vector entries; + std::vector> uls; + BatchPromise promise(config.gpu_devices_id.size()); + auto reqs = basic_request(cudaMemcpyDeviceToHost, [&promise]() { promise.set(); }); + + for (size_t i = 0; i < config.total_kvcache_pages; i++) { + std::lock_guard lg(this->lock); + auto col_uls = try_lock_col(i); + if (col_uls.empty()) + continue; + for (size_t l = 0; l < config.layer_count; l++) { + if (config.k_cache_on && (occupations[l][i]->gpu_cc.dirty.load() == false || occupations[l][i]->cpu_cc.dirty.load())) + goto next_gpu_page; + if (config.v_cache_on && (v_occupations[l][i]->gpu_cc.dirty.load() == false || v_occupations[l][i]->cpu_cc.dirty.load())) + goto next_gpu_page; + } + + dirty_cols.push_back(i); + for (size_t l = 0; l < config.layer_count; l++) { + // occupations[l][i]->alloc_on_cpu_no_lock(); + if (config.k_cache_on) + entries.push_back(occupations[l][i].get()); + if (config.v_cache_on) + entries.push_back(v_occupations[l][i].get()); + } + append_col_to_request(reqs, occupations, v_occupations, i); + for (auto& ul : col_uls) { + uls.push_back(std::move(ul)); + } + next_gpu_page: + continue; + } + + submit_requests(reqs); + promise.get_shared_fut().wait(); + if (dirty_cols.empty() == false) + SPDLOG_INFO("GPU Flushed Back {} cols", dirty_cols.size()); + + for (auto& entry : entries) { + entry->cpu_cc.tc.set_has_data(); + // we have locks here + entry->cpu_cc.dirty.store(true); + } + for (auto& col : dirty_cols) { + for (size_t l = 0; l < config.layer_count; l++) { + if (config.k_cache_on) + occupations[l][col]->gpu_cc.dirty.store(false); + if (config.v_cache_on) + v_occupations[l][col]->gpu_cc.dirty.store(false); + } + } + if (dirty_cols.empty() == false) { + debug(); + } + })); + } else { + SPDLOG_ERROR("Flush Thread Already Started"); + } +} + +} // namespace kvc2 diff --git a/csrc/balance_serve/kvc2/src/utils/all.hpp b/csrc/balance_serve/kvc2/src/utils/all.hpp new file mode 100644 index 0000000..18a6ee9 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/utils/all.hpp @@ -0,0 +1,3 @@ +#pragma once +#include "easy_format.hpp" +#include "timer.hpp" \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/src/utils/arithmetic.hpp b/csrc/balance_serve/kvc2/src/utils/arithmetic.hpp new file mode 100644 index 0000000..d1cb472 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/utils/arithmetic.hpp @@ -0,0 +1,14 @@ +#include +#include + +template +T div_up(T x, U by) { + static_assert(std::is_integral_v); + static_assert(std::is_integral_v); + return (x + by - 1) / by; +} + +template +T* offset_by_bytes(T* t, size_t n) { + return reinterpret_cast(reinterpret_cast(t) + n); +} diff --git a/csrc/balance_serve/kvc2/src/utils/easy_format.hpp b/csrc/balance_serve/kvc2/src/utils/easy_format.hpp new file mode 100644 index 0000000..4b1b63f --- /dev/null +++ b/csrc/balance_serve/kvc2/src/utils/easy_format.hpp @@ -0,0 +1,37 @@ +#ifndef __EASY_FORMAT_HPP_ +#define __EASY_FORMAT_HPP_ +#include +#include +#include +#include + +#include + +template +inline std::string format_vector(const std::vector& v) { + std::ostringstream oss; + if (v.empty()) + return "[]"; + for (size_t i = 0; i < v.size(); ++i) { + oss << v[i]; + if (i < v.size() - 1) + oss << ", "; // 逗号分隔 + } + return oss.str(); +} + +inline std::array units = {"", "K", "M", "G", "T", "P", "E"}; + +inline std::string readable_number(size_t size) { + size_t unit_index = 0; + double readable_size = size; + while (readable_size >= 1000 && unit_index < units.size() - 1) { + readable_size /= 1000; + unit_index++; + } + std::ostringstream ss; + ss << std::fixed << std::setprecision(2) << readable_size; + std::string str = ss.str(); + return str + "" + units[unit_index]; +} +#endif \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/src/utils/lock_free_queue.hpp b/csrc/balance_serve/kvc2/src/utils/lock_free_queue.hpp new file mode 100644 index 0000000..457bb82 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/utils/lock_free_queue.hpp @@ -0,0 +1,60 @@ +#include +#include +#include +#include +#include +#include + +template +class MPSCQueue { + struct Node { + std::shared_ptr data; + std::atomic next; + + Node() : next(nullptr) {} + Node(std::shared_ptr data_) : data(std::move(data_)), next(nullptr) {} + }; + + std::atomic head; + Node* tail; + + public: + std::atomic_size_t enqueue_count = 0; + size_t dequeue_count = 0; + MPSCQueue() { + Node* dummy = new Node(); + head.store(dummy, std::memory_order_relaxed); + tail = dummy; + } + + ~MPSCQueue() { + // 清理剩余的节点 + Node* node = tail; + while (node) { + Node* next = node->next.load(std::memory_order_relaxed); + delete node; + node = next; + } + } + + // 生产者调用 + void enqueue(std::shared_ptr data) { + enqueue_count.fetch_add(1); + Node* node = new Node(std::move(data)); + Node* prev_head = head.exchange(node, std::memory_order_acq_rel); + prev_head->next.store(node, std::memory_order_release); + } + + // 消费者调用 + std::shared_ptr dequeue() { + Node* next = tail->next.load(std::memory_order_acquire); + if (next) { + std::shared_ptr res = std::move(next->data); + delete tail; + tail = next; + dequeue_count += 1; + return res; + } + return nullptr; + } +}; \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/src/utils/mpsc.hpp b/csrc/balance_serve/kvc2/src/utils/mpsc.hpp new file mode 100644 index 0000000..ed44e63 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/utils/mpsc.hpp @@ -0,0 +1,90 @@ +#include +#include +#include +#include +#include + +template +class MPSCQueue { + struct Node { + T data; + std::atomic next; + + Node() : next(nullptr) {} + Node(T data_) : data(std::move(data_)), next(nullptr) {} + }; + + std::atomic head; + Node* tail; + + public: + std::atomic_size_t enqueue_count = 0; + size_t dequeue_count = 0; + MPSCQueue() { + Node* dummy = new Node(); + head.store(dummy, std::memory_order_seq_cst); + tail = dummy; + } + + ~MPSCQueue() { + Node* node = tail; + while (node) { + Node* next = node->next.load(std::memory_order_seq_cst); + delete node; + node = next; + } + } + + // 生产者调用 + void enqueue(T data) { + enqueue_count.fetch_add(1); + Node* node = new Node(std::move(data)); + Node* prev_head = head.exchange(node, std::memory_order_seq_cst); + prev_head->next.store(node, std::memory_order_seq_cst); + } + + // 消费者调用 + std::optional dequeue() { + Node* next = tail->next.load(std::memory_order_seq_cst); + if (next) { + T res = std::move(next->data); + delete tail; + tail = next; + dequeue_count += 1; + return res; + } + return std::nullopt; + } + + size_t size() { return enqueue_count.load() - dequeue_count; } +}; + +template +class MPSCQueueConsumerLock { + MPSCQueue queue; + std::counting_semaphore<> sema{0}; + + public: + void enqueue(T data) { + queue.enqueue(std::move(data)); + // std::atomic_thread_fence(std::memory_order_seq_cst);// Inserting this because the memory order might be wrong, I + // am also not that sure about this. + sema.release(); + } + + T dequeue() { + auto re = queue.dequeue(); + if (re.has_value()) { + while (sema.try_acquire() == false) { + std::cerr << __FILE__ << ":" << __FUNCTION__ << " sema try acquire should be success, retrying, please check" + << std::endl; + // assert(false); + } + return re.value(); + } + sema.acquire(); + return queue.dequeue().value(); + } + + size_t size() { return queue.size(); } +}; diff --git a/csrc/balance_serve/kvc2/src/utils/mutex_extend.hpp b/csrc/balance_serve/kvc2/src/utils/mutex_extend.hpp new file mode 100644 index 0000000..fb71f9a --- /dev/null +++ b/csrc/balance_serve/kvc2/src/utils/mutex_extend.hpp @@ -0,0 +1,70 @@ +#ifndef __MUTEX_EXTEND_HPP_ +#define __MUTEX_EXTEND_HPP_ + +#include +#include +#include +#include + +class non_recursive_mutex { + public: + non_recursive_mutex() = default; + + // 使用 try_lock 实现非递归锁 + bool try_lock() { + std::thread::id this_id = std::this_thread::get_id(); + + // 检查当前线程是否已经持有该锁 + if (owner.load(std::memory_order_acquire) == this_id) { + return false; // 如果是当前线程,返回失败 + } + + // 尝试加锁 + if (mtx.try_lock()) { + owner.store(this_id, std::memory_order_release); // 设置锁的拥有者 + return true; + } + + return false; + } + + // lock 会阻塞,直到获得锁 + void lock() { + std::thread::id this_id = std::this_thread::get_id(); + + while (true) { + // 检查当前线程是否已经持有该锁 + if (owner.load(std::memory_order_acquire) == this_id) { + throw std::runtime_error("Thread is trying to lock a mutex it already holds"); + } + + // 尝试加锁 + if (mtx.try_lock()) { + owner.store(this_id, std::memory_order_release); // 设置锁的拥有者 + return; + } + + // 如果锁未获得,则稍微等待,防止忙等 + std::this_thread::yield(); + } + } + + // 解锁 + void unlock() { + std::thread::id this_id = std::this_thread::get_id(); + + // 确保只有持有锁的线程可以解锁 + if (owner.load(std::memory_order_acquire) == this_id) { + owner.store(std::thread::id(), std::memory_order_release); // 清除锁的拥有者 + mtx.unlock(); + } else { + throw std::runtime_error("Thread attempting to unlock a mutex it doesn't own"); + } + } + + private: + std::mutex mtx; // 实际的互斥量 + std::atomic owner; // 原子变量,记录当前锁的拥有者 +}; + +#endif diff --git a/csrc/balance_serve/kvc2/src/utils/periodic_task.hpp b/csrc/balance_serve/kvc2/src/utils/periodic_task.hpp new file mode 100644 index 0000000..b846bb3 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/utils/periodic_task.hpp @@ -0,0 +1,102 @@ +#ifndef PERIODIC_TASK_HPP +#define PERIODIC_TASK_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace periodic { + +class PeriodicTask { + public: + explicit PeriodicTask(std::function func, + std::chrono::milliseconds interval_ms = std::chrono::milliseconds(100)) + : func_(std::move(func)), interval_(interval_ms), worker_([this](std::stop_token stoken) { this->run(stoken); }) { + // std::cout << "PeriodicTask created with interval: " << interval_.count() << " ms" << std::endl; + } + + ~PeriodicTask() { + worker_.request_stop(); + cv_.notify_one(); // Ensure worker wakes up when destroyed + // std::cout << "PeriodicTask destructor called, stopping worker." << std::endl; + } + + void wakeUp() { + { + std::lock_guard lock(wakeup_mutex_); + wake_up_requested_ = true; + } + cv_.notify_one(); // Notify worker thread to wake up immediately + // std::cout << "wakeUp() called: worker thread will wake up." << std::endl; + } + + std::future wakeUpWait() { + std::promise promise; + std::future future = promise.get_future(); + { + std::lock_guard lock(promise_mutex_); + wakeup_promises_.push_back(std::move(promise)); + } + wakeUp(); + return future; + } + + private: + void run(std::stop_token stoken) { + while (!stoken.stop_requested()) { + std::unique_lock lock(mutex_); + // Wait for either the time interval or a wake-up signal + cv_.wait_for(lock, interval_, [this] { return wake_up_requested_.load(); }); + + if (stoken.stop_requested()) + break; + + // If the wake-up was triggered, reset the flag and process the task + { + std::lock_guard lock(wakeup_mutex_); + wake_up_requested_ = false; + } + + try { + // std::cout << "Running task function." << std::endl; + func_(); + } catch (...) { + std::cerr << "Error in task function." << std::endl; + } + + notifyPromises(); + } + } + + void notifyPromises() { + std::lock_guard lock(promise_mutex_); + // std::cout << "Notifying all waiting promises." << std::endl; + for (auto& promise : wakeup_promises_) { + promise.set_value(); + } + wakeup_promises_.clear(); + } + + std::function func_; + std::chrono::milliseconds interval_; + std::mutex mutex_; + std::condition_variable cv_; + std::vector> wakeup_promises_; + std::mutex promise_mutex_; + std::mutex wakeup_mutex_; + std::atomic wake_up_requested_ = false; + std::jthread worker_; +}; + +} // namespace periodic + +#endif // PERIODIC_TASK_HPP diff --git a/csrc/balance_serve/kvc2/src/utils/spin_lock.hpp b/csrc/balance_serve/kvc2/src/utils/spin_lock.hpp new file mode 100644 index 0000000..82d35e3 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/utils/spin_lock.hpp @@ -0,0 +1,36 @@ +/* + * @Author: Xie Weiyu ervinxie@qq.com + * @Date: 2024-11-21 06:35:47 + * @LastEditors: Xie Weiyu ervinxie@qq.com + * @LastEditTime: 2024-11-21 06:35:50 + * @FilePath: /kvc2/src/utils/spin_lock.hpp + * @Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: + * https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE + */ + +#include +#include +#include + +class SpinLock { + public: + SpinLock() { flag.clear(); } + + void lock() { + const int max_delay = 1024; // Maximum delay in microseconds + int delay = 1; // Initial delay in microseconds + + while (flag.test_and_set(std::memory_order_acquire)) { + std::this_thread::sleep_for(std::chrono::microseconds(delay)); + delay *= 2; + if (delay > max_delay) { + delay = max_delay; + } + } + } + + void unlock() { flag.clear(std::memory_order_release); } + + private: + std::atomic_flag flag = ATOMIC_FLAG_INIT; +}; diff --git a/csrc/balance_serve/kvc2/src/utils/timer.hpp b/csrc/balance_serve/kvc2/src/utils/timer.hpp new file mode 100644 index 0000000..bf53f22 --- /dev/null +++ b/csrc/balance_serve/kvc2/src/utils/timer.hpp @@ -0,0 +1,128 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "easy_format.hpp" + +inline std::string doubleToStringR2(double value) { + std::stringstream stream; + stream << std::fixed << std::setprecision(2) << value; + return stream.str(); +} + +class Timer { + public: + std::string name; + bool tmp_timer = false; + + Timer() {} + Timer(std::string name) : name(name), tmp_timer(true) { start(); } + ~Timer() { + if (tmp_timer) { + std::cout << name << " " << elapsedMs() << " ms" << std::endl; + } + } + + void start() { + m_startTime = std::chrono::high_resolution_clock::now(); + assert(m_isRunning == false); + m_isRunning = true; + } + + void stop() { + m_endTime = std::chrono::high_resolution_clock::now(); + assert(m_isRunning == true); + m_isRunning = false; + m_runningNs += elapsedNs(); + } + + double elapsedNs() { + std::chrono::time_point endTime; + + if (m_isRunning) { + endTime = std::chrono::high_resolution_clock::now(); + } else { + endTime = m_endTime; + } + + return std::chrono::duration_cast(endTime - m_startTime).count(); + } + + void printElapsedMilliseconds() { std::cout << elapsedNs() / 1e6 << " ms" << std::endl; } + + static std::string ns_to_string(double duration) { + auto nano_sec = duration; + if (nano_sec >= 1000) { + auto mirco_sec = nano_sec / 1000.0; + if (mirco_sec >= 1000) { + auto milli_sec = mirco_sec / 1000.0; + if (milli_sec >= 1000) { + auto seconds = milli_sec / 1000.0; + + if (seconds >= 60.0) { + auto minutes = seconds / 60.0; + + if (minutes >= 60.0) { + auto hours = minutes / 60.0; + return doubleToStringR2(hours) + " h"; + } else { + return doubleToStringR2(minutes) + " min"; + } + } else { + return doubleToStringR2(seconds) + " sec"; + } + } else { + return doubleToStringR2(milli_sec) + " ms"; + } + } else { + return doubleToStringR2(mirco_sec) + " us"; + } + } else { + return doubleToStringR2(nano_sec) + " ns"; + } + } + + double runningTimeNs() { return m_runningNs; } + + std::string runningTime() { + auto duration = m_runningNs; + return ns_to_string(duration); + } + + std::string elapsedTime() { return ns_to_string(elapsedNs()); } + double elapsedMs() { return elapsedNs() / 1e6; } + std::string report_throughput(size_t op_cnt) { + double ops = op_cnt / elapsedMs() * 1000; + return readable_number(ops) + "op/s"; + } + + void merge(Timer& other) { + assert(m_isRunning == false); + assert(other.m_isRunning == false); + m_runningNs += other.runningTimeNs(); + } + + private: + std::chrono::time_point m_startTime; + std::chrono::time_point m_endTime; + bool m_isRunning = false; + double m_runningNs = 0.0; +}; + +class Counter { + public: + Counter() {} + + std::map counters; + + void inc(const char* name, size_t num) { counters[name] += num; }; + void print() { + for (auto& p : counters) { + std::cout << p.first << " : " << p.second << std::endl; + } + }; +}; diff --git a/csrc/balance_serve/kvc2/test/CMakeLists.txt b/csrc/balance_serve/kvc2/test/CMakeLists.txt new file mode 100644 index 0000000..f1b093a --- /dev/null +++ b/csrc/balance_serve/kvc2/test/CMakeLists.txt @@ -0,0 +1,78 @@ + +set(CMAKE_CXX_FLAGS "-Og -march=native -Wall -Wextra -g -fopenmp") +# set(CMAKE_CXX_FLAGS "-O3 -march=native -Wall -Wextra -pthread") + +add_subdirectory(kvc2test) + + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../src) + +add_executable(hashmap_test hashmap_test.cpp) +target_link_libraries(hashmap_test PRIVATE TBB::tbb) + + +add_executable(xxHash_test xxHash_test.cpp) +target_link_libraries(xxHash_test PRIVATE xxhash) + +function(add_async_store_executable source_file) + get_filename_component(target_name ${source_file} NAME_WE) # 获取不带扩展名的文件名作为目标名 + add_executable(${target_name} ${source_file}) + target_include_directories(${target_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../src) + target_include_directories(${target_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../third_party/nlohmann/single_include) + target_include_directories(${target_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../third_party/spdlog/include) + target_link_libraries(${target_name} PRIVATE async_store gflags) +endfunction() + +add_async_store_executable(async_store_test.cpp) + + +function(add_kvc2_executable source_file) + get_filename_component(target_name ${source_file} NAME_WE) # 获取不带扩展名的文件名作为目标名 + add_executable(${target_name} ${source_file}) + # target_compile_options(${target_name} PRIVATE -fopenmp -fno-strict-aliasing) + target_include_directories(${target_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../src) + target_include_directories(${target_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../third_party/nlohmann/single_include) + target_include_directories(${target_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../third_party/spdlog/include) + target_link_libraries(${target_name} PRIVATE kvc2 async_store gflags) +endfunction() + + + + +add_kvc2_executable(test_lock_free_queue.cpp) +add_kvc2_executable(test_queue_perf.cpp) + +# Disable deprecated test +# add_kvc2_executable(prefix_test.cpp) +# add_kvc2_executable(kvcache_disk_insert_read_test.cpp) +# add_kvc2_executable(kvcache_mem_eviction_test.cpp) +# add_kvc2_executable(kvcache_mem_insert_read_test.cpp) +# add_kvc2_executable(kvcache_save_load_test.cpp) +# add_kvc2_executable(kvc2_export_header_test.cpp) +# add_kvc2_executable(kvc2_export_load_test.cpp) + + + + + +target_include_directories(async_store_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..//third_party/nlohmann/single_include) +target_include_directories(async_store_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..//third_party/spdlog/include) +target_link_libraries(async_store_test PRIVATE xxhash) + +add_executable(test_std_list test_std_list.cpp) + + +add_executable(test_cuda_stream test_cuda_stream.cpp) +target_include_directories(test_cuda_stream PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) +target_link_libraries(test_cuda_stream PRIVATE CUDA::cudart) + +add_executable(test_cuda_stream_manager test_cuda_stream_manager.cpp) +target_include_directories(test_cuda_stream_manager PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../src) +target_link_libraries(test_cuda_stream_manager PRIVATE cuda_stream_manager) + +add_executable(test_periodic_task test_periodic_task.cpp) +target_include_directories(test_periodic_task PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../src) + +add_executable(test_page_pool page_pool_test.cpp) +target_include_directories(test_page_pool PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../src) +target_include_directories(test_page_pool PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../third_party/spdlog/include) \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/test/hashmap_test.cpp b/csrc/balance_serve/kvc2/test/hashmap_test.cpp new file mode 100644 index 0000000..cc5b7e4 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/hashmap_test.cpp @@ -0,0 +1,11 @@ +#include +#include + +int main() { + tbb::concurrent_hash_map map; + map.insert({1, 2}); + decltype(map)::accessor a; + std::cout << map.find(a, 1) << std::endl; + + return 0; +} diff --git a/csrc/balance_serve/kvc2/test/kvc2_export_header_test.cpp b/csrc/balance_serve/kvc2/test/kvc2_export_header_test.cpp new file mode 100644 index 0000000..e43d456 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvc2_export_header_test.cpp @@ -0,0 +1,87 @@ +#include "kvc2.h" +#include "kvc2_test_utils.cpp" + +int main(int argc, char* argv[]) { + init(argc, argv); + spdlog::set_level(spdlog::level::debug); + std::mt19937 gen(123); + + KVC2Config config = { + .path = FLAGS_disk_cache_path, + .config_path = std::string("/home/xwy/conifg"), + .block_length = BlockLength, + .memory_pool_size = size_t(10e9), + .evict_count = 20, + }; + auto kvcc = create_kvc2(config); + + auto io = kvcc->start_io_thread(); + + SPDLOG_INFO("Disk Test"); + auto ids = random_ids(10 * BlockLength, gen); + auto h1 = random_kvcache(qwen_cache_info, 10, gen); + kvcc->raw_insert(qwen_cache_info, reinterpret_cast(ids.data()), ids.size(), h1); + + // complete same + { + auto h2 = empty_kvcache(qwen_cache_info, 10); + kvcc->raw_read(qwen_cache_info, reinterpret_cast(ids.data()), ids.size(), h2); + cmp_handle_data(qwen_cache_info, h1, h2); + } + + // complete prefix + { + auto h2 = empty_kvcache(qwen_cache_info, 10); + auto ids2 = std::vector(ids.begin(), ids.begin() + 3 * BlockLength); + kvcc->raw_read(qwen_cache_info, reinterpret_cast(ids2.data()), ids2.size(), h2); + cmp_handle_data(qwen_cache_info, h1, h2, 3); + } + + // common prefix + { + auto h2 = empty_kvcache(qwen_cache_info, 10); + auto ids2 = std::vector(ids.begin(), ids.begin() + 5 * BlockLength); + auto rids = random_ids(BlockLength * 2 + BlockLength / 2, gen); + ids2.insert(ids2.end(), rids.begin(), rids.end()); + + kvcc->raw_read(qwen_cache_info, reinterpret_cast(ids2.data()), ids2.size(), h2); + + cmp_handle_data(qwen_cache_info, h1, h2, 5); + } + + // no prefix + { + auto h2 = empty_kvcache(qwen_cache_info, 10); + + auto ids2 = random_ids(10 * BlockLength, gen); + + kvcc->raw_read(qwen_cache_info, reinterpret_cast(ids2.data()), ids2.size(), h2); + } + + // insert partly new + auto h2 = random_kvcache(qwen_cache_info, 10, gen); + copy_kvcache(h1, h2, 0, 5); + auto ids2 = random_ids(10 * BlockLength, gen); + for (size_t i = 0; i < 5 * BlockLength; i++) { + ids2[i] = ids[i]; + } + + kvcc->raw_insert(qwen_cache_info, reinterpret_cast(ids2.data()), ids2.size(), h2); + + // read new part + { + auto h3 = empty_kvcache(qwen_cache_info, 10); + auto ids3 = std::vector(ids2.begin(), ids2.begin() + 7 * BlockLength); + ids3.push_back(123); + + kvcc->raw_read(qwen_cache_info, reinterpret_cast(ids3.data()), ids3.size(), h3); + cmp_handle_data(qwen_cache_info, h3, h2, 7); + } + kvcc->save(); + kvcc->stop_io_thread(); + io.join(); + + SPDLOG_WARN("{} Test Passed", __FILE__); + + return 0; +} \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/test/kvc2_export_load_test.cpp b/csrc/balance_serve/kvc2/test/kvc2_export_load_test.cpp new file mode 100644 index 0000000..5493b81 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvc2_export_load_test.cpp @@ -0,0 +1,87 @@ +#include "kvc2.h" +#include "kvc2_test_utils.cpp" + +int main(int argc, char* argv[]) { + init(argc, argv); + spdlog::set_level(spdlog::level::debug); + std::mt19937 gen(123); + + KVC2Config config = { + .path = FLAGS_disk_cache_path, + .block_length = BlockLength, + .memory_pool_size = size_t(10e9), + .evict_count = 20, + }; + auto kvcc = create_kvc2(config); + kvcc->load(); + + auto io = kvcc->start_io_thread(); + + SPDLOG_INFO("Disk Test"); + auto ids = random_ids(10 * BlockLength, gen); + auto h1 = empty_kvcache(qwen_cache_info, 10); + // kvcc->raw_insert(qwen_cache_info, reinterpret_cast(ids.data()), ids.size(), h1); + + // complete same + { + // auto h2 = empty_kvcache(qwen_cache_info, 10); + kvcc->raw_read(qwen_cache_info, reinterpret_cast(ids.data()), ids.size(), h1); + // cmp_handle_data(qwen_cache_info, h1, h2); + } + + // complete prefix + { + auto h2 = empty_kvcache(qwen_cache_info, 10); + auto ids2 = std::vector(ids.begin(), ids.begin() + 3 * BlockLength); + kvcc->raw_read(qwen_cache_info, reinterpret_cast(ids2.data()), ids2.size(), h2); + cmp_handle_data(qwen_cache_info, h1, h2, 3); + } + + // common prefix + { + auto h2 = empty_kvcache(qwen_cache_info, 10); + auto ids2 = std::vector(ids.begin(), ids.begin() + 5 * BlockLength); + auto rids = random_ids(BlockLength * 2 + BlockLength / 2, gen); + ids2.insert(ids2.end(), rids.begin(), rids.end()); + + kvcc->raw_read(qwen_cache_info, reinterpret_cast(ids2.data()), ids2.size(), h2); + + cmp_handle_data(qwen_cache_info, h1, h2, 5); + } + + // no prefix + { + auto h2 = empty_kvcache(qwen_cache_info, 10); + + auto ids2 = random_ids(10 * BlockLength, gen); + + kvcc->raw_read(qwen_cache_info, reinterpret_cast(ids2.data()), ids2.size(), h2); + } + + // insert partly new + auto h2 = random_kvcache(qwen_cache_info, 10, gen); + copy_kvcache(h1, h2, 0, 5); + auto ids2 = random_ids(10 * BlockLength, gen); + for (size_t i = 0; i < 5 * BlockLength; i++) { + ids2[i] = ids[i]; + } + + kvcc->raw_insert(qwen_cache_info, reinterpret_cast(ids2.data()), ids2.size(), h2); + + // read new part + { + auto h3 = empty_kvcache(qwen_cache_info, 10); + auto ids3 = std::vector(ids2.begin(), ids2.begin() + 7 * BlockLength); + ids3.push_back(123); + + kvcc->raw_read(qwen_cache_info, reinterpret_cast(ids3.data()), ids3.size(), h3); + cmp_handle_data(qwen_cache_info, h3, h2, 7); + } + + kvcc->stop_io_thread(); + io.join(); + + SPDLOG_WARN("{} Test Passed", __FILE__); + + return 0; +} \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/test/kvc2_test_utils.cpp b/csrc/balance_serve/kvc2/test/kvc2_test_utils.cpp new file mode 100644 index 0000000..d20ee85 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvc2_test_utils.cpp @@ -0,0 +1,117 @@ +#include +#include +#include "kvc2.h" +#define FMT_HEADER_ONLY +#include + +const int BlockLength = 256; + +std::string FLAGS_disk_cache_path; + +void init(int argc, char* argv[]) { + if (argc != 2) { + fmt::print("Usage: {} --disk_cache_path=xxx\n", argv[0]); + exit(1); + } + FLAGS_disk_cache_path = argv[1]; + if (FLAGS_disk_cache_path.empty()) { + fmt::print("disk_cache_path is empty"); + exit(1); + } +} + +using namespace kvc2; + +data_block_ptr empty_block(CacheInfo info) { + auto re = new (std::align_val_t(4096)) std::byte[info.element_size(BlockLength)]; + return reinterpret_cast(re); +} + +data_block_ptr random_block(CacheInfo info, std::mt19937& gen) { + auto re = empty_block(info); + uint64_t* d = (uint64_t*)re; + for (size_t i = 0; i < info.element_size(BlockLength) / 8; i++) { + d[i] = gen(); + } + return re; +} +layer_data random_blocks(CacheInfo info, size_t block_count, size_t seed) { + std::mt19937 gen(seed); + layer_data re; + for (size_t i = 0; i < block_count; i++) { + re.push_back(random_block(info, gen)); + } + return re; +} + +layer_data empty_blocks(CacheInfo info, size_t block_count) { + layer_data re; + for (size_t i = 0; i < block_count; i++) { + re.push_back(empty_block(info)); + } + return re; +} + +void copy_kvcache(std::vector& from, std::vector& to, size_t block_start, size_t length) { + for (size_t i = 0; i < from.size(); i++) { + for (size_t j = 0; j < length; j++) { + to[i][block_start + j] = from[i][block_start + j]; + } + } +} + +std::vector random_kvcache(CacheInfo info, size_t block_count, std::mt19937& gen) { + std::vector re; + re.resize(info.hidden_layer_count()); + fmt::print("Generating random kvcache, layer {}\n", info.hidden_layer_count()); +#pragma omp parallel for + for (size_t i = 0; i < info.hidden_layer_count(); i++) { + re[i] = random_blocks(info, block_count, gen()); + } + return re; +} + +std::vector empty_kvcache(CacheInfo info, size_t block_count) { + std::vector re; + re.resize(info.hidden_layer_count()); + fmt::print("Generating empty kvcache, layer {}\n", info.hidden_layer_count()); +#pragma omp parallel for + for (size_t i = 0; i < info.hidden_layer_count(); i++) { + re[i] = empty_blocks(info, block_count); + } + return re; +} + +std::vector random_ids(size_t length, std::mt19937& gen) { + std::vector re; + for (size_t i = 0; i < length; i++) { + re.push_back(gen()); + } + return re; +} + +CacheInfo qwen_cache_info = { + .model_name = "qwen2-72b-instruct", + .is_key_cache = true, + .quant_type = "BF16", +}; + +void cmp_handle_data(CacheInfo info, std::vector& h1, std::vector& h2, + std::optional blocks = std::nullopt) { + assert(h1.size() == h2.size()); + + for (size_t i = 0; i < h1.size(); i++) { + auto& b1 = h1[i]; + auto& b2 = h2[i]; + if (blocks.has_value() == false) { + assert(b1.size() == b2.size()); + } + int cmp_to = blocks.has_value() ? blocks.value() : b1.size(); + for (int j = 0; j < cmp_to; j++) { + auto e1 = reinterpret_cast(b1[j]); + auto e2 = reinterpret_cast(b2[j]); + assert(memcmp(e1, e2, info.element_size(BlockLength)) == 0); + } + } + fmt::print("KVCacheHandle cmp ok\n"); +} diff --git a/csrc/balance_serve/kvc2/test/kvc2test/CMakeLists.txt b/csrc/balance_serve/kvc2/test/kvc2test/CMakeLists.txt new file mode 100644 index 0000000..b9c40a9 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvc2test/CMakeLists.txt @@ -0,0 +1,26 @@ + +set(CMAKE_CXX_FLAGS "-Og -march=native -Wall -Wextra -g -fopenmp") + +function(add_kvc2_test source_file) + get_filename_component(target_name ${source_file} NAME_WE) # 获取不带扩展名的文件名作为目标名 + add_executable(${target_name} ${source_file}) + # target_compile_options(${target_name} PRIVATE -fopenmp -fno-strict-aliasing) + target_include_directories(${target_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../src) + target_include_directories(${target_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/nlohmann/single_include) + target_include_directories(${target_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/spdlog/include) + target_link_libraries(${target_name} PRIVATE kvc2 async_store) +endfunction() + +add_kvc2_test(raw_insert_read.cpp) +add_kvc2_test(lookup.cpp) +add_kvc2_test(lookup-alt.cpp) +add_kvc2_test(lookup-alt-gpu.cpp) +add_kvc2_test(lookup-mt.cpp) +add_kvc2_test(lookup-gpu.cpp) +add_kvc2_test(lookup-gpu-mt.cpp) +add_kvc2_test(lookup-gpu-async.cpp) +add_kvc2_test(append-tokens.cpp) +add_kvc2_test(flush-back.cpp) +add_kvc2_test(check-flush-back.cpp) +add_kvc2_test(lookup-without-vcache.cpp) +add_kvc2_test(lookup-gpu-mt-without-vcache.cpp) diff --git a/csrc/balance_serve/kvc2/test/kvc2test/append-tokens.cpp b/csrc/balance_serve/kvc2/test/kvc2test/append-tokens.cpp new file mode 100644 index 0000000..3356857 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvc2test/append-tokens.cpp @@ -0,0 +1,52 @@ +#include +#include "common.hpp" + +int main(int argc, char* argv[]) { + init(argc, argv); + spdlog::set_level(spdlog::level::debug); + auto kvc2 = kvc2::create_kvc2(config); + +#pragma omp parallel for + for (size_t ti = 0; ti < 3; ti++) { + auto [kcache, vcache] = kvc2->get_kvcache(); + std::mt19937 gen(ti + 123); + size_t total_page = 10; + TokenLength total_length = total_page * config.num_token_per_page; + auto tokens = random_ids(total_length, gen); + TokenLength prompt_length = 3 * config.num_token_per_page; + auto k1 = random_kvcache(total_page, gen); + auto v1 = random_kvcache(total_page, gen); + { + std::promise> p; + kvc2->lookup_to_gpu_async(test_model_name, test_quant_type, tokens.data(), prompt_length, total_length, + [&p](std::shared_ptr h) { p.set_value(h); }); + auto fut = p.get_future(); + fut.wait(); + auto h = fut.get(); + assert(h->matched_length() % config.num_token_per_page == 0); + size_t matched_block = h->matched_length() / config.num_token_per_page; + auto block_idx = h->get_gpu_block_idx(); + cmp_handle_gpu(block_idx, kcache, vcache, k1, v1, matched_block); + for (size_t at = matched_block; at < block_idx.size(); at++) { + copy_cpu_gpu(block_idx, kcache, vcache, k1, v1, at); + } + h->append_tokens(tokens.data(), total_length); + cmp_handle_gpu(block_idx, kcache, vcache, k1, v1, total_page); + } + + { + std::promise> p; + kvc2->lookup_to_gpu_async(test_model_name, test_quant_type, tokens.data(), total_length, total_length, + [&p](std::shared_ptr h) { p.set_value(h); }); + auto fut = p.get_future(); + fut.wait(); + auto h = fut.get(); + assert(h->matched_length() == total_length); + size_t matched_block = h->matched_length() / config.num_token_per_page; + auto block_idx = h->get_gpu_block_idx(); + cmp_handle_gpu(block_idx, kcache, vcache, k1, v1, matched_block); + } + } + SPDLOG_CRITICAL("All Test Passed: {}", argv[0]); + return 0; +} diff --git a/csrc/balance_serve/kvc2/test/kvc2test/check-flush-back.cpp b/csrc/balance_serve/kvc2/test/kvc2test/check-flush-back.cpp new file mode 100644 index 0000000..831f95a --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvc2test/check-flush-back.cpp @@ -0,0 +1,36 @@ +#include +#include "common.hpp" + +int main(int argc, char* argv[]) { + init(argc, argv); + spdlog::set_level(spdlog::level::debug); + config.gpu_cache_config->total_kvcache_pages = 12; + auto kvc2 = kvc2::create_kvc2(config); + kvc2->load(); + // #pragma omp parallel for + for (size_t ti = 0; ti < 2; ti++) { + SPDLOG_WARN("Test {}", ti); + auto [kcache, vcache] = kvc2->get_kvcache(); + std::mt19937 gen(ti + 123); + size_t total_page = 10; + TokenLength total_length = total_page * config.num_token_per_page; + auto tokens = random_ids(total_length, gen); + auto k1 = random_kvcache(total_page, gen); + auto v1 = random_kvcache(total_page, gen); + + { + std::promise> p; + kvc2->lookup_to_gpu_async(test_model_name, test_quant_type, tokens.data(), total_length, total_length, + [&p](std::shared_ptr h) { p.set_value(h); }); + auto fut = p.get_future(); + fut.wait(); + auto h = fut.get(); + assert(h->matched_length() == total_length); + size_t matched_block = h->matched_length() / config.num_token_per_page; + auto block_idx = h->get_gpu_block_idx(); + cmp_handle_gpu(block_idx, kcache, vcache, k1, v1, matched_block); + } + } + SPDLOG_CRITICAL("All Test Passed: {}", argv[0]); + return 0; +} diff --git a/csrc/balance_serve/kvc2/test/kvc2test/common.hpp b/csrc/balance_serve/kvc2/test/kvc2test/common.hpp new file mode 100644 index 0000000..29e37dd --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvc2test/common.hpp @@ -0,0 +1,233 @@ +/** + * @Description : + * @Author : Xie Weiyu + * @Date : 2024-11-22 06:02:41 + * @Version : 1.0.0 + * @LastEditors : Xie Weiyu + * @LastEditTime : 2024-12-11 07:34:10 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + **/ +#pragma once +#include +#include +#include "kvc2.h" +#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG +#define FMT_HEADER_ONLY +#include "spdlog/spdlog.h" + +using namespace kvc2; + +template +T* offset_by_bytes(T* t, size_t n) { + return reinterpret_cast(reinterpret_cast(t) + n); +} + +std::string FLAGS_disk_cache_path; + +kvc2::KVC2Config config; +kvc2::GPUPageCacheConfig qw25_7B_gpu_config{ + .gpu_only = false, + .gpu_devices_id = {0, 1}, + .layer_count = 28, + .total_kvcache_pages = 40, + .num_token_per_page = 256, + .num_k_heads = 4, + .k_head_dim = 896, + .full_kv_cache_on_each_gpu = false, + .k_cache_on = true, + .v_cache_on = true, + .tensor_type = torch::kBFloat16, + .num_streams_per_device = 4, +}; + +ModelName test_model_name = "Qwen2.5-7B-Instruct"; +QuantType test_quant_type = "FP16"; +CacheInfo test_cache_info{ + .model_name = test_model_name, + .is_key_cache = true, + .quant_type = test_quant_type, +}; + +void init(int argc, char* argv[]) { + if (argc != 2) { + fmt::print("Usage: {} \n", argv[0]); + exit(1); + } + load_quant_configs("./config/quant_configs.json"); + load_model_configs("./config/model_configs.json"); + + FLAGS_disk_cache_path = argv[1]; + if (FLAGS_disk_cache_path.empty()) { + fmt::print("disk_cache_path is empty\n"); + exit(1); + } + config.path = FLAGS_disk_cache_path; + config.config_path = "./config"; + config.gpu_cache_config = qw25_7B_gpu_config; +} + +data_block_ptr empty_block() { + auto re = new (std::align_val_t(4096)) std::byte[test_cache_info.element_size(config.num_token_per_page)]; + memset(re, 0, test_cache_info.element_size(config.num_token_per_page)); + return reinterpret_cast(re); +} + +data_block_ptr random_block(std::mt19937& gen) { + auto re = empty_block(); + uint64_t* d = (uint64_t*)re; + for (size_t i = 0; i < test_cache_info.element_size(config.num_token_per_page) / 8; i++) { + d[i] = gen(); + } + return re; +} +layer_data random_blocks(size_t block_count, size_t seed) { + std::mt19937 gen(seed); + layer_data re; + for (size_t i = 0; i < block_count; i++) { + re.push_back(random_block(gen)); + } + return re; +} + +layer_data empty_blocks(size_t block_count) { + layer_data re; + for (size_t i = 0; i < block_count; i++) { + re.push_back(empty_block()); + } + return re; +} + +void copy_kvcache(std::vector& from, std::vector& to, size_t block_start, size_t length) { + for (size_t i = 0; i < from.size(); i++) { + for (size_t j = 0; j < length; j++) { + to[i][block_start + j] = from[i][block_start + j]; + } + } +} + +std::vector random_kvcache(size_t block_count, std::mt19937& gen) { + std::vector re; + re.resize(test_cache_info.hidden_layer_count()); + fmt::print("Generating random kvcache, layer {}\n", test_cache_info.hidden_layer_count()); + std::vector gens; + for (size_t i = 0; i < test_cache_info.hidden_layer_count(); i++) { + gens.push_back(std::mt19937(gen())); + } +#pragma omp parallel for + for (size_t i = 0; i < test_cache_info.hidden_layer_count(); i++) { + re[i] = random_blocks(block_count, gens[i]()); + } + return re; +} + +std::vector empty_kvcache(size_t block_count) { + std::vector re; + re.resize(test_cache_info.hidden_layer_count()); + fmt::print("Generating empty kvcache, layer {}\n", test_cache_info.hidden_layer_count()); +#pragma omp parallel for + for (size_t i = 0; i < test_cache_info.hidden_layer_count(); i++) { + re[i] = empty_blocks(block_count); + } + return re; +} + +std::vector random_ids(size_t length, std::mt19937& gen) { + std::vector re; + for (size_t i = 0; i < length; i++) { + re.push_back(gen()); + } + return re; +} + +std::vector slice(std::vector& h1,size_t start,size_t end){ + std::vector re; + for(auto&l:h1){ + layer_data new_layer; + new_layer.insert(new_layer.end(),l.begin()+start,l.begin()+end); + re.push_back(new_layer); + } + return re; +} + +void cmp_handle_data(std::vector h1, std::vector h2, + std::optional blocks = std::nullopt) { + assert(h1.size() == h2.size()); + + for (size_t i = 0; i < h1.size(); i++) { + auto& b1 = h1[i]; + auto& b2 = h2[i]; + if (blocks.has_value() == false) { + assert(b1.size() == b2.size()); + } + int cmp_to = blocks.has_value() ? blocks.value() : b1.size(); + for (int j = 0; j < cmp_to; j++) { + auto e1 = reinterpret_cast(b1[j]); + auto e2 = reinterpret_cast(b2[j]); + assert(memcmp(e1, e2, test_cache_info.element_size(config.num_token_per_page)) == 0); + } + } + fmt::print("KVCacheHandle cmp ok\n"); +} + +void copy_gpu_cpu(std::vector& block_idx, std::vector& kcache, + std::vector& vcache, std::vector& k_cpu, std::vector& v_cpu, + size_t at) { + size_t gpu_count = config.gpu_cache_config->gpu_devices_id.size(); + size_t element_size_per_gpu = test_cache_info.element_size(config.num_token_per_page) / gpu_count; + + for (size_t layer = 0; layer < test_cache_info.hidden_layer_count(); layer++) { + for (size_t gpu_idx = 0; gpu_idx < gpu_count; gpu_idx++) { + { + auto kt = kcache[gpu_idx][layer][block_idx[at]].to(torch::kCPU); + void* src = kt.data_ptr(); + void* dst = offset_by_bytes(k_cpu[layer][at], gpu_idx * element_size_per_gpu); + memcpy(dst, src, element_size_per_gpu); + } + { + auto vt = vcache[gpu_idx][layer][block_idx[at]].to(torch::kCPU); + void* src = vt.data_ptr(); + void* dst = offset_by_bytes(v_cpu[layer][at], gpu_idx * element_size_per_gpu); + memcpy(dst, src, element_size_per_gpu); + } + } + } +} + +void copy_cpu_gpu(std::vector& block_idx, std::vector& kcache, + std::vector& vcache, std::vector& k_cpu, std::vector& v_cpu, + size_t at) { + size_t gpu_count = config.gpu_cache_config->gpu_devices_id.size(); + size_t element_size_per_gpu = test_cache_info.element_size(config.num_token_per_page) / gpu_count; + + for (size_t layer = 0; layer < test_cache_info.hidden_layer_count(); layer++) { + for (size_t gpu_idx = 0; gpu_idx < gpu_count; gpu_idx++) { + { + auto kt = kcache[gpu_idx][layer][block_idx[at]].to(torch::kCPU); + void* dst = kt.data_ptr(); + void* src = offset_by_bytes(k_cpu[layer][at], gpu_idx * element_size_per_gpu); + memcpy(dst, src, element_size_per_gpu); + kcache[gpu_idx][layer][block_idx[at]].copy_(kt); + } + { + auto vt = vcache[gpu_idx][layer][block_idx[at]].to(torch::kCPU); + void* dst = vt.data_ptr(); + void* src = offset_by_bytes(v_cpu[layer][at], gpu_idx * element_size_per_gpu); + memcpy(dst, src, element_size_per_gpu); + vcache[gpu_idx][layer][block_idx[at]].copy_(vt); + } + } + } +} + +void cmp_handle_gpu(std::vector& block_idx, std::vector& kcache, + std::vector& vcache, std::vector& k1, std::vector& v1, + size_t num_blocks) { + auto k_from_gpu = empty_kvcache(num_blocks); + auto v_from_gpu = empty_kvcache(num_blocks); + + for (size_t j = 0; j < std::min(block_idx.size(), num_blocks); j++) { + copy_gpu_cpu(block_idx, kcache, vcache, k_from_gpu, v_from_gpu, j); + } + cmp_handle_data(k1, k_from_gpu, num_blocks); + cmp_handle_data(v1, v_from_gpu, num_blocks); +} diff --git a/csrc/balance_serve/kvc2/test/kvc2test/flush-back.cpp b/csrc/balance_serve/kvc2/test/kvc2test/flush-back.cpp new file mode 100644 index 0000000..cd94b11 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvc2test/flush-back.cpp @@ -0,0 +1,57 @@ +#include +#include "common.hpp" + +int main(int argc, char* argv[]) { + init(argc, argv); + spdlog::set_level(spdlog::level::debug); + config.gpu_cache_config->total_kvcache_pages = 12; + auto kvc2 = kvc2::create_kvc2(config); + +// #pragma omp parallel for + for (size_t ti = 0; ti < 2; ti++) { + SPDLOG_WARN("Test {}",ti); + auto [kcache, vcache] = kvc2->get_kvcache(); + std::mt19937 gen(ti + 123); + size_t total_page = 10; + TokenLength total_length = total_page * config.num_token_per_page; + auto tokens = random_ids(total_length, gen); + TokenLength prompt_length = 3 * config.num_token_per_page; + auto k1 = random_kvcache(total_page, gen); + auto v1 = random_kvcache(total_page, gen); + + { + std::promise> p; + kvc2->lookup_to_gpu_async(test_model_name, test_quant_type, tokens.data(), prompt_length, total_length, + [&p](std::shared_ptr h) { p.set_value(h); }); + auto fut = p.get_future(); + fut.wait(); + auto h = fut.get(); + assert(h->matched_length() % config.num_token_per_page == 0); + size_t matched_block = h->matched_length() / config.num_token_per_page; + auto block_idx = h->get_gpu_block_idx(); + cmp_handle_gpu(block_idx, kcache, vcache, k1, v1, matched_block); + for (size_t at = matched_block; at < block_idx.size(); at++) { + copy_cpu_gpu(block_idx, kcache, vcache, k1, v1, at); + } + h->append_tokens(tokens.data(), total_length); + cmp_handle_gpu(block_idx, kcache, vcache, k1, v1, total_page); + } + + { + std::promise> p; + kvc2->lookup_to_gpu_async(test_model_name, test_quant_type, tokens.data(), total_length, total_length, + [&p](std::shared_ptr h) { p.set_value(h); }); + auto fut = p.get_future(); + fut.wait(); + auto h = fut.get(); + assert(h->matched_length() == total_length); + size_t matched_block = h->matched_length() / config.num_token_per_page; + auto block_idx = h->get_gpu_block_idx(); + cmp_handle_gpu(block_idx, kcache, vcache, k1, v1, matched_block); + } + } + kvc2->save(); + SPDLOG_CRITICAL("All Test Passed: {}", argv[0]); + + return 0; +} diff --git a/csrc/balance_serve/kvc2/test/kvc2test/lookup-alt-gpu.cpp b/csrc/balance_serve/kvc2/test/kvc2test/lookup-alt-gpu.cpp new file mode 100644 index 0000000..0ba7e75 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvc2test/lookup-alt-gpu.cpp @@ -0,0 +1,125 @@ +/** + * @Description : + * @Author : Xie Weiyu + * @Date : 2024-11-22 08:29:45 + * @Version : 1.0.0 + * @LastEditors : Xie Weiyu + * @LastEditTime : 2024-11-22 09:56:12 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + **/ +#include +#include "common.hpp" + +int main(int argc, char* argv[]) { + init(argc, argv); + spdlog::set_level(spdlog::level::trace); + auto kvc2 = kvc2::create_kvc2(config); + + std::mt19937 gen(123); + + std::vector> ids; + + std::vector> k, v; + for (size_t i = 0; i < 10; i++) { + ids.push_back(random_ids(1 * config.num_token_per_page, gen)); + k.push_back(random_kvcache(1, gen)); + v.push_back(random_kvcache(1, gen)); + kvc2->raw_insert(test_model_name, test_quant_type, ids[i].data(), ids[i].size(), k[i], v[i]); + } + + kvc2->debug(); + { + // all match + std::vector chunks; + std::vector lengths; + for (size_t i = 0; i < 10; i++) { + chunks.push_back(ids[i].data()); + lengths.push_back(ids[i].size()); + } + std::promise> p; + kvc2->lookup_alt_to_gpu_async(test_model_name, test_quant_type, chunks, lengths, 15 * config.num_token_per_page, + [&p](std::shared_ptr h) { p.set_value(h); }); + + auto fut = p.get_future(); + fut.wait(); + auto h = fut.get(); + auto hk = h->handle_data(true); + auto hv = h->handle_data(false); + + for (size_t i = 0; i < 10; i++) { + cmp_handle_data(slice(hk, i, i + 1), k[i], 1); + cmp_handle_data(slice(hv, i, i + 1), v[i], 1); + } + + auto block_idx = h->get_gpu_block_idx(); + auto [kcache, vcache] = kvc2->get_kvcache(); + for (size_t i = 0; i < 10; i++) { + std::vector blocks = {block_idx[i]}; + cmp_handle_gpu(blocks, kcache, vcache, k[i], v[i], 1); + } + } + + { + // no match in the middle + std::vector chunks; + std::vector lengths; + + std::vector> new_ids; + for (size_t i = 0; i < 10; i++) { + new_ids.push_back(random_ids(1 * config.num_token_per_page, gen)); + } + + for (size_t i = 0; i < 10; i++) { + if (i == 1 || i == 5 || i == 6) { + chunks.push_back(new_ids[i].data()); + } else { + chunks.push_back(ids[i].data()); + } + lengths.push_back(ids[i].size()); + } + + std::promise> p; + kvc2->lookup_alt_to_gpu_async(test_model_name, test_quant_type, chunks, lengths, 15 * config.num_token_per_page, + [&p](std::shared_ptr h) { p.set_value(h); }); + + auto fut = p.get_future(); + fut.wait(); + auto h = fut.get(); + auto statuses = h->matched_status(); + for (size_t i = 0; i < 10; i++) { + if (i == 1) { + assert(statuses[i] == MatchStatus::NotMatchExact); + } else if (i == 5 || i == 6) { + assert(statuses[i] == MatchStatus::NotMatchPartial); + } else if (i == 0) { + assert(statuses[i] == MatchStatus::Exact); + } else { + assert(statuses[i] == MatchStatus::Partial); + } + } + + auto hk = h->handle_data(true); + auto hv = h->handle_data(false); + + for (size_t i = 0; i < 10; i++) { + if (i == 1 || i == 5 || i == 6) { + } else { + cmp_handle_data(slice(hk, i, i + 1), k[i], 1); + cmp_handle_data(slice(hv, i, i + 1), v[i], 1); + } + } + + auto block_idx = h->get_gpu_block_idx(); + auto [kcache, vcache] = kvc2->get_kvcache(); + for (size_t i = 0; i < 10; i++) { + if (i == 1 || i == 5 || i == 6) { + } else { + std::vector blocks = {block_idx[i]}; + cmp_handle_gpu(blocks, kcache, vcache, k[i], v[i], 1); + } + } + } + + SPDLOG_CRITICAL("All Test Passed: {}", argv[0]); + return 0; +} diff --git a/csrc/balance_serve/kvc2/test/kvc2test/lookup-alt.cpp b/csrc/balance_serve/kvc2/test/kvc2test/lookup-alt.cpp new file mode 100644 index 0000000..0f60891 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvc2test/lookup-alt.cpp @@ -0,0 +1,97 @@ +/** + * @Description : + * @Author : Xie Weiyu + * @Date : 2024-11-22 08:29:45 + * @Version : 1.0.0 + * @LastEditors : Xie Weiyu + * @LastEditTime : 2024-11-22 09:56:12 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + **/ + +#include "common.hpp" + +int main(int argc, char* argv[]) { + init(argc, argv); + spdlog::set_level(spdlog::level::trace); + auto kvc2 = kvc2::create_kvc2(config); + + std::mt19937 gen(123); + + std::vector> ids; + + std::vector> k, v; + for (size_t i = 0; i < 10; i++) { + ids.push_back(random_ids(1 * config.num_token_per_page, gen)); + k.push_back(random_kvcache(1, gen)); + v.push_back(random_kvcache(1, gen)); + kvc2->raw_insert(test_model_name, test_quant_type, ids[i].data(), ids[i].size(), k[i], v[i]); + } + + kvc2->debug(); + { + // all match + std::vector chunks; + std::vector lengths; + for (size_t i = 0; i < 10; i++) { + chunks.push_back(ids[i].data()); + lengths.push_back(ids[i].size()); + } + + auto h = kvc2->lookup_alt(test_model_name, test_quant_type, chunks, lengths, 15 * config.num_token_per_page); + auto hk = h->handle_data(true); + auto hv = h->handle_data(false); + + for (size_t i = 0; i < 10; i++) { + cmp_handle_data(slice(hk, i, i + 1), k[i], 1); + cmp_handle_data(slice(hv, i, i + 1), v[i], 1); + } + } + + { + // no match in the middle + std::vector chunks; + std::vector lengths; + + std::vector> new_ids; + for (size_t i = 0; i < 10; i++) { + new_ids.push_back(random_ids(1 * config.num_token_per_page, gen)); + } + + for (size_t i = 0; i < 10; i++) { + if (i == 1 || i == 5 || i == 6) { + chunks.push_back(new_ids[i].data()); + } else { + chunks.push_back(ids[i].data()); + } + lengths.push_back(ids[i].size()); + } + + auto h = kvc2->lookup_alt(test_model_name, test_quant_type, chunks, lengths, 15 * config.num_token_per_page); + auto statuses = h->matched_status(); + for (size_t i = 0; i < 10; i++) { + if (i == 1) { + assert(statuses[i] == MatchStatus::NotMatchExact); + } else if (i == 5 || i == 6) { + assert(statuses[i] == MatchStatus::NotMatchPartial); + } else if (i == 0) { + assert(statuses[i] == MatchStatus::Exact); + } else { + assert(statuses[i] == MatchStatus::Partial); + } + } + + auto hk = h->handle_data(true); + auto hv = h->handle_data(false); + + for (size_t i = 0; i < 10; i++) { + if (i == 1 || i == 5 || i == 6) { + } else { + cmp_handle_data(slice(hk, i, i + 1), k[i], 1); + cmp_handle_data(slice(hv, i, i + 1), v[i], 1); + } + } + } + + SPDLOG_CRITICAL("All Test Passed: {}", argv[0]); + return 0; +} diff --git a/csrc/balance_serve/kvc2/test/kvc2test/lookup-gpu-async.cpp b/csrc/balance_serve/kvc2/test/kvc2test/lookup-gpu-async.cpp new file mode 100644 index 0000000..1c641c7 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvc2test/lookup-gpu-async.cpp @@ -0,0 +1,49 @@ +/** + * @Description : + * @Author : Xie Weiyu + * @Date : 2024-11-22 09:52:48 + * @Version : 1.0.0 + * @LastEditors : Xie Weiyu + * @LastEditTime : 2024-11-25 07:51:09 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + **/ + +#include +#include "common.hpp" + +int main(int argc, char* argv[]) { + init(argc, argv); + spdlog::set_level(spdlog::level::debug); + auto kvc2 = kvc2::create_kvc2(config); + + std::mt19937 gen(123); + auto ids1 = random_ids(10 * config.num_token_per_page, gen); + auto k1 = random_kvcache(10, gen); + auto v1 = random_kvcache(10, gen); + + kvc2->raw_insert(test_model_name, test_quant_type, ids1.data(), ids1.size(), k1, v1); + +// complete same +#pragma omp parallel for + for (size_t ti = 0; ti < 3; ti++) { + std::promise> p; + kvc2->lookup_to_gpu_async(test_model_name, test_quant_type, ids1.data(), ids1.size(), + ids1.size() + 2 * config.num_token_per_page, + [&p](std::shared_ptr h) { p.set_value(h); }); + auto fut = p.get_future(); + fut.wait(); + auto h = fut.get(); + auto k = h->handle_data(true); + auto v = h->handle_data(false); + cmp_handle_data(k1, k, 10); + cmp_handle_data(v1, v, 10); + + auto block_idx = h->get_gpu_block_idx(); + auto [kcache, vcache] = kvc2->get_kvcache(); + + cmp_handle_gpu(block_idx, kcache, vcache, k1, v1, 10); + } + + SPDLOG_CRITICAL("All Test Passed: {}", argv[0]); + return 0; +} diff --git a/csrc/balance_serve/kvc2/test/kvc2test/lookup-gpu-mt-without-vcache.cpp b/csrc/balance_serve/kvc2/test/kvc2test/lookup-gpu-mt-without-vcache.cpp new file mode 100644 index 0000000..a60de53 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvc2test/lookup-gpu-mt-without-vcache.cpp @@ -0,0 +1,61 @@ +/** + * @Description : + * @Author : Xie Weiyu + * @Date : 2024-11-22 09:52:48 + * @Version : 1.0.0 + * @LastEditors : Xie Weiyu + * @LastEditTime : 2024-11-25 07:51:09 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + **/ + +#include "common.hpp" + +int main(int argc, char* argv[]) { + qw25_7B_gpu_config.v_cache_on = false; + config.gpu_cache_config = qw25_7B_gpu_config; + config.v_cache_on = false; + + init(argc, argv); + spdlog::set_level(spdlog::level::debug); + auto kvc2 = kvc2::create_kvc2(config); + + std::mt19937 gen(123); + auto ids1 = random_ids(10 * config.num_token_per_page, gen); + auto k1 = random_kvcache(10, gen); + + kvc2->raw_insert(test_model_name, test_quant_type, ids1.data(), ids1.size(), k1, {}); + +// complete same +#pragma omp parallel for + for (size_t ti = 0; ti < 3; ti++) { + auto h = kvc2->lookup_to_gpu(test_model_name, test_quant_type, ids1.data(), ids1.size(), + ids1.size() + 2 * config.num_token_per_page); + auto k = h->handle_data(true); + cmp_handle_data(k1, k, 10); + + auto block_idx = h->get_gpu_block_idx(); + auto [kcache, vcache] = kvc2->get_kvcache(); + + auto k_from_gpu = empty_kvcache(15); + + size_t gpu_count = config.gpu_cache_config->gpu_devices_id.size(); + size_t element_size_per_gpu = test_cache_info.element_size(config.num_token_per_page) / gpu_count; + for (size_t i = 0; i < k_from_gpu.size(); i++) { + for (size_t j = 0; j < block_idx.size(); j++) { + size_t b_idx = block_idx[j]; + for (size_t gpu_idx = 0; gpu_idx < gpu_count; gpu_idx++) { + { + auto kt = kcache[gpu_idx][i][b_idx].to(torch::kCPU); + void* src = kt.data_ptr(); + void* dst = offset_by_bytes(k_from_gpu[i][j], gpu_idx * element_size_per_gpu); + memcpy(dst, src, element_size_per_gpu); + } + } + } + } + cmp_handle_data(k1, k_from_gpu, 10); + } + + SPDLOG_CRITICAL("All Test Passed: {}", argv[0]); + return 0; +} diff --git a/csrc/balance_serve/kvc2/test/kvc2test/lookup-gpu-mt.cpp b/csrc/balance_serve/kvc2/test/kvc2test/lookup-gpu-mt.cpp new file mode 100644 index 0000000..4179ccc --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvc2test/lookup-gpu-mt.cpp @@ -0,0 +1,68 @@ +/** + * @Description : + * @Author : Xie Weiyu + * @Date : 2024-11-22 09:52:48 + * @Version : 1.0.0 + * @LastEditors : Xie Weiyu + * @LastEditTime : 2024-11-25 07:51:09 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + **/ + +#include "common.hpp" + +int main(int argc, char* argv[]) { + init(argc, argv); + spdlog::set_level(spdlog::level::debug); + auto kvc2 = kvc2::create_kvc2(config); + + std::mt19937 gen(123); + auto ids1 = random_ids(10 * config.num_token_per_page, gen); + auto k1 = random_kvcache(10, gen); + auto v1 = random_kvcache(10, gen); + + kvc2->raw_insert(test_model_name, test_quant_type, ids1.data(), ids1.size(), k1, v1); + +// complete same +#pragma omp parallel for + for (size_t ti = 0; ti < 3; ti++) { + auto h = kvc2->lookup_to_gpu(test_model_name, test_quant_type, ids1.data(), ids1.size(), + ids1.size() + 2 * config.num_token_per_page); + auto k = h->handle_data(true); + auto v = h->handle_data(false); + cmp_handle_data(k1, k, 10); + cmp_handle_data(v1, v, 10); + + auto block_idx = h->get_gpu_block_idx(); + auto [kcache, vcache] = kvc2->get_kvcache(); + + auto k_from_gpu = empty_kvcache(15); + auto v_from_gpu = empty_kvcache(15); + + size_t gpu_count = config.gpu_cache_config->gpu_devices_id.size(); + size_t element_size_per_gpu = test_cache_info.element_size(config.num_token_per_page) / gpu_count; + for (size_t i = 0; i < k_from_gpu.size(); i++) { + for (size_t j = 0; j < block_idx.size(); j++) { + size_t b_idx = block_idx[j]; + for (size_t gpu_idx = 0; gpu_idx < gpu_count; gpu_idx++) { + { + auto kt = kcache[gpu_idx][i][b_idx].to(torch::kCPU); + void* src = kt.data_ptr(); + void* dst = offset_by_bytes(k_from_gpu[i][j], gpu_idx * element_size_per_gpu); + memcpy(dst, src, element_size_per_gpu); + } + { + auto vt = vcache[gpu_idx][i][b_idx].to(torch::kCPU); + void* src = vt.data_ptr(); + void* dst = offset_by_bytes(v_from_gpu[i][j], gpu_idx * element_size_per_gpu); + memcpy(dst, src, element_size_per_gpu); + } + } + } + } + cmp_handle_data(k1, k_from_gpu, 10); + cmp_handle_data(v1, v_from_gpu, 10); + } + + SPDLOG_CRITICAL("All Test Passed: {}", argv[0]); + return 0; +} diff --git a/csrc/balance_serve/kvc2/test/kvc2test/lookup-gpu.cpp b/csrc/balance_serve/kvc2/test/kvc2test/lookup-gpu.cpp new file mode 100644 index 0000000..36ba80f --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvc2test/lookup-gpu.cpp @@ -0,0 +1,160 @@ +/** + * @Description : + * @Author : Xie Weiyu + * @Date : 2024-11-22 09:52:48 + * @Version : 1.0.0 + * @LastEditors : Xie Weiyu + * @LastEditTime : 2024-11-25 08:38:33 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + **/ + +#include "common.hpp" + +int main(int argc, char* argv[]) { + init(argc, argv); + spdlog::set_level(spdlog::level::debug); + auto kvc2 = kvc2::create_kvc2(config); + + std::mt19937 gen(123); + auto ids1 = random_ids(10 * config.num_token_per_page, gen); + auto k1 = random_kvcache(10, gen); + auto v1 = random_kvcache(10, gen); + + kvc2->raw_insert(test_model_name, test_quant_type, ids1.data(), ids1.size(), k1, v1); + + // complete same + { + auto h = kvc2->lookup_to_gpu(test_model_name, test_quant_type, ids1.data(), ids1.size(), + ids1.size() + 5 * config.num_token_per_page); + auto k = h->handle_data(true); + auto v = h->handle_data(false); + cmp_handle_data(k1, k, 10); + cmp_handle_data(v1, v, 10); + + auto block_idx = h->get_gpu_block_idx(); + auto [kcache, vcache] = kvc2->get_kvcache(); + + auto k_from_gpu = empty_kvcache(15); + auto v_from_gpu = empty_kvcache(15); + + size_t gpu_count = config.gpu_cache_config->gpu_devices_id.size(); + size_t element_size_per_gpu = test_cache_info.element_size(config.num_token_per_page) / gpu_count; + for (size_t i = 0; i < k_from_gpu.size(); i++) { + for (size_t j = 0; j < block_idx.size(); j++) { + size_t b_idx = block_idx[j]; + for (size_t gpu_idx = 0; gpu_idx < gpu_count; gpu_idx++) { + { + auto kt = kcache[gpu_idx][i][b_idx].to(torch::kCPU); + void* src = kt.data_ptr(); + void* dst = offset_by_bytes(k_from_gpu[i][j], gpu_idx * element_size_per_gpu); + memcpy(dst, src, element_size_per_gpu); + } + { + auto vt = vcache[gpu_idx][i][b_idx].to(torch::kCPU); + void* src = vt.data_ptr(); + void* dst = offset_by_bytes(v_from_gpu[i][j], gpu_idx * element_size_per_gpu); + memcpy(dst, src, element_size_per_gpu); + } + } + } + } + cmp_handle_data(k1, k_from_gpu, 10); + cmp_handle_data(v1, v_from_gpu, 10); + } + + // prefix and evict + { + auto h = kvc2->lookup_to_gpu(test_model_name, test_quant_type, ids1.data(), config.num_token_per_page * 3, + config.gpu_cache_config->total_kvcache_pages * config.num_token_per_page); + auto k = h->handle_data(true); + auto v = h->handle_data(false); + cmp_handle_data(k1, k, 3); + cmp_handle_data(v1, v, 3); + + auto block_idx = h->get_gpu_block_idx(); + auto [kcache, vcache] = kvc2->get_kvcache(); + + auto k_from_gpu = empty_kvcache(3); + auto v_from_gpu = empty_kvcache(3); + + size_t gpu_count = config.gpu_cache_config->gpu_devices_id.size(); + size_t element_size_per_gpu = test_cache_info.element_size(config.num_token_per_page) / gpu_count; + for (size_t i = 0; i < k_from_gpu.size(); i++) { + for (size_t j = 0; j < 3; j++) { + size_t b_idx = block_idx[j]; + for (size_t gpu_idx = 0; gpu_idx < gpu_count; gpu_idx++) { + { + auto kt = kcache[gpu_idx][i][b_idx].to(torch::kCPU); + void* src = kt.data_ptr(); + void* dst = offset_by_bytes(k_from_gpu[i][j], gpu_idx * element_size_per_gpu); + memcpy(dst, src, element_size_per_gpu); + } + { + auto vt = vcache[gpu_idx][i][b_idx].to(torch::kCPU); + void* src = vt.data_ptr(); + void* dst = offset_by_bytes(v_from_gpu[i][j], gpu_idx * element_size_per_gpu); + memcpy(dst, src, element_size_per_gpu); + } + } + } + } + cmp_handle_data(k1, k_from_gpu, 3); + cmp_handle_data(v1, v_from_gpu, 3); + } + + // // complete prefix + // { + // std::vector ids2(ids1.begin(), ids1.begin() + 3 * config.num_token_per_page); + // auto h = kvc2->lookup(test_model_name, test_quant_type, ids2.data(), ids2.size(), + // ids2.size() + 3 * config.num_token_per_page); + // auto k = h->handle_data(true); + // auto v = h->handle_data(false); + // cmp_handle_data(k1, k, 3); + // cmp_handle_data(v1, v, 3); + // } + + // // common prefix + // { + // std::vector ids2(ids1.begin(), ids1.begin() + 3 * config.num_token_per_page); + // auto rids = random_ids(config.num_token_per_page * 2 + config.num_token_per_page / 2, gen); + // ids2.insert(ids2.end(), rids.begin(), rids.end()); + + // auto h = kvc2->lookup(test_model_name, test_quant_type, ids2.data(), ids2.size(), ids2.size()); + // auto k = h->handle_data(true); + // auto v = h->handle_data(false); + // cmp_handle_data(k1, k, 3); + // cmp_handle_data(v1, v, 3); + // } + + // // no prefix + // { + // std::vector ids2 = random_ids(config.num_token_per_page, gen); + // auto h = kvc2->lookup(test_model_name, test_quant_type, ids2.data(), ids2.size(), ids2.size()); + // assert(h->matched_length() == 0); + // } + + // // insert partly new + // auto k2 = random_kvcache(10, gen); + // auto v2 = random_kvcache(10, gen); + // copy_kvcache(k1, k2, 0, 5); + // copy_kvcache(v1, v2, 0, 5); + // auto ids2 = random_ids(10 * config.num_token_per_page, gen); + // for (size_t i = 0; i < 5 * config.num_token_per_page; i++) { + // ids2[i] = ids1[i]; + // } + // kvc2->raw_insert(test_model_name, test_quant_type, ids2.data(), ids2.size(), k2, v2); + + // // read new part + // { + // std::vector ids(ids2.begin(), ids2.begin() + 7 * config.num_token_per_page); + // auto h = kvc2->lookup(test_model_name, test_quant_type, ids.data(), ids.size(), + // ids.size() + 7 * config.num_token_per_page); + // auto k = h->handle_data(true); + // auto v = h->handle_data(false); + // cmp_handle_data(k, k2, 7); + // cmp_handle_data(v, v2, 7); + // } + + SPDLOG_CRITICAL("All Test Passed: {}", argv[0]); + return 0; +} diff --git a/csrc/balance_serve/kvc2/test/kvc2test/lookup-mt.cpp b/csrc/balance_serve/kvc2/test/kvc2test/lookup-mt.cpp new file mode 100644 index 0000000..323485b --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvc2test/lookup-mt.cpp @@ -0,0 +1,103 @@ +/** + * @Description : + * @Author : Xie Weiyu + * @Date : 2024-11-22 08:48:40 + * @Version : 1.0.0 + * @LastEditors : Xie Weiyu + * @LastEditTime : 2024-11-22 09:53:06 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + **/ + +#include "common.hpp" + +template +void test_multi(F f) { + std::vector threads; + for (size_t i = 0; i < 10; i++) { + threads.push_back([f]() { f(); }); + } + for (auto& t : threads) { + t.join(); + } +} + +int main(int argc, char* argv[]) { + init(argc, argv); + spdlog::set_level(spdlog::level::debug); + auto kvc2 = kvc2::create_kvc2(config); + + std::mt19937 gen(123); + auto ids1 = random_ids(3 * config.num_token_per_page, gen); + auto k1 = random_kvcache(3, gen); + auto v1 = random_kvcache(3, gen); + + kvc2->raw_insert(test_model_name, test_quant_type, ids1.data(), ids1.size(), k1, v1); + + // complete same + { +#pragma omp parallel for + for (size_t i = 0; i < 10; i++) { + auto h = kvc2->lookup(test_model_name, test_quant_type, ids1.data(), ids1.size(), + ids1.size() + 10 * config.num_token_per_page); + if (h == nullptr) { + SPDLOG_WARN("Thread[{}]: h is nullptr", i); + } else { + auto k = h->handle_data(true); + auto v = h->handle_data(false); + cmp_handle_data(k1, k, 3); + cmp_handle_data(v1, v, 3); + } + } + } + + // // complete prefix + // { + // std::vector ids2(ids1.begin(), ids1.begin() + 3 * config.num_token_per_page); + // auto h = kvc2->lookup(test_model_name, test_quant_type, ids2.data(), ids2.size(), ids2.size() + 3 * + // config.num_token_per_page); auto k = h->handle_data(true); auto v = h->handle_data(false); cmp_handle_data(k1, + // k, 3); cmp_handle_data(v1, v, 3); + // } + + // // common prefix + // { + // std::vector ids2(ids1.begin(), ids1.begin() + 3 * config.num_token_per_page); + // auto rids = random_ids(config.num_token_per_page * 2 + config.num_token_per_page / 2, gen); + // ids2.insert(ids2.end(), rids.begin(), rids.end()); + + // auto h = kvc2->lookup(test_model_name, test_quant_type, ids2.data(), ids2.size(), ids2.size()); + // auto k = h->handle_data(true); + // auto v = h->handle_data(false); + // cmp_handle_data(k1, k, 3); + // cmp_handle_data(v1, v, 3); + // } + + // // no prefix + // { + // std::vector ids2 = random_ids(config.num_token_per_page, gen); + // auto h = kvc2->lookup(test_model_name, test_quant_type, ids2.data(), ids2.size(), ids2.size()); + // assert(h->matched_length() == 0); + // } + + // // insert partly new + // auto k2 = random_kvcache(10, gen); + // auto v2 = random_kvcache(10, gen); + // copy_kvcache(k1, k2, 0, 5); + // copy_kvcache(v1, v2, 0, 5); + // auto ids2 = random_ids(10 * config.num_token_per_page, gen); + // for (size_t i = 0; i < 5 * config.num_token_per_page; i++) { + // ids2[i] = ids1[i]; + // } + // kvc2->raw_insert(test_model_name, test_quant_type, ids2.data(), ids2.size(), k2, v2); + + // // read new part + // { + // std::vector ids(ids2.begin(), ids2.begin() + 7 * config.num_token_per_page); + // auto h = kvc2->lookup(test_model_name, test_quant_type, ids.data(), ids.size(), ids.size() + 7 * + // config.num_token_per_page); auto k = h->handle_data(true); auto v = h->handle_data(false); cmp_handle_data(k, + // k2, 7); cmp_handle_data(v, v2, 7); + // } + kvc2->debug(); + + SPDLOG_CRITICAL("All Test Passed: {}", argv[0]); + return 0; +} diff --git a/csrc/balance_serve/kvc2/test/kvc2test/lookup-without-vcache.cpp b/csrc/balance_serve/kvc2/test/kvc2test/lookup-without-vcache.cpp new file mode 100644 index 0000000..febf5a8 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvc2test/lookup-without-vcache.cpp @@ -0,0 +1,85 @@ +/** + * @Description : + * @Author : Xie Weiyu + * @Date : 2024-11-22 08:29:45 + * @Version : 1.0.0 + * @LastEditors : Xie Weiyu + * @LastEditTime : 2024-11-22 09:56:12 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + **/ + +#include "common.hpp" + +int main(int argc, char* argv[]) { + + qw25_7B_gpu_config.v_cache_on = false; + config.gpu_cache_config = qw25_7B_gpu_config; + config.v_cache_on = false; + + init(argc, argv); + spdlog::set_level(spdlog::level::debug); + auto kvc2 = kvc2::create_kvc2(config); + + std::mt19937 gen(123); + auto ids1 = random_ids(10 * config.num_token_per_page, gen); + auto k1 = random_kvcache(10, gen); + // auto v1 = random_kvcache(10, gen); + + kvc2->raw_insert(test_model_name, test_quant_type, ids1.data(), ids1.size(), k1, {}); + + // complete same + { + auto h = kvc2->lookup(test_model_name, test_quant_type, ids1.data(), ids1.size(), + ids1.size() + 10 * config.num_token_per_page); + auto k = h->handle_data(true); + cmp_handle_data(k1, k, 10); + } + + // complete prefix + { + std::vector ids2(ids1.begin(), ids1.begin() + 3 * config.num_token_per_page); + auto h = kvc2->lookup(test_model_name, test_quant_type, ids2.data(), ids2.size(), + ids2.size() + 3 * config.num_token_per_page); + auto k = h->handle_data(true); + cmp_handle_data(k1, k, 3); + } + + // common prefix + { + std::vector ids2(ids1.begin(), ids1.begin() + 3 * config.num_token_per_page); + auto rids = random_ids(config.num_token_per_page * 2 + config.num_token_per_page / 2, gen); + ids2.insert(ids2.end(), rids.begin(), rids.end()); + + auto h = kvc2->lookup(test_model_name, test_quant_type, ids2.data(), ids2.size(), ids2.size()); + auto k = h->handle_data(true); + cmp_handle_data(k1, k, 3); + } + + // no prefix + { + std::vector ids2 = random_ids(config.num_token_per_page, gen); + auto h = kvc2->lookup(test_model_name, test_quant_type, ids2.data(), ids2.size(), ids2.size()); + assert(h->matched_length() == 0); + } + + // insert partly new + auto k2 = random_kvcache(10, gen); + copy_kvcache(k1, k2, 0, 5); + auto ids2 = random_ids(10 * config.num_token_per_page, gen); + for (size_t i = 0; i < 5 * config.num_token_per_page; i++) { + ids2[i] = ids1[i]; + } + kvc2->raw_insert(test_model_name, test_quant_type, ids2.data(), ids2.size(), k2, {}); + + // read new part + { + std::vector ids(ids2.begin(), ids2.begin() + 7 * config.num_token_per_page); + auto h = kvc2->lookup(test_model_name, test_quant_type, ids.data(), ids.size(), + ids.size() + 7 * config.num_token_per_page); + auto k = h->handle_data(true); + cmp_handle_data(k, k2, 7); + } + + SPDLOG_CRITICAL("All Test Passed: {}", argv[0]); + return 0; +} diff --git a/csrc/balance_serve/kvc2/test/kvc2test/lookup.cpp b/csrc/balance_serve/kvc2/test/kvc2test/lookup.cpp new file mode 100644 index 0000000..ec470f2 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvc2test/lookup.cpp @@ -0,0 +1,90 @@ +/** + * @Description : + * @Author : Xie Weiyu + * @Date : 2024-11-22 08:29:45 + * @Version : 1.0.0 + * @LastEditors : Xie Weiyu + * @LastEditTime : 2024-11-22 09:56:12 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + **/ + +#include "common.hpp" + +int main(int argc, char* argv[]) { + init(argc, argv); + spdlog::set_level(spdlog::level::debug); + auto kvc2 = kvc2::create_kvc2(config); + + std::mt19937 gen(123); + auto ids1 = random_ids(10 * config.num_token_per_page, gen); + auto k1 = random_kvcache(10, gen); + auto v1 = random_kvcache(10, gen); + + kvc2->raw_insert(test_model_name, test_quant_type, ids1.data(), ids1.size(), k1, v1); + + // complete same + { + auto h = kvc2->lookup(test_model_name, test_quant_type, ids1.data(), ids1.size(), + ids1.size() + 10 * config.num_token_per_page); + auto k = h->handle_data(true); + auto v = h->handle_data(false); + cmp_handle_data(k1, k, 10); + cmp_handle_data(v1, v, 10); + } + + // complete prefix + { + std::vector ids2(ids1.begin(), ids1.begin() + 3 * config.num_token_per_page); + auto h = kvc2->lookup(test_model_name, test_quant_type, ids2.data(), ids2.size(), + ids2.size() + 3 * config.num_token_per_page); + auto k = h->handle_data(true); + auto v = h->handle_data(false); + cmp_handle_data(k1, k, 3); + cmp_handle_data(v1, v, 3); + } + + // common prefix + { + std::vector ids2(ids1.begin(), ids1.begin() + 3 * config.num_token_per_page); + auto rids = random_ids(config.num_token_per_page * 2 + config.num_token_per_page / 2, gen); + ids2.insert(ids2.end(), rids.begin(), rids.end()); + + auto h = kvc2->lookup(test_model_name, test_quant_type, ids2.data(), ids2.size(), ids2.size()); + auto k = h->handle_data(true); + auto v = h->handle_data(false); + cmp_handle_data(k1, k, 3); + cmp_handle_data(v1, v, 3); + } + + // no prefix + { + std::vector ids2 = random_ids(config.num_token_per_page, gen); + auto h = kvc2->lookup(test_model_name, test_quant_type, ids2.data(), ids2.size(), ids2.size()); + assert(h->matched_length() == 0); + } + + // insert partly new + auto k2 = random_kvcache(10, gen); + auto v2 = random_kvcache(10, gen); + copy_kvcache(k1, k2, 0, 5); + copy_kvcache(v1, v2, 0, 5); + auto ids2 = random_ids(10 * config.num_token_per_page, gen); + for (size_t i = 0; i < 5 * config.num_token_per_page; i++) { + ids2[i] = ids1[i]; + } + kvc2->raw_insert(test_model_name, test_quant_type, ids2.data(), ids2.size(), k2, v2); + + // read new part + { + std::vector ids(ids2.begin(), ids2.begin() + 7 * config.num_token_per_page); + auto h = kvc2->lookup(test_model_name, test_quant_type, ids.data(), ids.size(), + ids.size() + 7 * config.num_token_per_page); + auto k = h->handle_data(true); + auto v = h->handle_data(false); + cmp_handle_data(k, k2, 7); + cmp_handle_data(v, v2, 7); + } + + SPDLOG_CRITICAL("All Test Passed: {}", argv[0]); + return 0; +} diff --git a/csrc/balance_serve/kvc2/test/kvc2test/raw_insert_read.cpp b/csrc/balance_serve/kvc2/test/kvc2test/raw_insert_read.cpp new file mode 100644 index 0000000..9191fd1 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvc2test/raw_insert_read.cpp @@ -0,0 +1,99 @@ +/** + * @Description : + * @Author : Xie Weiyu + * @Date : 2024-11-22 06:00:16 + * @Version : 1.0.0 + * @LastEditors : Xie Weiyu + * @LastEditTime : 2024-11-22 07:30:46 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + **/ + +#include "common.hpp" + +int main(int argc, char* argv[]) { + init(argc, argv); + spdlog::set_level(spdlog::level::debug); + auto kvc2 = kvc2::create_kvc2(config); + + std::mt19937 gen(123); + auto ids1 = random_ids(10 * config.num_token_per_page, gen); + auto k1 = random_kvcache(10, gen); + auto v1 = random_kvcache(10, gen); + + kvc2->raw_insert(test_model_name, test_quant_type, ids1.data(), ids1.size(), k1, v1); + + // complete same + { + auto k2 = empty_kvcache(10); + auto v2 = empty_kvcache(10); + auto l2 = kvc2->raw_read(test_model_name, test_quant_type, ids1.data(), ids1.size(), k2, v2); + assert(l2 == ids1.size()); + + cmp_handle_data(k1, k2); + cmp_handle_data(v1, v2); + } + + // complete prefix + { + auto k2 = empty_kvcache(10); + auto v2 = empty_kvcache(10); + std::vector ids2 = std::vector(ids1.begin(), ids1.begin() + 3 * config.num_token_per_page); + auto l2 = kvc2->raw_read(test_model_name, test_quant_type, ids2.data(), ids2.size(), k2, v2); + assert(l2 == 3 * config.num_token_per_page); + + cmp_handle_data(k1, k2, 3); + cmp_handle_data(v1, v2, 3); + } + + // common prefix + { + auto k2 = empty_kvcache(10); + auto v2 = empty_kvcache(10); + std::vector ids2 = std::vector(ids1.begin(), ids1.begin() + 3 * config.num_token_per_page); + auto rids = random_ids(config.num_token_per_page * 2 + config.num_token_per_page / 2, gen); + ids2.insert(ids2.end(), rids.begin(), rids.end()); + + auto l2 = kvc2->raw_read(test_model_name, test_quant_type, ids2.data(), ids2.size(), k2, v2); + assert(l2 == 3 * config.num_token_per_page); + + cmp_handle_data(k1, k2, 3); + cmp_handle_data(v1, v2, 3); + } + + // no prefix + { + auto k2 = empty_kvcache(1); + auto v2 = empty_kvcache(1); + std::vector ids2 = random_ids(config.num_token_per_page, gen); + auto l2 = kvc2->raw_read(test_model_name, test_quant_type, ids2.data(), ids2.size(), k2, v2); + assert(l2 == 0); + } + + // insert partly new + auto k2 = random_kvcache(10, gen); + auto v2 = random_kvcache(10, gen); + copy_kvcache(k1, k2, 0, 5); + copy_kvcache(v1, v2, 0, 5); + auto ids2 = random_ids(10 * config.num_token_per_page, gen); + for (size_t i = 0; i < 5 * config.num_token_per_page; i++) { + ids2[i] = ids1[i]; + } + kvc2->raw_insert(test_model_name, test_quant_type, ids2.data(), ids2.size(), k2, v2); + + // read new part + { + auto k = empty_kvcache(10); + auto v = empty_kvcache(10); + std::vector ids = std::vector(ids2.begin(), ids2.begin() + 7 * config.num_token_per_page); + + auto l = kvc2->raw_read(test_model_name, test_quant_type, ids.data(), ids.size(), k, v); + assert(l == 7 * config.num_token_per_page); + + cmp_handle_data(k, k2, 7); + cmp_handle_data(v, v2, 7); + } + + SPDLOG_CRITICAL("All Test Passed: {}", argv[0]); + + return 0; +} diff --git a/csrc/balance_serve/kvc2/test/kvcache_disk_insert_read_test.cpp b/csrc/balance_serve/kvc2/test/kvcache_disk_insert_read_test.cpp new file mode 100644 index 0000000..a3a38ab --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvcache_disk_insert_read_test.cpp @@ -0,0 +1,87 @@ +#include "kvcache_test_utils.cpp" + +int main(int argc, char* argv[]) { + parse_and_check(argc, argv); + spdlog::set_level(spdlog::level::debug); + std::mt19937 gen(123); + + KVC2 kvc2(FLAGS_disk_cache_path); + // auto io = kvc2.io_dealer->start_io_thread(); + kvc2.io_dealer->start_io_thread().detach(); + + auto h1 = random_kvcache(qwen_cache_info, 10, gen); + h1.ids = random_ids(10 * BlockLength, gen); + kvc2.raw_insert(h1); + + // complete same + { + auto h2 = empty_kvcache(qwen_cache_info, 10); + h2.ids = h1.ids; + kvc2.raw_read(h2); + assert(static_cast(h2.match.match_length) == h1.ids.size()); + + cmp_handle_data(h1, h2); + } + + // complete prefix + { + auto h2 = empty_kvcache(qwen_cache_info, 10); + + h2.ids = std::vector(h1.ids.begin(), h1.ids.begin() + 3 * BlockLength); + kvc2.raw_read(h2); + assert(h2.match.match_length == 3 * BlockLength); + + cmp_handle_data(h1, h2, 3); + } + + // common prefix + { + auto h2 = empty_kvcache(qwen_cache_info, 10); + + h2.ids = std::vector(h1.ids.begin(), h1.ids.begin() + 5 * BlockLength); + auto rids = random_ids(BlockLength * 2 + BlockLength / 2, gen); + h2.ids.insert(h2.ids.end(), rids.begin(), rids.end()); + + kvc2.raw_read(h2); + assert(h2.match.match_length == 5 * BlockLength); + + cmp_handle_data(h1, h2, 5); + } + + // no prefix + { + auto h2 = empty_kvcache(qwen_cache_info, 10); + + h2.ids = random_ids(10 * BlockLength, gen); + + kvc2.raw_read(h2); + assert(h2.match.match_length == 0); + } + + // insert partly new + auto h2 = random_kvcache(qwen_cache_info, 10, gen); + copy_kvcache(h1, h2, 0, 5); + h2.ids = random_ids(10 * BlockLength, gen); + for (size_t i = 0; i < 5 * BlockLength; i++) { + h2.ids[i] = h1.ids[i]; + } + kvc2.raw_insert(h2); + + // read new part + { + auto h = empty_kvcache(qwen_cache_info, 10); + h.ids = std::vector(h2.ids.begin(), h2.ids.begin() + 7 * BlockLength); + h.ids.push_back(123); + + kvc2.raw_read(h); + assert(h.match.match_length == 7 * BlockLength); + cmp_handle_data(h, h2, 7); + } + + kvc2.tree->debug(); + kvc2.io_dealer->stop(); + // io.join(); + + SPDLOG_WARN("{} Test Passed", __FILE__); + return 0; +} \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/test/kvcache_mem_eviction_test.cpp b/csrc/balance_serve/kvc2/test/kvcache_mem_eviction_test.cpp new file mode 100644 index 0000000..70f6987 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvcache_mem_eviction_test.cpp @@ -0,0 +1,52 @@ +#include "kvcache_test_utils.cpp" + +int main(int argc, char* argv[]) { + parse_and_check(argc, argv); + spdlog::set_level(spdlog::level::debug); + std::mt19937 gen(123); + + KVC2 kvc2(FLAGS_disk_cache_path); + auto io = kvc2.io_dealer->start_io_thread(); + + SPDLOG_WARN("Insert 10 x 10 KVCache"); + std::vector handles(10); + for (int i = 0; i < 10; i++) { + handles[i] = random_kvcache(qwen_cache_info, 10, gen); + auto& h1 = handles[i]; + h1.ids = random_ids(10 * BlockLength, gen); + kvc2.raw_insert(h1); + } + + SPDLOG_WARN("Cache Eviction Test"); + { + for (int i = 0; i < 10; i++) { + auto& h = handles[i]; + SPDLOG_WARN("Lookup {}", i); + auto x = kvc2.lookup(qwen_cache_info, h.ids.data(), h.ids.size()); + cmp_handle_data(h, *x); + } + SPDLOG_WARN("Simple Eviction OK"); + } + + { + std::vector> lookup_handles; + for (int i = 0; i < 10; i++) { + auto& h = handles[i]; + SPDLOG_WARN("Lookup {}", i); + auto x = kvc2.lookup(qwen_cache_info, h.ids.data(), h.ids.size()); + if (i >= 5) { + assert(x == nullptr); + continue; + } + lookup_handles.push_back(x); + cmp_handle_data(h, *x); + } + SPDLOG_WARN("Cannot Eviction OK"); + } + + kvc2.io_dealer->stop(); + io.join(); + + SPDLOG_WARN("{} Test Passed", __FILE__); + return 0; +} \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/test/kvcache_mem_insert_read_test.cpp b/csrc/balance_serve/kvc2/test/kvcache_mem_insert_read_test.cpp new file mode 100644 index 0000000..e92d3fb --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvcache_mem_insert_read_test.cpp @@ -0,0 +1,104 @@ +#include "kvcache_test_utils.cpp" + +int main(int argc, char* argv[]) { + parse_and_check(argc, argv); + spdlog::set_level(spdlog::level::debug); + std::mt19937 gen(123); + + KVC2 kvc2(FLAGS_disk_cache_path); + auto io = kvc2.io_dealer->start_io_thread(); + + SPDLOG_INFO("Disk Test"); + auto h1 = random_kvcache(qwen_cache_info, 10, gen); + h1.ids = random_ids(10 * BlockLength, gen); + kvc2.raw_insert(h1); + + // complete same + { + auto h2 = empty_kvcache(qwen_cache_info, 10); + h2.ids = h1.ids; + kvc2.raw_read(h2); + assert(static_cast(h2.match.match_length) == h1.ids.size()); + + cmp_handle_data(h1, h2); + } + + // complete prefix + { + auto h2 = empty_kvcache(qwen_cache_info, 10); + + h2.ids = std::vector(h1.ids.begin(), h1.ids.begin() + 3 * BlockLength); + kvc2.raw_read(h2); + assert(h2.match.match_length == 3 * BlockLength); + + cmp_handle_data(h1, h2, 3); + } + + // common prefix + { + auto h2 = empty_kvcache(qwen_cache_info, 10); + + h2.ids = std::vector(h1.ids.begin(), h1.ids.begin() + 5 * BlockLength); + auto rids = random_ids(BlockLength * 2 + BlockLength / 2, gen); + h2.ids.insert(h2.ids.end(), rids.begin(), rids.end()); + + kvc2.raw_read(h2); + assert(h2.match.match_length == 5 * BlockLength); + + cmp_handle_data(h1, h2, 5); + } + + // no prefix + { + auto h2 = empty_kvcache(qwen_cache_info, 10); + + h2.ids = random_ids(10 * BlockLength, gen); + + kvc2.raw_read(h2); + assert(h2.match.match_length == 0); + } + + // insert partly new + auto h2 = random_kvcache(qwen_cache_info, 10, gen); + copy_kvcache(h1, h2, 0, 5); + h2.ids = random_ids(10 * BlockLength, gen); + for (size_t i = 0; i < 5 * BlockLength; i++) { + h2.ids[i] = h1.ids[i]; + } + kvc2.raw_insert(h2); + + // read new part + { + auto h = empty_kvcache(qwen_cache_info, 10); + h.ids = std::vector(h2.ids.begin(), h2.ids.begin() + 7 * BlockLength); + h.ids.push_back(123); + + kvc2.raw_read(h); + assert(h.match.match_length == 7 * BlockLength); + cmp_handle_data(h, h2, 7); + } + + SPDLOG_WARN("Memory Test"); + + { + auto h = kvc2.lookup(qwen_cache_info, h1.ids.data(), h1.ids.size()); + assert(h); + cmp_handle_data(h1, *h); + kvc2.block_cache->debug(); + } + kvc2.block_cache->debug(); + + { + auto h = kvc2.lookup(qwen_cache_info, h1.ids.data(), 5 * BlockLength); + assert(h); + cmp_handle_data(h1, *h, 5); + kvc2.block_cache->debug(); + } + kvc2.block_cache->debug(); + + kvc2.io_dealer->stop(); + io.join(); + + SPDLOG_WARN("{} Test Passed", __FILE__); + return 0; +} \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/test/kvcache_save_load_test.cpp b/csrc/balance_serve/kvc2/test/kvcache_save_load_test.cpp new file mode 100644 index 0000000..bf289b0 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/kvcache_save_load_test.cpp @@ -0,0 +1,102 @@ +#include "kvcache_test_utils.cpp" + +int main(int argc, char* argv[]) { + parse_and_check(argc, argv); + spdlog::set_level(spdlog::level::debug); + std::mt19937 gen(123); + std::vector handles(10); + + { + KVC2 kvc2(FLAGS_disk_cache_path); + auto io = kvc2.io_dealer->start_io_thread(); + SPDLOG_WARN("Insert 10 x 10 KVCache"); + for (int i = 0; i < 10; i++) { + handles[i] = random_kvcache(qwen_cache_info, 10, gen); + auto& h1 = handles[i]; + h1.ids = random_ids(10 * BlockLength, gen); + kvc2.raw_insert(h1); + } + + kvc2.save(); + kvc2.tree->debug(); + + kvc2.io_dealer->stop(); + io.join(); + } + { + KVC2 kvc2(FLAGS_disk_cache_path); + auto io = kvc2.io_dealer->start_io_thread(); + kvc2.load(); + kvc2.tree->debug(); + auto& h1 = handles[0]; + // complete same + { + auto h2 = empty_kvcache(qwen_cache_info, 10); + h2.ids = h1.ids; + kvc2.raw_read(h2); + assert(static_cast(h2.match.match_length) == h1.ids.size()); + + cmp_handle_data(h1, h2); + } + + // complete prefix + { + auto h2 = empty_kvcache(qwen_cache_info, 10); + + h2.ids = std::vector(h1.ids.begin(), h1.ids.begin() + 3 * BlockLength); + kvc2.raw_read(h2); + assert(h2.match.match_length == 3 * BlockLength); + + cmp_handle_data(h1, h2, 3); + } + + // common prefix + { + auto h2 = empty_kvcache(qwen_cache_info, 10); + + h2.ids = std::vector(h1.ids.begin(), h1.ids.begin() + 5 * BlockLength); + auto rids = random_ids(BlockLength * 2 + BlockLength / 2, gen); + h2.ids.insert(h2.ids.end(), rids.begin(), rids.end()); + + kvc2.raw_read(h2); + assert(h2.match.match_length == 5 * BlockLength); + + cmp_handle_data(h1, h2, 5); + } + + // no prefix + { + auto h2 = empty_kvcache(qwen_cache_info, 10); + + h2.ids = random_ids(10 * BlockLength, gen); + + kvc2.raw_read(h2); + assert(h2.match.match_length == 0); + } + + // insert partly new + auto h2 = random_kvcache(qwen_cache_info, 10, gen); + copy_kvcache(h1, h2, 0, 5); + h2.ids = random_ids(10 * BlockLength, gen); + for (size_t i = 0; i < 5 * BlockLength; i++) { + h2.ids[i] = h1.ids[i]; + } + kvc2.raw_insert(h2); + + // read new part + { + auto h = empty_kvcache(qwen_cache_info, 10); + h.ids = std::vector(h2.ids.begin(), h2.ids.begin() + 7 * BlockLength); + h.ids.push_back(123); + + kvc2.raw_read(h); + assert(h.match.match_length == 7 * BlockLength); + cmp_handle_data(h, h2, 7); + } + + kvc2.io_dealer->stop(); + io.join(); + } + SPDLOG_WARN("{} Test Passed", __FILE__); + return 0; +} \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/test/kvcache_test_utils.cpp b/csrc/balance_serve/kvc2/test/kvcache_test_utils.cpp new file mode 100644 index 0000000..e69de29 diff --git a/csrc/balance_serve/kvc2/test/page_pool_test.cpp b/csrc/balance_serve/kvc2/test/page_pool_test.cpp new file mode 100644 index 0000000..5bccbca --- /dev/null +++ b/csrc/balance_serve/kvc2/test/page_pool_test.cpp @@ -0,0 +1,59 @@ + +#include +#include +#include +#include +#include +#include "page_aligned_memory_pool.cpp" + +#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG +#define FMT_HEADER_ONLY +#include "spdlog/spdlog.h" + + +// 每个线程执行的任务 +void thread_task(PageAlignedMemoryPool& pool) { + std::mt19937 gen(123); + std::vector> allocated; + size_t cnt = 40000; + for (size_t i = 0; i < cnt; ++i) { + // 随机分配一个大小 + size_t size = (gen() % 100 + 1) * 4096 * 4; + void* ptr = pool.alloc(size); + // SPDLOG_DEBUG(pool.debug()); + if (ptr) { + pool.free(ptr, size); + // allocated.push_back({ptr, size}); + } + // sleep((int)(gen() % 1000) / 1000.0); + } + // free all memory + for (auto& p : allocated) { + pool.free(p.first, p.second); + } +} + +int main(int argc, char* argv[]) { + spdlog::set_level(spdlog::level::debug); + + + // 创建一个内存池 + PageAlignedMemoryPool pool(40ll * 1024 * 1024 * 1024); // 40 G + + // 创建线程 + const int num_threads = 32; + std::vector threads; + for (int i = 0; i < num_threads; ++i) { + threads.emplace_back(thread_task, std::ref(pool)); + } + + // 等待所有线程完成 + for (auto& t : threads) { + t.join(); + } + + // 输出调试信息 + std::cout << pool.debug() << std::endl; + + return 0; +} \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/test/prefix_test.cpp b/csrc/balance_serve/kvc2/test/prefix_test.cpp new file mode 100644 index 0000000..e69de29 diff --git a/csrc/balance_serve/kvc2/test/pytest_load.py b/csrc/balance_serve/kvc2/test/pytest_load.py new file mode 100644 index 0000000..2cd32ca --- /dev/null +++ b/csrc/balance_serve/kvc2/test/pytest_load.py @@ -0,0 +1,61 @@ +import sys +sys.path.append('./build') +sys.path.append('./src') +import torch +import kvc2_ext +from kvc2_utils import get_tensor_from_data_ptr + +# Create a kvc2 instance +path = "/mnt/data/kvc2" +kvc2_instance = kvc2_ext.create_kvc2(path,int(10e9)) # 10 G memory pool +kvc2_ext.load(kvc2_instance) + +# Start IO thread +print("Start IO thread") +kvc2_ext.start_io_thread(kvc2_instance) +print("IO thread started") + +# Create CacheInfoInput +test_info = kvc2_ext.CacheInfoInput() +test_info.model_type = kvc2_ext.ModelType.MT_DeepseekV2 +test_info.cache_type = kvc2_ext.CacheType.CT_KeyCache +test_info.quant_type = kvc2_ext.QuantType.QT_F32 + +print("Element size: ", test_info.element_size()) + +# Generate random test IDs (length = 2560) +torch.manual_seed(123) +length = 2560 +test_id = torch.randint(0, 65536, (length,), dtype=torch.uint16).contiguous() +block_count = (length+255) // 256 +# print("Test ID: ", test_id) + +# Generate test data based on element size and hidden layer count +element_size = test_info.element_size() +hidden_layer_count = test_info.hidden_layer_count() + +def read_cmp_and_release(kvc2_instance,cache_info,ids,length): + handle = kvc2_ext.lookup(kvc2_instance, cache_info, ids, length) + if kvc2_ext.is_nullptr(handle): + print("Handle is nullptr.") + exit() + matched_length = kvc2_ext.matched_length(handle) + matched_data = kvc2_ext.handle_data(handle) + print('Matched length: ', matched_length) + if matched_length >0: + print(f'First layer address {[hex(x) for x in matched_data[0]]}') + read_data = get_tensor_from_data_ptr(matched_data,element_size) + + print("Just read check ok.") + kvc2_ext.release(handle) + + +l = 128 +while l<=length: + read_cmp_and_release(kvc2_instance,test_info,test_id.data_ptr(),l) + l+=128 + +kvc2_ext.destroy_kvc2(kvc2_instance) + + +print("Test completed successfully.") diff --git a/csrc/balance_serve/kvc2/test/pytest_mem_prefix_test.py b/csrc/balance_serve/kvc2/test/pytest_mem_prefix_test.py new file mode 100644 index 0000000..b6ed649 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/pytest_mem_prefix_test.py @@ -0,0 +1,83 @@ +import sys +sys.path.append('./build') +sys.path.append('./src') +import torch +import kvc2_ext +from kvc2_utils import alloc_aligned_cache,dealloc_aligned_cache,get_tensor_ptr,get_tensor_from_data_ptr + +# Create a kvc2 instance +path = "/mnt/data/kvc2" +kvc2_instance = kvc2_ext.create_kvc2(path,int(10e9)) # 10 G memory pool + +# Start IO thread +print("Start IO thread") +kvc2_ext.start_io_thread(kvc2_instance) +print("IO thread started") + +# Create CacheInfoInput +test_info = kvc2_ext.CacheInfoInput() +test_info.model_type = kvc2_ext.ModelType.MT_DeepseekV2 +test_info.cache_type = kvc2_ext.CacheType.CT_KeyCache +test_info.quant_type = kvc2_ext.QuantType.QT_F32 + +print("Element size: ", test_info.element_size()) + +# Generate random test IDs (length = 2560) +torch.manual_seed(123) +length = 2560 +test_id = torch.randint(0, 65536, (length,), dtype=torch.uint16).contiguous() +block_count = (length+255) // 256 +# print("Test ID: ", test_id) + +# Generate test data based on element size and hidden layer count +element_size = test_info.element_size() +hidden_layer_count = test_info.hidden_layer_count() + +write_data,write_data_mem = alloc_aligned_cache(hidden_layer_count,block_count,element_size) +# print(test_data,test_data_mem) +print('Generate Insert Data') +for layer in write_data: + for data in layer: + random_values = torch.randint(0, 256, (element_size,), dtype=torch.uint8) + data.copy_(random_values) + +print('Insert New data') +# Insert raw data +kvc2_ext.raw_insert(kvc2_instance, test_info, test_id.data_ptr(), length, get_tensor_ptr(write_data)) + + +def read_cmp_and_release(kvc2_instance,cache_info,ids,length): + handle = kvc2_ext.lookup(kvc2_instance, cache_info, ids, length) + if kvc2_ext.is_nullptr(handle): + print("Handle is nullptr.") + exit() + matched_length = kvc2_ext.matched_length(handle) + matched_data = kvc2_ext.handle_data(handle) + print('Matched length: ', matched_length) + if matched_length >0: + print(f'First layer address {[hex(x) for x in matched_data[0]]}') + read_data = get_tensor_from_data_ptr(matched_data,element_size) + + for layer_w,layer_r in zip(write_data,read_data): + for data_w,data_r in zip(layer_w,layer_r): + # print(data_w,data_r) + assert torch.equal(data_w,data_r) + print("Lookup read check ok.") + kvc2_ext.release(handle) + + +l = 128 +while l<=length: + read_cmp_and_release(kvc2_instance,test_info,test_id.data_ptr(),l) + l+=128 + + + +dealloc_aligned_cache(write_data_mem) + + +kvc2_ext.save(kvc2_instance) +kvc2_ext.destroy_kvc2(kvc2_instance) + + +print("Test completed successfully.") diff --git a/csrc/balance_serve/kvc2/test/pytest_mem_read.py b/csrc/balance_serve/kvc2/test/pytest_mem_read.py new file mode 100644 index 0000000..c96042b --- /dev/null +++ b/csrc/balance_serve/kvc2/test/pytest_mem_read.py @@ -0,0 +1,72 @@ +import sys +sys.path.append('./build') +sys.path.append('./src') +import torch +import kvc2_ext +from kvc2_utils import alloc_aligned_cache,dealloc_aligned_cache,get_tensor_ptr,get_tensor_from_data_ptr + +# Create a kvc2 instance +path = "/mnt/data/kvc2" +kvc2_instance = kvc2_ext.create_kvc2(path,int(10e9)) # 10 G memory pool + +# Start IO thread +print("Start IO thread") +kvc2_ext.start_io_thread(kvc2_instance) +print("IO thread started") + +# Create CacheInfoInput +test_info = kvc2_ext.CacheInfoInput() +test_info.model_type = kvc2_ext.ModelType.MT_DeepseekV2 +test_info.cache_type = kvc2_ext.CacheType.CT_KeyCache +test_info.quant_type = kvc2_ext.QuantType.QT_F32 + +print("Element size: ", test_info.element_size()) + +# Generate random test IDs (length = 2560) +length = 2560 +test_id = torch.randint(0, 65536, (length,), dtype=torch.uint16).contiguous() +block_count = (length+255) // 256 +# print("Test ID: ", test_id) + +# Generate test data based on element size and hidden layer count +element_size = test_info.element_size() +hidden_layer_count = test_info.hidden_layer_count() + +write_data,write_data_mem = alloc_aligned_cache(hidden_layer_count,block_count,element_size) +# print(test_data,test_data_mem) +print('Generate Insert Data') +for layer in write_data: + for data in layer: + random_values = torch.randint(0, 256, (element_size,), dtype=torch.uint8) + data.copy_(random_values) + +print('Insert New data') +# Insert raw data +kvc2_ext.raw_insert(kvc2_instance, test_info, test_id.data_ptr(), length, get_tensor_ptr(write_data)) + + +handle = kvc2_ext.lookup(kvc2_instance, test_info, test_id.data_ptr(), length) +matched_length = kvc2_ext.matched_length(handle) +matched_data = kvc2_ext.handle_data(handle) + +print('Matched length: ', matched_length) +print(f'Match data layer {len(matched_data)}') +print(f'Match layer block count {len(matched_data[0])}') +read_data = get_tensor_from_data_ptr(matched_data,element_size) + + +for layer_w,layer_r in zip(write_data,read_data): + for data_w,data_r in zip(layer_w,layer_r): + # print(data_w,data_r) + assert torch.equal(data_w,data_r) +print("Lookup read check ok.") + +dealloc_aligned_cache(write_data_mem) + + +kvc2_ext.save(kvc2_instance) + + + + +print("Test completed successfully.") diff --git a/csrc/balance_serve/kvc2/test/pytest_raw_insert_and_read.py b/csrc/balance_serve/kvc2/test/pytest_raw_insert_and_read.py new file mode 100644 index 0000000..3ccdf96 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/pytest_raw_insert_and_read.py @@ -0,0 +1,69 @@ +import sys +sys.path.append('./build') +sys.path.append('./src') +import torch +import kvc2_ext +from kvc2_utils import alloc_aligned_cache,dealloc_aligned_cache,get_tensor_ptr + +# Create a kvc2 instance +path = "/mnt/data/kvc2" +kvc2_instance = kvc2_ext.create_kvc2(path,int(10e9)) # 10 G memory pool + +# Start IO thread +print("Start IO thread") +kvc2_ext.start_io_thread(kvc2_instance) +print("IO thread started") + +# Create CacheInfoInput +test_info = kvc2_ext.CacheInfoInput() +test_info.model_type = kvc2_ext.ModelType.MT_DeepseekV2 +test_info.cache_type = kvc2_ext.CacheType.CT_KeyCache +test_info.quant_type = kvc2_ext.QuantType.QT_F32 + +print("Element size: ", test_info.element_size()) + +# Generate random test IDs (length = 2560) +length = 2560 +test_id = torch.randint(0, 65536, (length,), dtype=torch.uint16).contiguous() +block_count = (length+255) // 256 +# print("Test ID: ", test_id) + +# Generate test data based on element size and hidden layer count +element_size = test_info.element_size() +hidden_layer_count = test_info.hidden_layer_count() + +write_data,write_data_mem = alloc_aligned_cache(hidden_layer_count,block_count,element_size) +# print(test_data,test_data_mem) +print('Generate Insert Data') +for layer in write_data: + for data in layer: + random_values = torch.randint(0, 256, (element_size,), dtype=torch.uint8) + data.copy_(random_values) + +print('Insert New data') +# Insert raw data +kvc2_ext.raw_insert(kvc2_instance, test_info, test_id.data_ptr(), length, get_tensor_ptr(write_data)) + + +read_data,read_data_mem = alloc_aligned_cache(hidden_layer_count,block_count,element_size) + +print('Raw read') +matched_length = kvc2_ext.raw_read(kvc2_instance, test_info, test_id.data_ptr(), length,get_tensor_ptr(read_data)) + +print('Matched length: ', matched_length) +for layer_w,layer_r in zip(write_data,read_data): + for data_w,data_r in zip(layer_w,layer_r): + # print(data_w,data_r) + assert torch.equal(data_w,data_r) +print("Raw read check ok.") + +dealloc_aligned_cache(write_data_mem) +dealloc_aligned_cache(read_data_mem) + + +kvc2_ext.save(kvc2_instance) + + + + +print("Test completed successfully.") diff --git a/csrc/balance_serve/kvc2/test/test_align.py b/csrc/balance_serve/kvc2/test/test_align.py new file mode 100644 index 0000000..72ce165 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/test_align.py @@ -0,0 +1,32 @@ +import ctypes +import torch + +def aligned_tensor(size, alignment=4096): + num_bytes = size + mem = ctypes.c_void_p() + error_code = ctypes.CDLL(None).posix_memalign( + ctypes.byref(mem), ctypes.c_size_t(alignment), ctypes.c_size_t(num_bytes) + ) + + if error_code != 0: + raise MemoryError(f"posix_memalign failed with error code {error_code}") + + array_type = (ctypes.c_int8 * size) + raw_array = array_type.from_address(mem.value) + + tensor = torch.frombuffer(raw_array, dtype=torch.int8) + + if tensor.data_ptr() % alignment != 0: + raise ValueError(f"Tensor data_ptr {tensor.data_ptr()} is not aligned to {alignment} bytes") + + return tensor, mem + + +size = 5124380 +tensor, mem_ptr = aligned_tensor(size, alignment=4096) + +print(f"Tensor: {tensor}, size: {tensor.size()}, dataptr: {tensor.data_ptr()}") +print(f"Tensor memory alignment: {tensor.data_ptr() % 4096 == 0}") +print(f"Allocated memory address: {mem_ptr.value}") + +ctypes.CDLL(None).free(mem_ptr) diff --git a/csrc/balance_serve/kvc2/test/test_cuda_stream.cpp b/csrc/balance_serve/kvc2/test/test_cuda_stream.cpp new file mode 100644 index 0000000..1f8f0ae --- /dev/null +++ b/csrc/balance_serve/kvc2/test/test_cuda_stream.cpp @@ -0,0 +1,145 @@ +#include +#include +#include +#include +#include + +class CudaStreamManager { + public: + CudaStreamManager(int num_streams); + ~CudaStreamManager(); + + // Request structure + struct Request { + std::vector host_mem_addresses; + std::vector device_mem_addresses; + std::vector sizes; + cudaMemcpyKind direction; + std::function callback; + }; + + void submitRequest(const Request& request); + + private: + int num_streams_; + std::vector streams_; + int next_stream_index_; +}; + +CudaStreamManager::CudaStreamManager(int num_streams) : num_streams_(num_streams), next_stream_index_(0) { + streams_.resize(num_streams_); + for (int i = 0; i < num_streams_; ++i) { + cudaError_t err = cudaStreamCreate(&streams_[i]); + if (err != cudaSuccess) { + std::cerr << "Failed to create CUDA stream: " << cudaGetErrorString(err) << std::endl; + for (int j = 0; j < i; ++j) { + cudaStreamDestroy(streams_[j]); + } + throw std::runtime_error("Failed to create CUDA stream"); + } + } +} + +CudaStreamManager::~CudaStreamManager() { + for (int i = 0; i < num_streams_; ++i) { + cudaStreamDestroy(streams_[i]); + } +} + +void CudaStreamManager::submitRequest(const Request& request) { + int stream_index = next_stream_index_; + cudaStream_t stream = streams_[stream_index]; + next_stream_index_ = (next_stream_index_ + 1) % num_streams_; + + size_t num_transfers = request.host_mem_addresses.size(); + for (size_t i = 0; i < num_transfers; ++i) { + cudaError_t err = cudaMemcpyAsync(request.device_mem_addresses[i], request.host_mem_addresses[i], request.sizes[i], + request.direction, stream); + if (err != cudaSuccess) { + std::cerr << "cudaMemcpyAsync failed: " << cudaGetErrorString(err) << std::endl; + throw std::runtime_error("cudaMemcpyAsync failed"); + } + } + + // Enqueue the callback function + struct CallbackData { + std::function callback; + }; + + CallbackData* cb_data = new CallbackData{request.callback}; + + cudaError_t err = cudaLaunchHostFunc( + stream, + [](void* data) { + CallbackData* cb_data = static_cast(data); + cb_data->callback(); + delete cb_data; + }, + cb_data); + + if (err != cudaSuccess) { + std::cerr << "cudaLaunchHostFunc failed: " << cudaGetErrorString(err) << std::endl; + throw std::runtime_error("cudaLaunchHostFunc failed"); + } +} + +// Example usage +int main() { + try { + CudaStreamManager stream_manager(4); // Create a manager with 4 streams + + // Prepare host and device memory + const size_t num_pages = 10; + std::vector host_mem_addresses(num_pages); + std::vector device_mem_addresses(num_pages); + std::vector sizes(num_pages, 4096); // 4KB pages + + // Allocate host memory + for (size_t i = 0; i < num_pages; ++i) { + host_mem_addresses[i] = malloc(4096); + if (!host_mem_addresses[i]) { + throw std::runtime_error("Failed to allocate host memory"); + } + // Initialize data if necessary + } + + // Allocate device memory + for (size_t i = 0; i < num_pages; ++i) { + cudaError_t err = cudaMalloc(&device_mem_addresses[i], 4096); + if (err != cudaSuccess) { + std::cerr << "cudaMalloc failed: " << cudaGetErrorString(err) << std::endl; + throw std::runtime_error("cudaMalloc failed"); + } + } + + // Create a request + CudaStreamManager::Request request; + request.host_mem_addresses = host_mem_addresses; + request.device_mem_addresses = device_mem_addresses; + request.sizes = sizes; + request.direction = cudaMemcpyHostToDevice; + request.callback = []() { std::cout << "Data transfer completed!" << std::endl; }; + + // Submit the request + stream_manager.submitRequest(request); + + // Wait for all streams to complete + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + std::cerr << "cudaDeviceSynchronize failed: " << cudaGetErrorString(err) << std::endl; + throw std::runtime_error("cudaDeviceSynchronize failed"); + } + + // Clean up + for (size_t i = 0; i < num_pages; ++i) { + free(host_mem_addresses[i]); + cudaFree(device_mem_addresses[i]); + } + + } catch (const std::exception& e) { + std::cerr << "Exception: " << e.what() << std::endl; + return 1; + } + + return 0; +} diff --git a/csrc/balance_serve/kvc2/test/test_cuda_stream_manager.cpp b/csrc/balance_serve/kvc2/test/test_cuda_stream_manager.cpp new file mode 100644 index 0000000..464af82 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/test_cuda_stream_manager.cpp @@ -0,0 +1,113 @@ +#include "cuda_stream_manager.hh" + +#include +#include +#include +#include +#include + +int main() { + try { + int num_devices = 0; + cudaError_t err = cudaGetDeviceCount(&num_devices); + if (err != cudaSuccess) { + std::cerr << "cudaGetDeviceCount failed: " << cudaGetErrorString(err) << std::endl; + return 1; + } + + if (num_devices < 1) { + std::cerr << "未找到 CUDA 设备。" << std::endl; + return 1; + } + + std::vector device_ids; + for (int i = 0; i < num_devices; ++i) { + device_ids.push_back(i); + } + + const size_t num_pages = 10; + const size_t page_size = 4096; // 每页 4KB + + // 创建 CudaStreamManager 实例,管理所有设备 + CudaStreamManager stream_manager(device_ids, 4); + + // 准备主机内存和设备内存映射 + std::vector> host_mem_addresses(num_devices); + std::vector> device_mem_addresses(num_devices); + + // 分配主机内存 + for (size_t i = 0; i < num_pages; ++i) { + void* host_ptr = malloc(page_size); + if (!host_ptr) { + throw std::runtime_error("Failed to allocate host memory"); + } + // 如果需要,初始化数据 + + // 将相同的主机内存添加到每个设备的列表中 + for (int device_id = 0; device_id < num_devices; ++device_id) { + host_mem_addresses[device_id].push_back(host_ptr); + } + } + + // 为每个设备分配设备内存 + for (int device_id = 0; device_id < num_devices; ++device_id) { + err = cudaSetDevice(device_id); + if (err != cudaSuccess) { + std::cerr << "cudaSetDevice failed: " << cudaGetErrorString(err) << std::endl; + throw std::runtime_error("cudaSetDevice failed"); + } + + for (size_t i = 0; i < num_pages; ++i) { + void* device_ptr; + err = cudaMalloc(&device_ptr, page_size); + if (err != cudaSuccess) { + std::cerr << "cudaMalloc failed on device " << device_id << ": " << cudaGetErrorString(err) << std::endl; + throw std::runtime_error("cudaMalloc failed"); + } + device_mem_addresses[device_id].push_back(device_ptr); + } + } + + // 为每个设备创建并提交请求 + for (int device_id = 0; device_id < num_devices; ++device_id) { + auto request = std::shared_ptr(new CudaStreamManager::Request); + request->device_id = device_id; + request->host_mem_addresses = host_mem_addresses[device_id]; + request->device_mem_addresses = device_mem_addresses[device_id]; + request->sizes = std::vector(num_pages, page_size); + request->direction = cudaMemcpyHostToDevice; + request->callback = [device_id]() { + std::cout << "Device " << device_id << " data transfer completed!" << std::endl; + }; + + stream_manager.submitRequest(request); + } + + // 等待一段时间,确保所有请求都被处理 + // 在实际应用中,可以使用更好的同步机制 + std::this_thread::sleep_for(std::chrono::seconds(5)); + + // 清理主机内存 + for (size_t i = 0; i < num_pages; ++i) { + free(host_mem_addresses[0][i]); // 所有设备共享相同的主机内存,只需释放一次 + } + + // 清理设备内存 + for (int device_id = 0; device_id < num_devices; ++device_id) { + err = cudaSetDevice(device_id); + if (err != cudaSuccess) { + std::cerr << "cudaSetDevice failed during cleanup: " << cudaGetErrorString(err) << std::endl; + continue; + } + for (void* ptr : device_mem_addresses[device_id]) { + cudaFree(ptr); + } + } + + } catch (const std::exception& e) { + std::cerr << "异常: " << e.what() << std::endl; + return 1; + } + + return 0; +} diff --git a/csrc/balance_serve/kvc2/test/test_lock_free_queue.cpp b/csrc/balance_serve/kvc2/test/test_lock_free_queue.cpp new file mode 100644 index 0000000..5e11ba3 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/test_lock_free_queue.cpp @@ -0,0 +1,56 @@ +#include +#include +#include +#include +#include "utils/lock_free_queue.hpp" + +struct Item { + int value; + std::promise promise; +}; + +int main() { + MPSCQueue queue; + + std::vector producers; + const int num_producers = 4; + const int items_per_producer = 5; + + // 启动生产者线程 + for (int i = 0; i < num_producers; ++i) { + producers.emplace_back([&queue, i]() { + for (int j = 0; j < items_per_producer; ++j) { + auto item = std::make_shared(); + item->value = i * items_per_producer + j; + std::future future = item->promise.get_future(); + queue.enqueue(item); + future.wait(); // 等待消费者处理完成 + } + }); + } + + // 启动消费者线程 + std::thread consumer([&queue, num_producers, items_per_producer]() { + int total_items = num_producers * items_per_producer; + int processed = 0; + while (processed < total_items) { + std::shared_ptr item = queue.dequeue(); + if (item) { + std::cout << "Consumed item with value: " << item->value << std::endl; + item->promise.set_value(); // 通知生产者 + ++processed; + } else { + // 如果队列为空,可以选择休眠或让出线程 + std::this_thread::yield(); + } + } + }); + + // 等待所有线程完成 + for (auto& producer : producers) { + producer.join(); + } + consumer.join(); + + return 0; +} \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/test/test_periodic_task.cpp b/csrc/balance_serve/kvc2/test/test_periodic_task.cpp new file mode 100644 index 0000000..a2f8f89 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/test_periodic_task.cpp @@ -0,0 +1,171 @@ +#include "utils/periodic_task.hpp" +#include +#include +#include +#include +#include +#include +#include + +// 1. 任务是否按预期执行 +void testPeriodicTaskExecution() { + std::atomic execution_count{0}; + auto task = [&execution_count]() { + execution_count++; + }; + + periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(50)); + + std::this_thread::sleep_for(std::chrono::seconds(2)); + + assert(execution_count >= 20); // 确保任务执行了至少 20 次 + std::cout << "Test 1 passed: Task executed periodically." << std::endl; + std::cout << "Task executed " << execution_count.load() << " times." << std::endl; +} + +// 2. 提前唤醒任务的功能 +void testWakeUpImmediately() { + std::atomic execution_count{0}; + auto task = [&execution_count]() { + execution_count++; + }; + + periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200)); + + // 提前唤醒任务 + periodic_task.wakeUp(); + std::this_thread::sleep_for(std::chrono::milliseconds(50)); // 等待任务执行 + + std::cout << "Execution count after wakeUp: " << execution_count.load() << std::endl; + assert(execution_count == 1); // 确保任务立即执行 + std::cout << "Test 2 passed: Task woke up immediately." << std::endl; +} + +// 3. wakeUpWait() 的等待功能 +void testWakeUpWait() { + std::promise promise; + std::future future = promise.get_future(); + auto task = [&promise]() { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // 模拟任务执行 + promise.set_value(); // 任务完成时设置 promise + }; + + periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200)); + + // 调用 wakeUpWait 并等待任务完成 + std::future wakeup_future = periodic_task.wakeUpWait(); + wakeup_future.wait(); // 等待任务完成 + + assert(wakeup_future.valid()); // 确保 future 是有效的 + std::cout << "Test 3 passed: wakeUpWait() works correctly." << std::endl; + std::cout << "wakeUpWait() future is valid." << std::endl; +} + +// 4. 任务抛出异常的处理 +void testTaskExceptionHandling() { + auto task = []() { + throw std::runtime_error("Test exception"); + }; + + periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200)); + + std::this_thread::sleep_for(std::chrono::milliseconds(300)); // 等待一段时间 + + std::cout << "Test 4 passed: Task exception is handled correctly." << std::endl; + std::cout << "Exception handled and task did not crash." << std::endl; +} + +// 5. 线程是否能正确停止 +void testTaskStop() { + std::atomic stopped{false}; + auto task = [&stopped]() { + while (!stopped) { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + }; + + periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(100)); + + std::this_thread::sleep_for(std::chrono::seconds(1)); // 运行一段时间 + + stopped = true; // 请求停止 + std::this_thread::sleep_for(std::chrono::milliseconds(50)); // 等待线程停止 + + std::cout << "Test 5 passed: Task thread stops correctly." << std::endl; + std::cout << "Task has been stopped successfully." << std::endl; +} + +// 6. 高频唤醒的情况下任务执行是否正常 +void testHighFrequencyWakeUp() { + std::atomic execution_count{0}; + auto task = [&execution_count]() { + execution_count++; + }; + + periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200)); + + for (int i = 0; i < 100; ++i) { + periodic_task.wakeUp(); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); // 每 10 毫秒唤醒一次 + } + + std::this_thread::sleep_for(std::chrono::seconds(1)); // 等待任务执行完成 + + assert(execution_count > 50); // 确保任务至少执行了 50 次 + std::cout << "Test 6 passed: Task handles frequent wake ups correctly." << std::endl; + std::cout << "Task executed " << execution_count.load() << " times." << std::endl; +} + +// 7. 多个 wakeUpWait() 调用的处理 +void testMultipleWakeUpWait() { + std::atomic execution_count{0}; + auto task = [&execution_count]() { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // 模拟任务执行 + execution_count++; + }; + + periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(200)); + + // 同时调用两个 wakeUpWait + std::future future1 = periodic_task.wakeUpWait(); + std::future future2 = periodic_task.wakeUpWait(); + + future1.wait(); + future2.wait(); + + assert(execution_count == 1); // 确保任务只执行了一次 + std::cout << "Test 7 passed: Multiple wakeUpWait() calls are handled correctly." << std::endl; + std::cout << "Task executed " << execution_count.load() << " times." << std::endl; +} + +// 8. 任务函数为空的边界情况 +void testEmptyTaskFunction() { + auto task = []() { + // 空任务函数 + }; + + periodic::PeriodicTask periodic_task(task, std::chrono::milliseconds(100)); + + std::this_thread::sleep_for(std::chrono::seconds(1)); // 等待一段时间 + + std::cout << "Test 8 passed: Empty task function works correctly." << std::endl; + std::cout << "Empty task function executed without issues." << std::endl; +} + +int main() { + std::cout << "Starting tests..." << std::endl; + + // testWakeUpImmediately(); + testPeriodicTaskExecution(); + testWakeUpImmediately(); + testWakeUpWait(); + testTaskExceptionHandling(); + testTaskStop(); + testHighFrequencyWakeUp(); + testMultipleWakeUpWait(); + testEmptyTaskFunction(); + + std::cout << "All tests passed!" << std::endl; + + return 0; +} diff --git a/csrc/balance_serve/kvc2/test/test_queue_perf.cpp b/csrc/balance_serve/kvc2/test/test_queue_perf.cpp new file mode 100644 index 0000000..5dea851 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/test_queue_perf.cpp @@ -0,0 +1,84 @@ +#include +#include +#include "utils/lock_free_queue.hpp" + +#define STDQ + +int main() { + const int num_producers = 48; + const int num_items = 1e6; + +#ifdef STDQ + std::mutex lock; + std::queue queue; +#else + MPSCQueue queue; +#endif + + auto start_time = std::chrono::high_resolution_clock::now(); + + // Launch multiple producer threads + std::vector producers; + for (int i = 0; i < num_producers; ++i) { + producers.emplace_back([&queue, i +#ifdef STDQ + , + &lock +#endif + ]() { + for (int j = 0; j < num_items; ++j) { +#ifdef STDQ + std::lock_guard guard(lock); + queue.push(i * num_items + j); +#else + queue.enqueue(std::make_shared(i * num_items + j)); +#endif + } + }); + } + + // Consumer thread + std::thread consumer([&queue, num_producers +#ifdef STDQ + , + &lock +#endif + ]() { + int count = 0; + while (count < num_producers * num_items) { +#ifdef STDQ + std::lock_guard guard(lock); + if (!queue.empty()) { + queue.pop(); + count++; + } +#else + if (auto item = queue.dequeue()) { + count++; + } +#endif + } + }); + + // Wait for all producers to finish + for (auto& producer : producers) { + producer.join(); + } + + // Wait for the consumer to finish + consumer.join(); + + auto end_time = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end_time - start_time).count(); + +#ifdef STDQ + std::cout << "std::queue with mutex "; +#else + std::cout << "lock free queue "; +#endif + + std::cout << "Processed " << num_producers * num_items / 1e6 << "M items in " << duration << " milliseconds " + << num_producers * num_items / 1e3 / duration << " MOps." << std::endl; + + return 0; +} \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/test/test_std_list.cpp b/csrc/balance_serve/kvc2/test/test_std_list.cpp new file mode 100644 index 0000000..1ed5f51 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/test_std_list.cpp @@ -0,0 +1,38 @@ +#include +#include +#include + +int main() { + std::vector v = {0, 1, 2, 3, 4, 5}; + + using RevIt = std::reverse_iterator::iterator>; + + const auto it = v.begin() + 3; + RevIt r_it{it}; + + std::cout << "*it == " << *it << '\n' + << "*r_it == " << *r_it << '\n' + << "*r_it.base() == " << *r_it.base() << '\n' + << "*(r_it.base()-1) == " << *(r_it.base() - 1) << '\n'; + + RevIt r_end{v.begin()}; + RevIt r_begin{v.end()}; + + for (auto it = r_end.base(); it != r_begin.base(); ++it) + std::cout << *it << ' '; + std::cout << '\n'; + + for (auto it = r_begin; it != r_end; ++it) + std::cout << *it << ' '; + std::cout << '\n'; + + for (auto it = r_begin; it != r_end; ++it) { + if (*it == 3) { + v.erase(std::next(it).base()); + } + } + + for (auto it : v) + std::cout << it << ' '; + std::cout << '\n'; +} \ No newline at end of file diff --git a/csrc/balance_serve/kvc2/test/xxHash_test.cpp b/csrc/balance_serve/kvc2/test/xxHash_test.cpp new file mode 100644 index 0000000..7a782d0 --- /dev/null +++ b/csrc/balance_serve/kvc2/test/xxHash_test.cpp @@ -0,0 +1,31 @@ +#include "xxhash.h" +#include + +int main() { + std::string t = "hello world"; + XXH64_hash_t hash = XXH64(t.data(), t.size(), 123); + std::cout << hash << std::endl; + { + /* create a hash state */ + XXH64_state_t* const state = XXH64_createState(); + if (state == NULL) + abort(); + + if (XXH64_reset(state, 123) == XXH_ERROR) + abort(); + + if (XXH64_update(state, t.data(), 5) == XXH_ERROR) + abort(); + + if (XXH64_update(state, t.data() + 5, t.size() - 5) == XXH_ERROR) + abort(); + /* Produce the final hash value */ + XXH64_hash_t const hash = XXH64_digest(state); + + /* State could be re-used; but in this example, it is simply freed */ + XXH64_freeState(state); + std::cout << hash << std::endl; + } + + return 0; +} diff --git a/csrc/balance_serve/kvc2/unit_test.sh b/csrc/balance_serve/kvc2/unit_test.sh new file mode 100755 index 0000000..495a2c4 --- /dev/null +++ b/csrc/balance_serve/kvc2/unit_test.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +# 检查是否提供了 disk_cache_path 参数 +if [ -z "$1" ]; then + echo "Usage: $0 " + exit 1 +fi + +# 将 disk_cache_path 参数赋值给变量 +disk_cache_path=$1 + +# 定义测试命令数组,并使用变量替换 disk_cache_path +tests=( + "./build/test/kvc2_export_header_test --disk_cache_path=$disk_cache_path" + "./build/test/kvcache_disk_insert_read_test --disk_cache_path=$disk_cache_path" + "./build/test/kvcache_mem_eviction_test --disk_cache_path=$disk_cache_path" + "./build/test/kvcache_mem_insert_read_test --disk_cache_path=$disk_cache_path" + "./build/test/kvcache_save_load_test --disk_cache_path=$disk_cache_path" +) + + +# 遍历每个测试命令 +for test in "${tests[@]}"; do + echo "Running: $test" + # 运行测试并捕获输出 + output=$($test) + + # 检查测试输出中是否包含 "Test Passed" + if echo "$output" | grep -q "Test Passed"; then + echo " Test Passed" + else + echo " Test Failed" + fi + + sleep 1 +done \ No newline at end of file diff --git a/csrc/balance_serve/sched/CMakeLists.txt b/csrc/balance_serve/sched/CMakeLists.txt new file mode 100644 index 0000000..6b5dcda --- /dev/null +++ b/csrc/balance_serve/sched/CMakeLists.txt @@ -0,0 +1,19 @@ +set(CMAKE_CXX_FLAGS "-Og -march=native -Wall -Wextra -g -fPIC") +# set(CMAKE_CXX_FLAGS "-O3 -march=native -Wall -Wextra -fPIC") + +set(UTILS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/utils) + +add_library(sched_metrics metrics.cpp) +target_include_directories(sched_metrics PRIVATE ${UTILS_DIR}) +target_link_libraries(sched_metrics PUBLIC prometheus-cpp::pull) + + +add_library(sched scheduler.cpp) +target_include_directories(sched PRIVATE ${SPDLOG_DIR}/include ${FMT_DIR}/include ${UTILS_DIR} ${KVC2_INCLUDE_DIR}) +target_link_libraries(sched PUBLIC pthread ${TORCH_LIBRARIES} kvc2 async_store sched_metrics) + +pybind11_add_module(sched_ext bind.cpp) +target_link_libraries(sched_ext PUBLIC sched ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY}) + + + diff --git a/csrc/balance_serve/sched/bind.cpp b/csrc/balance_serve/sched/bind.cpp new file mode 100644 index 0000000..d280e44 --- /dev/null +++ b/csrc/balance_serve/sched/bind.cpp @@ -0,0 +1,208 @@ +#include +#include +#include +#include +#include "scheduler.h" + +#include + +namespace py = pybind11; + +PYBIND11_MODULE(sched_ext, m) { + py::class_(m, "ModelSettings") + .def(py::init<>()) + .def_readwrite("model_path", &scheduler::ModelSettings::model_path) + .def_readwrite("params_count", &scheduler::ModelSettings::params_count) + .def_readwrite("layer_count", &scheduler::ModelSettings::layer_count) + .def_readwrite("num_k_heads", &scheduler::ModelSettings::num_k_heads) + .def_readwrite("k_head_dim", &scheduler::ModelSettings::k_head_dim) + .def_readwrite("bytes_per_params", &scheduler::ModelSettings::bytes_per_params) + .def_readwrite("bytes_per_kv_cache_element", &scheduler::ModelSettings::bytes_per_kv_cache_element) + .def("params_size", &scheduler::ModelSettings::params_nbytes) + .def("bytes_per_token_kv_cache", &scheduler::ModelSettings::bytes_per_token_kv_cache) + // 添加 pickle 支持 + .def(py::pickle( + [](const scheduler::ModelSettings& self) { // __getstate__ + return py::make_tuple(self.params_count, self.layer_count, self.num_k_heads, self.k_head_dim, + self.bytes_per_params, self.bytes_per_kv_cache_element); + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 6) + throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size())); + scheduler::ModelSettings ms; + ms.params_count = t[0].cast(); + ms.layer_count = t[1].cast(); + ms.num_k_heads = t[2].cast(); + ms.k_head_dim = t[3].cast(); + ms.bytes_per_params = t[4].cast(); + ms.bytes_per_kv_cache_element = t[5].cast(); + return ms; + })); + + py::class_(m, "SampleOptions") + .def(py::init<>()) + .def_readwrite("temperature", &scheduler::SampleOptions::temperature) + .def_readwrite("top_p", &scheduler::SampleOptions::top_p) // 确保 top_p 也能被访问 + .def(py::pickle( + [](const scheduler::SampleOptions& self) { + return py::make_tuple(self.temperature, self.top_p); // 序列化 temperature 和 top_p + }, + [](py::tuple t) { + if (t.size() != 2) // 确保解包时参数数量匹配 + throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size())); + scheduler::SampleOptions so; + so.temperature = t[0].cast(); + so.top_p = t[1].cast(); // 反序列化 top_p + return so; + } + )); + + py::class_(m, "Settings") + .def(py::init<>()) + .def_readwrite("model_name", &scheduler::Settings::model_name) + .def_readwrite("quant_type", &scheduler::Settings::quant_type) + .def_readwrite("model_settings", &scheduler::Settings::model_settings) + .def_readwrite("page_size", &scheduler::Settings::page_size) + .def_readwrite("gpu_device_id", &scheduler::Settings::gpu_device_id) + .def_readwrite("gpu_memory_size", &scheduler::Settings::gpu_memory_size) + .def_readwrite("memory_utilization_percentage", &scheduler::Settings::memory_utilization_percentage) + .def_readwrite("max_batch_size", &scheduler::Settings::max_batch_size) + .def_readwrite("recommended_chunk_prefill_token_count", + &scheduler::Settings::recommended_chunk_prefill_token_count) + .def_readwrite("sample_options", &scheduler::Settings::sample_options) + .def_readwrite("sched_metrics_port", &scheduler::Settings::sched_metrics_port) + .def_readwrite("gpu_only", &scheduler::Settings::gpu_only) + .def_readwrite("use_self_defined_head_dim", &scheduler::Settings::use_self_defined_head_dim) + .def_readwrite("self_defined_head_dim", &scheduler::Settings::self_defined_head_dim) + .def_readwrite("full_kv_cache_on_each_gpu", &scheduler::Settings::full_kv_cache_on_each_gpu) + .def_readwrite("k_cache_on", &scheduler::Settings::k_cache_on) + .def_readwrite("v_cache_on", &scheduler::Settings::v_cache_on) + .def_readwrite("kvc2_config_path", &scheduler::Settings::kvc2_config_path) + .def_readwrite("kvc2_root_path", &scheduler::Settings::kvc2_root_path) + .def_readwrite("memory_pool_size_GB", &scheduler::Settings::memory_pool_size_GB) + .def_readwrite("evict_count", &scheduler::Settings::evict_count) + .def_readwrite("strategy_name", &scheduler::Settings::strategy_name) + .def_readwrite("kvc2_metrics_port", &scheduler::Settings::kvc2_metrics_port) + .def_readwrite("load_from_disk", &scheduler::Settings::load_from_disk) + .def_readwrite("save_to_disk", &scheduler::Settings::save_to_disk) + // derived + .def_readwrite("gpu_device_count", &scheduler::Settings::gpu_device_count) + .def_readwrite("total_kvcache_pages", &scheduler::Settings::total_kvcache_pages) + .def_readwrite("devices", &scheduler::Settings::devices) + .def("auto_derive", &scheduler::Settings::auto_derive); + + py::class_>(m, "BatchQueryTodo") + .def(py::init<>()) + .def_readwrite("query_ids", &scheduler::BatchQueryTodo::query_ids) + .def_readwrite("query_tokens", &scheduler::BatchQueryTodo::query_tokens) + .def_readwrite("query_lengths", &scheduler::BatchQueryTodo::query_lengths) + .def_readwrite("block_indexes", &scheduler::BatchQueryTodo::block_indexes) + .def_readwrite("attn_masks", &scheduler::BatchQueryTodo::attn_masks) + .def_readwrite("rope_ranges", &scheduler::BatchQueryTodo::rope_ranges) + .def_readwrite("sample_options", &scheduler::BatchQueryTodo::sample_options) + .def_readwrite("prefill_mini_batches", &scheduler::BatchQueryTodo::prefill_mini_batches) + .def_readwrite("decode_mini_batches", &scheduler::BatchQueryTodo::decode_mini_batches) + .def_readwrite("stop_criteria", &scheduler::BatchQueryTodo::stop_criteria) + .def("debug", &scheduler::BatchQueryTodo::debug) + .def(py::pickle( + [](const scheduler::BatchQueryTodo& self) { + return py::make_tuple(self.query_ids, self.query_tokens, self.query_lengths, self.block_indexes, + self.attn_masks, self.rope_ranges, self.sample_options, self.prefill_mini_batches, + self.decode_mini_batches, self.stop_criteria); + }, + [](py::tuple t) { + if (t.size() != 10) + throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size())); + scheduler::BatchQueryTodo bqt; + bqt.query_ids = t[0].cast>(); + bqt.query_tokens = t[1].cast>(); + bqt.query_lengths = t[2].cast>(); + bqt.block_indexes = t[3].cast>(); + bqt.attn_masks = t[4].cast>(); + bqt.rope_ranges = t[5].cast>(); + bqt.sample_options = t[6].cast>(); + bqt.prefill_mini_batches = t[7].cast>(); + bqt.decode_mini_batches = t[8].cast>>(); + bqt.stop_criteria = t[9].cast>>>(); + return bqt; + })); + + py::class_(m, "QueryUpdate") + .def(py::init<>()) + .def_readwrite("id", &scheduler::QueryUpdate::id) + .def_readwrite("ok", &scheduler::QueryUpdate::ok) + .def_readwrite("is_prefill", &scheduler::QueryUpdate::is_prefill) + .def_readwrite("decode_done", &scheduler::QueryUpdate::decode_done) + .def_readwrite("active_position", &scheduler::QueryUpdate::active_position) + .def_readwrite("generated_token", &scheduler::QueryUpdate::generated_token) + .def(py::pickle( + [](const scheduler::QueryUpdate& self) { + return py::make_tuple(self.id, self.ok, self.is_prefill, self.decode_done, self.active_position, + self.generated_token); + }, + [](py::tuple t) { + if (t.size() != 6) + throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size())); + scheduler::QueryUpdate qu; + qu.id = t[0].cast(); + qu.ok = t[1].cast(); + qu.is_prefill = t[2].cast(); + qu.decode_done = t[3].cast(); + qu.active_position = t[4].cast(); + qu.generated_token = t[5].cast(); + return qu; + })); + + py::class_(m, "InferenceContext") + .def(py::init<>()) + .def_readwrite("k_cache", &scheduler::InferenceContext::k_cache) + .def_readwrite("v_cache", &scheduler::InferenceContext::v_cache) + ; + + py::class_(m, "QueryAdd") + .def(py::init<>()) + .def_readwrite("query_token", &scheduler::QueryAdd::query_token) + // .def_readwrite("attn_mask", &scheduler::QueryAdd::attn_mask) + .def_readwrite("query_length", &scheduler::QueryAdd::query_length) + .def_readwrite("estimated_length", &scheduler::QueryAdd::estimated_length) + .def_readwrite("sample_options", &scheduler::QueryAdd::sample_options) + .def_readwrite("user_id", &scheduler::QueryAdd::user_id) + .def_readwrite("SLO_TTFT_ms", &scheduler::QueryAdd::SLO_TTFT_ms) + .def_readwrite("SLO_TBT_ms", &scheduler::QueryAdd::SLO_TBT_ms) + .def_readwrite("stop_criteria", &scheduler::QueryAdd::stop_criteria) + .def("serialize", &scheduler::QueryAdd::serialize) + .def_static("deserialize", &scheduler::QueryAdd::deserialize) + .def(py::pickle( + [](const scheduler::QueryAdd& self) { + return py::make_tuple(self.query_token, + // self.attn_mask, + self.query_length, self.estimated_length, self.sample_options, self.user_id, + self.SLO_TTFT_ms, self.SLO_TBT_ms, self.stop_criteria); + }, + [](py::tuple t) { + if (t.size() != 8) + throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size())); + scheduler::QueryAdd qa; + qa.query_token = t[0].cast>(); + // qa.attn_mask = t[1].cast(); + qa.query_length = t[1].cast(); + qa.estimated_length = t[2].cast(); + qa.sample_options = t[3].cast(); + qa.user_id = t[4].cast(); + qa.SLO_TTFT_ms = t[5].cast(); + qa.SLO_TBT_ms = t[6].cast(); + qa.stop_criteria = t[7].cast>>(); + return qa; + })); + + py::class_>(m, "Scheduler") + .def("init", &scheduler::Scheduler::init) + .def("run", &scheduler::Scheduler::run) + .def("stop", &scheduler::Scheduler::stop) + .def("add_query", &scheduler::Scheduler::add_query, py::call_guard()) + .def("cancel_query", &scheduler::Scheduler::cancel_query, py::call_guard()) + .def("update_last_batch", &scheduler::Scheduler::update_last_batch, py::call_guard()) + .def("get_inference_context", &scheduler::Scheduler::get_inference_context); + + m.def("create_scheduler", &scheduler::create_scheduler, "Create a new Scheduler instance"); +} diff --git a/csrc/balance_serve/sched/metrics.cpp b/csrc/balance_serve/sched/metrics.cpp new file mode 100644 index 0000000..8a6f259 --- /dev/null +++ b/csrc/balance_serve/sched/metrics.cpp @@ -0,0 +1,135 @@ +#include "metrics.h" +#include + +// 构造函数 +Metrics::Metrics(const MetricsConfig& config) + : registry_(std::make_shared()), + exposer_(config.endpoint), + stop_uptime_thread_(false), + start_time_(std::chrono::steady_clock::now()) { + // 定义统一的桶大小,最大为 10000 ms (10 s) + std::vector common_buckets = {0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0, 5.0, + 10.0, 50.0, 100.0, 500.0, 1000.0, 5000.0, 10000.0}; // 毫秒 + + // 注册 TTFT_ms Histogram + auto& TTFT_family = prometheus::BuildHistogram() + .Name(std::string(METRIC_PREFIX) + "_TTFT_ms") + .Help("Time to first token in milliseconds") + .Register(*registry_); + TTFT_ms = &TTFT_family.Add({{"model", config.model_name}}, common_buckets); + + // 注册 TBT_ms Histogram + auto& TBT_family = prometheus::BuildHistogram() + .Name(std::string(METRIC_PREFIX) + "_TBT_ms") + .Help("Time between tokens in milliseconds") + .Register(*registry_); + TBT_ms = &TBT_family.Add({{"model", config.model_name}}, common_buckets); + + // 注册 schedule_time Histogram + auto& schedule_time_family = prometheus::BuildHistogram() + .Name(std::string(METRIC_PREFIX) + "_schedule_time_ms") + .Help("Time to generate schedule in milliseconds") + .Register(*registry_); + schedule_time = &schedule_time_family.Add({{"model", config.model_name}}, common_buckets); + + // 注册 generated_tokens Counter + auto& generated_tokens_family = prometheus::BuildCounter() + .Name(std::string(METRIC_PREFIX) + "_generated_tokens_total") + .Help("Total generated tokens") + .Register(*registry_); + generated_tokens = &generated_tokens_family.Add({{"model", config.model_name}}); + + // 注册 throughput_query Gauge + auto& throughput_query_family = prometheus::BuildGauge() + .Name(std::string(METRIC_PREFIX) + "_throughput_query") + .Help("Throughput per second based on queries") + .Register(*registry_); + throughput_query = &throughput_query_family.Add({{"model", config.model_name}}); + + // 注册 throughput_generated_tokens Gauge + auto& throughput_generated_tokens_family = prometheus::BuildGauge() + .Name(std::string(METRIC_PREFIX) + "_throughput_generated_tokens") + .Help("Throughput per second based on generated tokens") + .Register(*registry_); + throughput_generated_tokens = &throughput_generated_tokens_family.Add({{"model", config.model_name}}); + + // 注册 event_count Counter family + event_count_family_ = &prometheus::BuildCounter() + .Name(std::string(METRIC_PREFIX) + "_event_count_total") + .Help("Count of various events") + .Register(*registry_); + + batch_count_family_ = &prometheus::BuildCounter() + .Name(std::string(METRIC_PREFIX) + "_batch_count_total") + .Help("Count of various batch by status") + .Register(*registry_); + + // 注册 query_count Counter family + query_count_family_ = &prometheus::BuildCounter() + .Name(std::string(METRIC_PREFIX) + "_query_count_total") + .Help("Count of queries by status") + .Register(*registry_); + + // 注册 uptime_ms Gauge + auto& uptime_family = prometheus::BuildGauge() + .Name(std::string(METRIC_PREFIX) + "_uptime_ms") + .Help("Uptime of the scheduler in milliseconds") + .Register(*registry_); + uptime_ms = &uptime_family.Add({{"model", config.model_name}}); + + // 注册 GPU 利用率 Gauges + auto& gpu_util_family = prometheus::BuildGauge() + .Name(std::string(METRIC_PREFIX) + "_gpu_utilization_ratio") + .Help("Current GPU utilization ratio (0 to 1)") + .Register(*registry_); + for (size_t i = 0; i < config.gpu_count; ++i) { + gpu_utilization_gauges.push_back( + &gpu_util_family.Add({{"gpu_id", std::to_string(i)}, {"model", config.model_name}})); + } + + // 将 Registry 注册到 Exposer 中 + exposer_.RegisterCollectable(registry_); + + // 启动 uptime 更新线程 + StartUptimeUpdater(); +} + +// 析构函数 +Metrics::~Metrics() { + StopUptimeUpdater(); +} + +// 启动 uptime 更新线程 +void Metrics::StartUptimeUpdater() { + uptime_thread_ = std::thread([this]() { + while (!stop_uptime_thread_) { + auto now = std::chrono::steady_clock::now(); + std::chrono::duration uptime_duration = now - start_time_; + uptime_ms->Set(uptime_duration.count()); + // fn_every_sec(this); + std::this_thread::sleep_for(std::chrono::seconds(1)); + } + }); +} + +// 停止 uptime 更新线程 +void Metrics::StopUptimeUpdater() { + stop_uptime_thread_ = true; + if (uptime_thread_.joinable()) { + uptime_thread_.join(); + } +} + +// 获取 event_count 指标 +prometheus::Counter* Metrics::event_count(const std::string& type) { + return &event_count_family_->Add({{"type", type}}); // 可根据需要添加更多标签 +} + +// 获取 query_count 指标 +prometheus::Counter* Metrics::query_count(const std::string& status) { + return &query_count_family_->Add({{"status", status}}); // 可根据需要添加更多标签 +} + +prometheus::Counter* Metrics::batch_count(const std::string& type) { + return &batch_count_family_->Add({{"type", type}}); +} diff --git a/csrc/balance_serve/sched/metrics.h b/csrc/balance_serve/sched/metrics.h new file mode 100644 index 0000000..226684c --- /dev/null +++ b/csrc/balance_serve/sched/metrics.h @@ -0,0 +1,85 @@ +#ifndef Metrics_H +#define Metrics_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "timer.hpp" +// 指标前缀宏定义 +#define METRIC_PREFIX "scheduler" +class Metrics; + +// 配置结构体 +struct MetricsConfig { + std::string endpoint; + std::string model_name; // 模型名称,如 "gpt-4" + size_t gpu_count; // GPU数量 +}; + +// Metrics 类,根据配置初始化 Prometheus 指标 +class Metrics { + public: + // 构造函数传入 MetricsConfig + Metrics(const MetricsConfig& config); + ~Metrics(); + + // 禁止拷贝和赋值 + Metrics(const Metrics&) = delete; + Metrics& operator=(const Metrics&) = delete; + + std::function fn_every_sec; + + // 指标指针 + prometheus::Gauge* uptime_ms; + prometheus::Histogram* TTFT_ms; + prometheus::Histogram* TBT_ms; + prometheus::Histogram* schedule_time; + prometheus::Gauge* throughput_query; + prometheus::Gauge* throughput_generated_tokens; + prometheus::Counter* generated_tokens; + std::vector gpu_utilization_gauges; + + // 计数器家族 + prometheus::Counter* event_count(const std::string& type); + prometheus::Counter* query_count(const std::string& status); + prometheus::Counter* batch_count(const std::string& type); + + private: + std::shared_ptr registry_; + prometheus::Exposer exposer_; + + // 计数器家族 + prometheus::Family* event_count_family_; + prometheus::Family* batch_count_family_; + prometheus::Family* query_count_family_; + + // 线程和控制变量用于更新 uptime_ms + std::thread uptime_thread_; + std::atomic stop_uptime_thread_; + + // 启动 uptime 更新线程 + void StartUptimeUpdater(); + // 停止 uptime 更新线程 + void StopUptimeUpdater(); + + // 记录程序启动时间 + std::chrono::steady_clock::time_point start_time_; +}; + +struct HistogramTimerWrapper { + prometheus::Histogram* histogram; + Timer timer; + inline HistogramTimerWrapper(prometheus::Histogram* histogram) : histogram(histogram), timer() { timer.start(); } + inline ~HistogramTimerWrapper() { histogram->Observe(timer.elapsedMs()); } +}; + +#endif // Metrics_H diff --git a/csrc/balance_serve/sched/model_config.h b/csrc/balance_serve/sched/model_config.h new file mode 100644 index 0000000..ff6e915 --- /dev/null +++ b/csrc/balance_serve/sched/model_config.h @@ -0,0 +1,113 @@ +#ifndef __MODEL_CONFIG_HPP_ +#define __MODEL_CONFIG_HPP_ + +#include +#include "nlohmann/json.hpp" + +#include +#include + +using DimSize = size_t; +using URL = std::string; +using ModelName = std::string; + +// We must assure this can be load by config.json +class ModelConfig { + public: + DimSize hidden_size; + DimSize intermediate_size; + size_t max_position_embeddings; + std::string model_type; + size_t num_attention_heads; + size_t num_hidden_layers; + size_t num_key_value_heads; + size_t vocab_size; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, intermediate_size, max_position_embeddings, model_type, + num_attention_heads, num_hidden_layers, num_key_value_heads, vocab_size); + + void load_from(std::filesystem::path path) { + std::ifstream i(path); + nlohmann::json j; + i >> j; + *this = j.get(); + } +}; + +using QuantType = std::string; +static const QuantType NoQuantType = ""; + +class QuantConfig { + public: + QuantType name; + + // For GEMV + QuantType type_of_dot_vector = NoQuantType; + inline bool can_be_used_as_matrix() { return type_of_dot_vector != NoQuantType; } + + bool can_be_used_as_vector; + + double bytes_per_element; + bool has_scale; + bool has_min; + + size_t block_element_count; + size_t block_element_size; + + URL reference = ""; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(QuantConfig, name, type_of_dot_vector, can_be_used_as_vector, + bytes_per_element, has_scale, has_min, block_element_count, + block_element_size, reference); +}; + +inline std::map quant_configs; +inline std::map model_configs; + +inline void load_quant_configs(std::filesystem::path path) { + nlohmann::json j; + if (std::filesystem::exists(path)) { + std::cout << __FUNCTION__ << " from " << path << std::endl; + std::ifstream i(path); + i >> j; + } else { + std::cout << __FUNCTION__ << " create new at " << path << std::endl; + } + + quant_configs = j.get>(); + std::cout << "Loaded Quant Configs" << std::endl; + for (auto& [k, v] : quant_configs) { + std::cout << " - " << k << std::endl; + } +} + +inline void dump_quant_configs(std::filesystem::path path) { + std::ofstream o(path); + nlohmann::json j = quant_configs; + o << j.dump(4); +} + +inline void load_model_configs(std::filesystem::path path) { + nlohmann::json j; + if (std::filesystem::exists(path)) { + std::cout << __FUNCTION__ << " from " << path << std::endl; + std::ifstream i(path); + i >> j; + } else { + std::cout << __FUNCTION__ << " create new at " << path << std::endl; + } + + model_configs = j.get>(); + std::cout << "Loaded Model Configs" << std::endl; + for (auto& [k, v] : model_configs) { + std::cout << " - " << k << std::endl; + } +} + +inline void dump_model_configs(std::filesystem::path path) { + std::ofstream o(path); + nlohmann::json j = model_configs; + o << j.dump(4); +} + +#endif \ No newline at end of file diff --git a/csrc/balance_serve/sched/scheduler.cpp b/csrc/balance_serve/sched/scheduler.cpp new file mode 100644 index 0000000..266157e --- /dev/null +++ b/csrc/balance_serve/sched/scheduler.cpp @@ -0,0 +1,916 @@ +#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO +#define FMT_HEADER_ONLY +#include "nlohmann/json.hpp" +#include "spdlog/spdlog.h" + +#include +#include "scheduler.h" + +#include +#include +#include +#include +#include +#include "arithmetic.hpp" +#include "atomic_ptr_with_flags.hpp" +#include "easy_format.hpp" +#include "metrics.h" +#include "mpsc.hpp" +#include "timer.hpp" + +#include "kvc2.h" + +using json = nlohmann::json; + +namespace scheduler { + +void Settings::auto_derive() { + gpu_device_count = gpu_device_id.size(); + if (torch::cuda::is_available()) { + size_t gpu_count = torch::cuda::device_count(); + SPDLOG_INFO("Number of available GPUs: {}, want {}", gpu_count, gpu_device_count); + if (gpu_count < gpu_device_count) { + SPDLOG_ERROR("Not enough GPUs available."); + exit(0); + } + for (size_t i = 0; i < gpu_device_count; i++) { + devices.push_back(torch::Device(torch::kCUDA, gpu_device_id[i])); + } + } else { + SPDLOG_ERROR("CUDA is not available on this system."); + exit(0); + } + + if (model_settings.num_k_heads % gpu_device_count != 0) { + SPDLOG_ERROR("num_k_heads {} is not divisible by gpu_device_count {}", model_settings.num_k_heads, + gpu_device_count); + assert(false); + } + + size_t gpu_memory_available = gpu_memory_size * memory_utilization_percentage; + if (gpu_memory_available * gpu_device_count < model_settings.params_nbytes()) { + SPDLOG_ERROR("GPU memory size {}G is smaller than {}G", gpu_memory_available * gpu_device_count / 1e9, + model_settings.params_nbytes() / 1e9); + assert(false); + } + + assert(model_settings.k_head_dim % model_settings.num_k_heads == 0); + size_t head_per_gpu = model_settings.num_k_heads / gpu_device_count; + size_t gpu_memory_for_kv_cache = gpu_memory_available /*- model_settings.params_nbytes() / gpu_device_count*/; + SPDLOG_INFO("Each GPU Total: {}MiB, Model Params: {}MiB, KVCache: {}MiB, Left: {}MiB", gpu_memory_size / (1 << 20), + model_settings.params_nbytes() / gpu_device_count / (1 << 20), gpu_memory_for_kv_cache / (1 << 20), + (gpu_memory_size - gpu_memory_available) / (1 << 20)); + size_t kv_cache_on_cnt = (size_t)(k_cache_on) + (size_t)(v_cache_on); + size_t max_total_kvcache_pages = + gpu_memory_for_kv_cache / (kv_cache_on_cnt * head_per_gpu * model_settings.k_head_dim * + model_settings.bytes_per_kv_cache_element * page_size * model_settings.layer_count); + if (total_kvcache_pages.has_value()) { + if (total_kvcache_pages.value() > max_total_kvcache_pages) { + SPDLOG_ERROR("total_kvcache_pages {} is larger than max_total_kvcache_pages {}", total_kvcache_pages.value(), + max_total_kvcache_pages); + assert(false); + } + } else { + total_kvcache_pages = max_total_kvcache_pages; + SPDLOG_INFO("total_kvcache_pages is auto derived as {}", max_total_kvcache_pages); + } + + if (page_size % 256 != 0) { + SPDLOG_ERROR("page_size {} is not divisible by 256", page_size); + assert(false); + } + if (page_size < 256) { + SPDLOG_ERROR("page_size {} is smaller than 256", page_size); + assert(false); + } +} + +std::string BatchQueryTodo::debug() { + std::string re = "BatchQueryTodo: "; + re += "QueryIDs: "; + for (auto& id : query_ids) { + re += std::to_string(id) + " "; + } + return re; +} + +bool BatchQueryTodo::empty() { + return prefill_mini_batches.empty() && decode_mini_batches.empty(); +} + +struct QueryMaintainer; + +struct Query { + QueryID id; + torch::Tensor query_token; + TokenLength prompt_length; + TokenLength no_kvcache_from; + TokenLength estimated_length; + + SampleOptions sample_options; + + UserID user_id; + std::optional SLO_TTFT_ms; + std::optional SLO_TBT_ms; + + std::vector> stop_criteria; + + // status + // Query status changed by this order + enum Status { Received, Preparing, Ready, Prefill, Decode, Done }; + Status plan_status = Received; + TokenLength active_position; // the position where no kvcache now + TokenLength plan_position; // the position where no kvcache now, in plan + size_t prepare_try_count = 0; + std::shared_ptr kvc2_handle = nullptr; + + // derived from kvc2_handle + torch::Tensor block_index; // block indexes + + struct QueryContext { + ModelName model_name; + QuantType quant_type; + kvc2::KVC2Interface* kvc2_interface; + QueryMaintainer* query_maintainer; + Metrics* met; + } ctx; + + void after_load(bool ok); + + void to_status(Status to); + + void export_metrics() { ctx.met->query_count(status_to_string(plan_status))->Increment(1); } + + Query(QueryID id, QueryAdd query_add, QueryContext context) + : id(id), + prompt_length(query_add.query_length), + no_kvcache_from(0), + estimated_length(query_add.estimated_length), + sample_options(query_add.sample_options), + user_id(query_add.user_id), + SLO_TTFT_ms(query_add.SLO_TTFT_ms), + SLO_TBT_ms(query_add.SLO_TBT_ms), + stop_criteria(query_add.stop_criteria), + ctx(context) { + std::vector shape = {int64_t(query_add.estimated_length)}; + query_token = torch::zeros(shape, torch::TensorOptions().dtype(torch::kInt32)); + assert(query_token.is_contiguous()); + if (query_token.is_contiguous() == false) { + SPDLOG_ERROR("Query Token must be contiguous!"); + exit(1); + } + + memcpy(query_token.data_ptr(), query_add.query_token.data(), query_add.query_length * sizeof(Token)); + + no_kvcache_from = 0; // maybe match prefix later + export_metrics(); + } + + Token& token_at(size_t idx) { return reinterpret_cast(query_token.data_ptr())[idx]; } + + void absorb_update(const QueryUpdate& update) { + SPDLOG_DEBUG("{}", update.debug()); + active_position = update.active_position; + kvc2_handle->append_tokens(&token_at(0), active_position); // active_position is length -1 + if (update.is_prefill) { + if (active_position == prompt_length) { + token_at(active_position) = update.generated_token; + ctx.met->generated_tokens->Increment(1); + } + } else { + token_at(active_position) = update.generated_token; + ctx.met->generated_tokens->Increment(1); + } + + if (update.decode_done || active_position == estimated_length - 1) { + to_status(Done); + } + } + + void absorb_prefill_task(const PrefillTask& task) { + auto& [id, start, length] = task; + this->plan_position = start + length; + if (this->plan_position == prompt_length) { + to_status(Decode); + } + } + + void absorb_decode_task([[maybe_unused]] const QueryID& task) { this->plan_position += 1; } + + PrefillTask get_prefill_task(size_t prefill_length) { + if (prefill_length + plan_position > prompt_length) { + prefill_length = prompt_length - plan_position; + } + return {id, plan_position, prefill_length}; + } + + static std::string status_to_string(Status status) { + switch (status) { + case Received: + return "Received"; + case Preparing: + return "Preparing"; + case Ready: + return "Ready"; + case Prefill: + return "Prefill"; + case Decode: + return "Decode"; + case Done: + return "Done"; + } + assert(false); + } + + void debug() { + std::string status_string = status_to_string(plan_status); + + SPDLOG_DEBUG( + "Query {}, prompt_length {}, estimated_length {}, plan status {}, plan position {} " + "active position {}", + id, prompt_length, estimated_length, status_string, plan_position, active_position); + } +}; + +std::string QueryUpdate::debug() const { + return fmt::format("Query {}, ok {}, is_prefill {}, done {}, active_position {}, gen token {}", id, ok, is_prefill, + decode_done, active_position, generated_token); +} + +using Q = std::shared_ptr; + +struct KVC2_Maintainer { + Settings settings; + + std::vector k_cache; + std::vector v_cache; + std::shared_ptr kvc2_interface; + + KVC2_Maintainer(Settings settings) : settings(settings) { + // SPDLOG_WARN("Creating KVC2 Instance {}", settings.kvc2_root_path); + assert(settings.kvc2_root_path.size() > 0); + + // SPDLOG_WARN("Sizeof KVC2Config {} upper", sizeof(kvc2::KVC2Config)); + kvc2::GPUPageCacheConfig gpu_cache_config{ + .gpu_only = settings.gpu_only, + .gpu_devices_id = settings.gpu_device_id, + .layer_count = settings.model_settings.layer_count, + .total_kvcache_pages = settings.total_kvcache_pages.value(), + .num_token_per_page = settings.page_size, + .num_k_heads = settings.model_settings.num_k_heads, + .k_head_dim = + settings.use_self_defined_head_dim ? settings.self_defined_head_dim : settings.model_settings.k_head_dim, + .full_kv_cache_on_each_gpu = settings.full_kv_cache_on_each_gpu, + .k_cache_on = settings.k_cache_on, + .v_cache_on = settings.v_cache_on, + .tensor_type = torch::kBFloat16, + }; + + auto model_configs_path = std::filesystem::path(settings.kvc2_config_path) / "model_configs.json"; + load_model_configs(model_configs_path); + auto my_model_config = ModelConfig(); + my_model_config.load_from(std::filesystem::path(settings.model_settings.model_path) / "config.json"); + model_configs[settings.model_name] = my_model_config; + dump_model_configs(model_configs_path); + + kvc2::KVC2Config kvc2_config = { + .k_cache_on = settings.k_cache_on, + .v_cache_on = settings.v_cache_on, + .gpu_only = settings.gpu_only, + .load_from_disk = settings.load_from_disk, + .save_to_disk = settings.save_to_disk, + .path = settings.kvc2_root_path, + .config_path = settings.kvc2_config_path, + .num_token_per_page = settings.page_size, + .memory_pool_size = size_t(settings.memory_pool_size_GB * 1e9), + .evict_count = settings.evict_count, + .gpu_cache_config = gpu_cache_config, + .metrics_port = settings.kvc2_metrics_port, + }; + kvc2_interface = kvc2::create_kvc2(kvc2_config); + if (settings.load_from_disk) + kvc2_interface->load(); + + SPDLOG_DEBUG("KVC2 created ok"); + + auto [k_cache, v_cache] = kvc2_interface->get_kvcache(); + this->k_cache = k_cache; + this->v_cache = v_cache; + } +}; + +using EventAddQuery = std::pair*>; +using EventUpdateQuery = BatchQueryUpdate; +using EventTakenBatch = std::shared_ptr; +struct EventPrepare { + QueryID query_id; + bool first_try; +}; +struct EventPrepared { + QueryID query_id; + bool ok; +}; + +struct EventQueryStatus{ + QueryID query_id; + Query::Status now_status; +}; +struct EventSchedule {}; + +using Event = std::variant; + +template +std::string event_name(const T& event); + +template <> +std::string event_name(const EventAddQuery&) { + return "EventAddQuery"; +} + +template <> +std::string event_name(const EventUpdateQuery&) { + return "EventUpdateQuery"; +} + +template <> +std::string event_name(const EventTakenBatch&) { + return "EventTakenBatch"; +} +template <> +std::string event_name(const EventPrepare&) { + return "EventPrepare"; +} + +template <> +std::string event_name(const EventPrepared&) { + return "EventPrepared"; +} + +template <> +std::string event_name(const EventQueryStatus&) { + return "EventQueryStatus"; +} + +template <> +std::string event_name(const EventSchedule&) { + return "EventSchedule"; +} + +// 用 std::visit 实现对 variant 的 event_name +std::string event_name(const Event& event) { + return std::visit([](const auto& e) { return event_name(e); }, event); +} + +static_assert(std::is_copy_constructible::value); +static_assert(std::is_move_constructible::value); + +struct QueryMaintainer : public Scheduler { + // only get access by event loop + Settings settings; + QueryID query_id_counter = NoQueryID + 1; + std::map query_map; + std::shared_ptr kvc2_maintainer; + + std::shared_ptr met; + // multi-thread visit + std::atomic_bool stop_flag = false; + // TODO consider correctness of event loop + MPSCQueueConsumerLock event_loop_queue; + + // std::binary_semaphore batch_ready{0}; + AtomicPtrWithFlag next_batch; + + QueryMaintainer() = default; + + void gen_batch_query_todo(BatchQueryTodo* re, const std::set& queries) { + std::vector> d_batch(2); + size_t last_decode_batch = 0; + size_t prefill_num = 0; + size_t decode_num = 0; + size_t preill_length = 0; + for (auto& q : queries) { + if (q->plan_status == Query::Prefill) { + prefill_num += 1; + } + if (q->plan_status == Query::Decode) { + decode_num += 1; + } + } + if (prefill_num >= 2 || (prefill_num ==1 && settings.max_batch_size - 2 < decode_num)) { + preill_length = settings.recommended_chunk_prefill_token_count; + } + else { + preill_length = settings.recommended_chunk_prefill_token_count * 2; + } + for (auto& q : queries) { + re->query_ids.push_back(q->id); + re->query_tokens.push_back(q->query_token); + re->query_lengths.push_back(q->prompt_length); + if (q->plan_status == Query::Prefill) { + re->prefill_mini_batches.push_back(q->get_prefill_task(preill_length)); + assert(re->prefill_mini_batches.size() <= 2); + } + if (q->plan_status == Query::Decode) { + d_batch[last_decode_batch].push_back(q->id); + // last_decode_batch = 1 - last_decode_batch; + if (d_batch[last_decode_batch].size() == settings.max_batch_size - 1) { + last_decode_batch += 1; + assert(last_decode_batch < 2); + } + } + re->block_indexes.push_back(q->block_index); + re->sample_options.push_back(q->sample_options); + re->stop_criteria.push_back(q->stop_criteria); + } + + re->attn_masks = std::nullopt; + re->rope_ranges = std::nullopt; + + for (auto& b : d_batch) { + if (b.empty()) + continue; + re->decode_mini_batches.push_back(b); + } + + met->batch_count("Generated")->Increment(1); + } + + // Interface + + void init(Settings settings) override { + SPDLOG_INFO( + "\nScheduler Settings:\n" + " model_name: {}\n" + " quant_type: {}\n" + " model_path: {}\n" + " params_count: {}\n" + " layer_count: {}\n" + " num_k_heads: {}\n" + " k_head_dim: {}\n" + " bytes_per_params: {}\n" + " bytes_per_kv_cache_element: {}\n" + " page_size: {}\n" + " gpu_device_id: {}\n" + " gpu_memory_size: {}\n" + " memory_utilization_percentage: {}\n" + " max_batch_size: {}\n" + " recommended_chunk_prefill_token_count: {}\n" + " sched_metrics_port: {}\n" + " kvc2_config_path: {}\n" + " kvc2_root_path: {}\n" + " memory_pool_size_GB: {}\n" + " evict_count: {}\n" + " kvc2_metrics_port: {}\n" + " load_from_disk: {}\n" + " save_to_disk: {}\n" + " strategy_name: {}\n" + " gpu_device_count: {}\n", + settings.model_name, settings.quant_type, settings.model_settings.model_path, + settings.model_settings.params_count, settings.model_settings.layer_count, settings.model_settings.num_k_heads, + settings.model_settings.k_head_dim, settings.model_settings.bytes_per_params, + settings.model_settings.bytes_per_kv_cache_element, + + settings.page_size, format_vector(settings.gpu_device_id), readable_number(settings.gpu_memory_size), + settings.memory_utilization_percentage, settings.max_batch_size, settings.recommended_chunk_prefill_token_count, + settings.sched_metrics_port, settings.kvc2_config_path, settings.kvc2_root_path, settings.memory_pool_size_GB, + settings.evict_count, settings.kvc2_metrics_port, settings.load_from_disk, settings.save_to_disk, + settings.strategy_name, settings.gpu_device_count); + + this->settings = settings; + kvc2_maintainer = std::shared_ptr(new KVC2_Maintainer(settings)); + MetricsConfig met_conf = { + .endpoint = "0.0.0.0:" + std::to_string(settings.sched_metrics_port), + .model_name = settings.model_name, + .gpu_count = settings.gpu_device_count, + }; + + SPDLOG_INFO("Creating scheduler metrics exporter on {}", met_conf.endpoint); + met = std::make_shared(met_conf); + met->fn_every_sec = [](Metrics* met) { + auto generated_tokens = met->generated_tokens->Collect().counter.value; + SPDLOG_INFO("Last Sec Generated Tokens {}", generated_tokens); + }; + } + Query::QueryContext get_query_context() { + return Query::QueryContext{ + .model_name = settings.model_name, + .quant_type = settings.quant_type, + .kvc2_interface = kvc2_maintainer->kvc2_interface.get(), + .query_maintainer = this, + .met = met.get(), + }; + } + + QueryID add_query(QueryAdd query_add) override { + std::promise p; + event_loop_queue.enqueue(EventAddQuery(query_add, &p)); + return p.get_future().get(); + } + + void cancel_query(QueryID id) override { + SPDLOG_INFO("Cancel Query"); + SPDLOG_INFO("sched:{} Cancel Query", fmt::ptr(this)); + auto it = query_map.find(id); + if (it == query_map.end()) { + SPDLOG_ERROR("Query {} is not found", id); + return; + } + query_map.erase(it); + } + + // Here this function update last batch results and get the next batch + // in most cases, the batch is ready, + // if not, busy wait to get it + std::shared_ptr update_last_batch(BatchQueryUpdate updates) override { + event_loop_queue.enqueue(updates); + + // Busy Wait + while (true) { + auto [ptr, is_new] = next_batch.touch_load(); + // SPDLOG_INFO("ptr {} is_new {}", fmt::ptr(ptr), is_new); + if (is_new) { + // SPDLOG_DEBUG("New Batch {}", fmt::ptr(ptr)); + auto re = std::shared_ptr(ptr); + event_loop_queue.enqueue(re); + return re; + } else { + // // here to busy wait + // SPDLOG_INFO("Not New"); + // using namespace std::chrono_literals; + // std::this_thread::sleep_for(1s); + } + } + } + + InferenceContext get_inference_context() override { + InferenceContext re; + re.k_cache = kvc2_maintainer->k_cache; + re.v_cache = kvc2_maintainer->v_cache; + // kvc2_maintainer->k_cache[0][0][0][0][0][0] = 42; // test whether we pass this to inference loop + return re; + } + + virtual void strategy_add_query(Q new_query) = 0; + virtual void strategy_update_query(const EventUpdateQuery& update) = 0; + virtual void strategy_taken_batch(const EventTakenBatch& batch) = 0; + virtual void strategy_prepare(const EventPrepare& prepare) = 0; + virtual void strategy_prepared(const EventPrepared& prepared) = 0; + virtual void strategy_query_status(const EventQueryStatus& query_status) = 0; + virtual void strategy_schedule(const EventSchedule& event, BatchQueryTodo* new_batch) = 0; + + void tackle_event(EventAddQuery& event) { + auto& query_add = event.first; + QueryID id = query_id_counter; + event.second->set_value(id); + query_id_counter += 1; + Q new_query(new Query(id, query_add, get_query_context())); + query_map[id] = new_query; + SPDLOG_INFO("New Query {} is added", id); + strategy_add_query(new_query); + } + + void tackle_event(const EventUpdateQuery& update) { + // SPDLOG_INFO("Tackle Update Query"); + for (auto& u : update) { + if (u.ok == false) { + SPDLOG_ERROR("Query {} is not exectued OK", u.id); + exit(1); + } + auto q = query_map[u.id]; + if (q->plan_status == Query::Status::Prefill || q->plan_status == Query::Status::Decode) { + q->absorb_update(u); + } else { + SPDLOG_DEBUG("Query {} is not in Prefill or Decode status, do not update it", u.id); + } + } + strategy_update_query(update); + } + + void tackle_event(const EventTakenBatch& batch) { + met->batch_count("Taken")->Increment(1); + for (auto& task : batch->prefill_mini_batches) { + auto [id, s, l] = task; + if (l == 0) + continue; + query_map.at(id)->absorb_prefill_task(task); + } + for (auto& mini_batch : batch->decode_mini_batches) { + for (auto& id : mini_batch) { + query_map.at(id)->absorb_decode_task(id); + } + } + + strategy_taken_batch(batch); + } + + void tackle_event(const EventPrepare& event) { strategy_prepare(event); } + void tackle_event(const EventPrepared& event) { strategy_prepared(event); } + void tackle_event(const EventQueryStatus& event) { strategy_query_status(event); } + + void tackle_event(const EventSchedule& event) { + // SPDLOG_INFO("Tackle Schedule Event"); + + HistogramTimerWrapper t(met->schedule_time); + + BatchQueryTodo* new_batch = new BatchQueryTodo; + strategy_schedule(event, new_batch); + // if (new_batch->query_ids.empty()) { + // SPDLOG_INFO("Nothing todo"); + // delete new_batch; + // return; + // } + auto [old_batch, flag] = next_batch.exchange(new_batch, true); + if (new_batch->empty() == false) { + SPDLOG_DEBUG("set new batch {}", fmt::ptr(new_batch)); + } + if (flag) { + SPDLOG_INFO("Batch {} is not consumed", fmt::ptr(old_batch)); + delete old_batch; + } + } + + void run() override { + std::thread([this]() { + SPDLOG_WARN("Starting Scheduler Event Loop"); + while (stop_flag.load() == false) { + auto event = event_loop_queue.dequeue(); + met->event_count(event_name(event))->Increment(1); + std::visit( + [this](auto event) { + using T = std::decay_t; + // SPDLOG_INFO("Event Loop: {}", typeid(T).name()); + if constexpr (std::is_same_v) { + tackle_event(event); + } else if constexpr (std::is_same_v) { + tackle_event(event); + } else if constexpr (std::is_same_v) { + tackle_event(event); + } else if constexpr (std::is_same_v) { + tackle_event(event); + } else if constexpr (std::is_same_v) { + tackle_event(event); + } else if constexpr (std::is_same_v) { + tackle_event(event); + } else if constexpr (std::is_same_v) { + tackle_event(event); + } else { + SPDLOG_ERROR("Should not be here"); + assert(false); + } + }, + event); + if (event_loop_queue.size() == 0 && std::holds_alternative(event) == false) { + // if this is not a schedule event, we need to schedule one + event_loop_queue.enqueue(EventSchedule()); + } + } + }).detach(); + } + + void stop() override { stop_flag.store(true); } + + ~QueryMaintainer() { + kvc2_maintainer->kvc2_interface->save(); + stop(); + } +}; + +void Query::to_status(Status to) { + SPDLOG_DEBUG("Calling to status query {}, to {}", id, status_to_string(to)); + switch (to) { + case Received: + assert(false); + break; + case Preparing: + SPDLOG_INFO("Preparing Query {} {}", id, + prepare_try_count > 0 ? (std::to_string(prepare_try_count) + " Try") : ""); + prepare_try_count += 1; + + ctx.kvc2_interface->lookup_to_gpu_async( + ctx.model_name, ctx.quant_type, static_cast(query_token.data_ptr()), prompt_length, + estimated_length, [this](std::shared_ptr handle) { + if (handle == nullptr) { + SPDLOG_INFO("Get handle from kvc2 Failed."); + this->after_load(false); + } else { + SPDLOG_INFO("Get handle from kvc2 Success."); + this->kvc2_handle = handle; + this->to_status(Ready); + this->after_load(true); + } + }); + break; + case Ready: + SPDLOG_INFO("Ready Query {}", id); + break; + case Prefill: + SPDLOG_INFO("Prefilling Query {}", id); + // assert(plan_status == Received); + plan_position = kvc2_handle->matched_length(); + + if (prompt_length - plan_position == 0) { + assert(prompt_length > 0); + plan_position -= 1; + } + break; + case Decode: + SPDLOG_INFO("Decoding Query {}", id); + // assert(plan_status == Prefill); + break; + case Done: + SPDLOG_INFO("Finish Query {}", id); + kvc2_handle = nullptr; + ctx.query_maintainer->event_loop_queue.enqueue(EventQueryStatus{ + .query_id = id, + .now_status = to, + }); + // assert(plan_status == Decode); + break; + } + plan_status = to; + export_metrics(); +} + +void Query::after_load(bool ok) { + if (ok) { + size_t page_count = div_up(estimated_length, ctx.query_maintainer->settings.page_size); + std::vector shape; + shape.push_back(page_count); + block_index = torch::zeros(shape, torch::TensorOptions().dtype(torch::kInt32)).contiguous(); + auto ptr = reinterpret_cast(block_index.data_ptr()); + auto vec_idx = kvc2_handle->get_gpu_block_idx(); + for (size_t i = 0; i < vec_idx.size(); i++) { + ptr[i] = vec_idx[i]; + } + no_kvcache_from = kvc2_handle->matched_length(); + } + if (ok) { + ctx.query_maintainer->event_loop_queue.enqueue(EventPrepared{ + .query_id = id, + .ok = ok, + }); + } else { + ctx.query_maintainer->event_loop_queue.enqueue(EventPrepare{ + .query_id = id, + .first_try = false, + }); + } +} + +struct FCFS_single_prefill : public QueryMaintainer { + std::queue queue; + std::queue ready_queue; + + bool has_query_preparing = false; + std::optional wait_done_prepare = std::nullopt; + + std::set active_query; // on going queries for LLMs + + // interface all these are executed in a single thread + void strategy_add_query(Q new_query) override { + queue.push(new_query); + if (has_query_preparing == false) { + has_query_preparing = true; + auto next_q = queue.front(); + queue.pop(); + event_loop_queue.enqueue(EventPrepare{next_q->id,true}); + } + } + + void strategy_update_query(const EventUpdateQuery& update) override { + for (auto u : update) { + auto& q = query_map[u.id]; + if (q->plan_status == Query::Done) { + active_query.erase(q); + } + } + } + + void strategy_taken_batch(const EventTakenBatch& batch) override { + for (auto& q : batch->query_ids) { + if (query_map[q]->plan_status != Query::Done) { + active_query.insert(query_map[q]); + } + } + } + + void strategy_prepare(const EventPrepare& prepare) override { + if(prepare.first_try){ + auto& q = query_map[prepare.query_id]; + q->to_status(Query::Preparing); + }else{ + assert(wait_done_prepare.has_value()==false); + wait_done_prepare = prepare; + wait_done_prepare->first_try = true; + } + } + + void strategy_prepared(const EventPrepared& prepared) override { + assert(prepared.ok); + ready_queue.push(query_map[prepared.query_id]); + if (queue.empty() == false) { + auto next_q_prepare = queue.front(); + queue.pop(); + event_loop_queue.enqueue(EventPrepare{next_q_prepare->id,true}); + + } else { + has_query_preparing = false; + } + } + + void strategy_query_status(const EventQueryStatus& query_status) override{ + if(query_status.now_status==Query::Done){ + if(wait_done_prepare.has_value()){ + event_loop_queue.enqueue(wait_done_prepare.value()); + wait_done_prepare = std::nullopt; + } + } + + } + + void strategy_schedule([[maybe_unused]] const EventSchedule& event, BatchQueryTodo* new_batch) override { + bool have_prefill = false; + for (auto& q : active_query) { + if (q->plan_status == Query::Prefill) { + have_prefill = true; + } + } + + if (have_prefill == false && ready_queue.empty() == false && active_query.size() < settings.max_batch_size) { + auto& next_q = ready_queue.front(); + ready_queue.pop(); + + SPDLOG_INFO("Active query {}", next_q->id); + active_query.insert(next_q); + next_q->to_status(Query::Prefill); + } + if (active_query.empty() == false) + SPDLOG_INFO("Active Query Size {}", active_query.size()); + for (auto& q : active_query) { + q->debug(); + } + gen_batch_query_todo(new_batch, active_query); + } +}; + +struct FCFS : public FCFS_single_prefill { + void strategy_schedule([[maybe_unused]] const EventSchedule& event, BatchQueryTodo* new_batch) override { + int prefill_count = 0; + const int max_prefill_count = 2; + for (auto& q : active_query) { + if (q->plan_status == Query::Prefill) { + prefill_count += 1; + } + } + + while (prefill_count < max_prefill_count && ready_queue.empty() == false && + active_query.size() < settings.max_batch_size) { + auto next_q = ready_queue.front(); + ready_queue.pop(); + + SPDLOG_INFO("Active query {}", next_q->id); + active_query.insert(next_q); + next_q->to_status(Query::Prefill); + prefill_count += 1; + } + if (active_query.empty() == false) { + SPDLOG_DEBUG("Active Query Size {}", active_query.size()); + } + for (auto& q : active_query) { + q->debug(); + } + gen_batch_query_todo(new_batch, active_query); + } +}; + +std::shared_ptr create_scheduler(Settings settings) { + spdlog::set_level(spdlog::level::debug); + std::shared_ptr re; + SPDLOG_INFO("Using Strategy {}", settings.strategy_name); + if (settings.strategy_name == "FCFS-single-prefill") { + re = std::shared_ptr(new FCFS_single_prefill()); + } else if (settings.strategy_name == "FCFS") { + re = std::shared_ptr(new FCFS()); + } else { + SPDLOG_ERROR("Unknown strategy {}", settings.strategy_name); + } + re->init(settings); + return re; +} + +NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(SampleOptions, temperature, top_p); +NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(QueryAdd, query_token, query_length, estimated_length, sample_options, user_id, + SLO_TTFT_ms, SLO_TBT_ms); + +std::string QueryAdd::serialize() { + json j = *this; + return j.dump(); +} + +QueryAdd QueryAdd::deserialize(const std::string& input) { + json j = json::parse(input); + return j.get(); +} + +}; // namespace scheduler diff --git a/csrc/balance_serve/sched/scheduler.h b/csrc/balance_serve/sched/scheduler.h new file mode 100644 index 0000000..0889c08 --- /dev/null +++ b/csrc/balance_serve/sched/scheduler.h @@ -0,0 +1,167 @@ +#pragma once +#include +#include +#include +#include +#include +#include "model_config.h" + +namespace scheduler { + +using Token = uint32_t; +using QueryID = uint64_t; +constexpr QueryID NoQueryID = 0; + +using TokenLength = size_t; +using BatchID = uint64_t; + +using PageCount = size_t; + +struct ModelSettings { + std::string model_path; + size_t params_count; + size_t layer_count; + size_t num_k_heads; + size_t k_head_dim; + + double bytes_per_params; + double bytes_per_kv_cache_element; + + inline size_t params_nbytes() { return params_count * bytes_per_params; } + inline size_t bytes_per_token_kv_cache() { return bytes_per_kv_cache_element * num_k_heads * k_head_dim; } +}; + +struct SampleOptions { + double temperature = 1.0; + double top_p = 1.0; +}; + +struct Settings { + // something is aukward here, kvc2 only use model_name and quant_type to get model infos. + ModelName model_name; + QuantType quant_type; + // model_setting is ignore by kvc2 + ModelSettings model_settings; + + size_t page_size = 256; // how many token in a page + std::vector gpu_device_id; // + size_t gpu_memory_size; // memory size in bytes of each GPU, each + double memory_utilization_percentage; + + size_t max_batch_size = 256; + + size_t recommended_chunk_prefill_token_count; + SampleOptions sample_options; + size_t sched_metrics_port; + + // for kvc2 + bool gpu_only; + bool use_self_defined_head_dim = false; + size_t self_defined_head_dim; + bool full_kv_cache_on_each_gpu = false; + bool k_cache_on = true; + bool v_cache_on = true; + std::string kvc2_config_path; + std::string kvc2_root_path; + double memory_pool_size_GB = 100; + size_t evict_count = 20; + size_t kvc2_metrics_port; + bool load_from_disk = false; + bool save_to_disk = false; + + // for strategy + std::string strategy_name; + + // derived + size_t gpu_device_count; + std::optional total_kvcache_pages; + std::vector devices; + void auto_derive(); +}; + +using PrefillTask = std::tuple; // id, start, length + +struct BatchQueryTodo { + // query + std::vector query_ids; + std::vector query_tokens; + std::vector query_lengths; + std::vector block_indexes; // (max_num_blocks_per_seq), dtype torch.int32. + std::optional attn_masks; + std::optional rope_ranges; + std::vector sample_options; + std::vector>> stop_criteria; + + // mini batches, adjacent two mini batches are executed together + // tasks count must be <=2, because of flash infer attention + std::vector prefill_mini_batches; // prefill minibatch only has 1 prefill + std::vector> decode_mini_batches; // decode minibatch has multiple decode + + std::string debug(); + bool empty(); +}; + +struct QueryUpdate { + QueryID id; + bool ok; + bool is_prefill; + bool decode_done; // no use for now + TokenLength active_position; // the position where no kvcache now, + // kvcache[active_position] == None + + Token generated_token; + + std::string debug() const; +}; + +using BatchQueryUpdate = std::vector; + +struct InferenceContext { + std::vector k_cache; // [gpu num] (layer_count, num blocks, + // page size, kheadnum, head_dim) + std::vector v_cache; +}; + +using UserID = int64_t; +constexpr UserID NoUser = -1; +const int MAX_SLO_TIME = 1e9; + +struct QueryAdd { + std::vector query_token; // int here + // torch::Tensor attn_mask; + TokenLength query_length; + TokenLength estimated_length; + + std::vector> stop_criteria; + + SampleOptions sample_options; + + UserID user_id; + int SLO_TTFT_ms = MAX_SLO_TIME; + int SLO_TBT_ms = MAX_SLO_TIME; + + std::string serialize(); + static QueryAdd deserialize(const std::string& input); +}; + +class Scheduler { + public: + virtual void init(Settings settings) = 0; + + virtual void run() = 0; + virtual void stop() = 0; + + // webserver call this + virtual QueryID add_query(QueryAdd query) = 0; + virtual void cancel_query(QueryID id) = 0; + + // inference loop call this + virtual std::shared_ptr update_last_batch(BatchQueryUpdate updates) = 0; + virtual InferenceContext get_inference_context() = 0; + + virtual ~Scheduler() = default; +}; + +std::shared_ptr create_scheduler(Settings settings); + +}; // namespace scheduler \ No newline at end of file diff --git a/csrc/balance_serve/sched/utils/all.hpp b/csrc/balance_serve/sched/utils/all.hpp new file mode 100644 index 0000000..903bc40 --- /dev/null +++ b/csrc/balance_serve/sched/utils/all.hpp @@ -0,0 +1,3 @@ +#pragma once +#include "readable_number.hpp" +#include "timer.hpp" \ No newline at end of file diff --git a/csrc/balance_serve/sched/utils/arithmetic.hpp b/csrc/balance_serve/sched/utils/arithmetic.hpp new file mode 100644 index 0000000..7562f56 --- /dev/null +++ b/csrc/balance_serve/sched/utils/arithmetic.hpp @@ -0,0 +1,8 @@ +#include + +template +T div_up(T x, U by) { + static_assert(std::is_integral_v); + static_assert(std::is_integral_v); + return (x + by - 1) / by; +} \ No newline at end of file diff --git a/csrc/balance_serve/sched/utils/atomic_ptr_with_flags.hpp b/csrc/balance_serve/sched/utils/atomic_ptr_with_flags.hpp new file mode 100644 index 0000000..f0c98bf --- /dev/null +++ b/csrc/balance_serve/sched/utils/atomic_ptr_with_flags.hpp @@ -0,0 +1,28 @@ +#include + +template +struct AtomicPtrWithFlag { + constexpr static uint64_t mask = 1ull << 63; + std::atomic_uint64_t ptr = 0; + + std::pair load(std::memory_order order = std::memory_order_seq_cst) { + uint64_t val = ptr.load(order); + return {reinterpret_cast(val & (~mask)), val & mask}; + } + + void store(T* p, bool flag, std::memory_order order = std::memory_order_seq_cst) { + ptr.store(reinterpret_cast(p) | (flag ? mask : 0), order); + } + + std::pair exchange(T* p, bool flag, std::memory_order order = std::memory_order_seq_cst) { + uint64_t val = ptr.exchange(reinterpret_cast(p) | (flag ? mask : 0), order); + return {reinterpret_cast(val & (~mask)), val & mask}; + } + + std::pair touch_load(std::memory_order order = std::memory_order_seq_cst) { + uint64_t val = ptr.fetch_and(~mask, order); + return {reinterpret_cast(val & (~mask)), val & mask}; + } + + bool check_flag(std::memory_order order = std::memory_order_seq_cst) { return ptr.load(order) & mask; } +}; diff --git a/csrc/balance_serve/sched/utils/csv.hpp b/csrc/balance_serve/sched/utils/csv.hpp new file mode 100644 index 0000000..d9dd558 --- /dev/null +++ b/csrc/balance_serve/sched/utils/csv.hpp @@ -0,0 +1,225 @@ +#ifndef CSV_READER_HPP +#define CSV_READER_HPP + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace csv { + +/** + * @brief Parses a CSV line into individual fields, handling quoted fields with + * commas and newlines. + * + * @param line The CSV line to parse. + * @return A vector of strings, each representing a field in the CSV line. + */ +inline std::vector parse_csv_line(const std::string& line) { + std::vector result; + std::string field; + bool in_quotes = false; + + for (size_t i = 0; i < line.length(); ++i) { + char c = line[i]; + + if (c == '"') { + // Handle double quotes inside quoted fields + if (in_quotes && i + 1 < line.length() && line[i + 1] == '"') { + field += '"'; + ++i; + } else { + in_quotes = !in_quotes; + } + } else if (c == ',' && !in_quotes) { + result.push_back(field); + field.clear(); + } else { + field += c; + } + } + result.push_back(field); + return result; +} + +/** + * @brief Reads a CSV file and returns a vector of pairs containing column names + * and their corresponding data vectors. + * + * This function reads the header to obtain column names and uses multithreading + * to read and parse the CSV file in chunks. + * + * @param filename The path to the CSV file. + * @return A vector of pairs, each containing a column name and a vector of data + * for that column. + */ +inline std::vector>> read_csv(const std::string& filename) { + std::cout << "Reading CSV file: " << filename << std::endl; + // Open the file + std::ifstream file(filename); + if (!file) { + throw std::runtime_error("Cannot open file"); + } + + // Read the header line and parse column names + std::string header_line; + std::getline(file, header_line); + std::vector column_names = parse_csv_line(header_line); + + // Prepare the result vector with column names + std::vector>> result; + for (const auto& name : column_names) { + result.emplace_back(name, std::vector()); + } + + // Read the rest of the file into a string buffer + std::stringstream buffer; + buffer << file.rdbuf(); + std::string content = buffer.str(); + + // Determine the number of threads to use + unsigned int num_threads = std::thread::hardware_concurrency(); + if (num_threads == 0) + num_threads = 4; // Default to 4 threads if hardware_concurrency returns 0 + + // Calculate chunk start positions based on content size + std::vector chunk_starts; + size_t content_size = content.size(); + size_t chunk_size = content_size / num_threads; + + chunk_starts.push_back(0); + for (unsigned int i = 1; i < num_threads; ++i) { + size_t pos = i * chunk_size; + // Adjust position to the next newline character to ensure we start at the + // beginning of a line + while (pos < content_size && content[pos] != '\n') { + ++pos; + } + if (pos < content_size) { + ++pos; // Skip the newline character + } + chunk_starts.push_back(pos); + } + chunk_starts.push_back(content_size); + + // Create threads to parse each chunk + std::vector>> thread_results(num_threads); + std::vector threads; + + for (unsigned int i = 0; i < num_threads; ++i) { + size_t start = chunk_starts[i]; + size_t end = chunk_starts[i + 1]; + + threads.emplace_back([&content, start, end, &thread_results, i]() { + std::vector> local_result; + size_t pos = start; + while (pos < end) { + size_t next_pos = content.find('\n', pos); + if (next_pos == std::string::npos || next_pos > end) { + next_pos = end; + } + std::string line = content.substr(pos, next_pos - pos); + if (!line.empty()) { + local_result.push_back(parse_csv_line(line)); + } + pos = next_pos + 1; + } + thread_results[i] = std::move(local_result); + }); + } + + // Wait for all threads to finish + for (auto& t : threads) { + t.join(); + } + + // Combine the results from all threads into the final result + for (const auto& local_result : thread_results) { + for (const auto& row : local_result) { + for (size_t i = 0; i < row.size(); ++i) { + if (i < result.size()) { + result[i].second.push_back(row[i]); + } + } + } + } + + return result; +} + +/** + * @brief Writes the CSV data into a file. + * + * @param filename The path to the output CSV file. + * @param data A vector of pairs, each containing a column name and a vector of + * data for that column. + */ +inline void write_csv(const std::string& filename, + const std::vector>>& data) { + std::cout << "Writing CSV file: " << filename << std::endl; + + // Open the file for writing + std::ofstream file(filename); + if (!file) { + throw std::runtime_error("Cannot open file for writing"); + } + + // Check that all columns have the same number of rows + if (data.empty()) { + return; // Nothing to write + } + size_t num_rows = data[0].second.size(); + for (const auto& column : data) { + if (column.second.size() != num_rows) { + throw std::runtime_error("All columns must have the same number of rows"); + } + } + + // Write the header + for (size_t i = 0; i < data.size(); ++i) { + file << data[i].first; + if (i != data.size() - 1) { + file << ','; + } + } + file << '\n'; + + // Write the data rows + for (size_t row = 0; row < num_rows; ++row) { + for (size_t col = 0; col < data.size(); ++col) { + const std::string& field = data[col].second[row]; + // Handle CSV escaping + std::string escaped_field = field; + bool needs_quotes = false; + if (escaped_field.find('"') != std::string::npos) { + needs_quotes = true; + // Escape double quotes + size_t pos = 0; + while ((pos = escaped_field.find('"', pos)) != std::string::npos) { + escaped_field.insert(pos, "\""); + pos += 2; + } + } + if (escaped_field.find(',') != std::string::npos || escaped_field.find('\n') != std::string::npos) { + needs_quotes = true; + } + if (needs_quotes) { + file << '"' << escaped_field << '"'; + } else { + file << escaped_field; + } + if (col != data.size() - 1) { + file << ','; + } + } + file << '\n'; + } +} + +} // namespace csv + +#endif // CSV_READER_HPP diff --git a/csrc/balance_serve/sched/utils/easy_format.hpp b/csrc/balance_serve/sched/utils/easy_format.hpp new file mode 100644 index 0000000..d541410 --- /dev/null +++ b/csrc/balance_serve/sched/utils/easy_format.hpp @@ -0,0 +1,16 @@ +#include +#include +#include + +template +std::string format_vector(const std::vector& v) { + std::ostringstream oss; + if (v.empty()) + return "[]"; + for (size_t i = 0; i < v.size(); ++i) { + oss << v[i]; + if (i < v.size() - 1) + oss << ", "; // 逗号分隔 + } + return oss.str(); +} diff --git a/csrc/balance_serve/sched/utils/mpsc.hpp b/csrc/balance_serve/sched/utils/mpsc.hpp new file mode 100644 index 0000000..4a3b476 --- /dev/null +++ b/csrc/balance_serve/sched/utils/mpsc.hpp @@ -0,0 +1,109 @@ +#include +#include +#include +#include +#include + +template +class MPSCQueue { + struct Node { + T data; + std::atomic next; + + Node() : next(nullptr) {} + Node(T data_) : data(std::move(data_)), next(nullptr) {} + }; + + std::atomic head; + Node* tail; + + public: + std::atomic_size_t enqueue_count = 0; + size_t dequeue_count = 0; + MPSCQueue() { + Node* dummy = new Node(); + head.store(dummy, std::memory_order_seq_cst); + tail = dummy; + } + + ~MPSCQueue() { + Node* node = tail; + while (node) { + Node* next = node->next.load(std::memory_order_seq_cst); + delete node; + node = next; + } + } + + // 生产者调用 + void enqueue(T data) { + enqueue_count.fetch_add(1); + Node* node = new Node(std::move(data)); + Node* prev_head = head.exchange(node, std::memory_order_seq_cst); + prev_head->next.store(node, std::memory_order_seq_cst); + } + + // 消费者调用 + std::optional dequeue() { + Node* next = tail->next.load(std::memory_order_seq_cst); + if (next) { + T res = std::move(next->data); + delete tail; + tail = next; + dequeue_count += 1; + return res; + } + return std::nullopt; + } + + size_t size() { return enqueue_count.load() - dequeue_count; } +}; + +template +class MPSCQueueConsumerLock { + MPSCQueue queue; + std::counting_semaphore<> sema{0}; + + public: + void enqueue(T data) { + queue.enqueue(std::move(data)); + // std::atomic_thread_fence(std::memory_order_seq_cst);// Inserting this because the memory order might be wrong, I + // am also not that sure about this. + sema.release(); + } + + T dequeue() { + auto re = queue.dequeue(); + if (re.has_value()) { + while (sema.try_acquire() == false) { + std::cerr << __FILE__ << ":" << __FUNCTION__ << " sema try acquire should be success, retrying, please check" + << std::endl; + // assert(false); + } + return re.value(); + } + sema.acquire(); + return queue.dequeue().value(); + } + + template + std::optional try_dequeue_for(std::chrono::duration dur) { + auto re = queue.dequeue(); + if (re.has_value()) { + while (sema.try_acquire() == false) { + std::cerr << __FILE__ << ":" << __FUNCTION__ << " sema try acquire should be success, retrying, please check" + << std::endl; + // assert(false); + } + return re.value(); + } + + if (sema.try_acquire_for(dur)) { + return queue.dequeue().value(); + } else { + return std::nullopt; + } + } + + size_t size() { return queue.size(); } +}; diff --git a/csrc/balance_serve/sched/utils/readable_number.hpp b/csrc/balance_serve/sched/utils/readable_number.hpp new file mode 100644 index 0000000..94de923 --- /dev/null +++ b/csrc/balance_serve/sched/utils/readable_number.hpp @@ -0,0 +1,20 @@ +#pragma once +#include +#include +#include +#include + +inline std::array units = {"", "K", "M", "G", "T", "P", "E"}; + +inline std::string readable_number(size_t size) { + size_t unit_index = 0; + double readable_size = size; + while (readable_size >= 1000 && unit_index < units.size() - 1) { + readable_size /= 1000; + unit_index++; + } + std::ostringstream ss; + ss << std::fixed << std::setprecision(2) << readable_size; + std::string str = ss.str(); + return str + "" + units[unit_index]; +} \ No newline at end of file diff --git a/csrc/balance_serve/sched/utils/statistics.hpp b/csrc/balance_serve/sched/utils/statistics.hpp new file mode 100644 index 0000000..98e82a7 --- /dev/null +++ b/csrc/balance_serve/sched/utils/statistics.hpp @@ -0,0 +1,65 @@ +#ifndef STATISTICS_HPP +#define STATISTICS_HPP + +#include +#include +#include +#include + +class Statistics { + public: + // Increment the counter for a given key by a specified value (default is 1) + void increment_counter(const std::string& key, int64_t value = 1) { counters_[key] += value; } + + int64_t& get_counter(const std::string& key) { return counters_[key]; } + + // Start the timer for a given key + void start_timer(const std::string& key) { active_timers_[key] = std::chrono::high_resolution_clock::now(); } + + // Stop the timer for a given key and update the total time and count + void stop_timer(const std::string& key) { + auto start_it = active_timers_.find(key); + if (start_it != active_timers_.end()) { + auto duration = std::chrono::high_resolution_clock::now() - start_it->second; + timings_[key].total_time += duration; + timings_[key].count += 1; + active_timers_.erase(start_it); + } else { + // Handle error: stop_timer called without a matching start_timer + std::cerr << "Warning: stop_timer called for key '" << key << "' without a matching start_timer.\n"; + } + } + + // Print out the collected statistical information + void report() const { + std::cout << "Counters:\n"; + for (const auto& kv : counters_) { + std::cout << " " << kv.first << ": " << kv.second << "\n"; + } + std::cout << "\nTimers:\n"; + for (const auto& kv : timings_) { + std::cout << " " << kv.first << ": count = " << kv.second.count + << ", total_time = " << kv.second.total_time.count() << "s" + << ", average_time = " << (kv.second.count > 0 ? kv.second.total_time.count() / kv.second.count : 0) + << "s\n"; + } + } + + private: + // Mapping from key to counter + std::unordered_map counters_; + + // Struct to hold timing information for a key + struct TimingInfo { + int64_t count = 0; + std::chrono::duration total_time = std::chrono::duration::zero(); + }; + + // Mapping from key to timing information + std::unordered_map timings_; + + // Mapping from key to the start time of active timers + std::unordered_map active_timers_; +}; + +#endif // STATISTICS_HPP diff --git a/csrc/balance_serve/sched/utils/timer.hpp b/csrc/balance_serve/sched/utils/timer.hpp new file mode 100644 index 0000000..7ec4fc5 --- /dev/null +++ b/csrc/balance_serve/sched/utils/timer.hpp @@ -0,0 +1,128 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "readable_number.hpp" + +inline std::string doubleToStringR2(double value) { + std::stringstream stream; + stream << std::fixed << std::setprecision(2) << value; + return stream.str(); +} + +class Timer { + public: + std::string name; + bool tmp_timer = false; + + Timer() {} + Timer(std::string name) : name(name), tmp_timer(true) { start(); } + ~Timer() { + if (tmp_timer) { + std::cout << name << " " << elapsedMs() << " ms" << std::endl; + } + } + + void start() { + m_startTime = std::chrono::high_resolution_clock::now(); + assert(m_isRunning == false); + m_isRunning = true; + } + + void stop() { + m_endTime = std::chrono::high_resolution_clock::now(); + assert(m_isRunning == true); + m_isRunning = false; + m_runningNs += elapsedNs(); + } + + double elapsedNs() { + std::chrono::time_point endTime; + + if (m_isRunning) { + endTime = std::chrono::high_resolution_clock::now(); + } else { + endTime = m_endTime; + } + + return std::chrono::duration_cast(endTime - m_startTime).count(); + } + + void printElapsedMilliseconds() { std::cout << elapsedNs() / 1e6 << " ms" << std::endl; } + + static std::string ns_to_string(double duration) { + auto nano_sec = duration; + if (nano_sec >= 1000) { + auto mirco_sec = nano_sec / 1000.0; + if (mirco_sec >= 1000) { + auto milli_sec = mirco_sec / 1000.0; + if (milli_sec >= 1000) { + auto seconds = milli_sec / 1000.0; + + if (seconds >= 60.0) { + auto minutes = seconds / 60.0; + + if (minutes >= 60.0) { + auto hours = minutes / 60.0; + return doubleToStringR2(hours) + " h"; + } else { + return doubleToStringR2(minutes) + " min"; + } + } else { + return doubleToStringR2(seconds) + " sec"; + } + } else { + return doubleToStringR2(milli_sec) + " ms"; + } + } else { + return doubleToStringR2(mirco_sec) + " us"; + } + } else { + return doubleToStringR2(nano_sec) + " ns"; + } + } + + double runningTimeNs() { return m_runningNs; } + + std::string runningTime() { + auto duration = m_runningNs; + return ns_to_string(duration); + } + + std::string elapsedTime() { return ns_to_string(elapsedNs()); } + double elapsedMs() { return elapsedNs() / 1e6; } + std::string report_throughput(size_t op_cnt) { + double ops = op_cnt / elapsedMs() * 1000; + return readable_number(ops) + "op/s"; + } + + void merge(Timer& other) { + assert(m_isRunning == false); + assert(other.m_isRunning == false); + m_runningNs += other.runningTimeNs(); + } + + private: + std::chrono::time_point m_startTime; + std::chrono::time_point m_endTime; + bool m_isRunning = false; + double m_runningNs = 0.0; +}; + +class Counter { + public: + Counter() {} + + std::map counters; + + void inc(const char* name, size_t num) { counters[name] += num; }; + void print() { + for (auto& p : counters) { + std::cout << p.first << " : " << p.second << std::endl; + } + }; +}; diff --git a/csrc/custom_marlin/__init__.py b/csrc/custom_marlin/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/csrc/custom_marlin/binding.cpp b/csrc/custom_marlin/binding.cpp new file mode 100644 index 0000000..184f3e2 --- /dev/null +++ b/csrc/custom_marlin/binding.cpp @@ -0,0 +1,44 @@ +/** + * @Description : + * @Author : Azure-Tang + * @Date : 2024-07-25 13:38:30 + * @Version : 1.0.0 + * @LastEditors : kkk1nak0 + * @LastEditTime : 2024-08-12 03:05:04 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + **/ + +#include "gptq_marlin/ops.h" +// Python bindings +#include +#include +#include +#include +#include +// namespace py = pybind11; + +PYBIND11_MODULE(vLLMMarlin, m) { + + /*m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 + data.", py::arg("data"), py::arg("blk_size"), py::arg("device")); + m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k + data.", py::arg("data"), py::arg("blk_size"), py::arg("device")); + m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k + data.", py::arg("data"), py::arg("blk_size"), py::arg("device")); + m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k + data.", py::arg("data"), py::arg("blk_size"), py::arg("device")); + m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k + data.", py::arg("data"), py::arg("blk_size"), py::arg("device")); + m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k + data.", py::arg("data"), py::arg("blk_size"), py::arg("device")); + m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize + iq4_xs data.", py::arg("data"), py::arg("blk_size"), py::arg("device"));*/ + m.def("gptq_marlin_gemm", &gptq_marlin_gemm, + "Function to perform GEMM using Marlin quantization.", py::arg("a"), + py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"), + py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m_tensor"), + py::arg("size_m"), py::arg("size_n"), py::arg("size_k"), + py::arg("sms"), py::arg("is_k_full")); + m.def("gptq_marlin_repack", &gptq_marlin_repack, + "gptq_marlin repack from GPTQ"); +} \ No newline at end of file diff --git a/csrc/custom_marlin/gptq_marlin/gptq_marlin.cu b/csrc/custom_marlin/gptq_marlin/gptq_marlin.cu new file mode 100644 index 0000000..3ecaeb0 --- /dev/null +++ b/csrc/custom_marlin/gptq_marlin/gptq_marlin.cu @@ -0,0 +1,2034 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + /* + * Adapted from https://github.com/IST-DASLab/marlin + */ + /* + * Adapted from + * https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin + */ +#include "gptq_marlin.cuh" +#include "gptq_marlin_dtypes.cuh" +#include +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || \ + std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +template inline std::string str(T x) { return std::to_string(x); } + +namespace gptq_marlin { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + + __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) {} + + template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > + __global__ void + Marlin(const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization + ) {} + +} // namespace gptq_marlin + +torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& g_idx, + torch::Tensor& perm, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full) { + TORCH_CHECK_NOT_IMPLEMENTED(false, + "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({ 1, 1 }); +} + +#else + + // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 + // output/accumulation. + template + __device__ inline void mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } + else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } + else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } + } + + // Instruction for loading a full 16x16 matrix fragment of operand A from shared + // memory, directly in tensor core layout. + template + __device__ inline void ldsm4(typename ScalarType::FragA& frag_a, + const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); + } + + // Lookup-table based 3-input logical operation; explicitly used for + // dequantization as the compiler does not seem to automatically recognize it in + // all cases. + template __device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; + } + + // Constructs destination register by taking bytes from 2 sources (based on + // mask) + template + __device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; + } + + // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 + // values. We mostly follow the strategy in the link below, with some small + // changes: + // - FP16: + // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 + // - BF16: + // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 + template + __device__ inline typename ScalarType::FragB dequant_4bit(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } + + template <> + __device__ inline typename ScalarType::FragB dequant_4bit(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + typename ScalarType::FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; + } + + template <> + __device__ inline typename ScalarType::FragB + dequant_4bit(int q) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + typename ScalarType::FragB frag_b; + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC308C308; + + frag_b[0] = __hfma2(*reinterpret_cast(&lo), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; + } + + // Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or + // bf16 Reference: + // - FP16: + // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 + // - BF16: + // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 + template + __device__ inline typename ScalarType::FragB dequant_8bit(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } + + template <> + __device__ inline typename ScalarType::FragB dequant_8bit(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + typename ScalarType::FragB frag_b; + frag_b[0] = + __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = + __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; + } + + template <> + __device__ inline typename ScalarType::FragB + dequant_8bit(int q) { + typename ScalarType::FragB frag_b; + + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); + + return frag_b; + } + + // Multiply dequantized values by the corresponding quantization scale; used + // only for grouped quantization. + template + __device__ inline void scale(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = ScalarType::num2num2( + reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); + } + + // Same as above, but for act_order (each K is multiplied individually) + template + __device__ inline void scale4(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s_1, + typename ScalarType::FragS& frag_s_2, + typename ScalarType::FragS& frag_s_3, + typename ScalarType::FragS& frag_s_4, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); + } + + // Given 2 floats multiply by 2 scales (halves) + template + __device__ inline void scale_float(float* c, + typename ScalarType::FragS& s) { + scalar_t* s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); + } + + // Wait until barrier reaches `count`, then lock for current threadblock. + __device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be + // visible globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); + } + + // Release barrier and increment visitation count. + __device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } + } + + // For a given "a" of size [M,K] performs a permutation of the K columns based + // on the given "perm" indices. + __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + int start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = finish_row - start_row; + + int row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int offset = row * row_stride; + + half const* a_row_half = + reinterpret_cast(a_int4_ptr + offset); + half* out_half = reinterpret_cast(out_int4_ptr + offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } + } + + template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > + __device__ void + Marlin(const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m, should be divisible by (16 * thread_m_blocks) if bigger than that + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization + ) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + + constexpr int pack_factor = 32 / num_bits; + + // int prob_m = *prob_m_ptr; + // const int thread_m_blocks = min(div_ceil(prob_m, 16), template_thread_m_blocks); + // constexpr int thread_m_blocks = template_thread_m_blocks; + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + div_ceil(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } + else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + } + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + { + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_s = sh_g_idx + (stages * g_idx_stage); + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + + // Zero accumulators. + auto zero_accums = [&]() { +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + { + reinterpret_cast(frag_c)[i] = 0; + } + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } + else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } + else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + { + ldsm4(frag_a[k % 2][i], + &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + } + + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + +#pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } + else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } + else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = { 0, 1, 8, + 9 }; // Tensor core offsets per thread + +#pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. +#pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + if constexpr (num_bits == 4) { + int b_quant = frag_b_quant[k % 2][0][j]; + int b_quant_shift = b_quant >> 8; + + frag_b0 = dequant_4bit(b_quant); + frag_b1 = dequant_4bit(b_quant_shift); + + } + else { + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); + } + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + scale4(frag_b0, act_frag_s[k % 2][0][j], + act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], + act_frag_s[k % 2][3][j], 0); + } + else { + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + } + + // Apply scale to frag_b1 + if constexpr (has_act_order) { + scale4(frag_b1, act_frag_s[k % 2][0][j], + act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], + act_frag_s[k % 2][3][j], 1); + + } + else { + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + +#pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { +#pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { +#pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); +#pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { +#pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); +#pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = + Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) { + res = __hmul2(res, s[0]); + } + + ((scalar_t2*)sh)[idx] = res; + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + +#pragma unroll + for (int i = 0; + i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + +#pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + +#pragma unroll + for (int pipe = 0; pipe < stages;) { +#pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (num_bits == 8) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + else { + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (num_bits == 8) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + + } + else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + scale_float( + reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float( + reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float( + reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } + else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } + + start_pipes(); + } + } + } + } + + template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > + __global__ void + Marlin_wrapper(const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + const int* __restrict__ prob_m_ptr, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization + ) { + int prob_m = *prob_m_ptr; + const int thread_m_blocks = min(div_ceil(prob_m, 16), template_thread_m_blocks); + if(prob_m > 16 * thread_m_blocks) + prob_m = (16 * thread_m_blocks) * div_ceil(prob_m, (16 * thread_m_blocks)); + /*if (blockIdx.x == 0 && threadIdx.x == 0) + printf("marlin prob_m %d\n", prob_m);*/ + if (thread_m_blocks == 1) { + Marlin( + A, B, C, scales_ptr, g_idx, num_groups, prob_m, prob_n, + prob_k, locks); + } + else if (thread_m_blocks == 2) { + Marlin( + A, B, C, scales_ptr, g_idx, num_groups, prob_m, prob_n, + prob_k, locks); + } + else if (thread_m_blocks == 3) { + Marlin( + A, B, C, scales_ptr, g_idx, num_groups, prob_m, prob_n, + prob_k, locks); + } + else if (thread_m_blocks == 4) { + Marlin( + A, B, C, scales_ptr, g_idx, num_groups, prob_m, prob_n, + prob_k, locks); + } + } + +#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + Marlin_wrapper, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin_wrapper<<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m_ptr, prob_n, \ + prob_k, locks); \ + } + + typedef struct { + int thread_k; + int thread_n; + int num_threads; + } thread_config_t; + + typedef struct { + int max_m_blocks; + thread_config_t tb_cfg; + } exec_config_t; + + thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}, + }; + + thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + // {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}, + + }; + + int get_scales_cache_size(thread_config_t const& th_config, int prob_m, + int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } + else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } + else { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = + tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + + } + else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; + } + } + + bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int scales_cache_size, int max_shared_mem) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + + int b_size = (tb_k * tb_n / pack_factor) * 4; + + // Get A size + int m_blocks = div_ceil(prob_m, 16); + int tb_max_m = 16; + + // zbx: too ugly + // origin + /*while (true) { + if (m_blocks >= max_m_blocks) { + tb_max_m *= max_m_blocks; + break; + } + + max_m_blocks--; + if (max_m_blocks == 0) { + TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); + } + }*/ + // refactor + tb_max_m *= std::min(m_blocks, max_m_blocks); + + int a_size = (tb_max_m * tb_k) * 2; + + float pipe_size = (a_size + b_size) * pipe_stages; + + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); + } + + bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || + th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Determine cache for scales + int scales_cache_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); + + // Check that pipeline fits into cache + if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, scales_cache_size, max_shared_mem)) { + return false; + } + + return true; + } + + exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, + bool has_act_order, bool is_k_full, + int max_shared_mem) { + int max_m_blocks = 4; + while (max_m_blocks > 0) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, + prob_k, num_bits, group_size, has_act_order, + is_k_full, max_shared_mem)) { + return exec_config_t{ max_m_blocks, th_config }; + } + } + } + else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, + prob_k, num_bits, group_size, has_act_order, + is_k_full, max_shared_mem)) { + return exec_config_t{ max_m_blocks, th_config }; + } + } + } + + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage + } + + return exec_config_t{ 0, {-1, -1, -1} }; + } + +#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + + template + void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, + void* g_idx, void* perm, void* a_tmp, int* prob_m_ptr, int prob_m, + int prob_n, int prob_k, void* workspace, int num_bits, + bool has_act_order, bool is_k_full, int num_groups, + int group_size, int dev, cudaStream_t stream, int thread_k, + int thread_n, int sms, int max_par) { + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", + prob_m, ", ", prob_n, ", ", prob_k, "]"); + + int tot_m = prob_m; + int tot_m_blocks = div_ceil(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) { + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + // Set thread config + exec_config_t exec_cfg; + if (thread_k != -1 && thread_n != -1) { + // User-defined config + exec_cfg = exec_config_t{ + 4, thread_config_t{thread_k, thread_n, default_threads} }; + } + else { + // Auto config + exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full, + max_shared_mem); + } + + TORCH_CHECK( + exec_cfg.max_m_blocks > 0 && + is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m, + prob_n, prob_k, num_bits, group_size, has_act_order, + is_k_full, max_shared_mem), + "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, + ", thread_k = ", exec_cfg.tb_cfg.thread_k, + ", thread_n = ", exec_cfg.tb_cfg.thread_n, + ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", prob_m, + ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, ", has_act_order = ", has_act_order, + ", is_k_full = ", is_k_full, ", max_shared_mem = ", max_shared_mem); + + int num_threads = exec_cfg.tb_cfg.num_threads; + thread_k = exec_cfg.tb_cfg.thread_k; + thread_n = exec_cfg.tb_cfg.thread_n; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + int blocks = sms; + + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(group_size != -1); + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } + else { + TORCH_CHECK(group_size == 0); + group_blocks = 0; + } + + } + else { + if (group_size == -1) { + group_blocks = -1; + } + else { + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } + } + + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; + const int* g_idx_ptr = (const int*)g_idx; + const int* perm_ptr = (const int*)perm; + int4* a_tmp_ptr = (int4*)a_tmp; + + int* locks = (int*)workspace; + + if (has_act_order) { + // Permute A columns + int block_rows = div_ceil(prob_m, blocks); + permute_cols_kernel << > > ( + A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows); + A_ptr = a_tmp_ptr; + } + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) { + has_act_order = false; + } + + // Main loop + for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > exec_cfg.max_m_blocks) { + // Note that parallel > 1 currently only works for inputs without + // any padding + par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); + if (par > max_par) + par = max_par; + prob_m = (16 * exec_cfg.max_m_blocks) * par; + i += exec_cfg.max_m_blocks * (par - 1); + thread_m_blocks = exec_cfg.max_m_blocks; + } + + // Define kernel configurations +#define undefined_error \ + TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + \ + str(prob_n) + ", " + str(prob_k) + "]" + \ + ", has_act_order = " + str(has_act_order) + \ + ", num_groups = " + str(num_groups) + \ + ", group_size = " + str(group_size) + \ + ", thread_m_blocks = " + str(thread_m_blocks) + \ + ", thread_n_blocks = " + str(thread_n_blocks) + \ + ", thread_k_blocks = " + str(thread_k_blocks)); + + /* std::cout << "MNK = [" + str(prob_m) + ", " + \ + str(prob_n) + ", " + str(prob_k) + "]" + \ + ", has_act_order = " + str(has_act_order) + \ + ", num_groups = " + str(num_groups) + \ + ", group_size = " + str(group_size) + \ + ", thread_m_blocks = " + str(thread_m_blocks) + \ + ", thread_n_blocks = " + str(thread_n_blocks) + \ + ", thread_k_blocks = " + str(thread_k_blocks) << std::endl;*/ + + /*if (false) { + } + // CALL_IF(4, 32, 2, 256) + // CALL_IF(4, 16, 4, 256) + __CALL_IF(4, 1, 16, 4, false, 4, 256) + __CALL_IF(4, 2, 16, 4, false, 4, 256) + // CALL_IF(4, 8, 8, 256) + __CALL_IF(4, 1, 8, 8, false, 4, 256) + __CALL_IF(4, 2, 8, 8, false, 4, 256) + // CALL_IF(4, 16, 4, 128) + __CALL_IF(4, 1, 16, 4, false, 4, 128) + __CALL_IF(4, 2, 16, 4, false, 4, 128) + // CALL_IF(4, 8, 8, 128) + __CALL_IF(4, 1, 8, 8, false, 4, 128) + __CALL_IF(4, 2, 8, 8, false, 4, 128) + else {undefined_error}*/ + + if (num_bits == 4 && num_threads == 256) + { + if (false) { + } + CALL_IF(4, 32, 2, 256) + CALL_IF(4, 16, 4, 256) + CALL_IF(4, 8, 8, 256) + else { + undefined_error + } + } + else if (num_bits == 4 && num_threads == 128) + { + if (false) { + } + CALL_IF(4, 8, 4, 128) + CALL_IF(4, 16, 4, 128) + CALL_IF(4, 4, 8, 128) + else { + undefined_error + } + } + // else if (num_bits == 8 && num_threads == 256) + // { + // if (false) { + // } + // CALL_IF(8, 32, 2, 256) + // CALL_IF(8, 16, 4, 256) + // CALL_IF(8, 8, 8, 256) + // else { + // undefined_error + // } + // } + // else if (num_bits == 8 && num_threads == 128) + // { + // if (false) { + // } + // CALL_IF(8, 8, 4, 128) + // CALL_IF(8, 16, 4, 128) + // CALL_IF(8, 4, 8, 128) + // else { + // undefined_error + // } + // } + else { + undefined_error + } + + A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + } + } + +} // namespace gptq_marlin + +torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& g_idx, + torch::Tensor& perm, torch::Tensor& workspace, + int64_t num_bits, torch::Tensor size_m_tensor, int64_t size_m, int64_t size_n, + int64_t size_k, int sms, bool is_k_full) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + // Verify num_bits + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + int pack_factor = 32 / num_bits; + + // Verify A + TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), + ", size_m = ", size_m); + TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), + ", size_k = ", size_k); + + // Verify B + TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", gptq_marlin::tile_size); + TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), + ", size_k = ", size_k, + ", tile_size = ", gptq_marlin::tile_size); + TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, + "b_q_weight.size(1) = ", b_q_weight.size(1), + " is not divisible by tile_size = ", gptq_marlin::tile_size); + int actual_size_n = + (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor; + TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, + ", actual_size_n = ", actual_size_n); + + // Verify device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); + TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); + + TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); + TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); + + // Alloc buffers + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c = torch::empty({ size_m, size_n }, options); + torch::Tensor a_tmp = torch::empty({ size_m, size_k }, options); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel (can usually be left as auto -1) + // int sms = -1; //zbx + + // Verify g_idx and perm + TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) || + (g_idx.size(0) == size_k && perm.size(0) == size_k), + "Unexpected g_idx.size(0) = ", g_idx.size(0), + " and perm.size(0) = ", perm.size(0), + ", where size_k = ", size_k); + + // Detect groupsize and act_order + int num_groups = -1; + int group_size = -1; + bool has_act_order = g_idx.size(0) != 0; + + int b_rank = b_scales.sizes().size(); + TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); + TORCH_CHECK(b_scales.size(1) == size_n, + "b_scales dim 1 = ", b_scales.size(1), + " is not size_n = ", size_n); + num_groups = b_scales.size(0); + + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(num_groups > 1, + "For act_order, num_groups must be > 1"); + TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by num_groups = ", num_groups); + group_size = size_k / num_groups; + } + else { + group_size = 0; + } + + } + else { + if (num_groups > 1) { + TORCH_CHECK( + size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); + group_size = size_k / num_groups; + } + else { + group_size = -1; + } + } + + // Verify workspace size + TORCH_CHECK( + size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, + ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); + int min_workspace_size = + (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; + TORCH_CHECK(workspace.numel() >= min_workspace_size, + "workspace.numel = ", workspace.numel(), + " is below min_workspace_size = ", min_workspace_size); + + int dev = a.get_device(); + if (a.scalar_type() == at::ScalarType::Half) { + gptq_marlin::marlin_mm_f16i4( + a.data_ptr(), b_q_weight.data_ptr(), + c.data_ptr(), b_scales.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), + size_m_tensor.data_ptr(), + size_m, size_n, size_k, workspace.data_ptr(), num_bits, + has_act_order, is_k_full, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + gptq_marlin::max_par); + } + else if (a.scalar_type() == at::ScalarType::BFloat16) { + gptq_marlin::marlin_mm_f16i4( + a.data_ptr(), b_q_weight.data_ptr(), + c.data_ptr(), b_scales.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), + size_m_tensor.data_ptr(), + size_m, size_n, size_k, workspace.data_ptr(), num_bits, + has_act_order, is_k_full, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + gptq_marlin::max_par); + } + else { + TORCH_CHECK(false, + "gpt_marlin_gemm only supports bfloat16 and float16"); + } + + return c; +} + +#endif \ No newline at end of file diff --git a/csrc/custom_marlin/gptq_marlin/gptq_marlin.cuh b/csrc/custom_marlin/gptq_marlin/gptq_marlin.cuh new file mode 100644 index 0000000..5b4b059 --- /dev/null +++ b/csrc/custom_marlin/gptq_marlin/gptq_marlin.cuh @@ -0,0 +1,76 @@ +// Adapted from +// https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin +// Copyrigth 2024 The vLLM team. +// Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace gptq_marlin { + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +static constexpr int default_threads = 256; + +static constexpr int pipe_stages = + 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; + +template struct Vec { + T elems[n]; + __device__ T &operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +// No support for async +#else + +__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +template __device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +#endif + +} // namespace gptq_marlin \ No newline at end of file diff --git a/csrc/custom_marlin/gptq_marlin/gptq_marlin_dtypes.cuh b/csrc/custom_marlin/gptq_marlin/gptq_marlin_dtypes.cuh new file mode 100644 index 0000000..3e8c3ca --- /dev/null +++ b/csrc/custom_marlin/gptq_marlin/gptq_marlin_dtypes.cuh @@ -0,0 +1,77 @@ +// Adapted from +// https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin +// Copyrigth 2024 The vLLM team. +// Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +#ifndef _data_types_cuh +#define _data_types_cuh +#include "gptq_marlin.cuh" +#include +#include + +namespace gptq_marlin { + +template class ScalarType {}; + +template <> class ScalarType { + public: + using scalar_t = half; + using scalar_t2 = half2; + + // Matrix fragments for tensor core instructions; their precise layout is + // documented here: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + + static __device__ float inline num2float(const half x) { + return __half2float(x); + } + + static __device__ half2 inline num2num2(const half x) { + return __half2half2(x); + } + + static __device__ half2 inline nums2num2(const half x1, const half x2) { + return __halves2half2(x1, x2); + } + + static __host__ __device__ half inline float2num(const float x) { + return __float2half(x); + } +}; + +template <> class ScalarType { + public: + using scalar_t = nv_bfloat16; + using scalar_t2 = nv_bfloat162; + + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static __device__ float inline num2float(const nv_bfloat16 x) { + return __bfloat162float(x); + } + + static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { + return __bfloat162bfloat162(x); + } + + static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, + const nv_bfloat16 x2) { + return __halves2bfloat162(x1, x2); + } + + static __host__ __device__ nv_bfloat16 inline float2num(const float x) { + return __float2bfloat16(x); + } +#endif +}; + +} // namespace gptq_marlin + +#endif \ No newline at end of file diff --git a/csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu b/csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu new file mode 100644 index 0000000..4adcbd5 --- /dev/null +++ b/csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu @@ -0,0 +1,350 @@ +#include "gptq_marlin.cuh" + +namespace gptq_marlin { + +static constexpr int repack_stages = 8; + +static constexpr int repack_threads = 256; + +static constexpr int tile_k_size = tile_size; +static constexpr int tile_n_size = tile_k_size * 4; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template +__global__ void marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, + uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, + int size_k, int size_n) {} + +} // namespace gptq_marlin + +torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, + int64_t size_k, int64_t size_n, + int64_t num_bits) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +template +__global__ void marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, + uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, + int size_k, int size_n) { + constexpr int pack_factor = 32 / num_bits; + + int k_tiles = size_k / tile_k_size; + int n_tiles = size_n / tile_n_size; + int block_k_tiles = div_ceil(k_tiles, gridDim.x); + + int start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int perm_size = tile_k_size / 4; + + int4* sh_perm_ptr = sh; + int4* sh_pipe_ptr = sh_perm_ptr; + if constexpr (has_perm) { + sh_pipe_ptr += perm_size; + } + + constexpr int tile_ints = tile_k_size / pack_factor; + + constexpr int stage_n_threads = tile_n_size / 4; + constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto load_perm_to_shared = [&](int k_tile_id) { + int first_k_int4 = (k_tile_id * tile_k_size) / 4; + + int4 const* perm_int4_ptr = reinterpret_cast(perm_ptr); + + if (threadIdx.x < perm_size) { + sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; + } + __syncthreads(); + }; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * tile_n_size; + + int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; + + if constexpr (has_perm) { + if (threadIdx.x < stage_size) { + int k_id = threadIdx.x / stage_n_threads; + int n_id = threadIdx.x % stage_n_threads; + + uint32_t const* sh_perm_int_ptr = + reinterpret_cast(sh_perm_ptr); + + int src_k = sh_perm_int_ptr[k_id]; + int src_k_packed = src_k / pack_factor; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast(&( + b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); + } + + } else { + if (threadIdx.x < stage_size) { + int k_id = threadIdx.x / stage_n_threads; + int n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * tile_k_size; + int first_k_packed = first_k / pack_factor; + + cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + + first_n + (n_id * 4)]))); + } + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } + + int warp_id = threadIdx.x / 32; + int th_id = threadIdx.x % 32; + + if (warp_id >= 4) { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = warp_id * 16 + tc_col; + + constexpr int sh_stride = 64; + constexpr uint32_t mask = (1 << num_bits) - 1; + + int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; + uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + uint32_t* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + + uint32_t vals[8]; + + if constexpr (has_perm) { + for (int i = 0; i < 4; i++) { + int k_idx = tc_row + tc_offsets[i]; + + uint32_t src_k = sh_perm_int_ptr[k_idx]; + uint32_t src_k_pos = src_k % pack_factor; + + uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; + uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; + + uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; + uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; + + vals[i] = b1_cur_val; + vals[4 + i] = b2_cur_val; + } + + } else { + uint32_t b1_vals[tile_ints]; + uint32_t b2_vals[tile_ints]; + + #pragma unroll + for (int i = 0; i < tile_ints; i++) { + b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; + b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; + } + + #pragma unroll + for (int i = 0; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i]; + int cur_int = cur_elem / pack_factor; + int cur_pos = cur_elem % pack_factor; + + vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; + vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; + } + } + + constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (num_bits == 4) { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; + #pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; + #pragma unroll + for (int i = 0; i < 4; i++) { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) { + #pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; + #pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; + + if constexpr (has_perm) { + load_perm_to_shared(k_tile_id); + } + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) { + #pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, + n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } +} + +} // namespace gptq_marlin + + #define CALL_IF(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ + cudaFuncSetAttribute( \ + gptq_marlin::marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + gptq_marlin::marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ + } + +torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, + int64_t size_k, int64_t size_n, + int64_t num_bits) { + // Verify compatibility with marlin tile of 16x64 + TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k, + " is not divisible by tile_k_size = ", gptq_marlin::tile_k_size); + TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n, + " is not divisible by tile_n_size = ", gptq_marlin::tile_n_size); + + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + int const pack_factor = 32 / num_bits; + + // Verify B + TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), + ", size_k = ", size_k, ", pack_factor = ", pack_factor); + TORCH_CHECK(b_q_weight.size(1) == size_n, + "b_q_weight.size(1) = ", b_q_weight.size(1), + " is not size_n = ", size_n); + + // Verify device and strides + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt"); + + TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); + TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); + TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt"); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); + auto options = torch::TensorOptions() + .dtype(b_q_weight.dtype()) + .device(b_q_weight.device()); + torch::Tensor out = + torch::empty({size_k / gptq_marlin::tile_size, + size_n * gptq_marlin::tile_size / pack_factor}, + options); + + // Detect if there is act_order + bool has_perm = perm.size(0) != 0; + + // Get ptrs + uint32_t const* b_q_weight_ptr = + reinterpret_cast(b_q_weight.data_ptr()); + uint32_t const* perm_ptr = reinterpret_cast(perm.data_ptr()); + uint32_t* out_ptr = reinterpret_cast(out.data_ptr()); + + // Get dev info + int dev = b_q_weight.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + int blocks; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + if (false) { + } + CALL_IF(4, false) + CALL_IF(4, true) + CALL_IF(8, false) + CALL_IF(8, true) + else { + TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, + ", has_perm = ", has_perm); + } + + return out; +} + +#endif \ No newline at end of file diff --git a/csrc/custom_marlin/gptq_marlin/ops.h b/csrc/custom_marlin/gptq_marlin/ops.h new file mode 100644 index 0000000..b3327d3 --- /dev/null +++ b/csrc/custom_marlin/gptq_marlin/ops.h @@ -0,0 +1,24 @@ +/** + * @Description : + * @Author : Azure + * @Date : 2024-07-22 09:27:55 + * @Version : 1.0.0 + * @LastEditors : Azure + * @LastEditTime : 2024-07-26 08:35:00 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + **/ +#pragma once + +#include +#include +#include + +torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, + torch::Tensor &b_scales, torch::Tensor &g_idx, + torch::Tensor &perm, torch::Tensor &workspace, + int64_t num_bits, torch::Tensor size_m_tensor, int64_t size_m, int64_t size_n, + int64_t size_k, int sms, bool is_k_full); + +torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor&perm, + int64_t size_k, int64_t size_n, + int64_t num_bits); \ No newline at end of file diff --git a/csrc/custom_marlin/setup.py b/csrc/custom_marlin/setup.py new file mode 100644 index 0000000..ec710ca --- /dev/null +++ b/csrc/custom_marlin/setup.py @@ -0,0 +1,25 @@ +from setuptools import setup, Extension +from torch.utils import cpp_extension +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +setup( + name='vLLMMarlin', + ext_modules=[ + CUDAExtension( + 'vLLMMarlin', [ + #'custom_gguf/dequant.cu', + 'binding.cpp', + 'gptq_marlin/gptq_marlin.cu', + 'gptq_marlin/gptq_marlin_repack.cu', + ], + extra_compile_args={ + 'cxx': ['-O3'], + 'nvcc': [ + '-O3', + '--use_fast_math', + '-Xcompiler', '-fPIC', + ] + }, + ) + ], + cmdclass={'build_ext': BuildExtension} +) \ No newline at end of file diff --git a/csrc/custom_marlin/test_cuda_graph.py b/csrc/custom_marlin/test_cuda_graph.py new file mode 100644 index 0000000..0024082 --- /dev/null +++ b/csrc/custom_marlin/test_cuda_graph.py @@ -0,0 +1,335 @@ +import csv +import torch +import torch.nn as nn +import vLLMMarlin +torch.set_grad_enabled(False) +from utils.marlin_utils import ( + MarlinWorkspace, + marlin_quantize, + GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MIN_THREAD_K, + GPTQ_MARLIN_MAX_PARALLEL, +) + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +setup_seed(20241223) + +torch.set_grad_enabled(False) +torch.set_default_dtype(torch.bfloat16) +global_dtype=torch.bfloat16 +global_device=torch.device("cuda",0) +global_num_cases:int=int(50) +torch.cuda.set_device(0) +torch.backends.cudnn.enabled =True +torch.backends.cudnn.benchmark = True + +max_batch_size = 512 +max_tp = 8 +L2_size = 73728 * 1024 + +def get_usable_mem(): + properties = torch.cuda.get_device_properties(global_device) + #print(f"Total memory: {properties.total_memory / (1024 ** 3):.2f} GB") + allocated_memory = torch.cuda.memory_allocated(global_device) + #print(f"Currently allocated memory: {allocated_memory / (1024 ** 2):.2f} MB") + reserved_memory = torch.cuda.memory_reserved(global_device) + #print(f"Currently reserved memory: {reserved_memory / (1024 ** 2):.2f} MB") + return properties.total_memory - 512 * 1024 ** 2 - allocated_memory# - reserved_memory + +def exp_range(start, stop, step = 2): + now = start + while now <= stop: + yield now + now *= step + +def timing(func, iters, epochs=100): + #warmup + for idx in range(iters): + func(idx) + + torch.cuda.synchronize() + cuda_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cuda_graph): + for idx in range(iters): + func(idx) + + for _ in range(2000): + cuda_graph.replay() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + stream = torch.cuda.Stream() + torch.cuda.synchronize() + #with torch.cuda.stream(stream): + start_event.record() + for _ in range(10): + cuda_graph.replay() + end_event.record() + torch.cuda.synchronize() + elapsed_time_ms0 = start_event.elapsed_time(end_event) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + #with torch.cuda.stream(stream): + start_event.record() + for _ in range(epochs+10): + cuda_graph.replay() + end_event.record() + torch.cuda.synchronize() + elapsed_time_ms = start_event.elapsed_time(end_event) - elapsed_time_ms0 + + #print(elapsed_time_ms0, elapsed_time_ms) + return elapsed_time_ms/iters/epochs + +class LinearMarlin(nn.Linear): + marlin_q_w: torch.Tensor + marlin_s: torch.Tensor + g_idx: torch.Tensor + sort_indices: torch.Tensor + has_bias: bool + def __init__( + self, + in_features, + out_features, + bias = False, + device: str = "cuda", + num_bits: int = 4, # 4-bit/8-bit is supported + group_size: int = 64, # -1, 32, 64, 128 + act_order: bool = False, + is_k_full=True, + sms = -1, # sms in GPU + **kwargs, + ): + self.padding = False + assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device" + if in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or out_features%GPTQ_MARLIN_MIN_THREAD_K!=0: + #print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding") + self.padding = True + self.orin_in_features = in_features + self.orin_out_features = out_features + in_features = (in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K + out_features = (out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N + #print(f"After padding: in_features={in_features}, out_features={out_features}") + + + super().__init__(in_features, out_features, bias, device) + self.has_bias = bias + self.device = device + self.num_bits = num_bits + self.group_size = group_size + self.act_order = act_order + # TODO: optimize every shape GEMM + + blocks_k, blocks_n = in_features//128, out_features//128 + + self.sms = sms + + self.is_k_full = is_k_full + + self.weight.requires_grad = False + self.weight.t_() + # Pack Marlin linear + #w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( + # self.weight, self.num_bits, self.group_size, self.act_order + #) + marlin_q_w = torch.randint(int(-1e9), int(1e9), (in_features//16, out_features*2), device=device, dtype=torch.int) + marlin_s = torch.randn((in_features//64, out_features), device=device) + self.workspace = MarlinWorkspace( + self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL, self.device + ) + self.marlin_q_w = marlin_q_w + self.marlin_s = marlin_s + self.g_idx = torch.empty((0), dtype=torch.int32, device=self.device) + self.sort_indices = torch.empty((0), dtype=torch.int32, device=self.device) + self.k = self.weight.shape[0] + self.n = self.weight.shape[1] + self.weight = None + """ + print(in_features, out_features) + print(marlin_q_w.shape) + print(marlin_q_w.dtype) + print(marlin_s.shape) + print(marlin_s.dtype) + print(self.workspace.scratch.shape) + print(self.workspace.scratch.dtype) + print(self.g_idx.shape) + print(self.g_idx.dtype) + print(self.sort_indices.shape) + print(self.sort_indices.dtype) + #print(w_ref.shape) + #print(w_ref.dtype) + """ + #w_ref = None + + def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor) -> torch.Tensor: + # Only support input x as BF16 and FP16 + x = x.to(self.device) + orig_shape = list(x.shape) + orig_dtype = x.dtype + x = x.reshape(-1, x.shape[-1]) + if self.padding: + padding_input=torch.empty(x.shape[0], self.in_features, device=x.device, dtype=x.dtype) + padding_input[:,:self.orin_in_features] = x + x = padding_input + marlin_s = self.marlin_s.to(x.dtype) + #print(self.sms * ((orig_shape[0]+63)//64)) + + sms = self.sms + + x = vLLMMarlin.gptq_marlin_gemm( + x, + self.marlin_q_w, + marlin_s, + self.g_idx, + self.sort_indices, + self.workspace.scratch, + self.num_bits, + bsz_tensor, + x.shape[0], + self.n, + x.shape[-1], + sms, + self.is_k_full, + ) + # TODO: don't padding bias + if self.has_bias: + x = x + self.bias + if self.padding: + x = x[:,:self.orin_out_features] + orig_shape[-1] = self.orin_out_features + else: + orig_shape[-1] = self.out_features + return x.reshape(orig_shape).to(orig_dtype) + +def benchLinearMarlin(input_dim, output_dim):#, out_file + print("benchmarking MLP Marlin") + print("-----------------------------------------------------------") + headers = ["batch_size", "tp", "used_time", "bandwidth GB/s", "TFLOPS", "cases", "padding", "sms"] + print(" | ".join(headers) + "\n") + rows = [] + for batch_size in exp_range(1, 64): + for tp in exp_range(1, max_tp): + torch.cuda.empty_cache() + if output_dim % tp != 0: + continue + cur_output_dim = output_dim // tp + modules = [] + inputs = [] + data_size = int(0.53125*input_dim*cur_output_dim) + input_size = int(2*batch_size*input_dim) + output_size = int(2*batch_size*cur_output_dim) + usable_mem = get_usable_mem() - 2 * input_dim * cur_output_dim + min_cases = max(global_num_cases, (2*L2_size) // (data_size+input_size)) + cases = int(min(min_cases, (usable_mem * 0.8) // (data_size+input_size))) + #print(usable_mem, data_size, input_size, cases) + + bsz_tensor = torch.tensor([batch_size], device=global_device, dtype=torch.int32) + + if cases == 0: + row = [f"{batch_size}", "OOM", "OOM", "OOM", "0", "False"] + rows.append(row) + break + for _ in range(cases): + modules.append(LinearMarlin(input_dim, cur_output_dim, sms=56, non_equal_division=False).to(device=global_device).eval()) + inputs.append(torch.randn(batch_size, 1, input_dim, device=global_device)) + + def forward(case_id): + modules[case_id](inputs[case_id], bsz_tensor) + + used_time = timing(forward, iters=cases) + bandwidth = (data_size+input_size+output_size)/used_time/1e6 + flops = 2*batch_size*input_dim*cur_output_dim + tflops = flops/used_time/1e9 + cur_sms = modules[0].sms + row = [f"{batch_size}", f"{tp}", f"{used_time}", f"{bandwidth}", f"{tflops}", f"{cases}", modules[0].padding, cur_sms] + rows.append(row) + print(f"{batch_size}", f"{tp}", f"{used_time}", f"{bandwidth}", f"{tflops}", f"{cases}", modules[0].padding, cur_sms) + + """ + with open(out_file, 'w', newline='') as csvfile: + csvwriter = csv.writer(csvfile) + csvwriter.writerow(headers) + for row in rows: + csvwriter.writerow(row) + """ + + """ + markdown_table = " | ".join(headers) + "\n" + markdown_table += " | ".join(["---"] * len(headers)) + "\n" + for row in rows: + markdown_table += " | ".join(row) + "\n" + + print(markdown_table) + """ + #print("finish write file", out_file) + #print("-------------------------------------------------------------") + +if __name__ == "__main__": + + benchLinearMarlin(5120, 3584) + exit(0) + + max_batch = 1 + cur_batch = 1 + + + marlin_linear = LinearMarlin(5120, 3584) + + input_tensor = torch.randn(max_batch, 1, 5120, device="cuda", dtype=torch.bfloat16) + bsz_tensor = torch.tensor([max_batch], device="cuda", dtype=torch.int32) + + out_truth = marlin_linear(input_tensor, bsz_tensor) + + print(out_truth) + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + out_buf = marlin_linear(input_tensor, bsz_tensor) + + for i in range(10000): + g.replay() + + #torch.testing.assert_close(out_buf, out_truth, rtol=1e-3, atol=1e-3) + + marlin_linear = LinearMarlin(5120, 3584) + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + out_buf = marlin_linear(input_tensor, bsz_tensor) + + new_input = torch.randn(cur_batch, 1, 5120, device="cuda", dtype=torch.bfloat16) + bsz_tensor.copy_(torch.tensor([cur_batch], device="cuda", dtype=torch.int32)) + + new_out_truth = marlin_linear(new_input, bsz_tensor) + input_tensor[:cur_batch].copy_(new_input) + input_tensor[cur_batch:] = 0 + + g.replay() + + torch.cuda.synchronize() + + def printMinMax(tensor): + abs_tensor = torch.abs(tensor) + + min_val = torch.min(abs_tensor) + max_val = torch.max(abs_tensor) + + min_indices = (abs_tensor == min_val).nonzero(as_tuple=True) + max_indices = (abs_tensor == max_val).nonzero(as_tuple=True) + + print(f"min: {min_val.item()}") + print(f"min idx: {min_indices}") + print(f"max: {max_val.item()}") + print(f"max idx: {max_indices}") + + print(out_buf[:cur_batch].shape) + print(new_out_truth.shape) + + + printMinMax(out_buf[:cur_batch]) + printMinMax(new_out_truth) + + #torch.testing.assert_close(out_buf[:cur_batch, 0, :], new_out_truth[:cur_batch, 0, :], rtol=1e-3, atol=1e-3) diff --git a/csrc/custom_marlin/utils/__init__.py b/csrc/custom_marlin/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/csrc/custom_marlin/utils/format24.py b/csrc/custom_marlin/utils/format24.py new file mode 100644 index 0000000..2434e79 --- /dev/null +++ b/csrc/custom_marlin/utils/format24.py @@ -0,0 +1,308 @@ +# +# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es). +# + +import torch + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, + device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError( + "Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16") + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32") + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, + idxs0.unsqueeze(-1) // 2).view( + m, + k // 2) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view( + (-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = (meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12)) + elif quadbits_per_meta_elem == 8: + meta = (meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28)) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols, )) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix") + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device) + meta = torch.gather(meta_reordered.view(-1), 0, + meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view( + -1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_(0, dense_offsets, + sparse.view(torch.half).view(-1)) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " + f"{M} groups") + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask \ No newline at end of file diff --git a/csrc/custom_marlin/utils/marlin_24_perms.py b/csrc/custom_marlin/utils/marlin_24_perms.py new file mode 100644 index 0000000..d79926f --- /dev/null +++ b/csrc/custom_marlin/utils/marlin_24_perms.py @@ -0,0 +1,65 @@ +''' +Date: 2024-11-08 02:46:07 +LastEditors: djw +LastEditTime: 2024-11-08 02:46:41 +''' +"""This file is used for /tests and /benchmarks""" +from typing import Dict, List + +import numpy +import torch + + +# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501 +# +# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 +# with the tensor-core format that is described here: +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 +# +# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 +# (without the need to use ldmatrix instructions) # noqa: E501 +def get_perms_24(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: List[int] = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return perm, scale_perm, scale_perm_single + + +marlin_24_perm: Dict[int, torch.Tensor] = {} +marlin_24_scale_perm: Dict[int, List[int]] = {} +marlin_24_scale_perm_single: Dict[int, List[int]] = {} +for num_bits in [4, 8]: + perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits) + marlin_24_perm[num_bits] = perm_24 + marlin_24_scale_perm[num_bits] = scale_perm_24 + marlin_24_scale_perm_single[num_bits] = scale_perm_single_24 \ No newline at end of file diff --git a/csrc/custom_marlin/utils/marlin_perms.py b/csrc/custom_marlin/utils/marlin_perms.py new file mode 100644 index 0000000..62255ec --- /dev/null +++ b/csrc/custom_marlin/utils/marlin_perms.py @@ -0,0 +1,65 @@ +''' +Date: 2024-11-08 02:46:47 +LastEditors: djw +LastEditTime: 2024-11-08 02:46:55 +''' +"""This file is used for /tests and /benchmarks""" +from typing import Dict, List + +import numpy +import torch + + +# Precompute permutations for Marlin weight and scale shuffling # noqa: E501 +# +# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 +# with the tensor-core format that is described here: +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 +# +# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 +# (without the need to use ldmatrix instructions) # noqa: E501 +def get_perms(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return perm, scale_perm, scale_perm_single + + +marlin_perm: Dict[int, torch.Tensor] = {} +marlin_scale_perm: Dict[int, List[int]] = {} +marlin_scale_perm_single: Dict[int, List[int]] = {} +for num_bits in [4, 8]: + perm, scale_perm, scale_perm_single = get_perms(num_bits) + marlin_perm[num_bits] = perm + marlin_scale_perm[num_bits] = scale_perm + marlin_scale_perm_single[num_bits] = scale_perm_single \ No newline at end of file diff --git a/csrc/custom_marlin/utils/marlin_utils.py b/csrc/custom_marlin/utils/marlin_utils.py new file mode 100644 index 0000000..ccecdc3 --- /dev/null +++ b/csrc/custom_marlin/utils/marlin_utils.py @@ -0,0 +1,234 @@ +"""This file is used for /tests and /benchmarks""" +import random + +import numpy +import torch + +from .format24 import ( + mask_creator, sparse_semi_structured_from_dense_cutlass) +from .marlin_24_perms import ( + marlin_24_perm, marlin_24_scale_perm, marlin_24_scale_perm_single) +from .marlin_perms import ( + marlin_perm, marlin_scale_perm, marlin_scale_perm_single) +from .quant_utils import ( + get_pack_factor, quantize_weights, sort_weights, dequantize_weights) + + + +__cuda_arch = torch.cuda.get_device_capability() + +MARLIN_TILE = 16 + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8] +GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] +GPTQ_MARLIN_SUPPORTED_SYM = [True] + +def is_marlin_supported(): + return __cuda_arch[0] >= 8 + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm, + scale_perm_single): + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, + act_order: bool, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, + act_order) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, + marlin_perm[num_bits]) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, + marlin_scale_perm[num_bits], + marlin_scale_perm_single[num_bits]) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j:j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, num_bits): + assert q_24.shape == (size_k, size_n) + + # Remove zp to normalize over 0 + max_q_val = (1 << num_bits) - 1 + zp = (max_q_val + 1) // 2 + q_24_no_zp = q_24 - zp + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass( + q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore zp + q_24_comp = q_24_no_zp_comp + zp + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def marlin_24_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24, + num_bits, + group_size, + act_order=False) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, + num_bits) + size_k_comp = size_k // 2 + + # Reformat to marlin + marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n, + num_bits, marlin_24_perm[num_bits]) + marlin_24_s = marlin_permute_scales(s, size_k, size_n, group_size, + marlin_24_scale_perm[num_bits], + marlin_24_scale_perm_single[num_bits]) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def compute_max_diff(output, output_ref): + return torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref)) + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel, device): + assert (out_features % min_thread_n == 0), ( + "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n)) + + max_workspace_size = ((out_features // min_thread_n) * max_parallel) + + self.scratch = torch.zeros(max_workspace_size, + dtype=torch.int, + device=device) \ No newline at end of file diff --git a/csrc/custom_marlin/utils/quant_utils.py b/csrc/custom_marlin/utils/quant_utils.py new file mode 100644 index 0000000..077d22f --- /dev/null +++ b/csrc/custom_marlin/utils/quant_utils.py @@ -0,0 +1,195 @@ +"""This file is used for /tests and /benchmarks""" +import numpy +import torch + +SUPPORTED_NUM_BITS = [4, 8] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + +def get_pack_factor(num_bits): + assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size, ), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +# Function: Dequantize quantized weights +def dequantize_weights(qweight, qzeros, scales, g_idx, bits=4, group_size=128, device='cuda:0'): + # Create a tensor for bitwise right shift operation + wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32, device=device).unsqueeze(0) + + # Apply bitwise right shift and convert qzeros to the appropriate type + zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8) + torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros) + + # Reshape the zeros tensor + zeros = zeros + 1 + zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2]) + + # Reshape the scales tensor + scales = scales.reshape(-1, 1, scales.shape[-1]) + + # Similar bitwise right shift operation for qweight and reshape + weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8) + torch.bitwise_and(weight, (2 ** bits) - 1, out=weight) + weight = weight.reshape(-1, group_size, weight.shape[2]) + + # Apply dequantization formula and reshape the final weight + weight = (scales * (weight - zeros)) + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + + # Return the transposed weight + return weight.transpose(0, 1) + +def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, + act_order: bool): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Reshape to [groupsize, -1] + if group_size < size_k: + w = w.view((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + s = torch.max(torch.abs(w), 0, keepdim=True)[0] + s *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s + + # Restore original shapes + if group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + s = s.reshape((-1, size_n)).contiguous() + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k) + + w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to( + dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res + +def gptq_unpack( + q_res: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = 32 // num_bits + assert size_k % pack_factor == 0 + + orig_device = q_res.device + + q_res = q_res.cpu().numpy() + + q_w = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_w[i::pack_factor, :] = (q_res >> (num_bits * i)) & ((1 << num_bits) - 1) + + q_w = torch.from_numpy(q_w.astype(numpy.int32)).to(orig_device) + return q_w \ No newline at end of file diff --git a/csrc/ktransformers_ext/CMakeLists.txt b/csrc/ktransformers_ext/CMakeLists.txt index eefcadf..b273dd3 100644 --- a/csrc/ktransformers_ext/CMakeLists.txt +++ b/csrc/ktransformers_ext/CMakeLists.txt @@ -3,8 +3,15 @@ project(cpuinfer_ext VERSION 0.1.0) set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -ffast-math") + + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -ffast-math -fopenmp") set(CMAKE_BUILD_TYPE "Release") +# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -ffast-math -fopenmp") +# set(CMAKE_BUILD_TYPE "Debug") +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + + include(CheckCXXCompilerFlag) set(CMAKE_POSITION_INDEPENDENT_CODE ON) @@ -30,7 +37,7 @@ if (NOT MSVC) option(LLAMA_F16C "llama: enable F16C" OFF) endif() option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF) -option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" OFF) +option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" ON) option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF) option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF) @@ -147,6 +154,7 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW endif() else() if (LLAMA_NATIVE) + list(APPEND ARCH_FLAGS -mfma -mavx -mavx2) list(APPEND ARCH_FLAGS -march=native) endif() if (LLAMA_F16C) @@ -172,6 +180,7 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW list(APPEND ARCH_FLAGS -mavx512vnni) endif() if (LLAMA_AVX512_FANCY_SIMD) + message(STATUS "AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI enabled") list(APPEND ARCH_FLAGS -mavx512vl) list(APPEND ARCH_FLAGS -mavx512bw) list(APPEND ARCH_FLAGS -mavx512dq) @@ -238,9 +247,18 @@ if (WIN32) include_directories("$ENV{CUDA_PATH}/include") add_compile_definitions(KTRANSFORMERS_USE_CUDA=1) elseif (UNIX) - if (KTRANSFORMERS_USE_CUDA) - find_package(CUDA REQUIRED) - include_directories("${CUDA_INCLUDE_DIRS}") + if (NOT KTRANSFORMERS_USE_MUSA) + # find_package(CUDA REQUIRED) + # include_directories("${CUDA_INCLUDE_DIRS}") + include(CheckLanguage) + check_language(CUDA) + if(CMAKE_CUDA_COMPILER) + message(STATUS "CUDA detected") + find_package(CUDAToolkit REQUIRED) + include_directories(${CUDAToolkit_INCLUDE_DIRS}) + endif() + message(STATUS "enabling CUDA") + enable_language(CUDA) add_compile_definitions(KTRANSFORMERS_USE_CUDA=1) endif() @@ -278,19 +296,35 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/llamafile SOURCE_DIR3) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile SOURCE_DIR4) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kvcache SOURCE_DIR5) -set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5}) -message(STATUS "ALL_SOURCES: ${ALL_SOURCES}") + +set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5}) + +file(GLOB_RECURSE FMT_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.h") + +add_custom_target( + format + COMMAND clang-format + -i + -style=file + ${FMT_SOURCES} + COMMENT "Running clang-format on all source files" +) + + +add_library(llamafile STATIC ${SOURCE_DIR4}) + +message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") +message(STATUS "ARCH_FLAGS: ${ARCH_FLAGS}") pybind11_add_module(${PROJECT_NAME} MODULE ${ALL_SOURCES}) target_link_libraries(${PROJECT_NAME} PRIVATE llama) + + if(WIN32) target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart elseif(UNIX) - if(KTRANSFORMERS_USE_CUDA) - if(NOT DEFINED ENV{CUDA_HOME} OR "$ENV{CUDA_HOME}" STREQUAL "") - set(ENV{CUDA_HOME} "/usr/local/cuda") - endif() - target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so") + if(NOT KTRANSFORMERS_USE_MUSA) + target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so") endif() if (KTRANSFORMERS_USE_ROCM) add_compile_definitions(USE_HIP=1) @@ -304,21 +338,28 @@ endif() # Define the USE_NUMA option option(USE_NUMA "Disable NUMA support" OFF) + # Check if the USE_NUMA environment variable is set if(DEFINED ENV{USE_NUMA}) set(USE_NUMA ON) endif() -if (USE_NUMA) + +if(USE_NUMA) message(STATUS "NUMA support is enabled") else() message(STATUS "NUMA support is disabled") endif() find_library(NUMA_LIBRARY NAMES numa) -if (NUMA_LIBRARY AND USE_NUMA) + +if(NUMA_LIBRARY AND USE_NUMA) message(STATUS "NUMA library found: ${NUMA_LIBRARY} - enabling NUMA support") target_link_libraries(${PROJECT_NAME} PRIVATE ${NUMA_LIBRARY}) target_compile_definitions(${PROJECT_NAME} PRIVATE USE_NUMA) else() - message(STATUS "NUMA library not found or user not set USE_NUMA - disabling NUMA support") -endif() + if(USE_NUMA) + message(FATAL_ERROR "NUMA library not found - maybe sudo apt install libnuma-dev") + else() + message(STATUS "NUMA library not found or user not set USE_NUMA - disabling NUMA support") + endif() +endif() \ No newline at end of file diff --git a/csrc/ktransformers_ext/cpu_backend/backend.cpp b/csrc/ktransformers_ext/cpu_backend/backend.cpp index a254db9..7478d5c 100644 --- a/csrc/ktransformers_ext/cpu_backend/backend.cpp +++ b/csrc/ktransformers_ext/cpu_backend/backend.cpp @@ -151,4 +151,4 @@ void Backend::worker_thread(int thread_id) { return; } } -} +} \ No newline at end of file diff --git a/csrc/ktransformers_ext/cpu_backend/cpuinfer.h b/csrc/ktransformers_ext/cpu_backend/cpuinfer.h index d0f7b11..9c7e781 100644 --- a/csrc/ktransformers_ext/cpu_backend/cpuinfer.h +++ b/csrc/ktransformers_ext/cpu_backend/cpuinfer.h @@ -28,7 +28,7 @@ #include "backend.h" #include "task_queue.h" - #include "../vendors/vendor.h" + #include "./vendors/vendor.h" #include "llama.cpp/ggml-impl.h" diff --git a/csrc/ktransformers_ext/cuda/binding.cpp b/csrc/ktransformers_ext/cuda/binding.cpp index 5bba873..4aa00c6 100644 --- a/csrc/ktransformers_ext/cuda/binding.cpp +++ b/csrc/ktransformers_ext/cuda/binding.cpp @@ -68,4 +68,4 @@ PYBIND11_MODULE(KTransformersOps, m) { py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"), py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full")); #endif -} +} \ No newline at end of file diff --git a/csrc/ktransformers_ext/cuda/custom_gguf/dequant.cu b/csrc/ktransformers_ext/cuda/custom_gguf/dequant.cu index 3a6151b..c579469 100644 --- a/csrc/ktransformers_ext/cuda/custom_gguf/dequant.cu +++ b/csrc/ktransformers_ext/cuda/custom_gguf/dequant.cu @@ -879,4 +879,4 @@ torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const i } cudaDeviceSynchronize(); return output; -} +} \ No newline at end of file diff --git a/csrc/ktransformers_ext/cuda/custom_gguf/ops.h b/csrc/ktransformers_ext/cuda/custom_gguf/ops.h index 1740cbf..fe9161a 100644 --- a/csrc/ktransformers_ext/cuda/custom_gguf/ops.h +++ b/csrc/ktransformers_ext/cuda/custom_gguf/ops.h @@ -19,4 +19,4 @@ torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); -torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); +torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); \ No newline at end of file diff --git a/csrc/ktransformers_ext/ext_bindings.cpp b/csrc/ktransformers_ext/ext_bindings.cpp index 0078a79..d8d13e8 100644 --- a/csrc/ktransformers_ext/ext_bindings.cpp +++ b/csrc/ktransformers_ext/ext_bindings.cpp @@ -9,6 +9,7 @@ **/ // Python bindings #include "cpu_backend/cpuinfer.h" +#include "device_launch_parameters.h" #include "llamafile/flags.h" #include "operators/kvcache/kvcache.h" #include "operators/llamafile/linear.h" @@ -535,16 +536,17 @@ class MOEBindings { const float *weights; const void *input; void *output; + int *batch_size_tensor; }; static void inner(void *args) { Args *args_ = (Args *)args; args_->cpuinfer->enqueue( &MOE::forward, args_->moe, args_->qlen, args_->k, - args_->expert_ids, args_->weights, args_->input, args_->output); + args_->expert_ids, args_->weights, args_->input, args_->output, args_->batch_size_tensor); } static std::pair cpuinfer_interface(MOE &moe, int qlen, int k, intptr_t expert_ids, - intptr_t weights, intptr_t input, intptr_t output) { + intptr_t weights, intptr_t input, intptr_t output, intptr_t batch_size_tensor) { Args *args = new Args{nullptr, &moe, qlen, @@ -552,7 +554,8 @@ class MOEBindings { (const uint64_t *)expert_ids, (const float *)weights, (const void *)input, - (void *)output}; + (void *)output, + (int *)batch_size_tensor}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; @@ -679,4 +682,4 @@ PYBIND11_MODULE(cpuinfer_ext, m) { cpuinfer_interface) .def("calc_anchor_all_layers", &KVCacheBindings::CalcAnchorAllLayersBindinds::cpuinfer_interface); -} +} \ No newline at end of file diff --git a/csrc/ktransformers_ext/operators/llamafile/moe.cpp b/csrc/ktransformers_ext/operators/llamafile/moe.cpp index 35c144f..cd42691 100644 --- a/csrc/ktransformers_ext/operators/llamafile/moe.cpp +++ b/csrc/ktransformers_ext/operators/llamafile/moe.cpp @@ -341,7 +341,8 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float* }, nullptr); } -void MOE::forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend) { +void MOE::forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, int* batch_size_tensor, Backend* backend) { + qlen = batch_size_tensor[0]; if (qlen < config_.group_min_len) { for (int i = 0; i < qlen; i++) { forward_one(k, expert_ids + i * k, weights + i * k, (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend); @@ -350,5 +351,7 @@ void MOE::forward(int qlen, int k, const uint64_t* expert_ids, const float* weig } int forward_len = std::min(config_.group_max_len, qlen); forward_many(forward_len, k, expert_ids, weights, input, output, backend); - forward(qlen - forward_len, k, expert_ids + forward_len * k, weights + forward_len * k, (uint8_t*)input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend); + + batch_size_tensor[0] -= forward_len; + forward(qlen - forward_len, k, expert_ids + forward_len * k, weights + forward_len * k, (uint8_t*)input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), batch_size_tensor, backend); } \ No newline at end of file diff --git a/csrc/ktransformers_ext/operators/llamafile/moe.h b/csrc/ktransformers_ext/operators/llamafile/moe.h index a39e21d..9a8b6cd 100644 --- a/csrc/ktransformers_ext/operators/llamafile/moe.h +++ b/csrc/ktransformers_ext/operators/llamafile/moe.h @@ -53,7 +53,7 @@ class MOE { void warm_up(Backend* backend); void forward_one(int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend); void forward_many(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend); - void forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend); + void forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, int* batch_size_tensor, Backend* backend); private: MOEConfig config_; diff --git a/doc/README.md b/doc/README.md index b4f9e3a..c09eb98 100644 --- a/doc/README.md +++ b/doc/README.md @@ -22,13 +22,14 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin

🔥 Updates

+* **Mar 27, 2025**: Support Multi-concurrency. * **Mar 15, 2025**: Support ROCm on AMD GPU ([Tutorial](./en/ROCm.md)). * **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./en/fp8_kernel.md) weights. Support 139K [Longer Context](./en/DeepseekR1_V3_tutorial.md#v022-longer-context) for DeepSeek-V3 and R1 in 24GB VRAM. * **Feb 25, 2025**: Support [FP8 GPU kernel](./en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./en/DeepseekR1_V3_tutorial.md#v022-longer-context). * **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. The detailed tutorial is [here](./en/DeepseekR1_V3_tutorial.md). * **Aug 28, 2024**: Support 1M context under the InternLM2.5-7B-Chat-1M model, utilizing 24GB of VRAM and 150GB of DRAM. The detailed tutorial is [here](./en/long_context_tutorial.md). * **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G. -* **Aug 15, 2024**: Update detailed [TUTORIAL](./en/injection_tutorial.md) for injection and multi-GPU. -* **Aug 14, 2024**: Support llamfile as linear backend. +* **Aug 15, 2024**: Update detailed [TUTORIAL](./en/injection_tutorial.md) for injection and multi-GPU. +* **Aug 14, 2024**: Support llamfile as linear backend. * **Aug 12, 2024**: Support multiple GPU; Support new model: mixtral 8\*7B and 8\*22B; Support q2k, q3k, q5k dequant on gpu. * **Aug 9, 2024**: Support windows native. diff --git a/doc/SUMMARY.md b/doc/SUMMARY.md index 854549c..44cd892 100644 --- a/doc/SUMMARY.md +++ b/doc/SUMMARY.md @@ -23,4 +23,28 @@ # V3 Reproduction - [Success List](en/V3-success.md) # Benchmark -- [Benchmark](en/benchmark.md) \ No newline at end of file +- [Benchmark](# Ktransformer + +[Introduction](./README.md) +# Install +- [Installation Guide](en/install.md) + +# Tutorial +- [Deepseek-R1/V3 Show Case/Tutorial](en/DeepseekR1_V3_tutorial.md) +- [Why KTransformers So Fast](en/deepseek-v2-injection.md) +- [Injection Tutorial](en/injection_tutorial.md) +- [Multi-GPU Tutorial](en/multi-gpu-tutorial.md) +- [Use FP8 GPU Kernel](en/fp8_kernel.md) +# Server + - [Server](en/api/server/server.md) + - [Website](en/api/server/website.md) + - [Tabby](en/api/server/tabby.md) +# For Developer +- [Makefile Usage](en/makefile_usage.md) + +# FAQ +- [FAQ](en/FAQ.md) +# V3 Reproduction +- [Success List](en/V3-success.md) +# Benchmark +- [Benchmark]( \ No newline at end of file diff --git a/doc/en/DeepseekR1_V3_tutorial.md b/doc/en/DeepseekR1_V3_tutorial.md index 082078d..b2d59d2 100644 --- a/doc/en/DeepseekR1_V3_tutorial.md +++ b/doc/en/DeepseekR1_V3_tutorial.md @@ -1,169 +1,204 @@ + # GPT-4/o1-level Local VSCode Copilot on a Desktop with only 24GB VRAM + - [SUMMARY](#summary) - - [Show Case Environment](#show-case-environment) - - [Bench Result](#bench-result) - - [V0.2.1](#v021) - - [Memory consumption:](#memory-consumption) - - [Change Log](#change-log) - - [Benchmark Results](#benchmark-results) - - [V0.2](#v02) - - [Settings](#settings) - - [Memory consumption:](#memory-consumption-1) - - [Benchmark Results](#benchmark-results-1) - - [V0.3-Preview](#v03-preview) - - [Settings](#settings-1) - - [Memory consumptions:](#memory-consumptions) - - [Benchmark results](#benchmark-results-2) - - [How to Run](#how-to-run) - - [v0.2.2 \& v0.2.3 longer context \& FP8 kernel](#v022--v023-longer-context--fp8-kernel) - - [longer context](#longer-context) - - [FP8 kernel](#fp8-kernel) - - [V0.2 \& V0.2.1 Showcase](#v02--v021-showcase) - - [Single socket version (32 cores)](#single-socket-version-32-cores) - - [Dual socket version (64 cores)](#dual-socket-version-64-cores) - - [V0.3 Showcase](#v03-showcase) - - [Dual socket version (64 cores)](#dual-socket-version-64-cores-1) - - [Some Explanations](#some-explanations) - - [Next](#next) - - [Faster](#faster) - - [Easier](#easier) - - [FAQ](#faq) - - [R1 No Thinking](#r1-no-thinking) - - [More FAQ](#more-faq) + - [Show Case Environment](#show-case-environment) + - [Bench Result](#bench-result) + - [V0.2.1](#v021) + - [Memory consumption:](#memory-consumption) + - [Change Log](#change-log) + - [Benchmark Results](#benchmark-results) + - [V0.2](#v02) + - [Settings](#settings) + - [Memory consumption:](#memory-consumption-1) + - [Benchmark Results](#benchmark-results-1) + - [V0.3-Preview](#v03-preview) + - [Settings](#settings-1) + - [Memory consumptions:](#memory-consumptions) + - [Benchmark results](#benchmark-results-2) + - [How to Run](#how-to-run) + - [v0.2.2 \& v0.2.3 longer context \& FP8 kernel](#v022--v023-longer-context--fp8-kernel) + - [longer context](#longer-context) + - [FP8 kernel](#fp8-kernel) + - [V0.2 \& V0.2.1 Showcase](#v02--v021-showcase) + - [Single socket version (32 cores)](#single-socket-version-32-cores) + - [Dual socket version (64 cores)](#dual-socket-version-64-cores) + - [V0.3 Showcase](#v03-showcase) + - [Dual socket version (64 cores)](#dual-socket-version-64-cores-1) + - [Some Explanations](#some-explanations) + - [Next](#next) + - [Faster](#faster) + - [Easier](#easier) + - [FAQ](#faq) + - [R1 No Thinking](#r1-no-thinking) + - [More FAQ](#more-faq) # SUMMARY > **Feb 10, 2025**: Support DeepseekR1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup.
-Hi, we're the KTransformers team (formerly known for our local CPU/GPU hybrid inference open source project with DeepSeek-V2). +Hi, we're the KTransformers team (formerly known for our local CPU/GPU hybrid inference open source project with DeepSeek-V2). -We've heard your requests for DeepSeek-R1/V3 support—and we're excited to finally deliver! +We've heard your requests for DeepSeek-R1/V3 support—and we're excited to finally deliver! Apologies for the wait, but we've been cooking up something truly amazing! -Today, we're proud to announce that we not only support DeepSeek-R1/V3, as showcased in the video below: +Today, we're proud to announce that we not only support DeepSeek-R1/V3, as showcased in the video below: https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285

- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 14GB VRAM and 382GB DRAM. - - Prefill Speed (tokens/s): - - KTransformers: 54.21 (32 cores) → 74.362 (dual-socket, 2×32 cores) → 255.26 (optimized AMX-based MoE kernel, V0.3 only) → 286.55 (selectively using 6 experts, V0.3 only) - - Compared to 10.31 tokens/s in llama.cpp with 2×32 cores, achieving up to **27.79× speedup**. - - Decode Speed (tokens/s): - - KTransformers: 8.73 (32 cores) → 11.26 (dual-socket, 2×32 cores) → 13.69 (selectively using 6 experts, V0.3 only) - - Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **3.03× speedup**. - + - Prefill Speed (tokens/s): + - KTransformers: 54.21 (32 cores) → 74.362 (dual-socket, 2×32 cores) → 255.26 (optimized AMX-based MoE kernel, V0.3 only) → 286.55 (selectively using 6 experts, V0.3 only) + - Compared to 10.31 tokens/s in llama.cpp with 2×32 cores, achieving up to **27.79× speedup**. + - Decode Speed (tokens/s): + - KTransformers: 8.73 (32 cores) → 11.26 (dual-socket, 2×32 cores) → 13.69 (selectively using 6 experts, V0.3 only) + - Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **3.03× speedup**. We also give our upcoming optimizations previews, including an Intel AMX-accelerated kernel and a selective expert activation method, which will significantly enhance performance. With V0.3-preview, we achieve up to 286 tokens/s for prefill, making it up to **28× faster than llama.cpp** for local inference. -The binary distribution is available now and the source code will come ASAP! Check out the wheel package [here](https://github.com/kvcache-ai/ktransformers/releases/download/v0.1.4/ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl) +The binary distribution is available now and the source code will come ASAP! Check out the wheel package [here](https://github.com/kvcache-ai/ktransformers/releases/download/v0.1.4/ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl) > **Feb 15, 2025**: KTransformers V0.2.1: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%) (Up to 16 Tokens/s), update docs [here](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/). We speed up the decode and prefill speed a littlt bit. The reason for the limited performance improvement mainly lies in the fact that the inference process is still constrained by the CPU's computational speed and memory bandwidth. The MLA part handled by the GPU accounts for a relatively small proportion. Besides the improvements in speed, we've also significantly updated the documentation to enhance usability, including:
+ - Added Multi-GPU configuration tutorial. - Consolidated installation guide. - Add a detailed tutorial on registering extra GPU memory with ExpertMarlin; - ## Show Case Environment + We run our best performance tests (V0.2) on
CPU: Intel (R) Xeon (R) Gold 6454S 1T DRAM (2 NUMA nodes)
GPU: 4090D 24G VRAM
Memory: standard DDR5-4800 server DRAM (1 TB), each socket with 8×DDR5-4800 + ## Bench Result + ### V0.2.1 + - Model: DeepseekV3-q4km (int4)
- CPU: cpu_model_name: Intel (R) Xeon (R) Gold 6454S, 32 cores per socket, 2 sockets, 2 numa nodes - GPU: 4090 24G VRAM - We test after enough warm up -#### Memory consumption: - - Single socket: 382G DRAM, at least 14GB VRAM - - Dual socket: 1T DRAM, at least 14GB VRAM -#### Change Log -- Longer Context (from 4K to 8K for 24GB VRAM) and Slightly Faster Speed (+15%):
-Integrated the highly efficient Triton MLA Kernel from the fantastic sglang project, enable much longer context length and slightly faster prefill/decode speed -- We suspect that some of the improvements come from the change of hardware platform (4090D->4090) -#### Benchmark Results +#### Memory consumption: + +- Single socket: 382G DRAM, at least 14GB VRAM +- Dual socket: 1T DRAM, at least 14GB VRAM + +#### Change Log + +- Longer Context (from 4K to 8K for 24GB VRAM) and Slightly Faster Speed (+15%):
+ Integrated the highly efficient Triton MLA Kernel from the fantastic sglang project, enable much longer context length and slightly faster prefill/decode speed +- We suspect that some of the improvements come from the change of hardware platform (4090D->4090) + +#### Benchmark Results "6 experts" case is part of V0.3's preview -| Prompt | hi (2) | 1K (969) | 2K (1930) | 4K (3846) | 8K (7678) | -| --- | --- | --- | --- | --- | --- | -| Output length | 10tokens | 300tokens | 300tokens | 300tokens | 300tokens | -| **6 experts V0.2.0** | | | | | | -| Prefill token/s | 13 | 105 | 102 | 88 | CUDA OOM | -| decode token/s | 16.8 | 15.4 | 14.2 | 13.0 | CUDA OOM | -| **6 experts V0.2.1** | | | | | | -| Prefill token/s | 13 | 111 | 112.5 | 102 **(1.16x speedup)** | 101 | -| decode token/s | 16.8 | 15.9 | 15.4 | 14.9 **(1.15x speedup)** | 13.9 | -| **8 experts V0.2.1** | | | | | | -| Prefill token/s | 12.2 | 88.2 | 88.5 | 81.9 | 80 | -| Decode token/s | 13.4 | 13.5 | 13.4 | 13.2 | 12.4 | - +| Prompt | hi (2) | 1K (969) | 2K (1930) | 4K (3846) | 8K (7678) | +| -------------------- | -------- | --------- | --------- | ----------------------- | --------- | +| Output length | 10tokens | 300tokens | 300tokens | 300tokens | 300tokens | +| **6 experts V0.2.0** | | | | | | +| Prefill token/s | 13 | 105 | 102 | 88 | CUDA OOM | +| decode token/s | 16.8 | 15.4 | 14.2 | 13.0 | CUDA OOM | +| **6 experts V0.2.1** | | | | | | +| Prefill token/s | 13 | 111 | 112.5 | 102**(1.16x speedup)** | 101 | +| decode token/s | 16.8 | 15.9 | 15.4 | 14.9**(1.15x speedup)** | 13.9 | +| **8 experts V0.2.1** | | | | | | +| Prefill token/s | 12.2 | 88.2 | 88.5 | 81.9 | 80 | +| Decode token/s | 13.4 | 13.5 | 13.4 | 13.2 | 12.4 | ### V0.2 + #### Settings + - Model: DeepseekV3-q4km (int4)
- CPU: cpu_model_name: Intel (R) Xeon (R) Gold 6454S, 32 cores per socket, 2 sockets, 2 numa nodes - GPU: 4090D 24G VRAM - We test after enough warm up + #### Memory consumption: - - Single socket: 382G DRAM, at least 14GB VRAM - - Dual socket: 1T DRAM, at least 14GB VRAM + +- Single socket: 382G DRAM, at least 14GB VRAM +- Dual socket: 1T DRAM, at least 14GB VRAM #### Benchmark Results "6 experts" case is part of V0.3's preview -| Prompt
(500 tokens) | Dual socket Ktrans (6 experts) | Dual socket Ktrans (8 experts) | Single socket Ktrans (6 experts) | Single socket Ktrans (8 experts)| llama.cpp (8 experts) | -| --- | --- | --- | --- | --- | --- | -| Prefill token/s | 97.32 | 82.94 | 65.14 | 54.21 | 10.31 | -| Decode token/s | 13.69 | 12.208 | 10.303 | 8.73 |4.51 | + +| Prompt
(500 tokens) | Dual socket Ktrans (6 experts) | Dual socket Ktrans (8 experts) | Single socket Ktrans (6 experts) | Single socket Ktrans (8 experts) | llama.cpp (8 experts) | +| ---------------------- | ------------------------------ | ------------------------------ | -------------------------------- | -------------------------------- | --------------------- | +| Prefill token/s | 97.32 | 82.94 | 65.14 | 54.21 | 10.31 | +| Decode token/s | 13.69 | 12.208 | 10.303 | 8.73 | 4.51 | **The highest speedup reaches up to 3.03x in decoding and 9.44x in prefill.** ### V0.3-Preview + #### Settings + - Model: DeepseekV3-BF16 (online quant into int8 for CPU and int4 for GPU) - CPU: cpu_model_name: Intel (R) Xeon (R) Gold 6454S, 32 cores per socket, 2 socket, 2 numa nodes - GPU: (1~4)x 4090D 24GVRAM (requires more VRAM for longer prompt) #### Memory consumptions: + - 644GB DRAM, at least 14GB VRAM #### Benchmark results -| Prompt length | 1K | 2K | 4K | 8K | -|---------------|-----|-----|-----|-----| -| KTrans (8 experts) Prefill token/s | 185.96 | 255.26 | 252.58 | 195.62 | -| KTrans (6 experts) Prefill token/s | 203.70 | 286.55 | 271.08 | 207.20 | + + +| Prompt length | 1K | 2K | 4K | 8K | +| ---------------------------------- | ------ | ------ | ------ | ------ | +| KTrans (8 experts) Prefill token/s | 185.96 | 255.26 | 252.58 | 195.62 | +| KTrans (6 experts) Prefill token/s | 203.70 | 286.55 | 271.08 | 207.20 | **The prefill of KTrans V0.3 is up to 3.45x times faster than KTrans V0.2, and is up to 27.79x times faster than llama.cpp.** **The decoding speed is the same as KTrans V0.2 (6 experts version) so it is omitted** -The main acceleration comes from +The main acceleration comes from + - Intel AMX instruction set and our specially designed cache friendly memory layout - Expert selection strategy that selects fewer experts based on offline profile results of out of domain data - -*From our research on DeepSeekV2, DeepSeekV3 and DeepSeekR1, -when we slightly decrease the activation experts num in inference, -the output quality doesn't change. But the speed of decoding and prefill +*From our research on DeepSeekV2, DeepSeekV3 and DeepSeekR1, +when we slightly decrease the activation experts num in inference, +the output quality doesn't change. But the speed of decoding and prefill is speed up which is inspiring. So our showcase makes use of this finding* ## How to Run + +### v0.2.4 +We provide a server script, which supports multi-concurrency functionality in version v0.2.4. + +``` +python ktransformers/server/main.py --model_path /mnt/data/models/DeepSeek-V3 --gguf_path /mnt/data/models/DeepSeek-V3-GGUF/DeepSeek-V3-Q4_K_M/ --cpu_infer 62 --optimize_config_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml --port 10002 --chunk_size 256 --max_new_tokens 1024 --max_batch_size 4 --port 10002 --cache_lens 32768 --backend_type balance_serve +``` +It features the following arguments: + +- `--chunk_size`: Maximum number of tokens processed in a single run by the engine. +- `--cache_lens`: Total length of kvcache allocated by the scheduler. All requests share a kvcache space corresponding to 32768 tokens, and the space occupied will be released after the requests are completed. +- `--backend_type`: `balance_serve` is a multi-concurrency backend engine introduced in version v0.2.4. The original single-concurrency engine is `ktransformers`. +- `--max_batch_size`: Maximum number of requests (prefill + decode) processed in a single run by the engine. (Supported only by `balance_serve`) + ### v0.2.2 & v0.2.3 longer context & FP8 kernel + #### longer context + To use this feature, [install flashinfer](https://github.com/flashinfer-ai/flashinfer) first. Note: The latest MLA kernel in FlashInfer still has a few minor issues. They are continuously fixing them on the main branch. If you are using FlashInfer, please install it from the main source code. If you want to use long context(longer than 20K) for prefill, enable the matrix absorption MLA during the prefill phase, which will significantly reduce the size of the kv cache. Modify yaml file like this: + ``` - match: name: "^model\\.layers\\..*\\.self_attn$" @@ -175,10 +210,12 @@ If you want to use long context(longer than 20K) for prefill, enable the matrix absorb_for_prefill: True # change this to True to enable long context(prefill may slower). ``` -If the VRAM is still insufficient, try reducing the `chunk_prefill_size` parameter (default is 8192) to further decrease the intermediate results during chunk prefill. +If the VRAM is still insufficient, try reducing the `chunk_size` parameter (default is 8192) to further decrease the intermediate results during chunk prefill. + #### FP8 kernel The DeepSeek-AI team provides FP8 safetensors for DeepSeek-R1/V3 models. We achieve performance optimization through the following works: + - **FP8 GPU Kernel Integration**: FP8 linear layer acceleration kernels integrated in KTransformers - **Hybrid Quantization Architecture**: - Attention and Shared-Expert modules use FP8 precision (enhances computational accuracy) @@ -189,16 +226,20 @@ So those who are persuing the best performance can use the FP8 linear kernel for The detailed guide is [here](./fp8_kernel.md). ### V0.2 & V0.2.1 Showcase + #### Single socket version (32 cores) + Our local_chat test command is: -``` shell + +```shell numactl -N 1 -m 1 python ./ktransformers/local_chat.py --model_path --gguf_path --prompt_file --cpu_infer 33 --max_new_tokens 1000 ``` -`` can be local or set from online hugging face like deepseek-ai/DeepSeek-V3. If online encounters connection problem, try use mirror (hf-mirror.com)
+ +`` can be local or set from online huggingface like deepseek-ai/DeepSeek-V3. If online encounters connection problem, try use mirror (hf-mirror.com)
`` can also be online, but as its large we recommend you download it and quantize the model to what you want (notice it's the dir path)
`--max_new_tokens 1000` is the max output token length. If you find the answer is truncated, you -can increase the number for longer answer (But be aware of OOM, and increase it will slow down the generation rate.). +can increase the number for longer answer (But be aware of OOM, and increase it will slow down the generation rate.). The command `numactl -N 1 -m 1` aims to advoid data transfer between numa nodes
Attention! If you are testing R1 and it may skip thinking. So you can add arg: `--force_think true`. This is explained in [FAQ](#faq) part @@ -208,7 +249,8 @@ Attention! If you are testing R1 and it may skip thinking. So you can add arg: ` Make sure before you install (use install.sh or `make dev_install`), setting the env var `USE_NUMA=1` by `export USE_NUMA=1` (if already installed, reinstall it with this env var set). You may check the doc [here](./install.md) for install details.
Test Command: -``` shell + +```shell # ---For those who have not installed ktransformers--- # git clone https://github.com/kvcache-ai/ktransformers.git # cd ktransformers @@ -220,53 +262,65 @@ Test Command: python ./ktransformers/local_chat.py --model_path --gguf_path --prompt_file --cpu_infer 65 --max_new_tokens 1000 ``` + The parameters' meaning is the same. But As we use dual socket, we set cpu_infer to 65 ### V0.3 Showcase + #### Dual socket version (64 cores) + Our local_chat test command is: -``` shell + +```shell wget https://github.com/kvcache-ai/ktransformers/releases/download/v0.1.4/ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl pip install ./ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl python -m ktransformers.local_chat --model_path --gguf_path --prompt_file --cpu_infer 65 --max_new_tokens 1000 ``` + The parameters' meaning is the same with V0.2. But As we use dual socket, we set cpu_infer to 65 ## Some Explanations -1. Also we want to make further use of our two NUMA nodes on Xeon Gold cpu. -To avoid the cost of data transfer between nodes, we "copy" the critical matrix on -both nodes which takes more memory consumption but accelerates the prefill and decoding process. -But this method takes huge memory and slow when loading weights, So be patient when loading -and monitor the memory usage. We are going to optimize this huge memory overhead. Stay tuned~
-2. The command args `--cpu_infer 65` specifies how many cores to use (it's ok that it exceeds the physical number, -but it's not the more the better. Adjust it slightly lower to your actual number of cores)
+1. Also we want to make further use of our two NUMA nodes on Xeon Gold cpu. + To avoid the cost of data transfer between nodes, we "copy" the critical matrix on + both nodes which takes more memory consumption but accelerates the prefill and decoding process. + But this method takes huge memory and slow when loading weights, So be patient when loading + and monitor the memory usage. We are going to optimize this huge memory overhead. Stay tuned~
+2. The command args `--cpu_infer 65` specifies how many cores to use (it's ok that it exceeds the physical number, + but it's not the more the better. Adjust it slightly lower to your actual number of cores)
3. Why CPU/GPU Hybrid Inference? -DeepSeek's MLA operators are highly computationally intensive. While running everything on CPU is possible, offloading the heavy computations to the GPU results in a massive performance boost. - + DeepSeek's MLA operators are highly computationally intensive. While running everything on CPU is possible, offloading the heavy computations to the GPU results in a massive performance boost. 4. Where Does the Speedup Come From? - - Expert Offload: Unlike traditional layer-based or KVCache offloading (as seen in llama.cpp), we offload the expert computation to the CPU and MLA/KVCache to GPU, aligning perfectly with DeepSeek’s architecture for optimal efficiency. - - Intel AMX Optimization – Our AMX-accelerated kernel is meticulously tuned, running several times faster than existing llama.cpp implementations. We plan to open-source this kernel after cleansing and are considering upstream contributions to llama.cpp. - + - Expert Offload: Unlike traditional layer-based or KVCache offloading (as seen in llama.cpp), we offload the expert computation to the CPU and MLA/KVCache to GPU, aligning perfectly with DeepSeek’s architecture for optimal efficiency. + - Intel AMX Optimization – Our AMX-accelerated kernel is meticulously tuned, running several times faster than existing llama.cpp implementations. We plan to open-source this kernel after cleansing and are considering upstream contributions to llama.cpp. 5. Why Intel CPUs? -Intel is currently the only CPU vendor that supports AMX-like instructions, which delivers significantly better performance compared to AVX-only alternatives. + Intel is currently the only CPU vendor that supports AMX-like instructions, which delivers significantly better performance compared to AVX-only alternatives. + ## Next + ### Faster + * The FlashInfer (https://github.com/flashinfer-ai/flashinfer) project is releasing an even more efficient fused MLA operator, promising further speedups * vLLM has explored multi-token prediction in DeepSeek-V3, and support is on our roadmap for even better performance * We are collaborating with Intel to enhance the AMX kernel (v0.3) and optimize for Xeon6/MRDIMM + ### Easier + * Official Docker images to simplify installation * Fix the server integration for web API access * Fix the local chat only accepting a single line prompt (currently \n begins generating prompt) * Support for more quantization types, including the highly requested dynamic quantization from unsloth -Stay tuned for more updates! +Stay tuned for more updates! + ## FAQ + ### R1 No Thinking + Attention! If you are testing R1 and it may skip thinking. So you can add arg: `--force_think true`. The detail is in [FAQ](./FAQ.md) part
### More FAQ + [See detail](./FAQ.md) diff --git a/doc/en/install.md b/doc/en/install.md index fc22023..03b14b3 100644 --- a/doc/en/install.md +++ b/doc/en/install.md @@ -1,20 +1,25 @@ + # How to Run DeepSeek-R1 + - [Preparation](#preparation) - [Installation](#installation) - [Attention](#attention) - [Supported models include:](#supported-models-include) - [Support quantize format:](#support-quantize-format) -In this document, we will show you how to install and run KTransformers on your local machine. There are two versions: +In this document, we will show you how to install and run KTransformers on your local machine. There are two versions: + * V0.2 is the current main branch. * V0.3 is a preview version only provides binary distribution for now. * To reproduce our DeepSeek-R1/V3 results, please refer to [Deepseek-R1/V3 Tutorial](./DeepseekR1_V3_tutorial.md) for more detail settings after installation. + ## Preparation + Some preparation: - CUDA 12.1 and above, if you didn't have it yet, you may install from [here](https://developer.nvidia.com/cuda-downloads). - + ```sh # Adding CUDA to PATH if [ -d "/usr/local/cuda/bin" ]; then @@ -32,39 +37,42 @@ Some preparation: export CUDA_PATH=$CUDA_PATH:/usr/local/cuda fi ``` - - Linux-x86_64 with gcc, g++ and cmake (using Ubuntu as an example) - - ```sh - sudo apt-get update - sudo apt-get install build-essential cmake ninja-build - ``` + ```sh + sudo apt-get update + sudo apt-get install build-essential cmake ninja-build patchelf + ``` - We recommend using [Miniconda3](https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh) or [Anaconda3](https://repo.anaconda.com/archive/Anaconda3-2024.10-1-Linux-x86_64.sh) to create a virtual environment with Python=3.11 to run our program. Assuming your Anaconda installation directory is `~/anaconda3`, you should ensure that the version identifier of the GNU C++standard library used by Anaconda includes `GLIBCXX-3.4.32` - ```sh conda create --name ktransformers python=3.11 conda activate ktransformers # you may need to run ‘conda init’ and reopen shell first - + conda install -c conda-forge libstdcxx-ng # Anaconda provides a package called `libstdcxx-ng` that includes a newer version of `libstdc++`, which can be installed via `conda-forge`. strings ~/anaconda3/envs/ktransformers/lib/libstdc++.so.6 | grep GLIBCXX ``` - - Make sure that PyTorch, packaging, ninja is installed You can also [install previous versions of PyTorch](https://pytorch.org/get-started/previous-versions/) - + ``` pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 pip3 install packaging ninja cpufeature numpy ``` - - - At the same time, you should download and install the corresponding version of flash-attention from https://github.com/Dao-AILab/flash-attention/releases. +- At the same time, you should download and install the corresponding version of flash-attention from https://github.com/Dao-AILab/flash-attention/releases. ## Installation + ### Attention + If you want to use numa support, not only do you need to set USE_NUMA=1, but you also need to make sure you have installed the libnuma-dev (`sudo apt-get install libnuma-dev` may help you). +[Optional] If you want to use the multi-concurrent version, please install the following dependencies. + +``` +sudo apt install libtbb-dev libssl-dev libcurl4-openssl-dev libaio1 libaio-dev libgflags-dev zlib1g-dev libfmt-dev +``` + + * Download source code and compile: - - - init source code - - ```sh - git clone https://github.com/kvcache-ai/ktransformers.git - cd ktransformers - git submodule init - git submodule update - ``` - - [Optional] If you want to run with website, please [compile the website](./api/server/website.md) before execute ```bash install.sh``` + - init source code - - For Linux - - For simple install: - - ```shell - bash install.sh - ``` - - For those who have two cpu and 1T RAM: + ```sh + git clone https://github.com/kvcache-ai/ktransformers.git + cd ktransformers + git submodule update --init --recursive + ``` + - [Optional] If you want to run with website, please [compile the website](./api/server/website.md) before execute ``bash install.sh`` + - For Linux - ```shell - # Make sure your system has dual sockets and double size RAM than the model's size (e.g. 1T RAM for 512G model) - apt install libnuma-dev - export USE_NUMA=1 - bash install.sh # or #make dev_install - ``` + - For simple install: - - For Windows - - ```shell - install.bat - ``` + ```shell + bash install.sh + ``` + - For those who have two cpu and 1T RAM: -* If you are developer, you can make use of the makefile to compile and format the code.
the detailed usage of makefile is [here](./makefile_usage.md) + ```shell + # Make sure your system has dual sockets and double size RAM than the model's size (e.g. 1T RAM for 512G model) + apt install libnuma-dev + export USE_NUMA=1 + bash install.sh # or #make dev_install + ``` + - For Multi-concurrency with 500G RAM: + + ```shell + sudo env USE_BALANCE_SERVE=1 PYTHONPATH="\$(which python)" PATH="\$(dirname \$(which python)):\$PATH" bash ./install.sh + ``` + - For Multi-concurrency with two cpu and 1T RAM: + + ```shell + sudo env USE_BALANCE_SERVE=1 USE_NUMA=1 PYTHONPATH="\$(which python)" PATH="\$(dirname \$(which python)):\$PATH" bash ./install.sh + ``` + - For Windows + + ```shell + install.bat + ``` +* If you are developer, you can make use of the makefile to compile and format the code.
the detailed usage of makefile is [here](./makefile_usage.md)

Local Chat

We provide a simple command-line local chat Python script that you can run for testing. -> Note: this is a very simple test tool only support one round chat without any memory about last input, if you want to try full ability of the model, you may go to [RESTful API and Web UI](#id_666). +> Note: this is a very simple test tool only support one round chat without any memory about last input, if you want to try full ability of the model, you may go to [RESTful API and Web UI](#id_666).

Run Example

@@ -141,57 +156,70 @@ python -m ktransformers.local_chat --model_path deepseek-ai/DeepSeek-V2-Lite-Cha # GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite # python ktransformers.local_chat --model_path ./DeepSeek-V2-Lite --gguf_path ./DeepSeek-V2-Lite-Chat-GGUF ``` - It features the following arguments: -- `--model_path` (required): Name of the model (such as "deepseek-ai/DeepSeek-V2-Lite-Chat" which will automatically download configs from [Hugging Face](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite)). Or if you already got local files you may directly use that path to initialize the model. - +- `--model_path` (required): Name of the model (such as "deepseek-ai/DeepSeek-V2-Lite-Chat" which will automatically download configs from [Hugging Face](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite)). Or if you already got local files you may directly use that path to initialize the model. + > Note: .safetensors files are not required in the directory. We only need config files to build model and tokenizer. - + > - `--gguf_path` (required): Path of a directory containing GGUF files which could that can be downloaded from [Hugging Face](https://huggingface.co/mzwing/DeepSeek-V2-Lite-Chat-GGUF/tree/main). Note that the directory should only contains GGUF of current model, which means you need one separate directory for each model. - - `--optimize_config_path` (required except for Qwen2Moe and DeepSeek-V2): Path of YAML file containing optimize rules. There are two rule files pre-written in the [ktransformers/optimize/optimize_rules](ktransformers/optimize/optimize_rules) directory for optimizing DeepSeek-V2 and Qwen2-57B-A14, two SOTA MoE models. - - `--max_new_tokens`: Int (default=1000). Maximum number of new tokens to generate. - - `--cpu_infer`: Int (default=10). The number of CPUs used for inference. Should ideally be set to the (total number of cores - 2). +

Start Server

+We provide a server script, which supports multi-concurrency functionality in version v0.2.4. + +``` +python ktransformers/server/main.py --model_path /mnt/data/models/DeepSeek-V3 --gguf_path /mnt/data/models/DeepSeek-V3-GGUF/DeepSeek-V3-Q4_K_M/ --cpu_infer 62 --optimize_config_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml --port 10002 --chunk_size 256 --max_new_tokens 1024 --max_batch_size 4 --port 10002 --cache_lens 32768 --backend_type balance_serve +``` +It features the following arguments: + +- `--chunk_size`: Maximum number of tokens processed in a single run by the engine. +- `--cache_lens`: Total length of kvcache allocated by the scheduler. All requests share a kvcache space corresponding to 32768 tokens, and the space occupied will be released after the requests are completed. +- `--backend_type`: `balance_serve` is a multi-concurrency backend engine introduced in version v0.2.4. The original single-concurrency engine is `ktransformers`. +- `--max_batch_size`: Maximum number of requests (prefill + decode) processed in a single run by the engine. (Supported only by `balance_serve`) +
Supported Models/quantization ### Supported models include: -| ✅ **Supported Models** | ❌ **Deprecated Models** | -|------------------------|------------------------| -| DeepSeek-R1 | ~~InternLM2.5-7B-Chat-1M~~ | -| DeepSeek-V3 | | -| DeepSeek-V2 | | -| DeepSeek-V2.5 | | -| Qwen2-57B | | -| DeepSeek-V2-Lite | | -| Mixtral-8x7B | | -| Mixtral-8x22B | | + +| ✅**Supported Models** | ❌**Deprecated Models** | +| ---------------------- | -------------------------- | +| DeepSeek-R1 | ~~InternLM2.5-7B-Chat-1M~~ | +| DeepSeek-V3 | | +| DeepSeek-V2 | | +| DeepSeek-V2.5 | | +| Qwen2-57B | | +| DeepSeek-V2-Lite | | +| Mixtral-8x7B | | +| Mixtral-8x22B | | ### Support quantize format: -| ✅ **Supported Formats** | ❌ **Deprecated Formats** | -|--------------------------|--------------------------| -| Q2_K_L | ~~IQ2_XXS~~ | -| Q2_K_XS | | -| Q3_K_M | | -| Q4_K_M | | -| Q5_K_M | | -| Q6_K | | -| Q8_0 | | + +| ✅**Supported Formats** | ❌**Deprecated Formats** | +| ----------------------- | ------------------------ | +| Q2_K_L | ~~IQ2_XXS~~ | +| Q2_K_XS | | +| Q3_K_M | | +| Q4_K_M | | +| Q5_K_M | | +| Q6_K | | +| Q8_0 | | +
Suggested Model + | Model Name | Model Size | VRAM | Minimum DRAM | Recommended DRAM | | ------------------------------ | ---------- | ----- | --------------- | ----------------- | -| DeepSeek-R1-q4_k_m | 377G | 14G | 382G | 512G | -| DeepSeek-V3-q4_k_m | 377G | 14G | 382G | 512G | +| DeepSeek-R1-q4_k_m | 377G | 14G | 382G | 512G | +| DeepSeek-V3-q4_k_m | 377G | 14G | 382G | 512G | | DeepSeek-V2-q4_k_m | 133G | 11G | 136G | 192G | | DeepSeek-V2.5-q4_k_m | 133G | 11G | 136G | 192G | | DeepSeek-V2.5-IQ4_XS | 117G | 10G | 107G | 128G | @@ -201,12 +229,11 @@ It features the following arguments: | Mixtral-8x22B-q4_k_m | 80G | 4G | 86.1G | 96G | | InternLM2.5-7B-Chat-1M | 15.5G | 15.5G | 8G(32K context) | 150G (1M context) | - -More will come soon. Please let us know which models you are most interested in. +More will come soon. Please let us know which models you are most interested in. Be aware that you need to be subject to their corresponding model licenses when using [DeepSeek](https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/main/LICENSE) and [QWen](https://huggingface.co/Qwen/Qwen2-72B-Instruct/blob/main/LICENSE). -
+
Click To Show how to run other examples @@ -228,9 +255,8 @@ Be aware that you need to be subject to their corresponding model licenses when # GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Qwen/Qwen2-57B-A14B-Instruct # python ktransformers/local_chat.py --model_path ./Qwen2-57B-A14B-Instruct --gguf_path ./DeepSeek-V2-Lite-Chat-GGUF ``` - * Deepseek-V2 - + ```sh mkdir DeepSeek-V2-Chat-0628-GGUF && cd DeepSeek-V2-Chat-0628-GGUF # Download weights @@ -250,40 +276,38 @@ Be aware that you need to be subject to their corresponding model licenses when # python -m ktransformers.local_chat --model_path ./DeepSeek-V2-Chat-0628 --gguf_path ./DeepSeek-V2-Chat-0628-GGUF ``` -| model name | weights download link | -|----------|----------| -| Qwen2-57B | [Qwen2-57B-A14B-gguf-Q4K-M](https://huggingface.co/Qwen/Qwen2-57B-A14B-Instruct-GGUF/tree/main) | -| DeepseekV2-coder |[DeepSeek-Coder-V2-Instruct-gguf-Q4K-M](https://huggingface.co/LoneStriker/DeepSeek-Coder-V2-Instruct-GGUF/tree/main) | -| DeepseekV2-chat |[DeepSeek-V2-Chat-gguf-Q4K-M](https://huggingface.co/bullerwins/DeepSeek-V2-Chat-0628-GGUF/tree/main) | -| DeepseekV2-lite | [DeepSeek-V2-Lite-Chat-GGUF-Q4K-M](https://huggingface.co/mzwing/DeepSeek-V2-Lite-Chat-GGUF/tree/main) | -| DeepSeek-R1 | [DeepSeek-R1-gguf-Q4K-M](https://huggingface.co/unsloth/DeepSeek-R1-GGUF/tree/main/DeepSeek-R1-Q4_K_M) | + +| model name | weights download link | +| ---------------- | --------------------------------------------------------------------------------------------------------------------- | +| Qwen2-57B | [Qwen2-57B-A14B-gguf-Q4K-M](https://huggingface.co/Qwen/Qwen2-57B-A14B-Instruct-GGUF/tree/main) | +| DeepseekV2-coder | [DeepSeek-Coder-V2-Instruct-gguf-Q4K-M](https://huggingface.co/LoneStriker/DeepSeek-Coder-V2-Instruct-GGUF/tree/main) | +| DeepseekV2-chat | [DeepSeek-V2-Chat-gguf-Q4K-M](https://huggingface.co/bullerwins/DeepSeek-V2-Chat-0628-GGUF/tree/main) | +| DeepseekV2-lite | [DeepSeek-V2-Lite-Chat-GGUF-Q4K-M](https://huggingface.co/mzwing/DeepSeek-V2-Lite-Chat-GGUF/tree/main) | +| DeepSeek-R1 | [DeepSeek-R1-gguf-Q4K-M](https://huggingface.co/unsloth/DeepSeek-R1-GGUF/tree/main/DeepSeek-R1-Q4_K_M) |
- + +

RESTful API and Web UI

- Start without website: ```sh ktransformers --model_path deepseek-ai/DeepSeek-V2-Lite-Chat --gguf_path /path/to/DeepSeek-V2-Lite-Chat-GGUF --port 10002 ``` - Start with website: ```sh ktransformers --model_path deepseek-ai/DeepSeek-V2-Lite-Chat --gguf_path /path/to/DeepSeek-V2-Lite-Chat-GGUF --port 10002 --web True ``` - Or you want to start server with transformers, the model_path should include safetensors ```bash ktransformers --type transformers --model_path /mnt/data/model/Qwen2-0.5B-Instruct --port 10002 --web True ``` - Access website with url [http://localhost:10002/web/index.html#/chat](http://localhost:10002/web/index.html#/chat) :

diff --git a/doc/zh/DeepseekR1_V3_tutorial_zh.md b/doc/zh/DeepseekR1_V3_tutorial_zh.md index ba9d7e8..17b51cd 100644 --- a/doc/zh/DeepseekR1_V3_tutorial_zh.md +++ b/doc/zh/DeepseekR1_V3_tutorial_zh.md @@ -1,26 +1,28 @@ + # GPT-4/o1 级别本地 VSCode Copilot 在仅 24GB 显存的台式机上的表现 + - [摘要](#摘要) - - [先决条件](#先决条件) - - [基准测试结果](#基准测试结果) - - [V0.2](#v02) - - [设置](#设置) - - [内存占用](#内存占用) - - [基准测试结果](#基准测试结果) - - [V0.3-Preview](#V0.3-Preview) - - [设置](#设置-1) - - [内存占用](#内存占用-1) - - [基准测试结果](#基准测试结果-1) - - [如何运行](#如何运行) - - [V0.2 展示](#v02-展示) - - [单插槽版本 (32 核心)](#单插槽版本(32 核心)) - - [双插槽版本 (64 核心)](#双插槽版本(64 核心)) - - [V0.3 展示](#v03-展示) - - [双插槽版本 (64 核心)](#双插槽版本(64 核心)-1) - - [一些解释](#一些解释) - - [常见问题解答](#常见问题解答) - - [R1 不思考](#R1 不返回思考过程) - - [更多常见问题解答](#更多常见问题解答) + - [先决条件](#先决条件) + - [基准测试结果](#基准测试结果) + - [V0.2](#v02) + - [设置](#设置) + - [内存占用](#内存占用) + - [基准测试结果](#基准测试结果) + - [V0.3-Preview](#V0.3-Preview) + - [设置](#设置-1) + - [内存占用](#内存占用-1) + - [基准测试结果](#基准测试结果-1) + - [如何运行](#如何运行) + - [V0.2 展示](#v02-展示) + - [单插槽版本 (32 核心)](#单插槽版本(32 核心)) + - [双插槽版本 (64 核心)](#双插槽版本(64 核心)) + - [V0.3 展示](#v03-展示) + - [双插槽版本 (64 核心)](#双插槽版本(64 核心)-1) + - [一些解释](#一些解释) + - [常见问题解答](#常见问题解答) + - [R1 不思考](#R1 不返回思考过程) + - [更多常见问题解答](#更多常见问题解答) # 摘要 @@ -37,74 +39,125 @@ https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285

- **[NEW!!!] 本地 671B DeepSeek-Coder-V3/R1:** 仅使用 14GB 显存和 382GB 内存运行其 Q4_K_M 版本。 - - 预填充(Prefill)速度 (tokens/s): - - KTransformers: 54.21 (32 核心) → 74.362 (双插槽,2×32 核心) → 255.26 (优化的 AMX 基 MoE 内核,仅 V0.3) → 286.55 (选择性使用 6 个专家,仅 V0.3) - - 与 llama.cpp 在 2×32 核心下 10.31 tokens/s 相比,速度提升高达 **27.79 倍** - - 解码(Decode)速度 (tokens/s): - - KTransformers: 8.73 (32 核心) → 11.26 (双插槽, 2×32 核心) → 13.69 (选择性使用 6 个专家,仅 V0.3) - - 与 llama.cpp 在 2×32 核心下 4.51 tokens/s 相比,速度提升高达 **3.03 倍** + - 预填充(Prefill)速度 (tokens/s): + - KTransformers: 54.21 (32 核心) → 74.362 (双插槽,2×32 核心) → 255.26 (优化的 AMX 基 MoE 内核,仅 V0.3) → 286.55 (选择性使用 6 个专家,仅 V0.3) + - 与 llama.cpp 在 2×32 核心下 10.31 tokens/s 相比,速度提升高达 **27.79 倍** + - 解码(Decode)速度 (tokens/s): + - KTransformers: 8.73 (32 核心) → 11.26 (双插槽, 2×32 核心) → 13.69 (选择性使用 6 个专家,仅 V0.3) + - 与 llama.cpp 在 2×32 核心下 4.51 tokens/s 相比,速度提升高达 **3.03 倍** - 我们还提供了即将推出的优化预览,包括英特尔 AMX 加速内核和选择性专家激活方法,这将显著提升性能。通过 V0.3 预览版,我们在预填充方面实现了高达 286 tokens/s 的速度,比本地推理的 llama.cpp **快 28 倍**。二进制发行版现已可用,源代码即将推出!请查看 wheel 包 [此处](https://github.com/kvcache-ai/ktransformers/releases/download/v0.1.4/ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl) 。 - ## 先决条件 + 我们在以下配置下进行了最佳性能测试(V0.2):
CPU: Intel (R) Xeon (R) Gold 6454S 1T 内存 (2 NUMA 节点)
GPU: 4090D 24G 显存
内存: 标准 DDR5-4800 服务器内存 (1 TB) + ## 基准测试结果 + ### V0.2 + #### 设置 + - Model: DeepseekV3-q4km (int4)
- CPU: cpu_model_name: Intel (R) Xeon (R) Gold 6454S,每个插槽 32 核心,2 个插槽,2 个 NUMA 节点 - GPU: 4090D 24G 显存 - 我们在充分预热后进行测试 + #### 内存占用: - - 单插槽: 382G 内存,至少 14GB 显存 - - 双插槽: 1T 内存,至少 14GB 显存 + +- 单插槽: 382G 内存,至少 14GB 显存 +- 双插槽: 1T 内存,至少 14GB 显存 #### 基准测试结果 “6 个专家” 情况是 V0.3 预览版中内容 -| Prompt
(500 tokens) | 双插槽 Ktrans (6 个专家) | 双插槽 Ktrans (8 个专家) | Single socket Ktrans (6 个专家) | Single socket Ktrans (8 个专家)| llama.cpp (8 个专家) | -|------------------------| --- | --- | --- | --- | --- | -| 预填充(Prefill) token/s | 97.32 | 82.94 | 65.14 | 54.21 | 10.31 | -| 解码(Decode) token/s | 13.69 | 12.208 | 10.303 | 8.73 |4.51 | + +| Prompt
(500 tokens) | 双插槽 Ktrans (6 个专家) | 双插槽 Ktrans (8 个专家) | Single socket Ktrans (6 个专家) | Single socket Ktrans (8 个专家) | llama.cpp (8 个专家) | +| ----------------------- | ------------------------ | ------------------------ | ------------------------------- | ------------------------------- | -------------------- | +| 预填充(Prefill) token/s | 97.32 | 82.94 | 65.14 | 54.21 | 10.31 | +| 解码(Decode) token/s | 13.69 | 12.208 | 10.303 | 8.73 | 4.51 | **最高加速比在解码方面达到 3.03x 倍,在预填充方面达到 9.44x 倍。** ### V0.3-Preview + #### 设置 + - Model: DeepseekV3-BF16 (在线量化为 CPU 的 int8 和 GPU 的 int4) - CPU: cpu_model_name: Intel (R) Xeon (R) Gold 6454S,每个插槽 32 核心,2 个插槽,2 个 NUMA 节点 - GPU: (1~4)x 4090D 24G 显存 (更长的 prompt 需要更多显存) #### 内存占用: + - 644GB 内存,至少 14GB 显存 #### 基准测试结果 -| Prompt length | 1K | 2K | 4K | 8K | -|---------------|-----|-----|-----|-----| -| KTrans (8 个专家) Prefill token/s | 185.96 | 255.26 | 252.58 | 195.62 | -| KTrans (6 个专家) Prefill token/s | 203.70 | 286.55 | 271.08 | 207.20 | + + +| Prompt length | 1K | 2K | 4K | 8K | +| --------------------------------- | ------ | ------ | ------ | ------ | +| KTrans (8 个专家) Prefill token/s | 185.96 | 255.26 | 252.58 | 195.62 | +| KTrans (6 个专家) Prefill token/s | 203.70 | 286.55 | 271.08 | 207.20 | **KTrans V0.3 的预填充速度比 KTrans V0.2 快 3.45x 倍,比 llama.cpp 快 27.79x 倍。** **解码速度与 KTrans V0.2(6 个专家版本)相同,因此省略。** -主要加速来自于 +主要加速来自于 + - 英特尔 AMX 指令集和我们专门设计的缓存友好内存布局 - 专家选择策略,根据离线配置文件结果选择更少的专家 - *从我们对 DeepSeekV2、DeepSeekV3 和 DeepSeekR1 的研究中,当我们略微减少推理中的激活专家数量时,输出质量没有变化。但解码和预填充的速度加快了,这令人鼓舞。因此,我们的展示利用了这一发现。* ## 如何运行 + +### 多并发展示 + +多并发需要额外编译调度器 c++ 代码 + +```shell +sudo apt install libtbb-dev libssl-dev libcurl4-openssl-dev libaio1 libaio-dev libfmt-dev +sudo apt-get install libgflags-dev zlib1g-dev patchelf +git clone https://github.com/kvcache-ai/ktransformers.git +cd ktransformers +git submodule update --init --recursive +# 如果使用双 numa 版本 +sudo env USE_BALANCE_SERVE=1 USE_NUMA=1 PYTHONzPATH="$(which python)" PATH="$(dirname $(which python)):$PATH" bash ./install.sh +# 如果使用单 numa 版本 +sudo env USE_BALANCE_SERVE=1 PYTHONzPATH="$(which python)" PATH="$(dirname $(which python)):$PATH" bash ./install.sh +# 启动命令 +python ktransformers/server/main.py --model_path --gguf_path --cpu_infer 62 --optimize_config_path --port 10002 --chunk_size 256 --max_new_tokens 1024 --max_batch_size 4 --port 10002 --cache_lens 32768 --backend_type balance_serve +``` + +`` 可以是本地路径,也可以是在线路径,例如 deepseek-ai/DeepSeek-V3。如果在线连接出现问题,可以尝试使用镜像(hf-mirror.com)
+`` 也可以是在线路径,但由于其体积较大,我们建议您下载并量化模型(注意这是目录路径) + +`` 注入规则 yaml 文件地址,我们在 `ktransformers/optimize/optimize_rules/ ` 目录下提供了 `DeepSeek-V3-Chat-serve.yaml` 和 `DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve.yaml` 分别对应 [`DeepSeek-V3/R1-q4km`](https://huggingface.co/unsloth/DeepSeek-R1-GGUF/tree/main/DeepSeek-R1-Q4_K_M) 和 [`DeepSeek-V3/R1-hybrid`](https://huggingface.co/KVCache-ai/DeepSeek-R1-GGML-FP8-Hybrid/tree/main) + +`--max_new_tokens 1000` 是最大输出 token 长度。如果发现答案被截断,可以增加此数字以获得更长的答案(但要注意内存不足问题,增加此数字会降低生成速度). + +`--chunk_size 256` 引擎单次运行最大 token 个数 + +`--cache_lens 32768` 调度器申请 kvcache 的总长度。所有请求共享 32768 个 tokens 对应 kvcache 空间,请求完成后会释放其所占用的 kvcache 空间。 + +`--backend_type balance_serve` `balance_serve`是 v0.2.4新增的后端引擎,原本的单并发引擎为`ktransformers` + +`--max_batch_size 4` 引擎单次运行最多处理 4 个请求(prefill + decode),(仅用于`balance_serve`) + +
命令 numactl -N 1 -m 1 的目的是避免 NUMA 节点之间的数据传输
+注意!如果测试 R1 可能会跳过思考。因此,可以添加参数:`--force_think`,这在 [常见问题解答](#常见问题解答) 部分中解释。 + ### V0.2 展示 + #### 单插槽版本(32 核心) + 我们的 local_chat 测试命令是: -``` shell + +```shell git clone https://github.com/kvcache-ai/ktransformers.git cd ktransformers git submodule init @@ -112,17 +165,13 @@ git submodule update numactl -N 1 -m 1 python ./ktransformers/local_chat.py --model_path --gguf_path --prompt_file --cpu_infer 33 --max_new_tokens 1000 <当您看到聊天时,按回车键加载文本提示文件> ``` -`` 可以是本地路径,也可以是在线路径,例如 deepseek-ai/DeepSeek-V3。如果在线连接出现问题,可以尝试使用镜像(hf-mirror.com)
-`` 也可以是在线路径,但由于其体积较大,我们建议您下载并量化模型(注意这是目录路径)
-`--max_new_tokens 1000` 是最大输出 token 长度。如果发现答案被截断,可以增加此数字以获得更长的答案(但要注意内存不足问题,增加此数字会降低生成速度). -
-命令 numactl -N 1 -m 1 的目的是避免 NUMA 节点之间的数据传输
-注意!如果测试 R1 可能会跳过思考。因此,可以添加参数:`--force_think true`,这在 [常见问题解答](#常见问题解答) 部分中解释。 #### 双插槽版本(64 核心) + 在安装之前(使用 install.sh 或 `make dev_install`),请确保设置环境变量 `USE_NUMA=1`,方法是 `export USE_NUMA=1`(如果已经安装,请重新安装并设置此环境变量)
我们的 local_chat 测试命令是: -``` shell + +```shell git clone https://github.com/kvcache-ai/ktransformers.git cd ktransformers git submodule init @@ -132,42 +181,48 @@ make dev_install # or sh ./install.sh python ./ktransformers/local_chat.py --model_path --gguf_path --prompt_file --cpu_infer 65 --max_new_tokens 1000 <当您看到聊天时,按回车键加载文本提示文件> ``` + 参数的含义相同。但因为我们使用双插槽,所以将 cpu_infer 设置为 65。 ### V0.3 展示 + #### 双插槽版本(64 核心) + 我们的 local_chat 测试命令是: -``` shell + +```shell wget https://github.com/kvcache-ai/ktransformers/releases/download/v0.1.4/ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl pip install ./ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl python -m ktransformers.local_chat --model_path --gguf_path --prompt_file --cpu_infer 65 --max_new_tokens 1000 <当您看到聊天时,按回车键加载文本提示文件> ``` + 参数的含义与 V0.2 相同。但因为我们使用双插槽,所以将 cpu_infer 设置为 65。 ## 一些解释 + 1. 我们还想进一步利用 Xeon Gold CPU 上的两个 NUMA 节点。为了避免节点之间的数据传输成本,我们在两个节点上 "copy" 了关键矩阵,这会增加内存占用,但会加速预填充和解码过程。但这种方法占用大量内存,加载权重时速度较慢,因此加载时请耐心等待并监控内存使用情况。我们计划优化这一巨大的内存开销。敬请期待。 - 2. 命令参数 `--cpu_infer 65` 指定使用多少核心(超过物理核心数量是可以的,但并不是越多越好。根据实际核心数量适当降低此值)。
- 3. 为什么使用 CPU/GPU 混合推理? -DeepSeek 的 MLA 操作符计算密集。虽然全部在 CPU 上运行是可行的,但将繁重的计算任务卸载到 GPU 上能带来巨大的性能提升。 - + DeepSeek 的 MLA 操作符计算密集。虽然全部在 CPU 上运行是可行的,但将繁重的计算任务卸载到 GPU 上能带来巨大的性能提升。 4. 加速来自哪里? - 专家卸载:与传统的基于层或 KVCache 卸载(如 llama.cpp 中的)不同,我们将专家计算卸载到 CPU,将 MLA/KVCache 卸载到 GPU,与 DeepSeek 的架构完美对齐,实现最佳效率。 - - 英特尔 AMX 优化 – 我们的 AMX 加速内核经过精心调优,运行速度是现有 llama.cpp 实现的数倍。我们计划在清理后开源此内核,并考虑向 llama.cpp 上游贡献代码。 - + - 英特尔 AMX 优化 – 我们的 AMX 加速内核经过精心调优,运行速度是现有 llama.cpp 实现的数倍。我们计划在清理后开源此内核,并考虑向 llama.cpp 上游贡献代码。 5. 为什么选择英特尔 CPU? -英特尔目前是唯一支持 AMX 类似指令的 CPU 供应商,与仅支持 AVX 的替代方案相比,性能显著更好。 + 英特尔目前是唯一支持 AMX 类似指令的 CPU 供应商,与仅支持 AVX 的替代方案相比,性能显著更好。 ## 常见问题解答 + ### R1 不返回思考过程 + 注意!如果测试 R1 可能会跳过思考。因此,可以添加参数:`--force_think true`。详细信息在 [常见问题解答](./FAQ.md) 部分中。
## 问题 + * 修复服务器集成功能以实现网络API访问支持 * 修复本地聊天功能仅支持单行提示输入的问题(目前输入换行符(\n)即开始生成提示) ### 更多常见问题解答 + [详见](./FAQ.md) diff --git a/install.sh b/install.sh index c5773ec..f80e552 100644 --- a/install.sh +++ b/install.sh @@ -4,14 +4,23 @@ set -e # clear build dirs rm -rf build rm -rf *.egg-info -rm -rf ktransformers/ktransformers_ext/build -rm -rf ktransformers/ktransformers_ext/cuda/build -rm -rf ktransformers/ktransformers_ext/cuda/dist -rm -rf ktransformers/ktransformers_ext/cuda/*.egg-info - +rm -rf csrc/build +rm -rf csrc/ktransformers_ext/build +rm -rf csrc/ktransformers_ext/cuda/build +rm -rf csrc/ktransformers_ext/cuda/dist +rm -rf csrc/ktransformers_ext/cuda/*.egg-info +rm -rf ~/.ktransformers echo "Installing python dependencies from requirements.txt" pip install -r requirements-local_chat.txt - +pip install -r ktransformers/server/requirements.txt echo "Installing ktransformers" -KTRANSFORMERS_FORCE_BUILD=TRUE pip install . --no-build-isolation +KTRANSFORMERS_FORCE_BUILD=TRUE pip install -v . --no-build-isolation +pip install third_party/custom_flashinfer/ + +SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])") +echo "Copying thirdparty libs to $SITE_PACKAGES" +cp -a csrc/balance_serve/build/third_party/prometheus-cpp/lib/libprometheus-cpp-*.so* $SITE_PACKAGES/ +patchelf --set-rpath '$ORIGIN' $SITE_PACKAGES/sched_ext.cpython* + + echo "Installation completed successfully" \ No newline at end of file diff --git a/ktransformers/configs/config.yaml b/ktransformers/configs/config.yaml index 80de09a..4144883 100644 --- a/ktransformers/configs/config.yaml +++ b/ktransformers/configs/config.yaml @@ -21,7 +21,8 @@ user: model: # type: transformers - type: ktransformers + type: balance_serve + # type: ktransformers name: DeepSeek-Coder-V2-Instruct path: deepseek-ai/DeepSeek-V2-Lite-Chat @@ -29,7 +30,7 @@ model: device: cuda:0 cache_lens: 8192 - + max_new_tokens: 500 web: mount: False open_cross_domain: True @@ -38,7 +39,6 @@ ext: cpu_infer: 10 long_context: - chunk_size: 4096 max_seq_len: 32000 block_size: 128 local_windows_len: 4096 @@ -54,4 +54,19 @@ long_context: token_step: local_chat: - prompt_file: "" \ No newline at end of file + prompt_file: "" + +async_server: + sched_strategy: "FCFS" + sched_port: 56441 + sched_metrics_port: 54321 + kvc2_metrics_port: 54391 + max_batch_size: 4 # decode count + prefill count, in one mini batch + +attn: + page_size: 256 + chunk_size: 256 +kvc2: + gpu_only: true + utilization_percentage: 1.0 + cpu_memory_size_GB: 500 diff --git a/ktransformers/configs/model_configs.json b/ktransformers/configs/model_configs.json new file mode 100644 index 0000000..6ce80b0 --- /dev/null +++ b/ktransformers/configs/model_configs.json @@ -0,0 +1,122 @@ +{ + "DeepSeek-Coder-V2-Instruct": { + "hidden_size": 5120, + "intermediate_size": 12288, + "max_position_embeddings": 163840, + "model_type": "deepseek_v2", + "num_attention_heads": 128, + "num_hidden_layers": 60, + "num_key_value_heads": 128, + "vocab_size": 102400 + }, + "DeepSeek-R1": { + "hidden_size": 7168, + "intermediate_size": 18432, + "max_position_embeddings": 163840, + "model_type": "deepseek_v3", + "num_attention_heads": 128, + "num_hidden_layers": 61, + "num_key_value_heads": 128, + "vocab_size": 129280 + }, + "DeepSeek-V2-Lite-Chat": { + "hidden_size": 2048, + "intermediate_size": 10944, + "max_position_embeddings": 163840, + "model_type": "deepseek_v2", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "num_key_value_heads": 16, + "vocab_size": 102400 + }, + "DeepSeek-V3": { + "hidden_size": 7168, + "intermediate_size": 18432, + "max_position_embeddings": 163840, + "model_type": "deepseek_v3", + "num_attention_heads": 128, + "num_hidden_layers": 3, + "num_key_value_heads": 128, + "vocab_size": 129280 + }, + "DeepSeek-V3-bf16": { + "hidden_size": 7168, + "intermediate_size": 18432, + "max_position_embeddings": 163840, + "model_type": "deepseek_v3", + "num_attention_heads": 128, + "num_hidden_layers": 61, + "num_key_value_heads": 128, + "vocab_size": 129280 + }, + "LLaMA-2-7B-32K": { + "hidden_size": 4096, + "intermediate_size": 11008, + "max_position_embeddings": 32768, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "vocab_size": 32000 + }, + "Moonlight-16B-A3B-Instruct": { + "hidden_size": 2048, + "intermediate_size": 11264, + "max_position_embeddings": 8192, + "model_type": "deepseek_v3", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "num_key_value_heads": 16, + "vocab_size": 163840 + }, + "Qwen2.5-32B-Instruct": { + "hidden_size": 5120, + "intermediate_size": 27648, + "max_position_embeddings": 32768, + "model_type": "qwen2", + "num_attention_heads": 40, + "num_hidden_layers": 64, + "num_key_value_heads": 8, + "vocab_size": 152064 + }, + "Qwen2.5-32B-Instruct-GPTQ-Int4": { + "hidden_size": 5120, + "intermediate_size": 27648, + "max_position_embeddings": 32768, + "model_type": "qwen2", + "num_attention_heads": 40, + "num_hidden_layers": 64, + "num_key_value_heads": 8, + "vocab_size": 152064 + }, + "Qwen2.5-7B-Instruct": { + "hidden_size": 3584, + "intermediate_size": 18944, + "max_position_embeddings": 32768, + "model_type": "qwen2", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "vocab_size": 152064 + }, + "Qwen2.5-7B-Instruct-GPTQ-Int4": { + "hidden_size": 3584, + "intermediate_size": 18944, + "max_position_embeddings": 32768, + "model_type": "qwen2", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "vocab_size": 152064 + }, + "qwen2-72b-instruct": { + "hidden_size": 8192, + "intermediate_size": 29568, + "max_position_embeddings": 32768, + "model_type": "qwen2", + "num_attention_heads": 64, + "num_hidden_layers": 80, + "num_key_value_heads": 8, + "vocab_size": 152064 + } +} \ No newline at end of file diff --git a/ktransformers/configs/quant_configs.json b/ktransformers/configs/quant_configs.json new file mode 100644 index 0000000..191df5a --- /dev/null +++ b/ktransformers/configs/quant_configs.json @@ -0,0 +1,57 @@ +{ + "BF16": { + "block_element_count": 1, + "block_element_size": 2, + "bytes_per_element": 2.0, + "can_be_used_as_vector": true, + "has_min": false, + "has_scale": false, + "name": "BF16", + "reference": "", + "type_of_dot_vector": "BF16" + }, + "FP16": { + "block_element_count": 1, + "block_element_size": 2, + "bytes_per_element": 2.0, + "can_be_used_as_vector": true, + "has_min": false, + "has_scale": false, + "name": "FP16", + "reference": "", + "type_of_dot_vector": "FP16" + }, + "FP32": { + "block_element_count": 1, + "block_element_size": 4, + "bytes_per_element": 4.0, + "can_be_used_as_vector": true, + "has_min": false, + "has_scale": false, + "name": "FP32", + "reference": "", + "type_of_dot_vector": "FP32" + }, + "Q4_0": { + "block_element_count": 32, + "block_element_size": 18, + "bytes_per_element": 0.5625, + "can_be_used_as_vector": false, + "has_min": false, + "has_scale": true, + "name": "Q4_0", + "reference": "https://huggingface.co/docs/hub/gguf", + "type_of_dot_vector": "Q8_0" + }, + "Q8_0": { + "block_element_count": 32, + "block_element_size": 34, + "bytes_per_element": 1.0625, + "can_be_used_as_vector": true, + "has_min": false, + "has_scale": true, + "name": "Q8_0", + "reference": "https://huggingface.co/docs/hub/gguf", + "type_of_dot_vector": "Q8_0" + } +} \ No newline at end of file diff --git a/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py b/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py index fadfb11..5a56c13 100644 --- a/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py +++ b/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py @@ -114,6 +114,44 @@ def marlin_quantize( return res_list +def vllm_marlin_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, + act_order: bool, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, + act_order) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, + marlin_perm[num_bits]) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, + marlin_scale_perm[num_bits], + marlin_scale_perm_single[num_bits]) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + def inject_24(w, size_k, size_n): assert w.shape == (size_k, size_n) diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 1a5f55f..928de48 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -63,7 +63,7 @@ def local_chat( prompt_file : str | None = None, mode: str = "normal", force_think: bool = False, - chunk_prefill_size: int = 8192 + chunk_size: int = 8192 ): torch.set_grad_enabled(False) @@ -172,12 +172,12 @@ def local_chat( if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA: generated = prefill_and_generate( - model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size, + model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size, use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim ) else: generated = prefill_and_generate( - model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size, + model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size, ) diff --git a/ktransformers/models/configuration_deepseek_v3.py b/ktransformers/models/configuration_deepseek_v3.py index 6227092..235b7b0 100644 --- a/ktransformers/models/configuration_deepseek_v3.py +++ b/ktransformers/models/configuration_deepseek_v3.py @@ -1,53 +1,60 @@ -# coding=utf-8 -# Copyright 2025 bzantium and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on the DeepSeekV3 implementations from the DeepSeek AI team. (https://huggingface.co/deepseek-ai/DeepSeek-V3) - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""DeepSeekV3 model configuration""" - from transformers.configuration_utils import PretrainedConfig -from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging +logger = logging.get_logger(__name__) DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} - - class DeepseekV3Config(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the DeepSeek-V3. - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. - - Args: vocab_size (`int`, *optional*, defaults to 129280): Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`DeepseekV3Model`] - hidden_size (`int`, *optional*, defaults to 7168): + hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 18432): + intermediate_size (`int`, *optional*, defaults to 11008): Dimension of the MLP representations. - moe_intermediate_size (`int`, *optional*, defaults to 2048): + moe_intermediate_size (`int`, *optional*, defaults to 1407): Dimension of the MoE representations. - num_hidden_layers (`int`, *optional*, defaults to 61): + num_hidden_layers (`int`, *optional*, defaults to 32): Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 128): + num_nextn_predict_layers (`int`, *optional*, defaults to 1): + Number of nextn predict layers in the DeepSeekV3 Model. + num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer decoder. - num_key_value_heads (`int`, *optional*, defaults to 128): + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When @@ -55,39 +62,9 @@ class DeepseekV3Config(PretrainedConfig): by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. - n_shared_experts (`int`, *optional*, defaults to 1): - Number of shared experts. - n_routed_experts (`int`, *optional*, defaults to 256): - Number of routed experts. - routed_scaling_factor (`float`, *optional*, defaults to 2.5): - Scaling factor or routed experts. - kv_lora_rank (`int`, *optional*, defaults to 512): - Rank of the LoRA matrices for key and value projections. - q_lora_rank (`int`, *optional*, defaults to 1536): - Rank of the LoRA matrices for query projections. - qk_rope_head_dim (`int`, *optional*, defaults to 64): - Dimension of the query/key heads that use rotary position embeddings. - v_head_dim (`int`, *optional*, defaults to 128): - Dimension of the value heads. - qk_nope_head_dim (`int`, *optional*, defaults to 128): - Dimension of the query/key heads that don't use rotary position embeddings. - n_group (`int`, *optional*, defaults to 8): - Number of groups for routed experts. - topk_group (`int`, *optional*, defaults to 4): - Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). - num_experts_per_tok (`int`, *optional*, defaults to 8): - Number of selected experts, None means dense model. - first_k_dense_replace (`int`, *optional*, defaults to 3): - Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). - \--k dense layers--/ - norm_topk_prob (`bool`, *optional*, defaults to `True`): - Whether to normalize the weights of the routed experts. - aux_loss_alpha (`float`, *optional*, defaults to 0.001): - Auxiliary loss weight coefficient. - Whether to compute the auxiliary loss for each individual sample. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 4096): + max_position_embeddings (`int`, *optional*, defaults to 2048): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. @@ -98,15 +75,10 @@ class DeepseekV3Config(PretrainedConfig): relevant if `config.is_decoder=True`. pad_token_id (`int`, *optional*): Padding token id. - bos_token_id (`int`, *optional*, defaults to 0): + bos_token_id (`int`, *optional*, defaults to 1): Beginning of stream token id. - eos_token_id (`int`, *optional*, defaults to 1): + eos_token_id (`int`, *optional*, defaults to 2): End of stream token id. - pretraining_tp (`int`, *optional*, defaults to 1): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is - necessary to ensure exact reproducibility of the pretraining results. Please refer to [this - issue](https://github.com/pytorch/pytorch/issues/76232). tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings rope_theta (`float`, *optional*, defaults to 10000.0): @@ -120,49 +92,44 @@ class DeepseekV3Config(PretrainedConfig): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. - ```python >>> from transformers import DeepseekV3Model, DeepseekV3Config - >>> # Initializing a Deepseek-V3 style configuration >>> configuration = DeepseekV3Config() - >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "deepseek_v3" keys_to_ignore_at_inference = ["past_key_values"] - # Default tensor parallel plan for base model `DeepseekV3Model` - base_model_tp_plan = { - "layers.*.gate_proj": "colwise", - "layers.*.up_proj": "colwise", - "layers.*.down_proj": "rowwise", - } def __init__( self, vocab_size=129280, hidden_size=7168, intermediate_size=18432, - moe_intermediate_size=2048, + moe_intermediate_size = 2048, num_hidden_layers=61, + num_nextn_predict_layers=1, num_attention_heads=128, num_key_value_heads=128, - n_shared_experts=1, - n_routed_experts=256, - routed_scaling_factor=2.5, - kv_lora_rank=512, - q_lora_rank=1536, - qk_rope_head_dim=64, - v_head_dim=128, - qk_nope_head_dim=128, - n_group=8, - topk_group=4, - num_experts_per_tok=8, - first_k_dense_replace=3, - norm_topk_prob=True, - aux_loss_alpha=0.001, + n_shared_experts = 1, + n_routed_experts = 256, + ep_size = 1, + routed_scaling_factor = 2.5, + kv_lora_rank = 512, + q_lora_rank = 1536, + qk_rope_head_dim = 64, + v_head_dim = 128, + qk_nope_head_dim = 128, + topk_method = 'noaux_tc', + n_group = 8, + topk_group = 4, + num_experts_per_tok = 8, + moe_layer_freq = 1, + first_k_dense_replace = 3, + norm_topk_prob = True, + scoring_func = 'sigmoid', hidden_act="silu", max_position_embeddings=4096, initializer_range=0.02, @@ -171,7 +138,6 @@ class DeepseekV3Config(PretrainedConfig): pad_token_id=None, bos_token_id=0, eos_token_id=1, - pretraining_tp=1, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, @@ -185,24 +151,25 @@ class DeepseekV3Config(PretrainedConfig): self.intermediate_size = intermediate_size self.moe_intermediate_size = moe_intermediate_size self.num_hidden_layers = num_hidden_layers + self.num_nextn_predict_layers = num_nextn_predict_layers self.num_attention_heads = num_attention_heads self.n_shared_experts = n_shared_experts self.n_routed_experts = n_routed_experts + self.ep_size = ep_size self.routed_scaling_factor = routed_scaling_factor self.kv_lora_rank = kv_lora_rank self.q_lora_rank = q_lora_rank self.qk_rope_head_dim = qk_rope_head_dim self.v_head_dim = v_head_dim self.qk_nope_head_dim = qk_nope_head_dim - self.q_head_dim = qk_nope_head_dim + qk_rope_head_dim - self.head_dim = qk_rope_head_dim + self.topk_method = topk_method self.n_group = n_group self.topk_group = topk_group self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq self.first_k_dense_replace = first_k_dense_replace self.norm_topk_prob = norm_topk_prob - self.aux_loss_alpha = aux_loss_alpha - + self.scoring_func = scoring_func # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads @@ -211,17 +178,11 @@ class DeepseekV3Config(PretrainedConfig): self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps - self.pretraining_tp = pretraining_tp self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.attention_bias = attention_bias self.attention_dropout = attention_dropout - # Validate the correctness of rotary position embeddings parameters - # BC: if there is a 'type' field, copy it it to 'rope_type'. - if self.rope_scaling is not None and "type" in self.rope_scaling: - self.rope_scaling["rope_type"] = self.rope_scaling["type"] - rope_config_validation(self) super().__init__( pad_token_id=pad_token_id, @@ -229,7 +190,4 @@ class DeepseekV3Config(PretrainedConfig): eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, - ) - - -__all__ = ["DeepseekV3Config"] \ No newline at end of file + ) \ No newline at end of file diff --git a/ktransformers/models/custom_cache.py b/ktransformers/models/custom_cache.py index 434399f..4e65ce6 100644 --- a/ktransformers/models/custom_cache.py +++ b/ktransformers/models/custom_cache.py @@ -8,9 +8,11 @@ Version : 0.1.0 # Copyright 2018- The Hugging Face team. All rights reserved. # Copyright (c) 2024 by KVCache.AI, All Rights Reserved. import torch +import torch.nn as nn import transformers from transformers import Cache, PretrainedConfig from typing import List, Optional, Dict, Any, Tuple +from ktransformers.server.balance_serve.settings import sched_ext class StaticCache(transformers.StaticCache): """ Static Cache class to be used with `torch.compile(model)`. @@ -188,3 +190,85 @@ class StaticCache(transformers.StaticCache): def get_max_cache_shape(self) -> Tuple[int, int, int, int]: """Returns the maximum shape of the cache.""" return self.max_cache_len + +class KDeepSeekV3Cache(nn.Module): + def __init__( + self, + config: PretrainedConfig, + page_size: int = 256, + dtype=torch.bfloat16, + device=torch.device("cuda:0"), + + ): + super().__init__() + self.config = config + self.dtype = dtype + self.device = device + self.kv_lora_rank = config.kv_lora_rank + self.page_size = page_size + self.k_caches = [] + self.v_caches = [] + + + def load(self, inference_context: sched_ext.InferenceContext): + + for i in range(self.config.num_hidden_layers): + self.k_caches.append( + inference_context.k_cache[0][i] + ) + self.max_cache_len = self.k_caches[0].shape[0]*self.k_caches[0].shape[1] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + + page_idx: torch.Tensor, + page_offset: torch.Tensor, + + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + k_out = self.k_caches[layer_idx] + + k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states.reshape(-1, *key_states.shape[2:]) + k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states.reshape(-1, *value_states.shape[2:]) + return k_out + + + def get_page_table(self, cache_position: torch.Tensor, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, bsz_tensors: torch.tensor): + page_offset = cache_position % self.page_size + page_idx_local = cache_position // self.page_size + query_ids = torch.zeros_like(cache_position) + for i in range(len(q_indptr) - 1): + start_idx = q_indptr[i] + end_idx = q_indptr[i + 1] + query_ids[start_idx:end_idx] = i + page_idx = torch.zeros_like(page_idx_local) + for i in range(bsz_tensors[0]): + query_id = query_ids[i] + local_block = page_idx_local[i] + start_block = kv_indptr[query_id] + if local_block < kv_indptr[query_id + 1] - kv_indptr[query_id]: + page_idx[i] = kv_indices[start_block + local_block] + + return page_idx, page_offset + diff --git a/ktransformers/models/custom_modeling_deepseek_v2.py b/ktransformers/models/custom_modeling_deepseek_v2.py new file mode 100644 index 0000000..b70a457 --- /dev/null +++ b/ktransformers/models/custom_modeling_deepseek_v2.py @@ -0,0 +1,152 @@ +import math +from dataclasses import dataclass +import torch +import torch.nn as nn +from torch.nn import functional as F +import math +from typing import List, Optional, Tuple, Union +import torch +import torch.utils.checkpoint +from torch import nn +from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput +from ktransformers.models.custom_cache import KDeepSeekV3Cache +from ktransformers.models.modeling_deepseek import DeepseekV2Model, DeepseekV2PreTrainedModel +from ktransformers.models.configuration_deepseek import DeepseekV2Config + + +torch.set_grad_enabled(False) +torch.set_default_dtype(torch.bfloat16) +import flashinfer + +class KDeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): + + kv_cache: KDeepSeekV3Cache + use_cuda_graph = False + def __init__( + self, + config, + kv_cache, + + ): + super().__init__(config) + self.model = DeepseekV2Model(config) + self.config = config + self.kv_cache = kv_cache + + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + + def init_wrapper(self, use_cuda_graph, device, max_batch_size, max_pages): + self.use_cuda_graph = use_cuda_graph + self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + self.qo_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device) + self.paged_kv_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device) + self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device) + self.paged_kv_len_buf = torch.empty((max_batch_size,), dtype=torch.int32, device=device) + + + + self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( + self.workspace_buffer, use_cuda_graph=use_cuda_graph, + qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf, + kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf + ) + + def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"): + features = [] + for i in range(batch.batch_size): + tokens = batch.minibatch.tokens.contiguous() + feature = ( + self.model.embed_tokens(tokens.to(torch.device('cpu'))) + .to(torch.bfloat16) + .to(device=device) + ) + features.append(feature) + + return features + + + def forward( + self, + batch: ForwardBatchInput | None = None, + features: List[torch.Tensor] | None = None, + bsz_tensors: torch.Tensor | None = None, + num_tokens_tensors: torch.Tensor | None = None, + page_idx: torch.Tensor | None = None, + page_offset: torch.Tensor | None = None, + ) -> ForwardBatchOutput: + current_stream = torch.cuda.current_stream() + + forward_batch_output = ForwardBatchOutput() + + + hidden_states = features[0] + + + with torch.cuda.stream(current_stream): + residual = torch.zeros_like(hidden_states) + for i, decode_layer in enumerate(self.model.layers): + if self.model.transfer_map is not None and i in self.model.transfer_map: + prev_stream = torch.cuda.current_stream() + cur_device = self.model.transfer_map[i] + if cur_device not in self.model.stream_device_map: + self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device) + torch.cuda.set_device(cur_device) + self.model.stream_device_map[cur_device].wait_stream(prev_stream) + torch.cuda.set_stream(self.model.stream_device_map[cur_device]) + hidden_states = hidden_states.to( + self.model.transfer_map[i], non_blocking=True + ) + + batch.minibatch.position_ids = ( + batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True) + if batch.minibatch.position_ids is not None + else None + ) + hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual) + hidden_states = decode_layer.self_attn(hidden_states, self.kv_cache, + position_ids=batch.minibatch.position_ids, + wrapper=self.wrapper, bsz_tensors=num_tokens_tensors, + cache_position=batch.minibatch.positions, + batch_indices=batch.minibatch.batch_indices, + kv_indices=batch.minibatch.kv_indices, + kv_indptr=batch.minibatch.kv_indptr, + kv_last_page_len=batch.minibatch.kv_last_page_len, + q_indptr=batch.minibatch.q_indptr, + page_idx=page_idx, + page_offset=page_offset + ) + + hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual) + if i < 3: + hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors) + else: + hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors) + hidden_states = hidden_states.squeeze(0) + forward_batch_output = ForwardBatchOutput() + assert batch.batch_size == 1 + with torch.cuda.stream(current_stream): + + local_logit = self.lm_head(self.model.norm(hidden_states[batch.minibatch.logits_start], num_tokens_tensors, residual[batch.minibatch.logits_start])[0]) + # local_logit = local_logit[batch.minibatch.logits_start] + forward_batch_output.logits.append(local_logit) + + return forward_batch_output + + + + def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors, + num_heads: int, + head_dim_ckv: int, + head_dim_kpe: int, + page_size: int, + causal: bool, + sm_scale: float, + q_data_type: torch.dtype, + kv_data_type: torch.dtype,): + minibatch = batch.minibatch + + self.wrapper.plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, + minibatch.kv_len, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_data_type, kv_data_type) + \ No newline at end of file diff --git a/ktransformers/models/custom_modeling_deepseek_v3.py b/ktransformers/models/custom_modeling_deepseek_v3.py new file mode 100644 index 0000000..205fa86 --- /dev/null +++ b/ktransformers/models/custom_modeling_deepseek_v3.py @@ -0,0 +1,147 @@ +""" +Date: 2024-11-06 10:05:11 +LastEditors: djw +LastEditTime: 2024-11-13 07:50:51 +""" + +import math +from dataclasses import dataclass +import torch +import torch.nn as nn +from torch.nn import functional as F +import math +from typing import List, Optional, Tuple, Union +import torch +import torch.utils.checkpoint +from torch import nn +from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput +from ktransformers.models.custom_cache import KDeepSeekV3Cache +from ktransformers.models.modeling_deepseek_v3 import DeepseekV3Model, DeepseekV3PreTrainedModel +from ktransformers.models.configuration_deepseek_v3 import DeepseekV3Config + + +torch.set_grad_enabled(False) +torch.set_default_dtype(torch.bfloat16) +import flashinfer + +class KDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): + + cache: KDeepSeekV3Cache + use_cuda_graph = False + def __init__( + self, + config: DeepseekV3Config, + cache, + ): + super().__init__(config) + self.model = DeepseekV3Model(config) + self.config = config + self.cache = cache + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def init_wrapper(self, use_cuda_graph, device, max_batch_size, max_pages): + self.use_cuda_graph = use_cuda_graph + self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + self.qo_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device) + self.paged_kv_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device) + self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device) + self.paged_kv_len_buf = torch.empty((max_batch_size,), dtype=torch.int32, device=device) + self.bsz_tensor_buf = torch.empty((1, ), dtype=torch.int32, device=device) + + + self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( + self.workspace_buffer, use_cuda_graph=use_cuda_graph, + qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf, + kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf, + bsz_tensor=self.bsz_tensor_buf + ) + + def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"): + features = [] + for i in range(batch.batch_size): + tokens = batch.minibatch.tokens.contiguous() + feature = ( + self.model.embed_tokens(tokens.to(torch.device('cpu'))) + .to(torch.bfloat16) + .to(device=device) + ) + features.append(feature) + + return features + + + def forward( + self, + batch: ForwardBatchInput | None = None, + features: List[torch.Tensor] | None = None, + bsz_tensors: torch.Tensor | None = None, + num_tokens_tensors: torch.Tensor | None = None, + page_idx: torch.Tensor | None = None, + page_offset: torch.Tensor | None = None, + cuda_graph_idx: int | None = -1 + ) -> ForwardBatchOutput: + current_stream = torch.cuda.current_stream() + + forward_batch_output = ForwardBatchOutput() + + + hidden_states = features[0] + + with torch.cuda.stream(current_stream): + residual = torch.zeros_like(hidden_states) + for i, decode_layer in enumerate(self.model.layers): + # can't use now, only one flashinfer wrapper + if self.model.transfer_map is not None and i in self.model.transfer_map: + prev_stream = torch.cuda.current_stream() + cur_device = self.model.transfer_map[i] + if cur_device not in self.model.stream_device_map: + self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device) + torch.cuda.set_device(cur_device) + self.model.stream_device_map[cur_device].wait_stream(prev_stream) + torch.cuda.set_stream(self.model.stream_device_map[cur_device]) + hidden_states = hidden_states.to( + self.model.transfer_map[i], non_blocking=True + ) + + batch.minibatch.position_ids = ( + batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True) + if batch.minibatch.position_ids is not None + else None + ) + hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual) + hidden_states = decode_layer.self_attn(hidden_states, self.cache, + position_ids=batch.minibatch.position_ids, + wrapper=self.wrapper, num_tokens_tensors=num_tokens_tensors, + page_idx=page_idx, + page_offset=page_offset + ) + + hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual) + if i < self.config.first_k_dense_replace: + hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors) + else: + hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors, cuda_graph_idx) + hidden_states = hidden_states.squeeze(0) + forward_batch_output = ForwardBatchOutput() + with torch.cuda.stream(current_stream): + local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors) + forward_batch_output.logits.append(local_logit) + + return forward_batch_output + + + + def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors, + num_heads: int, + head_dim_ckv: int, + head_dim_kpe: int, + page_size: int, + causal: bool, + sm_scale: float, + q_data_type: torch.dtype, + kv_data_type: torch.dtype,): + minibatch = batch.minibatch + self.wrapper.plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, + minibatch.kv_len, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_data_type, kv_data_type, bsz_tensors) + \ No newline at end of file diff --git a/ktransformers/models/modeling_deepseek_v3.py b/ktransformers/models/modeling_deepseek_v3.py index 952eed7..12294e1 100644 --- a/ktransformers/models/modeling_deepseek_v3.py +++ b/ktransformers/models/modeling_deepseek_v3.py @@ -99,6 +99,7 @@ class DeepseekV3RMSNorm(nn.Module): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + self.hidden_size = hidden_size def forward(self, hidden_states): input_dtype = hidden_states.dtype @@ -398,7 +399,6 @@ class MoEGate(nn.Module): self.n_routed_experts = config.n_routed_experts self.routed_scaling_factor = config.routed_scaling_factor self.scoring_func = config.scoring_func - self.seq_aux = config.seq_aux self.topk_method = config.topk_method self.n_group = config.n_group self.topk_group = config.topk_group @@ -436,6 +436,7 @@ class MoEGate(nn.Module): ### select top-k experts if self.topk_method == "noaux_tc": + assert not self.training scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) group_scores = ( scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1) @@ -454,7 +455,7 @@ class MoEGate(nn.Module): ) .reshape(bsz * seq_len, -1) ) # [n, e] - tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] + tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] _, topk_idx = torch.topk( tmp_scores, k=self.top_k, dim=-1, sorted=False ) @@ -1933,4 +1934,4 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, - ) \ No newline at end of file + ) diff --git a/ktransformers/operators/RoPE.py b/ktransformers/operators/RoPE.py index adc1c5f..5233fc7 100644 --- a/ktransformers/operators/RoPE.py +++ b/ktransformers/operators/RoPE.py @@ -359,3 +359,56 @@ class DynamicNTKScalingRotaryEmbedding( self.orig_module.rope_type, self.orig_module.config, ) + + + +class RotaryEmbeddingV4(BaseInjectedModule): + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + # device: str = "cuda", + generate_device: str = "cuda", + prefill_device: str = "cuda", + **kwargs, + ): + BaseInjectedModule.__init__( + self, key, gguf_loader, config, orig_module, generate_device, **kwargs + ) + self.generate_device = generate_device + self.prefill_device = prefill_device + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def load(self): + self._init( + dim=self.config.qk_rope_head_dim, + max_position_embeddings=self.config.max_position_embeddings, + base=self.config.rope_theta, + device=self.device, + ) + def _init(self, dim, max_position_embeddings, base, device, scaling_factor=1.0): + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + # self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings \ No newline at end of file diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index db65f34..d02a505 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -32,7 +32,8 @@ import os from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled if flashinfer_enabled: from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton - + from flashinfer.mla import BatchMLAPagedAttentionWrapper +from ktransformers.models.custom_cache import KDeepSeekV3Cache logger = logging.getLogger("attention") # Copied from transformers.models.llama.modeling_llama.rotate_half @@ -759,3 +760,92 @@ class KLlamaAttention(BaseInjectedModule): attn_weights = None return attn_output, attn_weights, past_key_value + +class flashinfer_attn(BaseInjectedModule, DeepseekV2Attention): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + chunck_size: int = 1000, + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.config, + orig_module.layer_idx) + self.chunck_size = chunck_size # TODO, generate chunck_size automatically. + + + def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]: + if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')): + kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) + q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank) + out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank) + self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, + bias=False, dtype=q_absorb.dtype, device=q_absorb.device) + self.q_absorb.weight.data = q_absorb + self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim, + bias=False, dtype=out_absorb.dtype, device=out_absorb.device) + self.out_absorb.weight.data = out_absorb + #del self.orig_module.kv_b_proj + q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank) + out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank) + return q_absorb, out_absorb + + + + def forward(self, + hidden_states: torch.Tensor, + kv_cache: KDeepSeekV3Cache, + position_ids: torch.Tensor, + wrapper: BatchMLAPagedAttentionWrapper, + num_tokens_tensors: torch.Tensor, + page_idx: torch.Tensor, + page_offset: torch.Tensor, + ): + q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states, num_tokens_tensors) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states, num_tokens_tensors), num_tokens_tensors), num_tokens_tensors) + q = q.view(q_len, self.num_heads, self.q_head_dim) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states, num_tokens_tensors) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + compressed_kv = compressed_kv.contiguous() + compressed_kv = self.kv_a_layernorm(compressed_kv, num_tokens_tensors) + k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim) + compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank) + + cos, sin = self.rotary_emb(q_pe, position_ids.unsqueeze(0)) + q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2) + q_pe = q_pe.squeeze(0) + if kv_cache is not None: + + # page_idx, page_offset = kv_cache.get_page_table(position_ids, q_indptr, kv_indptr, kv_indices) + cache_kwargs = {"sin": sin, "cos": cos, "page_idx": page_idx, "page_offset": page_offset} # Specific to RoPE models + compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, page_idx, page_offset, cache_kwargs) + compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank) + k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim) + + q_absorb, out_absorb = self.get_absorbed() + q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below + q_nope = torch.matmul(q_nope, q_absorb) # batched MM + q_nope = q_nope.transpose(0, 1) + # q_nope.squeeze_(1) + # q_pe.squeeze_(1) + + attn_output = wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(q_len, self.num_heads, self.kv_lora_rank) + attn_output = attn_output.transpose(0, 1) + attn_output = torch.matmul(attn_output, out_absorb.mT) # [self.num_heads, q_len, self.v_head_dim] + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output, num_tokens_tensors) + return attn_output \ No newline at end of file diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index c2d5c25..74613b4 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -37,6 +37,10 @@ import time from ktransformers.operators.cpuinfer import CPUInfer +def deduplicate_and_sort(lst): + return sorted(set(lst)) +#cuda_graphs = [Config().chunk_size] +cuda_graphs = deduplicate_and_sort([1, 2, 3, Config().max_batch_size, 64, Config().chunk_size]) # class Base(BaseInjectedModule, ABC): class KExpertsBase(ABC): def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs): @@ -112,6 +116,7 @@ class KExpertsBase(ABC): tensors[k] = self.gguf_loader.load_gguf_tensor(key + k, device=device) return tensors + class KExpertsCPU(KExpertsBase): input_tensor_cpu:Tensor = None expert_ids_cpu:Tensor = None @@ -119,8 +124,8 @@ class KExpertsCPU(KExpertsBase): output_cpu:Tensor = None output_gpu_map:dict = {} # Manage output tensor buffer on different gpu #stream_map:dict = {} # Manage cuda stream on different gpu - #gguf_loader:GGUFLoader = None - CPU_INFER = None + # @TODO add yaml + CPU_INFER = CPUInfer(Config().cpu_infer) def __init__( self, key: str, @@ -133,11 +138,6 @@ class KExpertsCPU(KExpertsBase): **kwargs ): super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) - if KExpertsCPU.CPU_INFER is None: - KExpertsCPU.CPU_INFER = CPUInfer(Config().cpu_infer) - #if KExpertsCPU.gguf_loader is None: - # KExpertsCPU.gguf_loader = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf") - self.gguf_loader = gguf_loader assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU" self.n_routed_experts = n_routed_experts self.out_device = out_device @@ -161,7 +161,7 @@ class KExpertsCPU(KExpertsBase): down_ptr = ctypes.addressof( ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents ) - #print(self.gate_type, self.up_type, self.down_type) + # print(self.gate_qtype, self.up_qtype, self.down_qtype) n_routed_experts = self.n_routed_experts # n_routed_experts = len(self.orig_module) moe_config = MOEConfig( @@ -188,43 +188,83 @@ class KExpertsCPU(KExpertsBase): self.cpu_infer.submit(self.moe.warm_up()) self.cpu_infer.sync() if self.out_device not in KExpertsCPU.output_gpu_map: - KExpertsCPU.output_gpu_map[self.out_device] = torch.zeros((self.config.hidden_size), device=self.out_device) + if isinstance(cuda_graphs, list): + KExpertsCPU.output_gpu_map[self.out_device] = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device=self.out_device) for i in range(len(cuda_graphs))] + else: + KExpertsCPU.output_gpu_map[self.out_device] = torch.zeros((cuda_graphs, self.config.hidden_size), device=self.out_device) if KExpertsCPU.input_tensor_cpu == None: - KExpertsCPU.input_tensor_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True) - KExpertsCPU.expert_ids_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True) - KExpertsCPU.weights_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True) - KExpertsCPU.output_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) + if isinstance(cuda_graphs, list): + KExpertsCPU.input_tensor_cpu = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device="cpu", pin_memory=True) for i in range(len(cuda_graphs))] + KExpertsCPU.expert_ids_cpu = [torch.zeros((cuda_graphs[i], num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True) for i in range(len(cuda_graphs))] + KExpertsCPU.weights_cpu = [torch.zeros((cuda_graphs[i], num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True) for i in range(len(cuda_graphs))] + KExpertsCPU.output_cpu = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) for i in range(len(cuda_graphs))] + KExpertsCPU.bsz_tensor_cpu = [torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True) for i in range(len(cuda_graphs))] + else: + KExpertsCPU.input_tensor_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True) + KExpertsCPU.expert_ids_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True) + KExpertsCPU.weights_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True) + KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) + KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True) - def submit_for_one_decode(self, input_tensor, expert_ids, weights): - KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True) - KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True) - KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True) - self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(0), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr())) - - def sync_for_one_decode(self): - self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream) - KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True) - return KExpertsCPU.output_gpu_map[self.out_device] - - def forward(self, input_tensor, expert_ids, weights): - # generate, capture and run cuda graph - # print(expert_ids) - if input_tensor.size(0)==1 and torch.cuda.is_current_stream_capturing(): - # TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible - #print("capturing experts") + def submit_for_one_decode(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0): + if bsz_tensor is None: + bsz_tensor = torch.ones(1, device=input_tensor.device, dtype=torch.int32) + if cuda_graph_idx != -1: + KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True) + KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True) + KExpertsCPU.weights_cpu[cuda_graph_idx].copy_(weights, non_blocking=True) + KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].copy_(bsz_tensor, non_blocking=True) + self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(-1), KExpertsCPU.expert_ids_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.weights_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.input_tensor_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.output_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].data_ptr())) + else: KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True) KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True) KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True) - self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(1, expert_ids.size(1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr())) - self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream) + KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor, non_blocking=True) + self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(-1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr())) + + + def sync_for_one_decode(self, cuda_graph_idx=0): + if cuda_graph_idx != -1: + self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream) + KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx].copy_(KExpertsCPU.output_cpu[cuda_graph_idx], non_blocking=True) + return KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx] + else: + self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream) KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True) return KExpertsCPU.output_gpu_map[self.out_device] + + def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0): + # generate, capture and run cuda graph + # print(expert_ids) + if bsz_tensor is None: + bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32) + if torch.cuda.is_current_stream_capturing(): + if cuda_graph_idx != -1: + KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True) + KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True) + KExpertsCPU.weights_cpu[cuda_graph_idx].copy_(weights, non_blocking=True) + KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].copy_(bsz_tensor, non_blocking=True) + self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(expert_ids.size(0), expert_ids.size(-1), KExpertsCPU.expert_ids_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.weights_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.input_tensor_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.output_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].data_ptr())) + self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream) + KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx].copy_(KExpertsCPU.output_cpu[cuda_graph_idx], non_blocking=True) + return KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx] + + else: + KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True) + KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True) + KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True) + KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor, non_blocking=True) + self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(expert_ids.size(0), expert_ids.size(-1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr())) + self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream) + KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True) + return KExpertsCPU.output_gpu_map[self.out_device] else: input_tensor = input_tensor.contiguous().cpu() expert_ids = expert_ids.contiguous().cpu() weights = weights.contiguous().to(torch.float32).cpu() + bsz_tensor = bsz_tensor.contiguous().cpu() output = torch.empty_like(input_tensor).contiguous() - self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr())) + self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr(), bsz_tensor.data_ptr())) self.cpu_infer.sync() return output.to(device=object.__getattribute__(self, "out_device")) @@ -859,6 +899,8 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): y += y_ return y + + @torch.no_grad() def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: outs = self.experts(x, topk_ids, topk_weight) @@ -1013,4 +1055,178 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock): # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states_cpu.dtype)) - return final_hidden_states \ No newline at end of file + return final_hidden_states + +class KDeepseekV3MoEV2(BaseInjectedModule, DeepseekV3MoE): + def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0): + identity = hidden_states + orig_shape = hidden_states.shape + sequence_length = orig_shape[1] + topk_idx, topk_weight = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + + # only for generate phase + if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug + self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx) + if self.config.n_shared_experts is not None: + y_ = self.shared_experts(identity, bsz_tensor).squeeze(0) + y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0) + y += y_ + y.resize_(*orig_shape) + return y + + if self.config.n_shared_experts is not None: + y_ = self.shared_experts(identity, bsz_tensor).squeeze(0) + + if isinstance(self.experts, KExpertsBase): + y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device) + elif hidden_states.size(0) > 10: + # TODO may bugs here + y = ( + self.moe_infer(hidden_states, topk_idx, topk_weight) + .view(*orig_shape) + .to(device=hidden_states.device) + ) + else: + # TODO may bugs here + y = ( + self.moe_infer_simple(hidden_states, topk_idx, topk_weight) + .view(*orig_shape) + .to(device=hidden_states.device) + ) + if self.config.n_shared_experts is not None: + y += y_ + return y + + @torch.no_grad() + def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor: + outs = torch.empty_like(x) + outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx) + return outs + + @torch.no_grad() + # TODO may bugs here + def moe_infer_simple( + self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor + ) -> torch.Tensor: + """ + x: [num_tokens, hidden_size] + topk_ids, topk_weight: [num_tokens, num_selected_experts] + """ + outs = torch.zeros_like(x) + for token_idx in range(topk_ids.size(0)): + for expert_idx in range(topk_ids.size(1)): + expert = self.experts[topk_ids[token_idx, expert_idx]] + outs[token_idx] += ( + expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx] + ) + return outs + + @torch.no_grad() + # TODO may bugs here + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert.forward(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out + +class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase): + def __init__(self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + # device: str = "cuda", + prefill_device:str = "cuda", + prefill_op: str | None = "KExpertsTorch", + generate_device: str = "cpu", + generate_op: str | None = "KExpertsCPU", + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + if generate_op is not None: + self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs) + else: + self.generate_experts = None + if prefill_op is not None: + self.prefill_experts = EXPERTS_MAP[prefill_op](key, gguf_loader, config, len(orig_module), device=prefill_device, **kwargs) + else: + self.prefill_experts = None + self.gpu_mlp_type = prefill_op + self.cpu_mlp_type = generate_op + self.mode = InferenceState.UNLOAD + + def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True): + # TODO support w as input + if not mode: mode = InferenceState.GENERATE + if mode == InferenceState.GENERATE: + self.prefill_experts.unload() + self.generate_experts.load(w, warmup=warmup) + self.device = self.generate_experts.device + self.mode = mode + elif mode == InferenceState.PREFILL: + self.generate_experts.unload() + self.prefill_experts.load(w, warmup=warmup) + self.device = self.prefill_experts.device + self.mode = mode + elif mode == InferenceState.UNLOAD: + self.unload() + self.mode = mode + self.device = self.generate_experts.device + else: + raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") + + def unload(self): + if self.generate_experts is not None: + self.generate_experts.unload() + if self.prefill_experts is not None: + self.prefill_experts.unload() + self.device = self.generate_experts.device + + def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0): + if self.mode == InferenceState.GENERATE: + assert self.generate_experts is not None, "generate_experts is None" + return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx) + elif self.mode == InferenceState.PREFILL: + assert self.prefill_experts is not None, "prefill_experts is None" + return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx) + else: + raise ValueError("load or set_inference_mode before forward") + + def set_inference_mode(self, mode: InferenceState): + if mode == InferenceState.GENERATE: + self.load(mode=InferenceState.GENERATE, warmup=False) + elif mode == InferenceState.PREFILL: + self.load(mode=InferenceState.PREFILL, warmup=False) + elif mode == InferenceState.UNLOAD: + self.unload() + else: + raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") diff --git a/ktransformers/operators/flashinfer_wrapper.py b/ktransformers/operators/flashinfer_wrapper.py index a702872..5700d65 100644 --- a/ktransformers/operators/flashinfer_wrapper.py +++ b/ktransformers/operators/flashinfer_wrapper.py @@ -86,6 +86,7 @@ class MLAWrapper(): self.qo_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device) self.kv_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device) self.kv_indices_buf = torch.empty(max_pages, dtype=torch.int32, device=device) + self.batch_size_tensor_buf = torch.tensor([self.max_batch_size], dtype=torch.int32, device=device) self.kv_len_arr_buf = torch.empty(max_batch_size, dtype=torch.int32, device=device) else: self.qo_indptr_buf = None @@ -94,19 +95,22 @@ class MLAWrapper(): self.kv_len_arr_buf = None self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( self.float_workspace_buffer, - use_cuda_graph=False, + use_cuda_graph=use_cuda_graph, qo_indptr=self.qo_indptr_buf, kv_indptr=self.kv_indptr_buf, kv_indices=self.kv_indices_buf, kv_len_arr=self.kv_len_arr_buf, + bsz_tensor=self.batch_size_tensor_buf ) self.need_plan = True + def plan(self, qo_indptr, kv_indptr, kv_indices, kv_len_arr, + bsz_tensor, num_heads, head_dim_ckv, head_dim_kpe, @@ -138,6 +142,7 @@ class MLAWrapper(): sm_scale, q_data_type, kv_data_type, + bsz_tensor ) def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False): @@ -240,16 +245,17 @@ if __name__ == "__main__": #checksame() #exit(0) - max_batch_size = 1 - max_pages = 64 + max_batch_size = 2 + max_batch_tokens = 256 + max_pages = 128 page_size = 64 num_heads = 128 # warm-up kv_len = 4023 q_len = 1 - q_nope_buf = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda") - q_pe_buf = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda") + q_nope_buf = torch.randn((max_batch_tokens, num_heads, 512), dtype=torch.bfloat16, device="cuda") + q_pe_buf = torch.randn((max_batch_tokens, num_heads, 64), dtype=torch.bfloat16, device="cuda") kv_buf = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda") ckv, k_pe = torch.split(kv_buf, [512, 64], dim=-1) @@ -260,13 +266,19 @@ if __name__ == "__main__": max_pages, ) + used_pages = (kv_len + page_size - 1)// page_size kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda") qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda") + kv_indptr = torch.tensor([0, used_pages], dtype=torch.int32, device="cuda") + kv_indices = torch.empty(max_pages, dtype=torch.int32, device="cuda") + kv_indices[:used_pages] = torch.arange(0, used_pages, dtype=torch.int32, device="cuda") + bsz_tensor = torch.tensor([1], dtype=torch.int32, device="cuda") wrapper.plan( qo_indptr, - None, - None, + kv_indptr, + kv_indices, kv_len_arr, + bsz_tensor, 128, 512, 64, @@ -276,14 +288,98 @@ if __name__ == "__main__": torch.bfloat16, ) - attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe) + attn_output = wrapper.run(q_nope_buf[:q_len], q_pe_buf[:q_len], ckv, k_pe) print(attn_output.shape) - graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe) + graph.replay() + + q = torch.cat([q_nope_buf, q_pe_buf], dim=-1) + k = ( + torch.cat([ckv, k_pe], dim=-1) + .view(-1, 1, 512 + 64) + .repeat_interleave(num_heads, dim=1) + ) + v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1) + attn_ref, lse_ref = attention_ref_torch( + 1, + q[:q_len], + k[:kv_len], + v[:kv_len], + True, + 192 ** (-0.5) + ) + torch.testing.assert_close(attn_output[:q_len], attn_ref, rtol=5e-3, atol=5e-3) # warm-up finished + kv_len = 512 + q_len = 128 + pages = max_pages + used_pages = (kv_len + page_size - 1)// page_size + q_nope = torch.randn((q_len*2, num_heads, 512), dtype=torch.bfloat16, device="cuda") + q_nope[q_len:] = q_nope[:q_len] + q_pe = torch.randn((q_len*2, num_heads, 64), dtype=torch.bfloat16, device="cuda") + q_pe[q_len:] = q_pe[:q_len] + kv_cache = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda") + kv_cache[used_pages:2*used_pages] = kv_cache[:used_pages] + ckv, k_pe = torch.split(kv_cache, [512, 64], dim=-1) + + kv_len_arr = torch.tensor([kv_len, kv_len], dtype=torch.int32, device="cuda") + qo_indptr = torch.tensor([0, q_len, q_len*2], dtype=torch.int32, device="cuda") + kv_indptr = torch.tensor([0, used_pages, used_pages*2], dtype=torch.int32, device="cuda") + kv_indices = torch.empty(max_pages, dtype=torch.int32, device="cuda") + kv_indices[:2*used_pages] = torch.arange(0, 2*used_pages, dtype=torch.int32, device="cuda") + bsz_tensor = torch.tensor([2], dtype=torch.int32, device="cuda") + wrapper.plan( + qo_indptr, + kv_indptr, + kv_indices, + kv_len_arr, + bsz_tensor, + 128, + 512, + 64, + page_size, + 192 ** (-0.5), + torch.bfloat16, + torch.bfloat16, + ) + + q_nope_buf.copy_(q_nope) + q_pe_buf.copy_(q_pe) + kv_buf[:pages].copy_(kv_cache) + + torch.cuda.synchronize() + graph.replay() + torch.cuda.synchronize() + + # ref_torch + q = torch.cat([q_nope, q_pe], dim=-1) + k = ( + torch.cat([ckv, k_pe], dim=-1) + .view(-1, 1, 512 + 64) + .repeat_interleave(num_heads, dim=1) + ) + v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1) + attn_ref, lse_ref = attention_ref_torch( + max_batch_size, + q, + k[:2*kv_len], + v[:2*kv_len], + True, + 192 ** (-0.5) + ) + + torch.testing.assert_close(attn_ref[:q_len], attn_ref[q_len:q_len*2], rtol=1e-9, atol=1e-9) + torch.testing.assert_close(attn_output[:q_len], attn_output[q_len:q_len*2], rtol=1e-9, atol=1e-9) + torch.testing.assert_close(attn_output[:q_len], attn_ref[:q_len], rtol=5e-3, atol=5e-3) + torch.testing.assert_close(attn_output[q_len:q_len*2], attn_ref[q_len:q_len*2], rtol=5e-3, atol=5e-3) + #torch.testing.assert_close(attn_output[:q_len], attn_output[q_len:q_len*2], rtol=1e-9, atol=1e-9) + #torch.testing.assert_close(attn_output, attn_ref, rtol=5e-3, atol=5e-3) + + exit(0) + for forward_id in range(0, 1): print("forward_id", forward_id) for layer_id in range(1): @@ -376,5 +472,4 @@ if __name__ == "__main__": #file_name = f"./flashinfer_output/layer_{layer_id}_forward_{forward_id}_attn_output.pt" #ktrans_output = torch.load(file_name) #torch.testing.assert_close(attn_output, ktrans_output.squeeze(1), rtol=1e-3, atol=1e-3) - print("test past") - + print("test past") \ No newline at end of file diff --git a/ktransformers/operators/gate.py b/ktransformers/operators/gate.py index d3aa215..a3fb70e 100644 --- a/ktransformers/operators/gate.py +++ b/ktransformers/operators/gate.py @@ -249,4 +249,4 @@ class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase): if self.weight is not None: self.weight = None if self.e_score_correction_bias is not None: - self.e_score_correction_bias = None + self.e_score_correction_bias = None \ No newline at end of file diff --git a/ktransformers/operators/layernorm.py b/ktransformers/operators/layernorm.py new file mode 100644 index 0000000..8e8cbc7 --- /dev/null +++ b/ktransformers/operators/layernorm.py @@ -0,0 +1,78 @@ +''' +Date: 2024-11-13 15:05:52 +LastEditors: Xie Weiyu ervinxie@qq.com +LastEditTime: 2024-11-25 08:59:19 +''' +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Fused operators for normalization layers.""" + +import logging +from typing import Optional, Tuple, Union +from transformers import PretrainedConfig +import torch +import torch.nn as nn +from ktransformers.models.modeling_deepseek_v3 import DeepseekV3RMSNorm +from ktransformers.operators.base_operator import BaseInjectedModule +from ktransformers.util.custom_gguf import GGUFLoader +from flashinfer.norm import ( + fused_add_rmsnorm, + rmsnorm, +) + + +logger = logging.getLogger(__name__) + + +class RMSNorm(DeepseekV3RMSNorm, BaseInjectedModule): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.hidden_size, + orig_module.variance_epsilon) + + def forward( + self, + x: torch.Tensor, + batch_size_tensor: torch.Tensor = None, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + #return self.forward_native(x, residual) + if batch_size_tensor is None: + return self.forward_native(x) + if residual is not None: + fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon) + #residual = x + residual + #out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon) + return x, residual + # print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous()) + out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon) + return out + + def forward_native( + self, hidden_states + ): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) \ No newline at end of file diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index c0dfc85..3174536 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -15,14 +15,16 @@ import ctypes import torch from torch import Tensor, nn import KTransformersOps +import vLLMMarlin from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.utils import InferenceState from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import ( MarlinWorkspace, - marlin_quantize, + marlin_quantize, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MIN_THREAD_K, GPTQ_MARLIN_MAX_PARALLEL, + vllm_marlin_quantize ) from ktransformers.operators.base_operator import BaseInjectedModule from transformers.configuration_utils import PretrainedConfig @@ -84,8 +86,10 @@ class KLinearBase(ABC): if self.gguf_loader.safetensor_loader is not None: # using safetensor_loader tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight') - weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv') - return nn.Parameter(tensor), nn.Parameter(weight_scale_inv) + if key+'.weight_scale_inv' in self.gguf_loader.safetensor_loader.tensor_file_map: + weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv') + return nn.Parameter(tensor), nn.Parameter(weight_scale_inv) + return nn.Parameter(tensor) elif key + ".weight" in self.gguf_loader.tensor_file_map: if key + ".bias" in self.gguf_loader.tensor_file_map: @@ -134,7 +138,7 @@ class KLinearTorch(KLinearBase): self.weight = None self.has_bias = False - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: dtype = x.dtype out_device = x.device # TODO: support CUDA Graph when using cpu, but CPUInfer is recommended. @@ -178,7 +182,6 @@ class KLinearTorch(KLinearBase): if self.has_bias: self.bias = None - class KLinearQ8(KLinearBase): def __init__( self, @@ -370,7 +373,7 @@ class KLinearFP8(KLinearBase): self.dtype = torch.get_default_dtype() self.block_size = block_size - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor) -> torch.Tensor: x = x.to(self.device) orig_dtype = x.dtype x_quantized, scale_x = act_quant(x, self.block_size) @@ -397,8 +400,152 @@ class KLinearFP8(KLinearBase): self.weight = None if self.has_bias: self.bias = None + +# TODO: merge two marlin class + +class VLinearMarlin(KLinearBase): + marlin_q_w: torch.Tensor + marlin_s: torch.Tensor + g_idx: torch.Tensor + sort_indices: torch.Tensor + has_bias: bool + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module = None, + device: str = "cuda", + num_bits: int = 4, # 4-bit/8-bit is supported + group_size: int = 64, # -1, 32, 64, 128 + act_order: bool = False, + is_k_full=True, + **kwargs, + ): + assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device" + super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) + self.num_bits = num_bits + self.group_size = group_size + self.act_order = act_order + self.is_k_full = is_k_full + self.padding = False + self.orin_in_features = self.in_features + self.orin_out_features = self.out_features + if self.in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or self.out_features%GPTQ_MARLIN_MIN_THREAD_K!=0: + #print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding") + self.padding = True + self.in_features = (self.in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K + self.out_features = (self.out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N + #print(f"After padding: in_features={in_features}, out_features={out_features}") + self.k = self.in_features + self.n = self.out_features + + def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): + if self.loaded: return + if device is None: device = self.device + assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device" + #if self.in_features * self.out_features: + if w is None: + w = self.load_weight(device=device) + + if isinstance(w, nn.Parameter): + # pad weight + weight = w.view(self.orin_out_features, self.orin_in_features).T + self.has_bias = False + elif isinstance(w, tuple): + w = list(w) + weight = w[0].view(self.orin_out_features, self.orin_in_features).T + self.bias = w[1].view(self.orin_out_features) + self.bias = w[1] + self.has_bias = True + else: + raise ValueError("Invalid weight type") + weight = weight.to(device) + if self.has_bias: + self.bias = self.bias.to(device) + + if self.padding: + padded_weight = torch.zeros(self.in_features, self.out_features, device=self.device) + padded_weight[:self.orin_in_features, :self.orin_out_features] = weight + weight = padded_weight + + # Pack Marlin linear + marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( + weight, self.num_bits, self.group_size, self.act_order + ) + self.workspace = MarlinWorkspace( + self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device + ) + self.weight = marlin_q_w + self.marlin_q_w = marlin_q_w + self.marlin_s = marlin_s + self.g_idx = g_idx + self.sort_indices = sort_indices + self.k = weight.shape[0] + self.n = weight.shape[1] + # self.shape_buffer = torch.tensor([60], dtype=torch.int32, device=self.device) + self.loaded = True + + + def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor: + if bsz_tensor is None: + bsz_tensor = torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device) + + + # Only support input x as BF16 and FP16 + x = x.to(self.device) + orig_shape = list(x.shape) + orig_dtype = x.dtype + x = x.reshape(-1, orig_shape[-1]) + marlin_s = self.marlin_s.to(x.dtype) + sms = -1 + + x = vLLMMarlin.gptq_marlin_gemm( + x, + self.marlin_q_w, + marlin_s, + self.g_idx, + self.sort_indices, + self.workspace.scratch, + self.num_bits, + bsz_tensor, + # torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device), + x.shape[0], + self.n, + x.shape[-1], + sms, + self.is_k_full, + ) + # x = KTransformersOps.gptq_marlin_gemm( + # x, + # self.marlin_q_w, + # marlin_s, + # self.g_idx, + # self.sort_indices, + # self.workspace.scratch, + # self.num_bits, + # x.shape[0], + # self.n, + # x.shape[-1], + # self.is_k_full, + # ) + if self.has_bias: + x = x + self.bias + orig_shape[-1] = self.n + return x.reshape(orig_shape).to(orig_dtype) + + def unload(self): + + if self.has_bias: + self.bias = None + self.marlin_q_w = None + self.marlin_s = None + self.g_idx = None + self.sort_indices = None + self.workspace = None + class KLinearMarlin(KLinearBase): marlin_q_w: torch.Tensor marlin_s: torch.Tensor @@ -483,7 +630,7 @@ class KLinearMarlin(KLinearBase): self.n = weight.shape[1] self.loaded = True - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None, **kwargs) -> torch.Tensor: # Only support input x as BF16 and FP16 x = x.to(self.device) orig_shape = list(x.shape) @@ -629,12 +776,13 @@ class KLinearCPUInfer(KLinearBase): if self.w is not None: self.w = None if self.has_bias: - self.bias = None + self.bias = None LINEAR_MAP = { "KLinearMarlin": KLinearMarlin, "KLinearTorch": KLinearTorch, "KLinearCPUInfer": KLinearCPUInfer, + "VLinearMarlin": VLinearMarlin, "KLinearFP8": KLinearFP8, "KLinearQ8": KLinearQ8, } @@ -668,13 +816,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase): self.generate_linear = None self.mode = InferenceState.UNLOAD - def forward(self, x): + def forward(self, x, bsz_tensor=None): if self.mode == InferenceState.PREFILL: assert self.prefill_linear is not None, "cpu linear is not initialized" - y = self.prefill_linear.forward(x) + y = self.prefill_linear.forward(x, bsz_tensor) else: assert self.generate_linear is not None, "gpu linear is not initialized" - y = self.generate_linear.forward(x) + y = self.generate_linear.forward(x, bsz_tensor) return y def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE): @@ -717,3 +865,5 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase): self.unload() else: raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") + + diff --git a/ktransformers/operators/mlp.py b/ktransformers/operators/mlp.py new file mode 100644 index 0000000..d8c502d --- /dev/null +++ b/ktransformers/operators/mlp.py @@ -0,0 +1,23 @@ + +from ktransformers.operators.base_operator import BaseInjectedModule +from ktransformers.util.custom_gguf import GGUFLoader +from transformers import PretrainedConfig +import torch.nn as nn +from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP + + +class kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.config, + orig_module.hidden_size, orig_module.intermediate_size) + def forward(self, x, bsz_tensor): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor) + return down_proj \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml index c20973d..7f3e44e 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml @@ -22,7 +22,7 @@ replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: - generate_device: "cpu" + generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve.yaml new file mode 100644 index 0000000..4d5ecb0 --- /dev/null +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve.yaml @@ -0,0 +1,90 @@ +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearFP8" + prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE + replace: + class: ktransformers.operators.experts.KDeepseekV3MoEV2 # mlp module with custom forward function + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + class: ktransformers.models.modeling_deepseek_v3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGate + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.flashinfer_attn # optimized MLA implementation + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm + replace: + class: ktransformers.operators.layernorm.RMSNorm + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP + replace: + class: ktransformers.operators.mlp.kDeepseekV3MLP + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^lm_head$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "VLinearMarlin" + prefill_op: "KLinearTorch" \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml index 18138c9..849439c 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml @@ -10,7 +10,7 @@ name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: - class: ktransformers.operators.RoPE.KMoEGateDeepSeekV3 + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" @@ -18,7 +18,7 @@ name: "^model\\.layers\\.([3456][0-9])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: - class: ktransformers.operators.RoPE.KMoEGateDeepSeekV3 + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" @@ -66,7 +66,7 @@ name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: - class: ktransformers.operators.gate.KMoEGate + class: ktransformers.operators.gate.KMoEGateDeepSeekV3 kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" @@ -74,7 +74,7 @@ name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: - class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function + class: ktransformers.operators.gate.KMoEGateDeepSeekV3 # mlp module with custom forward function kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml index 98b5b5e..7e24d44 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml @@ -10,7 +10,7 @@ name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: - class: ktransformers.operators.RoPE.KMoEGateDeepSeekV3 + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" @@ -66,7 +66,7 @@ name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: - class: ktransformers.operators.gate.KMoEGate + class: ktransformers.operators.gate.KMoEGateDeepSeekV3 kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" @@ -74,7 +74,7 @@ name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: - class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function + class: ktransformers.operators.gate.KMoEGateDeepSeekV3 # mlp module with custom forward function kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml new file mode 100644 index 0000000..622ad21 --- /dev/null +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml @@ -0,0 +1,92 @@ +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^lm_head$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "VLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "VLinearMarlin" + prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE + replace: + class: ktransformers.operators.experts.KDeepseekV3MoEV2 # mlp module with custom forward function + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + class: ktransformers.models.modeling_deepseek_v3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGate + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.flashinfer_attn # optimized MLA implementation + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + absorb_for_prefill: False # change this to True to enable long context(prefill may slower). +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm + replace: + class: ktransformers.operators.layernorm.RMSNorm + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP + replace: + class: ktransformers.operators.mlp.kDeepseekV3MLP + kwargs: + generate_device: "cuda" + prefill_device: "cuda" \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml index 3c36073..d28e016 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml @@ -38,7 +38,7 @@ - match: class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: - class: ktransformers.operators.gate.KMoEGateDeepSeekV3 + class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" diff --git a/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B-serve.yaml b/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B-serve.yaml new file mode 100644 index 0000000..68098b7 --- /dev/null +++ b/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B-serve.yaml @@ -0,0 +1,94 @@ + + +- match: + name: "^lm_head$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "VLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "VLinearMarlin" + prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE + replace: + class: ktransformers.operators.experts.KDeepseekV3MoEV2 # mlp module with custom forward function + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + class: ktransformers.models.modeling_deepseek_v3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGateDeepSeekV3 + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.flashinfer_attn # optimized MLA implementation + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + absorb_for_prefill: False # change this to True to enable long context(prefill may slower). +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm + replace: + class: ktransformers.operators.layernorm.RMSNorm + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP + replace: + class: ktransformers.operators.mlp.kDeepseekV3MLP + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding + replace: + class: ktransformers.operators.RoPE.RotaryEmbeddingV4 + kwargs: + generate_device: "cuda" + prefill_device: "cuda" \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml b/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml index 6cea246..e15b880 100644 --- a/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml +++ b/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml @@ -38,7 +38,7 @@ - match: class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: - class: ktransformers.operators.gate.KMoEGate + class: ktransformers.operators.gate.KMoEGateDeepSeekV3 kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" diff --git a/ktransformers/server/args.py b/ktransformers/server/args.py index 1f9af76..e60da10 100644 --- a/ktransformers/server/args.py +++ b/ktransformers/server/args.py @@ -1,6 +1,6 @@ import argparse from ktransformers.server.backend.args import ConfigArgs, default_args - +from ktransformers.util.utils import get_free_ports class ArgumentParser: def __init__(self, cfg): @@ -16,20 +16,18 @@ class ArgumentParser: parser.add_argument("--web", type=bool, default=self.cfg.mount_web) parser.add_argument("--model_name", type=str, default=self.cfg.model_name) parser.add_argument("--model_dir", type=str) - parser.add_argument("--model_path", type=str) + parser.add_argument("--model_path", type=str, default=self.cfg.model_path) parser.add_argument( "--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter" ) parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path) - parser.add_argument("--optimize_config_path", default=self.cfg.optimize_config_path, type=str, required=False) + parser.add_argument("--optimize_config_path", default=None, type=str, required=False) parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer) - parser.add_argument("--type", type=str, default=self.cfg.backend_type) - parser.add_argument("--chunk_prefill_size", type=int, default=8192) + parser.add_argument("--backend_type", type=str, default=self.cfg.backend_type) + parser.add_argument("--chunk_size", type=int, default=self.cfg.chunk_size) # model configs # parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int? - parser.add_argument("--paged", type=bool, default=self.cfg.paged) - parser.add_argument("--total_context", type=int, default=self.cfg.total_context) parser.add_argument("--max_batch_size", type=int, default=self.cfg.max_batch_size) parser.add_argument("--max_new_tokens", type=int, default=self.cfg.max_new_tokens) parser.add_argument("--json_mode", type=bool, default=self.cfg.json_mode) @@ -62,7 +60,6 @@ class ArgumentParser: parser.add_argument("--repetition_penalty", type=float, default=self.cfg.repetition_penalty) parser.add_argument("--frequency_penalty", type=float, default=self.cfg.frequency_penalty) parser.add_argument("--presence_penalty", type=float, default=self.cfg.presence_penalty) - parser.add_argument("--max_response_tokens", type=int, default=self.cfg.max_response_tokens) parser.add_argument("--response_chunk", type=int, default=self.cfg.response_chunk) parser.add_argument("--no_code_formatting", type=bool, default=self.cfg.no_code_formatting) parser.add_argument("--cache_8bit", type=bool, default=self.cfg.cache_8bit) @@ -103,6 +100,18 @@ class ArgumentParser: # local chat parser.add_argument("--prompt_file", type=str, default=self.cfg.prompt_file) + + # async server + parser.add_argument("--sched_strategy", type=str, default=self.cfg.sched_strategy) + # parser.add_argument("--sched_port", type=int, default=self.cfg.sched_port) + # parser.add_argument("--sched_metrics_port", type=int, default=self.cfg.sched_metrics_port) + # parser.add_argument("--kvc2_metrics_port", type=int, default=self.cfg.kvc2_metrics_port) + parser.add_argument("--page_size", type=str, default=self.cfg.page_size) + parser.add_argument("--memory_gpu_only", type=str, default=self.cfg.memory_gpu_only) + parser.add_argument("--utilization_percentage", type=str, default=self.cfg.utilization_percentage) + parser.add_argument("--cpu_memory_size_GB", type=str, default=self.cfg.cpu_memory_size_GB) + + args = parser.parse_args() if (args.model_dir is not None or args.model_path is not None): if (args.model_path is not None): @@ -123,6 +132,15 @@ class ArgumentParser: self.cfg.mount_web = args.web self.cfg.server_ip = args.host self.cfg.server_port = args.port - self.cfg.backend_type = args.type self.cfg.user_force_think = args.force_think + + args.gpu_memory_size = args.cache_lens*2*576*61 + self.cfg.gpu_memory_size = args.gpu_memory_size + free_ports = get_free_ports(3, [args.port]) + args.sched_port = free_ports[0] + args.sched_metrics_port = free_ports[1] + args.kvc2_metrics_port = free_ports[2] + self.cfg.sched_port = free_ports[0] + self.cfg.sched_metrics_port = free_ports[1] + self.cfg.kvc2_metrics_port = free_ports[2] return args diff --git a/ktransformers/server/backend/args.py b/ktransformers/server/backend/args.py index 0f025d4..1c602b1 100644 --- a/ktransformers/server/backend/args.py +++ b/ktransformers/server/backend/args.py @@ -12,18 +12,10 @@ class ConfigArgs(BaseModel): class Config: protected_namespaces = () - paged: bool = Field(None, description="Whether to use paged attention kv cache") - total_context: int = Field( - None, - description=( - "Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the" - " total to distribute dynamically over however many jobs are active at once" - ), - ) max_batch_size: int = Field( None, description="Max number of batches to run at once, assuming the sequences will fit within total_context" ) - chunk_prefill_size: int = Field( + chunk_size: int = Field( None, description=( "Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new" @@ -70,7 +62,6 @@ class ConfigArgs(BaseModel): repetition_penalty: float = Field(None, description="Sampler repetition penalty, default = 1.01 (1 to disable)") frequency_penalty: float = Field(None, description="Sampler frequency penalty, default = 0.0 (0 to disable)") presence_penalty: float = Field(None, description="Sampler presence penalty, default = 0.0 (0 to disable)") - max_response_tokens: int = Field(None, description="Max tokens per response, default = 1000") response_chunk: int = Field(None, description="Space to reserve in context for reply, default = 250") no_code_formatting: bool = Field(None, description="Disable code formatting/syntax highlighting") cache_8bit: bool = Field(None, description="Use 8-bit (FP8) cache") diff --git a/ktransformers/server/backend/context_manager.py b/ktransformers/server/backend/context_manager.py index f18e3cf..e44feaa 100644 --- a/ktransformers/server/backend/context_manager.py +++ b/ktransformers/server/backend/context_manager.py @@ -9,9 +9,11 @@ from ktransformers.server.backend.interfaces.transformers import TransformersThr from ktransformers.server.backend.interfaces.ktransformers import KTransformersThreadContext from ktransformers.server.backend.interfaces.exllamav2 import ExllamaThreadContext + from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface from ktransformers.server.backend.interfaces.transformers import TransformersInterface from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface + class ThreadContextManager: lock: Lock threads_context: Dict[ObjectID, ThreadContext] @@ -36,7 +38,16 @@ class ThreadContextManager: elif isinstance(self.interface, TransformersInterface): new_context = TransformersThreadContext(run, self.interface) else: - raise NotImplementedError + from ktransformers.server.backend.interfaces.balance_serve import BalanceServeThreadContext + from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface + if isinstance(self.interface, BalanceServeInterface): + new_context = BalanceServeThreadContext(run, self.interface) + else: + raise NotImplementedError + # elif isinstance(self.interface, BalanceServeInterface): + # new_context = BalanceServeThreadContext(run, self.interface) + # else: + # raise NotImplementedError self.threads_context[run.thread_id] = new_context # self.threads_context[run.thread_id] = ExllamaInferenceContext(run) re = self.threads_context[run.thread_id] diff --git a/ktransformers/server/backend/interfaces/balance_serve.py b/ktransformers/server/backend/interfaces/balance_serve.py new file mode 100644 index 0000000..be48c92 --- /dev/null +++ b/ktransformers/server/backend/interfaces/balance_serve.py @@ -0,0 +1,406 @@ +from typing import Any, AsyncIterator, List, Optional, Set +from ktransformers.models.custom_cache import KDeepSeekV3Cache +from transformers import ( + AutoTokenizer, + AutoConfig, + GenerationConfig, + StaticCache, + AutoModelForCausalLM, + BitsAndBytesConfig, +) + +from ktransformers.server.config.config import Config +from ..base import ThreadContext, BackendInterfaceBase +import torch +from ktransformers.server.backend.interfaces.transformers import ( + ConfigArgs, + default_args, + TextStreamer, +) +from ktransformers.server.schemas.base import ObjectID +from ktransformers.server.config.log import logger +from ktransformers.optimize.optimize import optimize_and_load_gguf +from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM +from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM +from ktransformers.server.balance_serve.inference.model_runner import ModelRunner +from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions +from ktransformers.server.balance_serve.inference.query_manager import QueryManager +from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput +from ktransformers.server.balance_serve.sched_rpc import SchedulerClient +from ktransformers.server.balance_serve.settings import sched_ext +from torch.multiprocessing import Queue +import torch.multiprocessing as mp +from ktransformers.server.schemas.endpoints.chat import RawUsage +from ktransformers.server.utils.multi_timer import Profiler +import zmq +import time +import queue +import tempfile +import asyncio +import threading +from contextlib import asynccontextmanager +from fastapi import FastAPI, Request +import os + + + +ktransformer_rules_dir = ( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/") +) +default_optimize_rules = { + "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml", + "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct-serve.yaml", +} + +async def chat_stream(queue: asyncio.Queue, tokenizer: AutoTokenizer): + streamer = TextStreamer(tokenizer) + while True: + token = await queue.get() + #print(f"Got token: {token}") + if token is None: + # str = f'{token}\n\n' + # str = model.tokenizer.decode(token) + s = streamer.end() + if s is not None: + yield s + break + + # str = model.tokenizer.decode(token) + yield streamer.put(token) + + + +def fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_tokens: torch.Tensor, query_manager: QueryManager = None): + #print(len(query_updates), generated_tokens.size(0), generated_tokens) + for i in range(generated_tokens.size(0)): + print(generated_tokens[i].item()) + query_updates[i].generated_token = generated_tokens[i].item() + if not query_manager.query_map[query_updates[i].id].is_prefill: + pos = query_updates[i].active_position + query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i] + +def report_last_time_performance(profiler: Profiler): + try: + tokenize_time = profiler.get_timer_sec('tokenize') + prefill_time = profiler.get_timer_sec('prefill') + decode_time = profiler.get_timer_sec('decode') + prefill_count = profiler.get_counter('prefill') + decode_count = profiler.get_counter('decode') + + logger.info(f'Performance(T/s): prefill {prefill_count/prefill_time}, decode {decode_count/decode_time}. Time(s): tokenize {tokenize_time}, prefill {prefill_time}, decode {decode_time}') + except: + logger.info(f'Performance statistics not recorded') + +class Engine: + sched_client : SchedulerClient + updates : list[sched_ext.QueryUpdate] + batch : sched_ext.BatchQueryTodo + model_runner: ModelRunner + sampler: Sampler + query_manager: QueryManager + cache: KDeepSeekV3Cache + def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None): + self.args = args + + # 子进程和父进程无法共享 config 变量 + for key, value in vars(args).items(): + if value is not None and hasattr(Config(), key): + setattr(Config(), key, value) + + self.device = self.args.device + self.sched_client = SchedulerClient(args.sched_port) + self.updates = [] + config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) + self.cache = KDeepSeekV3Cache(config, self.args.page_size) + + self.gen_queue = generated_token_queue + + print(f"Getting inference context from sched_client.") + inference_context = self.sched_client.get_inference_context_raw() + print(f"Got inference context, sending it to subscribers.") + inference_context = self.sched_client.rebuild_inferece_context(inference_context) + self.cache.load(inference_context) + print(f"kv_cache loaded successfully.") + + self.block_num = inference_context.k_cache[0].size(1) + with torch.device("meta"): + if config.architectures[0] == "DeepseekV3ForCausalLM": + self.model = KDeepseekV3ForCausalLM(config, self.cache) + elif config.architectures[0] == "DeepseekV2ForCausalLM": + self.model = KDeepseekV2ForCausalLM(config, self.cache) + # print(self.block_num) + + context = zmq.Context() + + + self.pub_socket = context.socket(zmq.PUB) + self.pub_socket.bind(f"ipc://{broadcast_endpoint}") + # time.sleep(1) # make sure all subscribers are ready + + + try: + generation_config = GenerationConfig.from_pretrained(args.model_dir) + except: + generation_config = GenerationConfig( + max_length=args.max_new_tokens, + temperature=args.temperature, + top_p=args.top_p, + do_sample=True + ) + + if args.optimize_config_path is None: + optimize_config_path = default_optimize_rules[config.architectures[0]] + + else: + optimize_config_path = args.optimize_config_path + gguf_path = args.gguf_path + if gguf_path is None: + gguf_path = input( + "please input the path of your gguf file(gguf file in the dir containing input gguf file must all" + " belong to current model):" + ) + optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config) + self.model.generation_config = generation_config + if self.model.generation_config.pad_token_id is None: + self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id + + self.model.eval() + #@TODO add config + self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num) + + self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size) + self.sampler = Sampler() + self.query_manager = QueryManager(device = self.device, page_size = args.page_size) + + + def sampling(self, forward_output: ForwardBatchOutput): + generated_tokens = torch.empty(0, device=self.device, dtype=torch.int32) + for i in range(forward_output.num_batchs): + logit = forward_output.logits[i] + if hasattr(forward_output, "temperatures"): + temperatures = forward_output.temperatures[i] + else: + temperatures = None + + if hasattr(forward_output, "top_ps"): + top_ps = forward_output.top_ps[i] + else: + top_ps = None + + sample_options = SamplingOptions(logit.size(0), self.device, pretrained_config=self.model.generation_config, temperatures=temperatures, top_ps=top_ps) + generated_tokens, probs=self.sampler(logit, sample_options) + return generated_tokens, probs + + def loop(self): + + next_batch = None + + while True: + self.batch = next_batch + if self.batch is not None: + self.model_runner.run(self.batch, self.query_manager) + + if len(self.updates) > 0: + for q in self.updates: + if q.is_prefill == True: + continue + # print(f"Putting token {q.generated_token} into queue for query id: {q.id}") + try: + self.gen_queue.put((q.id, q.generated_token if q.decode_done == False else None), timeout=5) + except queue.Full: + pass#print("Queue is full after timeout; unable to put more items.") + + next_batch = self.sched_client.update_last_batch(self.updates) + if next_batch.query_ids == []: + next_batch = None + self.pub_socket.send_pyobj(next_batch) + + if next_batch is not None: + self.query_manager.add_query(next_batch) + + + if self.batch is not None: + self.model_runner.sync() + print(f"Model execution time (GPU): {self.model_runner.model_time:.3f} ms") + # if self.rank == 0: + + generated_tokens, probs = self.sampling( self.model_runner.output) + + self.updates = self.query_manager.update(self.batch) + fill_generated_tokens(self.updates, generated_tokens, self.query_manager) + else: + self.updates = [] + +class BalanceServeThreadContext(ThreadContext): + def get_local_messages(self): + local_messages = [] + for m in self.messages: + local_messages.append({"role": m.role.value, "content": m.get_text_content()}) + + return local_messages + + +def run_engine(args, token_queue, broadcast_endpoint, event): + engine = Engine(args, token_queue, broadcast_endpoint) + if args.use_cuda_graph: + engine.model_runner.warmup() + + event.set() + engine.loop() + + +class BalanceServeInterface(BackendInterfaceBase): + use_static_cache: bool = True + + model: Any + tokenizer: AutoTokenizer + + cache: StaticCache + generated_ids: torch.Tensor + seq_length: int + + streamer: TextStreamer + + # thread_related + last_request_id: Optional[str] = None + ever_generated_ids: Set[int] = set() + def __init__(self, args: ConfigArgs = default_args): + self.args = args + self.queue_map:dict[int,asyncio.Queue] = {} + self.thread_map: dict[int, int] = {} + processes = [] + self.broadcast_endpoint = tempfile.NamedTemporaryFile(delete=False).name # @TODO add to config + ctx = mp.get_context("spawn") + self.token_queue = ctx.Queue(maxsize=1000) + self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True) + self.sched_client = SchedulerClient(args.sched_port) + self.streamer = TextStreamer(self.tokenizer) + + start_event = ctx.Event() + + p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event)) + p.start() + processes.append(p) + start_event.wait() + + def run_queue_proxy(self): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self.queue_proxy()) + + @asynccontextmanager + async def lifespan(self, app: FastAPI): + asyncio.create_task(self.queue_proxy()) + yield + + async def queue_proxy(self): + print("Queue Proxy Started") + while True: + try: + query_id, token = self.token_queue.get_nowait() + try: + # query id might not be allocated yet + self.queue_map[query_id].put_nowait(token) + #print(f"Proxy Put token: {token} to queue for query id: {query_id}") + except asyncio.QueueFull: + #print(f"Queue for query id: {query_id} is full, waiting to put: {token}") + await self.queue_map[query_id].put(token) + + except queue.Empty: + # print("no new token") + # await asyncio.sleep(1) + await asyncio.sleep(0) + def tokenize_prompt(self, prompt: str): + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.args.device) + return input_ids + + def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List): + for m in messages: + if m["role"] == "system": + logger.warning(f'change {m["role"]} to user') + m["role"] = "user" + + new_messages = [messages[0]] + for m in messages[1:]: + if m["role"] == "user" and new_messages[-1]["role"] == "user": + logger.warning("merge two adjacent user messages") + new_messages[-1]["content"] += '\n' + m["content"] + else: + new_messages.append(m) + input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True) + # drop token in chat template + if input_str.endswith('\n'): + input_str = input_str[:-len('\n')] + input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device) + logger.debug(f"get input ids of shape {input_ids.shape}") + return input_ids + + async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None): + profiler = Profiler() + profiler.create_and_start_timer("tokenize") + + if isinstance(local_messages, List): + input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages) + elif isinstance(local_messages, str): + #local_messages = local_messages[0]['content'] + input_ids = self.tokenize_prompt(local_messages) + else: + raise ValueError("local_messages should be List or str") + if Config().user_force_think: + token_thinks = torch.tensor([self.tokenizer.encode("\n",add_special_tokens=False)],device=input_ids.device) + input_ids = torch.cat( + [input_ids, token_thinks], dim=1 + ) + + + profiler.pause_timer("tokenize") + + profiler.create_and_start_timer("prefill") + + + + query_add = sched_ext.QueryAdd() + query_add.query_token = input_ids[0].tolist() + query_length = input_ids[0].shape[0] + query_add.query_length = query_length + profiler.set_counter("prefill", query_length) + #@TODO add server + stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")] + query_add.stop_criteria = stop_criteria + query_add.sample_options.temperature = temperature + query_add.sample_options.top_p = top_p + query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens) + query_id = self.sched_client.add_query(query_add) + queue = asyncio.Queue(maxsize=self.args.max_new_tokens) + self.queue_map[query_id] = queue + self.thread_map[thread_id] = query_id + is_first_token = True + async for token in chat_stream(self.queue_map[query_id], self.tokenizer): + if is_first_token: + is_first_token=False + profiler.pause_timer("prefill") + profiler.create_and_start_timer("decode") + profiler.set_counter("decode", 0) + if Config().user_force_think: + think = '\n' + print(think, end="",flush=True) + yield think, None + else: + profiler.inc("decode") + yield token, None + profiler.pause_timer("decode") + report_last_time_performance(profiler) + yield self.streamer.end(), None + if profiler.get_counter('decode') >= self.args.max_new_tokens - 1: + yield "", "length" + else: + yield "", "stop" + + + yield RawUsage( + tokenize_time = profiler.get_timer_sec('tokenize'), + prefill_time = profiler.get_timer_sec('prefill'), + decode_time = profiler.get_timer_sec('decode'), + prefill_count = profiler.get_counter('prefill'), + decode_count = profiler.get_counter('decode'), + ) diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 1752a3c..e9e2533 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -211,11 +211,11 @@ class KTransformersInterface(TransformersInterface): chunk_start = 0 while chunk_start < input_ids_length: - chunk_end = min(chunk_start + self.args.chunk_prefill_size, input_ids_length) + chunk_end = min(chunk_start + self.args.chunk_size, input_ids_length) if self.cache != None: self.cache.cur_idx=cache_position[chunk_start:chunk_end] logits = chunk_prefill(input_ids[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end]) - chunk_start += self.args.chunk_prefill_size + chunk_start += self.args.chunk_size if flashinfer_enabled: MLAWrapperSingleton.reset_buffer() diff --git a/ktransformers/server/balance_serve/inference/__init__.py b/ktransformers/server/balance_serve/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ktransformers/server/balance_serve/inference/config.py b/ktransformers/server/balance_serve/inference/config.py new file mode 100644 index 0000000..6140f24 --- /dev/null +++ b/ktransformers/server/balance_serve/inference/config.py @@ -0,0 +1,142 @@ +''' +Date: 2024-11-07 07:30:16 +LastEditors: djw +LastEditTime: 2024-11-15 14:23:26 +''' +import math +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F +import yaml + +import json +from typing import Optional + +class ModelConfig: + vocab_size: int = 32000 + n_layer: int = 1 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = 18944 + n_local_heads: int = 8 + head_dim: int = 128 + rope_base: float = 1000000.0 + norm_eps: float = 1e-06 + rope_scaling: Optional[dict] = None + rms_norm_eps: float = 1e-6 + hidden_act: str = "silu" + model_path: str + gguf_path: str + optimize_rule_path: str + speculative_rule_path: str + + + # quantize config + quant_algorithm: Optional[str] = None + quant_group_size: Optional[int] = None + quant_num_bits: Optional[int] = None + + json_key_map = { + "vocab_size": "vocab_size", + "n_layer": "num_hidden_layers", + "n_head": "num_attention_heads", + "dim": "hidden_size", + "intermediate_size": "intermediate_size", + "n_local_heads": "num_key_value_heads", + "rope_base": "rope_theta", + "norm_eps": "norm_eps", + "rms_norm_eps": "rms_norm_eps", + "hidden_act": "hidden_act", + } + + def __init__(self, config): + self.model_path = config["model"]["model_path"] + self.gguf_path = config["model"]["gguf_path"] + self.optimize_rule_path = config["model"]["optimize_rule_path"] + if "speculative_rule_path" in config["model"]: + self.speculative_rule_path = config["model"]["speculative_rule_path"] + self.speculative_gguf_path = config["model"]["speculative_gguf_path"] + self.speculative_model_path = config["model"]["speculative_model_path"] + self.quant_algorithm = config["model"]["quant"]["algorithm"] + self.quant_group_size = config["model"]["quant"]["group_size"] + self.quant_num_bits = config["model"]["quant"]["num_bits"] + self.load_config() + self.n_layer = config["model"]["n_layers"] + + def load_config(self): + config_file = f"{self.model_path}/config.json" + try: + with open(config_file, "r") as f: + config_data = json.load(f) + except FileNotFoundError: + raise FileNotFoundError(f"Configuration file not found at {config_file}") + + for attr, json_key in self.json_key_map.items(): + if json_key in config_data: + setattr(self, attr, config_data[json_key]) + else: + setattr(self, attr, getattr(self, attr)) + + + + + +class ParallelConfig: + def __init__( + self, + config, + ) -> None: + self.pipeline_parallel_size = config["parallel"]["pp"] + self.tensor_parallel_size = config["parallel"]["tp"] + self.disable_custom_all_reduce = config["parallel"]["disable_custom_all_reduce"] + self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size + +class AttnConfig: + page_size: int = 256 + block_num: int = 32 + max_batch_token : int = 256 + max_batch_size: int = 32 + + def __init__(self, config): + self.page_size = config["attn"]["page_size"] + self.block_num = config["attn"]["block_num"] + self.max_batch_token = config["attn"]["max_batch_token"] + self.max_batch_size = config["attn"]["max_batch_size"] + + +class SamplerConfig(): + # Batched sampling params + temperatures: float + is_all_greedy: bool + + def __init__(self, config): + self.temperatures = config["sample"]["temperature"] + self.is_all_greedy = True + + +def load_yaml_config(file_path): + with open(file_path, "r") as f: + return yaml.safe_load(f) + + + + +class LLMConfig: + model_config: ModelConfig + parallel_config: ParallelConfig + attn_config: AttnConfig + sample_config: SamplerConfig + config_file: str + + def __init__(self, config_file): + self.config_file = config_file + config = load_yaml_config(config_file) + self.model_config = ModelConfig(config) + self.parallel_config = ParallelConfig(config) + self.attn_config = AttnConfig(config) + self.sample_config = SamplerConfig(config) + diff --git a/ktransformers/server/balance_serve/inference/distributed/__init__.py b/ktransformers/server/balance_serve/inference/distributed/__init__.py new file mode 100644 index 0000000..db325cf --- /dev/null +++ b/ktransformers/server/balance_serve/inference/distributed/__init__.py @@ -0,0 +1,3 @@ +from .communication_op import * +from .parallel_state import * +from .utils import * diff --git a/ktransformers/server/balance_serve/inference/distributed/communication_op.py b/ktransformers/server/balance_serve/inference/distributed/communication_op.py new file mode 100644 index 0000000..37d8dca --- /dev/null +++ b/ktransformers/server/balance_serve/inference/distributed/communication_op.py @@ -0,0 +1,39 @@ +""" +Date: 2024-12-11 06:02:42 +LastEditors: djw +LastEditTime: 2024-12-12 09:52:06 +""" + +from typing import Any, Dict, Optional, Union + +import torch +import torch.distributed + +from .parallel_state import get_tp_group + + +def tensor_model_parallel_all_reduce(input_: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> torch.Tensor: + """All-reduce the input tensor across model parallel group.""" + return get_tp_group().all_reduce(input_, bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap) + + +def tensor_model_parallel_all_gather( + input_: torch.Tensor, dim: int = -1 +) -> torch.Tensor: + """All-gather the input tensor across model parallel group.""" + return get_tp_group().all_gather(input_, dim) + + +def tensor_model_parallel_gather( + input_: torch.Tensor, dst: int = 0, dim: int = -1 +) -> Optional[torch.Tensor]: + """Gather the input tensor across model parallel group.""" + return get_tp_group().gather(input_, dst, dim) + + +def broadcast_tensor_dict( + tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0 +): + if not torch.distributed.is_initialized(): + return tensor_dict + return get_tp_group().broadcast_tensor_dict(tensor_dict, src) diff --git a/ktransformers/server/balance_serve/inference/distributed/cuda_wrapper.py b/ktransformers/server/balance_serve/inference/distributed/cuda_wrapper.py new file mode 100644 index 0000000..31bf415 --- /dev/null +++ b/ktransformers/server/balance_serve/inference/distributed/cuda_wrapper.py @@ -0,0 +1,168 @@ +"""This file is a pure Python wrapper for the cudart library. +It avoids the need to compile a separate shared library, and is +convenient for use when we just need to call a few functions. +""" + +import ctypes +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +# this line makes it possible to directly load `libcudart.so` using `ctypes` +import torch # noqa + +# === export types and functions from cudart to Python === +# for the original cudart definition, please check +# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html + +cudaError_t = ctypes.c_int +cudaMemcpyKind = ctypes.c_int + + +class cudaIpcMemHandle_t(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +def find_loaded_library(lib_name) -> Optional[str]: + """ + According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, + the file `/proc/self/maps` contains the memory maps of the process, which includes the + shared libraries loaded by the process. We can use this file to find the path of the + a loaded library. + """ # noqa + found = False + with open("/proc/self/maps") as f: + for line in f: + if lib_name in line: + found = True + break + if not found: + # the library is not loaded in the current process + return None + # if lib_name is libcudart, we need to match a line with: + # address /path/to/libcudart-hash.so.11.0 + start = line.index("/") + path = line[start:].strip() + filename = path.split("/")[-1] + assert filename.rpartition(".so")[0].startswith(lib_name), \ + f"Unexpected filename: {filename} for library {lib_name}" + return path + + +class CudaRTLibrary: + exported_functions = [ + # ​cudaError_t cudaSetDevice ( int device ) + Function("cudaSetDevice", cudaError_t, [ctypes.c_int]), + # cudaError_t cudaDeviceSynchronize ( void ) + Function("cudaDeviceSynchronize", cudaError_t, []), + # ​cudaError_t cudaDeviceReset ( void ) + Function("cudaDeviceReset", cudaError_t, []), + + # const char* cudaGetErrorString ( cudaError_t error ) + Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]), + + # ​cudaError_t cudaMalloc ( void** devPtr, size_t size ) + Function("cudaMalloc", cudaError_t, + [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]), + # ​cudaError_t cudaFree ( void* devPtr ) + Function("cudaFree", cudaError_t, [ctypes.c_void_p]), + # ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count ) + Function("cudaMemset", cudaError_t, + [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]), + # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa + Function("cudaMemcpy", cudaError_t, [ + ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind + ]), + + # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa + Function("cudaIpcGetMemHandle", cudaError_t, + [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]), + # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa + Function("cudaIpcOpenMemHandle", cudaError_t, [ + ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint + ]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + if so_file is None: + so_file = find_loaded_library("libcudart") + assert so_file is not None, \ + "libcudart is not loaded in the current process" + if so_file not in CudaRTLibrary.path_to_library_cache: + lib = ctypes.CDLL(so_file) + CudaRTLibrary.path_to_library_cache[so_file] = lib + self.lib = CudaRTLibrary.path_to_library_cache[so_file] + + if so_file not in CudaRTLibrary.path_to_dict_mapping: + _funcs = {} + for func in CudaRTLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs + self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file] + + def CUDART_CHECK(self, result: cudaError_t) -> None: + if result != 0: + error_str = self.cudaGetErrorString(result) + raise RuntimeError(f"CUDART error: {error_str}") + + def cudaGetErrorString(self, error: cudaError_t) -> str: + return self.funcs["cudaGetErrorString"](error).decode("utf-8") + + def cudaSetDevice(self, device: int) -> None: + self.CUDART_CHECK(self.funcs["cudaSetDevice"](device)) + + def cudaDeviceSynchronize(self) -> None: + self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]()) + + def cudaDeviceReset(self) -> None: + self.CUDART_CHECK(self.funcs["cudaDeviceReset"]()) + + def cudaMalloc(self, size: int) -> ctypes.c_void_p: + devPtr = ctypes.c_void_p() + self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size)) + return devPtr + + def cudaFree(self, devPtr: ctypes.c_void_p) -> None: + self.CUDART_CHECK(self.funcs["cudaFree"](devPtr)) + + def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, + count: int) -> None: + self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count)) + + def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p, + count: int) -> None: + cudaMemcpyDefault = 4 + kind = cudaMemcpyDefault + self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind)) + + def cudaIpcGetMemHandle(self, + devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t: + handle = cudaIpcMemHandle_t() + self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"]( + ctypes.byref(handle), devPtr)) + return handle + + def cudaIpcOpenMemHandle(self, + handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: + cudaIpcMemLazyEnablePeerAccess = 1 + devPtr = ctypes.c_void_p() + self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"]( + ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess)) + return devPtr diff --git a/ktransformers/server/balance_serve/inference/distributed/custom_all_reduce.py b/ktransformers/server/balance_serve/inference/distributed/custom_all_reduce.py new file mode 100644 index 0000000..e170e43 --- /dev/null +++ b/ktransformers/server/balance_serve/inference/distributed/custom_all_reduce.py @@ -0,0 +1,310 @@ +import ctypes +from contextlib import contextmanager +from typing import List, Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +import server.envs as envs +from server.inference.distributed.cuda_wrapper import CudaRTLibrary +from server.inference.distributed.custom_all_reduce_utils import gpu_p2p_access_check +from server.inference.distributed.parallel_state import in_the_same_node_as +from server.inference.platforms import current_platform +from server.utils import cuda_device_count_stateless +import vLLMCustomAllreduce + +try: + vLLMCustomAllreduce.meta_size() + custom_ar = True +except Exception: + # For AMD GPUs and CPUs + custom_ar = False + + +def _can_p2p(rank: int, world_size: int) -> bool: + for i in range(world_size): + if i == rank: + continue + if envs.VLLM_SKIP_P2P_CHECK: + print("Skipping P2P check and trusting the driver's P2P report.") + return torch.cuda.can_device_access_peer(rank, i) + if not gpu_p2p_access_check(rank, i): + return False + return True + + +def is_weak_contiguous(inp: torch.Tensor): + return inp.is_contiguous() or ( + inp.storage().nbytes() - inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size() + ) + + +class CustomAllreduce: + + _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + + # max_size: max supported allreduce size + def __init__( + self, + group: ProcessGroup, + device: Union[int, str, torch.device], + max_size=8192 * 1024, + ) -> None: + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the CustomAllreduce to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device, and all communicators in this group + are in the same node. + """ + self._IS_CAPTURING = False + self.disabled = True + + if not custom_ar: + # disable because of missing custom allreduce library + # e.g. in a non-cuda environment + return + + self.group = group + + assert ( + dist.get_backend(group) != dist.Backend.NCCL + ), "CustomAllreduce should be attached to a non-NCCL group." + + if not all(in_the_same_node_as(group, source_rank=0)): + # No need to initialize custom allreduce for multi-node case. + print( + "Custom allreduce is disabled because this process group" + " spans across nodes." + ) + return + + rank = dist.get_rank(group=self.group) + world_size = dist.get_world_size(group=self.group) + if world_size == 1: + # No need to initialize custom allreduce for single GPU case. + return + + if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES: + print( + "Custom allreduce is disabled due to an unsupported world" + " size: %d. Supported world sizes: %s. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly.", + world_size, + str(CustomAllreduce._SUPPORTED_WORLD_SIZES), + ) + return + + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices: + device_ids = list(map(int, cuda_visible_devices.split(","))) + else: + device_ids = list(range(cuda_device_count_stateless())) + + physical_device_id = device_ids[device.index] + tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu") + gather_list = [ + torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(world_size) + ] + dist.all_gather(gather_list, tensor, group=self.group) + physical_device_ids = [t.item() for t in gather_list] + + # test nvlink first, this will filter out most of the cases + # where custom allreduce is not supported + # this checks hardware and driver support for NVLink + assert current_platform.is_cuda() + from server.inference.platforms.cuda import CudaPlatform + + cuda_platform: CudaPlatform = current_platform + full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids) + if world_size > 2 and not full_nvlink: + print( + "Custom allreduce is disabled because it's not supported on" + " more than two PCIe-only GPUs. To silence this warning, " + "specify disable_custom_all_reduce=True explicitly." + ) + return + # test P2P capability, this checks software/cudaruntime support + # this is expensive to compute at the first time + # then we cache the result + if not _can_p2p(rank, world_size): + print( + "Custom allreduce is disabled because your platform lacks " + "GPU P2P capability or P2P test failed. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly." + ) + return + + self.disabled = False + # Buffers memory are owned by this Python class and passed to C++. + # Meta data composes of two parts: meta data for synchronization and a + # temporary buffer for storing intermediate allreduce results. + self.meta_ptrs = self.create_shared_buffer( + vLLMCustomAllreduce.meta_size() + max_size, group=group + ) + # This is a pre-registered IPC buffer. In eager mode, input tensors + # are first copied into this buffer before allreduce is performed + self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) + # This is a buffer for storing the tuples of pointers pointing to + # IPC buffers from all ranks. Each registered tuple has size of + # 8*world_size bytes where world_size is at most 8. Allocating 8MB + # is enough for 131072 such tuples. The largest model I've seen only + # needs less than 10000 of registered tuples. + self.rank_data = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=self.device + ) + self.max_size = max_size + self.rank = rank + self.world_size = world_size + self.full_nvlink = full_nvlink + self._ptr = vLLMCustomAllreduce.init_custom_ar( + self.meta_ptrs, self.rank_data, rank, self.full_nvlink + ) + vLLMCustomAllreduce.register_buffer(self._ptr, self.buffer_ptrs) + + @staticmethod + def create_shared_buffer( + size_in_bytes: int, group: Optional[ProcessGroup] = None + ) -> List[int]: + """ + Creates a shared buffer and returns a list of pointers + representing the buffer on all processes in the group. + """ + lib = CudaRTLibrary() + pointer = lib.cudaMalloc(size_in_bytes) + handle = lib.cudaIpcGetMemHandle(pointer) + world_size = dist.get_world_size(group=group) + rank = dist.get_rank(group=group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=group) + + pointers: List[int] = [] + for i, h in enumerate(handles): + if i == rank: + pointers.append(pointer.value) # type: ignore + else: + pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore + + return pointers + + @staticmethod + def free_shared_buffer( + pointers: List[int], group: Optional[ProcessGroup] = None + ) -> None: + rank = dist.get_rank(group=group) + lib = CudaRTLibrary() + lib.cudaFree(ctypes.c_void_p(pointers[rank])) + + @contextmanager + def capture(self): + """ + The main responsibility of this context manager is the + `register_graph_buffers` call at the end of the context. + It records all the buffer addresses used in the CUDA graph. + """ + try: + self._IS_CAPTURING = True + yield + finally: + self._IS_CAPTURING = False + if not self.disabled: + self.register_graph_buffers() + + def register_graph_buffers(self): + handle, offset = vLLMCustomAllreduce.get_graph_buffer_ipc_meta(self._ptr) + print("Registering %d cuda graph addresses", len(offset)) + # We cannot directly use `dist.all_gather_object` here + # because it is incompatible with `gloo` backend under inference mode. + # see https://github.com/pytorch/pytorch/issues/126032 for details. + all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))] + all_data[self.rank] = [handle, offset] + ranks = sorted(dist.get_process_group_ranks(group=self.group)) + for i, rank in enumerate(ranks): + dist.broadcast_object_list( + all_data[i], src=rank, group=self.group, device="cpu" + ) + # Unpack list of tuples to tuple of lists. + handles = [d[0] for d in all_data] # type: ignore + offsets = [d[1] for d in all_data] # type: ignore + vLLMCustomAllreduce.register_graph_buffers(self._ptr, handles, offsets) + + def should_custom_ar(self, inp: torch.Tensor): + if self.disabled: + return False + inp_size = inp.numel() * inp.element_size() + # custom allreduce requires input byte size to be multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + # for 4 or more non NVLink-capable GPUs, custom allreduce provides + # little performance improvement over NCCL. + if self.world_size == 2 or self.full_nvlink: + return inp_size < self.max_size + return False + + def all_reduce( + self, inp: torch.Tensor, *, out: torch.Tensor = None, bsz_tensor: torch.Tensor = None, registered: bool = False, + is_compute_bound=False, overlap=False + ): + """Performs an out-of-place all reduce. + + If registered is True, this assumes inp's pointer is already + IPC-registered. Otherwise, inp is first copied into a pre-registered + buffer. + """ + if is_compute_bound: + sms = 2 if overlap else 36 + else: + sms = 20 if overlap else 36 + #print("all reduce sms", sms) + if out is None: + out = torch.empty_like(inp) + if registered: + vLLMCustomAllreduce.all_reduce(self._ptr, inp, out, 0, 0, bsz_tensor, block_limit=sms) + else: + vLLMCustomAllreduce.all_reduce( + self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size, bsz_tensor, block_limit=sms + ) + return out + + def custom_all_reduce(self, input: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> Optional[torch.Tensor]: + """The main allreduce API that provides support for cuda graph.""" + # When custom allreduce is disabled, this will be None. + if self.disabled or not self.should_custom_ar(input): + return None + if self._IS_CAPTURING: + if torch.cuda.is_current_stream_capturing(): + return self.all_reduce(input, bsz_tensor=bsz_tensor, registered=True, is_compute_bound=is_compute_bound, overlap=overlap) + else: + # If warm up, mimic the allocation pattern since custom + # allreduce is out-of-place. + return torch.empty_like(input) + else: + # Note: outside of cuda graph context, custom allreduce incurs a + # cost of cudaMemcpy, which should be small (<=1% of overall + # latency) compared to the performance gain of using custom kernels + return self.all_reduce(input, bsz_tensor=bsz_tensor, registered=False, is_compute_bound=is_compute_bound, overlap=overlap) + + def close(self): + if not self.disabled and self._ptr: + vLLMCustomAllreduce.dispose(self._ptr) + self._ptr = 0 + self.free_shared_buffer(self.meta_ptrs) + self.free_shared_buffer(self.buffer_ptrs) + + def __del__(self): + self.close() diff --git a/ktransformers/server/balance_serve/inference/distributed/custom_all_reduce_utils.py b/ktransformers/server/balance_serve/inference/distributed/custom_all_reduce_utils.py new file mode 100644 index 0000000..d94ffe4 --- /dev/null +++ b/ktransformers/server/balance_serve/inference/distributed/custom_all_reduce_utils.py @@ -0,0 +1,272 @@ +import ctypes +import json +import os +import pickle +import subprocess +import sys +import tempfile +from itertools import product +from typing import Dict, List, Optional, Sequence + +import torch.distributed as dist +import torch.multiprocessing as mp + +import server.envs as envs +from server.inference.distributed.cuda_wrapper import CudaRTLibrary +from server.utils import cuda_device_count_stateless, update_environment_variables + + +def producer( + batch_src: Sequence[int], + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices: Optional[str] = None, +): + if cuda_visible_devices is not None: + update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) + + lib = CudaRTLibrary() + for i in batch_src: + lib.cudaSetDevice(i) + pointer = lib.cudaMalloc(1024) + lib.cudaMemset(pointer, 1, 1024) + lib.cudaDeviceSynchronize() + handle = lib.cudaIpcGetMemHandle(pointer) + producer_queue.put(handle) + open_success = consumer_queue.get() + if open_success: + # use two queues to simulate barrier + producer_queue.put(0) + consumer_queue.get() + # check if the memory is modified + host_data = (ctypes.c_char * 1024)() + lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore + for i in range(1024): + if ord(host_data[i]) != 2: + open_success = False + break + result_queue.put(open_success) + lib.cudaDeviceReset() + + +def consumer( + batch_tgt: Sequence[int], + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices: Optional[str] = None, +): + if cuda_visible_devices is not None: + update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) + + lib = CudaRTLibrary() + for j in batch_tgt: + lib.cudaSetDevice(j) + handle = producer_queue.get() + open_success = False + try: + pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore + open_success = True + except RuntimeError: + # cannot error out here, because the producer process + # is still waiting for the response. + pass + consumer_queue.put(open_success) + if open_success: + # modify the memory + lib.cudaMemset(pointer, 2, 1024) + lib.cudaDeviceSynchronize() + # use two queues to simulate barrier + producer_queue.get() + consumer_queue.put(0) + # check if the memory is modified + host_data = (ctypes.c_char * 1024)() + lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore + for i in range(1024): + if ord(host_data[i]) != 2: + open_success = False + break + result_queue.put(open_success) + lib.cudaDeviceReset() + + +def can_actually_p2p( + batch_src: Sequence[int], + batch_tgt: Sequence[int], +) -> Sequence[bool]: + """ + Usually, checking if P2P access is enabled can be done by + `torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes + the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)` + returns `True` even if P2P access is not actually possible. + See https://github.com/vllm-project/vllm/issues/2728 and + https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10 + Therefore, we have to perform a real P2P access to check if it is actually + possible. + + Note on p2p and cuda IPC: + Usually, one process uses one GPU: + GPU src --> cuda context src --> tensor src --> process src + + We need to combine p2p and cuda IPC, so that: + GPU src --> cuda context src --> tensor src --> process src + |shared| + GPU tgt --> cuda context tgt --> tensor tgt --> process tgt + That is to say, process src creates a tensor in GPU src, passes IPC handle to + process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the + tensor in process tgt will be reflected in the tensor in process src, because + they are the same memory segment. + It is important to note that process tgt accesses the tensor in GPU tgt, not + GPU src. That's why we need p2p access. + + The most time-consuming part is the process creation. To avoid creating + processes for every pair of GPUs, we use batched testing. We create two + processes for testing all pairs of GPUs in batch. The trick is to reset + the device after each test (which is not available in PyTorch). + """ # noqa + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + # pass the CUDA_VISIBLE_DEVICES to the child process + # to make sure they see the same set of GPUs + + # make sure the processes are spawned + smp = mp.get_context("spawn") + producer_queue = smp.Queue() + consumer_queue = smp.Queue() + result_queue = smp.Queue() + p_src = smp.Process( + target=producer, + args=( + batch_src, + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices, + ), + ) + p_tgt = smp.Process( + target=consumer, + args=( + batch_tgt, + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices, + ), + ) + p_src.start() + p_tgt.start() + p_src.join() + p_tgt.join() + assert p_src.exitcode == 0 and p_tgt.exitcode == 0 + result: List[bool] = [] + for src, tgt in zip(batch_src, batch_tgt): + a = result_queue.get() + b = result_queue.get() + if a != b: + print( + "Two processes do not agree on the P2P access" + " status on %d -> %d, treat as disabled.", + src, + tgt, + ) + result.append(False) + else: + result.append(a) + return result + + +# why do we need this cache? +# we are testing peer-to-peer (p2p) access between GPUs,across processes. +# if we test it every time, it will be very slow, because we need to create +# N * N * 2 processes, where N is the world size. This is very slow. +# to reduce the time, we use a cache file to store the p2p access status. +# the cache file is generated by the master process if it does not exist. +# then all the processes can read the cache file to check the p2p access status. +# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we +# can have different cache files for different CUDA_VISIBLE_DEVICES settings, +# e.g. used by different vllm engines. The device id in the cache file is a +# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number +# of visible devices in the vllm engine. +_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None + + +def gpu_p2p_access_check(src: int, tgt: int) -> bool: + """Check if GPU src can access GPU tgt.""" + + # if the cache variable is already calculated, + # read from the cache instead of checking it again + global _gpu_p2p_access_cache + if _gpu_p2p_access_cache is not None: + return _gpu_p2p_access_cache[f"{src}->{tgt}"] + + is_distributed = dist.is_initialized() + + num_dev = cuda_device_count_stateless() + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices is None: + cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) + + path = os.path.join( + envs.VLLM_CACHE_ROOT, f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json" + ) + os.makedirs(os.path.dirname(path), exist_ok=True) + from server.inference.distributed.parallel_state import get_world_group + + if (not is_distributed or get_world_group().local_rank == 0) and ( + not os.path.exists(path) + ): + # only the local master process (with local_rank == 0) can + # enter this block to calculate the cache + print("generating GPU P2P access cache in %s", path) + cache: Dict[str, bool] = {} + ids = list(range(num_dev)) + # batch of all pairs of GPUs + batch_src, batch_tgt = zip(*list(product(ids, ids))) + # NOTE: we use `subprocess` rather than `multiprocessing` here + # because the caller might not have `if __name__ == "__main__":`, + # in that case we cannot use spawn method in multiprocessing. + # However, `can_actually_p2p` requires spawn method. + # The fix is, we use `subprocess` to call the function, + # where we have `if __name__ == "__main__":` in this file. + + # use a temporary file to store the result + # we don't use the output of the subprocess directly, + # because the subprocess might produce logging output + with tempfile.NamedTemporaryFile() as output_file: + input_bytes = pickle.dumps((batch_src, batch_tgt, output_file.name)) + returned = subprocess.run( + [sys.executable, __file__], input=input_bytes, capture_output=True + ) + # check if the subprocess is successful + try: + returned.check_returncode() + except Exception as e: + # wrap raised exception to provide more information + raise RuntimeError( + f"Error happened when batch testing " + f"peer-to-peer access from {batch_src} to {batch_tgt}:\n" + f"{returned.stderr.decode()}" + ) from e + with open(output_file.name, "rb") as f: + result = pickle.load(f) + for _i, _j, r in zip(batch_src, batch_tgt, result): + cache[f"{_i}->{_j}"] = r + with open(path, "w") as f: + json.dump(cache, f, indent=4) + if is_distributed: + get_world_group().barrier() + print("reading GPU P2P access cache from %s", path) + with open(path) as f: + cache = json.load(f) + _gpu_p2p_access_cache = cache + return _gpu_p2p_access_cache[f"{src}->{tgt}"] + + +__all__ = ["gpu_p2p_access_check"] + +if __name__ == "__main__": + batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read()) + result = can_actually_p2p(batch_src, batch_tgt) + with open(output_file, "wb") as f: + f.write(pickle.dumps(result)) diff --git a/ktransformers/server/balance_serve/inference/distributed/parallel_state.py b/ktransformers/server/balance_serve/inference/distributed/parallel_state.py new file mode 100644 index 0000000..fc11374 --- /dev/null +++ b/ktransformers/server/balance_serve/inference/distributed/parallel_state.py @@ -0,0 +1,1262 @@ +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +"""vLLM distributed state. +It takes over the control of the distributed environment from PyTorch. +The typical workflow is: + +- call `init_distributed_environment` to initialize the distributed environment. +- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to + initialize the model parallel groups. + +- any code dealing with the distributed stuff + +- call `destroy_model_parallel` to destroy the model parallel groups. +- call `destroy_distributed_environment` to destroy the distributed environment. + +If you only need to use the distributed environment without model/pipeline + parallelism, you can skip the model parallel initialization and destruction + steps. +""" +import contextlib +import gc +import pickle +import weakref +from collections import namedtuple +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass +from multiprocessing import shared_memory +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from unittest.mock import patch + +import torch +import torch.distributed +from torch.distributed import Backend, ProcessGroup + +import server.envs as envs +from server.inference.platforms import current_platform +from server.utils import direct_register_custom_op, supports_custom_op + + +@dataclass +class GraphCaptureContext: + stream: torch.cuda.Stream + + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +def _split_tensor_dict( + tensor_dict: Dict[str, Union[torch.Tensor, Any]] +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + """ + metadata_list: List[Tuple[str, Any]] = [] + tensor_list: List[torch.Tensor] = [] + for key, value in tensor_dict.items(): + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + (key, TensorMetadata(device, value.dtype, value.size())) + ) + tensor_list.append(value) + else: + metadata_list.append((key, value)) + return metadata_list, tensor_list + + +_group_name_counter: Dict[str, int] = {} + + +def _get_unique_name(name: str) -> str: + """Get a unique name for the group. + Example: + _get_unique_name("tp") -> "tp:0" + _get_unique_name("tp") -> "tp:1" + """ + if name not in _group_name_counter: + _group_name_counter[name] = 0 + newname = f"{name}:{_group_name_counter[name]}" + _group_name_counter[name] += 1 + return newname + + +_groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {} + + +def _register_group(group: "GroupCoordinator") -> None: + _groups[group.unique_name] = weakref.ref(group) + + +if supports_custom_op(): + + def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + group._all_reduce_in_place(tensor) + + def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None: + return + + direct_register_custom_op( + op_name="inplace_all_reduce", + op_func=inplace_all_reduce, + mutates_args=["tensor"], + fake_impl=inplace_all_reduce_fake, + ) + + def outplace_all_reduce(tensor: torch.Tensor, group_name: str, bsz_tensor: torch.Tensor, is_compute_bound: bool = False, overlap: bool = False) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._all_reduce_out_place(tensor, bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap) + + def outplace_all_reduce_fake(tensor: torch.Tensor, group_name: str, bsz_tensor: torch.Tensor, is_compute_bound: bool = False, overlap: bool = False) -> torch.Tensor: + return torch.empty_like(tensor) + + direct_register_custom_op( + op_name="outplace_all_reduce", + op_func=outplace_all_reduce, + mutates_args=[], + fake_impl=outplace_all_reduce_fake, + ) + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and cuda graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + use_pynccl: bool # a hint of whether to use PyNccl + use_custom_allreduce: bool # a hint of whether to use CustomAllreduce + # communicators are only created for world size > 1 + pynccl_comm: Optional[Any] # PyNccl communicator + ca_comm: Optional[Any] # Custom allreduce communicator + mq_broadcaster: Optional[Any] # shared memory broadcaster + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + use_pynccl: bool, + use_custom_allreduce: bool, + use_tpu_communicator: bool, + use_hpu_communicator: bool, + use_xpu_communicator: bool, + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, + ): + group_name = group_name or "anonymous" + self.unique_name = _get_unique_name(group_name) + _register_group(self) + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + assert self.cpu_group is not None + assert self.device_group is not None + assert current_platform.is_cuda_alike() + + if current_platform.is_cuda_alike(): + self.device = torch.device(f"cuda:{local_rank}") + else: + self.device = torch.device("cpu") + + self.use_pynccl = use_pynccl + self.use_custom_allreduce = use_custom_allreduce + self.use_tpu_communicator = use_tpu_communicator + self.use_hpu_communicator = use_hpu_communicator + self.use_xpu_communicator = use_xpu_communicator + + # lazy import to avoid documentation build error + from server.inference.distributed.custom_all_reduce import CustomAllreduce + from server.inference.distributed.pynccl import PyNcclCommunicator + + self.pynccl_comm: Optional[PyNcclCommunicator] = None + # if use_pynccl and self.world_size > 1: + # self.pynccl_comm = PyNcclCommunicator( + # group=self.cpu_group, + # device=self.device, + # ) + + self.ca_comm: Optional[CustomAllreduce] = None + if use_custom_allreduce and self.world_size > 1: + # Initialize a custom fast all-reduce implementation. + self.ca_comm = CustomAllreduce( + group=self.cpu_group, + device=self.device, + ) + + #### we assume we won't use tpu or hpu or xpu or messagequeue broadcast + + # from vllm.distributed.device_communicators.tpu_communicator import ( + # TpuCommunicator) + # self.tpu_communicator: Optional[TpuCommunicator] = None + # if use_tpu_communicator and self.world_size > 1: + # self.tpu_communicator = TpuCommunicator(group=self.cpu_group) + self.tpu_communicator = None + + # from vllm.distributed.device_communicators.hpu_communicator import ( + # HpuCommunicator) + # self.hpu_communicator: Optional[HpuCommunicator] + # if use_hpu_communicator and self.world_size > 1: + # self.hpu_communicator = HpuCommunicator(group=self.device_group) + self.hpu_communicator = None + + # from vllm.distributed.device_communicators.xpu_communicator import ( + # XpuCommunicator) + # self.xpu_communicator: Optional[XpuCommunicator] + # if use_xpu_communicator and self.world_size > 1: + # self.xpu_communicator = XpuCommunicator(group=self.device_group) + self.xpu_communicator = None + + # from vllm.distributed.device_communicators.shm_broadcast import ( + # MessageQueue) + # self.mq_broadcaster: Optional[MessageQueue] = None + # if use_message_queue_broadcaster and self.world_size > 1: + # self.mq_broadcaster = MessageQueue.create_from_process_group( + # self.cpu_group, 1 << 22, 6) + self.mq_broadcaster = None + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @contextmanager + def graph_capture( + self, graph_capture_context: Optional[GraphCaptureContext] = None + ): + if graph_capture_context is None: + stream = torch.cuda.Stream() + graph_capture_context = GraphCaptureContext(stream) + else: + stream = graph_capture_context.stream + + ca_comm = self.ca_comm + maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture() + + # ensure all initialization operations complete before attempting to + # capture the graph on another stream + curr_stream = torch.cuda.current_stream() + if curr_stream != stream: + stream.wait_stream(curr_stream) + + with torch.cuda.stream(stream), maybe_ca_context: + # In graph mode, we have to be very careful about the collective + # operations. The current status is: + # allreduce \ Mode | Eager | Graph | + # -------------------------------------------- + # custom allreduce | enabled | enabled | + # PyNccl | disabled| enabled | + # torch.distributed | enabled | disabled| + # + # Note that custom allreduce will have a runtime check, if the + # tensor size is too large, it will fallback to the next + # available option. + # In summary: When using CUDA graph, we use + # either custom all-reduce kernel or pynccl. When not using + # CUDA graph, we use either custom all-reduce kernel or + # PyTorch NCCL. We always prioritize using custom all-reduce + # kernel but fall back to PyTorch or pynccl if it is + # disabled or not supported. + pynccl_comm = self.pynccl_comm + maybe_pynccl_context: Any + if not pynccl_comm: + maybe_pynccl_context = nullcontext() + else: + maybe_pynccl_context = pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream() + ) + with maybe_pynccl_context: + yield graph_capture_context + + def all_reduce(self, input_: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> torch.Tensor: + """ + User-facing all-reduce function before we actually call the + all-reduce operation. + + We need this because Dynamo does not support passing an arbitrary + object (`self` in this case) to a custom op. We need to pass the + group name as a string, and then look up the group coordinator from + the group name, dispatch the all-reduce operation to the group + coordinator. + + In addition, PyTorch custom ops do not support mutation or returning + a new tensor in the same op. So we need to figure out if the op is + in-place or out-of-place ahead of time. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + + if input_.is_cpu: + import intel_extension_for_pytorch as ipex + + ipex.distributed.all_reduce(input_, group=self.device_group) + return input_ + + if not supports_custom_op(): + self._all_reduce_in_place(input_) + return input_ + + if self.tpu_communicator is not None and not self.tpu_communicator.disabled: + # TPU handles Dynamo with its own logic. + return self.tpu_communicator.all_reduce(input_) + + if self.hpu_communicator is not None and not self.hpu_communicator.disabled: + return self.hpu_communicator.all_reduce(input_) + + if self.xpu_communicator is not None and not self.xpu_communicator.disabled: + return self.xpu_communicator.all_reduce(input_) + + if ( + self.ca_comm is not None + and not self.ca_comm.disabled + and self.ca_comm.should_custom_ar(input_) + ): + return torch.ops.vllm.outplace_all_reduce( + input_, group_name=self.unique_name, bsz_tensor=bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap + ) + else: + #assert self.ca_comm is not None + #assert not self.ca_comm.disabled + #assert self.ca_comm.should_custom_ar(input_) + torch.ops.vllm.inplace_all_reduce(input_, group_name=self.unique_name) + return input_ + + def _all_reduce_out_place(self, input_: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> torch.Tensor: + ca_comm = self.ca_comm + assert ca_comm is not None + assert not ca_comm.disabled + out = ca_comm.custom_all_reduce(input_, bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap) + assert out is not None + return out + + def _all_reduce_in_place(self, input_: torch.Tensor) -> None: + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.all_reduce(input_) + else: + torch.distributed.all_reduce(input_, group=self.device_group) + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + + # For TPUs, use TPU communicator. + tpu_comm = self.tpu_communicator + if tpu_comm is not None and not tpu_comm.disabled: + return tpu_comm.all_gather(input_, dim) + + # For HPUs, use HPU communicator. + hpu_comm = self.hpu_communicator + if hpu_comm is not None and not hpu_comm.disabled: + return hpu_comm.all_gather(input_, dim) + + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # NOTE: we have to use concat-style all-gather here, + # stack-style all-gather has compatibility issues with + # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 + output_size = (input_size[0] * world_size,) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty( + output_size, dtype=input_.dtype, device=input_.device + ) + # All-gather. + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=self.device_group + ) + # Reshape + output_tensor = output_tensor.reshape((world_size,) + input_size) + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape( + input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :] + ) + return output_tensor + + def gather( + self, input_: torch.Tensor, dst: int = 0, dim: int = -1 + ) -> Optional[torch.Tensor]: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + if self.xpu_communicator is not None and not self.xpu_communicator.disabled: + return self.xpu_communicator.gather(input_, self.rank_in_group, dst, dim) + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast( + input_, src=self.ranks[src], group=self.device_group + ) + return input_ + + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.mq_broadcaster is not None: + assert src == 0, "Message queue broadcaster only supports src=0" + return self.mq_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list( + [obj], src=self.ranks[src], group=self.cpu_group + ) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list( + recv, src=self.ranks[src], group=self.cpu_group + ) + return recv[0] + + def broadcast_object_list( + self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None + ): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list( + obj_list, src=self.ranks[src], group=self.device_group + ) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + assert dst != self.rank_in_group, ( + "Invalid destination rank. Destination rank is the same " + "as the current rank." + ) + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor( + [object_tensor.numel()], dtype=torch.long, device="cpu" + ) + + # Send object size + + torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + assert src < self.world_size, f"Invalid src rank ({src})" + + assert ( + src != self.rank_in_group + ), "Invalid source rank. Source rank is the same as the current rank." + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv( + size_tensor, src=self.ranks[src], group=self.cpu_group + ) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu", + ) + + rank_object = torch.distributed.recv( + object_tensor, src=self.ranks[src], group=self.cpu_group + ) + + assert ( + rank_object == rank_size + ), "Received object sender rank does not match the size sender rank." + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + + rank_in_group = self.rank_in_group + if rank_in_group == src: + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, dict + ), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=self.ranks[src], group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=self.ranks[src], group=group, async_op=True + ) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty( + value.size, dtype=value.dtype, device=value.device + ) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True, + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=self.ranks[src], group=group, async_op=True + ) + async_handles.append(handle) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size + all_gather_rank = ( + 0 if all_gather_group is None else all_gather_group.rank_in_group + ) + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, dict + ), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + + # send-allgather: send only a slice, then do allgather. + if all_gather_group is not None and tensor.numel() % all_gather_size == 0: + tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] + + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send( + tensor, dst=self.ranks[dst], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.send(tensor, dst=self.ranks[dst], group=group) + return None + + def recv_tensor_dict( + self, + src: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size + all_gather_rank = ( + 0 if all_gather_group is None else all_gather_group.rank_in_group + ) + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = (self.rank_in_group - 1) % self.world_size + assert src < self.world_size, f"Invalid src rank ({src})" + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: Dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + + # send-allgather: send only a slice, then do allgather. + use_all_gather = ( + all_gather_group is not None + and tensor.numel() % all_gather_size == 0 + ) + + if use_all_gather: + orig_shape = tensor.shape + tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] + + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv( + tensor, src=self.ranks[src], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, src=self.ranks[src], group=group) + if use_all_gather: + # do the allgather + tensor = all_gather_group.all_gather(tensor, dim=0) # type: ignore + tensor = tensor.reshape(orig_shape) + + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.send(tensor, dst) + else: + torch.distributed.send(tensor, self.ranks[dst], self.device_group) + + def recv( + self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + ) -> torch.Tensor: + """Receives a tensor from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + if src is None: + src = (self.rank_in_group - 1) % self.world_size + + tensor = torch.empty(size, dtype=dtype, device=self.device) + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.recv(tensor, src) + else: + torch.distributed.recv(tensor, self.ranks[src], self.device_group) + return tensor + + def destroy(self): + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + if self.pynccl_comm is not None: + self.pynccl_comm = None + if self.ca_comm is not None: + self.ca_comm = None + if self.mq_broadcaster is not None: + self.mq_broadcaster = None + + +_WORLD: Optional[GroupCoordinator] = None + + +def get_world_group() -> GroupCoordinator: + assert _WORLD is not None, "world group is not initialized" + return _WORLD + + +def init_world_group( + ranks: List[int], local_rank: int, backend: str +) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + use_pynccl=False, + use_custom_allreduce=False, + use_tpu_communicator=False, + use_hpu_communicator=False, + use_xpu_communicator=False, + group_name="world", + ) + + +def init_model_parallel_group( + group_ranks: List[List[int]], + local_rank: int, + backend: str, + use_custom_allreduce: Optional[bool] = None, + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, +) -> GroupCoordinator: + if use_custom_allreduce is None: + use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE + return GroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_pynccl=True, + use_custom_allreduce=use_custom_allreduce, + use_tpu_communicator=True, + use_hpu_communicator=True, + use_xpu_communicator=True, + use_message_queue_broadcaster=use_message_queue_broadcaster, + group_name=group_name, + ) + + +_TP: Optional[GroupCoordinator] = None + + +def get_tp_group() -> GroupCoordinator: + assert _TP is not None, "tensor model parallel group is not initialized" + return _TP + + +# kept for backward compatibility +get_tensor_model_parallel_group = get_tp_group + +_PP: Optional[GroupCoordinator] = None + + +def get_pp_group() -> GroupCoordinator: + assert _PP is not None, "pipeline model parallel group is not initialized" + return _PP + + +# kept for backward compatibility +get_pipeline_model_parallel_group = get_pp_group + + +@contextmanager +def graph_capture(): + """ + `graph_capture` is a context manager which should surround the code that + is capturing the CUDA graph. Its main purpose is to ensure that the + some operations will be run after the graph is captured, before the graph + is replayed. It returns a `GraphCaptureContext` object which contains the + necessary data for the graph capture. Currently, it only contains the + stream that the graph capture is running on. This stream is set to the + current CUDA stream when the context manager is entered and reset to the + default stream when the context manager is exited. This is to ensure that + the graph capture is running on a separate stream from the default stream, + in order to explicitly distinguish the kernels to capture + from other kernels possibly launched on background in the default stream. + """ + with get_tp_group().graph_capture() as context, get_pp_group().graph_capture( + context + ): + yield context + + +_ENABLE_CUSTOM_ALL_REDUCE = True + + +def set_custom_all_reduce(enable: bool): + global _ENABLE_CUSTOM_ALL_REDUCE + _ENABLE_CUSTOM_ALL_REDUCE = enable + + +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "nccl", +): + print( + "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", + world_size, + rank, + local_rank, + distributed_init_method, + backend, + ) + if not torch.distributed.is_initialized(): + assert distributed_init_method is not None, ( + "distributed_init_method must be provided when initializing " + "distributed environment" + ) + # this backend is used for WORLD + torch.distributed.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + ) + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = envs.LOCAL_RANK + else: + local_rank = rank + global _WORLD + if _WORLD is None: + ranks = list(range(torch.distributed.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend) + else: + assert ( + _WORLD.world_size == torch.distributed.get_world_size() + ), "world group already initialized with a different world size" + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """ + Initialize model parallel groups. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model + parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model + parallelism. + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: + 4 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 pipeline model-parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + + if world_size != tensor_model_parallel_size * pipeline_model_parallel_size: + raise RuntimeError( + f"world_size ({world_size}) is not equal to " + f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " + f"pipeline_model_parallel_size ({pipeline_model_parallel_size})" + ) + + # Build the tensor model-parallel groups. + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size + global _TP + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = list( + range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + ) + group_ranks.append(ranks) + + # message queue broadcaster is only used in tensor model parallel group + _TP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name="tp", + ) + + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + global _PP + assert _PP is None, "pipeline model parallel group is already initialized" + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + _PP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_custom_allreduce=False, + group_name="pp", + ) + + +def ensure_model_parallel_initialized( + tensor_model_parallel_size: int, + pipeline_model_parallel_size: int, + backend: Optional[str] = None, +) -> None: + """Helper to initialize model parallel groups if they are not initialized, + or ensure tensor-parallel and pipeline-parallel sizes are equal to expected + values if the model parallel groups are initialized. + """ + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + if not model_parallel_is_initialized(): + initialize_model_parallel( + tensor_model_parallel_size, pipeline_model_parallel_size, backend + ) + return + + assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, ( + "tensor parallel group already initialized, but of unexpected size: " + f"{get_tensor_model_parallel_world_size()=} vs. " + f"{tensor_model_parallel_size=}" + ) + pp_world_size = get_pp_group().world_size + assert pp_world_size == pipeline_model_parallel_size, ( + "pipeline parallel group already initialized, but of unexpected size: " + f"{pp_world_size=} vs. " + f"{pipeline_model_parallel_size=}" + ) + + +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return _TP is not None and _PP is not None + + +_TP_STATE_PATCHED = False + + +@contextmanager +def patch_tensor_parallel_group(tp_group: GroupCoordinator): + """Patch the tp group temporarily until this function ends. + + This method is for draft workers of speculative decoding to run draft model + with different tp degree from that of target model workers. + + Args: + tp_group (GroupCoordinator): the tp group coordinator + """ + global _TP_STATE_PATCHED + assert not _TP_STATE_PATCHED, "Should not call when it's already patched" + + _TP_STATE_PATCHED = True + old_tp_group = get_tp_group() + global _TP + _TP = tp_group + try: + yield + finally: + # restore the original state + _TP_STATE_PATCHED = False + _TP = old_tp_group + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return get_tp_group().world_size + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return get_tp_group().rank_in_group + + +def destroy_model_parallel(): + """Set the groups to none and destroy them.""" + global _TP + if _TP: + _TP.destroy() + _TP = None + + global _PP + if _PP: + _PP.destroy() + _PP = None + + +def destroy_distributed_environment(): + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +def cleanup_dist_env_and_memory(shutdown_ray: bool = False): + destroy_model_parallel() + destroy_distributed_environment() + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() + if shutdown_ray: + import ray # Lazy import Ray + + ray.shutdown() + gc.collect() + if not current_platform.is_cpu(): + torch.cuda.empty_cache() + + +def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: + """ + This is a collective operation that returns if each rank is in the same node + as the source rank. It tests if processes are attached to the same + memory system (shared access to shared memory). + """ + assert ( + torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL + ), "in_the_same_node_as should be tested with a non-NCCL group." + # local rank inside the group + rank = torch.distributed.get_rank(group=pg) + world_size = torch.distributed.get_world_size(group=pg) + + # local tensor in each process to store the result + is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32) + + # global ranks of the processes in the group + ranks = torch.distributed.get_process_group_ranks(pg) + + magic_message = b"magic_message" + shm = None + + try: + with contextlib.suppress(OSError): + if rank == source_rank: + # create a shared memory segment + shm = shared_memory.SharedMemory(create=True, size=128) + shm.buf[: len(magic_message)] = magic_message + torch.distributed.broadcast_object_list( + [shm.name], src=ranks[source_rank], group=pg + ) + is_in_the_same_node[rank] = 1 + else: + # try to open the shared memory segment + recv = [None] + torch.distributed.broadcast_object_list( + recv, src=ranks[source_rank], group=pg + ) + name = recv[0] + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch( + "multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None, + ): + shm = shared_memory.SharedMemory(name=name) + if shm.buf[: len(magic_message)] == magic_message: + is_in_the_same_node[rank] = 1 + except Exception as e: + print("Error ignored in is_in_the_same_node: %s", e) + finally: + if shm: + shm.close() + + torch.distributed.barrier(group=pg) + + # clean up the shared memory segment + with contextlib.suppress(OSError): + if rank == source_rank and shm: + shm.unlink() + torch.distributed.all_reduce(is_in_the_same_node, group=pg) + + return [x == 1 for x in is_in_the_same_node.tolist()] diff --git a/ktransformers/server/balance_serve/inference/distributed/pynccl.py b/ktransformers/server/balance_serve/inference/distributed/pynccl.py new file mode 100644 index 0000000..98be81d --- /dev/null +++ b/ktransformers/server/balance_serve/inference/distributed/pynccl.py @@ -0,0 +1,201 @@ +from contextlib import contextmanager +from typing import Optional, Union + +# ===================== import region ===================== +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup, ReduceOp + +from server.inference.distributed.pynccl_wrapper import ( + NCCLLibrary, + buffer_type, + cudaStream_t, + ncclComm_t, + ncclDataTypeEnum, + ncclRedOpTypeEnum, + ncclUniqueId, +) +from server.inference.distributed.utils import StatelessProcessGroup + + +class PyNcclCommunicator: + + def __init__( + self, + group: Union[ProcessGroup, StatelessProcessGroup], + device: Union[int, str, torch.device], + library_path: Optional[str] = None, + ): + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the PyNcclCommunicator to. If None, + it will be bind to f"cuda:{local_rank}". + library_path: the path to the NCCL library. If None, it will + use the default library path. + It is the caller's responsibility to make sure each communicator + is bind to a unique device. + """ + if not isinstance(group, StatelessProcessGroup): + assert dist.is_initialized() + assert ( + dist.get_backend(group) != dist.Backend.NCCL + ), "PyNcclCommunicator should be attached to a non-NCCL group." + # note: this rank is the rank in the group + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(group) + else: + self.rank = group.rank + self.world_size = group.world_size + + self.group = group + + # if world_size == 1, no need to create communicator + if self.world_size == 1: + self.available = False + self.disabled = True + self.stream = None + return + try: + self.nccl = NCCLLibrary(library_path) + except Exception: + # disable because of missing NCCL library + # e.g. in a non-GPU environment + self.available = False + self.disabled = True + self.stream = None + return + + self.available = True + self.disabled = False + + print("vLLM is using nccl==%s", self.nccl.ncclGetVersion()) + + if self.rank == 0: + # get the unique id from NCCL + self.unique_id = self.nccl.ncclGetUniqueId() + else: + # construct an empty unique id + self.unique_id = ncclUniqueId() + + if not isinstance(group, StatelessProcessGroup): + tensor = torch.ByteTensor(list(self.unique_id.internal)) + ranks = dist.get_process_group_ranks(group) + # arg `src` in `broadcast` is the global rank + dist.broadcast(tensor, src=ranks[0], group=group) + byte_list = tensor.tolist() + for i, byte in enumerate(byte_list): + self.unique_id.internal[i] = byte + else: + self.unique_id = group.broadcast_obj(self.unique_id, src=0) + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + # nccl communicator and stream will use this device + # `torch.cuda.device` is a context manager that changes the + # current cuda device to the specified one + with torch.cuda.device(device): + self.comm: ncclComm_t = self.nccl.ncclCommInitRank( + self.world_size, self.unique_id, self.rank + ) + self.stream = torch.cuda.Stream() + + # A small all_reduce for warmup. + data = torch.zeros(1, device=device) + self.all_reduce(data) + self.stream.synchronize() + del data + + # by default it is disabled, e.g. in profiling models and prefill phase. + # to use it, use under `with obj.change_state(enable=True)`, usually + # when we are using CUDA graph. + self.disabled = True + + def all_reduce( + self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = self.stream + self.nccl.ncclAllReduce( + buffer_type(tensor.data_ptr()), + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def send(self, tensor: torch.Tensor, dst: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = self.stream + self.nccl.ncclSend( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + dst, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def recv(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = self.stream + self.nccl.ncclRecv( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + @contextmanager + def change_state( + self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None + ): + """ + A context manager to change the state of the communicator. + """ + if enable is None: + # guess a default value when not specified + enable = self.available + + if stream is None: + stream = self.stream + + old_disable = self.disabled + old_stream = self.stream + + self.stream = stream + self.disabled = not enable + yield + + self.disabled = old_disable + self.stream = old_stream diff --git a/ktransformers/server/balance_serve/inference/distributed/pynccl_wrapper.py b/ktransformers/server/balance_serve/inference/distributed/pynccl_wrapper.py new file mode 100644 index 0000000..aa7d4a7 --- /dev/null +++ b/ktransformers/server/balance_serve/inference/distributed/pynccl_wrapper.py @@ -0,0 +1,276 @@ +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +from torch.distributed import ReduceOp + +from server.utils import find_nccl_library + + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int +ncclComm_t = ctypes.c_void_p + + +class ncclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +ncclDataType_t = ctypes.c_int + + +class ncclDataTypeEnum: + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +ncclRedOp_t = ctypes.c_int + + +class ncclRedOpTypeEnum: + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +class NCCLLibrary: + exported_functions = [ + # const char* ncclGetErrorString(ncclResult_t result) + Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), + # ncclResult_t ncclGetVersion(int *version); + Function("ncclGetVersion", ncclResult_t, + [ctypes.POINTER(ctypes.c_int)]), + # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); + Function("ncclGetUniqueId", ncclResult_t, + [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommInitRank( + # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + # note that ncclComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function("ncclCommInitRank", ncclResult_t, [ + ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, + ctypes.c_int + ]), + # ncclResult_t ncclAllReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllReduce", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclSend( + # const void* sendbuff, size_t count, ncclDataType_t datatype, + # int dest, ncclComm_t comm, cudaStream_t stream); + Function("ncclSend", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclRecv( + # void* recvbuff, size_t count, ncclDataType_t datatype, + # int src, ncclComm_t comm, cudaStream_t stream); + Function("ncclRecv", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. + # ncclResult_t ncclCommDestroy(ncclComm_t comm); + Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + + so_file = so_file or find_nccl_library() + + try: + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] + except Exception as e: + print( + "Failed to load NCCL library from %s ." + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise, the nccl library might not exist, be corrupted " + "or it does not support the current platform %s." + "If you already have the library, please set the " + "environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path.", so_file, + platform.platform()) + raise e + + if so_file not in NCCLLibrary.path_to_dict_mapping: + _funcs: Dict[str, Any] = {} + for func in NCCLLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + NCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] + + def ncclGetErrorString(self, result: ncclResult_t) -> str: + return self._funcs["ncclGetErrorString"](result).decode("utf-8") + + def NCCL_CHECK(self, result: ncclResult_t) -> None: + if result != 0: + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"NCCL error: {error_str}") + + def ncclGetVersion(self) -> str: + version = ctypes.c_int() + self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def ncclGetUniqueId(self) -> ncclUniqueId: + unique_id = ncclUniqueId() + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"]( + ctypes.byref(unique_id))) + return unique_id + + def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, + rank: int) -> ncclComm_t: + comm = ncclComm_t() + self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), + world_size, unique_id, + rank)) + return comm + + def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, + datatype, op, comm, + stream)) + + def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, + dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, + dest, comm, stream)) + + def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, + src: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, + comm, stream)) + + def ncclCommDestroy(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + + +__all__ = [ + "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", + "ncclComm_t", "cudaStream_t", "buffer_type" +] diff --git a/ktransformers/server/balance_serve/inference/distributed/utils.py b/ktransformers/server/balance_serve/inference/distributed/utils.py new file mode 100644 index 0000000..475433c --- /dev/null +++ b/ktransformers/server/balance_serve/inference/distributed/utils.py @@ -0,0 +1,219 @@ +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +import dataclasses +import pickle +import time +from collections import deque +from typing import Any, Deque, Dict, Optional, Sequence, Tuple + +import torch +from torch.distributed import TCPStore + +import server.envs as envs + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, "{} is not divisible by {}".format( + numerator, denominator + ) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> Sequence[torch.Tensor]: + """Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # NOTE: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +def get_pp_indices( + num_hidden_layers: int, pp_rank: int, pp_size: int +) -> Tuple[int, int]: + """Try to evenly distribute layers across partitions. + If the number of layers is not divisible by the number of partitions, + the last partition will have the remaining layers. + """ + partition_list_str = envs.VLLM_PP_LAYER_PARTITION + if partition_list_str is not None: + try: + partitions = [int(layer) for layer in partition_list_str.split(",")] + except ValueError as err: + raise ValueError( + "Invalid partition string: {}".format(partition_list_str) + ) from err + if len(partitions) != pp_size: + raise ValueError(f"{len(partitions)=} does not match {pp_size=}.") + if sum(partitions) != num_hidden_layers: + raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.") + start_layer = sum(partitions[:pp_rank]) + end_layer = start_layer + partitions[pp_rank] + else: + layers_per_partition = num_hidden_layers // pp_size + start_layer = pp_rank * layers_per_partition + end_layer = start_layer + layers_per_partition + + if pp_rank == pp_size - 1: + end_layer = num_hidden_layers + + return (start_layer, end_layer) + + +@dataclasses.dataclass +class StatelessProcessGroup: + """A dataclass to hold a metadata store, and the rank, world_size of the + group. Only use it to communicate metadata between processes. + For data-plane communication, create NCCL-related objects. + """ + + rank: int + world_size: int + store: torch._C._distributed_c10d.Store + data_expiration_seconds: int = 3600 # 1 hour + + # dst rank -> counter + send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict) + # src rank -> counter + recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict) + broadcast_send_counter: int = 0 + broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict) + + # A deque to store the data entries, with key and timestamp. + entries: Deque[Tuple[str, float]] = dataclasses.field(default_factory=deque) + + def __post_init__(self): + assert self.rank < self.world_size + self.send_dst_counter = {i: 0 for i in range(self.world_size)} + self.recv_src_counter = {i: 0 for i in range(self.world_size)} + self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)} + + def send_obj(self, obj: Any, dst: int): + """Send an object to a destination rank.""" + self.expire_data() + key = f"send_to/{dst}/{self.send_dst_counter[dst]}" + self.store.set(key, pickle.dumps(obj)) + self.send_dst_counter[dst] += 1 + self.entries.append((key, time.time())) + + def expire_data(self): + """Expire data that is older than `data_expiration_seconds` seconds.""" + while self.entries: + # check the oldest entry + key, timestamp = self.entries[0] + if time.time() - timestamp > self.data_expiration_seconds: + self.store.delete_key(key) + self.entries.popleft() + else: + break + + def recv_obj(self, src: int) -> Any: + """Receive an object from a source rank.""" + obj = pickle.loads( + self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}") + ) + self.recv_src_counter[src] += 1 + return obj + + def broadcast_obj(self, obj: Optional[Any], src: int) -> Any: + """Broadcast an object from a source rank to all other ranks. + It does not clean up after all ranks have received the object. + Use it for limited times, e.g., for initialization. + """ + if self.rank == src: + self.expire_data() + key = f"broadcast_from/{src}/" f"{self.broadcast_send_counter}" + self.store.set(key, pickle.dumps(obj)) + self.broadcast_send_counter += 1 + self.entries.append((key, time.time())) + return obj + else: + key = f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}" + recv_obj = pickle.loads(self.store.get(key)) + self.broadcast_recv_src_counter[src] += 1 + return recv_obj + + def all_gather_obj(self, obj: Any) -> list[Any]: + """All gather an object from all ranks.""" + gathered_objs = [] + for i in range(self.world_size): + if i == self.rank: + gathered_objs.append(obj) + self.broadcast_obj(obj, src=self.rank) + else: + recv_obj = self.broadcast_obj(None, src=i) + gathered_objs.append(recv_obj) + return gathered_objs + + def barrier(self): + """A barrier to synchronize all ranks.""" + for i in range(self.world_size): + if i == self.rank: + self.broadcast_obj(None, src=self.rank) + else: + self.broadcast_obj(None, src=i) + + @staticmethod + def create( + host: str, + port: int, + rank: int, + world_size: int, + data_expiration_seconds: int = 3600, + ) -> "StatelessProcessGroup": + """A replacement for `torch.distributed.init_process_group` that does not + pollute the global state. + + If we have process A and process B called `torch.distributed.init_process_group` + to form a group, and then we want to form another group with process A, B, C, + D, it is not possible in PyTorch, because process A and process B have already + formed a group, and process C and process D cannot join that group. This + function is a workaround for this issue. + + `torch.distributed.init_process_group` is a global call, while this function + is a stateless call. It will return a `StatelessProcessGroup` object that can be + used for exchanging metadata. With this function, process A and process B + can call `StatelessProcessGroup.create` to form a group, and then process A, B, + C, and D can call `StatelessProcessGroup.create` to form another group. + """ # noqa + store = TCPStore( + host_name=host, + port=port, + world_size=world_size, + is_master=(rank == 0), + ) + + return StatelessProcessGroup( + rank=rank, + world_size=world_size, + store=store, + data_expiration_seconds=data_expiration_seconds, + ) diff --git a/ktransformers/server/balance_serve/inference/forward_batch.py b/ktransformers/server/balance_serve/inference/forward_batch.py new file mode 100644 index 0000000..4f79bc3 --- /dev/null +++ b/ktransformers/server/balance_serve/inference/forward_batch.py @@ -0,0 +1,284 @@ +''' +Date: 2024-11-12 14:15:16 +LastEditors: Xie Weiyu ervinxie@qq.com +LastEditTime: 2024-11-26 08:12:49 +''' +import torch +from ktransformers.server.balance_serve.settings import sched_ext +from ktransformers.server.balance_serve.inference.query_manager import QueryManager, QueryInfo +import time +from ktransformers.server.config.config import Config +class ForwardBatchInput: + + class ForwardMiniBatch: + q_indptr: torch.Tensor + kv_indptr: torch.Tensor + kv_indices: torch.Tensor + kv_last_page_len: torch.Tensor + kv_len: torch.Tensor + position_ids: torch.Tensor + tokens: torch.Tensor + batch_indices: torch.Tensor + positions: torch.Tensor + chunk_size: int + decode_batch: int + is_last_prefill_chunk: bool + logits_start: list + + temperatures: torch.Tensor + top_ps: torch.Tensor + + def __init__(self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo], prefill_s: list[int] = None, prefill_l: list[int] = None, device = torch.device('cuda'), page_size = 256): + batch_decode = len(decode_querys_info) + batch_prefill = len(prefill_querys_info) + + self.q_indptr = torch.tensor([0], device=device, dtype=torch.int32) + self.kv_indptr = torch.tensor([0], device=device, dtype=torch.int32) + self.kv_indices = torch.tensor([], device=device, dtype=torch.int32) + self.kv_len = torch.tensor([], device=device, dtype=torch.int32) + self.kv_last_page_len = torch.tensor([], device=device, dtype=torch.int32) + self.position_ids = torch.tensor([], device=device, dtype=torch.int32) + self.tokens = torch.tensor([], device=device, dtype=torch.int32) + + self.temperatures = torch.tensor([], device=device, dtype=torch.float32) + self.top_ps = torch.tensor([], device=device, dtype=torch.float32) + + self.logits_start = [] + self.decode_batch = batch_decode + self.num_tokens = batch_decode + sum(prefill_l) + self.batch_size = batch_decode + batch_prefill + + for i, prefill_query_info in enumerate(prefill_querys_info): + if prefill_query_info != None: + prefill_kv_block_len = (prefill_query_info.active_position + prefill_l[i] + page_size - 1) // page_size if prefill_query_info is not None else 0 + # print(f"block_len: {prefill_kv_block_len}, page_size: {page_size}") + self.q_indptr = torch.concat((self.q_indptr, torch.tensor([prefill_l[i] + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0) + self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([prefill_kv_block_len + self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0) + self.kv_indices = torch.concat((self.kv_indices, prefill_query_info.block_index[:prefill_kv_block_len]), dim=0) + self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i]) % page_size if (prefill_query_info.active_position + prefill_l[i]) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0) + self.kv_len = torch.concat((self.kv_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i])], device=device, dtype=torch.int32)), dim=0) + self.position_ids = torch.concat((self.position_ids, torch.arange(prefill_s[i], prefill_l[i] + prefill_s[i], device=device, dtype=torch.int32)), dim=0) + self.tokens = torch.concat((self.tokens, prefill_query_info.query_tokens[prefill_s[i]:prefill_s[i] + prefill_l[i]]), dim=0) + self.logits_start.append(prefill_l[i] - 1 if len(self.logits_start) == 0 else sum(prefill_l[:i+1])-1) + + self.temperatures = torch.concat((self.temperatures, torch.tensor([prefill_query_info.temperature], device=device, dtype=torch.float32)), dim=0) + self.top_ps = torch.concat((self.top_ps, torch.tensor([prefill_query_info.top_p], device=device, dtype=torch.float32)), dim=0) + + for decode_query_info in decode_querys_info: + decode_kv_block_len = (decode_query_info.active_position + 1 + page_size - 1) // page_size + self.q_indptr = torch.concat((self.q_indptr, torch.tensor([1 + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0) + self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([decode_kv_block_len+self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0) + self.kv_indices = torch.concat((self.kv_indices, decode_query_info.block_index[:decode_kv_block_len]), dim=0) + self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(decode_query_info.active_position + 1) % page_size if (decode_query_info.active_position + 1) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0) + self.kv_len = torch.concat((self.kv_len, torch.tensor([(decode_query_info.active_position + 1)], device=device, dtype=torch.int32)), dim=0) + self.position_ids = torch.concat((self.position_ids, torch.arange(decode_query_info.active_position, decode_query_info.active_position + 1, device=device, dtype=torch.int32)), dim=0) + if decode_query_info.active_position > 0: + self.tokens = torch.concat((self.tokens, decode_query_info.query_tokens[decode_query_info.active_position:decode_query_info.active_position+1]), dim=0) + else: + self.tokens = torch.concat((self.tokens, torch.tensor([0], device=device, dtype=torch.int32)), dim=0) + self.logits_start.append(0 if len(self.logits_start) == 0 else self.logits_start[-1]+1) + + self.temperatures = torch.concat((self.temperatures, torch.tensor([decode_query_info.temperature], device=device, dtype=torch.float32)), dim=0) + self.top_ps = torch.concat((self.top_ps, torch.tensor([decode_query_info.top_p], device=device, dtype=torch.float32)), dim=0) + + self.q_indptr = self.q_indptr.contiguous() + self.kv_indptr = self.kv_indptr.contiguous() + self.kv_indices = self.kv_indices.contiguous() + self.kv_len = self.kv_len.contiguous() + self.kv_last_page_len = self.kv_last_page_len.contiguous() + self.position_ids = self.position_ids.contiguous() + self.tokens = self.tokens.contiguous() + + self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32) + + def fill(self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo], prefill_s: list[int] = None, prefill_l: list[int] = None, device = torch.device('cuda'), page_size = 256): + batch_decode = len(decode_querys_info) + batch_prefill = len(prefill_querys_info) + + self.q_indptr = torch.tensor([0], device=device, dtype=torch.int32) + self.kv_indptr = torch.tensor([0], device=device, dtype=torch.int32) + self.kv_indices = torch.tensor([], device=device, dtype=torch.int32) + self.kv_len = torch.tensor([], device=device, dtype=torch.int32) + self.kv_last_page_len = torch.tensor([], device=device, dtype=torch.int32) + new_position_ids = torch.tensor([], device=device, dtype=torch.int32) + new_tokens = torch.tensor([], device=device, dtype=torch.int32) + + self.temperatures = torch.tensor([], device=device, dtype=torch.float32) + self.top_ps = torch.tensor([], device=device, dtype=torch.float32) + + self.logits_start = [] + self.decode_batch = batch_decode + self.num_tokens = batch_decode + sum(prefill_l) + self.batch_size = batch_decode + batch_prefill + + for i, prefill_query_info in enumerate(prefill_querys_info): + prefill_kv_block_len = (prefill_query_info.active_position + prefill_l[i] + page_size - 1) // page_size if prefill_query_info is not None else 0 + # print(f"block_len: {prefill_kv_block_len}, page_size: {page_size}") + self.q_indptr = torch.concat((self.q_indptr, torch.tensor([prefill_l[i] + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0) + self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([prefill_kv_block_len + self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0) + self.kv_indices = torch.concat((self.kv_indices, prefill_query_info.block_index[:prefill_kv_block_len]), dim=0) + self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i]) % page_size if (prefill_query_info.active_position + prefill_l[i]) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0) + self.kv_len = torch.concat((self.kv_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i])], device=device, dtype=torch.int32)), dim=0) + new_position_ids = torch.concat((new_position_ids, torch.arange(prefill_s[i], prefill_l[i] + prefill_s[i], device=device, dtype=torch.int32)), dim=0) + new_tokens = torch.concat((new_tokens, prefill_query_info.query_tokens[prefill_s[i]:prefill_s[i] + prefill_l[i]]), dim=0) + self.logits_start.append(prefill_l[i] - 1 if len(self.logits_start) == 0 else sum(prefill_l[:i+1])-1) + + self.temperatures = torch.concat((self.temperatures, torch.tensor([prefill_query_info.temperature], device=device, dtype=torch.float32)), dim=0) + self.top_ps = torch.concat((self.top_ps, torch.tensor([prefill_query_info.top_p], device=device, dtype=torch.float32)), dim=0) + + + for decode_query_info in decode_querys_info: + decode_kv_block_len = (decode_query_info.active_position + 1 + page_size - 1) // page_size + self.q_indptr = torch.concat((self.q_indptr, torch.tensor([1 + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0) + self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([decode_kv_block_len+self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0) + self.kv_indices = torch.concat((self.kv_indices, decode_query_info.block_index[:decode_kv_block_len]), dim=0) + self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(decode_query_info.active_position + 1) % page_size if (decode_query_info.active_position + 1) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0) + self.kv_len = torch.concat((self.kv_len, torch.tensor([(decode_query_info.active_position + 1)], device=device, dtype=torch.int32)), dim=0) + new_position_ids = torch.concat((new_position_ids, torch.arange(decode_query_info.active_position, decode_query_info.active_position + 1, device=device, dtype=torch.int32)), dim=0) + if decode_query_info.active_position > 0: + new_tokens = torch.concat((new_tokens, decode_query_info.query_tokens[decode_query_info.active_position:decode_query_info.active_position+1]), dim=0) + else: + new_tokens = torch.concat((new_tokens, torch.tensor([0], device=device, dtype=torch.int32)), dim=0) + self.logits_start.append(0 if len(self.logits_start) == 0 else self.logits_start[-1]+1) + + self.temperatures = torch.concat((self.temperatures, torch.tensor([decode_query_info.temperature], device=device, dtype=torch.float32)), dim=0) + self.top_ps = torch.concat((self.top_ps, torch.tensor([decode_query_info.top_p], device=device, dtype=torch.float32)), dim=0) + + + self.q_indptr = self.q_indptr.contiguous() + self.kv_indptr = self.kv_indptr.contiguous() + self.kv_indices = self.kv_indices.contiguous() + self.kv_len = self.kv_len.contiguous() + self.kv_last_page_len = self.kv_last_page_len.contiguous() + + self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32) + + # copy new_position_ids and new_tokens to self.position_ids and self.tokens + # print("new_position_ids: ", new_position_ids) + # self.print() + self.position_ids[:new_position_ids.size(0)].copy_(new_position_ids) + self.position_ids[new_position_ids.size(0):].zero_() + self.tokens[:new_tokens.size(0)].copy_(new_tokens) + + + forward_minibatchs: list[ForwardMiniBatch] + batch_size: int + minibatch: ForwardMiniBatch + + + + def __init__(self, batch : sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None, device=None, tokens: torch.Tensor = None): + + if batch is None: + return + + + prefill_minibatches = batch.prefill_mini_batches + decode_mini_batches = [item for sublist in batch.decode_mini_batches for item in sublist] + prefill_querys_info = [] + prefill_s = [] + prefill_l = [] + decode_querys_info = [] + self.batch_size = 1 + for (id, s, l) in prefill_minibatches: + prefill_querys_info.append(query_manager.query_map[id]) + prefill_s.append(s) + prefill_l.append(l) + for decode_batch_idx in decode_mini_batches: + if query_manager.query_map[decode_batch_idx].decode_start_time is None: + query_manager.query_map[decode_batch_idx].decode_start_time =time.time() + decode_querys_info.append(query_manager.query_map[decode_batch_idx]) + + + minibatch = ForwardBatchInput.ForwardMiniBatch(prefill_querys_info, decode_querys_info, prefill_s, prefill_l, device = query_manager.device, page_size = query_manager.page_size) + + self.minibatch = minibatch + + @classmethod + def gen_max_forward_batch( + cls, + device=None, + tokens: torch.Tensor = None, + num_mini_batches: int = 1, + max_seq_length: int = 1024, # TODO: add to yaml + prefill_query_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size, # TODO: use config + prefill_active_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size, + gen_prefill: bool = True, + decode_batch_size: int = Config().max_decode_batch_size, + decode_active_position: torch.Tensor = None, + page_size = 256, + cuda_lens = 1 + ): + instance = cls() + + instance.batch_size = num_mini_batches + page_size = page_size + + prefill_query_info = [] + offset = 0 + if gen_prefill and prefill_query_length != 0: + for i in range(Config().max_prefill_batch_size): + prefill_query_info.append(QueryInfo(i, prefill_query_length, max_seq_length, page_size, device, offset=offset)) + offset += max_seq_length // page_size + + decode_querys_info = [] + for i in range(min(decode_batch_size, cuda_lens)): + query_info = QueryInfo(i+Config().max_prefill_batch_size, prefill_query_length, max_seq_length, page_size, device, is_prefill=False, offset=offset) + offset += max_seq_length // page_size + if tokens is not None: + query_info.query_tokens[prefill_active_length:prefill_active_length + 1].copy_(tokens) + if decode_active_position is None: + query_info.active_position = prefill_active_length + else: + query_info.active_position = decode_active_position[i] + + decode_querys_info.append(query_info) + + if prefill_query_length*Config().max_prefill_batch_size + len(decode_querys_info) < cuda_lens: + decode_querys_info.append(query_info) + + instance.minibatch = ForwardBatchInput.ForwardMiniBatch(prefill_query_info, decode_querys_info, [0, 0], [prefill_active_length for _ in range(Config().max_prefill_batch_size)], device, page_size) + + return instance + + def fill(self, batch : sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None, page_size = 256): + if batch is None: + return + prefill_minibatches = batch.prefill_mini_batches + decode_mini_batches = [item for sublist in batch.decode_mini_batches for item in sublist] + + prefill_querys_info = [] + prefill_s = [] + prefill_l = [] + decode_querys_info = [] + self.batch_size = 1 + for (id, s, l) in prefill_minibatches: + prefill_querys_info.append(query_manager.query_map[id]) + prefill_s.append(s) + prefill_l.append(l) + for decode_batch_idx in decode_mini_batches: + if query_manager.query_map[decode_batch_idx].decode_start_time is None: + query_manager.query_map[decode_batch_idx].decode_start_time =time.time() + decode_querys_info.append(query_manager.query_map[decode_batch_idx]) + + self.minibatch.fill(prefill_querys_info, decode_querys_info, prefill_s, prefill_l, device=query_manager.device, page_size=page_size) + + + +class ForwardBatchOutput: + logits: list[torch.Tensor] + num_batchs: int + batch_sizes: list[int] + generated_tokens_num: list[int] + lm_start: list[int] + + temperatures: list[torch.Tensor] + top_ps: list[torch.Tensor] + + def __init__(self): + self.logits = [] + self.batch_sizes = [] + self.generated_tokens_num = [] + self.top_ps = [] + self.temperatures = [] + pass \ No newline at end of file diff --git a/ktransformers/server/balance_serve/inference/model_runner.py b/ktransformers/server/balance_serve/inference/model_runner.py new file mode 100644 index 0000000..386307b --- /dev/null +++ b/ktransformers/server/balance_serve/inference/model_runner.py @@ -0,0 +1,306 @@ +""" +Date: 2024-11-07 07:02:20 +LastEditors: djw +LastEditTime: 2024-12-10 08:48:32 +""" + +import torch +from torch import nn +import queue +import signal +import queue +from typing import AsyncIterable +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +from contextlib import asynccontextmanager +from pydantic import BaseModel, Field +import asyncio +import multiprocessing +import time +import torch.multiprocessing as mp +import random +import torch.distributed as dist +import zmq +import tempfile +from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput + +from ktransformers.server.config.config import Config +from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM +from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM +from ktransformers.server.balance_serve.inference.query_manager import QueryManager +from ktransformers.server.balance_serve.settings import sched_ext + + + +def pad_num_tokens(num_tokens): + return (num_tokens + 63) // 64 * 64 + +def deduplicate_and_sort(lst): + return sorted(set(lst)) +class ModelRunner: + """A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile.""" + + model: KDeepseekV3ForCausalLM + input: ForwardBatchInput | list[ForwardBatchInput] + output: ForwardBatchOutput + + def __init__(self, model = None, device = None, use_cuda_graph = False, max_decode_batch_size = 1, max_chunk_size = 4096, num_mini_batches: int = 1, page_size = 256): + + self.stream = torch.cuda.Stream(device=device) + # 先注释掉 + self.model = model # Compile and move model to the specified device + self.device = device + self.input = None + self.features_buf = None + self.output = None + self.graph_memory_pool = None + self.cuda_graphs = deduplicate_and_sort([1, 2, 3, Config().max_batch_size, 64, Config().chunk_size]) + self.use_cuda_graph = use_cuda_graph + self.model_time = 0 + self.page_size = page_size + # GPU timing for model execution + self.start_model_event = torch.cuda.Event(enable_timing=True) + self.end_model_event = torch.cuda.Event(enable_timing=True) + if isinstance(self.cuda_graphs, list): + self.graphs = [torch.cuda.CUDAGraph() for _ in range(len(self.cuda_graphs))] + self.page_idx_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))] + self.page_offset_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))] + else: + self.graphs = torch.cuda.CUDAGraph() + self.page_idx_buf = torch.zeros([self.cuda_graphs], dtype=torch.int32, device = self.device) + self.page_offset_buf = torch.zeros([self.cuda_graphs], dtype=torch.int32, device = self.device) + self.num_mini_batches = num_mini_batches + + self.max_chunk_size = max_chunk_size + + self.bsz_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device) + self.num_tokens_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device) + def warmup(self): + + def capture_graphs(cuda_graph_idx=-1): + if cuda_graph_idx != -1: + with torch.cuda.graph(self.graphs[cuda_graph_idx], pool=self.graph_memory_pool, stream=self.stream): + self.outputs_buf[cuda_graph_idx] = self.model(self.input[cuda_graph_idx], self.features_buf[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[cuda_graph_idx], self.page_offset_buf[cuda_graph_idx], cuda_graph_idx=cuda_graph_idx) + self.graph_memory_pool = self.graphs[cuda_graph_idx].pool() + else: + with torch.cuda.graph(self.graphs, pool=self.graph_memory_pool, stream=self.stream): + self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf) + self.graph_memory_pool = self.graphs.pool() + + if isinstance(self.cuda_graphs, list): + self.input = [] + self.features_buf = [] + self.outputs_buf = [] + self.bsz_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device) + self.num_tokens_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device) + for i in range(len(self.cuda_graphs)): + prefill_query_length = (self.cuda_graphs[i] - Config().max_decode_batch_size) // Config().max_prefill_batch_size if self.cuda_graphs[i] > Config().max_decode_batch_size else 0 #@TODO only supprot 2 prefill batch + self.input.append(ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches, prefill_query_length=prefill_query_length, prefill_active_length=prefill_query_length, page_size=self.page_size, cuda_lens = self.cuda_graphs[i])) + + self.features_buf.append(self.model.batch_embeddings(self.input[i])) + batch_size = self.input[i].minibatch.q_indptr.size(0)-1 + num_tokens = self.features_buf[i][0].size(0) + print("capturing cuda graph", batch_size, num_tokens) + self.bsz_tensor_buf[0] = batch_size + self.num_tokens_tensor_buf[0] = num_tokens + + self.model.flash_infer_attn_plan(self.input[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf, + num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, + head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True, + sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) + + page_idx, page_offset = self.model.cache.get_page_table(self.input[i].minibatch.position_ids, self.input[i].minibatch.q_indptr, self.input[i].minibatch.kv_indptr, self.input[i].minibatch.kv_indices, self.num_tokens_tensor_buf) + + self.page_idx_buf[i][:num_tokens].copy_(page_idx[:num_tokens]) + self.page_offset_buf[i][:num_tokens].copy_(page_offset[:num_tokens]) + self.page_idx_buf[i][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size -1) + + self.outputs_buf.append(None) + + torch.cuda.synchronize() + for warm_up_iters in range(11): + with torch.cuda.stream(self.stream): + self.outputs_buf[i] = self.model(self.input[i], self.features_buf[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[i], self.page_offset_buf[i]) + torch.cuda.synchronize() + + capture_graphs(i) + + with torch.cuda.stream(self.stream): + self.graphs[i].replay() + + self.sync(calc_time=False) + print(f"cuda_graph: {i+1}/{len(self.cuda_graphs)}, warmup finished.") + else: + self.input = ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches) + + self.features_buf = self.model.batch_embeddings(self.input) + batch_size = self.input.minibatch.q_indptr.size(0)-1 + num_tokens = self.features_buf[0].size(0) + + + self.bsz_tensor_buf = torch.tensor([batch_size], dtype=torch.int32, device=self.device) + self.num_tokens_tensor_buf = torch.tensor([num_tokens], dtype=torch.int32, device=self.device) + + + self.model.flash_infer_attn_plan(self.input, self.bsz_tensor_buf, self.num_tokens_tensor_buf, + num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, + head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True, + sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) + + page_idx, page_offset = self.model.cache.get_page_table(self.input.minibatch.position_ids, self.input.minibatch.q_indptr, self.input.minibatch.kv_indptr, self.input.minibatch.kv_indices, self.num_tokens_tensor_buf) + self.page_idx_buf[:num_tokens].copy_(page_idx[:num_tokens]) + self.page_offset_buf[:num_tokens].copy_(page_offset[:num_tokens]) + self.page_idx_buf[num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1) + + + torch.cuda.synchronize() + for warm_up_iters in range(11): + with torch.cuda.stream(self.stream): + self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf) + torch.cuda.synchronize() + + def capture_graphs(): + with torch.cuda.graph(self.graphs, stream=self.stream): + self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf) + # self.graph_memory_pool = self.graphs.pool() + + + capture_graphs() + + with torch.cuda.stream(self.stream): + self.graphs.replay() + + self.sync(calc_time=False) + print("warmup finished.") + + def run(self, batch: sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None): + with torch.cuda.stream(self.stream): + + batch_size = len(batch.prefill_mini_batches) # TODO: calc this + num_tokens = 0 + for i in range(len(batch.decode_mini_batches)): + batch_size += len(batch.decode_mini_batches[i]) + num_tokens += len(batch.decode_mini_batches[i]) + print(f'decode_batch_i: {len(batch.decode_mini_batches[i])},') + + for i in range(len(batch.prefill_mini_batches)): + num_tokens += batch.prefill_mini_batches[i][2] + print(f'prefill_batch_i: {batch.prefill_mini_batches[i][2]},') + + + + if isinstance(self.cuda_graphs, list): + # cuda graph idx equal to min idx i in self.cuda_graphs, that self.cuda_graphs[i] > num_tokens + cuda_graph_idx = next((i for i, token in enumerate(self.cuda_graphs) if token >= num_tokens), len(self.cuda_graphs)) + if cuda_graph_idx == len(self.cuda_graphs): + assert False, "num_tokens is too large" + else: + cuda_graph_idx = -1 + + if self.use_cuda_graph: + if cuda_graph_idx != -1: + self.input[cuda_graph_idx].fill(batch, query_manager, self.page_size) + else: + self.input.fill(batch, query_manager, self.page_size) + else: + self.input = ForwardBatchInput(batch=batch, query_manager=query_manager, device=self.device) + + + if cuda_graph_idx != -1 and self.use_cuda_graph: + self.features = self.model.batch_embeddings(self.input[cuda_graph_idx], device=self.device) + else: + self.features = self.model.batch_embeddings(self.input, device=self.device) + + + self.bsz_tensor_buf.copy_(batch_size) + self.num_tokens_tensor_buf.copy_(torch.tensor([num_tokens], dtype=torch.int32, device=self.device)) + + if self.use_cuda_graph: + if cuda_graph_idx != -1: + self.features_buf[cuda_graph_idx][0].copy_(self.features[0], non_blocking=True) + else: + self.features_buf[0].copy_(self.features[0], non_blocking=True) + """ + if num_tokens_0 > 64: + padded_num_tokens_0 = pad_num_tokens(num_tokens_0) + self.features_buf[0][num_tokens_0:padded_num_tokens_0] = 0 + """ + #self.input.forward_minibatchs[0].print() + # print([[hash(k[i].float().cpu().numpy().tobytes()) for i in self.input.forward_minibatchs[0].kv_indices] for k in self.model.cache.k_caches]) + # print(f"overlap: {overlap}, is_compute_bound: {is_compute_bound}") + + # self.model.flash_infer_attn_plan(self.input, self.bsz_tensors, self.num_tokens_tensors) + + """ + if self.use_cuda_graph: + print("before replay features_buf", self.features_buf[0]) + print("features_buf addr", self.features_buf[0].data_ptr()) + else: + print("before run features", self.features[0]) + """ + if cuda_graph_idx != -1 and self.use_cuda_graph: + self.model.flash_infer_attn_plan(self.input[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf, + num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, + head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True, + sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) + self.start_model_event.record(self.stream) + page_idx, page_offset = self.model.cache.get_page_table(self.input[cuda_graph_idx].minibatch.position_ids, self.input[cuda_graph_idx].minibatch.q_indptr, self.input[cuda_graph_idx].minibatch.kv_indptr, self.input[cuda_graph_idx].minibatch.kv_indices, self.num_tokens_tensor_buf) + if self.use_cuda_graph: + self.page_idx_buf[cuda_graph_idx][:num_tokens].copy_(page_idx[:num_tokens]) + self.page_offset_buf[cuda_graph_idx][:num_tokens].copy_(page_offset[:num_tokens]) + self.page_idx_buf[cuda_graph_idx][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1) + self.replay(cuda_graph_idx) + self.output = ForwardBatchOutput() + + self.output.top_ps.append(self.input[cuda_graph_idx].minibatch.top_ps) + self.output.temperatures.append(self.input[cuda_graph_idx].minibatch.temperatures) + + self.output.logits.append(self.outputs_buf[cuda_graph_idx].logits[0][self.input[cuda_graph_idx].minibatch.logits_start].clone()) + else: + self.output = self.model(self.input[cuda_graph_idx], self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset) + self.output.logits[0] = self.output.logits[0][self.input[cuda_graph_idx].minibatch.logits_start] + self.end_model_event.record(self.stream) + else: + self.model.flash_infer_attn_plan(self.input, self.bsz_tensor_buf, self.num_tokens_tensor_buf, + num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, + head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True, + sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) + self.start_model_event.record(self.stream) + page_idx, page_offset = self.model.cache.get_page_table(self.input.minibatch.position_ids, self.input.minibatch.q_indptr, self.input.minibatch.kv_indptr, self.input.minibatch.kv_indices, self.num_tokens_tensor_buf) + if self.use_cuda_graph: + self.page_idx_buf[:num_tokens].copy_(page_idx[:num_tokens]) + self.page_offset_buf[:num_tokens].copy_(page_offset[:num_tokens]) + self.page_idx_buf[num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1) + self.replay(cuda_graph_idx) + self.output = ForwardBatchOutput() + + self.output.top_ps.append(self.input.minibatch.top_ps) + self.output.temperatures.append(self.input.minibatch.temperatures) + + self.output.logits.append(self.outputs_buf.logits[0][self.input.minibatch.logits_start].clone()) + else: + self.output = self.model(self.input, self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset) + self.output.logits[0] = self.output.logits[0][self.input.minibatch.logits_start] + self.output.top_ps.append(self.input.minibatch.top_ps) + self.output.temperatures.append(self.input.minibatch.temperatures) + + self.end_model_event.record(self.stream) + + if not self.use_cuda_graph: + self.output.num_batchs = self.input.batch_size + else: + self.output.num_batchs = self.input[cuda_graph_idx].batch_size + + + def replay(self, cuda_graph_idx=-1): + with torch.cuda.stream(self.stream): + if cuda_graph_idx != -1: + self.graphs[cuda_graph_idx].replay() + else: + self.graphs.replay() + + + def sync(self, calc_time = True): + self.stream.synchronize() + if calc_time: + self.model_time = self.start_model_event.elapsed_time(self.end_model_event) # In ms \ No newline at end of file diff --git a/ktransformers/server/balance_serve/inference/query_manager.py b/ktransformers/server/balance_serve/inference/query_manager.py new file mode 100644 index 0000000..e77ffc8 --- /dev/null +++ b/ktransformers/server/balance_serve/inference/query_manager.py @@ -0,0 +1,158 @@ +''' +Date: 2024-11-14 12:23:45 +LastEditors: djw +LastEditTime: 2024-11-20 04:06:23 +''' +import torch +from ktransformers.server.balance_serve.settings import sched_ext +import random +import time + +class QueryInfo: + id: int + active_position: int + query_length: int + is_prefill: int + block_index: torch.Tensor + query_tokens: torch.Tensor + stop_criteria: list[torch.Tensor] + + temperature: float + top_p: float + + max_length: int + + def __init__(self, id, query_length: int, max_length: int, page_size: int, device: torch.device, is_prefill: bool = True, offset: int = 0, active_position: int = 0, temperature: float = 0.01, top_p: float = 1.0): + self.id = id + self.is_prefill = is_prefill + self.active_position = active_position + self.max_length = max_length - 1 + self.query_tokens = torch.zeros((max_length,), dtype=torch.int, device = device) + self.stop_criteria = [] + self.block_index = torch.arange(offset, offset + (max_length + active_position + page_size - 1) // page_size, dtype=torch.int, device = device) + self.query_length = query_length + self.enqueue_time = time.time() + self.decode_start_time = None + self.speculative_token = {} # {position: (accept, token)} + + self.temperature = temperature + self.top_p = top_p + + def check_stop(self): + if self.active_position >= self.max_length - 2: + return True + + # 遍历每个停止条件 + for stop_tensor in self.stop_criteria: + stop_len = len(stop_tensor) + + # 如果停止条件比 query_tokens 长,跳过 + if stop_len >= self.active_position: + continue + + #print(f"stop_tensor: {stop_tensor}, stop_len: {stop_len}, active_position: {self.active_position}, query_token: {self.query_tokens[self.active_position - stop_len - 1:self.active_position - 1]}") + + if (torch.equal(self.query_tokens[self.active_position - stop_len - 1:self.active_position - 1], stop_tensor) and self.active_position) or self.max_length <= self.active_position + 3: + self.life_time = time.time() - self.enqueue_time + self.decode_duration_time = time.time() - self.decode_start_time + self.decode_tps = (self.active_position - self.query_length) / self.decode_duration_time + print(f"prefill length: {self.query_length}, prefill time: {self.prefill_duration_time}, prefill tps {self.prefill_tps}, decode length: {self.active_position - self.query_length}, decode time: {self.decode_duration_time}, decode tps {self.decode_tps}") + return True # 找到匹配的停止条件 + + + return False # 没有找到任何停止条件 + + + def print(self): + print(f"active_position: {self.active_position}, query_length: {self.query_length}, is_prefill: {self.is_prefill}") + print(f"block_index_shape: {self.block_index.shape}, query_tokens_shape: {self.query_tokens.shape}") + + +class QueryManager: + + max_length: int = 65536 + page_size: int = 256 + device: torch.device + query_map : dict[int, QueryInfo] + + def __init__(self, max_length = 65536, page_size = 256, device = torch.device('cuda')): + self.max_length = max_length + self.page_size = page_size + self.device = device + self.query_map = {} + + def add_query(self, batch: sched_ext.BatchQueryTodo): + + for i in range(len(batch.query_ids)): + id = batch.query_ids[i] + if id not in self.query_map: + print(f"add query id: {id}, batch.query_lengths: {batch.query_lengths[i]}, batch_query_tokens: {batch.query_tokens[i].shape}, batch.block_indexes: {batch.block_indexes[i]}") + assert batch.query_tokens[i].size(0) < self.max_length, "query max length in batchquerytodo exceeds internal max_length" + query_info = QueryInfo(id=id, query_length=batch.query_lengths[i], max_length=batch.query_tokens[i].size(0) + 1, page_size=self.page_size, device=self.device, temperature=batch.sample_options[i].temperature, top_p=batch.sample_options[i].top_p) + query_info.query_tokens[:query_info.query_length].copy_(batch.query_tokens[i][:query_info.query_length].to(self.device)) + + for stop_token_list in batch.stop_criteria[i]: + query_info.stop_criteria.append(torch.tensor(stop_token_list, dtype=torch.int, device = self.device)) + + block_num = batch.block_indexes[i].size(0) + query_info.block_index[:block_num].copy_(batch.block_indexes[i].to(self.device)) + + self.query_map[id] = query_info + + prefill_mini_batches = batch.prefill_mini_batches + for (prefill_id, s, l) in prefill_mini_batches: + if prefill_id == id: + self.query_map[prefill_id].active_position = s + + + def update(self, batch: sched_ext.BatchQueryTodo) -> list[sched_ext.QueryUpdate]: + query_updates = [] + + prefill_mini_batches = batch.prefill_mini_batches + + for (id, s, l) in prefill_mini_batches: + + if id not in self.query_map: + assert False, f"query id {id} not found in query_map" + + # update query_info + query_info = self.query_map[id] + query_info.active_position += l + + if query_info.active_position >= query_info.query_length and query_info.is_prefill: + query_info.is_prefill = False + query_info.prefill_duration_time = time.time() - query_info.enqueue_time + query_info.prefill_tps = query_info.query_length / query_info.prefill_duration_time + + + # generate schedule query_update + query_update = sched_ext.QueryUpdate() + query_update.id = id + query_update.ok = True + query_update.is_prefill = query_info.is_prefill + query_update.active_position = query_info.active_position + # if(not query_info.is_prefill): + query_updates.append(query_update) + + + decode_mini_batches = batch.decode_mini_batches + + for ids in decode_mini_batches: + for id in ids: + if id not in self.query_map: + assert False, f"query id {id} not found in query_map" + + query_info = self.query_map[id] + query_info.active_position += 1 + + query_update = sched_ext.QueryUpdate() + query_update.id = id + query_update.ok = True + query_update.is_prefill = query_info.is_prefill + + query_update.decode_done = query_info.check_stop() + + query_update.active_position = query_info.active_position + query_updates.append(query_update) + + return query_updates \ No newline at end of file diff --git a/ktransformers/server/balance_serve/inference/sampling/penaltylib/__init__.py b/ktransformers/server/balance_serve/inference/sampling/penaltylib/__init__.py new file mode 100644 index 0000000..43fff0f --- /dev/null +++ b/ktransformers/server/balance_serve/inference/sampling/penaltylib/__init__.py @@ -0,0 +1,13 @@ +from .orchestrator import BatchedPenalizerOrchestrator +from .penalizers.frequency_penalty import BatchedFrequencyPenalizer +from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer +from .penalizers.presence_penalty import BatchedPresencePenalizer +from .penalizers.repetition_penalty import BatchedRepetitionPenalizer + +__all__ = [ + "BatchedFrequencyPenalizer", + "BatchedMinNewTokensPenalizer", + "BatchedPresencePenalizer", + "BatchedRepetitionPenalizer", + "BatchedPenalizerOrchestrator", +] diff --git a/ktransformers/server/balance_serve/inference/sampling/penaltylib/orchestrator.py b/ktransformers/server/balance_serve/inference/sampling/penaltylib/orchestrator.py new file mode 100644 index 0000000..c35e8ed --- /dev/null +++ b/ktransformers/server/balance_serve/inference/sampling/penaltylib/orchestrator.py @@ -0,0 +1,376 @@ +import abc +import dataclasses +import typing + +import torch + + +@dataclasses.dataclass +class _ReqLike: + origin_input_ids: typing.Union[torch.Tensor, typing.List[int]] + + +@dataclasses.dataclass +class _BatchLike: + reqs: typing.List[_ReqLike] + + def batch_size(self): + return len(self.reqs) + + +class BatchedPenalizerOrchestrator: + batch: _BatchLike + device: str + vocab_size: int + penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"] + + def __init__( + self, + vocab_size: int, + batch: _BatchLike, + device: str, + Penalizers: typing.Set[typing.Type["_BatchedPenalizer"]], + ): + self.vocab_size = vocab_size + self.batch = batch + self.device = device + + self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers} + + is_required = False + for penalizer in self.penalizers.values(): + pen_is_required = penalizer.prepare_if_required() + is_required |= pen_is_required + self.is_required = is_required + + if self.is_required: + self.cumulate_input_tokens( + input_ids=[req.origin_input_ids for req in self.reqs()] + ) + + def reqs(self): + return self.batch.reqs + + def batch_size(self): + return self.batch.batch_size() + + def cumulate_input_tokens( + self, + input_ids: typing.Union[ + typing.List[torch.Tensor], typing.List[typing.List[int]] + ], + ): + """ + Feed the input tokens to the penalizers. + + Args: + input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens. + """ + token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids) + + for penalizer in self.penalizers.values(): + penalizer.cumulate_input_tokens(input_ids=token_ids) + + def cumulate_output_tokens( + self, + output_ids: typing.Union[ + typing.List[torch.Tensor], typing.List[typing.List[int]] + ], + ): + """ + Feed the output tokens to the penalizers. + + Args: + output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens. + """ + if not self.is_required: + return + + token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids) + + for penalizer in self.penalizers.values(): + penalizer.cumulate_output_tokens(output_ids=token_ids) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + """ + Apply the penalizers to the logits. + Note that it may apply the penalizers in-place. + + Args: + logits (torch.Tensor): The logits to apply the penalizers to. + + Returns: + torch.Tensor: The logits after applying the penalizers. + """ + if not self.is_required: + return + + for penalizer in self.penalizers.values(): + logits = penalizer.apply(logits) + + return logits + + def filter( + self, + indices_to_keep: typing.List[int], + indices_tensor_to_keep: torch.Tensor = None, + ): + """ + Filter the penalizers based on the indices to keep in the batch. + + Args: + indices_to_keep (typing.List[int]): List of indices to keep in the batch. + indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor. + """ + if not self.is_required: + return + + empty_indices = len(indices_to_keep) == 0 + + is_required = False + for penalizer in self.penalizers.values(): + tmp_is_required = penalizer.is_required() + is_required = is_required or tmp_is_required + if not tmp_is_required or empty_indices: + penalizer.teardown() + else: + # create tensor index only when it's needed + if indices_tensor_to_keep is None: + indices_tensor_to_keep = torch.tensor( + indices_to_keep, dtype=torch.int32, device=self.device + ) + + penalizer.filter( + indices_to_keep=indices_to_keep, + indices_tensor_to_keep=indices_tensor_to_keep, + ) + self.is_required = is_required + + def merge(self, their: "BatchedPenalizerOrchestrator"): + """ + Merge the penalizers of another orchestrator into this one. + + Note that this function **must** be called _before_ self.batch.reqs is updated (filtered). + Each unprepared penalizers would have to be prepared (creating tensors, etc.) first before merging. + This step requires the original batch.reqs, before it gets merged with other batch.reqs. + + Args: + their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one. + """ + if not self.is_required and not their.is_required: + return + + self.is_required |= their.is_required + for Penalizer, their_penalizer in their.penalizers.items(): + if Penalizer not in self.penalizers: + raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers") + + self.penalizers[Penalizer].merge(their_penalizer) + + +class _TokenIDs: + """ + A class that wraps token IDs to provide additional utility functions to penalizers. + + Attributes: + orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to. + token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs. + cached_counts (torch.Tensor): The cached occurrence count tensor. + """ + + orchestrator: BatchedPenalizerOrchestrator + token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]] + cached_counts: torch.Tensor = None + + def __init__( + self, + orchestrator: BatchedPenalizerOrchestrator, + token_ids: typing.Union[ + typing.List[torch.Tensor], typing.List[typing.List[int]] + ], + ): + self.orchestrator = orchestrator + + if not isinstance(token_ids[0], torch.Tensor): + token_ids = [ + torch.tensor( + data=ids, dtype=torch.int64, device=self.orchestrator.device + ) + for ids in token_ids + ] + + self.token_ids = token_ids + + def occurrence_count(self) -> torch.Tensor: + """ + Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch. + + Returns: + torch.Tensor: The occurrence count tensor. + """ + if self.cached_counts is not None: + return self.cached_counts + + token_ids = self.token_ids + + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.unsqueeze(1) + + # needs to be long to be used as index in scatter_add + if token_ids.dtype != torch.int64: + token_ids = token_ids.to(torch.int64) + + padded_token_ids = torch.nn.utils.rnn.pad_sequence( + sequences=token_ids, + batch_first=True, + padding_value=self.orchestrator.vocab_size, + ) + + self.cached_counts = torch.zeros( + size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1), + dtype=torch.int64, + device=self.orchestrator.device, + ).scatter_add_( + dim=1, + index=padded_token_ids, + src=torch.ones_like(padded_token_ids), + )[ + :, : self.orchestrator.vocab_size + ] + + return self.cached_counts + + +class _BatchedPenalizer(abc.ABC): + """ + An abstract class for a batched penalizer. + """ + + orchestrator: BatchedPenalizerOrchestrator + _is_prepared: bool = False + + def __init__(self, orchestrator: BatchedPenalizerOrchestrator): + self.orchestrator = orchestrator + + def is_prepared(self) -> bool: + return self._is_prepared + + def is_required(self) -> bool: + return self._is_required() + + def prepare(self): + if not self.is_prepared(): + self._prepare() + self._is_prepared = True + + def prepare_if_required(self): + if self.is_required(): + self.prepare() + return True + else: + return False + + def teardown(self): + if self.is_prepared(): + self._teardown() + self._is_prepared = False + + def cumulate_input_tokens(self, input_ids: _TokenIDs): + if not self.is_prepared(): + return + + self._cumulate_input_tokens(input_ids=input_ids) + + def cumulate_output_tokens(self, output_ids: _TokenIDs): + if not self.is_prepared(): + return + + self._cumulate_output_tokens(output_ids=output_ids) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if not self.is_prepared(): + return logits + + return self._apply(logits=logits) + + def filter( + self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor + ): + if not self.is_prepared(): + return + + self._filter( + indices_to_keep=indices_to_keep, + indices_tensor_to_keep=indices_tensor_to_keep, + ) + + def merge(self, their: "_BatchedPenalizer"): + if not self.is_prepared() and not their.is_prepared(): + return + + self.prepare() + their.prepare() + self._merge(their) + + @abc.abstractmethod + def _is_required(self) -> bool: + """ + Check if the penalizer is required to be prepared. + """ + pass + + @abc.abstractmethod + def _prepare(self): + """ + Prepare the penalizer. + Usually, this is where the penalizer initializes its tensors. + """ + pass + + @abc.abstractmethod + def _teardown(self): + """ + Tear down the penalizer. + Usually, this is where the penalizer frees its tensors. + """ + pass + + @abc.abstractmethod + def _cumulate_input_tokens(self, input_ids: _TokenIDs): + """ + Cumulate the input tokens. + Orchestrator will call this function to feed the input tokens to the penalizer. + """ + pass + + @abc.abstractmethod + def _cumulate_output_tokens(self, output_ids: _TokenIDs): + """ + Cumulate the output tokens. + Orchestrator will call this function to feed the output tokens to the penalizer. + """ + pass + + @abc.abstractmethod + def _apply(self, logits: torch.Tensor) -> torch.Tensor: + """ + Apply the penalizer to the logits. + Penalizers can modify the logits in-place if needed. + """ + pass + + @abc.abstractmethod + def _filter( + self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor + ): + """ + Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch. + """ + pass + + @abc.abstractmethod + def _merge(self, their: "_BatchedPenalizer"): + """ + Merge the penalizer with another penalizer. + """ + pass diff --git a/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/frequency_penalty.py b/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/frequency_penalty.py new file mode 100644 index 0000000..178cb54 --- /dev/null +++ b/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/frequency_penalty.py @@ -0,0 +1,80 @@ +import typing + +import torch + +from ..orchestrator import _BatchedPenalizer, _TokenIDs + + +class BatchedFrequencyPenalizer(_BatchedPenalizer): + """ + Frequency penalizer penalizes tokens based on their frequency in the output. + """ + + frequency_penalties: torch.Tensor = None + cumulated_frequency_penalties: torch.Tensor = None + + def _is_required(self) -> bool: + return any( + req.sampling_params.frequency_penalty != 0.0 + for req in self.orchestrator.reqs() + ) + + def _prepare(self): + self.cumulated_frequency_penalties = ( + torch.tensor( + data=[0.0 for _ in self.orchestrator.reqs()], + dtype=torch.float32, + device=self.orchestrator.device, + ) + .unsqueeze_(1) + .repeat(1, self.orchestrator.vocab_size) + ) + + self.frequency_penalties = ( + torch.tensor( + data=[ + req.sampling_params.frequency_penalty + for req in self.orchestrator.reqs() + ], + dtype=torch.float32, + device=self.orchestrator.device, + ) + .unsqueeze_(1) + .expand_as(self.cumulated_frequency_penalties) + ) + + def _teardown(self): + del self.frequency_penalties + del self.cumulated_frequency_penalties + + self.frequency_penalties = None + self.cumulated_frequency_penalties = None + + def _cumulate_input_tokens(self, input_ids: _TokenIDs): + pass + + def _cumulate_output_tokens(self, output_ids: _TokenIDs): + self.cumulated_frequency_penalties += ( + self.frequency_penalties * output_ids.occurrence_count() + ) + + def _apply(self, logits: torch.Tensor) -> torch.Tensor: + logits -= self.cumulated_frequency_penalties + return logits + + def _filter( + self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor + ): + self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep] + self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[ + indices_tensor_to_keep + ] + + def _merge(self, their: "BatchedFrequencyPenalizer"): + self.frequency_penalties = torch.cat( + [self.frequency_penalties, their.frequency_penalties], dim=0 + ) + self.cumulated_frequency_penalties = torch.cat( + [self.cumulated_frequency_penalties, their.cumulated_frequency_penalties], + dim=0, + ) diff --git a/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/min_new_tokens.py b/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/min_new_tokens.py new file mode 100644 index 0000000..cc97a2e --- /dev/null +++ b/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/min_new_tokens.py @@ -0,0 +1,108 @@ +import typing + +import torch + +from ..orchestrator import _BatchedPenalizer, _TokenIDs + + +class BatchedMinNewTokensPenalizer(_BatchedPenalizer): + """ + Min new tokens penalizer penalizes tokens based on the length of the output. + """ + + min_new_tokens: torch.Tensor = None + stop_token_penalties: torch.Tensor = None + len_output_tokens: torch.Tensor = None + + def _is_required(self) -> bool: + return any( + req.sampling_params.min_new_tokens > 0 for req in self.orchestrator.reqs() + ) + + def _prepare(self): + self.min_new_tokens = torch.tensor( + data=[ + req.sampling_params.min_new_tokens for req in self.orchestrator.reqs() + ], + dtype=torch.int32, + device=self.orchestrator.device, + ).unsqueeze_(1) + + padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence( + sequences=[ + torch.tensor( + data=( + list( + (req.sampling_params.stop_token_ids or set()) + | (req.tokenizer.additional_stop_token_ids or set()) + | {req.tokenizer.eos_token_id} + ) + ), + dtype=torch.int64, + device=self.orchestrator.device, + ) + for req in self.orchestrator.reqs() + ], + batch_first=True, + padding_value=self.orchestrator.vocab_size, + ) + self.stop_token_penalties = torch.zeros( + size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1), + dtype=torch.float32, + device=self.orchestrator.device, + ).scatter_add_( + dim=1, + index=padded_stop_token_ids, + src=torch.full_like( + input=padded_stop_token_ids, + dtype=torch.float32, + fill_value=float("-inf"), + device=self.orchestrator.device, + ), + )[ + :, : self.orchestrator.vocab_size + ] + + self.len_output_tokens = torch.zeros( + size=(self.orchestrator.batch_size(), 1), + dtype=torch.int32, + device=self.orchestrator.device, + ) + + def _teardown(self): + del self.min_new_tokens + del self.stop_token_penalties + del self.len_output_tokens + + self.min_new_tokens = None + self.stop_token_penalties = None + self.len_output_tokens = None + + def _cumulate_input_tokens(self, input_ids: _TokenIDs): + pass + + def _cumulate_output_tokens(self, output_ids: _TokenIDs): + self.len_output_tokens += 1 + + def _apply(self, logits: torch.Tensor) -> torch.Tensor: + mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits) + logits[mask] += self.stop_token_penalties[mask] + return logits + + def _filter( + self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor + ): + self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep] + self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep] + self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep] + + def _merge(self, their: "BatchedMinNewTokensPenalizer"): + self.min_new_tokens = torch.cat( + [self.min_new_tokens, their.min_new_tokens], dim=0 + ) + self.stop_token_penalties = torch.cat( + [self.stop_token_penalties, their.stop_token_penalties], dim=0 + ) + self.len_output_tokens = torch.cat( + [self.len_output_tokens, their.len_output_tokens], dim=0 + ) diff --git a/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/presence_penalty.py b/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/presence_penalty.py new file mode 100644 index 0000000..0593fdd --- /dev/null +++ b/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/presence_penalty.py @@ -0,0 +1,79 @@ +import typing + +import torch + +from ..orchestrator import _BatchedPenalizer, _TokenIDs + + +class BatchedPresencePenalizer(_BatchedPenalizer): + """ + Presence penalizer penalizes tokens based on their presence in the output. + """ + + presence_penalties: torch.Tensor = None + cumulated_presence_penalties: torch.Tensor = None + + def _is_required(self) -> bool: + return any( + req.sampling_params.presence_penalty != 0.0 + for req in self.orchestrator.reqs() + ) + + def _prepare(self): + self.cumulated_presence_penalties = ( + torch.tensor( + data=[0.0 for _ in self.orchestrator.reqs()], + dtype=torch.float32, + device=self.orchestrator.device, + ) + .unsqueeze_(1) + .repeat(1, self.orchestrator.vocab_size) + ) + + self.presence_penalties = ( + torch.tensor( + data=[ + req.sampling_params.presence_penalty + for req in self.orchestrator.reqs() + ], + dtype=torch.float32, + device=self.orchestrator.device, + ) + .unsqueeze_(1) + .expand_as(self.cumulated_presence_penalties) + ) + + def _teardown(self): + del self.presence_penalties + del self.cumulated_presence_penalties + + self.presence_penalties = None + self.cumulated_presence_penalties = None + + def _cumulate_input_tokens(self, input_ids: _TokenIDs): + pass + + def _cumulate_output_tokens(self, output_ids: _TokenIDs): + mask = output_ids.occurrence_count() > 0 + self.cumulated_presence_penalties[mask] = self.presence_penalties[mask] + + def _apply(self, logits: torch.Tensor) -> torch.Tensor: + logits -= self.cumulated_presence_penalties + return logits + + def _filter( + self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor + ): + self.presence_penalties = self.presence_penalties[indices_tensor_to_keep] + self.cumulated_presence_penalties = self.cumulated_presence_penalties[ + indices_tensor_to_keep + ] + + def _merge(self, their: "BatchedPresencePenalizer"): + self.presence_penalties = torch.cat( + [self.presence_penalties, their.presence_penalties], dim=0 + ) + self.cumulated_presence_penalties = torch.cat( + [self.cumulated_presence_penalties, their.cumulated_presence_penalties], + dim=0, + ) diff --git a/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/repetition_penalty.py b/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/repetition_penalty.py new file mode 100644 index 0000000..ea32add --- /dev/null +++ b/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/repetition_penalty.py @@ -0,0 +1,83 @@ +import typing + +import torch + +from ..orchestrator import _BatchedPenalizer, _TokenIDs + + +class BatchedRepetitionPenalizer(_BatchedPenalizer): + """ + Repetition penalizer penalizes tokens based on their repetition in the input and output. + """ + + repetition_penalties: torch.Tensor = None + cumulated_repetition_penalties: torch.Tensor = None + + def _is_required(self) -> bool: + return any( + req.sampling_params.repetition_penalty != 1.0 + for req in self.orchestrator.reqs() + ) + + def _prepare(self): + self.cumulated_repetition_penalties = ( + torch.tensor( + data=[1.0 for _ in self.orchestrator.reqs()], + dtype=torch.float32, + device=self.orchestrator.device, + ) + .unsqueeze_(1) + .repeat(1, self.orchestrator.vocab_size) + ) + + self.repetition_penalties = ( + torch.tensor( + data=[ + req.sampling_params.repetition_penalty + for req in self.orchestrator.reqs() + ], + dtype=torch.float32, + device=self.orchestrator.device, + ) + .unsqueeze_(1) + .expand_as(self.cumulated_repetition_penalties) + ) + + def _teardown(self): + del self.repetition_penalties + del self.cumulated_repetition_penalties + + self.repetition_penalties = None + self.cumulated_repetition_penalties = None + + def _cumulate_input_tokens(self, input_ids: _TokenIDs): + mask = input_ids.occurrence_count() > 0 + self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask] + + def _cumulate_output_tokens(self, output_ids: _TokenIDs): + mask = output_ids.occurrence_count() > 0 + self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask] + + def _apply(self, logits: torch.Tensor) -> torch.Tensor: + return torch.where( + logits > 0, + logits / self.cumulated_repetition_penalties, + logits * self.cumulated_repetition_penalties, + ) + + def _filter( + self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor + ): + self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep] + self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[ + indices_tensor_to_keep + ] + + def _merge(self, their: "BatchedRepetitionPenalizer"): + self.repetition_penalties = torch.cat( + [self.repetition_penalties, their.repetition_penalties], dim=0 + ) + self.cumulated_repetition_penalties = torch.cat( + [self.cumulated_repetition_penalties, their.cumulated_repetition_penalties], + dim=0, + ) diff --git a/ktransformers/server/balance_serve/inference/sampling/sampler.py b/ktransformers/server/balance_serve/inference/sampling/sampler.py new file mode 100644 index 0000000..f491c97 --- /dev/null +++ b/ktransformers/server/balance_serve/inference/sampling/sampler.py @@ -0,0 +1,100 @@ +''' +Date: 2024-11-14 12:23:45 +LastEditors: Xie Weiyu ervinxie@qq.com +LastEditTime: 2024-11-25 08:59:23 +''' +import logging +import torch +from torch import nn +from transformers import GenerationConfig + +from flashinfer.sampling import ( + min_p_sampling_from_probs, + top_k_renorm_probs, + top_k_top_p_sampling_from_logits, + top_p_renorm_probs, +) + +logger = logging.getLogger(__name__) + +class SamplingOptions(): + # Batched sampling params + temperatures: torch.Tensor + top_ps: torch.Tensor + top_ks: torch.Tensor + min_ps: torch.Tensor + + # All requests use greedy sampling + is_all_greedy: bool + + # Dispatch in CUDA graph + need_min_p_sampling: bool + + def __init__(self, bsz = 1, device = torch.device('cuda'), pretrained_config:GenerationConfig = None, temperatures: torch.Tensor = None, top_ps: torch.Tensor = None): + if pretrained_config is None and temperatures is None: + self.temperatures = torch.full((bsz, 1), 0, device=device, dtype=torch.float32) + self.top_ps = torch.ones((bsz, 1), device=device, dtype=torch.float32) + self.top_ks = torch.ones((bsz, 1), device=device, dtype=torch.float32) + self.need_min_p_sampling = False + self.is_all_greedy = True + else: + if temperatures is not None: + self.temperatures = temperatures.unsqueeze(-1) + else: + self.temperatures = torch.full((bsz, 1), pretrained_config.temperature, device=device, dtype=torch.float32) + + if top_ps is not None: + self.top_ps = top_ps.unsqueeze(-1) + else: + self.top_ps = torch.full((bsz, 1), pretrained_config.top_p, device=device, dtype=torch.float32) + self.top_ks = torch.full((bsz, 1), pretrained_config.top_k, device=device, dtype=torch.float32) + self.need_min_p_sampling = False + self.is_all_greedy = False + +class Sampler(nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, + logits: torch.Tensor, + sampling_config: SamplingOptions = None, + ): + if sampling_config == None: + sampling_config = SamplingOptions() + + logits = logits.contiguous() + origin_logits = logits.clone() + if sampling_config.is_all_greedy: + # Use torch.argmax if all requests use greedy sampling + probs = logits + batch_next_token_ids = torch.argmax(logits, -1) + else: + # Post process logits + logits.div_(sampling_config.temperatures) + max_top_k_round, batch_size = 32, logits.shape[0] + if sampling_config.need_min_p_sampling: + probs = torch.softmax(logits, dim=-1) + logits = None + del logits + probs = top_k_renorm_probs(probs, sampling_config.top_ks) + probs = top_p_renorm_probs(probs, sampling_config.top_ps) + batch_next_token_ids = min_p_sampling_from_probs( + probs, sampling_config.min_ps + ) + temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0] + batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32) + else: + # TODO: use different kernel when don't need top_k or top_p + # @TODO get probs + probs = logits + batch_next_token_ids = top_k_top_p_sampling_from_logits( + logits, + sampling_config.top_ks, + sampling_config.top_ps, + filter_apply_order="joint", + ) + temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0] + batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32) + + return batch_next_token_ids.to(torch.int32), probs \ No newline at end of file diff --git a/ktransformers/server/balance_serve/sched_rpc.py b/ktransformers/server/balance_serve/sched_rpc.py new file mode 100644 index 0000000..8294b43 --- /dev/null +++ b/ktransformers/server/balance_serve/sched_rpc.py @@ -0,0 +1,213 @@ +from datetime import datetime +import os +from typing import Optional +import zmq +import pickle +import threading +import torch.multiprocessing as mp +import sys +current_file_path = os.path.abspath(__file__) +# sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..")) +import pickle +import argparse +from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings + + + +if mp.get_start_method(allow_none=True) is None: + print('set start method') + mp.set_start_method('spawn') +else: + print(f'start method already set to {mp.get_start_method(allow_none=True)}') + + +class SchedulerServer: + def __init__(self, settings, main_args): + # 创建 Scheduler 实例并初始化 + self.sched = sched_ext.create_scheduler(settings) + + # 初始化 ZeroMQ 上下文和套接字 + self.context = zmq.Context() + self.frontend = self.context.socket(zmq.ROUTER) + print(f"sched zmq rpc server on port {main_args.sched_port}") + self.frontend.bind(f"tcp://*:{main_args.sched_port}") + + # 创建内部的 DEALER 套接字,用于与工作线程通信 + self.backend = self.context.socket(zmq.DEALER) + self.backend.bind("inproc://backend") + + # 启动调度器 + def run_scheduler(self): + self.sched.run() + + # 停止调度器 + def stop_scheduler(self): + self.sched.stop() + + # 处理客户端请求 + def start_proxy(self): + # 使用 ZMQ 的内置代理,将前端请求分发给后端工作线程 + zmq.proxy(self.frontend, self.backend) + + # 工作线程处理请求 + def worker_routine(self): + worker = self.context.socket(zmq.REP) + worker.connect("inproc://backend") + while True: + try: + # 接收客户端请求 + message = worker.recv() + data = pickle.loads(message) + + method = data.get('method') + params = data.get('params', {}) + # print(f"Received request: {method}") + + if method == 'add_query': + query_add = params.get('query') # 直接是一个 QueryAdd 对象 + # 添加查询 + query_id = self.sched.add_query(query_add) + # 发送响应 + response = {'status': 'ok', 'query_id': query_id} + worker.send(pickle.dumps(response)) + + elif method == 'cancel_query': + query_id = params.get('query_id') + # 假设您的 Scheduler 类实现了 cancel 方法 + self.sched.cancel(query_id) + response = {'status': 'ok'} + worker.send(pickle.dumps(response)) + + elif method == 'update_last_batch': + updates = params.get('updates') # 直接是一个列表,包含 QueryUpdate 对象 + + # 更新最后一个批次 + batch_todo = self.sched.update_last_batch(updates) + + # 直接发送 batch_todo 对象 + response = {'status': 'ok', 'batch_todo': batch_todo} + # print (batch_todo.query_lengths, batch_todo.query_ids) + worker.send(pickle.dumps(response)) + + elif method == 'get_inference_context': + inference_context = self.sched.get_inference_context() + data = { + "k_cache":inference_context.k_cache, + "v_cache":inference_context.v_cache + } + print(f"Serializing KVCache") + data["k_cache"] = [mp.reductions.reduce_tensor(t) for t in data['k_cache']] + data["v_cache"] = [mp.reductions.reduce_tensor(t) for t in data['v_cache']] + # print(data) + response = {'status': 'ok', 'inference_context': data} + + worker.send(pickle.dumps(response)) + # response['inference_context'].k_cache[0][0, 0, 0, 0, 0] = 1 + # print("k_cache update") + + else: + # 未知方法 + response = {'status': 'error', 'message': 'Unknown method'} + worker.send(pickle.dumps(response)) + + except Exception as e: + # 处理异常并发送错误响应 + response = {'status': 'error', 'message': str(e)} + worker.send(pickle.dumps(response)) + + # 启动 RPC 服务 + def start_rpc_service(self): + try: + print("Scheduler RPC service is running...") + + # 在单独的线程中运行调度器 + threading.Thread(target=self.run_scheduler, daemon=True).start() + + # 启动工作线程 + for _ in range(10): # 根据需要调整线程数 + threading.Thread(target=self.worker_routine, daemon=True).start() + + # 启动代理,开始监听请求 + self.start_proxy() + + except KeyboardInterrupt: + print("Shutting down scheduler RPC service...") + self.stop_rpc_service() + + # 停止 RPC 服务 + def stop_rpc_service(self): + self.stop_scheduler() + self.frontend.close() + self.backend.close() + self.context.term() + +def start_server(settings, main_args): + server = SchedulerServer(settings, main_args) + server.start_rpc_service() + + +# Add async client for webserver +class SchedulerClient: + def __init__(self, sched_port): + address=f'tcp://localhost:{sched_port}' + self.address = address + self.context = zmq.Context() + self.socket = self.context.socket(zmq.REQ) + self.socket.connect(self.address) + print(f"Connected to server at {self.address}") + + def __del__(self): + self.socket.close() + self.context.term() + + def send_request(self, method, params=None): + if params is None: + params = {} + request = { + 'method': method, + 'params': params + } + # print(f'send request {request}') + self.socket.send(pickle.dumps(request)) + response = self.socket.recv() + # print(response) + response = pickle.loads(response) + if response.get('status') == 'ok': + return response + else: + raise Exception(f"Error from server: {response.get('message')}") + + def add_query(self, query): + response = self.send_request('add_query', {'query': query}) + return response.get('query_id') + + def cancel_query(self, query_id): + self.send_request('cancel_query', {'query_id': query_id}) + + def update_last_batch(self, updates): + response = self.send_request('update_last_batch', {'updates': updates}) + # print(f"update_last_batch response {response}") + return response.get('batch_todo') + + def rebuild_inferece_context(self,response): + data = response.get('inference_context') + inference_context = sched_ext.InferenceContext() + print('Rebuilding kvcache') + inference_context.k_cache = [fn(*args) for fn,args in data['k_cache']] + inference_context.v_cache = [fn(*args) for fn,args in data['v_cache']] + return inference_context + + def get_inference_context_raw(self): + response = self.send_request('get_inference_context') + return response + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True) + args = parser.parse_args() + with open(args.config, "rb") as f: + main_args = pickle.load(f) + settings = create_sched_settings(main_args) + start_server(settings, main_args) diff --git a/ktransformers/server/balance_serve/settings.py b/ktransformers/server/balance_serve/settings.py new file mode 100644 index 0000000..0a90a86 --- /dev/null +++ b/ktransformers/server/balance_serve/settings.py @@ -0,0 +1,76 @@ +''' +Date: 2024-11-13 09:43:39 +LastEditors: djw +LastEditTime: 2024-11-18 16:41:03 +''' +import sys, os +import yaml, json +from time import sleep + +current_dir = os.path.dirname(__file__) +# sched_path = os.path.abspath(os.path.join(current_dir, '../../../build/balance_serve/sched')) +# sys.path.insert(0, sched_path) +import sched_ext +from transformers import AutoConfig + +def create_sched_settings(args): + default_sample_options = sched_ext.SampleOptions() + model_name = os.path.basename(os.path.normpath(args.model_dir)) + input_model_settings = sched_ext.ModelSettings() + input_model_settings.model_path = args.model_dir + input_model_settings.params_count = int(0) + model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) + input_model_settings.layer_count = model_config.num_hidden_layers + input_model_settings.num_k_heads = 1 # model_config["num_key_value_heads"] + input_model_settings.k_head_dim = 576 + input_model_settings.bytes_per_params = 2 + input_model_settings.bytes_per_kv_cache_element = 2 + settings = sched_ext.Settings() + settings.model_name = model_name + settings.quant_type = "BF16" + settings.model_settings = input_model_settings + settings.page_size = args.page_size + settings.gpu_device_count = 1 # tp + settings.gpu_device_id = [i for i in range(settings.gpu_device_count)] + # settings.gpu_memory_size = args.cache_lens*576*2 + settings.gpu_memory_size = args.gpu_memory_size + settings.memory_utilization_percentage = args.utilization_percentage + max_batch_size = args.max_batch_size + chunk_size = args.chunk_size + + max_decode_batch_size = max_batch_size - 2 + + settings.max_batch_size = max_batch_size + settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2 + settings.sample_options = default_sample_options + settings.sched_metrics_port = args.sched_metrics_port + settings.gpu_only = args.memory_gpu_only + settings.use_self_defined_head_dim = True + settings.self_defined_head_dim = 576 + settings.full_kv_cache_on_each_gpu = True + settings.k_cache_on = True + settings.v_cache_on = False + + settings.kvc2_root_path = '/mnt/data/persist-kvc' + settings.kvc2_config_path = os.path.join(current_dir, "..", "..", "configs") + print(os.path.join(current_dir, "..", "..", "configs")) + settings.memory_pool_size_GB = args.cpu_memory_size_GB + settings.evict_count = 40 + settings.kvc2_metrics_port = args.kvc2_metrics_port + settings.load_from_disk = False + settings.save_to_disk = True + + + settings.strategy_name = args.sched_strategy + + settings.auto_derive() + return settings + + + + + + + + + diff --git a/ktransformers/server/config/config.py b/ktransformers/server/config/config.py index 332e398..e5cbafc 100644 --- a/ktransformers/server/config/config.py +++ b/ktransformers/server/config/config.py @@ -11,6 +11,7 @@ LastEditTime : 2024-08-12 06:31:14 import os import shutil import yaml +import psutil from ktransformers.server.config.singleton import Singleton from typing import Optional @@ -60,7 +61,7 @@ class Config(metaclass=Singleton): self.user_path: str = os.path.expanduser("~") self.localstore_path: str = os.path.join(self.user_path, ".ktransformers") # log configs - self.log_dir = os.path.join(self.base_path, Config.to_path(cfg["log"]["dir"])) + self.log_dir = os.path.join(self.localstore_path, cfg["log"]["dir"]) self.log_file = cfg["log"]["file"] self.log_level = cfg["log"]["level"] self.backup_count = cfg["log"]["backup_count"] @@ -74,7 +75,7 @@ class Config(metaclass=Singleton): # db configs self.db_configs: dict = cfg.get("db", {}) self.db_type = self.db_configs.get("type", "") - self.db_host = os.path.join(self.base_path, self.db_configs.get("host", "")) + self.db_host = Config.to_path(self.db_configs.get("host", "")) self.db_port = self.db_configs.get("port", "") self.db_name = self.db_configs.get("database", "") self.db_pool_size = self.db_configs.get("pool_size") @@ -101,11 +102,6 @@ class Config(metaclass=Singleton): self.optimize_config_path: Optional[str] = self.model.get( "optimize_config_path", None ) - self.paged = self.model.get("paged", True) - - self.total_context = self.model.get("total_context", 2**18) - self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1) - self.chunk_prefill_size = self.model.get("chunk_prefill_size", 8192) self.max_new_tokens = self.model.get("max_new_tokens", 2000) self.json_mode = self.model.get("json_mode", False) @@ -138,7 +134,6 @@ class Config(metaclass=Singleton): self.repetition_penalty = self.model.get("repetition_penalty", 1.01) self.frequency_penalty = self.model.get("frequency_penalty", 0.0) self.presence_penalty = self.model.get("presence_penalty", 0.0) - self.max_response_tokens = self.model.get("max_response_tokens", 300) self.response_chunk = self.model.get("response_chunk", 250) self.no_code_formatting = self.model.get("no_code_formatting", False) self.cache_8bit = self.model.get("cache_8bit", False) @@ -155,8 +150,9 @@ class Config(metaclass=Singleton): self.web_cross_domain: bool = self.web.get("open_cross_domain", True) self.mount_web: bool = self.web.get("mount", False) + # ext self.ext: dict = cfg.get("ext", {}) - self.cpu_infer = self.ext.get("cpu_infer", 10) + self.cpu_infer = psutil.cpu_count(logical=False) - 3 # file config self.local_store_configs: dict = cfg.get("local_store", {}) @@ -169,7 +165,6 @@ class Config(metaclass=Singleton): # long context config self.long_context_config: dict = cfg.get("long_context", {}) - self.chunk_size = self.long_context_config.get("chunk_size", 4096) self.max_seq_len = self.long_context_config.get("max_seq_len", 32000) self.block_size = self.long_context_config.get("block_size", 128) self.local_windows_len = self.long_context_config.get("local_windows_len", 4096) @@ -187,3 +182,21 @@ class Config(metaclass=Singleton): # local chat self.local_chat_config: dict = cfg.get("local_chat", {}) self.prompt_file = self.local_chat_config.get("prompt_file", None) + + # asyncserver + self.sched_strategy = cfg['async_server']['sched_strategy'] + self.sched_port = cfg['async_server']['sched_port'] + self.sched_metrics_port = cfg['async_server']['sched_metrics_port'] + self.kvc2_metrics_port = cfg['async_server']['kvc2_metrics_port'] + self.max_batch_size = cfg['async_server']['max_batch_size'] + self.page_size = cfg['attn']['page_size'] + self.chunk_size = cfg['attn']['chunk_size'] + self.memory_gpu_only = cfg['kvc2']['gpu_only'] + self.cache_lens = ((self.cache_lens + self.page_size - 1) // self.page_size) * self.page_size + self.gpu_memory_size = 2*576*61*self.cache_lens + self.utilization_percentage = 1.0 #cfg['kvc2']['utilization_percentage'] + self.cpu_memory_size_GB = cfg['kvc2']['cpu_memory_size_GB'] + # only support 2 prefill task + self.max_prefill_batch_size = 2 + self.max_decode_batch_size = self.max_batch_size - self.max_prefill_batch_size + diff --git a/ktransformers/server/main.py b/ktransformers/server/main.py index f536f9c..8108a3c 100644 --- a/ktransformers/server/main.py +++ b/ktransformers/server/main.py @@ -5,24 +5,20 @@ from fastapi.staticfiles import StaticFiles import uvicorn.logging import uvicorn import sys - +import atexit project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) -sys.path.insert(0, project_dir) from fastapi.middleware.cors import CORSMiddleware from ktransformers.server.args import ArgumentParser from ktransformers.server.config.config import Config -from ktransformers.server.utils.create_interface import create_interface -from ktransformers.server.backend.args import default_args +from ktransformers.server.utils.create_interface import create_interface, GlobalInterface from fastapi.openapi.utils import get_openapi - from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware - - from ktransformers.server.api import router, post_db_creation_operations from ktransformers.server.utils.sql_utils import Base, SQLUtil from ktransformers.server.config.log import logger - +import subprocess +import tempfile def mount_app_routes(mount_app: FastAPI): sql_util = SQLUtil() @@ -34,7 +30,10 @@ def mount_app_routes(mount_app: FastAPI): def create_app(): cfg = Config() - app = FastAPI() + if(hasattr(GlobalInterface.interface, "lifespan")): + app = FastAPI(lifespan=GlobalInterface.interface.lifespan) + else: + app = FastAPI() if Config().web_cross_domain: app.add_middleware( CORSMiddleware, @@ -108,11 +107,32 @@ def main(): arg_parser = ArgumentParser(cfg) - # 初始化消息 args = arg_parser.parse_args() + if args.backend_type == "balance_serve": + import pickle + def cleanup(): + if sched_process.poll() is None: + sched_process.terminate() + + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + pickle.dump(args, temp_file) + temp_file_path = temp_file.name + current_file = __file__ + target_file = os.path.join(os.path.dirname(current_file), "balance_serve", "sched_rpc.py") + target_file = os.path.normpath(target_file) + log_path = os.path.join(args.log_dir, "rpc.log") + log = open(log_path, "a") + sched_process = subprocess.Popen( + ["python3", target_file, "--config", temp_file_path], + stdout=log, + stderr=log + ) + print("sched_rpc started with PID:", sched_process.pid) + atexit.register(cleanup) + create_interface(config=cfg, default_args=cfg) app = create_app() custom_openapi(app) - create_interface(config=cfg, default_args=cfg) + run_api( app=app, host=args.host, @@ -121,6 +141,5 @@ def main(): ssl_certfile=args.ssl_certfile, ) - if __name__ == "__main__": main() diff --git a/ktransformers/server/requirements.txt b/ktransformers/server/requirements.txt index 9a4c9c5..76377d5 100644 --- a/ktransformers/server/requirements.txt +++ b/ktransformers/server/requirements.txt @@ -1,4 +1,4 @@ -torch >= 2.3.0,<=2.3.1 +torch >= 2.3.0 transformers == 4.43.2 fastapi >= 0.111.0 langchain >= 0.2.0 @@ -11,4 +11,6 @@ build ninja wheel colorlog -fire \ No newline at end of file +fire +zmq +psutil \ No newline at end of file diff --git a/ktransformers/server/schemas/endpoints/chat.py b/ktransformers/server/schemas/endpoints/chat.py index eb0081a..a48d4ab 100644 --- a/ktransformers/server/schemas/endpoints/chat.py +++ b/ktransformers/server/schemas/endpoints/chat.py @@ -2,7 +2,7 @@ from typing import List, Optional from typing_extensions import Literal from enum import Enum -from pydantic import BaseModel +from pydantic import BaseModel, Field from ktransformers.server.schemas.base import Object @@ -30,8 +30,8 @@ class ChatCompletionCreate(BaseModel): messages: List[Message] model : str stream : bool = False - temperature: Optional[float] = None - top_p: Optional[float] = None + temperature: Optional[float] = Field(default=1.0) + top_p: Optional[float] = Field(default=1.0) def get_tokenizer_messages(self): return [m.to_tokenizer_message() for m in self.messages] diff --git a/ktransformers/server/utils/create_interface.py b/ktransformers/server/utils/create_interface.py index af0a331..992c831 100644 --- a/ktransformers/server/utils/create_interface.py +++ b/ktransformers/server/utils/create_interface.py @@ -15,6 +15,7 @@ from ktransformers.server.backend.context_manager import ThreadContextManager from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface from ktransformers.server.backend.interfaces.transformers import TransformersInterface from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface + def create_interface(config: Config, default_args: ConfigArgs): if config.backend_type=='transformers': from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface @@ -22,6 +23,8 @@ def create_interface(config: Config, default_args: ConfigArgs): from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface elif config.backend_type == 'ktransformers': from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface + elif config.backend_type == 'balance_serve': + from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface as BackendInterface else: raise NotImplementedError(f'{config.backend_type} not implemented') GlobalInterface.interface = BackendInterface(default_args) @@ -30,9 +33,9 @@ def create_interface(config: Config, default_args: ConfigArgs): class GlobalContextManager: context_manager: ThreadContextManager class GlobalInterface: - interface: TransformersInterface | KTransformersInterface | ExllamaInterface + interface: TransformersInterface | KTransformersInterface | ExllamaInterface -def get_thread_context_manager() -> ThreadContextManager: +def get_thread_context_manager() -> GlobalContextManager: return GlobalContextManager.context_manager -def get_interface() -> TransformersInterface | KTransformersInterface | ExllamaInterface: +def get_interface() -> GlobalInterface: return GlobalInterface.interface \ No newline at end of file diff --git a/ktransformers/tests/mmlu_test_multi.py b/ktransformers/tests/mmlu_test_multi.py new file mode 100644 index 0000000..06c75ab --- /dev/null +++ b/ktransformers/tests/mmlu_test_multi.py @@ -0,0 +1,155 @@ +import argparse +import random +import time +import json +import requests +import pandas as pd +from datasets import load_dataset +import os +import concurrent.futures +import threading + +os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' +os.environ['https_proxy'] = '' +os.environ['http_proxy'] = '' +hint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.' + +class DataEvaluator: + def __init__(self): + self.data = [] + + def load_data(self, file_path): + """ + 从数据文件中加载数据,每条记录对应一个实例 + """ + ds = load_dataset(file_path, "all") + df = pd.DataFrame(ds['test']) + for _, row in df.iterrows(): + self.data.append(row.to_dict()) + + def get_prompt(self, record): + """ + 结合提示信息和记录数据生成完整的题目 + """ + options_str = "\n".join([f"{chr(65 + i)}. {opt}" for i, opt in enumerate(record['choices'])]) + prompt = hint + "\nQuestion: " + record['question'] + "\n" + options_str + "\nAnswer: '" + return prompt + + def post_processing(self, text): + """ + 对生成的文本进行后处理,提取最终答案(只返回最后一个字符) + """ + text = text.lstrip('\n').split('\n')[-1] + return text[-1:] + + def score(self, pred, answer): + """ + 对比预测答案和正确答案,返回得分 + """ + if pred == answer: + return 1 + return 0 + +def generate_text(api_url, question, model_name, stream=False): + headers = { + 'accept': 'application/json', + 'Content-Type': 'application/json', + 'Authorization': 'Bearer ' # 如有需要,请填入 API Key + } + data = { + "messages": [{"content": question, "role": "user"}], + "model": model_name, + "stream": stream, + } + print("POST data:", data) + response = requests.post(api_url, headers=headers, json=data, timeout=5000000) + if response.status_code == 200: + result = response.json() + return result.get('choices', [{}])[0].get('message', {}).get('content', '').strip() + else: + print(f"API Request failed with status code {response.status_code}") + return None + +def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name): + start_total_time = time.time() + total_score = 0 + results = [] + file_lock = threading.Lock() + + # 打乱数据顺序,并选择需要测试的实例数 + random.seed(42) + random.shuffle(data_evaluator.data) + data_subset = data_evaluator.data[:min(concurrent_requests, len(data_evaluator.data))] + + batch_size = 10 # 每批次最多 10 个实例 + + def worker(index, data_item): + nonlocal total_score + question = data_evaluator.get_prompt(data_item) + start_time = time.time() + try: + prediction = generate_text(api_url, question, model_name) + if prediction is None: + raise Exception(f"Failed to get prediction for question: {question}") + # 正确答案:将数字转换成字母(0->A, 1->B, 2->C, 3->D) + answer = chr(data_item['answer'] + 65) + processed_prediction = data_evaluator.post_processing(prediction) + score = data_evaluator.score(processed_prediction, answer) + elapsed_time = time.time() - start_time + result_data = { + "question_id": index, + "answer": answer, + "prediction": processed_prediction, + "real_prediction": prediction, + "score": score, + "time": elapsed_time + } + # 写入结果时加锁保证线程安全 + with file_lock: + with open(result_file, 'a', encoding='utf-8') as f: + json.dump(result_data, f, ensure_ascii=False, indent=4) + f.write("\n") + return result_data + except Exception as e: + print(f"Error processing request {index}: {e}") + return None + + # 按批次处理,每批最多 10 个任务 + for batch_start in range(0, len(data_subset), batch_size): + batch = data_subset[batch_start: batch_start + batch_size] + with concurrent.futures.ThreadPoolExecutor(max_workers=batch_size) as executor: + futures = [executor.submit(worker, batch_start + j, data_item) for j, data_item in enumerate(batch)] + for future in concurrent.futures.as_completed(futures): + res = future.result() + if res is not None: + results.append(res) + total_score += res['score'] + + total_time = time.time() - start_total_time + throughput = len(data_subset) / total_time if total_time > 0 else 0 + + with open(log_file, 'a', encoding='utf-8') as log_f: + log_f.write(f"Total Time: {total_time:.2f} seconds\n") + log_f.write(f"Throughput: {throughput:.2f} requests per second\n") + average_score = total_score / len(data_subset) if data_subset else 0 + log_f.write(f"Average Score: {average_score}\n") + log_f.write('-' * 40 + '\n') + + print(f"Results saved to {result_file}") + print(f"Log saved to {log_file}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="API Generate Tester") + parser.add_argument("--concurrent", type=int, default=1000, help="需要测试的实例总数") + parser.add_argument("--file", type=str, default="cais/mmlu", help="数据文件路径") + parser.add_argument("--result", type=str, default="./mmlu_result_silicon.json", help="结果文件保存路径") + parser.add_argument("--log", type=str, default="./mmlu_result_silicon.log", help="日志文件保存路径") + parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="模型名称或路径") + parser.add_argument("--api_url", type=str, default="http://localhost:10006/v1/chat/completions", help="API URL") + + args = parser.parse_args() + + data_evaluator = DataEvaluator() + data_evaluator.load_data(args.file) + + main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model) diff --git a/ktransformers/tests/test_client.py b/ktransformers/tests/test_client.py new file mode 100644 index 0000000..67e0bd0 --- /dev/null +++ b/ktransformers/tests/test_client.py @@ -0,0 +1,115 @@ +import asyncio +import json +import sys +import aiohttp +import random +import argparse +import yaml +import os +import time +from time import sleep + +decodesz = 128 +# Server URL (replace with your server URL) +SERVER_URL = "http://localhost:10002/v1/chat/completions" +bf_list = [1] +decodesz_list = [128] +prompt_list = ['请你介绍下秦始皇', '3.9 和 3.11 哪个大', '抗衰老有何妙招', '给我讲个故事'] +async def fetch_event_stream(session, request_id): + try: + payload = { + "messages": [ + {"role": "system", "content": ""}, + {"role": "user", "content": prompt_list[request_id]} + ], + "model": "DeepSeek-V3", + "temperature": 0.3, + "top_p": 1.0, + "stream": True # 开启流式输出 + } + + headers = { + 'accept': 'application/json', + 'Content-Type': 'application/json' + } + + async with session.post(SERVER_URL, json=payload, headers=headers, timeout=50000) as response: + print(f"Request {request_id}: Connected, status {response.status}") + + if response.status != 200: + print(f"Request {request_id}: Error, status {response.status}") + return + + output_text = "" # 存储当前 response 的所有 token + total_tokens = 0 # 统计总 tokens 数 + decode_start_time = None # 记录 decode 阶段开始时间 + decode_end_time = None # 记录 decode 结束时间 + + async for line in response.content: + try: + decoded_line = line.decode("utf-8").strip() + + # 过滤空行 + if not decoded_line or not decoded_line.startswith("data: "): + continue + + decoded_line = decoded_line[6:].strip() # 去掉 `data: ` + + # 确保 JSON 数据是合法的 + if not decoded_line: + continue + + response_data = json.loads(decoded_line) # 解析 JSON + + # 确保 choices 存在 + choices = response_data.get("choices", []) + if not choices: + continue + + delta = choices[0].get("delta", {}) + token = delta.get("content", "") + + if token: + if decode_start_time is None: + decode_start_time = time.time() # 记录 decode 开始时间 + + output_text += token # 追加 token + sys.stdout.write(token) # 直接输出 token + sys.stdout.flush() # 立即刷新,确保 token 立刻出现在终端 + total_tokens += 1 # 增加 token 计数 + decode_end_time = time.time() # 每次收到 token,更新 decode 结束时间 + + # 检查是否完成 + finish_reason = choices[0].get("finish_reason", None) + if finish_reason: + # print(f"\nRequest {request_id}: Done") + break # 结束流式处理 + + except json.JSONDecodeError as e: + print(f"\nRequest {request_id}: JSON Decode Error - {e}") + except IndexError: + print(f"\nRequest {request_id}: List Index Error - choices is empty") + except Exception as e: + print(f"\nRequest {request_id}: Error parsing stream - {e}") + + # 计算 decode 速度 + if decode_start_time and decode_end_time and total_tokens > 0: + decode_time = decode_end_time - decode_start_time + decode_speed = total_tokens / decode_time if decode_time > 0 else 0 + # print(f"Request {request_id}: Decode Speed = {decode_speed:.2f} tokens/s") + + except Exception as e: + print(f"\nRequest {request_id}: Exception - {e}") + +async def main(prompt_id): + async with aiohttp.ClientSession() as session: + tasks = [fetch_event_stream(session, prompt_id)] + await asyncio.gather(*tasks) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="Event Stream Request Tester") + parser.add_argument("--question_id", type=int, default=0, required=False) + args = parser.parse_args() + output_file = "ktransformer_test_results.txt" + asyncio.run(main(args.question_id)) diff --git a/ktransformers/tests/test_speed.py b/ktransformers/tests/test_speed.py new file mode 100644 index 0000000..b20f0c0 --- /dev/null +++ b/ktransformers/tests/test_speed.py @@ -0,0 +1,146 @@ +import asyncio +import json +import sys +import aiohttp +import random +import argparse +import yaml +import os +import time +from time import sleep + +decodesz = 128 +# Server URL (replace with your server URL) +decodesz_list = [128] +ktansformer_prompt1024="""在遥远的翡翠森林里,住着各种各样的神奇生物。其中,有一只名叫露露的小狐狸,她与其他狐狸不同,天生长着一双晶莹剔透的翅膀。然而,这双翅膀却从未带她飞翔过。 + 一天,森林里传来一个惊人的消息:藏在森林深处的魔法泉水干涸了,所有生物赖以生存的泉水即将枯竭。他们说,只有传说中的“天空之羽”才能唤醒泉水,让它重新流淌。然而,“天空之羽”藏在一座高耸入云的山峰上,没有任何动物能抵达那里。 + 露露听到这个消息后,决定亲自去寻找“天空之羽”,即便她的翅膀无法飞翔,她也要尝试。最终,露露来到了传说中的高峰脚下,根本无法攀爬。她望着天空,心里充满了不甘:“如果我能飞起来,就不会被这座山挡住了……” + 正当她感到迷茫时,一只年迈的白鹰出现在她面前。 + “孩子,你为什么到这里来?”白鹰用苍老但慈祥的声音问道。 + 露露将森林的困境告诉了白鹰,并说自己愿意付出一切,只要能拯救森林。 + 白鹰沉思了一会儿,缓缓说道:“你的翅膀并不是没有力量,而是你一直害怕它们不能飞翔。相信自己,勇敢跳下去。” + 露露听后,心跳加速,她望着万丈深渊,犹豫不决就在那一瞬间,她竟然真的飞了起来!露露兴奋极了,她终于看到了“天空之羽”——一根散发着金光的羽毛,轻盈地悬浮在空中。露露小心翼翼地将“天空之羽”叼住,振翅返回森林。 + 当她将羽毛放入干涸的泉水中时,一道金光闪耀。整个森林恢复了生机,花草重新绽放,动物们欢欣鼓舞。从那以后,露露成为了森林的英雄,她是翱翔天空的勇士。她让所有动物都明白:只要相信自己,勇敢前行,就能实现自己的梦想。 + 请简述这个故事的内涵 写10000个字。 + 在遥远的翡翠森林里,住着各种各样的神奇生物。其中,有一只名叫露露的小狐狸,她与其他狐狸不同,天生长着一双晶莹剔透的翅膀。然而,这双翅膀却从未带她飞翔过。 + 一天,森林里传来一个惊人的消息:藏在森林深处的魔法泉水干涸了,所有生物赖以生存的泉水即将枯竭。他们说,只有传说中的“天空之羽”才能唤醒泉水,让它重新流淌。然而,“天空之羽”藏在一座高耸入云的山峰上,没有任何动物能抵达那里。 + 露露听到这个消息后,决定亲自去寻找“天空之羽”,即便她的翅膀无法飞翔,她也要尝试。最终,露露来到了传说中的高峰脚下,根本无法攀爬。她望着天空,心里充满了不甘:“如果我能飞起来,就不会被这座山挡住了……” + 正当她感到迷茫时,一只年迈的白鹰出现在她面前。 + “孩子,你为什么到这里来?”白鹰用苍老但慈祥的声音问道。 + 露露将森林的困境告诉了白鹰,并说自己愿意付出一切,只要能拯救森林。 + 白鹰沉思了一会儿,缓缓说道:“你的翅膀并不是没有力量,而是你一直害怕它们不能飞翔。相信自己,勇敢跳下去。” + 露露听后,心跳加速,她望着万丈深渊,犹豫不决就在那一瞬间,她竟然真的飞了起来!露露兴奋极了,她终于看到了“天空之羽”——一根散发着金光的羽毛,轻盈地悬浮在空中。露露小心翼翼地将“天空之羽”叼住,振翅返回森林。 + 当她将羽毛放入干涸的泉水中时,一道金光闪耀。整个森林恢复了生机,花草重新绽放,动物们欢欣鼓舞。从那以后,露露成为了森林的英雄,她是翱翔天空的勇士。她让所有动物都明白:只要相信自己,勇敢前行,就能实现自己的梦想。 + 请简述这个故事的内涵 写10000个字。 + 露露将森林的困境告诉了白鹰,并说自己愿意付出一切,只要能拯救森林。 + 白鹰沉思了一会儿,缓缓说道:“你的翅膀并不是没有力量,而是你一直害怕它们不能飞翔。相信自己,勇敢跳下去。” + 露露听后,心跳加速,她望着万丈深渊,犹豫不决就在那一瞬间,她竟然真的飞了起来!露露兴奋极了,她终于看到了“天空之羽”——一根散发着金光的羽毛,轻盈地悬浮在空中。露露小心翼翼地将“天空之羽”叼住,振翅返回森林。 + 当她将羽毛放入干涸的泉水中时,一道金光闪耀。整个森林恢复了生机,花草重新绽放,动物们欢欣鼓舞。从那以后,露露成为了森林的英雄,她是翱翔天空的勇士。她让所有动物都明白:只要相信自己,勇敢前行,就能实现自己的梦想。 + 请简述这个故事的内涵 写10000个字。想。 + 请简述这个故事的内涵 故事的内涵这个故事的内涵写10000个字""" +async def fetch_event_stream(session, request_id , prompt): + try: + payload = { + "messages": [ + {"role": "system", "content": ""}, + {"role": "user", "content": prompt} + ], + "model": "DeepSeek-V3", + "temperature": 0.3, + "top_p": 1.0, + "stream": True # 开启流式输出 + } + + headers = { + 'accept': 'application/json', + 'Content-Type': 'application/json' + } + + async with session.post(SERVER_URL, json=payload, headers=headers, timeout=500000) as response: + print(f"Request {request_id}: Connected, status {response.status}") + + if response.status != 200: + print(f"Request {request_id}: Error, status {response.status}") + return + + output_text = "" # 存储当前 response 的所有 token + total_tokens = 0 # 统计总 tokens 数 + decode_start_time = None # 记录 decode 阶段开始时间 + decode_end_time = None # 记录 decode 结束时间 + + async for line in response.content: + try: + decoded_line = line.decode("utf-8").strip() + + # 过滤空行 + if not decoded_line or not decoded_line.startswith("data: "): + continue + + decoded_line = decoded_line[6:].strip() # 去掉 `data: ` + + # 确保 JSON 数据是合法的 + if not decoded_line: + continue + + response_data = json.loads(decoded_line) # 解析 JSON + + # 确保 choices 存在 + choices = response_data.get("choices", []) + if not choices: + continue + + delta = choices[0].get("delta", {}) + token = delta.get("content", "") + + if token: + if decode_start_time is None: + decode_start_time = time.time() # 记录 decode 开始时间 + + output_text += token # 追加 token + sys.stdout.write(str(request_id)) + sys.stdout.write(token) # 直接输出 token + sys.stdout.flush() # 立即刷新,确保 token 立刻出现在终端 + total_tokens += 1 # 增加 token 计数 + decode_end_time = time.time() # 每次收到 token,更新 decode 结束时间 + + # 检查是否完成 + finish_reason = choices[0].get("finish_reason", None) + if finish_reason: + # print(f"\nRequest {request_id}: Done") + break # 结束流式处理 + + except json.JSONDecodeError as e: + print(f"\nRequest {request_id}: JSON Decode Error - {e}") + except IndexError: + print(f"\nRequest {request_id}: List Index Error - choices is empty") + except Exception as e: + print(f"\nRequest {request_id}: Error parsing stream - {e}") + + # 计算 decode 速度 + if decode_start_time and decode_end_time and total_tokens > 0: + decode_time = decode_end_time - decode_start_time + decode_speed = total_tokens / decode_time if decode_time > 0 else 0 + # print(f"Request {request_id}: Decode Speed = {decode_speed:.2f} tokens/s") + + except Exception as e: + print(f"\nRequest {request_id}: Exception - {e}") + +async def main(concurrent_requests , prompt ): + async with aiohttp.ClientSession() as session: + tasks = [fetch_event_stream(session, i , prompt) for i in range(concurrent_requests)] + await asyncio.gather(*tasks) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Event Stream Request Tester") + parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests") + parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048") + parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL") + + args = parser.parse_args() + SERVER_URL = args.api_url + if args.prompt_lens == 1024: + prompt = ktansformer_prompt1024 + elif args.prompt_lens == 2048: + prompt = ktansformer_prompt1024 * 2 + asyncio.run(main(args.concurrent, prompt)) + diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 6f3b049..bb21b1c 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -18,9 +18,26 @@ from ktransformers.models.custom_cache import StaticCache from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.util.textstream import TextStreamer from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton +import socket warm_uped = False +def get_free_ports(n: int, continue_prot: list): + sockets = [] + ports = [] + for _ in range(n): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + port = s.getsockname()[1] + if port in continue_prot: + s.close() + continue + ports.append(port) + sockets.append(s) + for s in sockets: + s.close() + return ports + def get_compute_capability(device:torch.device = None): if torch.cuda.is_available(): if device is None: @@ -110,7 +127,7 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''): module.load() def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True, - mode = 'normal', force_think: bool = False, chunk_prefill_size = 16384, use_flashinfer_mla = False, + mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False, num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None): import os os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -202,11 +219,11 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud chunk_start = 0 while chunk_start < seq_length: - chunk_end = min(chunk_start + chunk_prefill_size, seq_length) + chunk_end = min(chunk_start + chunk_size, seq_length) if past_key_values != None: past_key_values.cur_idx=cache_position[chunk_start:chunk_end] logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values) - chunk_start += chunk_prefill_size + chunk_start += chunk_size next_token_scores = logits_warper(inputs, logits[:, -1, :]) if generation_config.do_sample: diff --git a/merge_tensors/merge_safetensor_gguf.py b/merge_tensors/merge_safetensor_gguf.py index 69780fe..efeab3b 100644 --- a/merge_tensors/merge_safetensor_gguf.py +++ b/merge_tensors/merge_safetensor_gguf.py @@ -3,7 +3,7 @@ import os # insert the path of the project import sys -sys.path.insert(0, "/home/azure/ktransformers") +# sys.path.insert(0, "/home/azure/ktransformers") import argparse import torch from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf diff --git a/requirements-local_chat.txt b/requirements-local_chat.txt index ad280c0..855b360 100644 --- a/requirements-local_chat.txt +++ b/requirements-local_chat.txt @@ -6,4 +6,4 @@ packaging cpufeature protobuf tiktoken -blobfile \ No newline at end of file +blobfile diff --git a/setup.py b/setup.py index 5c29b8f..e13ceb7 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,8 @@ try: from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME except ImportError: MUSA_HOME=None + +with_balance = os.environ.get("USE_BALANCE_SERVE", "0") == "1" class CpuInstructInfo: CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE") @@ -212,7 +214,7 @@ class VersionInfo: cpu_instruct = self.get_cpu_instruct() backend_version = "" if CUDA_HOME is not None: - backend_version = f"" + backend_version = f"cu{self.get_cuda_bare_metal_version(CUDA_HOME)}" elif MUSA_HOME is not None: backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}" elif ROCM_HOME is not None: @@ -274,11 +276,10 @@ PLAT_TO_CMAKE = { class CMakeExtension(Extension): - def __init__(self, name: str, sourcedir: str = "") -> None: + def __init__(self, name: str, sourcedir: str) -> None: super().__init__(name, sources=[]) - self.sourcedir = os.fspath( - Path(sourcedir).resolve() / "ktransformers" / "ktransformers_ext") - + print(name, sourcedir) + self.sourcedir = sourcedir class CMakeBuild(BuildExtension): @@ -342,16 +343,17 @@ class CMakeBuild(BuildExtension): f"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}"] if self.compiler.compiler_type != "msvc": if not cmake_generator or cmake_generator == "Ninja": - try: - import ninja + pass + # try: + # import ninja - ninja_executable_path = Path(ninja.BIN_DIR) / "ninja" - cmake_args += [ - "-GNinja", - f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}", - ] - except ImportError: - pass + # ninja_executable_path = Path(ninja.BIN_DIR) / "ninja" + # cmake_args += [ + # "-GNinja", + # f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}", + # ] + # except ImportError: + # pass else: # Single config generators are handled "normally" @@ -387,10 +389,12 @@ class CMakeBuild(BuildExtension): build_args += [f"--parallel={cpu_count}"] print("CMake args:", cmake_args) build_temp = Path(ext.sourcedir) / "build" + print("build_temp:", build_temp) + if not build_temp.exists(): build_temp.mkdir(parents=True) result = subprocess.run( - ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True , capture_output=True + ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True , capture_output=True, text=True ) print("Standard output:", result.stdout) print("Standard error:", result.stderr) @@ -400,9 +404,9 @@ class CMakeBuild(BuildExtension): if CUDA_HOME is not None or ROCM_HOME is not None: ops_module = CUDAExtension('KTransformersOps', [ - 'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu', - 'ktransformers/ktransformers_ext/cuda/binding.cpp', - 'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu' + 'csrc/ktransformers_ext/cuda/custom_gguf/dequant.cu', + 'csrc/ktransformers_ext/cuda/binding.cpp', + 'csrc/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu' ], extra_compile_args={ 'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'], @@ -415,7 +419,7 @@ if CUDA_HOME is not None or ROCM_HOME is not None: } ) elif MUSA_HOME is not None: - SimplePorting(cuda_dir_path="ktransformers/ktransformers_ext/cuda", mapping_rule={ + SimplePorting(cuda_dir_path="csrc/ktransformers_ext/cuda", mapping_rule={ # Common rules "at::cuda": "at::musa", "#include ": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"", @@ -423,10 +427,10 @@ elif MUSA_HOME is not None: "nv_bfloat16": "mt_bfloat16", }).run() ops_module = MUSAExtension('KTransformersOps', [ - 'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu', - 'ktransformers/ktransformers_ext/cuda_musa/binding.cpp', + 'csrc/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu', + 'csrc/ktransformers_ext/cuda_musa/binding.cpp', # TODO: Add Marlin support for MUSA. - # 'ktransformers/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu' + # 'csrc/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu' ], extra_compile_args={ 'cxx': ['force_mcc'], @@ -440,12 +444,30 @@ elif MUSA_HOME is not None: else: raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.") +ext_modules = [ + CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")), + ops_module, + CUDAExtension( + 'vLLMMarlin', [ + 'csrc/custom_marlin/binding.cpp', + 'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu', + 'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu', + ], + extra_compile_args={ + 'cxx': ['-O3'], + 'nvcc': ['-O3', '-Xcompiler', '-fPIC'], + }, + ) +] +if with_balance: + print("using balance_serve") + ext_modules.append( + CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve")) + ) + setup( name=VersionInfo.PACKAGE_NAME, version=VersionInfo().get_package_version(), cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild}, - ext_modules=[ - CMakeExtension("cpuinfer_ext"), - ops_module, - ] + ext_modules=ext_modules )