diff --git a/.gitignore b/.gitignore index 718ea55..1bb8666 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,7 @@ node_modules .DS_Store compile_commands.json *.egg-info* -*dist/ \ No newline at end of file +*dist/ +ktransformers/server/local_store/ +ktransformers/server_test1.db +*.patch \ No newline at end of file diff --git a/install.bat b/install.bat new file mode 100644 index 0000000..dc429e4 --- /dev/null +++ b/install.bat @@ -0,0 +1,16 @@ +@echo off + +REM clear build dirs +rmdir /S /Q ktransformers\ktransformers_ext\build +rmdir /S /Q ktransformers\ktransformers_ext\cuda\build +rmdir /S /Q ktransformers\ktransformers_ext\cuda\dist +rmdir /S /Q ktransformers\ktransformers_ext\out +del /F /Q ktransformers\ktransformers_ext\cuda\*.egg-info + +echo Installing python dependencies from requirements.txt +pip install -r requirements-local_chat.txt + +echo Installing ktransformers +set KTRANSFORMERS_FORCE_BUILD=TRUE +pip install . --no-build-isolation +echo Installation completed successfully \ No newline at end of file diff --git a/install.sh b/install.sh index fa5ba18..ffb7aca 100644 --- a/install.sh +++ b/install.sh @@ -11,5 +11,5 @@ echo "Installing python dependencies from requirements.txt" pip install -r requirements-local_chat.txt echo "Installing ktransformers" -pip install . --no-build-isolation +KTRANSFORMERS_FORCE_BUILD=TRUE pip install . --no-build-isolation echo "Installation completed successfully" \ No newline at end of file diff --git a/ktransformers/ktransformers_ext/CMakeLists.txt b/ktransformers/ktransformers_ext/CMakeLists.txt index c3d4f5b..e6e0518 100644 --- a/ktransformers/ktransformers_ext/CMakeLists.txt +++ b/ktransformers/ktransformers_ext/CMakeLists.txt @@ -189,7 +189,13 @@ else() message(STATUS "Unknown architecture") endif() -find_package(CUDA REQUIRED) +# message(STATUS "CUDAToolkit_ROOT:${CUDAToolkit_ROOT}") +# find_package(FindCUDAToolkit REQUIRED) +# if(CUDAToolkit_FOUND) +# message(STATUS "Found CUDA cudart lib at:${CUDAToolkit_LIBRARY_DIR}") +# else() +# message(STATUS "Can't found CUDA lib") +# endif() add_compile_options("$<$:${ARCH_FLAGS}>") add_compile_options("$<$:${ARCH_FLAGS}>") @@ -198,7 +204,12 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/pybind11 ${CMAKE_ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llama.cpp ${CMAKE_CURRENT_BINARY_DIR}/third_party/llama.cpp) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party) -include_directories("${CUDA_INCLUDE_DIRS}") +if (WIN32) + include_directories("$ENV{CUDA_PATH}/include") +elseif (UNIX) + find_package(CUDA REQUIRED) + include_directories("${CUDA_INCLUDE_DIRS}") +endif() aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2) @@ -209,4 +220,8 @@ message(STATUS "ALL_SOURCES: ${ALL_SOURCES}") pybind11_add_module(${PROJECT_NAME} MODULE ${ALL_SOURCES}) target_link_libraries(${PROJECT_NAME} PRIVATE llama) -target_link_libraries(${PROJECT_NAME} PRIVATE "/usr/local/cuda/lib64/libcudart.so") \ No newline at end of file +if(WIN32) + target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart +elseif(UNIX) + target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so") +endif() \ No newline at end of file diff --git a/ktransformers/ktransformers_ext/cpu_backend/task_queue.h b/ktransformers/ktransformers_ext/cpu_backend/task_queue.h index b4212fd..d4e6d8a 100644 --- a/ktransformers/ktransformers_ext/cpu_backend/task_queue.h +++ b/ktransformers/ktransformers_ext/cpu_backend/task_queue.h @@ -3,8 +3,8 @@ * @Author : chenht2022 * @Date : 2024-07-16 10:43:18 * @Version : 1.0.0 - * @LastEditors : chenht2022 - * @LastEditTime : 2024-07-25 10:33:47 + * @LastEditors : chenxl + * @LastEditTime : 2024-08-08 04:23:51 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #ifndef CPUINFER_TASKQUEUE_H @@ -17,6 +17,44 @@ #include #include #include +#ifdef _WIN32 +#include +#endif + +class custom_mutex { +private: +#ifdef _WIN32 + HANDLE global_mutex; +#else + std::mutex global_mutex; +#endif + +public: + custom_mutex() + { +#ifdef _WIN32 + HANDLE global_mutex; +#endif + } + + void lock() + { +#ifdef _WIN32 + WaitForSingleObject(global_mutex, INFINITE); +#else + global_mutex.lock(); +#endif + } + + void unlock() + { +#ifdef _WIN32 + ReleaseMutex(global_mutex); +#else + global_mutex.lock(); +#endif + } +}; class TaskQueue { public: @@ -32,7 +70,7 @@ class TaskQueue { std::queue> tasks; std::thread worker; - std::mutex mutex; + custom_mutex mutex; std::atomic sync_flag; std::atomic exit_flag; }; diff --git a/ktransformers/ktransformers_ext/cuda/binding.cpp b/ktransformers/ktransformers_ext/cuda/binding.cpp index 2d5da68..f17382d 100644 --- a/ktransformers/ktransformers_ext/cuda/binding.cpp +++ b/ktransformers/ktransformers_ext/cuda/binding.cpp @@ -3,8 +3,8 @@ * @Author : Azure-Tang * @Date : 2024-07-25 13:38:30 * @Version : 1.0.0 - * @LastEditors : Azure - * @LastEditTime : 2024-07-26 08:36:03 + * @LastEditors : kkk1nak0 + * @LastEditTime : 2024-08-09 01:45:02 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ @@ -23,6 +23,8 @@ PYBIND11_MODULE(KTransformersOps, m) { py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.", py::arg("data"), py::arg("blk_size"), py::arg("device")); + m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.", + py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.", py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.", diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp b/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp index ea52e8f..2cb46fc 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp @@ -12,12 +12,15 @@ int test(){ } torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device); +torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device); PYBIND11_MODULE(cudaops, m) { m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.", py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.", py::arg("data"), py::arg("blk_size"), py::arg("device")); + m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.", + py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.", py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("test", &test, "Function to test."); diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/custom_ggml.h b/ktransformers/ktransformers_ext/cuda/custom_gguf/custom_ggml.h deleted file mode 100644 index 333dc69..0000000 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/custom_ggml.h +++ /dev/null @@ -1,39 +0,0 @@ - - - -#include - - -__device__ float ggml_compute_fp16_to_fp32(uint16_t h) { - return __uint2float_rd(h); -} - -static inline float ggml_compute_fp16_to_fp32(uint16_t h) { - uint16_t tmp; - memcpy(&tmp, &h, sizeof(ggml_fp16_t)); - return (float)tmp; -} - -// define the global table for fp16 to fp32 conversion -__device__ float ggml_table_f32_f16[1 << 16]; - -// CUDA Kernel to init the table -__global__ void init_fp16_to_fp32_table() { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - for (auto blk_id = idx; blk_id<(1 << 16); blk_id+=blockDim.x * gridDim.x){ - ggml_table_f32_f16[blk_id] = GGML_COMPUTE_FP16_TO_FP32(blk_id); - } -} - -#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) - -extern __device__ float ggml_table_f32_f16[1 << 16]; // Declare as __device__ if used within device code - -// This version of the function is designed to be called from within a CUDA kernel -#if !defined(GGML_FP16_TO_FP32) -__device__ float ggml_lookup_fp16_to_fp32(uint16_t f) { - return ggml_table_f32_f16[f]; -} - -#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x) -#endif \ No newline at end of file diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu index 38f4842..aaa6453 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu @@ -3,8 +3,8 @@ * @Author : Azure-Tang, Boxin Zhang * @Date : 2024-07-25 13:38:30 * @Version : 1.0.0 - * @LastEditors : Azure - * @LastEditTime : 2024-07-26 11:58:50 + * @LastEditors : kkk1nak0 + * @LastEditTime : 2024-08-09 07:57:06 * Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c * Copyright (c) 2023-2024 The ggml authors * Copyright (c) 2024 by KVCache.AI, All Rights Reserved. @@ -14,6 +14,7 @@ #include #include #include +#include __global__ void dequantize_q8_0_kernel(float* output, const float* scales, const int8_t* qs, int num_blocks, int blk_size) { int global_idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -59,6 +60,35 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size } } +__global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { + int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (auto block_id=global_idx; block_id(data + block_id * blk_size + 0))); + const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 2))); + + const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16); + const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48); + + int is = 0; + uint8_t sc, m; + uint8_t u1 = 1, u2 = 2; + uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4); + + for (int j = 0; j < 256; j += 64) { + get_scale_min_k4(is + 0, scales, &sc, &m); + const float d1 = d * sc; const float m1 = min * m; + get_scale_min_k4(is + 1, scales, &sc, &m); + const float d2 = d * sc; const float m2 = min * m; + for (int l = 0; l < 32; ++l) *output_blk++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1; + for (int l = 0; l < 32; ++l) *output_blk++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2; + ql += 32; is += 2; + u1 <<= 2; u2 <<= 2; + } + } +} + __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { int global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (auto block_id=global_idx; block_id>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + + cudaDeviceSynchronize(); + return output; +} \ No newline at end of file diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h b/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h index 9af8f30..f5fde87 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h @@ -3,8 +3,8 @@ * @Author : Azure-Tang * @Date : 2024-07-22 09:27:55 * @Version : 1.0.0 - * @LastEditors : Azure - * @LastEditTime : 2024-07-26 08:38:20 + * @LastEditors : kkk1nak0 + * @LastEditTime : 2024-08-09 01:44:21 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #pragma once @@ -15,4 +15,5 @@ torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device); +torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device); \ No newline at end of file diff --git a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu index e8e5153..54e538a 100644 --- a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu +++ b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu @@ -23,7 +23,7 @@ */ #include "gptq_marlin.cuh" #include "gptq_marlin_dtypes.cuh" - +#include #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ static_assert(std::is_same::value || \ std::is_same::value, \ @@ -1703,28 +1703,63 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, thread_m_blocks = exec_cfg.max_m_blocks; } + + // Define kernel configurations - if (false) { +#define undefined_error TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + \ + str(prob_n) + ", " + str(prob_k) + "]" + \ + ", has_act_order = " + str(has_act_order) + \ + ", num_groups = " + str(num_groups) + \ + ", group_size = " + str(group_size) + \ + ", thread_m_blocks = " + str(thread_m_blocks) + \ + ", thread_n_blocks = " + str(thread_n_blocks) + \ + ", thread_k_blocks = " + str(thread_k_blocks)); + + + if (num_bits == 4 && num_threads == 256) + { + if (false) { + } + CALL_IF(4, 32, 2, 256) + CALL_IF(4, 16, 4, 256) + CALL_IF(4, 8, 8, 256) + else { + undefined_error + } + } + else if (num_bits == 4 && num_threads == 128) + { + if (false) { + } + CALL_IF(4, 8, 4, 128) + CALL_IF(4, 4, 8, 128) + else { + undefined_error + } + } + else if (num_bits == 8 && num_threads == 256) + { + if (false) { + } + CALL_IF(8, 32, 2, 256) + CALL_IF(8, 16, 4, 256) + CALL_IF(8, 8, 8, 256) + else { + undefined_error + } + } + else if (num_bits == 8 && num_threads == 128) + { + if (false) { + } + CALL_IF(8, 8, 4, 128) + CALL_IF(8, 4, 8, 128) + else { + undefined_error + } } - CALL_IF(4, 32, 2, 256) - CALL_IF(4, 16, 4, 256) - CALL_IF(4, 8, 8, 256) - CALL_IF(4, 8, 4, 128) - CALL_IF(4, 4, 8, 128) - CALL_IF(8, 32, 2, 256) - CALL_IF(8, 16, 4, 256) - CALL_IF(8, 8, 8, 256) - CALL_IF(8, 8, 4, 128) - CALL_IF(8, 4, 8, 128) else { - TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + - str(prob_n) + ", " + str(prob_k) + "]" + - ", has_act_order = " + str(has_act_order) + - ", num_groups = " + str(num_groups) + - ", group_size = " + str(group_size) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); + undefined_error } A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; @@ -1739,6 +1774,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& perm, torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); // Verify num_bits TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); @@ -1781,7 +1817,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); // Alloc buffers - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); torch::Tensor c = torch::empty({size_m, size_n}, options); torch::Tensor a_tmp = torch::empty({size_m, size_k}, options); diff --git a/ktransformers/ktransformers_ext/cuda/setup.py b/ktransformers/ktransformers_ext/cuda/setup.py index baf0808..156bb0e 100644 --- a/ktransformers/ktransformers_ext/cuda/setup.py +++ b/ktransformers/ktransformers_ext/cuda/setup.py @@ -2,17 +2,25 @@ from setuptools import setup, Extension from torch.utils import cpp_extension from torch.utils.cpp_extension import BuildExtension, CUDAExtension - -# setup marlin gemm -setup(name='KTransformersOps', - ext_modules=[ - CUDAExtension('KTransformersOps', [ +setup( + name='KTransformersOps', + ext_modules=[ + CUDAExtension( + 'KTransformersOps', [ 'custom_gguf/dequant.cu', 'binding.cpp', 'gptq_marlin/gptq_marlin.cu', # 'gptq_marlin_repack.cu', - ]) - ], - cmdclass={'build_ext': BuildExtension -}) - + ], + extra_compile_args={ + 'cxx': ['-O3'], + 'nvcc': [ + '-O3', + '--use_fast_math', + '-Xcompiler', '-fPIC', + ] + }, + ) + ], + cmdclass={'build_ext': BuildExtension} +) \ No newline at end of file diff --git a/ktransformers/ktransformers_ext/ext_bindings.cpp b/ktransformers/ktransformers_ext/ext_bindings.cpp index ef30037..c220a9b 100644 --- a/ktransformers/ktransformers_ext/ext_bindings.cpp +++ b/ktransformers/ktransformers_ext/ext_bindings.cpp @@ -37,7 +37,7 @@ class LinearBindings { Args* args_ = (Args*)args; args_->cpuinfer->enqueue(&Linear::warm_up, args_->linear); } - static std::pair interface(Linear& linear) { + static std::pair cpuinfer_interface(Linear& linear) { Args* args = new Args{nullptr, &linear}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } @@ -55,7 +55,7 @@ class LinearBindings { Args* args_ = (Args*)args; args_->cpuinfer->enqueue(&Linear::forward, args_->linear, args_->qlen, args_->input, args_->output); } - static std::pair interface(Linear& linear, int qlen, intptr_t input, intptr_t output) { + static std::pair cpuinfer_interface(Linear& linear, int qlen, intptr_t input, intptr_t output) { Args* args = new Args{nullptr, &linear, qlen, (const void*)input, (void*)output}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } @@ -74,7 +74,7 @@ class MLPBindings { Args* args_ = (Args*)args; args_->cpuinfer->enqueue(&MLP::warm_up, args_->mlp); } - static std::pair interface(MLP& mlp) { + static std::pair cpuinfer_interface(MLP& mlp) { Args* args = new Args{nullptr, &mlp}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } @@ -92,7 +92,7 @@ class MLPBindings { Args* args_ = (Args*)args; args_->cpuinfer->enqueue(&MLP::forward, args_->mlp, args_->qlen, args_->input, args_->output); } - static std::pair interface(MLP& mlp, int qlen, intptr_t input, intptr_t output) { + static std::pair cpuinfer_interface(MLP& mlp, int qlen, intptr_t input, intptr_t output) { Args* args = new Args{nullptr, &mlp, qlen, (const void*)input, (void*)output}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } @@ -111,7 +111,7 @@ class MOEBindings { Args* args_ = (Args*)args; args_->cpuinfer->enqueue(&MOE::warm_up, args_->moe); } - static std::pair interface(MOE& moe) { + static std::pair cpuinfer_interface(MOE& moe) { Args* args = new Args{nullptr, &moe}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } @@ -132,7 +132,7 @@ class MOEBindings { Args* args_ = (Args*)args; args_->cpuinfer->enqueue(&MOE::forward, args_->moe, args_->qlen, args_->k, args_->expert_ids, args_->weights, args_->input, args_->output); } - static std::pair interface(MOE& moe, int qlen, int k, intptr_t expert_ids, intptr_t weights, intptr_t input, intptr_t output) { + static std::pair cpuinfer_interface(MOE& moe, int qlen, int k, intptr_t expert_ids, intptr_t weights, intptr_t input, intptr_t output) { Args* args = new Args{nullptr, &moe, qlen, k, (const uint64_t*)expert_ids, (const float*)weights, (const void*)input, (void*)output}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } @@ -154,8 +154,8 @@ PYBIND11_MODULE(cpuinfer_ext, m) { })); py::class_(linear_module, "Linear") .def(py::init()) - .def("warm_up", &LinearBindings::WarmUpBindinds::interface) - .def("forward", &LinearBindings::ForwardBindings::interface); + .def("warm_up", &LinearBindings::WarmUpBindinds::cpuinfer_interface) + .def("forward", &LinearBindings::ForwardBindings::cpuinfer_interface); auto mlp_module = m.def_submodule("mlp"); py::class_(mlp_module, "MLPConfig") @@ -164,8 +164,8 @@ PYBIND11_MODULE(cpuinfer_ext, m) { })); py::class_(mlp_module, "MLP") .def(py::init()) - .def("warm_up", &MLPBindings::WarmUpBindinds::interface) - .def("forward", &MLPBindings::ForwardBindings::interface); + .def("warm_up", &MLPBindings::WarmUpBindinds::cpuinfer_interface) + .def("forward", &MLPBindings::ForwardBindings::cpuinfer_interface); auto moe_module = m.def_submodule("moe"); py::class_(moe_module, "MOEConfig") @@ -174,6 +174,6 @@ PYBIND11_MODULE(cpuinfer_ext, m) { })); py::class_(moe_module, "MOE") .def(py::init()) - .def("warm_up", &MOEBindings::WarmUpBindinds::interface) - .def("forward", &MOEBindings::ForwardBindings::interface); + .def("warm_up", &MOEBindings::WarmUpBindinds::cpuinfer_interface) + .def("forward", &MOEBindings::ForwardBindings::cpuinfer_interface); } diff --git a/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq.py b/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq.py deleted file mode 100644 index cda3e7a..0000000 --- a/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq.py +++ /dev/null @@ -1,206 +0,0 @@ -import math -import os -import time -from logging import getLogger - -import torch -import torch.nn as nn -import transformers - -from .quantizer import Quantizer - - -logger = getLogger(__name__) - -torch.backends.cuda.matmul.allow_tf32 = False -torch.backends.cudnn.allow_tf32 = False - - -class GPTQ: - def __init__(self, layer): - self.layer = layer - self.dev = self.layer.weight.device - W = layer.weight.data.clone() - if isinstance(self.layer, nn.Conv2d): - W = W.flatten(1) - if isinstance(self.layer, transformers.pytorch_utils.Conv1D): - W = W.t() - self.rows = W.shape[0] - self.columns = W.shape[1] - self.H = torch.zeros((self.columns, self.columns), device=self.dev) - self.nsamples = 0 - self.quantizer = Quantizer() - - def add_batch(self, inp, out): - if os.environ.get("DEBUG"): - self.inp1 = inp - self.out1 = out - if len(inp.shape) == 2: - inp = inp.unsqueeze(0) - tmp = inp.shape[0] - if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): - if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[-1])) - inp = inp.t() - if isinstance(self.layer, nn.Conv2d): - unfold = nn.Unfold( - self.layer.kernel_size, - dilation=self.layer.dilation, - padding=self.layer.padding, - stride=self.layer.stride, - ) - inp = unfold(inp) - inp = inp.permute([1, 0, 2]) - inp = inp.flatten(1) - self.H *= self.nsamples / (self.nsamples + tmp) - self.nsamples += tmp - # inp = inp.float() - inp = math.sqrt(2 / self.nsamples) * inp.float() - # self.H += 2 / self.nsamples * inp.matmul(inp.t()) - self.H += inp.matmul(inp.t()) - - def fasterquant( - self, - blocksize=128, - percdamp=0.01, - group_size=-1, - actorder=False, - static_groups=False, - ): - W = self.layer.weight.data.clone() - if isinstance(self.layer, nn.Conv2d): - W = W.flatten(1) - if isinstance(self.layer, transformers.Conv1D): - W = W.t() - W = W.float() - - tick = time.time() - - if not self.quantizer.ready(): - self.quantizer.find_params(W, weight=True) - - H = self.H - del self.H - dead = torch.diag(H) == 0 - H[dead, dead] = 1 - W[:, dead] = 0 - - g_idx = [] - scale = [] - zero = [] - now_idx = 1 - - if static_groups: - import copy - - groups = [] - for i in range(0, self.columns, group_size): - quantizer = copy.deepcopy(self.quantizer) - quantizer.find_params(W[:, i : (i + group_size)], weight=True) - scale.append(quantizer.scale) - zero.append(quantizer.zero) - groups.append(quantizer) - - if actorder: - perm = torch.argsort(torch.diag(H), descending=True) - W = W[:, perm] - H = H[perm][:, perm] - invperm = torch.argsort(perm) - - Losses = torch.zeros_like(W) - Q = torch.zeros_like(W) - - damp = percdamp * torch.mean(torch.diag(H)) - diag = torch.arange(self.columns, device=self.dev) - H[diag, diag] += damp - H = torch.linalg.cholesky(H) - H = torch.cholesky_inverse(H) - H = torch.linalg.cholesky(H, upper=True) - Hinv = H - - for i1 in range(0, self.columns, blocksize): - i2 = min(i1 + blocksize, self.columns) - count = i2 - i1 - - W1 = W[:, i1:i2].clone() - Q1 = torch.zeros_like(W1) - Err1 = torch.zeros_like(W1) - Losses1 = torch.zeros_like(W1) - Hinv1 = Hinv[i1:i2, i1:i2] - - for i in range(count): - w = W1[:, i] - d = Hinv1[i, i] - - if group_size != -1: - if not static_groups: - if (i1 + i) % group_size == 0: - self.quantizer.find_params(W[:, (i1 + i) : (i1 + i + group_size)], weight=True) - - if ((i1 + i) // group_size) - now_idx == -1: - scale.append(self.quantizer.scale) - zero.append(self.quantizer.zero) - now_idx += 1 - else: - idx = i1 + i - if actorder: - idx = perm[idx] - self.quantizer = groups[idx // group_size] - - q = self.quantizer.quantize(w.unsqueeze(1)).flatten() - Q1[:, i] = q - Losses1[:, i] = (w - q) ** 2 / d**2 - - err1 = (w - q) / d - W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) - Err1[:, i] = err1 - - Q[:, i1:i2] = Q1 - Losses[:, i1:i2] = Losses1 / 2 - - W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) - - if os.environ.get("DEBUG"): - self.layer.weight.data[:, :i2] = Q[:, :i2] - self.layer.weight.data[:, i2:] = W[:, i2:] - logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) - logger.debug(torch.sum(Losses)) - - torch.cuda.synchronize() - logger.info(f"duration: {(time.time() - tick)}") - logger.info(f"avg loss: {torch.sum(Losses).item() / self.nsamples}") - - group_size = group_size if group_size != -1 else self.columns - if static_groups and actorder: - g_idx = [perm[i] // group_size for i in range(self.columns)] - else: - g_idx = [i // group_size for i in range(self.columns)] - g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) - if actorder: - Q = Q[:, invperm] - g_idx = g_idx[invperm] - - if isinstance(self.layer, transformers.Conv1D): - Q = Q.t() - self.layer.weight.data = Q.reshape(self.layer.weight.shape).type_as(self.layer.weight.data) - if os.environ.get("DEBUG"): - logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) - - if scale == []: - scale.append(self.quantizer.scale) - zero.append(self.quantizer.zero) - scale = torch.cat(scale, dim=1) - zero = torch.cat(zero, dim=1) - return scale, zero, g_idx - - def free(self): - if os.environ.get("DEBUG"): - self.inp1 = None - self.out1 = None - self.H = None - self.Losses = None - self.Trace = None - torch.cuda.empty_cache() - - -__all__ = ["GPTQ"] diff --git a/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq_marlin.py b/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq_marlin.py deleted file mode 100644 index 599070f..0000000 --- a/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq_marlin.py +++ /dev/null @@ -1,458 +0,0 @@ -import enum -from enum import Enum -from typing import Any, Dict, List, Optional - -import torch -from torch.nn.parameter import Parameter - -from vllm import _custom_ops as ops -from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - set_weight_attrs) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) - -logger = init_logger(__name__) - -GPTQ_MARLIN_TILE = 16 -GPTQ_MARLIN_MIN_THREAD_N = 64 -GPTQ_MARLIN_MIN_THREAD_K = 128 -GPTQ_MARLIN_MAX_PARALLEL = 16 - -GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8] -GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] -GPTQ_MARLIN_SUPPORTED_SYM = [True] - - -# Permutations for Marlin scale shuffling -def get_scale_perms(num_bits: int): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -def get_pack_factor(num_bits: int): - assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS - ), f"Unsupported num_bits = {num_bits}" - return 32 // num_bits - - -def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, - group_size: int, num_bits: int): - scale_perm, scale_perm_single = get_scale_perms(num_bits) - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - - return s - - -class GPTQMarlinConfig(QuantizationConfig): - """Config class for GPTQ Marlin""" - - def __init__(self, weight_bits: int, group_size: int, desc_act: bool, - is_sym: bool) -> None: - if desc_act and group_size == -1: - # In this case, act_order == True is the same as act_order == False - # (since we have only one group per output channel) - desc_act = False - - self.weight_bits = weight_bits - self.group_size = group_size - self.desc_act = desc_act - self.is_sym = is_sym - - # Verify - if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS: - raise ValueError( - f"Marlin does not support weight_bits = {self.weight_bits}. " - f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} " - "are supported.") - if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES: - raise ValueError( - f"Marlin does not support group_size = {self.group_size}. " - f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.") - if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM: - raise ValueError( - f"Marlin does not support is_sym = {self.is_sym}. " - f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.") - - # Init - self.pack_factor = get_pack_factor(weight_bits) - self.tile_size = GPTQ_MARLIN_TILE - self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N - self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K - self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL - - def __repr__(self) -> str: - return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, " - f"group_size={self.group_size}, " - f"desc_act={self.desc_act})") - - @classmethod - def get_name(cls) -> str: - return "gptq_marlin" - - @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.half, torch.bfloat16] - - @classmethod - def get_min_capability(cls) -> int: - return 80 - - @classmethod - def get_config_filenames(cls) -> List[str]: - return ["quantize_config.json"] - - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": - weight_bits = cls.get_from_keys(config, ["bits"]) - group_size = cls.get_from_keys(config, ["group_size"]) - desc_act = cls.get_from_keys(config, ["desc_act"]) - is_sym = cls.get_from_keys(config, ["sym"]) - return cls(weight_bits, group_size, desc_act, is_sym) - - @classmethod - def override_quantization_method(cls, hf_quant_cfg, - user_quant) -> Optional[str]: - can_convert = cls.is_marlin_compatible(hf_quant_cfg) - - is_valid_user_quant = (user_quant is None or user_quant == "marlin") - - if can_convert and is_valid_user_quant: - msg = ("The model is convertible to {} during runtime." - " Using {} kernel.".format(cls.get_name(), cls.get_name())) - logger.info(msg) - return cls.get_name() - - if can_convert and user_quant == "gptq": - logger.info("Detected that the model can run with gptq_marlin" - ", however you specified quantization=gptq explicitly," - " so forcing gptq. Use quantization=gptq_marlin for" - " faster inference") - return None - - def get_quant_method( - self, - layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]: - if isinstance(layer, LinearBase): - return GPTQMarlinLinearMethod(self) - return None - - def get_scaled_act_names(self) -> List[str]: - return [] - - @classmethod - def is_marlin_compatible(cls, quant_config: Dict[str, Any]): - # Extract data from quant config. - num_bits = quant_config.get("bits", None) - group_size = quant_config.get("group_size", None) - sym = quant_config.get("sym", None) - desc_act = quant_config.get("desc_act", None) - - # If we cannot find the info needed in the config, cannot convert. - if (num_bits is None or group_size is None or sym is None - or desc_act is None): - return False - - # If the capability of the device is too low, cannot convert. - major, minor = torch.cuda.get_device_capability() - device_capability = major * 10 + minor - if device_capability < cls.get_min_capability(): - return False - - # Otherwise, can convert if model satisfies marlin constraints. - return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS - and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES - and sym in GPTQ_MARLIN_SUPPORTED_SYM) - - -class GPTQMarlinState(Enum): - REPACK = enum.auto() - READY = enum.auto() - - -class GPTQMarlinLinearMethod(LinearMethodBase): - """Linear method for GPTQ Marlin. - - Args: - quant_config: The GPTQ Marlin quantization config. - """ - - def __init__(self, quant_config: GPTQMarlinConfig) -> None: - self.quant_config = quant_config - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ) -> None: - del output_size - - # Normalize group_size - if self.quant_config.group_size != -1: - group_size = self.quant_config.group_size - else: - group_size = input_size - - # Validate dtype - if params_dtype not in [torch.float16, torch.bfloat16]: - raise ValueError(f"The params dtype must be float16 " - f"or bfloat16, but got {params_dtype}") - - # Validate output_size_per_partition - output_size_per_partition = sum(output_partition_sizes) - if output_size_per_partition % self.quant_config.min_thread_n != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {self.quant_config.min_thread_n}.") - - # Validate input_size_per_partition - if input_size_per_partition % self.quant_config.min_thread_k != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {self.quant_config.min_thread_k}.") - - if (group_size < input_size - and input_size_per_partition % group_size != 0): - raise ValueError( - f"Weight input_size_per_partition = {input_size_per_partition}" - f" is not divisible by group_size = {group_size}.") - - # Detect sharding of scales/zp - - # By default, no sharding over "input dim" - scales_and_zp_size = input_size // group_size - scales_and_zp_input_dim = None - - if self.quant_config.desc_act: - # Act-order case - assert self.quant_config.group_size != -1 - - is_k_full = input_size_per_partition == input_size - - else: - # No act-order case - - # K is always full due to full alignment with - # group-size and shard of scales/zp - is_k_full = True - - # If this is a row-parallel case, then shard scales/zp - if (input_size != input_size_per_partition - and self.quant_config.group_size != -1): - scales_and_zp_size = input_size_per_partition // group_size - scales_and_zp_input_dim = 0 - - # Init buffers - - # Quantized weights - qweight = Parameter( - torch.empty( - input_size_per_partition // self.quant_config.pack_factor, - output_size_per_partition, - dtype=torch.int32, - ), - requires_grad=False, - ) - set_weight_attrs( - qweight, - { - **extra_weight_attrs, - "input_dim": 0, - "output_dim": 1, - "packed_dim": 0, - "pack_factor": self.quant_config.pack_factor, - }, - ) - - # Activation order - g_idx = Parameter( - torch.empty( - input_size_per_partition, - dtype=torch.int32, - ), - requires_grad=False, - ) - # Ignore warning from fused linear layers such as QKVParallelLinear. - set_weight_attrs( - g_idx, - { - **extra_weight_attrs, "input_dim": 0, - "ignore_warning": True - }, - ) - - g_idx_sort_indices = torch.empty( - g_idx.shape, - dtype=torch.int32, - ) - - # Scales - scales = Parameter( - torch.empty( - scales_and_zp_size, - output_size_per_partition, - dtype=params_dtype, - ), - requires_grad=False, - ) - set_weight_attrs( - scales, - { - **extra_weight_attrs, - "input_dim": scales_and_zp_input_dim, - "output_dim": 1, - }, - ) - - # Quantized zero-points - qzeros = Parameter( - torch.empty( - scales_and_zp_size, - output_size_per_partition // self.quant_config.pack_factor, - dtype=torch.int32, - device="meta", - ), - requires_grad=False, - ) - set_weight_attrs( - qzeros, - { - **extra_weight_attrs, - "input_dim": scales_and_zp_input_dim, - "output_dim": 1, - "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, - }, - ) - - # Allocate marlin workspace - max_workspace_size = ( - output_size_per_partition // - self.quant_config.min_thread_n) * self.quant_config.max_parallel - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - requires_grad=False) - - layer.register_parameter("qweight", qweight) - layer.register_parameter("g_idx", g_idx) - layer.register_parameter("scales", scales) - layer.register_parameter("qzeros", qzeros) - layer.g_idx_sort_indices = g_idx_sort_indices - layer.workspace = workspace - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.input_size = input_size - layer.is_k_full = is_k_full - layer.marlin_state = GPTQMarlinState.REPACK - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - reshaped_x = x.reshape(-1, x.shape[-1]) - - size_m = reshaped_x.shape[0] - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - full_size_k = layer.input_size - - out_shape = x.shape[:-1] + (part_size_n, ) - - if layer.marlin_state == GPTQMarlinState.REPACK: - layer.marlin_state = GPTQMarlinState.READY - - # Newly generated tensors need to replace existing tensors that are - # already registered as parameters by vLLM (and won't be freed) - def replace_tensor(name, new_t): - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - cur_device = layer.qweight.device - - # Process act_order - if self.quant_config.desc_act: - # Get sorting based on g_idx - g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int) - - sorted_g_idx = layer.g_idx[g_idx_sort_indices] - - replace_tensor("g_idx", sorted_g_idx) - replace_tensor("g_idx_sort_indices", g_idx_sort_indices) - - else: - # Reset g_idx related tensors - layer.g_idx = Parameter( - torch.empty(0, dtype=torch.int, device=cur_device), - requires_grad=False, - ) - layer.g_idx_sort_indices = Parameter( - torch.empty(0, dtype=torch.int, device=cur_device), - requires_grad=False, - ) - - # Repack weights - marlin_qweight = ops.gptq_marlin_repack( - layer.qweight, - layer.g_idx_sort_indices, - part_size_k, - part_size_n, - self.quant_config.weight_bits, - ) - replace_tensor("qweight", marlin_qweight) - - # Permute scales - scales_size_k = part_size_k - scales_size_n = part_size_n - if self.quant_config.desc_act: - scales_size_k = full_size_k - - marlin_scales = marlin_permute_scales( - layer.scales, - scales_size_k, - scales_size_n, - self.quant_config.group_size, - self.quant_config.weight_bits, - ) - replace_tensor("scales", marlin_scales) - - output = ops.gptq_marlin_gemm( - reshaped_x, - layer.qweight, - layer.scales, - layer.g_idx, - layer.g_idx_sort_indices, - layer.workspace, - self.quant_config.weight_bits, - size_m, - part_size_n, - part_size_k, - layer.is_k_full, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) diff --git a/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/quantizer.py b/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/quantizer.py deleted file mode 100644 index e945a70..0000000 --- a/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/quantizer.py +++ /dev/null @@ -1,140 +0,0 @@ -from logging import getLogger - -import torch -import torch.nn as nn - - -logger = getLogger(__name__) - - -def quantize(x, scale, zero, maxq): - if maxq < 0: - return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero - q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) - return scale * (q - zero) - - -class Quantizer(nn.Module): - def __init__(self, shape=1): - super(Quantizer, self).__init__() - self.register_buffer("maxq", torch.tensor(0)) - self.register_buffer("scale", torch.zeros(shape)) - self.register_buffer("zero", torch.zeros(shape)) - - def configure( - self, - bits, - perchannel=False, - sym=True, - mse=False, - norm=2.4, - grid=100, - maxshrink=0.8, - trits=False, - ): - self.maxq = torch.tensor(2**bits - 1) - self.perchannel = perchannel - self.sym = sym - self.mse = mse - self.norm = norm - self.grid = grid - self.maxshrink = maxshrink - if trits: - self.maxq = torch.tensor(-1) - - def find_params(self, x, weight=False): - dev = x.device - self.maxq = self.maxq.to(dev) - - shape = x.shape - if self.perchannel: - if weight: - x = x.flatten(1) - else: - if len(shape) == 4: - x = x.permute([1, 0, 2, 3]) - x = x.flatten(1) - if len(shape) == 3: - x = x.reshape((-1, shape[-1])).t() - if len(shape) == 2: - x = x.t() - else: - x = x.flatten().unsqueeze(0) - - tmp = torch.zeros(x.shape[0], device=dev) - xmin = torch.minimum(x.min(1)[0], tmp) - xmax = torch.maximum(x.max(1)[0], tmp) - - if self.sym: - xmax = torch.maximum(torch.abs(xmin), xmax) - tmp = xmin < 0 - if torch.any(tmp): - xmin[tmp] = -xmax[tmp] - tmp = (xmin == 0) & (xmax == 0) - xmin[tmp] = -1 - xmax[tmp] = +1 - - if self.maxq < 0: - self.scale = xmax - self.zero = xmin - else: - self.scale = (xmax - xmin) / self.maxq - if self.sym: - self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) - else: - self.zero = torch.round(-xmin / self.scale) - - if self.mse: - best = torch.full([x.shape[0]], float("inf"), device=dev) - for i in range(int(self.maxshrink * self.grid)): - p = 1 - i / self.grid - xmin1 = p * xmin - xmax1 = p * xmax - scale1 = (xmax1 - xmin1) / self.maxq - zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero - q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) - q -= x - q.abs_() - q.pow_(self.norm) - err = torch.sum(q, 1) - tmp = err < best - if torch.any(tmp): - best[tmp] = err[tmp] - self.scale[tmp] = scale1[tmp] - self.zero[tmp] = zero1[tmp] - if not self.perchannel: - if weight: - tmp = shape[0] - else: - tmp = shape[1] if len(shape) != 3 else shape[2] - self.scale = self.scale.repeat(tmp) - self.zero = self.zero.repeat(tmp) - - if weight: - shape = [-1] + [1] * (len(shape) - 1) - self.scale = self.scale.reshape(shape) - self.zero = self.zero.reshape(shape) - return - if len(shape) == 4: - self.scale = self.scale.reshape((1, -1, 1, 1)) - self.zero = self.zero.reshape((1, -1, 1, 1)) - if len(shape) == 3: - self.scale = self.scale.reshape((1, 1, -1)) - self.zero = self.zero.reshape((1, 1, -1)) - if len(shape) == 2: - self.scale = self.scale.unsqueeze(0) - self.zero = self.zero.unsqueeze(0) - - def quantize(self, x): - if self.ready(): - return quantize(x, self.scale, self.zero, self.maxq) - return x - - def enabled(self): - return self.maxq > 0 - - def ready(self): - return torch.all(self.scale != 0) - - -__all__ = ["Quantizer"] diff --git a/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/repack.py b/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/repack.py deleted file mode 100644 index 987f05b..0000000 --- a/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/repack.py +++ /dev/null @@ -1,99 +0,0 @@ -import torch -import enum -from enum import Enum -from typing import Any, Dict, List, Optional -from torch.nn.parameter import Parameter - -def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - reshaped_x = x.reshape(-1, x.shape[-1]) - - size_m = reshaped_x.shape[0] - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - full_size_k = layer.input_size - - out_shape = x.shape[:-1] + (part_size_n, ) - - if layer.marlin_state == GPTQMarlinState.REPACK: - layer.marlin_state = GPTQMarlinState.READY - - # Newly generated tensors need to replace existing tensors that are - # already registered as parameters by vLLM (and won't be freed) - def replace_tensor(name, new_t): - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - cur_device = layer.qweight.device - - # Process act_order - if self.quant_config.desc_act: - # Get sorting based on g_idx - g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int) - - sorted_g_idx = layer.g_idx[g_idx_sort_indices] - - replace_tensor("g_idx", sorted_g_idx) - replace_tensor("g_idx_sort_indices", g_idx_sort_indices) - - else: - # Reset g_idx related tensors - layer.g_idx = Parameter( - torch.empty(0, dtype=torch.int, device=cur_device), - requires_grad=False, - ) - layer.g_idx_sort_indices = Parameter( - torch.empty(0, dtype=torch.int, device=cur_device), - requires_grad=False, - ) - - # Repack weights - marlin_qweight = ops.gptq_marlin_repack( - layer.qweight, - layer.g_idx_sort_indices, - part_size_k, - part_size_n, - self.quant_config.weight_bits, - ) - replace_tensor("qweight", marlin_qweight) - - # Permute scales - scales_size_k = part_size_k - scales_size_n = part_size_n - if self.quant_config.desc_act: - scales_size_k = full_size_k - - marlin_scales = marlin_permute_scales( - layer.scales, - scales_size_k, - scales_size_n, - self.quant_config.group_size, - self.quant_config.weight_bits, - ) - replace_tensor("scales", marlin_scales) - - output = ops.gptq_marlin_gemm( - reshaped_x, - layer.qweight, - layer.scales, - layer.g_idx, - layer.g_idx_sort_indices, - layer.workspace, - self.quant_config.weight_bits, - size_m, - part_size_n, - part_size_k, - layer.is_k_full, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) diff --git a/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py b/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py index 7b0398f..accbc00 100644 --- a/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py +++ b/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py @@ -220,7 +220,7 @@ def compute_max_diff(output, output_ref): class MarlinWorkspace: - def __init__(self, out_features, min_thread_n, max_parallel): + def __init__(self, out_features, min_thread_n, max_parallel, device): assert (out_features % min_thread_n == 0), ( "out_features = {} is undivisible by min_thread_n = {}".format( out_features, min_thread_n)) @@ -229,4 +229,4 @@ class MarlinWorkspace: self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, - device="cuda") + device=device) diff --git a/ktransformers/ktransformers_ext/operators/llamafile/linear.cpp b/ktransformers/ktransformers_ext/operators/llamafile/linear.cpp index 7dcba57..81e5006 100644 --- a/ktransformers/ktransformers_ext/operators/llamafile/linear.cpp +++ b/ktransformers/ktransformers_ext/operators/llamafile/linear.cpp @@ -47,13 +47,13 @@ void Linear::forward_many(int qlen, const void* input, void* output, Backend* ba int nth = config_.output_size / config_.stride; backend->do_work_stealing_job(nth, [&](int task_id) { int ith = task_id; - void* proj_ptr = proj_ + ith * config_.stride * config_.input_size * ggml_type_size(config_.proj_type) / ggml_blck_size(config_.proj_type); + void* proj_ptr = (uint8_t*)proj_ + ith * config_.stride * config_.input_size * ggml_type_size(config_.proj_type) / ggml_blck_size(config_.proj_type); float* proj_output_ptr = proj_output_ + ith * config_.stride; llamafile_sgemm(config_.stride, qlen, config_.input_size / ggml_blck_size(config_.proj_type), proj_ptr, config_.input_size / ggml_blck_size(config_.proj_type), proj_input_ptr, config_.input_size / ggml_blck_size(config_.proj_type), proj_output_ptr, config_.output_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.proj_type, ggml_internal_get_type_traits(config_.proj_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) { for (int i = 0; i < qlen; i++) { float* output_fp32_ptr = proj_output_ + i * config_.output_size + ith * config_.stride; - void* output_ptr = output + i * config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); + void* output_ptr = (uint8_t*)output + i * config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type); } } @@ -69,5 +69,5 @@ void Linear::forward(int qlen, const void* input, void* output, Backend* backend } int forward_len = std::min(qlen, config_.group_max_len); forward_many(forward_len, input, output, backend); - forward(qlen - forward_len, input + forward_len * config_.input_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), output + forward_len * config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend); + forward(qlen - forward_len, (uint8_t*)input + forward_len * config_.input_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend); } \ No newline at end of file diff --git a/ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp b/ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp index 8ef092f..abad01e 100644 --- a/ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp +++ b/ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp @@ -74,10 +74,10 @@ void MLP::forward_many(int qlen, const void* input, void* output, Backend* backe int nth = config_.intermediate_size / config_.stride; backend->do_work_stealing_job(nth, [&](int task_id) { int ith = task_id; - void* gate_proj_ptr = gate_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); + void* gate_proj_ptr = (uint8_t*)gate_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); float* gate_output_ptr = gate_output_ + ith * config_.stride; llamafile_sgemm(config_.stride, qlen, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); - void* up_proj_ptr = up_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); + void* up_proj_ptr = (uint8_t*)up_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); float* up_output_ptr = up_output_ + ith * config_.stride; llamafile_sgemm(config_.stride, qlen, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); for (int i = 0; i < qlen; i++) { @@ -86,7 +86,7 @@ void MLP::forward_many(int qlen, const void* input, void* output, Backend* backe } if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) == 0) { float* intermediate_fp32_ptr = intermediate_fp32_ + i * config_.intermediate_size + ith * config_.stride; - void* down_input_ptr = down_input_ + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * config_.stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type); + void* down_input_ptr = (uint8_t*)down_input_ + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * config_.stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type); from_float(intermediate_fp32_ptr, down_input_ptr, config_.stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type); } } @@ -97,13 +97,13 @@ void MLP::forward_many(int qlen, const void* input, void* output, Backend* backe nth = config_.hidden_size / config_.stride; backend->do_work_stealing_job(nth, [&](int task_id) { int ith = task_id; - void* down_proj_ptr = down_proj_ + ith * config_.stride * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); + void* down_proj_ptr = (uint8_t*)down_proj_ + ith * config_.stride * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); float* down_output_ptr = down_output_ + ith * config_.stride; llamafile_sgemm(config_.stride, qlen, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) { for (int i = 0; i < qlen; i++) { float* output_fp32_ptr = down_output_ + i * config_.hidden_size + ith * config_.stride; - void* output_ptr = output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); + void* output_ptr = (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type); } } @@ -119,5 +119,5 @@ void MLP::forward(int qlen, const void* input, void* output, Backend* backend) { } int forward_len = std::min(qlen, config_.group_max_len); forward_many(forward_len, input, output, backend); - forward(qlen - forward_len, input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend); + forward(qlen - forward_len, (uint8_t*)input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend); } \ No newline at end of file diff --git a/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp b/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp index 8010f54..d75db65 100644 --- a/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp +++ b/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp @@ -9,7 +9,7 @@ **/ #include "moe.h" #include -#include "unistd.h" +#include MOE::MOE(MOEConfig config) { config_ = config; @@ -60,7 +60,7 @@ MOE::MOE(MOEConfig config) { m_local_pos_.resize(config_.group_max_len); for (int i = 0; i < config_.group_max_len; i++) { - m_local_pos_[i].reserve(config_.expert_num); + m_local_pos_[i].resize(config_.routed_expert_num); } m_local_num_.resize(config_.expert_num); m_local_gate_input_ptr_.resize(config_.expert_num); @@ -125,10 +125,10 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c int expert_idx = task_id / nth; uint64_t expert_id = expert_ids[expert_idx]; int ith = task_id % nth; - void* gate_proj_ptr = gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); + void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); float* gate_output_ptr = s_gate_output_[expert_idx] + ith * config_.stride; llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); - void* up_proj_ptr = up_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); + void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride; llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { @@ -153,7 +153,7 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c } for (int expert_idx = 0; expert_idx < k; expert_idx++) { uint64_t expert_id = expert_ids[expert_idx]; - void* down_proj_ptr = down_proj_ + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); + void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); float* down_output_ptr = s_down_output_[expert_idx] + ith * config_.stride; llamafile_sgemm(config_.stride, 1, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), s_down_input_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { @@ -162,7 +162,7 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c } if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) { float* output_fp32_ptr = s_output_fp32_ + ith * config_.stride; - void* output_ptr = output + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); + void* output_ptr = (uint8_t*)output + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type); } }); @@ -195,9 +195,9 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float* const void* gate_input_ptr; const void* up_input_ptr; if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type && config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) { - gate_input_ptr = up_input_ptr = input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); + gate_input_ptr = up_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); } else { - to_float(input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), m_input_fp32_[i], config_.hidden_size, config_.hidden_type); + to_float((uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), m_input_fp32_[i], config_.hidden_size, config_.hidden_type); if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) { from_float(m_input_fp32_[i], m_gate_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type); gate_input_ptr = up_input_ptr = m_gate_input_[i]; @@ -206,13 +206,13 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float* from_float(m_input_fp32_[i], m_gate_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type); gate_input_ptr = m_gate_input_[i]; } else { - gate_input_ptr = input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); + gate_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); } if (config_.hidden_type != ggml_internal_get_type_traits(config_.up_type).vec_dot_type) { from_float(m_input_fp32_[i], m_up_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.up_type).vec_dot_type); up_input_ptr = m_up_input_[i]; } else { - up_input_ptr = input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); + up_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); } } } @@ -227,11 +227,11 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float* int expert_idx = task_id / nth; int ith = task_id % nth; void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx]; - void* gate_proj_ptr = gate_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); + void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); float* gate_output_ptr = m_local_gate_output_ptr_[expert_idx] + ith * stride; llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); void* up_input_ptr = m_local_up_input_ptr_[expert_idx]; - void* up_proj_ptr = up_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); + void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride; llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); for (int i = 0; i < m_local_num_[expert_idx]; i++) { @@ -249,7 +249,7 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float* int expert_idx = task_id / nth; int ith = task_id % nth; void* down_input_ptr = m_local_down_input_ptr_[expert_idx]; - void* down_proj_ptr = down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); + void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); float* down_output_ptr = m_local_down_output_ptr_[expert_idx] + ith * stride; llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); }); @@ -262,18 +262,18 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float* m_output_fp32_[i][e] += m_local_down_output_ptr_[expert_ids[i * k + j]][m_local_pos_[i][j] * config_.hidden_size + e] * weights[i * k + j]; } } - from_float(m_output_fp32_[i], output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), config_.hidden_size, config_.hidden_type); + from_float(m_output_fp32_[i], (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), config_.hidden_size, config_.hidden_type); }); } void MOE::forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend) { if (qlen < config_.group_min_len) { for (int i = 0; i < qlen; i++) { - forward_one(k, expert_ids + i * k, weights + i * k, input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend); + forward_one(k, expert_ids + i * k, weights + i * k, (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend); } return; } int forward_len = std::min(config_.group_max_len, qlen); forward_many(forward_len, k, expert_ids, weights, input, output, backend); - forward(qlen - forward_len, k, expert_ids + forward_len * k, weights + forward_len * k, input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend); + forward(qlen - forward_len, k, expert_ids + forward_len * k, weights + forward_len * k, (uint8_t*)input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend); } \ No newline at end of file diff --git a/ktransformers/ktransformers_ext/operators/llamafile/shared_mem_buffer.cpp b/ktransformers/ktransformers_ext/operators/llamafile/shared_mem_buffer.cpp index b1599da..dc2d65d 100644 --- a/ktransformers/ktransformers_ext/operators/llamafile/shared_mem_buffer.cpp +++ b/ktransformers/ktransformers_ext/operators/llamafile/shared_mem_buffer.cpp @@ -49,7 +49,7 @@ void SharedMemBuffer::dealloc(void* object) { void SharedMemBuffer::arrange(std::vector> requests) { uint64_t offset = 0; for (auto& request : requests) { - *(request.first) = buffer_ + offset; + *(request.first) = (uint8_t*)buffer_ + offset; offset += request.second; } } diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py old mode 100644 new mode 100755 index 59839be..b5782d1 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -31,18 +31,21 @@ import fire from ktransformers.optimize.optimize import optimize_and_load_gguf from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM +from ktransformers.models.modeling_mixtral import MixtralForCausalLM from ktransformers.util.utils import prefill_and_generate from ktransformers.server.config.config import Config custom_models = { "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, "Qwen2MoeForCausalLM": Qwen2MoeForCausalLM, + "MixtralForCausalLM": MixtralForCausalLM, } ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" default_optimize_rules ={ "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml", + "MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml", } def local_chat( @@ -50,7 +53,8 @@ def local_chat( optimize_rule_path: str = None, gguf_path: str = None, max_new_tokens: int = 1000, - cpu_infer: int = Config().cpu_infer + cpu_infer: int = Config().cpu_infer, + use_cuda_graph: bool = True, ): torch.set_grad_enabled(False) @@ -64,6 +68,8 @@ def local_chat( print("using custom modeling_xxx.py.") if "Qwen2Moe" in config.architectures[0]: # Qwen2Moe must use flash_attention_2 to avoid overflow. config._attn_implementation = "flash_attention_2" + if "Mixtral" in config.architectures[0]: + config._attn_implementation = "flash_attention_2" model = custom_models[config.architectures[0]](config) else: model = AutoModelForCausalLM.from_config( @@ -100,7 +106,6 @@ def local_chat( while True: content = input("Chat: ") - # if content is num if content == "": content = "Please write a piece of quicksort code in C++." @@ -109,7 +114,7 @@ def local_chat( messages, add_generation_prompt=True, return_tensors="pt" ) torch.set_default_dtype(torch.bfloat16) # TODO: Remove this, replace dtype using config - generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens) + generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph) if __name__ == "__main__": - fire.Fire(local_chat) + fire.Fire(local_chat) \ No newline at end of file diff --git a/ktransformers/models/custom_cache.py b/ktransformers/models/custom_cache.py index 385e6ec..dbaea57 100644 --- a/ktransformers/models/custom_cache.py +++ b/ktransformers/models/custom_cache.py @@ -22,13 +22,14 @@ class StaticCache(transformers.StaticCache): The maximum batch size with which the model will be used. max_cache_len (`int`): The maximum sequence length with which the model will be used. - device (`torch.device`): + device (`torch.device` or `dict`): The device on which the cache should be initialized. Should be the same as the layer. + If a `dict`, it should contain the `device` key with the device name as the value. dtype (*optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. """ - def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: + def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device: torch.device| dict, dtype=None) -> None: Cache.__init__(self) self.max_batch_size = max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len @@ -46,6 +47,7 @@ class StaticCache(transformers.StaticCache): self.value_cache: List[torch.Tensor] = [] cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) if config.architectures[0] == "DeepseekV2ForCausalLM": + # TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically # key_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.qk_rope_head_dim + config.qk_nope_head_dim) # value_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.v_head_dim) key_shape = (max_batch_size, 1, self.max_cache_len, config.qk_rope_head_dim) @@ -56,11 +58,15 @@ class StaticCache(transformers.StaticCache): self.past_tokens = [] self.num_hidden_layers = config.num_hidden_layers - for _ in range(self.num_hidden_layers): + for idx in range(self.num_hidden_layers): # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. - new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=device) - new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=device) + if isinstance(device, dict): + target_device = device[f"blk.{idx}.self_attn"]["generate_device"] + else: + target_device = device + new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=target_device) + new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=target_device) torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_value_cache) self.key_cache.append(new_layer_key_cache) diff --git a/ktransformers/models/modeling_deepseek.py b/ktransformers/models/modeling_deepseek.py index 81fee86..692020d 100644 --- a/ktransformers/models/modeling_deepseek.py +++ b/ktransformers/models/modeling_deepseek.py @@ -1048,7 +1048,7 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. - Args: + # Args: query_states (`torch.Tensor`): Input query states to be passed to Flash Attention API key_states (`torch.Tensor`): @@ -1245,12 +1245,14 @@ class DeepseekV2DecoderLayer(nn.Module): cache_position=cache_position, **kwargs, ) + hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states outputs = (hidden_states,) diff --git a/ktransformers/models/modeling_mixtral.py b/ktransformers/models/modeling_mixtral.py new file mode 100644 index 0000000..87d8cf1 --- /dev/null +++ b/ktransformers/models/modeling_mixtral.py @@ -0,0 +1,1735 @@ +# coding=utf-8 +''' +Description : +Author : kkk1nak0 +Date : 2024-07-29 02:58:57 +Version : 1.0.0 +LastEditors : kkk1nak0 +LastEditTime : 2024-08-02 06:08:34 +''' + +# Adapted from +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mixtral model.""" + +import inspect +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_causal_attention_mask, +) +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13 +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + logging, + replace_return_docstrings, +) +from transformers.utils.import_utils import is_torch_fx_available +from transformers.models.mixtral.configuration_mixtral import MixtralConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func, flash_attn_func, flash_attn_with_kvcache + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MixtralConfig" + + +def load_balancing_loss_func( + gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None +) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + attention_mask (`torch.Tensor`, None): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + num_experts (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral +class MixtralRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MixtralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral +# TODO @longjie no longer copied from Mistral after static cache +class MixtralRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +# TODO @longjie no longer copied from Mistral after static cache +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral +# TODO @longjie no longer copied from Mistral after static cache +class MixtralAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = MixtralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral +# TODO @longjie no longer copied from Mistral after static cache +class MixtralFlashAttention2(MixtralAttention): + """ + Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, position_ids) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and self.config.use_sliding_window + ) + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" + " make sure to upgrade flash-attn library." + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # we slice the states for static kv cache to be supported in FA2. Not sure it's a must as compile fails + # for bsz == 1, avoid using slice to capture cuda graph + if cache_position is not None and q_len > 1: + key_states = key_states[:, :, : cache_position[-1] + 1, :] + value_states = value_states[:, :, : cache_position[-1] + 1, :] + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self.config, "sliding_window", None), + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids, + dropout, + sliding_window, + is_causal, + softmax_scale=None, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + + """ + + # Decide whether to use SWA or not by layer index. + # if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: + # use_sliding_windows = False + use_sliding_windows = False + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, q_len + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=is_causal, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=is_causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len) + else: + if not use_sliding_windows: + if q_len == 1: + position_ids = position_ids.to(dtype=torch.int32).squeeze(1) + attn_output = flash_attn_with_kvcache( + query_states, + key_states, + value_states, + cache_seqlens=position_ids, + softmax_scale=softmax_scale, + causal=is_causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=is_causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=is_causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + return attn_output + + # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + # On the first iteration we need to properly re-create the padding mask + # by slicing it on the proper place + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + + +# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral +# TODO @longjie no longer copied from Mistral after static cache +class MixtralSdpaAttention(MixtralAttention): + """ + Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from MixtralAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +MIXTRAL_ATTENTION_CLASSES = { + "eager": MixtralAttention, + "flash_attention_2": MixtralFlashAttention2, + "sdpa": MixtralSdpaAttention, +} + + +class MixtralBlockSparseTop2MLP(nn.Module): + def __init__(self, config: MixtralConfig): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # gate + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) # down + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # up + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class MixtralSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accomodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + + self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) + + # Jitter parameters + self.jitter_noise = config.router_jitter_noise + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + if self.training and self.jitter_noise > 0: + hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class MixtralDecoderLayer(nn.Module): + def __init__(self, config: MixtralConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.block_sparse_moe = MixtralSparseMoeBlock(config) + self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, router_logits = self.block_sparse_moe(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +MIXTRAL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MixtralConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", + MIXTRAL_START_DOCSTRING, +) +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral +class MixtralPreTrainedModel(PreTrainedModel): + config_class = MixtralConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MixtralDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MIXTRAL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", + MIXTRAL_START_DOCSTRING, +) +# copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral +# TODO @longjie no longer copied from Mistral after static cache +class MixtralModel(MixtralPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`] + + Args: + config: MixtralConfig + """ + + def __init__(self, config: MixtralConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Ignore copy + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + use_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache) and not self.training: + use_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class MixtralForCausalLM(MixtralPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = MixtralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + # Ignore copy + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + output_router_logits=False, + position_ids=None, + use_cache=True, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + } + ) + return model_inputs + + +@add_start_docstrings( + """ + The Mixtral Model transformer with a sequence classification head on top (linear layer). + + [`MixtralForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + MIXTRAL_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL +class MixtralForSequenceClassification(MixtralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MixtralModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Mixtral Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + MIXTRAL_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mixtral, LLAMA->MIXTRAL +class MixtralForTokenClassification(MixtralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MixtralModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/ktransformers/operators/RoPE.py b/ktransformers/operators/RoPE.py index 5fcce4f..9dc233b 100644 --- a/ktransformers/operators/RoPE.py +++ b/ktransformers/operators/RoPE.py @@ -10,6 +10,7 @@ from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.utils import InferenceState from transformers.configuration_utils import PretrainedConfig + # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding): def __init__(self, @@ -17,12 +18,16 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding): gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, - device: str = "cuda", + # device: str = "cuda", + generate_device: str = "cuda", + prefill_device: str = "cuda", **kwargs): - BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) self.orig_module.__init__(orig_module.dim, orig_module.max_position_embeddings, orig_module.base) + self.generate_device = generate_device + self.prefill_device = prefill_device def load(self): self.orig_module.__init__(self.orig_module.dim, @@ -36,9 +41,11 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding): gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, - device: str = "cuda", + # device: str = "cuda", + generate_device: str = "cuda", + prefill_device: str = "cuda", **kwargs): - BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) self.orig_module.__init__(orig_module.dim, orig_module.max_position_embeddings, orig_module.base, @@ -49,13 +56,15 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding): orig_module.beta_slow, orig_module.mscale, orig_module.mscale_all_dim) + self.generate_device = generate_device + self.prefill_device = prefill_device def load(self): self.orig_module.__init__(self.orig_module.dim, self.orig_module.max_position_embeddings, self.orig_module.base, - self.device, + self.generate_device, self.orig_module.scaling_factor, self.orig_module.original_max_position_embeddings, self.orig_module.beta_fast, diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 0369f5f..7028c74 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -5,8 +5,8 @@ Description : Author : Azure-Tang, Boxin Zhang, chenht2022 Date : 2024-07-25 11:25:24 Version : 0.1.0 -LastEditors : Azure -LastEditTime : 2024-07-26 09:27:41 +LastEditors : kkk1nak0 +LastEditTime : 2024-08-11 12:14:39 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' @@ -19,7 +19,9 @@ import torch import sys, os from ktransformers.operators.base_operator import BaseInjectedModule -sys.path.append(os.path.dirname(__file__) + "/../ktransformers_ext/build") +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build")) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release")) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug")) import cpuinfer_ext from cpuinfer_ext.moe import MOEConfig, MOE import ctypes @@ -78,6 +80,25 @@ class MLPExpertsBase(ABC): gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"] up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"] down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"] + elif key + ".ffn_down.0.weight" in self.gguf_loader.tensor_info: + # for supporting Mixtral-8x7B-Instuct + gate = [] + up = [] + down = [] + for i in range(8): + gatei, upi, downi = f".ffn_gate.{i}.weight", f".ffn_up.{i}.weight", f".ffn_down.{i}.weight" + targets = [gatei, upi, downi] + tensors = self.load_multi(key, targets, device=device) + gate_it, up_it, down_it = tensors[gatei], tensors[upi], tensors[downi] + gate.append(gate_it) + up.append(up_it) + down.append(down_it) + gate = torch.stack(gate) + up = torch.stack(up) + down = torch.stack(down) + gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate.0.weight"]["ggml_type"] + up_type = self.gguf_loader.tensor_info[key + ".ffn_up.0.weight"]["ggml_type"] + down_type = self.gguf_loader.tensor_info[key + ".ffn_down.0.weight"]["ggml_type"] else: raise ValueError(f"Experts {key} not found in gguf_loader") res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}} @@ -94,7 +115,8 @@ class MLPCPUExperts(MLPExpertsBase): expert_ids_cpu:Tensor = None weights_cpu:Tensor = None output_cpu:Tensor = None - output_gpu:Tensor = None + output_gpu_map:dict = {} # Manage output tensor buffer on different gpu + #stream_map:dict = {} # Manage cuda stream on different gpu CPU_INFER = cpuinfer_ext.CPUInfer(Config().cpu_infer) def __init__( self, @@ -113,81 +135,83 @@ class MLPCPUExperts(MLPExpertsBase): self.out_device = out_device def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False): - if device: - assert device.lower() == "cpu", "MLPCPUExperts can only be loaded on CPU, Parameter \"device\" can be cpu or None." - if w is None: w = self.load_weights()[self.key] - self.gate = w["gate"] - self.up = w["up"] - self.down = w["down"] - self.gate_type = w["gate_type"] - self.up_type = w["up_type"] - self.down_type = w["down_type"] - gate_ptr = ctypes.addressof( - ctypes.cast(self.gate.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents - ) - up_ptr = ctypes.addressof( - ctypes.cast(self.up.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents - ) - down_ptr = ctypes.addressof( - ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents - ) - # print(self.gate_qtype, self.up_qtype, self.down_qtype) - n_routed_experts = self.n_routed_experts - # n_routed_experts = len(self.orig_module) - moe_config = MOEConfig( - n_routed_experts, - self.config.num_experts_per_tok, - self.config.hidden_size, - self.config.moe_intermediate_size, - 64, - 10, - 1024, - gate_ptr, - up_ptr, - down_ptr, - self.gate_type, - self.up_type, - self.down_type, - 30, # TODO: get from model.dtype - ) - # print(n_routed_experts, hidden_size, moe_intermediate_size) - num_experts_per_tok = self.config.num_experts_per_tok - self.moe = MOE(moe_config) - self.cpu_infer = MLPCPUExperts.CPU_INFER - if warmup: - self.cpu_infer.submit(self.moe.warm_up()) - self.cpu_infer.sync() - if MLPCPUExperts.output_gpu == None: - MLPCPUExperts.input_tensor_cpu = torch.empty((self.config.hidden_size), device="cpu", pin_memory=True) - MLPCPUExperts.expert_ids_cpu = torch.empty((num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True) - MLPCPUExperts.weights_cpu = torch.empty((num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True) - MLPCPUExperts.output_cpu = torch.empty((self.config.hidden_size), device="cpu", pin_memory=True) - MLPCPUExperts.output_gpu = torch.empty((self.config.hidden_size), device=self.out_device) - + with torch.device(self.out_device): + if device: + assert device.lower() == "cpu", "MLPCPUExperts can only be loaded on CPU, Parameter \"device\" can be cpu or None." + if w is None: w = self.load_weights()[self.key] + self.gate = w["gate"] + self.up = w["up"] + self.down = w["down"] + self.gate_type = w["gate_type"] + self.up_type = w["up_type"] + self.down_type = w["down_type"] + gate_ptr = ctypes.addressof( + ctypes.cast(self.gate.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents + ) + up_ptr = ctypes.addressof( + ctypes.cast(self.up.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents + ) + down_ptr = ctypes.addressof( + ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents + ) + # print(self.gate_qtype, self.up_qtype, self.down_qtype) + n_routed_experts = self.n_routed_experts + # n_routed_experts = len(self.orig_module) + moe_config = MOEConfig( + n_routed_experts, + self.config.num_experts_per_tok, + self.config.hidden_size, + self.config.moe_intermediate_size, + 64, + 10, + 1024, + gate_ptr, + up_ptr, + down_ptr, + self.gate_type, + self.up_type, + self.down_type, + 30, # TODO: get from model.dtype + ) + # print(n_routed_experts, hidden_size, moe_intermediate_size) + num_experts_per_tok = self.config.num_experts_per_tok + self.moe = MOE(moe_config) + self.cpu_infer = MLPCPUExperts.CPU_INFER + if warmup: + self.cpu_infer.submit(self.moe.warm_up()) + self.cpu_infer.sync() + if self.out_device not in MLPCPUExperts.output_gpu_map: + MLPCPUExperts.output_gpu_map[self.out_device] = torch.zeros((self.config.hidden_size), device=self.out_device) + if MLPCPUExperts.input_tensor_cpu == None: + MLPCPUExperts.input_tensor_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True) + MLPCPUExperts.expert_ids_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True) + MLPCPUExperts.weights_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True) + MLPCPUExperts.output_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) + def submit_for_one_decode(self, input_tensor, expert_ids, weights): MLPCPUExperts.input_tensor_cpu.copy_(input_tensor, non_blocking=True) MLPCPUExperts.expert_ids_cpu.copy_(expert_ids, non_blocking=True) MLPCPUExperts.weights_cpu.copy_(weights, non_blocking=True) - self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(1, expert_ids.size(0), MLPCPUExperts.expert_ids_cpu.data_ptr(), MLPCPUExperts.weights_cpu.data_ptr(), MLPCPUExperts.input_tensor_cpu.data_ptr(), MLPCPUExperts.output_cpu.data_ptr())) - + self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(0), MLPCPUExperts.expert_ids_cpu.data_ptr(), MLPCPUExperts.weights_cpu.data_ptr(), MLPCPUExperts.input_tensor_cpu.data_ptr(), MLPCPUExperts.output_cpu.data_ptr())) + def sync_for_one_decode(self): - self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream) - MLPCPUExperts.output_gpu.copy_(MLPCPUExperts.output_cpu, non_blocking=True) - #print("capturing experts finish") - return MLPCPUExperts.output_gpu + self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream) + MLPCPUExperts.output_gpu_map[self.out_device].copy_(MLPCPUExperts.output_cpu, non_blocking=True) + return MLPCPUExperts.output_gpu_map[self.out_device] def forward(self, input_tensor, expert_ids, weights): # generate, capture and run cuda graph + # print(expert_ids) if input_tensor.size(0)==1: + # TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible #print("capturing experts") MLPCPUExperts.input_tensor_cpu.copy_(input_tensor, non_blocking=True) MLPCPUExperts.expert_ids_cpu.copy_(expert_ids, non_blocking=True) MLPCPUExperts.weights_cpu.copy_(weights, non_blocking=True) self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(1, expert_ids.size(1), MLPCPUExperts.expert_ids_cpu.data_ptr(), MLPCPUExperts.weights_cpu.data_ptr(), MLPCPUExperts.input_tensor_cpu.data_ptr(), MLPCPUExperts.output_cpu.data_ptr())) self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream) - MLPCPUExperts.output_gpu.copy_(MLPCPUExperts.output_cpu, non_blocking=True) - #print("capturing experts finish") - return MLPCPUExperts.output_gpu + MLPCPUExperts.output_gpu_map[self.out_device].copy_(MLPCPUExperts.output_cpu, non_blocking=True) + return MLPCPUExperts.output_gpu_map[self.out_device] else: input_tensor = input_tensor.contiguous().cpu() expert_ids = expert_ids.contiguous().cpu() @@ -195,7 +219,7 @@ class MLPCPUExperts(MLPExpertsBase): output = torch.empty_like(input_tensor).contiguous() self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr())) self.cpu_infer.sync() - return output.to(device=object.__getattribute__(self, "device")) + return output.to(device=object.__getattribute__(self, "out_device")) def unload(self): return @@ -222,6 +246,24 @@ class MLPCPUExperts(MLPExpertsBase): gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"] up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"] down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"] + elif key + ".ffn_down.0.weight" in self.gguf_loader.tensor_info: + # for supporting Mixtral-8x7B-Instuct + gate = [] + up = [] + down = [] + for i in range(8): + gate_it = self.gguf_loader.get_mmap_tensor(f"{key}.ffn_gate.{i}.weight") + up_it = self.gguf_loader.get_mmap_tensor(f"{key}.ffn_up.{i}.weight") + down_it = self.gguf_loader.get_mmap_tensor(f"{key}.ffn_down.{i}.weight") + gate.append(gate_it) + up.append(up_it) + down.append(down_it) + gate = np.stack(gate) + up = np.stack(up) + down = np.stack(down) + gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate.0.weight"]["ggml_type"] + up_type = self.gguf_loader.tensor_info[key + ".ffn_up.0.weight"]["ggml_type"] + down_type = self.gguf_loader.tensor_info[key + ".ffn_down.0.weight"]["ggml_type"] else: raise ValueError(f"Experts {key} not found in gguf_loader") res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}} @@ -299,7 +341,7 @@ class MLPExpertsMarlin(MLPExpertsBase): gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"] up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"] down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"] - # tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"]) + # tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"]) res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}} return res @@ -359,6 +401,11 @@ class MLPExpertsTorch(MLPExpertsBase): self.down = None def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor: + + org_device = hidden_states_cpu.device + hidden_states_cpu = hidden_states_cpu.to(self.device) + selected_experts_cpu = selected_experts_cpu.to(self.device) + routing_weights_cpu = routing_weights_cpu.to(self.device) batch_sequence_length, hidden_dim = hidden_states_cpu.size() @@ -388,27 +435,29 @@ class MLPExpertsTorch(MLPExpertsBase): # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states) - return final_hidden_states.to(org_dtype) + + return final_hidden_states.to(org_dtype, device=org_device) EXPERTS_MAP = { "MLPCPUExperts": MLPCPUExperts, "MLPExpertsTorch": MLPExpertsTorch, "MLPExpertsMarlin": MLPExpertsMarlin, } + class KTransformersMLPExpert(BaseInjectedModule, MLPExpertsBase): def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, - device: str = "cuda", + # device: str = "cuda", prefill_device:str = "cuda", prefill_mlp_type: str | None = "MLPExpertsTorch", generate_device: str = "cpu", generate_mlp_type: str | None = "MLPCPUExperts", **kwargs): - BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) - MLPExpertsBase.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + MLPExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) if generate_mlp_type is not None: self.generate_experts = EXPERTS_MAP[generate_mlp_type](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs) else: @@ -471,6 +520,7 @@ class KTransformersMLPExpert(BaseInjectedModule, MLPExpertsBase): from ktransformers.models.modeling_deepseek import DeepseekV2MoE from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock +from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock class Qwen2MoeSparseMoeBlockInjected(BaseInjectedModule, Qwen2MoeSparseMoeBlock): @@ -578,7 +628,6 @@ class Qwen2MoeSparseMoeBlockInjected(BaseInjectedModule, Qwen2MoeSparseMoeBlock) return final_hidden_states - class DeepseekV2MoEInjected(BaseInjectedModule, DeepseekV2MoE): def forward(self, hidden_states): identity = hidden_states @@ -587,7 +636,7 @@ class DeepseekV2MoEInjected(BaseInjectedModule, DeepseekV2MoE): topk_idx, topk_weight, aux_loss = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - if sequence_length == 1: + if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode"): self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0]) if self.config.n_shared_experts is not None: y_ = self.shared_experts(identity).squeeze(0) @@ -677,3 +726,102 @@ class DeepseekV2MoEInjected(BaseInjectedModule, DeepseekV2MoE): .type(new_x.dtype) ) return final_out + +class MisrtalSparseMoEBlockInjected(BaseInjectedModule, MixtralSparseMoeBlock): + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + orig_shape = hidden_states.shape + batch_size, sequence_length, hidden_dim = hidden_states.shape + if self.training and self.jitter_noise > 0: + hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode"): + self.experts.generate_experts.submit_for_one_decode(hidden_states[0], selected_experts[0], routing_weights[0]) + y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0) + y.resize_(*orig_shape) + return y, router_logits + + hidden_states_expert = hidden_states.to(self.experts.device) if isinstance(self.experts, MLPExpertsBase) else hidden_states_expert.cpu() + selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, MLPExpertsBase) else selected_experts_expert.cpu() + routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, MLPExpertsBase) else routing_weights_expert.cpu() + + if isinstance(self.experts, MLPExpertsBase): + y = ( + self.moe_on_cpuinfer( + hidden_states_expert, selected_experts_expert, routing_weights_expert + ) + .view(*orig_shape) + .to(device=hidden_states.device) + ) + elif hidden_states_expert.size(0) > 10: + y = self.moe_infer( + hidden_states_expert, selected_experts_expert, routing_weights_expert, orig_shape + ).to(device=hidden_states.device) + else: + y = self.moe_infer_simple( + hidden_states_expert, selected_experts_expert, routing_weights_expert + ).to(device=hidden_states.device) + + y.resize_(*orig_shape) + return y, router_logits + + @torch.no_grad() + def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: + outs = torch.empty_like(x) + outs = self.experts(x, topk_ids, topk_weight) + return outs + + @torch.no_grad() + # TODO may bugs here + def moe_infer_simple(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor: + ''' + hidden_states_cpu: [num_tokens, hidden_size] + topk_ids, topk_weight: [num_tokens, num_selected_experts] + ''' + outs = torch.zeros_like(hidden_states_cpu) + for token_idx in range(selected_experts_cpu.size(0)): + for expert_idx in range(selected_experts_cpu.size(1)): + expert = self.experts[selected_experts_cpu[token_idx, expert_idx]] + outs[token_idx] += expert.forward(hidden_states_cpu[token_idx]) * routing_weights_cpu[token_idx, expert_idx] + return outs + + @torch.no_grad() + # TODO may bugs here + def moe_infer(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor, orig_shape: tuple) -> torch.Tensor: + + batch_size, sequence_length, hidden_dim = orig_shape + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states_cpu.dtype, device=hidden_states_cpu.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer.forward(current_state) * routing_weights_cpu[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states_cpu.dtype)) + + return final_hidden_states \ No newline at end of file diff --git a/ktransformers/operators/layer_wise_prefill.py b/ktransformers/operators/layer_wise_prefill.py index 61efed8..2a1d1fe 100644 --- a/ktransformers/operators/layer_wise_prefill.py +++ b/ktransformers/operators/layer_wise_prefill.py @@ -6,7 +6,7 @@ Author : Azure-Tang Date : 2024-07-25 11:25:24 Version : 1.0.0 LastEditors : Azure -LastEditTime : 2024-07-26 09:27:48 +LastEditTime : 2024-08-08 10:09:14 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' @@ -45,6 +45,8 @@ from ktransformers.models.modeling_deepseek import BaseModelOutputWithPast, Deep from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.utils import InferenceState +from ktransformers.util.custom_gguf import GGUFLoader +from transformers.configuration_utils import PretrainedConfig if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func @@ -73,34 +75,6 @@ QWEN2MOE_START_DOCSTRING = r""" [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ - -@add_start_docstrings( - "The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.", - QWEN2MOE_START_DOCSTRING, -) -class Qwen2MoePreTrainedModel(PreTrainedModel): - config_class = Qwen2MoeConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Qwen2MoeDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - _supports_static_cache = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - QWEN2MOE_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -177,13 +151,11 @@ QWEN2MOE_INPUTS_DOCSTRING = r""" the complete sequence length. """ -from ktransformers.util.custom_gguf import GGUFLoader -from transformers.configuration_utils import PretrainedConfig @add_start_docstrings( "The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.", QWEN2MOE_START_DOCSTRING, ) -class Qwen2MoeModelPerLayerPrefill(BaseInjectedModule): +class Qwen2MoeModelKTransformers(BaseInjectedModule): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`] @@ -198,10 +170,13 @@ class Qwen2MoeModelPerLayerPrefill(BaseInjectedModule): orig_module: nn.Module, device: str = "cuda", per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill + transfer_map: dict = None, **kwargs, ): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold + self.transfer_map = transfer_map + self.stream_device_map = dict() @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) def forward( @@ -287,7 +262,20 @@ class Qwen2MoeModelPerLayerPrefill(BaseInjectedModule): all_router_logits = () if output_router_logits else None next_decoder_cache = None - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): + if self.transfer_map is not None and i in self.transfer_map: + prev_stream = torch.cuda.current_stream() + cur_device = self.transfer_map[i] + if cur_device not in self.stream_device_map: + self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device) + torch.cuda.set_device(cur_device) + self.stream_device_map[cur_device].wait_stream(prev_stream) + torch.cuda.set_stream(self.stream_device_map[cur_device]) + hidden_states = hidden_states.to(self.transfer_map[i], non_blocking = True) + causal_mask = causal_mask.to(self.transfer_map[i], non_blocking = True) if causal_mask is not None else None + position_ids = position_ids.to(self.transfer_map[i], non_blocking = True) if position_ids is not None else None + cache_position = cache_position.to(self.transfer_map[i], non_blocking = True) if cache_position is not None else None + if output_hidden_states: all_hidden_states += (hidden_states,) @@ -463,7 +451,7 @@ DeepseekV2_INPUTS_DOCSTRING = r""" """ -class DeepseekV2ModelPerLayerPrefill(BaseInjectedModule): +class DeepseekV2ModelKTransformers(BaseInjectedModule): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`] @@ -478,10 +466,13 @@ class DeepseekV2ModelPerLayerPrefill(BaseInjectedModule): orig_module: nn.Module, device: str = "cuda", per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill + transfer_map: dict = None, **kwargs, ): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold + self.transfer_map = transfer_map + self.stream_device_map = dict() @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) def forward( @@ -584,7 +575,20 @@ class DeepseekV2ModelPerLayerPrefill(BaseInjectedModule): t_cpu = 0 t_f = 0 - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): + if self.transfer_map is not None and i in self.transfer_map: + prev_stream = torch.cuda.current_stream() + cur_device = self.transfer_map[i] + if cur_device not in self.stream_device_map: + self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device) + torch.cuda.set_device(cur_device) + self.stream_device_map[cur_device].wait_stream(prev_stream) + torch.cuda.set_stream(self.stream_device_map[cur_device]) + hidden_states = hidden_states.to(self.transfer_map[i], non_blocking = True) + causal_mask = causal_mask.to(self.transfer_map[i], non_blocking = True) if causal_mask is not None else None + position_ids = position_ids.to(self.transfer_map[i], non_blocking = True) if position_ids is not None else None + cache_position = cache_position.to(self.transfer_map[i], non_blocking = True) if cache_position is not None else None + if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index e264323..90b5506 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -176,7 +176,7 @@ class QuantizedLinearMarlin(QuantizedLinearBase): self.act_order = act_order self.is_k_full = is_k_full - def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = "cuda"): + def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): if device is None: device = self.device assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device" if w is None: w = self.load_weight(device=device) @@ -200,7 +200,7 @@ class QuantizedLinearMarlin(QuantizedLinearBase): weight, self.num_bits, self.group_size, self.act_order ) self.workspace = MarlinWorkspace( - self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL + self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device ) self.marlin_q_w = marlin_q_w self.marlin_s = marlin_s @@ -247,7 +247,6 @@ class QuantizedLinearMarlin(QuantizedLinearBase): LINEAR_MAP = { "QuantizedLinearMarlin": QuantizedLinearMarlin, "QuantizedLinearTorch": QuantizedLinearTorch, - "QuantizedLinearTorch": QuantizedLinearTorch, } class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase): @@ -257,15 +256,15 @@ class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase): gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, - device: str = "cuda", + # device: str = "cuda", generate_device: str = "cuda", generate_op: str| None = "QuantizedLinearMarlin", prefill_device: str = "cuda", prefill_op: str| None = "QuantizedLinearTorch", **kwargs, ): - BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) - QuantizedLinearBase.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + QuantizedLinearBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) # build all the linear operators if prefill_op is not None: assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported" @@ -289,7 +288,6 @@ class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase): self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs) else: self.generate_linear = None - self.device = device self.mode = InferenceState.UNLOAD def forward(self, x): diff --git a/ktransformers/optimize/optimize.py b/ktransformers/optimize/optimize.py index 7062166..36ab62d 100644 --- a/ktransformers/optimize/optimize.py +++ b/ktransformers/optimize/optimize.py @@ -1,6 +1,6 @@ ''' Description : -Author : Boxin Zhang +Author : Boxin Zhang, Azure-Tang Version : 0.1.0 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' @@ -15,6 +15,7 @@ from transformers.configuration_utils import PretrainedConfig from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf from ktransformers.util.utils import set_module, load_weights import itertools +import copy def inject(module, local_optimization_dict, model_config:AutoConfig ,gguf_loader:GGUFLoader, prefix=''): for name, child in module._modules.items(): @@ -22,18 +23,20 @@ def inject(module, local_optimization_dict, model_config:AutoConfig ,gguf_loader child_prefix = prefix + name if child_prefix in local_optimization_dict: inject_module_meta=local_optimization_dict[child_prefix] - if isinstance(inject_module_meta, Mapping): + if inject_module_meta["class"] != "default": import_path = inject_module_meta["class"].split(".") import_module_name = ".".join(import_path[:-1]) + gguf_loader.tensor_device_map[inject_module_meta["key"]] = inject_module_meta["kwargs"] if "kwargs" in inject_module_meta else dict() import_class_name = import_path[-1] module_cls=getattr(__import__(import_module_name, fromlist=[""]), import_class_name) print(f"Injecting {child_prefix} as", import_module_name, ".", import_class_name) - inject_module=module_cls(key = inject_module_meta["key"], gguf_loader = gguf_loader, config = model_config, orig_module=child, device = inject_module_meta["device"], **inject_module_meta["kwargs"]) + inject_module=module_cls(key = inject_module_meta["key"], gguf_loader = gguf_loader, config = model_config, orig_module=child, **inject_module_meta["kwargs"]) set_module(module, name, inject_module) - elif isinstance(inject_module_meta, str): - assert inject_module_meta=="default", "for str inject_module_meta, only support \"default\"." + elif inject_module_meta["class"] == "default": + print(f"Injecting {child_prefix} as default") + gguf_loader.tensor_device_map[inject_module_meta["key"]] = inject_module_meta["kwargs"] if "kwargs" in inject_module_meta else dict() else: - raise Exception("inject_module_meta must be a dict or str") + raise Exception("inject_module_meta[\"class\"] must be \"default\" or a class path") child_prefix += "." child_optimization_dict = {k: v for k, v in local_optimization_dict.items() if k.startswith(child_prefix)} inject(child, child_optimization_dict, model_config, gguf_loader, child_prefix) @@ -57,6 +60,8 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p for rule in rule_list: #print(rule) match_meta = rule["match"] + if "class" not in match_meta and "name" not in match_meta: + raise Exception("match must have at least one of \"class\" and \"name\"") if "class" in match_meta: import_path = match_meta["class"].split(".") import_module_name = ".".join(import_path[:-1]) @@ -67,16 +72,29 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p if "name" in match_meta: if re.search(match_meta["name"], module_name) is None: continue - replace_meta = rule["replace"] - out_data[module_name]={"key": translated_name, - "class": replace_meta["class"], - "device": replace_meta["device"] if "device" in replace_meta else default_device, - "kwargs": replace_meta["kwargs"] if "kwargs" in replace_meta else dict()} + if "replace" not in rule: + raise Exception("replace must be in rule") + if "replace" in rule: + replace_meta = rule["replace"] + if module_name not in out_data: + out_data[module_name]={"key": translated_name, + "class": replace_meta["class"] if "class" in replace_meta else "default", + # "device": replace_meta["device"] if "device" in replace_meta else default_device, + "kwargs": copy.deepcopy(replace_meta["kwargs"]) if "kwargs" in replace_meta else dict()} + else: + if out_data[module_name]["class"] == "default": + out_data[module_name]["class"] = replace_meta["class"] if "class" in replace_meta else "default" + out_data[module_name]["kwargs"].update(copy.deepcopy(replace_meta["kwargs"]) if "kwargs" in replace_meta else dict()) if "recursive" in rule: recursive = bool(rule["recursive"]) if module_name not in out_data: - out_data[module_name]="default" + out_data[module_name]= { + "class": "default", + "key": translated_name, + "kwargs": {"generate_device": default_device, + "prefill_device": default_device} + } #print(out_data[module_name]) #input() @@ -88,6 +106,14 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p gen_optimize_config(child, out_data, rule_list, child_prefix) +def translate_model_config(model_config: PretrainedConfig): + # for supporting some special model + if model_config.model_type == "mixtral": + model_config.moe_intermediate_size = model_config.intermediate_size + + return model_config + + def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, model_config: PretrainedConfig, default_device: str = "cuda:0"): with open(rule_file, 'r', encoding='utf-8') as f: rule_list = yaml.load(f.read(), Loader=yaml.FullLoader) @@ -95,8 +121,11 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo optimize_config = dict() gen_optimize_config(module, optimize_config, rule_list, default_device = default_device) + model_config = translate_model_config(model_config) + gguf_loader=GGUFLoader(gguf_path) with torch.device("meta"): inject(module, optimize_config, model_config, gguf_loader) load_weights(module, gguf_loader) + model_config.gguf_loader = gguf_loader del_meta(module) diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml new file mode 100644 index 0000000..1d6b46f --- /dev/null +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml @@ -0,0 +1,228 @@ +- match: + name: "^model\\.layers\\.([0-9])\\." + replace: + class: "default" + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "(^model\\.layers\\.([1][0-9])\\.)" + replace: + class: "default" + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" +- match: + name: "(^model\\.layers\\.([2][0-9])\\.)" + replace: + class: "default" + kwargs: + generate_device: "cuda:2" + prefill_device: "cuda:2" +- match: + name: "(^model\\.layers\\.([345][0-9])\\.)|(^model.norm)|(^lm_head)" + replace: + class: "default" + kwargs: + generate_device: "cuda:3" + prefill_device: "cuda:3" + +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +- match: + name: "^model\\.layers\\.([0-9])\\." + class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbedding + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([1][0-9])\\." + class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbedding + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" +- match: + name: "^model\\.layers\\.([2][0-9])\\." + class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbedding + kwargs: + generate_device: "cuda:2" + prefill_device: "cuda:2" +- match: + name: "^model\\.layers\\.([345][0-9])\\." + class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbedding + kwargs: + generate_device: "cuda:3" + prefill_device: "cuda:3" + +- match: + name: "^model\\.layers\\.([1][0-9])\\.(?!self_attn).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + generate_op: "QuantizedLinearMarlin" + prefill_op: "QuantizedLinearTorch" +- match: + name: "^model\\.layers\\.([1][0-9])\\.(?!self_attn).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + generate_op: "QuantizedLinearMarlin" + prefill_op: "QuantizedLinearTorch" +- match: + name: "^model\\.layers\\.([2][0-9])\\.(?!self_attn).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda:2" + prefill_device: "cuda:2" + generate_op: "QuantizedLinearMarlin" + prefill_op: "QuantizedLinearTorch" +- match: + name: "^model\\.layers\\.([345][0-9])\\.(?!self_attn).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda:3" + prefill_device: "cuda:3" + generate_op: "QuantizedLinearMarlin" + prefill_op: "QuantizedLinearTorch" + +- match: + name: "^model\\.layers\\.([0-9])\\.mlp$" + class: ktransformers.models.modeling_deepseek.DeepseekV2MoE + replace: + class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([1][0-9])\\.mlp$" + class: ktransformers.models.modeling_deepseek.DeepseekV2MoE + replace: + class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" +- match: + name: "^model\\.layers\\.([2][0-9])\\.mlp$" + class: ktransformers.models.modeling_deepseek.DeepseekV2MoE + replace: + class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function + kwargs: + generate_device: "cuda:2" + prefill_device: "cuda:2" +- match: + name: "^model\\.layers\\.([345][0-9])\\.mlp$" + class: ktransformers.models.modeling_deepseek.DeepseekV2MoE + replace: + class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function + kwargs: + generate_device: "cuda:3" + prefill_device: "cuda:3" + +- match: + name: "^model\\.layers\\.([0-9])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda:0" + prefill_mlp_type: "MLPExpertsTorch" + generate_device: "cpu" + generate_mlp_type: "MLPCPUExperts" + out_device: "cuda:0" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\.([1][0-9])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda:1" + prefill_mlp_type: "MLPExpertsTorch" + generate_device: "cpu" + generate_mlp_type: "MLPCPUExperts" + out_device: "cuda:1" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\.([2][0-9])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda:2" + prefill_mlp_type: "MLPExpertsTorch" + generate_device: "cpu" + generate_mlp_type: "MLPCPUExperts" + out_device: "cuda:2" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\.([345][0-9])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda:3" + prefill_mlp_type: "MLPExpertsTorch" + generate_device: "cpu" + generate_mlp_type: "MLPCPUExperts" + out_device: "cuda:3" + recursive: False # don't recursively inject submodules of this module + +- match: + name: "^model\\.layers\\.([0-9])\\.self_attn$" + replace: + class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([1][0-9])\\.self_attn$" + replace: + class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" +- match: + name: "^model\\.layers\\.([2][0-9])\\.self_attn$" + replace: + class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation + kwargs: + generate_device: "cuda:2" + prefill_device: "cuda:2" +- match: + name: "^model\\.layers\\.([345][0-9])\\.self_attn$" + replace: + class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation + kwargs: + generate_device: "cuda:3" + prefill_device: "cuda:3" + +- match: + name: "^model$" + replace: + class: "ktransformers.operators.layer_wise_prefill.DeepseekV2ModelKTransformers" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill + transfer_map: + 10: "cuda:1" + 20: "cuda:2" + 30: "cuda:3" \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml new file mode 100644 index 0000000..45af034 --- /dev/null +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml @@ -0,0 +1,126 @@ +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." + replace: + class: "default" + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + +- match: + name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)|(lm_head)" + replace: + class: "default" + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." + class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbedding + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([345][0-9])\\." + class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbedding + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + generate_op: "QuantizedLinearMarlin" + prefill_op: "QuantizedLinearTorch" + +- match: + name: "^model\\.layers\\.([345][0-9])\\.(?!self_attn).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + generate_op: "QuantizedLinearMarlin" + prefill_op: "QuantizedLinearTorch" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$" + class: ktransformers.models.modeling_deepseek.DeepseekV2MoE + replace: + class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([345][0-9])\\.mlp$" + class: ktransformers.models.modeling_deepseek.DeepseekV2MoE + replace: + class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda:0" + prefill_mlp_type: "MLPExpertsTorch" + generate_device: "cpu" + generate_mlp_type: "MLPCPUExperts" + out_device: "cuda:0" + recursive: False # don't recursively inject submodules of this module + +- match: + name: "^model\\.layers\\.([345][0-9])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda:1" + prefill_mlp_type: "MLPExpertsTorch" + generate_device: "cpu" + generate_mlp_type: "MLPCPUExperts" + out_device: "cuda:1" + recursive: False # don't recursively inject submodules of this module + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$" + replace: + class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([345][0-9])\\.self_attn$" + replace: + class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.layer_wise_prefill.DeepseekV2ModelKTransformers" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill + transfer_map: + 30: "cuda:1" \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml index 025bd2b..328c9d7 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml @@ -1,3 +1,10 @@ +- match: + name: "^model\\.layers\\..*\\.|^lm_head" + replace: + class: "default" + kwargs: + generate_device: "cuda" + prefill_device: "cuda" - match: class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding replace: @@ -21,12 +28,11 @@ name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism - device: "cpu" # which devices to load this module when initializing kwargs: prefill_device: "cuda" prefill_mlp_type: "MLPExpertsTorch" generate_device: "cpu" - generate_mlp_type: "MLPCPUExperts" + generate_mlp_type: "MLPCPUExperts" out_device: "cuda" recursive: False # don't recursively inject submodules of this module - match: @@ -36,6 +42,13 @@ - match: name: "^model$" replace: - class: "ktransformers.operators.layer_wise_prefill.DeepseekV2ModelPerLayerPrefill" + class: "ktransformers.operators.layer_wise_prefill.DeepseekV2ModelKTransformers" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml new file mode 100644 index 0000000..c9c1809 --- /dev/null +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml @@ -0,0 +1,126 @@ +- match: + name: "^model\\.layers\\.(0|[1-9])\\." + replace: + class: "default" + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + +- match: + name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)|(lm_head)" + replace: + class: "default" + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +- match: + name: "^model\\.layers\\.(0|[1-9])\\." + class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbedding + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([12][0-9])\\." + class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbedding + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9])\\.(?!self_attn).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + generate_op: "QuantizedLinearMarlin" + prefill_op: "QuantizedLinearTorch" + +- match: + name: "^model\\.layers\\.([12][0-9])\\.(?!self_attn).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + generate_op: "QuantizedLinearMarlin" + prefill_op: "QuantizedLinearTorch" + +- match: + name: "^model\\.layers\\.(0|[1-9])\\.mlp$" + class: ktransformers.models.modeling_deepseek.DeepseekV2MoE + replace: + class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([12][0-9])\\.mlp$" + class: ktransformers.models.modeling_deepseek.DeepseekV2MoE + replace: + class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda:0" + prefill_mlp_type: "MLPExpertsTorch" + generate_device: "cpu" + generate_mlp_type: "MLPCPUExperts" + out_device: "cuda:0" + recursive: False # don't recursively inject submodules of this module + +- match: + name: "^model\\.layers\\.([12][0-9])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda:1" + prefill_mlp_type: "MLPExpertsTorch" + generate_device: "cpu" + generate_mlp_type: "MLPCPUExperts" + out_device: "cuda:1" + recursive: False # don't recursively inject submodules of this module + +- match: + name: "^model\\.layers\\.(0|[1-9])\\.self_attn$" + replace: + class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([12][0-9])\\.self_attn$" + replace: + class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.layer_wise_prefill.DeepseekV2ModelKTransformers" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill + transfer_map: + 10: "cuda:1" \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/Mixtral.yaml b/ktransformers/optimize/optimize_rules/Mixtral.yaml new file mode 100644 index 0000000..5bd6705 --- /dev/null +++ b/ktransformers/optimize/optimize_rules/Mixtral.yaml @@ -0,0 +1,45 @@ +- match: + name: "^model\\.layers\\..*\\." + replace: + class: "default" + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + class: ktransformers.models.modeling_mixtral.MixtralRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.RotaryEmbedding +- match: + name: "^model\\.layers\\..*$" + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "QuantizedLinearMarlin" + prefill_op: "QuantizedLinearTorch" +- match: + name: "^model\\.layers\\..*\\.block_sparse_moe$" + class: ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock + replace: + class: ktransformers.operators.experts.MisrtalSparseMoEBlockInjected +- match: + name: "^model\\.layers\\..*\\.block_sparse_moe\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersMLPExpert + kwargs: + prefill_device: "cuda" + prefill_mlp_type: "MLPExpertsTorch" + generate_device: "cpu" + generate_mlp_type: "MLPCPUExperts" + out_device: "cuda" + recursive: False # don't recursively inject submodules of this module + +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" diff --git a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml new file mode 100644 index 0000000..82415aa --- /dev/null +++ b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml @@ -0,0 +1,111 @@ +- match: + name: "^model\\.layers\\.([012])\\." + replace: + class: "default" + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([012])\\." + class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.RotaryEmbedding + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([012])$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + generate_op: "QuantizedLinearMarlin" + prefill_op: "QuantizedLinearTorch" +- match: + name: "^model\\.layers\\.([012])\\.mlp$" + class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock + replace: + class: ktransformers.operators.experts.Qwen2MoeSparseMoeBlockInjected # mlp module with custom forward function +- match: + name: "^model\\.layers\\.([012])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism + # device: "cpu" # which devices to load this module when initializing + kwargs: + prefill_device: "cuda:0" + prefill_mlp_type: "MLPExpertsTorch" + generate_device: "cpu" + generate_mlp_type: "MLPCPUExperts" + out_device: "cuda:0" + recursive: False # don't recursively inject submodules of this module + +- match: + name: "^model\\.layers\\.([12][0-9]|[3-9])\\." + replace: + class: "default" + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" +- match: + name: "^model\\.layers\\.([12][0-9]|[3-9])\\." + class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.RotaryEmbedding + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" +- match: + name: "^model\\.layers\\.([12][0-9]|[3-9])$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + generate_op: "QuantizedLinearMarlin" + prefill_op: "QuantizedLinearTorch" +- match: + name: "^model\\.layers\\.([12][0-9]|[3-9])\\.mlp$" + class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock + replace: + class: ktransformers.operators.experts.Qwen2MoeSparseMoeBlockInjected # mlp module with custom forward function +- match: + name: "^model\\.layers\\.([12][0-9]|[3-9])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism + # device: "cpu" # which devices to load this module when initializing + kwargs: + prefill_device: "cuda:1" + prefill_mlp_type: "MLPExpertsTorch" + generate_device: "cpu" + generate_mlp_type: "MLPCPUExperts" + out_device: "cuda:1" + recursive: False # don't recursively inject submodules of this module + +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +- match: + name: "(^model.norm)|(^lm_head)" + replace: + class: "default" + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + +- match: + name: "^model$" + replace: + class: "ktransformers.operators.layer_wise_prefill.Qwen2MoeModelKTransformers" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill + transfer_map: + 3: "cuda:1" + diff --git a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml index 2b4e312..3fd59cb 100644 --- a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml +++ b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml @@ -1,3 +1,10 @@ +- match: + name: "^model\\.layers\\..*\\." + replace: + class: "default" + kwargs: + generate_device: "cuda" + prefill_device: "cuda" - match: class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding replace: @@ -21,7 +28,7 @@ name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism - device: "cpu" # which devices to load this module when initializing + # device: "cpu" # which devices to load this module when initializing kwargs: prefill_device: "cuda" prefill_mlp_type: "MLPExpertsTorch" @@ -32,6 +39,13 @@ - match: name: "^model$" replace: - class: "ktransformers.operators.layer_wise_prefill.Qwen2MoeModelPerLayerPrefill" + class: "ktransformers.operators.layer_wise_prefill.Qwen2MoeModelKTransformers" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" \ No newline at end of file diff --git a/ktransformers/tests/dequant_gpu.py b/ktransformers/tests/dequant_gpu.py index 4fbca1c..9c839c1 100644 --- a/ktransformers/tests/dequant_gpu.py +++ b/ktransformers/tests/dequant_gpu.py @@ -1,12 +1,9 @@ import os -os.environ["CUDA_VISIBLE_DEVICES"]="1" +# os.environ["CUDA_VISIBLE_DEVICES"]="1,2" # add path import sys current_path = os.path.abspath(os.path.dirname(__file__)) sys.path.append(current_path+"/../..") -import pycuda.autoinit -import pycuda.driver as cuda -from pycuda.compiler import SourceModule import numpy as np # from ktransformers.operators.linear import KTransformerLinear, QuantizedLinearMarlin # from ktransformers.operators.experts import KTransformersMLPExpert, MLPExpertsTorch @@ -18,36 +15,23 @@ import time from transformers import ( AutoConfig, ) +import os +# CUDA_LAUNCH_BLOCKING=1 +os.environ["CUDA_LAUNCH_BLOCKING"]="1" gguf_config = GGUFLoader("/data/Qwen2-57B-A14B-Instruct-GGUF/q4_k_m") model_name = "/data/Qwen2-57B-A14B-Instruct" -key = "blk.0." -target = "ffn_down_exps.weight" - -t1 = time.time() -q_weight_cpu = gguf_config.load_gguf_tensor(key+target, "cpu") -# q_weight_cpu = torch.from_numpy(q_weight_cpu) - -t2 = time.time() -q_weight_gpu = gguf_config.load_gguf_tensor(key+target, "cuda") -t3 = time.time() -print() -allclose = torch.allclose(q_weight_cpu, q_weight_gpu.cpu().to(torch.float32), atol=1e-6) -print(f"Q6k {key+target}") -print("load gguf tensor from cpu cost: ", t2-t1) -print("load gguf tensor from gpu cost: ", t3-t2) -print("allclose: ", allclose) - +# Q4k key = "blk.1." -target = "ffn_up_shexp.weight" +target = "attn_q.weight" t1 = time.time() q_weight_cpu = gguf_config.load_gguf_tensor(key+target, "cpu") # q_weight_cpu = torch.from_numpy(q_weight_cpu) t2 = time.time() -q_weight_gpu = gguf_config.load_gguf_tensor(key+target, "cuda") +q_weight_gpu = gguf_config.load_gguf_tensor(key+target, "cuda:0") t3 = time.time() print() allclose = torch.allclose(q_weight_cpu, q_weight_gpu.cpu(), atol=1e-6) @@ -55,3 +39,20 @@ print(f"Q4k {key+target}") print("load gguf tensor from cpu cost: ", t2-t1) print("load gguf tensor from gpu cost: ", t3-t2) print("allclose: ", allclose) + + +# Q6k +key = "blk.0." +target = "ffn_down_exps.weight" + +t1 = time.time() +q_weight_cpu = gguf_config.load_gguf_tensor(key+target, "cpu") +t2 = time.time() +q_weight_gpu = gguf_config.load_gguf_tensor(key+target, "cuda:0") +t3 = time.time() +print() +allclose = torch.allclose(q_weight_cpu, q_weight_gpu.cpu().to(torch.float32), atol=1e-6) +print(f"Q6k {key+target}") +print("load gguf tensor from cpu cost: ", t2-t1) +print("load gguf tensor from gpu cost: ", t3-t2) +print("allclose: ", allclose) diff --git a/ktransformers/tests/dequant_gpu_t.py b/ktransformers/tests/dequant_gpu_t.py index 3efcdf3..8abc89d 100644 --- a/ktransformers/tests/dequant_gpu_t.py +++ b/ktransformers/tests/dequant_gpu_t.py @@ -11,7 +11,7 @@ from ktransformers.operators.linear import KTransformerLinear, QuantizedLinearMa from ktransformers.operators.experts import KTransformersMLPExpert, MLPExpertsTorch from ktransformers.util.custom_gguf import GGUFLoader, dequantize_q4_k_gpu, dequantize_q4_k import torch -import CudaOps +import KTransformersOps torch.set_default_dtype(torch.bfloat16) import time from transformers import ( diff --git a/ktransformers/util/cuda_graph_runner.py b/ktransformers/util/cuda_graph_runner.py index 2ac7a17..c7a9c87 100644 --- a/ktransformers/util/cuda_graph_runner.py +++ b/ktransformers/util/cuda_graph_runner.py @@ -21,6 +21,7 @@ class CUDAGraphRunner: position_ids, cache_position, past_key_values, + main_device, **kwargs, ) -> None: assert self.graph is None @@ -29,15 +30,24 @@ class CUDAGraphRunner: self.graph = torch.cuda.CUDAGraph() #self.graph.enable_debug_mode() self.model = model - inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to("cuda") - with torch.cuda.graph(self.graph): + inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(main_device) + # torch.cuda.set_device can't set "cuda", must have a index + if main_device == "cuda": + main_device = "cuda:0" + torch.cuda.set_device(main_device) + self.main_device = main_device + capture_stream = torch.cuda.Stream() + with torch.cuda.graph(self.graph, stream = capture_stream): logits=model(inputs_embeds=inputs_embeds, position_ids=position_ids, cache_position=cache_position, past_key_values=past_key_values, **kwargs)[0] + capture_stream.wait_stream(torch.cuda.current_stream()) + torch.cuda.set_device(main_device) + torch.cuda.set_stream(capture_stream) past_key_values.change_seq_length(-1) - torch.cuda.synchronize() + torch.cuda.synchronize(self.main_device) #self.graph.debug_dump("cuda_graph_hooked.dot") # Save the input and output buffers. @@ -65,7 +75,7 @@ class CUDAGraphRunner: #print("begin replay") #time.sleep(1) self.graph.replay() - torch.cuda.synchronize() + torch.cuda.synchronize(self.main_device) # Return the output tensor. return self.output_buffers["logits"] diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index 643713e..fe796a7 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -5,8 +5,11 @@ Description : Author : Azure-Tang, Boxin Zhang, chenht2022 Date : 2024-07-26 08:48:54 Version : 1.0.0 -LastEditors : Azure -LastEditTime : 2024-07-26 09:28:25 +LastEditors : kkk1nak0 +LastEditTime : 2024-08-09 08:03:44 +Adapted from https://github.com/99991/pygguf/blob/main/gguf.py +Copyright (c) 2023-2024 The ggml authors +Copyright (c) 2024 Thomas Germer Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' # copied from llama.cpp/gguf-py/gguf/constants.py to satisfy dependence of gguf @@ -15,6 +18,7 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved. import struct import warnings import numpy as np +import re import numpy.typing as npt from typing import Sequence import os @@ -96,6 +100,8 @@ def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantization GGML_TYPES = { "F32": 0, "F16": 1, + "Q4_0": 2, + "Q5_0": 6, "Q8_0": 8, "Q2_K": 10, "Q3_K": 11, @@ -109,6 +115,8 @@ GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()} GGML_BLOCK_SIZES = { "F32": 4, "F16": 2, + "Q4_0": 2 + 16, + "Q5_0": 2 + 4 + 16, "Q8_0": 2 + 32, "Q2_K": 256 // 16 + 256 // 4 + 2 + 2, "Q3_K": 256 // 8 + 256 // 4 + 12 + 2, @@ -120,6 +128,8 @@ GGML_BLOCK_SIZES = { GGML_ELEMENTS_PER_BLOCK = { "F32": 1, "F16": 1, + "Q4_0": 32, + "Q5_0": 32, "Q8_0": 32, "Q2_K": 256, "Q3_K": 256, @@ -128,14 +138,6 @@ GGML_ELEMENTS_PER_BLOCK = { "Q6_K": 256, } -# DATA_TYPES = { -# "uint32": 4, -# "int32": 5, -# "float32": 6, -# "string": 8, -# "array": 9, -# "uint64": 10, -# } DATA_TYPES = { "uint8": 0, "int8": 1, @@ -167,6 +169,7 @@ class GGUFLoader: self.tensor_file_map = {} self.file_data_map = {} self.gguf_file_meta = {} + self.tensor_device_map = {} # Walk through all the .gguf files in the directory for root, dirs, files in os.walk(gguf_path): @@ -272,7 +275,7 @@ class GGUFLoader: def load_gguf_tensor(self, name: str, device:str = "cpu")->torch.Tensor: t = self.tensor_info[name] - + shape = t["shape"] ggml_type = t["ggml_type"] @@ -282,15 +285,28 @@ class GGUFLoader: ggml_name = GGML_NAMES[ggml_type] data = self.get_mmap_tensor(name) - if "cuda" in device.lower(): values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) + #values = GGML_DEQUANTIZE[ggml_name](data) + #print("load_gguf_tensor") + #values = torch.from_numpy(values).to(device = device) else: values = GGML_DEQUANTIZE[ggml_name](data) values = torch.from_numpy(values) - - return values.view(shape[::-1]) + + values = values.view(shape[::-1]) + if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]: + n_head = self.gguf_file_meta['llama.attention.head_count'] + values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:]) + .swapaxes(1, 2) + .reshape(values.shape)) + elif "attn_k" in name and self.gguf_file_meta['general.architecture'] in ["llama"]: + n_head = self.gguf_file_meta['llama.attention.head_count_kv'] + values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:]) + .swapaxes(1, 2) + .reshape(values.shape)) + return values def read_value(f, data_type): if data_type == DATA_TYPES["string"]: @@ -375,7 +391,7 @@ def dequantize_q2_k(data): return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4) def dequantize_q2_k_gpu(data): - pass + raise NotImplementedError() def dequantize_q3_k(data): # C implementation @@ -420,7 +436,7 @@ def dequantize_q3_k(data): ], axis=1) def dequantize_q3_k_gpu(data): - pass + raise NotImplementedError() def dequantize_q4_k(data): # C implementation @@ -429,20 +445,16 @@ def dequantize_q4_k(data): # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L116 block_size = GGML_BLOCK_SIZES["Q4_K"] num_blocks = len(data) // block_size - data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) - # Casting to float32 because float16 is very slow on CPU scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32) scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32) qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1) qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32) - # Dequantize scales and offsets (6 bits and 4 + 2 bits) factors = scale_factors * np.concatenate([qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1) offsets = scale_offsets * np.concatenate([qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1) - # Interleave low and high quantized bits qs2 = np.stack([qs2 & 0xf, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32) # Dequantize final weights using scales and offsets @@ -512,9 +524,14 @@ def dequantize_q5_k(data): d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8, ], axis=1) -def dequantize_q5_k_gpu(data): - pass - +def dequantize_q5_k_gpu(data, device:str ="cuda"): + block_size = GGML_BLOCK_SIZES["Q5_K"] + data = np.frombuffer(data, dtype=data.dtype) + device = torch.device(device) + # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, + # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. + data = torch.from_numpy(data) + return KTransformersOps.dequantize_q5_k(data, block_size, device) def dequantize_q6_k(data): # C implementation @@ -571,7 +588,49 @@ def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda"): num_blocks = len(data) // block_size data = np.frombuffer(data, dtype=data.dtype) data = torch.from_numpy(data) - return KTransformersOps.dequantize_q6_k(data, 210, device) + return KTransformersOps.dequantize_q6_k(data, block_size, device) + +def dequantize_q4_0(data): + # C implementation + # https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1515 + # C struct definition + # https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-common.h#L141 + num_blocks = len(data) // GGML_BLOCK_SIZES["Q4_0"] + + scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 8)[:, :1].astype(np.float32) + qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 16)[:, 2:] + + return np.concatenate([ + scales * ((qs & 0xf).astype(np.int8) - 8), + scales * ((qs >> 4).astype(np.int8) - 8), + ], axis=1) + +def dequantize_q4_0_gpu(data): + raise NotImplementedError() + +def dequantize_q5_0(data): + # C implementation + # https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1556 + # C struct definition + # https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-common.h#L161 + num_blocks = len(data) // GGML_BLOCK_SIZES["Q5_0"] + + scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 2 + 8)[:, :1].astype(np.float32) + qh = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2:2 + 4] + qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2 + 4:] + + bits = np.unpackbits(qh, axis=-1, bitorder="little") + + x0 = ((qs & 0xf).astype(np.int8) | (bits[:, :16] << 4)) - 16 + x1 = ((qs >> 4).astype(np.int8) | (bits[:, 16:] << 4)) - 16 + + return np.concatenate([ + scales * x0, + scales * x1, + ], axis=1) + +def dequantize_q5_0_gpu(data): + raise NotImplementedError() def dequantize_q8_0(data): # C struct definition @@ -615,6 +674,8 @@ def dequantize_f16_gpu(data, device): GGML_DEQUANTIZE = { "F32": dequantize_f32, "F16": dequantize_f16, + "Q4_0": dequantize_q4_0, + "Q5_0": dequantize_q5_0, "Q8_0": dequantize_q8_0, "Q2_K": dequantize_q2_k, "Q3_K": dequantize_q3_k, @@ -626,6 +687,8 @@ GGML_DEQUANTIZE = { GGML_DEQUANTIZE_GPU = { "F32": dequantize_f32_gpu, "F16": dequantize_f16_gpu, + "Q4_0": dequantize_q4_0_gpu, + "Q5_0": dequantize_q5_0_gpu, "Q8_0": dequantize_q8_0_gpu, "Q2_K": dequantize_q2_k_gpu, "Q3_K": dequantize_q3_k_gpu, @@ -634,7 +697,34 @@ GGML_DEQUANTIZE_GPU = { "Q6_K": dequantize_q6_k_gpu, } + +def translate_name_to_gguf_mixtral(name): + + replacement_template = { + "w1.weight": "ffn_gate", + "w2.weight": "ffn_down", + "w3.weight": "ffn_up" + } + + pattern = re.compile(r"model.layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.(w\d\.weight)") + + def replace_match(match): + blk_id = match.group(1) + expert_id = match.group(2) + weight_type = match.group(3) + if weight_type in replacement_template: + return f"blk.{blk_id}.{replacement_template[weight_type]}.{expert_id}.weight" + else: + return match.group(0) + + new_name = re.sub(pattern, replace_match, name) + + return new_name + def translate_name_to_gguf(name): + + name = translate_name_to_gguf_mixtral(name) + name = name.replace("lm_head.", "output.") name = name.replace("model.embed_tokens.", "token_embd.") name = name.replace("model.norm.", "output_norm.") @@ -671,9 +761,14 @@ def translate_name_to_gguf(name): name = name.replace(".mlp.experts.ffn_gate_exps", ".ffn_gate_exps") name = name.replace(".mlp.experts.ffn_up_exps", ".ffn_up_exps") + + name = name.replace(".block_sparse_moe.gate.", ".ffn_gate_inp.") + name = name.replace(".block_sparse_moe.experts", "") + return name if __name__ == '__main__': gguf_path = '/mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH' loader = GGUFLoader(gguf_path) loader.load_gguf_tensor('token_embd.weight') + diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 7976e56..7993d62 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -39,6 +39,22 @@ def set_param(module: nn.Module, name: str, weights: torch.Tensor): param.unsqueeze_(0) setattr(module, name, param) +def get_device(gguf_module_key:str, device_map:dict): + if gguf_module_key in device_map: + return device_map[gguf_module_key]["generate_device"] + else: + return "cuda" + +def get_all_used_cuda_device(device_map:dict): + all_device_list = set() + for key in device_map: + all_device_list.add(device_map[key]["generate_device"]) if "generate_device" in device_map[key] else None + all_device_list.add(device_map[key]["prefill_device"]) if "prefill_device" in device_map[key] else None + if "cpu" in all_device_list: + all_device_list.remove("cpu") + all_device_list = list(all_device_list) + return all_device_list + def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str = ""): prefix = prefix.replace("orig_module.", "") persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set} @@ -47,18 +63,19 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str for name, param in local_state.items(): key = prefix + name translated_key = translate_name_to_gguf(key) - print("default loading weights", key, translated_key) if translated_key in gguf_loader.tensor_file_map: target_dtype = torch.get_default_dtype() - device = "cpu" if "embd" in translated_key else "cuda" + device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map) + print(f"loading {translated_key} to {device}") + # device = "cpu" if "embd" in translated_key else "cuda" weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype) set_param(module, name, weights) del weights else: #print(load_config.tensor_file_map.keys()) - raise Exception(f"can't fand {translated_key} in GGUF file!") + raise Exception(f"can't find {translated_key} in GGUF file!") -def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix='', return_when_injected:bool = False, only_load_injected:bool = False): +def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''): # print(f"recursively loading weights {prefix},{return_when_injected=}, {only_load_injected=}") if not isinstance(module, base_operator.BaseInjectedModule): load_cur_state_dict(module, gguf_loader, prefix) @@ -66,27 +83,36 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix='', return_whe load_weights(child, gguf_loader, prefix+name+".") else: module.load() - -def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000): + +def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True): import os os.environ["TOKENIZERS_PARALLELISM"] = "false" torch._dynamo.config.suppress_errors = True batch_size, seq_length = inputs.shape - torch_device = inputs.device + device_map = model.config.gguf_loader.tensor_device_map + torch_device = get_device('blk.0.self_attn', device_map) + torch_device = "cuda:0" if torch_device == "cuda" else torch_device + inputs = inputs.to(torch_device) + all_cuda_device = get_all_used_cuda_device(device_map) + tokens = [] - def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values): - logits = cuda_graph_runner(cur_token, position_ids, cache_position) + def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, use_cuda_graph: bool = True): + if use_cuda_graph: + logits = cuda_graph_runner(cur_token, position_ids, cache_position) + else: + # custom_stream = torch.cuda.Stream() + torch.cuda.set_device(torch_device) + inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device) + # with torch.cuda.stream(custom_stream): + logits=model(inputs_embeds=inputs_embeds, + position_ids=position_ids, + cache_position=cache_position, + past_key_values=past_key_values, + return_dict=False, use_cache=True)[0] past_key_values.change_seq_length(1) - """ - with torch.cuda.stream(custom_stream): - logits=model(cur_token, - position_ids=position_ids, - cache_position=cache_position, - past_key_values=past_key_values, - return_dict=False, use_cache=True)[0] - #""" - torch.cuda.synchronize() + for device in all_cuda_device: + torch.cuda.synchronize(device) #print(logits) next_token_scores = logits_warper(inputs, logits[:, -1, :]) if generation_config.do_sample: @@ -95,11 +121,12 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000): else: next_token = torch.argmax(next_token_scores, dim=-1) return next_token - + + torch.cuda.set_device(torch_device) with torch.no_grad(): stream = TextStreamer(tokenizer) past_key_values = StaticCache( - config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = torch_device, dtype = model.dtype + config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype ) cache_position = torch.arange(seq_length, device=torch_device) generated_ids = torch.zeros( @@ -108,23 +135,22 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000): generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int) past_key_values.cur_idx=cache_position start_time = time.time() - #custom_stream = torch.cuda.Stream() - inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to("cuda") + inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device) logits = model( inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True - )[0][:,-1,:].unsqueeze(0).clone() + )[0][:,-1,:].unsqueeze(0).clone().to(torch_device) generation_config, model_kwargs = model._prepare_generation_config( None, max_length=max_new_tokens, do_sample=True, top_k=5, top_p=0.85, temperature=0.1 # change this to modify generate config ) try: # transformers==4.43 logits_warper = ( - model._get_logits_warper(generation_config,device=inputs.device) if generation_config.do_sample else None + model._get_logits_warper(generation_config,device=inputs.device) ) except: logits_warper = ( - model._get_logits_warper(generation_config) if generation_config.do_sample else None + model._get_logits_warper(generation_config) ) next_token_scores = logits_warper(inputs, logits[:, -1, :]) if generation_config.do_sample: @@ -136,7 +162,6 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000): prefill_count = seq_length prefill_time = first_token_time - print(stream.put(next_token.item()), end="", flush=True) generated_ids[:, seq_length] = next_token tokens.append(next_token) @@ -144,12 +169,16 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000): cache_position = torch.tensor([seq_length], device=torch_device) position_ids = cache_position.unsqueeze(0) seq_length += 1 - - cuda_graph_runner = CUDAGraphRunner() - cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, return_dict=False, use_cache=True) + + if use_cuda_graph: + cuda_graph_runner = CUDAGraphRunner() + cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True) + else: + cuda_graph_runner = None + start_time = time.time() for _ in range(1, max_new_tokens): - next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values) + next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device) inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1) generated_ids[:, cache_position] = next_token.int() tokens.append(next_token.int()) @@ -162,6 +191,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000): print(stream.put(next_token.item()), end="", flush=True) cache_position += 1 position_ids = cache_position.unsqueeze(0) + total_time = time.time() - start_time tokens_generated = len(tokens) diff --git a/pyproject.toml b/pyproject.toml index 8cfe290..863fcb4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,8 @@ requires = [ "setuptools", "torch >= 2.3.0", "ninja", - "packaging" + "packaging", + "cpufeature" ] build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 1b2d3cf..aeff5f6 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ Author : chenxl Date : 2024-07-27 16:15:27 Version : 1.0.0 LastEditors : chenxl -LastEditTime : 2024-07-31 09:44:46 +LastEditTime : 2024-08-08 02:45:15 Adapted from: https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py Copyright (c) 2023, Tri Dao. @@ -19,6 +19,7 @@ import re import ast import subprocess import platform +import shutil import http.client import urllib.request import urllib.error @@ -27,6 +28,7 @@ from packaging.version import parse import torch.version from wheel.bdist_wheel import bdist_wheel as _bdist_wheel from setuptools import setup, Extension +from cpufeature.extension import CPUFeature from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME class CpuInstructInfo: @@ -67,6 +69,8 @@ class VersionInfo: """ if sys.platform.startswith("linux"): return f'linux_{platform.uname().machine}' + elif sys.platform == "win32": + return "win_amd64" else: raise ValueError("Unsupported platform: {}".format(sys.platform)) @@ -97,6 +101,15 @@ class VersionInfo: return 'avx2' raise ValueError( "Unsupported cpu Instructions: {}".format(flags_line)) + elif sys.platform == "win32": + if CPUFeature.get("AVX512bw", False): + return 'fancy' + if CPUFeature.get("AVX512f", False): + return 'avx512' + if CPUFeature.get("AVX2", False): + return 'avx2' + raise ValueError( + "Unsupported cpu Instructions: {}".format(str(CPUFeature))) else: raise ValueError("Unsupported platform: {}".format(sys.platform)) @@ -154,7 +167,7 @@ class BuildWheelsCommand(_bdist_wheel): wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") print("Raw wheel path", wheel_path) - os.rename(wheel_filename, wheel_path) + shutil.move(wheel_filename, wheel_path) except (urllib.error.HTTPError, urllib.error.URLError, http.client.RemoteDisconnected): print("Precompiled wheel not found. Building from source...") # If the wheel could not be downloaded, build from source diff --git a/third_party/llamafile/iqk_mul_mat.inc b/third_party/llamafile/iqk_mul_mat.inc index 150a8f9..5e9d688 100644 --- a/third_party/llamafile/iqk_mul_mat.inc +++ b/third_party/llamafile/iqk_mul_mat.inc @@ -22,7 +22,7 @@ #include #include -#if defined __x86_64__ || defined __aarch64__ +#if defined __x86_64__ || defined __aarch64__ || defined(_M_X64) #include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml-quants.h" @@ -225,7 +225,7 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const voi return true; } -#if defined __x86_64__ +#if defined __x86_64__ || defined(_M_X64) #if defined HAVE_FANCY_SIMD #undef HAVE_FANCY_SIMD @@ -1412,7 +1412,8 @@ template void MulMat::set_functions(MulMat& m) { bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int) { - row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00); + if (ne00 % ggml_blck_size(GGML_TYPE_Q8_K) == 0) + row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00); switch (typeA) { case GGML_TYPE_Q2_K: diff --git a/third_party/llamafile/iqk_mul_mat_amd_avx2.cpp b/third_party/llamafile/iqk_mul_mat_amd_avx2.cpp index 9e3de18..bfd12da 100644 --- a/third_party/llamafile/iqk_mul_mat_amd_avx2.cpp +++ b/third_party/llamafile/iqk_mul_mat_amd_avx2.cpp @@ -3,6 +3,6 @@ // Copyrigth 2024 Iwan Kawrakow. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(_M_X64) #include "iqk_mul_mat.inc" #endif // __x86_64__ diff --git a/third_party/llamafile/iqk_mul_mat_amd_zen4.cpp b/third_party/llamafile/iqk_mul_mat_amd_zen4.cpp index 4d0a979..f0f439f 100644 --- a/third_party/llamafile/iqk_mul_mat_amd_zen4.cpp +++ b/third_party/llamafile/iqk_mul_mat_amd_zen4.cpp @@ -3,7 +3,7 @@ // Copyrigth 2024 Iwan Kawrakow. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(_M_X64) #define iqk_mul_mat iqk_mul_mat_zen4 #define iqk_mul_mat_moe iqk_mul_mat_moe_zen4 #include "iqk_mul_mat.inc" diff --git a/third_party/llamafile/sgemm.cpp b/third_party/llamafile/sgemm.cpp index 7ec34ff..6a7cab4 100644 --- a/third_party/llamafile/sgemm.cpp +++ b/third_party/llamafile/sgemm.cpp @@ -22,19 +22,22 @@ #include "sgemm.h" // #include -#include +// #include // #include #include -#include +// #include #include // #include "llamafile.h" static const struct GemmFuncs { - typeof(llamafile_sgemm)* sgemm; - typeof(llamafile_mixmul)* mixmul; - typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported; + bool (*sgemm)(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int); + bool (*mixmul)(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*); + bool (*iqk_mixmul)(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int); + // typeof(llamafile_sgemm)* sgemm; + // typeof(llamafile_mixmul)* mixmul; + // typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported; GemmFuncs() { -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(_M_X64) // if (X86_HAVE(AVX)) { // if (X86_HAVE(FMA)) { // if (X86_HAVE(AVX2)) { @@ -86,10 +89,12 @@ static const struct GemmFuncs { // sgemm = llamafile_sgemm_unsupported; // mixmul = llamafile_mixmul_unsupported; // } + #if defined(__AVX__) -#if defined(__FMA__) +#if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))) #if defined(__AVX2__) #if defined(__AVX512F__) + printf("__AVX512F__\n"); #if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__) // AMD Zen4+ (2023-) sgemm = llamafile_sgemm_amd_zen4; diff --git a/third_party/llamafile/tinyblas_cpu.h b/third_party/llamafile/tinyblas_cpu.h index f361c0c..962c47c 100644 --- a/third_party/llamafile/tinyblas_cpu.h +++ b/third_party/llamafile/tinyblas_cpu.h @@ -223,7 +223,7 @@ inline float32x4_t badder(float32x4_t a, float b, float32x4_t c, float32x4_t* e) } #endif -#if defined(__FMA__) +#if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))) #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) template <> inline __m256 madd(__m256 a, __m256 b, __m256 c) { diff --git a/third_party/llamafile/tinyblas_cpu_mixmul_amd_avx.cpp b/third_party/llamafile/tinyblas_cpu_mixmul_amd_avx.cpp index 255f873..5cbf5df 100644 --- a/third_party/llamafile/tinyblas_cpu_mixmul_amd_avx.cpp +++ b/third_party/llamafile/tinyblas_cpu_mixmul_amd_avx.cpp @@ -3,7 +3,7 @@ // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(_M_X64) #define llamafile_mixmul llamafile_mixmul_amd_avx #include "tinyblas_cpu_mixmul.inc" diff --git a/third_party/llamafile/tinyblas_cpu_mixmul_amd_avx2.cpp b/third_party/llamafile/tinyblas_cpu_mixmul_amd_avx2.cpp index 552d1aa..95d44bf 100644 --- a/third_party/llamafile/tinyblas_cpu_mixmul_amd_avx2.cpp +++ b/third_party/llamafile/tinyblas_cpu_mixmul_amd_avx2.cpp @@ -3,7 +3,7 @@ // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(_M_X64) #define llamafile_mixmul llamafile_mixmul_amd_avx2 #include "tinyblas_cpu_mixmul.inc" #endif // __x86_64__ diff --git a/third_party/llamafile/tinyblas_cpu_mixmul_amd_avx512f.cpp b/third_party/llamafile/tinyblas_cpu_mixmul_amd_avx512f.cpp index b5e5183..82ab637 100644 --- a/third_party/llamafile/tinyblas_cpu_mixmul_amd_avx512f.cpp +++ b/third_party/llamafile/tinyblas_cpu_mixmul_amd_avx512f.cpp @@ -3,7 +3,7 @@ // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(_M_X64) #define llamafile_mixmul llamafile_mixmul_amd_avx512f #include "tinyblas_cpu_mixmul.inc" #endif // __x86_64__ diff --git a/third_party/llamafile/tinyblas_cpu_mixmul_amd_avxvnni.cpp b/third_party/llamafile/tinyblas_cpu_mixmul_amd_avxvnni.cpp index c2b2790..2726ac8 100644 --- a/third_party/llamafile/tinyblas_cpu_mixmul_amd_avxvnni.cpp +++ b/third_party/llamafile/tinyblas_cpu_mixmul_amd_avxvnni.cpp @@ -3,7 +3,7 @@ // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(_M_X64) #define llamafile_mixmul llamafile_mixmul_amd_avxvnni #include "tinyblas_cpu_mixmul.inc" #endif // __x86_64__ diff --git a/third_party/llamafile/tinyblas_cpu_mixmul_amd_fma.cpp b/third_party/llamafile/tinyblas_cpu_mixmul_amd_fma.cpp index 6fd25c9..4d4c4d8 100644 --- a/third_party/llamafile/tinyblas_cpu_mixmul_amd_fma.cpp +++ b/third_party/llamafile/tinyblas_cpu_mixmul_amd_fma.cpp @@ -3,7 +3,7 @@ // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(_M_X64) #define llamafile_mixmul llamafile_mixmul_amd_fma #include "tinyblas_cpu_mixmul.inc" #endif // __x86_64__ diff --git a/third_party/llamafile/tinyblas_cpu_mixmul_amd_zen4.cpp b/third_party/llamafile/tinyblas_cpu_mixmul_amd_zen4.cpp index aaac6e1..3d478c1 100644 --- a/third_party/llamafile/tinyblas_cpu_mixmul_amd_zen4.cpp +++ b/third_party/llamafile/tinyblas_cpu_mixmul_amd_zen4.cpp @@ -3,7 +3,7 @@ // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(_M_X64) #define llamafile_mixmul llamafile_mixmul_amd_zen4 #include "tinyblas_cpu_mixmul.inc" #endif // __x86_64__ diff --git a/third_party/llamafile/tinyblas_cpu_sgemm.inc b/third_party/llamafile/tinyblas_cpu_sgemm.inc index c9d1f47..634dc3e 100644 --- a/third_party/llamafile/tinyblas_cpu_sgemm.inc +++ b/third_party/llamafile/tinyblas_cpu_sgemm.inc @@ -321,8 +321,8 @@ bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void assert(ith < nth); #if QK_K == 256 -#if defined(__x86_64__) -#if defined(__AVX2__) && defined(__FMA__) +#if defined(__x86_64__) || defined(_M_X64) +#if defined(__AVX2__) && (defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))) // if (X86_CHECK(AVX2) && X86_CHECK(FMA)) { if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) { if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float*)C, ldc, ith, nth)) { diff --git a/third_party/llamafile/tinyblas_cpu_sgemm_amd_avx.cpp b/third_party/llamafile/tinyblas_cpu_sgemm_amd_avx.cpp index e57eda6..439e55d 100644 --- a/third_party/llamafile/tinyblas_cpu_sgemm_amd_avx.cpp +++ b/third_party/llamafile/tinyblas_cpu_sgemm_amd_avx.cpp @@ -3,7 +3,7 @@ // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(_M_X64) #define llamafile_sgemm llamafile_sgemm_amd_avx #include "tinyblas_cpu_sgemm.inc" #endif // __x86_64__ diff --git a/third_party/llamafile/tinyblas_cpu_sgemm_amd_avx2.cpp b/third_party/llamafile/tinyblas_cpu_sgemm_amd_avx2.cpp index 0e1fe84..4b46f01 100644 --- a/third_party/llamafile/tinyblas_cpu_sgemm_amd_avx2.cpp +++ b/third_party/llamafile/tinyblas_cpu_sgemm_amd_avx2.cpp @@ -3,7 +3,7 @@ // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(_M_X64) #define llamafile_sgemm llamafile_sgemm_amd_avx2 #include "tinyblas_cpu_sgemm.inc" #endif // __x86_64__ diff --git a/third_party/llamafile/tinyblas_cpu_sgemm_amd_avx512f.cpp b/third_party/llamafile/tinyblas_cpu_sgemm_amd_avx512f.cpp index cafcaa2..16425e4 100644 --- a/third_party/llamafile/tinyblas_cpu_sgemm_amd_avx512f.cpp +++ b/third_party/llamafile/tinyblas_cpu_sgemm_amd_avx512f.cpp @@ -3,7 +3,7 @@ // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(_M_X64) #define llamafile_sgemm llamafile_sgemm_amd_avx512f #include "tinyblas_cpu_sgemm.inc" #endif // __x86_64__ diff --git a/third_party/llamafile/tinyblas_cpu_sgemm_amd_avxvnni.cpp b/third_party/llamafile/tinyblas_cpu_sgemm_amd_avxvnni.cpp index 5d2ddce..a4ac488 100644 --- a/third_party/llamafile/tinyblas_cpu_sgemm_amd_avxvnni.cpp +++ b/third_party/llamafile/tinyblas_cpu_sgemm_amd_avxvnni.cpp @@ -3,7 +3,7 @@ // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(_M_X64) #define llamafile_sgemm llamafile_sgemm_amd_avxvnni #include "tinyblas_cpu_sgemm.inc" #endif // __x86_64__ diff --git a/third_party/llamafile/tinyblas_cpu_sgemm_amd_fma.cpp b/third_party/llamafile/tinyblas_cpu_sgemm_amd_fma.cpp index 275c9b4..e1559da 100644 --- a/third_party/llamafile/tinyblas_cpu_sgemm_amd_fma.cpp +++ b/third_party/llamafile/tinyblas_cpu_sgemm_amd_fma.cpp @@ -3,7 +3,7 @@ // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(_M_X64) #define llamafile_sgemm llamafile_sgemm_amd_fma #include "tinyblas_cpu_sgemm.inc" #endif // __x86_64__ diff --git a/third_party/llamafile/tinyblas_cpu_sgemm_amd_zen4.cpp b/third_party/llamafile/tinyblas_cpu_sgemm_amd_zen4.cpp index 01924a7..f524ba1 100644 --- a/third_party/llamafile/tinyblas_cpu_sgemm_amd_zen4.cpp +++ b/third_party/llamafile/tinyblas_cpu_sgemm_amd_zen4.cpp @@ -3,7 +3,7 @@ // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(_M_X64) #define llamafile_sgemm llamafile_sgemm_amd_zen4 #define iqk_mul_mat iqk_mul_mat_zen4 #include "tinyblas_cpu_sgemm.inc"