[ADD] support multi-gpu qlen>1 q5_k

This commit is contained in:
chenxl 2024-08-12 11:17:29 +00:00
parent f293803156
commit f5f79f5c0e
63 changed files with 3271 additions and 1285 deletions

3
.gitignore vendored
View file

@ -15,3 +15,6 @@ node_modules
compile_commands.json compile_commands.json
*.egg-info* *.egg-info*
*dist/ *dist/
ktransformers/server/local_store/
ktransformers/server_test1.db
*.patch

16
install.bat Normal file
View file

@ -0,0 +1,16 @@
@echo off
REM clear build dirs
rmdir /S /Q ktransformers\ktransformers_ext\build
rmdir /S /Q ktransformers\ktransformers_ext\cuda\build
rmdir /S /Q ktransformers\ktransformers_ext\cuda\dist
rmdir /S /Q ktransformers\ktransformers_ext\out
del /F /Q ktransformers\ktransformers_ext\cuda\*.egg-info
echo Installing python dependencies from requirements.txt
pip install -r requirements-local_chat.txt
echo Installing ktransformers
set KTRANSFORMERS_FORCE_BUILD=TRUE
pip install . --no-build-isolation
echo Installation completed successfully

View file

@ -11,5 +11,5 @@ echo "Installing python dependencies from requirements.txt"
pip install -r requirements-local_chat.txt pip install -r requirements-local_chat.txt
echo "Installing ktransformers" echo "Installing ktransformers"
pip install . --no-build-isolation KTRANSFORMERS_FORCE_BUILD=TRUE pip install . --no-build-isolation
echo "Installation completed successfully" echo "Installation completed successfully"

View file

@ -189,7 +189,13 @@ else()
message(STATUS "Unknown architecture") message(STATUS "Unknown architecture")
endif() endif()
find_package(CUDA REQUIRED) # message(STATUS "CUDAToolkit_ROOT:${CUDAToolkit_ROOT}")
# find_package(FindCUDAToolkit REQUIRED)
# if(CUDAToolkit_FOUND)
# message(STATUS "Found CUDA cudart lib at:${CUDAToolkit_LIBRARY_DIR}")
# else()
# message(STATUS "Can't found CUDA lib")
# endif()
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>") add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>")
add_compile_options("$<$<COMPILE_LANGUAGE:C>:${ARCH_FLAGS}>") add_compile_options("$<$<COMPILE_LANGUAGE:C>:${ARCH_FLAGS}>")
@ -198,7 +204,12 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/pybind11 ${CMAKE_
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llama.cpp ${CMAKE_CURRENT_BINARY_DIR}/third_party/llama.cpp) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llama.cpp ${CMAKE_CURRENT_BINARY_DIR}/third_party/llama.cpp)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
include_directories("${CUDA_INCLUDE_DIRS}") if (WIN32)
include_directories("$ENV{CUDA_PATH}/include")
elseif (UNIX)
find_package(CUDA REQUIRED)
include_directories("${CUDA_INCLUDE_DIRS}")
endif()
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2)
@ -209,4 +220,8 @@ message(STATUS "ALL_SOURCES: ${ALL_SOURCES}")
pybind11_add_module(${PROJECT_NAME} MODULE ${ALL_SOURCES}) pybind11_add_module(${PROJECT_NAME} MODULE ${ALL_SOURCES})
target_link_libraries(${PROJECT_NAME} PRIVATE llama) target_link_libraries(${PROJECT_NAME} PRIVATE llama)
target_link_libraries(${PROJECT_NAME} PRIVATE "/usr/local/cuda/lib64/libcudart.so") if(WIN32)
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart
elseif(UNIX)
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
endif()

View file

@ -3,8 +3,8 @@
* @Author : chenht2022 * @Author : chenht2022
* @Date : 2024-07-16 10:43:18 * @Date : 2024-07-16 10:43:18
* @Version : 1.0.0 * @Version : 1.0.0
* @LastEditors : chenht2022 * @LastEditors : chenxl
* @LastEditTime : 2024-07-25 10:33:47 * @LastEditTime : 2024-08-08 04:23:51
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/ **/
#ifndef CPUINFER_TASKQUEUE_H #ifndef CPUINFER_TASKQUEUE_H
@ -17,6 +17,44 @@
#include <queue> #include <queue>
#include <thread> #include <thread>
#include <vector> #include <vector>
#ifdef _WIN32
#include <windows.h>
#endif
class custom_mutex {
private:
#ifdef _WIN32
HANDLE global_mutex;
#else
std::mutex global_mutex;
#endif
public:
custom_mutex()
{
#ifdef _WIN32
HANDLE global_mutex;
#endif
}
void lock()
{
#ifdef _WIN32
WaitForSingleObject(global_mutex, INFINITE);
#else
global_mutex.lock();
#endif
}
void unlock()
{
#ifdef _WIN32
ReleaseMutex(global_mutex);
#else
global_mutex.lock();
#endif
}
};
class TaskQueue { class TaskQueue {
public: public:
@ -32,7 +70,7 @@ class TaskQueue {
std::queue<std::function<void()>> tasks; std::queue<std::function<void()>> tasks;
std::thread worker; std::thread worker;
std::mutex mutex; custom_mutex mutex;
std::atomic<bool> sync_flag; std::atomic<bool> sync_flag;
std::atomic<bool> exit_flag; std::atomic<bool> exit_flag;
}; };

View file

@ -3,8 +3,8 @@
* @Author : Azure-Tang * @Author : Azure-Tang
* @Date : 2024-07-25 13:38:30 * @Date : 2024-07-25 13:38:30
* @Version : 1.0.0 * @Version : 1.0.0
* @LastEditors : Azure * @LastEditors : kkk1nak0
* @LastEditTime : 2024-07-26 08:36:03 * @LastEditTime : 2024-08-09 01:45:02
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/ **/
@ -23,6 +23,8 @@ PYBIND11_MODULE(KTransformersOps, m) {
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.", m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.", m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.", m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.",

View file

@ -12,12 +12,15 @@ int test(){
} }
torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device);
PYBIND11_MODULE(cudaops, m) { PYBIND11_MODULE(cudaops, m) {
m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.", m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.", m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.", m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("test", &test, "Function to test."); m.def("test", &test, "Function to test.");

View file

@ -1,39 +0,0 @@
#include <cuda_fp16.h>
__device__ float ggml_compute_fp16_to_fp32(uint16_t h) {
return __uint2float_rd(h);
}
static inline float ggml_compute_fp16_to_fp32(uint16_t h) {
uint16_t tmp;
memcpy(&tmp, &h, sizeof(ggml_fp16_t));
return (float)tmp;
}
// define the global table for fp16 to fp32 conversion
__device__ float ggml_table_f32_f16[1 << 16];
// CUDA Kernel to init the table
__global__ void init_fp16_to_fp32_table() {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto blk_id = idx; blk_id<(1 << 16); blk_id+=blockDim.x * gridDim.x){
ggml_table_f32_f16[blk_id] = GGML_COMPUTE_FP16_TO_FP32(blk_id);
}
}
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
extern __device__ float ggml_table_f32_f16[1 << 16]; // Declare as __device__ if used within device code
// This version of the function is designed to be called from within a CUDA kernel
#if !defined(GGML_FP16_TO_FP32)
__device__ float ggml_lookup_fp16_to_fp32(uint16_t f) {
return ggml_table_f32_f16[f];
}
#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
#endif

View file

@ -3,8 +3,8 @@
* @Author : Azure-Tang, Boxin Zhang * @Author : Azure-Tang, Boxin Zhang
* @Date : 2024-07-25 13:38:30 * @Date : 2024-07-25 13:38:30
* @Version : 1.0.0 * @Version : 1.0.0
* @LastEditors : Azure * @LastEditors : kkk1nak0
* @LastEditTime : 2024-07-26 11:58:50 * @LastEditTime : 2024-08-09 07:57:06
* Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c * Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c
* Copyright (c) 2023-2024 The ggml authors * Copyright (c) 2023-2024 The ggml authors
* Copyright (c) 2024 by KVCache.AI, All Rights Reserved. * Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
@ -14,6 +14,7 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <torch/torch.h> #include <torch/torch.h>
#include <cstdint> #include <cstdint>
#include <c10/cuda/CUDAGuard.h>
__global__ void dequantize_q8_0_kernel(float* output, const float* scales, const int8_t* qs, int num_blocks, int blk_size) { __global__ void dequantize_q8_0_kernel(float* output, const float* scales, const int8_t* qs, int num_blocks, int blk_size) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x; int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
@ -59,6 +60,35 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size
} }
} }
__global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 0)));
const float min = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 2)));
const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16);
const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48);
int is = 0;
uint8_t sc, m;
uint8_t u1 = 1, u2 = 2;
uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4);
for (int j = 0; j < 256; j += 64) {
get_scale_min_k4(is + 0, scales, &sc, &m);
const float d1 = d * sc; const float m1 = min * m;
get_scale_min_k4(is + 1, scales, &sc, &m);
const float d2 = d * sc; const float m2 = min * m;
for (int l = 0; l < 32; ++l) *output_blk++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
for (int l = 0; l < 32; ++l) *output_blk++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
ql += 32; is += 2;
u1 <<= 2; u2 <<= 2;
}
}
}
__global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x; int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){ for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
@ -94,6 +124,7 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device) { torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size; int num_blocks = data.numel() / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device);
// create gpu // create gpu
auto options_scales = torch::TensorOptions().dtype(torch::kFloat32).device(device).memory_format(torch::MemoryFormat::Contiguous); auto options_scales = torch::TensorOptions().dtype(torch::kFloat32).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto options_qs = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto options_qs = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
@ -128,6 +159,7 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
// data.numel%blk_size should be 0, else raise err // data.numel%blk_size should be 0, else raise err
int num_blocks = data.numel() / blk_size; int num_blocks = data.numel() / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options); auto data_gpu = torch::empty({data.numel()}, options);
@ -147,6 +179,7 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device) { torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device) {
// data.numel%blk_size should be 0, else raise err // data.numel%blk_size should be 0, else raise err
int num_blocks = data.numel() / blk_size; int num_blocks = data.numel() / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options); auto data_gpu = torch::empty({data.numel()}, options);
@ -162,3 +195,22 @@ torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device de
cudaDeviceSynchronize(); cudaDeviceSynchronize();
return output; return output;
} }
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size;
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options);
data_gpu.copy_(data, false);
// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device));
// Launch kernel
dequantize_q5_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
cudaDeviceSynchronize();
return output;
}

View file

@ -3,8 +3,8 @@
* @Author : Azure-Tang * @Author : Azure-Tang
* @Date : 2024-07-22 09:27:55 * @Date : 2024-07-22 09:27:55
* @Version : 1.0.0 * @Version : 1.0.0
* @LastEditors : Azure * @LastEditors : kkk1nak0
* @LastEditTime : 2024-07-26 08:38:20 * @LastEditTime : 2024-08-09 01:44:21
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/ **/
#pragma once #pragma once
@ -15,4 +15,5 @@
torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device);

View file

@ -23,7 +23,7 @@
*/ */
#include "gptq_marlin.cuh" #include "gptq_marlin.cuh"
#include "gptq_marlin_dtypes.cuh" #include "gptq_marlin_dtypes.cuh"
#include <c10/cuda/CUDAGuard.h>
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \ static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \ std::is_same<scalar_t, nv_bfloat16>::value, \
@ -1703,28 +1703,63 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
thread_m_blocks = exec_cfg.max_m_blocks; thread_m_blocks = exec_cfg.max_m_blocks;
} }
// Define kernel configurations // Define kernel configurations
if (false) { #define undefined_error TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + \
str(prob_n) + ", " + str(prob_k) + "]" + \
", has_act_order = " + str(has_act_order) + \
", num_groups = " + str(num_groups) + \
", group_size = " + str(group_size) + \
", thread_m_blocks = " + str(thread_m_blocks) + \
", thread_n_blocks = " + str(thread_n_blocks) + \
", thread_k_blocks = " + str(thread_k_blocks));
if (num_bits == 4 && num_threads == 256)
{
if (false) {
}
CALL_IF(4, 32, 2, 256)
CALL_IF(4, 16, 4, 256)
CALL_IF(4, 8, 8, 256)
else {
undefined_error
}
}
else if (num_bits == 4 && num_threads == 128)
{
if (false) {
}
CALL_IF(4, 8, 4, 128)
CALL_IF(4, 4, 8, 128)
else {
undefined_error
}
}
else if (num_bits == 8 && num_threads == 256)
{
if (false) {
}
CALL_IF(8, 32, 2, 256)
CALL_IF(8, 16, 4, 256)
CALL_IF(8, 8, 8, 256)
else {
undefined_error
}
}
else if (num_bits == 8 && num_threads == 128)
{
if (false) {
}
CALL_IF(8, 8, 4, 128)
CALL_IF(8, 4, 8, 128)
else {
undefined_error
}
} }
CALL_IF(4, 32, 2, 256)
CALL_IF(4, 16, 4, 256)
CALL_IF(4, 8, 8, 256)
CALL_IF(4, 8, 4, 128)
CALL_IF(4, 4, 8, 128)
CALL_IF(8, 32, 2, 256)
CALL_IF(8, 16, 4, 256)
CALL_IF(8, 8, 8, 256)
CALL_IF(8, 8, 4, 128)
CALL_IF(8, 4, 8, 128)
else { else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + undefined_error
str(prob_n) + ", " + str(prob_k) + "]" +
", has_act_order = " + str(has_act_order) +
", num_groups = " + str(num_groups) +
", group_size = " + str(group_size) +
", thread_m_blocks = " + str(thread_m_blocks) +
", thread_n_blocks = " + str(thread_n_blocks) +
", thread_k_blocks = " + str(thread_k_blocks));
} }
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
@ -1739,6 +1774,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& perm, torch::Tensor& workspace, torch::Tensor& perm, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n, int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full) { int64_t size_k, bool is_k_full) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
// Verify num_bits // Verify num_bits
TORCH_CHECK(num_bits == 4 || num_bits == 8, TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits); "num_bits must be 4 or 8. Got = ", num_bits);
@ -1781,7 +1817,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
// Alloc buffers // Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
torch::Tensor c = torch::empty({size_m, size_n}, options); torch::Tensor c = torch::empty({size_m, size_n}, options);
torch::Tensor a_tmp = torch::empty({size_m, size_k}, options); torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);

View file

@ -2,17 +2,25 @@
from setuptools import setup, Extension from setuptools import setup, Extension
from torch.utils import cpp_extension from torch.utils import cpp_extension
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
# setup marlin gemm name='KTransformersOps',
setup(name='KTransformersOps', ext_modules=[
ext_modules=[ CUDAExtension(
CUDAExtension('KTransformersOps', [ 'KTransformersOps', [
'custom_gguf/dequant.cu', 'custom_gguf/dequant.cu',
'binding.cpp', 'binding.cpp',
'gptq_marlin/gptq_marlin.cu', 'gptq_marlin/gptq_marlin.cu',
# 'gptq_marlin_repack.cu', # 'gptq_marlin_repack.cu',
]) ],
], extra_compile_args={
cmdclass={'build_ext': BuildExtension 'cxx': ['-O3'],
}) 'nvcc': [
'-O3',
'--use_fast_math',
'-Xcompiler', '-fPIC',
]
},
)
],
cmdclass={'build_ext': BuildExtension}
)

View file

@ -37,7 +37,7 @@ class LinearBindings {
Args* args_ = (Args*)args; Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&Linear::warm_up, args_->linear); args_->cpuinfer->enqueue(&Linear::warm_up, args_->linear);
} }
static std::pair<intptr_t, intptr_t> interface(Linear& linear) { static std::pair<intptr_t, intptr_t> cpuinfer_interface(Linear& linear) {
Args* args = new Args{nullptr, &linear}; Args* args = new Args{nullptr, &linear};
return std::make_pair((intptr_t)&inner, (intptr_t)args); return std::make_pair((intptr_t)&inner, (intptr_t)args);
} }
@ -55,7 +55,7 @@ class LinearBindings {
Args* args_ = (Args*)args; Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&Linear::forward, args_->linear, args_->qlen, args_->input, args_->output); args_->cpuinfer->enqueue(&Linear::forward, args_->linear, args_->qlen, args_->input, args_->output);
} }
static std::pair<intptr_t, intptr_t> interface(Linear& linear, int qlen, intptr_t input, intptr_t output) { static std::pair<intptr_t, intptr_t> cpuinfer_interface(Linear& linear, int qlen, intptr_t input, intptr_t output) {
Args* args = new Args{nullptr, &linear, qlen, (const void*)input, (void*)output}; Args* args = new Args{nullptr, &linear, qlen, (const void*)input, (void*)output};
return std::make_pair((intptr_t)&inner, (intptr_t)args); return std::make_pair((intptr_t)&inner, (intptr_t)args);
} }
@ -74,7 +74,7 @@ class MLPBindings {
Args* args_ = (Args*)args; Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MLP::warm_up, args_->mlp); args_->cpuinfer->enqueue(&MLP::warm_up, args_->mlp);
} }
static std::pair<intptr_t, intptr_t> interface(MLP& mlp) { static std::pair<intptr_t, intptr_t> cpuinfer_interface(MLP& mlp) {
Args* args = new Args{nullptr, &mlp}; Args* args = new Args{nullptr, &mlp};
return std::make_pair((intptr_t)&inner, (intptr_t)args); return std::make_pair((intptr_t)&inner, (intptr_t)args);
} }
@ -92,7 +92,7 @@ class MLPBindings {
Args* args_ = (Args*)args; Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MLP::forward, args_->mlp, args_->qlen, args_->input, args_->output); args_->cpuinfer->enqueue(&MLP::forward, args_->mlp, args_->qlen, args_->input, args_->output);
} }
static std::pair<intptr_t, intptr_t> interface(MLP& mlp, int qlen, intptr_t input, intptr_t output) { static std::pair<intptr_t, intptr_t> cpuinfer_interface(MLP& mlp, int qlen, intptr_t input, intptr_t output) {
Args* args = new Args{nullptr, &mlp, qlen, (const void*)input, (void*)output}; Args* args = new Args{nullptr, &mlp, qlen, (const void*)input, (void*)output};
return std::make_pair((intptr_t)&inner, (intptr_t)args); return std::make_pair((intptr_t)&inner, (intptr_t)args);
} }
@ -111,7 +111,7 @@ class MOEBindings {
Args* args_ = (Args*)args; Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MOE::warm_up, args_->moe); args_->cpuinfer->enqueue(&MOE::warm_up, args_->moe);
} }
static std::pair<intptr_t, intptr_t> interface(MOE& moe) { static std::pair<intptr_t, intptr_t> cpuinfer_interface(MOE& moe) {
Args* args = new Args{nullptr, &moe}; Args* args = new Args{nullptr, &moe};
return std::make_pair((intptr_t)&inner, (intptr_t)args); return std::make_pair((intptr_t)&inner, (intptr_t)args);
} }
@ -132,7 +132,7 @@ class MOEBindings {
Args* args_ = (Args*)args; Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MOE::forward, args_->moe, args_->qlen, args_->k, args_->expert_ids, args_->weights, args_->input, args_->output); args_->cpuinfer->enqueue(&MOE::forward, args_->moe, args_->qlen, args_->k, args_->expert_ids, args_->weights, args_->input, args_->output);
} }
static std::pair<intptr_t, intptr_t> interface(MOE& moe, int qlen, int k, intptr_t expert_ids, intptr_t weights, intptr_t input, intptr_t output) { static std::pair<intptr_t, intptr_t> cpuinfer_interface(MOE& moe, int qlen, int k, intptr_t expert_ids, intptr_t weights, intptr_t input, intptr_t output) {
Args* args = new Args{nullptr, &moe, qlen, k, (const uint64_t*)expert_ids, (const float*)weights, (const void*)input, (void*)output}; Args* args = new Args{nullptr, &moe, qlen, k, (const uint64_t*)expert_ids, (const float*)weights, (const void*)input, (void*)output};
return std::make_pair((intptr_t)&inner, (intptr_t)args); return std::make_pair((intptr_t)&inner, (intptr_t)args);
} }
@ -154,8 +154,8 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
})); }));
py::class_<Linear>(linear_module, "Linear") py::class_<Linear>(linear_module, "Linear")
.def(py::init<LinearConfig>()) .def(py::init<LinearConfig>())
.def("warm_up", &LinearBindings::WarmUpBindinds::interface) .def("warm_up", &LinearBindings::WarmUpBindinds::cpuinfer_interface)
.def("forward", &LinearBindings::ForwardBindings::interface); .def("forward", &LinearBindings::ForwardBindings::cpuinfer_interface);
auto mlp_module = m.def_submodule("mlp"); auto mlp_module = m.def_submodule("mlp");
py::class_<MLPConfig>(mlp_module, "MLPConfig") py::class_<MLPConfig>(mlp_module, "MLPConfig")
@ -164,8 +164,8 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
})); }));
py::class_<MLP>(mlp_module, "MLP") py::class_<MLP>(mlp_module, "MLP")
.def(py::init<MLPConfig>()) .def(py::init<MLPConfig>())
.def("warm_up", &MLPBindings::WarmUpBindinds::interface) .def("warm_up", &MLPBindings::WarmUpBindinds::cpuinfer_interface)
.def("forward", &MLPBindings::ForwardBindings::interface); .def("forward", &MLPBindings::ForwardBindings::cpuinfer_interface);
auto moe_module = m.def_submodule("moe"); auto moe_module = m.def_submodule("moe");
py::class_<MOEConfig>(moe_module, "MOEConfig") py::class_<MOEConfig>(moe_module, "MOEConfig")
@ -174,6 +174,6 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
})); }));
py::class_<MOE>(moe_module, "MOE") py::class_<MOE>(moe_module, "MOE")
.def(py::init<MOEConfig>()) .def(py::init<MOEConfig>())
.def("warm_up", &MOEBindings::WarmUpBindinds::interface) .def("warm_up", &MOEBindings::WarmUpBindinds::cpuinfer_interface)
.def("forward", &MOEBindings::ForwardBindings::interface); .def("forward", &MOEBindings::ForwardBindings::cpuinfer_interface);
} }

View file

@ -1,206 +0,0 @@
import math
import os
import time
from logging import getLogger
import torch
import torch.nn as nn
import transformers
from .quantizer import Quantizer
logger = getLogger(__name__)
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
class GPTQ:
def __init__(self, layer):
self.layer = layer
self.dev = self.layer.weight.device
W = layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.pytorch_utils.Conv1D):
W = W.t()
self.rows = W.shape[0]
self.columns = W.shape[1]
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
self.nsamples = 0
self.quantizer = Quantizer()
def add_batch(self, inp, out):
if os.environ.get("DEBUG"):
self.inp1 = inp
self.out1 = out
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()
if isinstance(self.layer, nn.Conv2d):
unfold = nn.Unfold(
self.layer.kernel_size,
dilation=self.layer.dilation,
padding=self.layer.padding,
stride=self.layer.stride,
)
inp = unfold(inp)
inp = inp.permute([1, 0, 2])
inp = inp.flatten(1)
self.H *= self.nsamples / (self.nsamples + tmp)
self.nsamples += tmp
# inp = inp.float()
inp = math.sqrt(2 / self.nsamples) * inp.float()
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
self.H += inp.matmul(inp.t())
def fasterquant(
self,
blocksize=128,
percdamp=0.01,
group_size=-1,
actorder=False,
static_groups=False,
):
W = self.layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
W = W.float()
tick = time.time()
if not self.quantizer.ready():
self.quantizer.find_params(W, weight=True)
H = self.H
del self.H
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0
g_idx = []
scale = []
zero = []
now_idx = 1
if static_groups:
import copy
groups = []
for i in range(0, self.columns, group_size):
quantizer = copy.deepcopy(self.quantizer)
quantizer.find_params(W[:, i : (i + group_size)], weight=True)
scale.append(quantizer.scale)
zero.append(quantizer.zero)
groups.append(quantizer)
if actorder:
perm = torch.argsort(torch.diag(H), descending=True)
W = W[:, perm]
H = H[perm][:, perm]
invperm = torch.argsort(perm)
Losses = torch.zeros_like(W)
Q = torch.zeros_like(W)
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1
W1 = W[:, i1:i2].clone()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]
for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]
if group_size != -1:
if not static_groups:
if (i1 + i) % group_size == 0:
self.quantizer.find_params(W[:, (i1 + i) : (i1 + i + group_size)], weight=True)
if ((i1 + i) // group_size) - now_idx == -1:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
now_idx += 1
else:
idx = i1 + i
if actorder:
idx = perm[idx]
self.quantizer = groups[idx // group_size]
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d**2
err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
Err1[:, i] = err1
Q[:, i1:i2] = Q1
Losses[:, i1:i2] = Losses1 / 2
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
if os.environ.get("DEBUG"):
self.layer.weight.data[:, :i2] = Q[:, :i2]
self.layer.weight.data[:, i2:] = W[:, i2:]
logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
logger.debug(torch.sum(Losses))
torch.cuda.synchronize()
logger.info(f"duration: {(time.time() - tick)}")
logger.info(f"avg loss: {torch.sum(Losses).item() / self.nsamples}")
group_size = group_size if group_size != -1 else self.columns
if static_groups and actorder:
g_idx = [perm[i] // group_size for i in range(self.columns)]
else:
g_idx = [i // group_size for i in range(self.columns)]
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
if actorder:
Q = Q[:, invperm]
g_idx = g_idx[invperm]
if isinstance(self.layer, transformers.Conv1D):
Q = Q.t()
self.layer.weight.data = Q.reshape(self.layer.weight.shape).type_as(self.layer.weight.data)
if os.environ.get("DEBUG"):
logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
if scale == []:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
scale = torch.cat(scale, dim=1)
zero = torch.cat(zero, dim=1)
return scale, zero, g_idx
def free(self):
if os.environ.get("DEBUG"):
self.inp1 = None
self.out1 = None
self.H = None
self.Losses = None
self.Trace = None
torch.cuda.empty_cache()
__all__ = ["GPTQ"]

View file

@ -1,458 +0,0 @@
import enum
from enum import Enum
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
logger = init_logger(__name__)
GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
GPTQ_MARLIN_SUPPORTED_SYM = [True]
# Permutations for Marlin scale shuffling
def get_scale_perms(num_bits: int):
scale_perm: List[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: List[int] = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
def get_pack_factor(num_bits: int):
assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
), f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
group_size: int, num_bits: int):
scale_perm, scale_perm_single = get_scale_perms(num_bits)
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()
return s
class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin"""
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool) -> None:
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act = False
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
# Verify
if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
raise ValueError(
f"Marlin does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
"are supported.")
if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
raise ValueError(
f"Marlin does not support group_size = {self.group_size}. "
f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
"are supported.")
if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
raise ValueError(
f"Marlin does not support is_sym = {self.is_sym}. "
f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
# Init
self.pack_factor = get_pack_factor(weight_bits)
self.tile_size = GPTQ_MARLIN_TILE
self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N
self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K
self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL
def __repr__(self) -> str:
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act})")
@classmethod
def get_name(cls) -> str:
return "gptq_marlin"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
return cls(weight_bits, group_size, desc_act, is_sym)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
can_convert = cls.is_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "marlin")
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
if can_convert and user_quant == "gptq":
logger.info("Detected that the model can run with gptq_marlin"
", however you specified quantization=gptq explicitly,"
" so forcing gptq. Use quantization=gptq_marlin for"
" faster inference")
return None
def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
if isinstance(layer, LinearBase):
return GPTQMarlinLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
@classmethod
def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
sym = quant_config.get("sym", None)
desc_act = quant_config.get("desc_act", None)
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or sym is None
or desc_act is None):
return False
# If the capability of the device is too low, cannot convert.
major, minor = torch.cuda.get_device_capability()
device_capability = major * 10 + minor
if device_capability < cls.get_min_capability():
return False
# Otherwise, can convert if model satisfies marlin constraints.
return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
and sym in GPTQ_MARLIN_SUPPORTED_SYM)
class GPTQMarlinState(Enum):
REPACK = enum.auto()
READY = enum.auto()
class GPTQMarlinLinearMethod(LinearMethodBase):
"""Linear method for GPTQ Marlin.
Args:
quant_config: The GPTQ Marlin quantization config.
"""
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
del output_size
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
# Validate dtype
if params_dtype not in [torch.float16, torch.bfloat16]:
raise ValueError(f"The params dtype must be float16 "
f"or bfloat16, but got {params_dtype}")
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.min_thread_n != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f" min_thread_n = {self.quant_config.min_thread_n}.")
# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_thread_k != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible "
f"by min_thread_k = {self.quant_config.min_thread_k}.")
if (group_size < input_size
and input_size_per_partition % group_size != 0):
raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition}"
f" is not divisible by group_size = {group_size}.")
# Detect sharding of scales/zp
# By default, no sharding over "input dim"
scales_and_zp_size = input_size // group_size
scales_and_zp_input_dim = None
if self.quant_config.desc_act:
# Act-order case
assert self.quant_config.group_size != -1
is_k_full = input_size_per_partition == input_size
else:
# No act-order case
# K is always full due to full alignment with
# group-size and shard of scales/zp
is_k_full = True
# If this is a row-parallel case, then shard scales/zp
if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
scales_and_zp_size = input_size_per_partition // group_size
scales_and_zp_input_dim = 0
# Init buffers
# Quantized weights
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
**extra_weight_attrs,
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
},
)
# Activation order
g_idx = Parameter(
torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs(
g_idx,
{
**extra_weight_attrs, "input_dim": 0,
"ignore_warning": True
},
)
g_idx_sort_indices = torch.empty(
g_idx.shape,
dtype=torch.int32,
)
# Scales
scales = Parameter(
torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
},
)
# Quantized zero-points
qzeros = Parameter(
torch.empty(
scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
device="meta",
),
requires_grad=False,
)
set_weight_attrs(
qzeros,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
},
)
# Allocate marlin workspace
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_thread_n) * self.quant_config.max_parallel
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
requires_grad=False)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
layer.g_idx_sort_indices = g_idx_sort_indices
layer.workspace = workspace
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.is_k_full = is_k_full
layer.marlin_state = GPTQMarlinState.REPACK
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
size_m = reshaped_x.shape[0]
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
full_size_k = layer.input_size
out_shape = x.shape[:-1] + (part_size_n, )
if layer.marlin_state == GPTQMarlinState.REPACK:
layer.marlin_state = GPTQMarlinState.READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_tensor(name, new_t):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr(layer, name).resize_(new_t.shape)
getattr(layer, name).copy_(new_t)
del new_t
cur_device = layer.qweight.device
# Process act_order
if self.quant_config.desc_act:
# Get sorting based on g_idx
g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int)
sorted_g_idx = layer.g_idx[g_idx_sort_indices]
replace_tensor("g_idx", sorted_g_idx)
replace_tensor("g_idx_sort_indices", g_idx_sort_indices)
else:
# Reset g_idx related tensors
layer.g_idx = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
layer.g_idx_sort_indices = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
# Repack weights
marlin_qweight = ops.gptq_marlin_repack(
layer.qweight,
layer.g_idx_sort_indices,
part_size_k,
part_size_n,
self.quant_config.weight_bits,
)
replace_tensor("qweight", marlin_qweight)
# Permute scales
scales_size_k = part_size_k
scales_size_n = part_size_n
if self.quant_config.desc_act:
scales_size_k = full_size_k
marlin_scales = marlin_permute_scales(
layer.scales,
scales_size_k,
scales_size_n,
self.quant_config.group_size,
self.quant_config.weight_bits,
)
replace_tensor("scales", marlin_scales)
output = ops.gptq_marlin_gemm(
reshaped_x,
layer.qweight,
layer.scales,
layer.g_idx,
layer.g_idx_sort_indices,
layer.workspace,
self.quant_config.weight_bits,
size_m,
part_size_n,
part_size_k,
layer.is_k_full,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)

View file

@ -1,140 +0,0 @@
from logging import getLogger
import torch
import torch.nn as nn
logger = getLogger(__name__)
def quantize(x, scale, zero, maxq):
if maxq < 0:
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
return scale * (q - zero)
class Quantizer(nn.Module):
def __init__(self, shape=1):
super(Quantizer, self).__init__()
self.register_buffer("maxq", torch.tensor(0))
self.register_buffer("scale", torch.zeros(shape))
self.register_buffer("zero", torch.zeros(shape))
def configure(
self,
bits,
perchannel=False,
sym=True,
mse=False,
norm=2.4,
grid=100,
maxshrink=0.8,
trits=False,
):
self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel
self.sym = sym
self.mse = mse
self.norm = norm
self.grid = grid
self.maxshrink = maxshrink
if trits:
self.maxq = torch.tensor(-1)
def find_params(self, x, weight=False):
dev = x.device
self.maxq = self.maxq.to(dev)
shape = x.shape
if self.perchannel:
if weight:
x = x.flatten(1)
else:
if len(shape) == 4:
x = x.permute([1, 0, 2, 3])
x = x.flatten(1)
if len(shape) == 3:
x = x.reshape((-1, shape[-1])).t()
if len(shape) == 2:
x = x.t()
else:
x = x.flatten().unsqueeze(0)
tmp = torch.zeros(x.shape[0], device=dev)
xmin = torch.minimum(x.min(1)[0], tmp)
xmax = torch.maximum(x.max(1)[0], tmp)
if self.sym:
xmax = torch.maximum(torch.abs(xmin), xmax)
tmp = xmin < 0
if torch.any(tmp):
xmin[tmp] = -xmax[tmp]
tmp = (xmin == 0) & (xmax == 0)
xmin[tmp] = -1
xmax[tmp] = +1
if self.maxq < 0:
self.scale = xmax
self.zero = xmin
else:
self.scale = (xmax - xmin) / self.maxq
if self.sym:
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
else:
self.zero = torch.round(-xmin / self.scale)
if self.mse:
best = torch.full([x.shape[0]], float("inf"), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
xmin1 = p * xmin
xmax1 = p * xmax
scale1 = (xmax1 - xmin1) / self.maxq
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
q -= x
q.abs_()
q.pow_(self.norm)
err = torch.sum(q, 1)
tmp = err < best
if torch.any(tmp):
best[tmp] = err[tmp]
self.scale[tmp] = scale1[tmp]
self.zero[tmp] = zero1[tmp]
if not self.perchannel:
if weight:
tmp = shape[0]
else:
tmp = shape[1] if len(shape) != 3 else shape[2]
self.scale = self.scale.repeat(tmp)
self.zero = self.zero.repeat(tmp)
if weight:
shape = [-1] + [1] * (len(shape) - 1)
self.scale = self.scale.reshape(shape)
self.zero = self.zero.reshape(shape)
return
if len(shape) == 4:
self.scale = self.scale.reshape((1, -1, 1, 1))
self.zero = self.zero.reshape((1, -1, 1, 1))
if len(shape) == 3:
self.scale = self.scale.reshape((1, 1, -1))
self.zero = self.zero.reshape((1, 1, -1))
if len(shape) == 2:
self.scale = self.scale.unsqueeze(0)
self.zero = self.zero.unsqueeze(0)
def quantize(self, x):
if self.ready():
return quantize(x, self.scale, self.zero, self.maxq)
return x
def enabled(self):
return self.maxq > 0
def ready(self):
return torch.all(self.scale != 0)
__all__ = ["Quantizer"]

View file

@ -1,99 +0,0 @@
import torch
import enum
from enum import Enum
from typing import Any, Dict, List, Optional
from torch.nn.parameter import Parameter
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
size_m = reshaped_x.shape[0]
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
full_size_k = layer.input_size
out_shape = x.shape[:-1] + (part_size_n, )
if layer.marlin_state == GPTQMarlinState.REPACK:
layer.marlin_state = GPTQMarlinState.READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_tensor(name, new_t):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr(layer, name).resize_(new_t.shape)
getattr(layer, name).copy_(new_t)
del new_t
cur_device = layer.qweight.device
# Process act_order
if self.quant_config.desc_act:
# Get sorting based on g_idx
g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int)
sorted_g_idx = layer.g_idx[g_idx_sort_indices]
replace_tensor("g_idx", sorted_g_idx)
replace_tensor("g_idx_sort_indices", g_idx_sort_indices)
else:
# Reset g_idx related tensors
layer.g_idx = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
layer.g_idx_sort_indices = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
# Repack weights
marlin_qweight = ops.gptq_marlin_repack(
layer.qweight,
layer.g_idx_sort_indices,
part_size_k,
part_size_n,
self.quant_config.weight_bits,
)
replace_tensor("qweight", marlin_qweight)
# Permute scales
scales_size_k = part_size_k
scales_size_n = part_size_n
if self.quant_config.desc_act:
scales_size_k = full_size_k
marlin_scales = marlin_permute_scales(
layer.scales,
scales_size_k,
scales_size_n,
self.quant_config.group_size,
self.quant_config.weight_bits,
)
replace_tensor("scales", marlin_scales)
output = ops.gptq_marlin_gemm(
reshaped_x,
layer.qweight,
layer.scales,
layer.g_idx,
layer.g_idx_sort_indices,
layer.workspace,
self.quant_config.weight_bits,
size_m,
part_size_n,
part_size_k,
layer.is_k_full,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)

View file

@ -220,7 +220,7 @@ def compute_max_diff(output, output_ref):
class MarlinWorkspace: class MarlinWorkspace:
def __init__(self, out_features, min_thread_n, max_parallel): def __init__(self, out_features, min_thread_n, max_parallel, device):
assert (out_features % min_thread_n == 0), ( assert (out_features % min_thread_n == 0), (
"out_features = {} is undivisible by min_thread_n = {}".format( "out_features = {} is undivisible by min_thread_n = {}".format(
out_features, min_thread_n)) out_features, min_thread_n))
@ -229,4 +229,4 @@ class MarlinWorkspace:
self.scratch = torch.zeros(max_workspace_size, self.scratch = torch.zeros(max_workspace_size,
dtype=torch.int, dtype=torch.int,
device="cuda") device=device)

View file

@ -47,13 +47,13 @@ void Linear::forward_many(int qlen, const void* input, void* output, Backend* ba
int nth = config_.output_size / config_.stride; int nth = config_.output_size / config_.stride;
backend->do_work_stealing_job(nth, [&](int task_id) { backend->do_work_stealing_job(nth, [&](int task_id) {
int ith = task_id; int ith = task_id;
void* proj_ptr = proj_ + ith * config_.stride * config_.input_size * ggml_type_size(config_.proj_type) / ggml_blck_size(config_.proj_type); void* proj_ptr = (uint8_t*)proj_ + ith * config_.stride * config_.input_size * ggml_type_size(config_.proj_type) / ggml_blck_size(config_.proj_type);
float* proj_output_ptr = proj_output_ + ith * config_.stride; float* proj_output_ptr = proj_output_ + ith * config_.stride;
llamafile_sgemm(config_.stride, qlen, config_.input_size / ggml_blck_size(config_.proj_type), proj_ptr, config_.input_size / ggml_blck_size(config_.proj_type), proj_input_ptr, config_.input_size / ggml_blck_size(config_.proj_type), proj_output_ptr, config_.output_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.proj_type, ggml_internal_get_type_traits(config_.proj_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(config_.stride, qlen, config_.input_size / ggml_blck_size(config_.proj_type), proj_ptr, config_.input_size / ggml_blck_size(config_.proj_type), proj_input_ptr, config_.input_size / ggml_blck_size(config_.proj_type), proj_output_ptr, config_.output_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.proj_type, ggml_internal_get_type_traits(config_.proj_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) { if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) {
for (int i = 0; i < qlen; i++) { for (int i = 0; i < qlen; i++) {
float* output_fp32_ptr = proj_output_ + i * config_.output_size + ith * config_.stride; float* output_fp32_ptr = proj_output_ + i * config_.output_size + ith * config_.stride;
void* output_ptr = output + i * config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); void* output_ptr = (uint8_t*)output + i * config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);
from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type); from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);
} }
} }
@ -69,5 +69,5 @@ void Linear::forward(int qlen, const void* input, void* output, Backend* backend
} }
int forward_len = std::min(qlen, config_.group_max_len); int forward_len = std::min(qlen, config_.group_max_len);
forward_many(forward_len, input, output, backend); forward_many(forward_len, input, output, backend);
forward(qlen - forward_len, input + forward_len * config_.input_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), output + forward_len * config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend); forward(qlen - forward_len, (uint8_t*)input + forward_len * config_.input_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend);
} }

View file

@ -74,10 +74,10 @@ void MLP::forward_many(int qlen, const void* input, void* output, Backend* backe
int nth = config_.intermediate_size / config_.stride; int nth = config_.intermediate_size / config_.stride;
backend->do_work_stealing_job(nth, [&](int task_id) { backend->do_work_stealing_job(nth, [&](int task_id) {
int ith = task_id; int ith = task_id;
void* gate_proj_ptr = gate_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); void* gate_proj_ptr = (uint8_t*)gate_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
float* gate_output_ptr = gate_output_ + ith * config_.stride; float* gate_output_ptr = gate_output_ + ith * config_.stride;
llamafile_sgemm(config_.stride, qlen, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(config_.stride, qlen, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
void* up_proj_ptr = up_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); void* up_proj_ptr = (uint8_t*)up_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
float* up_output_ptr = up_output_ + ith * config_.stride; float* up_output_ptr = up_output_ + ith * config_.stride;
llamafile_sgemm(config_.stride, qlen, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(config_.stride, qlen, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = 0; i < qlen; i++) { for (int i = 0; i < qlen; i++) {
@ -86,7 +86,7 @@ void MLP::forward_many(int qlen, const void* input, void* output, Backend* backe
} }
if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) == 0) { if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) == 0) {
float* intermediate_fp32_ptr = intermediate_fp32_ + i * config_.intermediate_size + ith * config_.stride; float* intermediate_fp32_ptr = intermediate_fp32_ + i * config_.intermediate_size + ith * config_.stride;
void* down_input_ptr = down_input_ + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * config_.stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type); void* down_input_ptr = (uint8_t*)down_input_ + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * config_.stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
from_float(intermediate_fp32_ptr, down_input_ptr, config_.stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type); from_float(intermediate_fp32_ptr, down_input_ptr, config_.stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
} }
} }
@ -97,13 +97,13 @@ void MLP::forward_many(int qlen, const void* input, void* output, Backend* backe
nth = config_.hidden_size / config_.stride; nth = config_.hidden_size / config_.stride;
backend->do_work_stealing_job(nth, [&](int task_id) { backend->do_work_stealing_job(nth, [&](int task_id) {
int ith = task_id; int ith = task_id;
void* down_proj_ptr = down_proj_ + ith * config_.stride * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); void* down_proj_ptr = (uint8_t*)down_proj_ + ith * config_.stride * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
float* down_output_ptr = down_output_ + ith * config_.stride; float* down_output_ptr = down_output_ + ith * config_.stride;
llamafile_sgemm(config_.stride, qlen, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(config_.stride, qlen, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) { if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) {
for (int i = 0; i < qlen; i++) { for (int i = 0; i < qlen; i++) {
float* output_fp32_ptr = down_output_ + i * config_.hidden_size + ith * config_.stride; float* output_fp32_ptr = down_output_ + i * config_.hidden_size + ith * config_.stride;
void* output_ptr = output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); void* output_ptr = (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);
from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type); from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);
} }
} }
@ -119,5 +119,5 @@ void MLP::forward(int qlen, const void* input, void* output, Backend* backend) {
} }
int forward_len = std::min(qlen, config_.group_max_len); int forward_len = std::min(qlen, config_.group_max_len);
forward_many(forward_len, input, output, backend); forward_many(forward_len, input, output, backend);
forward(qlen - forward_len, input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend); forward(qlen - forward_len, (uint8_t*)input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend);
} }

View file

@ -9,7 +9,7 @@
**/ **/
#include "moe.h" #include "moe.h"
#include <iostream> #include <iostream>
#include "unistd.h" #include <cstdint>
MOE::MOE(MOEConfig config) { MOE::MOE(MOEConfig config) {
config_ = config; config_ = config;
@ -60,7 +60,7 @@ MOE::MOE(MOEConfig config) {
m_local_pos_.resize(config_.group_max_len); m_local_pos_.resize(config_.group_max_len);
for (int i = 0; i < config_.group_max_len; i++) { for (int i = 0; i < config_.group_max_len; i++) {
m_local_pos_[i].reserve(config_.expert_num); m_local_pos_[i].resize(config_.routed_expert_num);
} }
m_local_num_.resize(config_.expert_num); m_local_num_.resize(config_.expert_num);
m_local_gate_input_ptr_.resize(config_.expert_num); m_local_gate_input_ptr_.resize(config_.expert_num);
@ -125,10 +125,10 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c
int expert_idx = task_id / nth; int expert_idx = task_id / nth;
uint64_t expert_id = expert_ids[expert_idx]; uint64_t expert_id = expert_ids[expert_idx];
int ith = task_id % nth; int ith = task_id % nth;
void* gate_proj_ptr = gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
float* gate_output_ptr = s_gate_output_[expert_idx] + ith * config_.stride; float* gate_output_ptr = s_gate_output_[expert_idx] + ith * config_.stride;
llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
void* up_proj_ptr = up_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride; float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride;
llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
@ -153,7 +153,7 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c
} }
for (int expert_idx = 0; expert_idx < k; expert_idx++) { for (int expert_idx = 0; expert_idx < k; expert_idx++) {
uint64_t expert_id = expert_ids[expert_idx]; uint64_t expert_id = expert_ids[expert_idx];
void* down_proj_ptr = down_proj_ + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
float* down_output_ptr = s_down_output_[expert_idx] + ith * config_.stride; float* down_output_ptr = s_down_output_[expert_idx] + ith * config_.stride;
llamafile_sgemm(config_.stride, 1, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), s_down_input_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(config_.stride, 1, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), s_down_input_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
@ -162,7 +162,7 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c
} }
if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) { if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) {
float* output_fp32_ptr = s_output_fp32_ + ith * config_.stride; float* output_fp32_ptr = s_output_fp32_ + ith * config_.stride;
void* output_ptr = output + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); void* output_ptr = (uint8_t*)output + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);
from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type); from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);
} }
}); });
@ -195,9 +195,9 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
const void* gate_input_ptr; const void* gate_input_ptr;
const void* up_input_ptr; const void* up_input_ptr;
if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type && config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) { if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type && config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {
gate_input_ptr = up_input_ptr = input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); gate_input_ptr = up_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);
} else { } else {
to_float(input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), m_input_fp32_[i], config_.hidden_size, config_.hidden_type); to_float((uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), m_input_fp32_[i], config_.hidden_size, config_.hidden_type);
if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) { if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {
from_float(m_input_fp32_[i], m_gate_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type); from_float(m_input_fp32_[i], m_gate_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);
gate_input_ptr = up_input_ptr = m_gate_input_[i]; gate_input_ptr = up_input_ptr = m_gate_input_[i];
@ -206,13 +206,13 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
from_float(m_input_fp32_[i], m_gate_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type); from_float(m_input_fp32_[i], m_gate_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);
gate_input_ptr = m_gate_input_[i]; gate_input_ptr = m_gate_input_[i];
} else { } else {
gate_input_ptr = input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); gate_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);
} }
if (config_.hidden_type != ggml_internal_get_type_traits(config_.up_type).vec_dot_type) { if (config_.hidden_type != ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {
from_float(m_input_fp32_[i], m_up_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.up_type).vec_dot_type); from_float(m_input_fp32_[i], m_up_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.up_type).vec_dot_type);
up_input_ptr = m_up_input_[i]; up_input_ptr = m_up_input_[i];
} else { } else {
up_input_ptr = input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); up_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);
} }
} }
} }
@ -227,11 +227,11 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
int expert_idx = task_id / nth; int expert_idx = task_id / nth;
int ith = task_id % nth; int ith = task_id % nth;
void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx]; void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx];
void* gate_proj_ptr = gate_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
float* gate_output_ptr = m_local_gate_output_ptr_[expert_idx] + ith * stride; float* gate_output_ptr = m_local_gate_output_ptr_[expert_idx] + ith * stride;
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
void* up_input_ptr = m_local_up_input_ptr_[expert_idx]; void* up_input_ptr = m_local_up_input_ptr_[expert_idx];
void* up_proj_ptr = up_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride; float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride;
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = 0; i < m_local_num_[expert_idx]; i++) { for (int i = 0; i < m_local_num_[expert_idx]; i++) {
@ -249,7 +249,7 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
int expert_idx = task_id / nth; int expert_idx = task_id / nth;
int ith = task_id % nth; int ith = task_id % nth;
void* down_input_ptr = m_local_down_input_ptr_[expert_idx]; void* down_input_ptr = m_local_down_input_ptr_[expert_idx];
void* down_proj_ptr = down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
float* down_output_ptr = m_local_down_output_ptr_[expert_idx] + ith * stride; float* down_output_ptr = m_local_down_output_ptr_[expert_idx] + ith * stride;
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
}); });
@ -262,18 +262,18 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
m_output_fp32_[i][e] += m_local_down_output_ptr_[expert_ids[i * k + j]][m_local_pos_[i][j] * config_.hidden_size + e] * weights[i * k + j]; m_output_fp32_[i][e] += m_local_down_output_ptr_[expert_ids[i * k + j]][m_local_pos_[i][j] * config_.hidden_size + e] * weights[i * k + j];
} }
} }
from_float(m_output_fp32_[i], output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), config_.hidden_size, config_.hidden_type); from_float(m_output_fp32_[i], (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), config_.hidden_size, config_.hidden_type);
}); });
} }
void MOE::forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend) { void MOE::forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend) {
if (qlen < config_.group_min_len) { if (qlen < config_.group_min_len) {
for (int i = 0; i < qlen; i++) { for (int i = 0; i < qlen; i++) {
forward_one(k, expert_ids + i * k, weights + i * k, input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend); forward_one(k, expert_ids + i * k, weights + i * k, (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend);
} }
return; return;
} }
int forward_len = std::min(config_.group_max_len, qlen); int forward_len = std::min(config_.group_max_len, qlen);
forward_many(forward_len, k, expert_ids, weights, input, output, backend); forward_many(forward_len, k, expert_ids, weights, input, output, backend);
forward(qlen - forward_len, k, expert_ids + forward_len * k, weights + forward_len * k, input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend); forward(qlen - forward_len, k, expert_ids + forward_len * k, weights + forward_len * k, (uint8_t*)input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend);
} }

View file

@ -49,7 +49,7 @@ void SharedMemBuffer::dealloc(void* object) {
void SharedMemBuffer::arrange(std::vector<std::pair<void**, uint64_t>> requests) { void SharedMemBuffer::arrange(std::vector<std::pair<void**, uint64_t>> requests) {
uint64_t offset = 0; uint64_t offset = 0;
for (auto& request : requests) { for (auto& request : requests) {
*(request.first) = buffer_ + offset; *(request.first) = (uint8_t*)buffer_ + offset;
offset += request.second; offset += request.second;
} }
} }

11
ktransformers/local_chat.py Normal file → Executable file
View file

@ -31,18 +31,21 @@ import fire
from ktransformers.optimize.optimize import optimize_and_load_gguf from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate from ktransformers.util.utils import prefill_and_generate
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
custom_models = { custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM, "Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
"MixtralForCausalLM": MixtralForCausalLM,
} }
ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
default_optimize_rules ={ default_optimize_rules ={
"DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml", "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
"MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml",
} }
def local_chat( def local_chat(
@ -50,7 +53,8 @@ def local_chat(
optimize_rule_path: str = None, optimize_rule_path: str = None,
gguf_path: str = None, gguf_path: str = None,
max_new_tokens: int = 1000, max_new_tokens: int = 1000,
cpu_infer: int = Config().cpu_infer cpu_infer: int = Config().cpu_infer,
use_cuda_graph: bool = True,
): ):
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
@ -64,6 +68,8 @@ def local_chat(
print("using custom modeling_xxx.py.") print("using custom modeling_xxx.py.")
if "Qwen2Moe" in config.architectures[0]: # Qwen2Moe must use flash_attention_2 to avoid overflow. if "Qwen2Moe" in config.architectures[0]: # Qwen2Moe must use flash_attention_2 to avoid overflow.
config._attn_implementation = "flash_attention_2" config._attn_implementation = "flash_attention_2"
if "Mixtral" in config.architectures[0]:
config._attn_implementation = "flash_attention_2"
model = custom_models[config.architectures[0]](config) model = custom_models[config.architectures[0]](config)
else: else:
model = AutoModelForCausalLM.from_config( model = AutoModelForCausalLM.from_config(
@ -100,7 +106,6 @@ def local_chat(
while True: while True:
content = input("Chat: ") content = input("Chat: ")
# if content is num
if content == "": if content == "":
content = "Please write a piece of quicksort code in C++." content = "Please write a piece of quicksort code in C++."
@ -109,7 +114,7 @@ def local_chat(
messages, add_generation_prompt=True, return_tensors="pt" messages, add_generation_prompt=True, return_tensors="pt"
) )
torch.set_default_dtype(torch.bfloat16) # TODO: Remove this, replace dtype using config torch.set_default_dtype(torch.bfloat16) # TODO: Remove this, replace dtype using config
generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens) generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph)
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(local_chat) fire.Fire(local_chat)

View file

@ -22,13 +22,14 @@ class StaticCache(transformers.StaticCache):
The maximum batch size with which the model will be used. The maximum batch size with which the model will be used.
max_cache_len (`int`): max_cache_len (`int`):
The maximum sequence length with which the model will be used. The maximum sequence length with which the model will be used.
device (`torch.device`): device (`torch.device` or `dict`):
The device on which the cache should be initialized. Should be the same as the layer. The device on which the cache should be initialized. Should be the same as the layer.
If a `dict`, it should contain the `device` key with the device name as the value.
dtype (*optional*, defaults to `torch.float32`): dtype (*optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer. The default `dtype` to use when initializing the layer.
""" """
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device: torch.device| dict, dtype=None) -> None:
Cache.__init__(self) Cache.__init__(self)
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
@ -46,6 +47,7 @@ class StaticCache(transformers.StaticCache):
self.value_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = []
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
if config.architectures[0] == "DeepseekV2ForCausalLM": if config.architectures[0] == "DeepseekV2ForCausalLM":
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
# key_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.qk_rope_head_dim + config.qk_nope_head_dim) # key_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.qk_rope_head_dim + config.qk_nope_head_dim)
# value_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.v_head_dim) # value_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.v_head_dim)
key_shape = (max_batch_size, 1, self.max_cache_len, config.qk_rope_head_dim) key_shape = (max_batch_size, 1, self.max_cache_len, config.qk_rope_head_dim)
@ -56,11 +58,15 @@ class StaticCache(transformers.StaticCache):
self.past_tokens = [] self.past_tokens = []
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
for _ in range(self.num_hidden_layers): for idx in range(self.num_hidden_layers):
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. # breaks when updating the cache.
new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=device) if isinstance(device, dict):
new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=device) target_device = device[f"blk.{idx}.self_attn"]["generate_device"]
else:
target_device = device
new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=target_device)
new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=target_device)
torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache) torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache) self.key_cache.append(new_layer_key_cache)

View file

@ -1048,7 +1048,7 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention):
""" """
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores. first unpad the input, then computes the attention scores and pad the final attention scores.
Args: # Args:
query_states (`torch.Tensor`): query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`): key_states (`torch.Tensor`):
@ -1245,12 +1245,14 @@ class DeepseekV2DecoderLayer(nn.Module):
cache_position=cache_position, cache_position=cache_position,
**kwargs, **kwargs,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# Fully Connected # Fully Connected
residual = hidden_states residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
outputs = (hidden_states,) outputs = (hidden_states,)

File diff suppressed because it is too large Load diff

View file

@ -10,6 +10,7 @@ from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import InferenceState from ktransformers.util.utils import InferenceState
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding): class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
def __init__(self, def __init__(self,
@ -17,12 +18,16 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
gguf_loader : GGUFLoader, gguf_loader : GGUFLoader,
config: PretrainedConfig, config: PretrainedConfig,
orig_module: nn.Module, orig_module: nn.Module,
device: str = "cuda", # device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs): **kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.orig_module.__init__(orig_module.dim, self.orig_module.__init__(orig_module.dim,
orig_module.max_position_embeddings, orig_module.max_position_embeddings,
orig_module.base) orig_module.base)
self.generate_device = generate_device
self.prefill_device = prefill_device
def load(self): def load(self):
self.orig_module.__init__(self.orig_module.dim, self.orig_module.__init__(self.orig_module.dim,
@ -36,9 +41,11 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
gguf_loader : GGUFLoader, gguf_loader : GGUFLoader,
config: PretrainedConfig, config: PretrainedConfig,
orig_module: nn.Module, orig_module: nn.Module,
device: str = "cuda", # device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs): **kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.orig_module.__init__(orig_module.dim, self.orig_module.__init__(orig_module.dim,
orig_module.max_position_embeddings, orig_module.max_position_embeddings,
orig_module.base, orig_module.base,
@ -49,13 +56,15 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
orig_module.beta_slow, orig_module.beta_slow,
orig_module.mscale, orig_module.mscale,
orig_module.mscale_all_dim) orig_module.mscale_all_dim)
self.generate_device = generate_device
self.prefill_device = prefill_device
def load(self): def load(self):
self.orig_module.__init__(self.orig_module.dim, self.orig_module.__init__(self.orig_module.dim,
self.orig_module.max_position_embeddings, self.orig_module.max_position_embeddings,
self.orig_module.base, self.orig_module.base,
self.device, self.generate_device,
self.orig_module.scaling_factor, self.orig_module.scaling_factor,
self.orig_module.original_max_position_embeddings, self.orig_module.original_max_position_embeddings,
self.orig_module.beta_fast, self.orig_module.beta_fast,

View file

@ -5,8 +5,8 @@ Description :
Author : Azure-Tang, Boxin Zhang, chenht2022 Author : Azure-Tang, Boxin Zhang, chenht2022
Date : 2024-07-25 11:25:24 Date : 2024-07-25 11:25:24
Version : 0.1.0 Version : 0.1.0
LastEditors : Azure LastEditors : kkk1nak0
LastEditTime : 2024-07-26 09:27:41 LastEditTime : 2024-08-11 12:14:39
Copyright (c) 2024 by KVCache.AI, All Rights Reserved. Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
''' '''
@ -19,7 +19,9 @@ import torch
import sys, os import sys, os
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
sys.path.append(os.path.dirname(__file__) + "/../ktransformers_ext/build") sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
import cpuinfer_ext import cpuinfer_ext
from cpuinfer_ext.moe import MOEConfig, MOE from cpuinfer_ext.moe import MOEConfig, MOE
import ctypes import ctypes
@ -78,6 +80,25 @@ class MLPExpertsBase(ABC):
gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"] gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"]
up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"] up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"]
down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"] down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"]
elif key + ".ffn_down.0.weight" in self.gguf_loader.tensor_info:
# for supporting Mixtral-8x7B-Instuct
gate = []
up = []
down = []
for i in range(8):
gatei, upi, downi = f".ffn_gate.{i}.weight", f".ffn_up.{i}.weight", f".ffn_down.{i}.weight"
targets = [gatei, upi, downi]
tensors = self.load_multi(key, targets, device=device)
gate_it, up_it, down_it = tensors[gatei], tensors[upi], tensors[downi]
gate.append(gate_it)
up.append(up_it)
down.append(down_it)
gate = torch.stack(gate)
up = torch.stack(up)
down = torch.stack(down)
gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate.0.weight"]["ggml_type"]
up_type = self.gguf_loader.tensor_info[key + ".ffn_up.0.weight"]["ggml_type"]
down_type = self.gguf_loader.tensor_info[key + ".ffn_down.0.weight"]["ggml_type"]
else: else:
raise ValueError(f"Experts {key} not found in gguf_loader") raise ValueError(f"Experts {key} not found in gguf_loader")
res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}} res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
@ -94,7 +115,8 @@ class MLPCPUExperts(MLPExpertsBase):
expert_ids_cpu:Tensor = None expert_ids_cpu:Tensor = None
weights_cpu:Tensor = None weights_cpu:Tensor = None
output_cpu:Tensor = None output_cpu:Tensor = None
output_gpu:Tensor = None output_gpu_map:dict = {} # Manage output tensor buffer on different gpu
#stream_map:dict = {} # Manage cuda stream on different gpu
CPU_INFER = cpuinfer_ext.CPUInfer(Config().cpu_infer) CPU_INFER = cpuinfer_ext.CPUInfer(Config().cpu_infer)
def __init__( def __init__(
self, self,
@ -113,81 +135,83 @@ class MLPCPUExperts(MLPExpertsBase):
self.out_device = out_device self.out_device = out_device
def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False): def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False):
if device: with torch.device(self.out_device):
assert device.lower() == "cpu", "MLPCPUExperts can only be loaded on CPU, Parameter \"device\" can be cpu or None." if device:
if w is None: w = self.load_weights()[self.key] assert device.lower() == "cpu", "MLPCPUExperts can only be loaded on CPU, Parameter \"device\" can be cpu or None."
self.gate = w["gate"] if w is None: w = self.load_weights()[self.key]
self.up = w["up"] self.gate = w["gate"]
self.down = w["down"] self.up = w["up"]
self.gate_type = w["gate_type"] self.down = w["down"]
self.up_type = w["up_type"] self.gate_type = w["gate_type"]
self.down_type = w["down_type"] self.up_type = w["up_type"]
gate_ptr = ctypes.addressof( self.down_type = w["down_type"]
ctypes.cast(self.gate.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents gate_ptr = ctypes.addressof(
) ctypes.cast(self.gate.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
up_ptr = ctypes.addressof( )
ctypes.cast(self.up.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents up_ptr = ctypes.addressof(
) ctypes.cast(self.up.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
down_ptr = ctypes.addressof( )
ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents down_ptr = ctypes.addressof(
) ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
# print(self.gate_qtype, self.up_qtype, self.down_qtype) )
n_routed_experts = self.n_routed_experts # print(self.gate_qtype, self.up_qtype, self.down_qtype)
# n_routed_experts = len(self.orig_module) n_routed_experts = self.n_routed_experts
moe_config = MOEConfig( # n_routed_experts = len(self.orig_module)
n_routed_experts, moe_config = MOEConfig(
self.config.num_experts_per_tok, n_routed_experts,
self.config.hidden_size, self.config.num_experts_per_tok,
self.config.moe_intermediate_size, self.config.hidden_size,
64, self.config.moe_intermediate_size,
10, 64,
1024, 10,
gate_ptr, 1024,
up_ptr, gate_ptr,
down_ptr, up_ptr,
self.gate_type, down_ptr,
self.up_type, self.gate_type,
self.down_type, self.up_type,
30, # TODO: get from model.dtype self.down_type,
) 30, # TODO: get from model.dtype
# print(n_routed_experts, hidden_size, moe_intermediate_size) )
num_experts_per_tok = self.config.num_experts_per_tok # print(n_routed_experts, hidden_size, moe_intermediate_size)
self.moe = MOE(moe_config) num_experts_per_tok = self.config.num_experts_per_tok
self.cpu_infer = MLPCPUExperts.CPU_INFER self.moe = MOE(moe_config)
if warmup: self.cpu_infer = MLPCPUExperts.CPU_INFER
self.cpu_infer.submit(self.moe.warm_up()) if warmup:
self.cpu_infer.sync() self.cpu_infer.submit(self.moe.warm_up())
if MLPCPUExperts.output_gpu == None: self.cpu_infer.sync()
MLPCPUExperts.input_tensor_cpu = torch.empty((self.config.hidden_size), device="cpu", pin_memory=True) if self.out_device not in MLPCPUExperts.output_gpu_map:
MLPCPUExperts.expert_ids_cpu = torch.empty((num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True) MLPCPUExperts.output_gpu_map[self.out_device] = torch.zeros((self.config.hidden_size), device=self.out_device)
MLPCPUExperts.weights_cpu = torch.empty((num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True) if MLPCPUExperts.input_tensor_cpu == None:
MLPCPUExperts.output_cpu = torch.empty((self.config.hidden_size), device="cpu", pin_memory=True) MLPCPUExperts.input_tensor_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True)
MLPCPUExperts.output_gpu = torch.empty((self.config.hidden_size), device=self.out_device) MLPCPUExperts.expert_ids_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
MLPCPUExperts.weights_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
MLPCPUExperts.output_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
def submit_for_one_decode(self, input_tensor, expert_ids, weights): def submit_for_one_decode(self, input_tensor, expert_ids, weights):
MLPCPUExperts.input_tensor_cpu.copy_(input_tensor, non_blocking=True) MLPCPUExperts.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
MLPCPUExperts.expert_ids_cpu.copy_(expert_ids, non_blocking=True) MLPCPUExperts.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
MLPCPUExperts.weights_cpu.copy_(weights, non_blocking=True) MLPCPUExperts.weights_cpu.copy_(weights, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(1, expert_ids.size(0), MLPCPUExperts.expert_ids_cpu.data_ptr(), MLPCPUExperts.weights_cpu.data_ptr(), MLPCPUExperts.input_tensor_cpu.data_ptr(), MLPCPUExperts.output_cpu.data_ptr())) self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(0), MLPCPUExperts.expert_ids_cpu.data_ptr(), MLPCPUExperts.weights_cpu.data_ptr(), MLPCPUExperts.input_tensor_cpu.data_ptr(), MLPCPUExperts.output_cpu.data_ptr()))
def sync_for_one_decode(self): def sync_for_one_decode(self):
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream) self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)
MLPCPUExperts.output_gpu.copy_(MLPCPUExperts.output_cpu, non_blocking=True) MLPCPUExperts.output_gpu_map[self.out_device].copy_(MLPCPUExperts.output_cpu, non_blocking=True)
#print("capturing experts finish") return MLPCPUExperts.output_gpu_map[self.out_device]
return MLPCPUExperts.output_gpu
def forward(self, input_tensor, expert_ids, weights): def forward(self, input_tensor, expert_ids, weights):
# generate, capture and run cuda graph # generate, capture and run cuda graph
# print(expert_ids)
if input_tensor.size(0)==1: if input_tensor.size(0)==1:
# TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible
#print("capturing experts") #print("capturing experts")
MLPCPUExperts.input_tensor_cpu.copy_(input_tensor, non_blocking=True) MLPCPUExperts.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
MLPCPUExperts.expert_ids_cpu.copy_(expert_ids, non_blocking=True) MLPCPUExperts.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
MLPCPUExperts.weights_cpu.copy_(weights, non_blocking=True) MLPCPUExperts.weights_cpu.copy_(weights, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(1, expert_ids.size(1), MLPCPUExperts.expert_ids_cpu.data_ptr(), MLPCPUExperts.weights_cpu.data_ptr(), MLPCPUExperts.input_tensor_cpu.data_ptr(), MLPCPUExperts.output_cpu.data_ptr())) self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(1, expert_ids.size(1), MLPCPUExperts.expert_ids_cpu.data_ptr(), MLPCPUExperts.weights_cpu.data_ptr(), MLPCPUExperts.input_tensor_cpu.data_ptr(), MLPCPUExperts.output_cpu.data_ptr()))
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream) self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
MLPCPUExperts.output_gpu.copy_(MLPCPUExperts.output_cpu, non_blocking=True) MLPCPUExperts.output_gpu_map[self.out_device].copy_(MLPCPUExperts.output_cpu, non_blocking=True)
#print("capturing experts finish") return MLPCPUExperts.output_gpu_map[self.out_device]
return MLPCPUExperts.output_gpu
else: else:
input_tensor = input_tensor.contiguous().cpu() input_tensor = input_tensor.contiguous().cpu()
expert_ids = expert_ids.contiguous().cpu() expert_ids = expert_ids.contiguous().cpu()
@ -195,7 +219,7 @@ class MLPCPUExperts(MLPExpertsBase):
output = torch.empty_like(input_tensor).contiguous() output = torch.empty_like(input_tensor).contiguous()
self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr())) self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr()))
self.cpu_infer.sync() self.cpu_infer.sync()
return output.to(device=object.__getattribute__(self, "device")) return output.to(device=object.__getattribute__(self, "out_device"))
def unload(self): def unload(self):
return return
@ -222,6 +246,24 @@ class MLPCPUExperts(MLPExpertsBase):
gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"] gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"]
up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"] up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"]
down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"] down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"]
elif key + ".ffn_down.0.weight" in self.gguf_loader.tensor_info:
# for supporting Mixtral-8x7B-Instuct
gate = []
up = []
down = []
for i in range(8):
gate_it = self.gguf_loader.get_mmap_tensor(f"{key}.ffn_gate.{i}.weight")
up_it = self.gguf_loader.get_mmap_tensor(f"{key}.ffn_up.{i}.weight")
down_it = self.gguf_loader.get_mmap_tensor(f"{key}.ffn_down.{i}.weight")
gate.append(gate_it)
up.append(up_it)
down.append(down_it)
gate = np.stack(gate)
up = np.stack(up)
down = np.stack(down)
gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate.0.weight"]["ggml_type"]
up_type = self.gguf_loader.tensor_info[key + ".ffn_up.0.weight"]["ggml_type"]
down_type = self.gguf_loader.tensor_info[key + ".ffn_down.0.weight"]["ggml_type"]
else: else:
raise ValueError(f"Experts {key} not found in gguf_loader") raise ValueError(f"Experts {key} not found in gguf_loader")
res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}} res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
@ -360,6 +402,11 @@ class MLPExpertsTorch(MLPExpertsBase):
def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
org_device = hidden_states_cpu.device
hidden_states_cpu = hidden_states_cpu.to(self.device)
selected_experts_cpu = selected_experts_cpu.to(self.device)
routing_weights_cpu = routing_weights_cpu.to(self.device)
batch_sequence_length, hidden_dim = hidden_states_cpu.size() batch_sequence_length, hidden_dim = hidden_states_cpu.size()
final_hidden_states = torch.zeros( final_hidden_states = torch.zeros(
@ -388,27 +435,29 @@ class MLPExpertsTorch(MLPExpertsBase):
# the `top_x` tensor here. # the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states) final_hidden_states.index_add_(0, top_x, current_hidden_states)
return final_hidden_states.to(org_dtype)
return final_hidden_states.to(org_dtype, device=org_device)
EXPERTS_MAP = { EXPERTS_MAP = {
"MLPCPUExperts": MLPCPUExperts, "MLPCPUExperts": MLPCPUExperts,
"MLPExpertsTorch": MLPExpertsTorch, "MLPExpertsTorch": MLPExpertsTorch,
"MLPExpertsMarlin": MLPExpertsMarlin, "MLPExpertsMarlin": MLPExpertsMarlin,
} }
class KTransformersMLPExpert(BaseInjectedModule, MLPExpertsBase): class KTransformersMLPExpert(BaseInjectedModule, MLPExpertsBase):
def __init__(self, def __init__(self,
key: str, key: str,
gguf_loader: GGUFLoader, gguf_loader: GGUFLoader,
config: PretrainedConfig, config: PretrainedConfig,
orig_module: nn.Module, orig_module: nn.Module,
device: str = "cuda", # device: str = "cuda",
prefill_device:str = "cuda", prefill_device:str = "cuda",
prefill_mlp_type: str | None = "MLPExpertsTorch", prefill_mlp_type: str | None = "MLPExpertsTorch",
generate_device: str = "cpu", generate_device: str = "cpu",
generate_mlp_type: str | None = "MLPCPUExperts", generate_mlp_type: str | None = "MLPCPUExperts",
**kwargs): **kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
MLPExpertsBase.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) MLPExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
if generate_mlp_type is not None: if generate_mlp_type is not None:
self.generate_experts = EXPERTS_MAP[generate_mlp_type](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs) self.generate_experts = EXPERTS_MAP[generate_mlp_type](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
else: else:
@ -471,6 +520,7 @@ class KTransformersMLPExpert(BaseInjectedModule, MLPExpertsBase):
from ktransformers.models.modeling_deepseek import DeepseekV2MoE from ktransformers.models.modeling_deepseek import DeepseekV2MoE
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
class Qwen2MoeSparseMoeBlockInjected(BaseInjectedModule, Qwen2MoeSparseMoeBlock): class Qwen2MoeSparseMoeBlockInjected(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
@ -578,7 +628,6 @@ class Qwen2MoeSparseMoeBlockInjected(BaseInjectedModule, Qwen2MoeSparseMoeBlock)
return final_hidden_states return final_hidden_states
class DeepseekV2MoEInjected(BaseInjectedModule, DeepseekV2MoE): class DeepseekV2MoEInjected(BaseInjectedModule, DeepseekV2MoE):
def forward(self, hidden_states): def forward(self, hidden_states):
identity = hidden_states identity = hidden_states
@ -587,7 +636,7 @@ class DeepseekV2MoEInjected(BaseInjectedModule, DeepseekV2MoE):
topk_idx, topk_weight, aux_loss = self.gate(hidden_states) topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
if sequence_length == 1: if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode"):
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0]) self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
if self.config.n_shared_experts is not None: if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity).squeeze(0) y_ = self.shared_experts(identity).squeeze(0)
@ -677,3 +726,102 @@ class DeepseekV2MoEInjected(BaseInjectedModule, DeepseekV2MoE):
.type(new_x.dtype) .type(new_x.dtype)
) )
return final_out return final_out
class MisrtalSparseMoEBlockInjected(BaseInjectedModule, MixtralSparseMoeBlock):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
orig_shape = hidden_states.shape
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode"):
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], selected_experts[0], routing_weights[0])
y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)
y.resize_(*orig_shape)
return y, router_logits
hidden_states_expert = hidden_states.to(self.experts.device) if isinstance(self.experts, MLPExpertsBase) else hidden_states_expert.cpu()
selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, MLPExpertsBase) else selected_experts_expert.cpu()
routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, MLPExpertsBase) else routing_weights_expert.cpu()
if isinstance(self.experts, MLPExpertsBase):
y = (
self.moe_on_cpuinfer(
hidden_states_expert, selected_experts_expert, routing_weights_expert
)
.view(*orig_shape)
.to(device=hidden_states.device)
)
elif hidden_states_expert.size(0) > 10:
y = self.moe_infer(
hidden_states_expert, selected_experts_expert, routing_weights_expert, orig_shape
).to(device=hidden_states.device)
else:
y = self.moe_infer_simple(
hidden_states_expert, selected_experts_expert, routing_weights_expert
).to(device=hidden_states.device)
y.resize_(*orig_shape)
return y, router_logits
@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
outs = self.experts(x, topk_ids, topk_weight)
return outs
@torch.no_grad()
# TODO may bugs here
def moe_infer_simple(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
'''
hidden_states_cpu: [num_tokens, hidden_size]
topk_ids, topk_weight: [num_tokens, num_selected_experts]
'''
outs = torch.zeros_like(hidden_states_cpu)
for token_idx in range(selected_experts_cpu.size(0)):
for expert_idx in range(selected_experts_cpu.size(1)):
expert = self.experts[selected_experts_cpu[token_idx, expert_idx]]
outs[token_idx] += expert.forward(hidden_states_cpu[token_idx]) * routing_weights_cpu[token_idx, expert_idx]
return outs
@torch.no_grad()
# TODO may bugs here
def moe_infer(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor, orig_shape: tuple) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = orig_shape
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states_cpu.dtype, device=hidden_states_cpu.device
)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.num_experts).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer.forward(current_state) * routing_weights_cpu[top_x, idx, None]
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states_cpu.dtype))
return final_hidden_states

View file

@ -6,7 +6,7 @@ Author : Azure-Tang
Date : 2024-07-25 11:25:24 Date : 2024-07-25 11:25:24
Version : 1.0.0 Version : 1.0.0
LastEditors : Azure LastEditors : Azure
LastEditTime : 2024-07-26 09:27:48 LastEditTime : 2024-08-08 10:09:14
Copyright (c) 2024 by KVCache.AI, All Rights Reserved. Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
''' '''
@ -45,6 +45,8 @@ from ktransformers.models.modeling_deepseek import BaseModelOutputWithPast, Deep
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.utils import InferenceState from ktransformers.util.utils import InferenceState
from ktransformers.util.custom_gguf import GGUFLoader
from transformers.configuration_utils import PretrainedConfig
if is_flash_attn_2_available(): if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn import flash_attn_func, flash_attn_varlen_func
@ -73,34 +75,6 @@ QWEN2MOE_START_DOCSTRING = r"""
[`~PreTrainedModel.from_pretrained`] method to load the model weights. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
""" """
@add_start_docstrings(
"The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.",
QWEN2MOE_START_DOCSTRING,
)
class Qwen2MoePreTrainedModel(PreTrainedModel):
config_class = Qwen2MoeConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen2MoeDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
QWEN2MOE_INPUTS_DOCSTRING = r""" QWEN2MOE_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@ -177,13 +151,11 @@ QWEN2MOE_INPUTS_DOCSTRING = r"""
the complete sequence length. the complete sequence length.
""" """
from ktransformers.util.custom_gguf import GGUFLoader
from transformers.configuration_utils import PretrainedConfig
@add_start_docstrings( @add_start_docstrings(
"The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.", "The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.",
QWEN2MOE_START_DOCSTRING, QWEN2MOE_START_DOCSTRING,
) )
class Qwen2MoeModelPerLayerPrefill(BaseInjectedModule): class Qwen2MoeModelKTransformers(BaseInjectedModule):
""" """
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`] Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`]
@ -198,10 +170,13 @@ class Qwen2MoeModelPerLayerPrefill(BaseInjectedModule):
orig_module: nn.Module, orig_module: nn.Module,
device: str = "cuda", device: str = "cuda",
per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill
transfer_map: dict = None,
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold
self.transfer_map = transfer_map
self.stream_device_map = dict()
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
def forward( def forward(
@ -287,7 +262,20 @@ class Qwen2MoeModelPerLayerPrefill(BaseInjectedModule):
all_router_logits = () if output_router_logits else None all_router_logits = () if output_router_logits else None
next_decoder_cache = None next_decoder_cache = None
for decoder_layer in self.layers: for i, decoder_layer in enumerate(self.layers):
if self.transfer_map is not None and i in self.transfer_map:
prev_stream = torch.cuda.current_stream()
cur_device = self.transfer_map[i]
if cur_device not in self.stream_device_map:
self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
torch.cuda.set_device(cur_device)
self.stream_device_map[cur_device].wait_stream(prev_stream)
torch.cuda.set_stream(self.stream_device_map[cur_device])
hidden_states = hidden_states.to(self.transfer_map[i], non_blocking = True)
causal_mask = causal_mask.to(self.transfer_map[i], non_blocking = True) if causal_mask is not None else None
position_ids = position_ids.to(self.transfer_map[i], non_blocking = True) if position_ids is not None else None
cache_position = cache_position.to(self.transfer_map[i], non_blocking = True) if cache_position is not None else None
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
@ -463,7 +451,7 @@ DeepseekV2_INPUTS_DOCSTRING = r"""
""" """
class DeepseekV2ModelPerLayerPrefill(BaseInjectedModule): class DeepseekV2ModelKTransformers(BaseInjectedModule):
""" """
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`] Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`]
@ -478,10 +466,13 @@ class DeepseekV2ModelPerLayerPrefill(BaseInjectedModule):
orig_module: nn.Module, orig_module: nn.Module,
device: str = "cuda", device: str = "cuda",
per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill
transfer_map: dict = None,
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold
self.transfer_map = transfer_map
self.stream_device_map = dict()
@add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
def forward( def forward(
@ -584,7 +575,20 @@ class DeepseekV2ModelPerLayerPrefill(BaseInjectedModule):
t_cpu = 0 t_cpu = 0
t_f = 0 t_f = 0
for decoder_layer in self.layers: for i, decoder_layer in enumerate(self.layers):
if self.transfer_map is not None and i in self.transfer_map:
prev_stream = torch.cuda.current_stream()
cur_device = self.transfer_map[i]
if cur_device not in self.stream_device_map:
self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
torch.cuda.set_device(cur_device)
self.stream_device_map[cur_device].wait_stream(prev_stream)
torch.cuda.set_stream(self.stream_device_map[cur_device])
hidden_states = hidden_states.to(self.transfer_map[i], non_blocking = True)
causal_mask = causal_mask.to(self.transfer_map[i], non_blocking = True) if causal_mask is not None else None
position_ids = position_ids.to(self.transfer_map[i], non_blocking = True) if position_ids is not None else None
cache_position = cache_position.to(self.transfer_map[i], non_blocking = True) if cache_position is not None else None
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)

View file

@ -176,7 +176,7 @@ class QuantizedLinearMarlin(QuantizedLinearBase):
self.act_order = act_order self.act_order = act_order
self.is_k_full = is_k_full self.is_k_full = is_k_full
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = "cuda"): def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if device is None: device = self.device if device is None: device = self.device
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device" assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
if w is None: w = self.load_weight(device=device) if w is None: w = self.load_weight(device=device)
@ -200,7 +200,7 @@ class QuantizedLinearMarlin(QuantizedLinearBase):
weight, self.num_bits, self.group_size, self.act_order weight, self.num_bits, self.group_size, self.act_order
) )
self.workspace = MarlinWorkspace( self.workspace = MarlinWorkspace(
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device
) )
self.marlin_q_w = marlin_q_w self.marlin_q_w = marlin_q_w
self.marlin_s = marlin_s self.marlin_s = marlin_s
@ -247,7 +247,6 @@ class QuantizedLinearMarlin(QuantizedLinearBase):
LINEAR_MAP = { LINEAR_MAP = {
"QuantizedLinearMarlin": QuantizedLinearMarlin, "QuantizedLinearMarlin": QuantizedLinearMarlin,
"QuantizedLinearTorch": QuantizedLinearTorch, "QuantizedLinearTorch": QuantizedLinearTorch,
"QuantizedLinearTorch": QuantizedLinearTorch,
} }
class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase): class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase):
@ -257,15 +256,15 @@ class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase):
gguf_loader: GGUFLoader, gguf_loader: GGUFLoader,
config: PretrainedConfig, config: PretrainedConfig,
orig_module: nn.Module, orig_module: nn.Module,
device: str = "cuda", # device: str = "cuda",
generate_device: str = "cuda", generate_device: str = "cuda",
generate_op: str| None = "QuantizedLinearMarlin", generate_op: str| None = "QuantizedLinearMarlin",
prefill_device: str = "cuda", prefill_device: str = "cuda",
prefill_op: str| None = "QuantizedLinearTorch", prefill_op: str| None = "QuantizedLinearTorch",
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
QuantizedLinearBase.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) QuantizedLinearBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
# build all the linear operators # build all the linear operators
if prefill_op is not None: if prefill_op is not None:
assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported" assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported"
@ -289,7 +288,6 @@ class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase):
self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs) self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs)
else: else:
self.generate_linear = None self.generate_linear = None
self.device = device
self.mode = InferenceState.UNLOAD self.mode = InferenceState.UNLOAD
def forward(self, x): def forward(self, x):

View file

@ -1,6 +1,6 @@
''' '''
Description : Description :
Author : Boxin Zhang Author : Boxin Zhang, Azure-Tang
Version : 0.1.0 Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved. Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
''' '''
@ -15,6 +15,7 @@ from transformers.configuration_utils import PretrainedConfig
from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf
from ktransformers.util.utils import set_module, load_weights from ktransformers.util.utils import set_module, load_weights
import itertools import itertools
import copy
def inject(module, local_optimization_dict, model_config:AutoConfig ,gguf_loader:GGUFLoader, prefix=''): def inject(module, local_optimization_dict, model_config:AutoConfig ,gguf_loader:GGUFLoader, prefix=''):
for name, child in module._modules.items(): for name, child in module._modules.items():
@ -22,18 +23,20 @@ def inject(module, local_optimization_dict, model_config:AutoConfig ,gguf_loader
child_prefix = prefix + name child_prefix = prefix + name
if child_prefix in local_optimization_dict: if child_prefix in local_optimization_dict:
inject_module_meta=local_optimization_dict[child_prefix] inject_module_meta=local_optimization_dict[child_prefix]
if isinstance(inject_module_meta, Mapping): if inject_module_meta["class"] != "default":
import_path = inject_module_meta["class"].split(".") import_path = inject_module_meta["class"].split(".")
import_module_name = ".".join(import_path[:-1]) import_module_name = ".".join(import_path[:-1])
gguf_loader.tensor_device_map[inject_module_meta["key"]] = inject_module_meta["kwargs"] if "kwargs" in inject_module_meta else dict()
import_class_name = import_path[-1] import_class_name = import_path[-1]
module_cls=getattr(__import__(import_module_name, fromlist=[""]), import_class_name) module_cls=getattr(__import__(import_module_name, fromlist=[""]), import_class_name)
print(f"Injecting {child_prefix} as", import_module_name, ".", import_class_name) print(f"Injecting {child_prefix} as", import_module_name, ".", import_class_name)
inject_module=module_cls(key = inject_module_meta["key"], gguf_loader = gguf_loader, config = model_config, orig_module=child, device = inject_module_meta["device"], **inject_module_meta["kwargs"]) inject_module=module_cls(key = inject_module_meta["key"], gguf_loader = gguf_loader, config = model_config, orig_module=child, **inject_module_meta["kwargs"])
set_module(module, name, inject_module) set_module(module, name, inject_module)
elif isinstance(inject_module_meta, str): elif inject_module_meta["class"] == "default":
assert inject_module_meta=="default", "for str inject_module_meta, only support \"default\"." print(f"Injecting {child_prefix} as default")
gguf_loader.tensor_device_map[inject_module_meta["key"]] = inject_module_meta["kwargs"] if "kwargs" in inject_module_meta else dict()
else: else:
raise Exception("inject_module_meta must be a dict or str") raise Exception("inject_module_meta[\"class\"] must be \"default\" or a class path")
child_prefix += "." child_prefix += "."
child_optimization_dict = {k: v for k, v in local_optimization_dict.items() if k.startswith(child_prefix)} child_optimization_dict = {k: v for k, v in local_optimization_dict.items() if k.startswith(child_prefix)}
inject(child, child_optimization_dict, model_config, gguf_loader, child_prefix) inject(child, child_optimization_dict, model_config, gguf_loader, child_prefix)
@ -57,6 +60,8 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
for rule in rule_list: for rule in rule_list:
#print(rule) #print(rule)
match_meta = rule["match"] match_meta = rule["match"]
if "class" not in match_meta and "name" not in match_meta:
raise Exception("match must have at least one of \"class\" and \"name\"")
if "class" in match_meta: if "class" in match_meta:
import_path = match_meta["class"].split(".") import_path = match_meta["class"].split(".")
import_module_name = ".".join(import_path[:-1]) import_module_name = ".".join(import_path[:-1])
@ -67,16 +72,29 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
if "name" in match_meta: if "name" in match_meta:
if re.search(match_meta["name"], module_name) is None: if re.search(match_meta["name"], module_name) is None:
continue continue
replace_meta = rule["replace"] if "replace" not in rule:
out_data[module_name]={"key": translated_name, raise Exception("replace must be in rule")
"class": replace_meta["class"], if "replace" in rule:
"device": replace_meta["device"] if "device" in replace_meta else default_device, replace_meta = rule["replace"]
"kwargs": replace_meta["kwargs"] if "kwargs" in replace_meta else dict()} if module_name not in out_data:
out_data[module_name]={"key": translated_name,
"class": replace_meta["class"] if "class" in replace_meta else "default",
# "device": replace_meta["device"] if "device" in replace_meta else default_device,
"kwargs": copy.deepcopy(replace_meta["kwargs"]) if "kwargs" in replace_meta else dict()}
else:
if out_data[module_name]["class"] == "default":
out_data[module_name]["class"] = replace_meta["class"] if "class" in replace_meta else "default"
out_data[module_name]["kwargs"].update(copy.deepcopy(replace_meta["kwargs"]) if "kwargs" in replace_meta else dict())
if "recursive" in rule: if "recursive" in rule:
recursive = bool(rule["recursive"]) recursive = bool(rule["recursive"])
if module_name not in out_data: if module_name not in out_data:
out_data[module_name]="default" out_data[module_name]= {
"class": "default",
"key": translated_name,
"kwargs": {"generate_device": default_device,
"prefill_device": default_device}
}
#print(out_data[module_name]) #print(out_data[module_name])
#input() #input()
@ -88,6 +106,14 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
gen_optimize_config(child, out_data, rule_list, child_prefix) gen_optimize_config(child, out_data, rule_list, child_prefix)
def translate_model_config(model_config: PretrainedConfig):
# for supporting some special model
if model_config.model_type == "mixtral":
model_config.moe_intermediate_size = model_config.intermediate_size
return model_config
def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, model_config: PretrainedConfig, default_device: str = "cuda:0"): def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, model_config: PretrainedConfig, default_device: str = "cuda:0"):
with open(rule_file, 'r', encoding='utf-8') as f: with open(rule_file, 'r', encoding='utf-8') as f:
rule_list = yaml.load(f.read(), Loader=yaml.FullLoader) rule_list = yaml.load(f.read(), Loader=yaml.FullLoader)
@ -95,8 +121,11 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo
optimize_config = dict() optimize_config = dict()
gen_optimize_config(module, optimize_config, rule_list, default_device = default_device) gen_optimize_config(module, optimize_config, rule_list, default_device = default_device)
model_config = translate_model_config(model_config)
gguf_loader=GGUFLoader(gguf_path) gguf_loader=GGUFLoader(gguf_path)
with torch.device("meta"): with torch.device("meta"):
inject(module, optimize_config, model_config, gguf_loader) inject(module, optimize_config, model_config, gguf_loader)
load_weights(module, gguf_loader) load_weights(module, gguf_loader)
model_config.gguf_loader = gguf_loader
del_meta(module) del_meta(module)

View file

@ -0,0 +1,228 @@
- match:
name: "^model\\.layers\\.([0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "(^model\\.layers\\.([1][0-9])\\.)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "(^model\\.layers\\.([2][0-9])\\.)"
replace:
class: "default"
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
- match:
name: "(^model\\.layers\\.([345][0-9])\\.)|(^model.norm)|(^lm_head)"
replace:
class: "default"
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model\\.layers\\.([0-9])\\."
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([1][0-9])\\."
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.([2][0-9])\\."
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
- match:
name: "^model\\.layers\\.([345][0-9])\\."
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
- match:
name: "^model\\.layers\\.([1][0-9])\\.(?!self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "QuantizedLinearMarlin"
prefill_op: "QuantizedLinearTorch"
- match:
name: "^model\\.layers\\.([1][0-9])\\.(?!self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "QuantizedLinearMarlin"
prefill_op: "QuantizedLinearTorch"
- match:
name: "^model\\.layers\\.([2][0-9])\\.(?!self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
generate_op: "QuantizedLinearMarlin"
prefill_op: "QuantizedLinearTorch"
- match:
name: "^model\\.layers\\.([345][0-9])\\.(?!self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
generate_op: "QuantizedLinearMarlin"
prefill_op: "QuantizedLinearTorch"
- match:
name: "^model\\.layers\\.([0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([1][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.([2][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
- match:
name: "^model\\.layers\\.([345][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
- match:
name: "^model\\.layers\\.([0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:0"
prefill_mlp_type: "MLPExpertsTorch"
generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts"
out_device: "cuda:0"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([1][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:1"
prefill_mlp_type: "MLPExpertsTorch"
generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts"
out_device: "cuda:1"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([2][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:2"
prefill_mlp_type: "MLPExpertsTorch"
generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts"
out_device: "cuda:2"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([345][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:3"
prefill_mlp_type: "MLPExpertsTorch"
generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts"
out_device: "cuda:3"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([1][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.([2][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
- match:
name: "^model\\.layers\\.([345][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.layer_wise_prefill.DeepseekV2ModelKTransformers"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map:
10: "cuda:1"
20: "cuda:2"
30: "cuda:3"

View file

@ -0,0 +1,126 @@
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)|(lm_head)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([345][0-9])\\."
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "QuantizedLinearMarlin"
prefill_op: "QuantizedLinearTorch"
- match:
name: "^model\\.layers\\.([345][0-9])\\.(?!self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "QuantizedLinearMarlin"
prefill_op: "QuantizedLinearTorch"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([345][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:0"
prefill_mlp_type: "MLPExpertsTorch"
generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts"
out_device: "cuda:0"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([345][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:1"
prefill_mlp_type: "MLPExpertsTorch"
generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts"
out_device: "cuda:1"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([345][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.layer_wise_prefill.DeepseekV2ModelKTransformers"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map:
30: "cuda:1"

View file

@ -1,3 +1,10 @@
- match:
name: "^model\\.layers\\..*\\.|^lm_head"
replace:
class: "default"
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match: - match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace: replace:
@ -21,12 +28,11 @@
name: "^model\\.layers\\..*\\.mlp\\.experts$" name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace: replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism
device: "cpu" # which devices to load this module when initializing
kwargs: kwargs:
prefill_device: "cuda" prefill_device: "cuda"
prefill_mlp_type: "MLPExpertsTorch" prefill_mlp_type: "MLPExpertsTorch"
generate_device: "cpu" generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts" generate_mlp_type: "MLPCPUExperts"
out_device: "cuda" out_device: "cuda"
recursive: False # don't recursively inject submodules of this module recursive: False # don't recursively inject submodules of this module
- match: - match:
@ -36,6 +42,13 @@
- match: - match:
name: "^model$" name: "^model$"
replace: replace:
class: "ktransformers.operators.layer_wise_prefill.DeepseekV2ModelPerLayerPrefill" class: "ktransformers.operators.layer_wise_prefill.DeepseekV2ModelKTransformers"
kwargs: kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View file

@ -0,0 +1,126 @@
- match:
name: "^model\\.layers\\.(0|[1-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)|(lm_head)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model\\.layers\\.(0|[1-9])\\."
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([12][0-9])\\."
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9])\\.(?!self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "QuantizedLinearMarlin"
prefill_op: "QuantizedLinearTorch"
- match:
name: "^model\\.layers\\.([12][0-9])\\.(?!self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "QuantizedLinearMarlin"
prefill_op: "QuantizedLinearTorch"
- match:
name: "^model\\.layers\\.(0|[1-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([12][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:0"
prefill_mlp_type: "MLPExpertsTorch"
generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts"
out_device: "cuda:0"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([12][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:1"
prefill_mlp_type: "MLPExpertsTorch"
generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts"
out_device: "cuda:1"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.(0|[1-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([12][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.layer_wise_prefill.DeepseekV2ModelKTransformers"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map:
10: "cuda:1"

View file

@ -0,0 +1,45 @@
- match:
name: "^model\\.layers\\..*\\."
replace:
class: "default"
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_mixtral.MixtralRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbedding
- match:
name: "^model\\.layers\\..*$"
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "QuantizedLinearMarlin"
prefill_op: "QuantizedLinearTorch"
- match:
name: "^model\\.layers\\..*\\.block_sparse_moe$"
class: ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock
replace:
class: ktransformers.operators.experts.MisrtalSparseMoEBlockInjected
- match:
name: "^model\\.layers\\..*\\.block_sparse_moe\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersMLPExpert
kwargs:
prefill_device: "cuda"
prefill_mlp_type: "MLPExpertsTorch"
generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View file

@ -0,0 +1,111 @@
- match:
name: "^model\\.layers\\.([012])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([012])\\."
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbedding
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([012])$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "QuantizedLinearMarlin"
prefill_op: "QuantizedLinearTorch"
- match:
name: "^model\\.layers\\.([012])\\.mlp$"
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock
replace:
class: ktransformers.operators.experts.Qwen2MoeSparseMoeBlockInjected # mlp module with custom forward function
- match:
name: "^model\\.layers\\.([012])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism
# device: "cpu" # which devices to load this module when initializing
kwargs:
prefill_device: "cuda:0"
prefill_mlp_type: "MLPExpertsTorch"
generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts"
out_device: "cuda:0"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([12][0-9]|[3-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.([12][0-9]|[3-9])\\."
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbedding
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.([12][0-9]|[3-9])$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "QuantizedLinearMarlin"
prefill_op: "QuantizedLinearTorch"
- match:
name: "^model\\.layers\\.([12][0-9]|[3-9])\\.mlp$"
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock
replace:
class: ktransformers.operators.experts.Qwen2MoeSparseMoeBlockInjected # mlp module with custom forward function
- match:
name: "^model\\.layers\\.([12][0-9]|[3-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism
# device: "cpu" # which devices to load this module when initializing
kwargs:
prefill_device: "cuda:1"
prefill_mlp_type: "MLPExpertsTorch"
generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts"
out_device: "cuda:1"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "(^model.norm)|(^lm_head)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.layer_wise_prefill.Qwen2MoeModelKTransformers"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map:
3: "cuda:1"

View file

@ -1,3 +1,10 @@
- match:
name: "^model\\.layers\\..*\\."
replace:
class: "default"
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match: - match:
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
replace: replace:
@ -21,7 +28,7 @@
name: "^model\\.layers\\..*\\.mlp\\.experts$" name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace: replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism
device: "cpu" # which devices to load this module when initializing # device: "cpu" # which devices to load this module when initializing
kwargs: kwargs:
prefill_device: "cuda" prefill_device: "cuda"
prefill_mlp_type: "MLPExpertsTorch" prefill_mlp_type: "MLPExpertsTorch"
@ -32,6 +39,13 @@
- match: - match:
name: "^model$" name: "^model$"
replace: replace:
class: "ktransformers.operators.layer_wise_prefill.Qwen2MoeModelPerLayerPrefill" class: "ktransformers.operators.layer_wise_prefill.Qwen2MoeModelKTransformers"
kwargs: kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View file

@ -1,12 +1,9 @@
import os import os
os.environ["CUDA_VISIBLE_DEVICES"]="1" # os.environ["CUDA_VISIBLE_DEVICES"]="1,2"
# add path # add path
import sys import sys
current_path = os.path.abspath(os.path.dirname(__file__)) current_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(current_path+"/../..") sys.path.append(current_path+"/../..")
import pycuda.autoinit
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
import numpy as np import numpy as np
# from ktransformers.operators.linear import KTransformerLinear, QuantizedLinearMarlin # from ktransformers.operators.linear import KTransformerLinear, QuantizedLinearMarlin
# from ktransformers.operators.experts import KTransformersMLPExpert, MLPExpertsTorch # from ktransformers.operators.experts import KTransformersMLPExpert, MLPExpertsTorch
@ -18,36 +15,23 @@ import time
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
) )
import os
# CUDA_LAUNCH_BLOCKING=1
os.environ["CUDA_LAUNCH_BLOCKING"]="1"
gguf_config = GGUFLoader("/data/Qwen2-57B-A14B-Instruct-GGUF/q4_k_m") gguf_config = GGUFLoader("/data/Qwen2-57B-A14B-Instruct-GGUF/q4_k_m")
model_name = "/data/Qwen2-57B-A14B-Instruct" model_name = "/data/Qwen2-57B-A14B-Instruct"
key = "blk.0."
target = "ffn_down_exps.weight"
t1 = time.time()
q_weight_cpu = gguf_config.load_gguf_tensor(key+target, "cpu")
# q_weight_cpu = torch.from_numpy(q_weight_cpu)
t2 = time.time()
q_weight_gpu = gguf_config.load_gguf_tensor(key+target, "cuda")
t3 = time.time()
print()
allclose = torch.allclose(q_weight_cpu, q_weight_gpu.cpu().to(torch.float32), atol=1e-6)
print(f"Q6k {key+target}")
print("load gguf tensor from cpu cost: ", t2-t1)
print("load gguf tensor from gpu cost: ", t3-t2)
print("allclose: ", allclose)
# Q4k
key = "blk.1." key = "blk.1."
target = "ffn_up_shexp.weight" target = "attn_q.weight"
t1 = time.time() t1 = time.time()
q_weight_cpu = gguf_config.load_gguf_tensor(key+target, "cpu") q_weight_cpu = gguf_config.load_gguf_tensor(key+target, "cpu")
# q_weight_cpu = torch.from_numpy(q_weight_cpu) # q_weight_cpu = torch.from_numpy(q_weight_cpu)
t2 = time.time() t2 = time.time()
q_weight_gpu = gguf_config.load_gguf_tensor(key+target, "cuda") q_weight_gpu = gguf_config.load_gguf_tensor(key+target, "cuda:0")
t3 = time.time() t3 = time.time()
print() print()
allclose = torch.allclose(q_weight_cpu, q_weight_gpu.cpu(), atol=1e-6) allclose = torch.allclose(q_weight_cpu, q_weight_gpu.cpu(), atol=1e-6)
@ -55,3 +39,20 @@ print(f"Q4k {key+target}")
print("load gguf tensor from cpu cost: ", t2-t1) print("load gguf tensor from cpu cost: ", t2-t1)
print("load gguf tensor from gpu cost: ", t3-t2) print("load gguf tensor from gpu cost: ", t3-t2)
print("allclose: ", allclose) print("allclose: ", allclose)
# Q6k
key = "blk.0."
target = "ffn_down_exps.weight"
t1 = time.time()
q_weight_cpu = gguf_config.load_gguf_tensor(key+target, "cpu")
t2 = time.time()
q_weight_gpu = gguf_config.load_gguf_tensor(key+target, "cuda:0")
t3 = time.time()
print()
allclose = torch.allclose(q_weight_cpu, q_weight_gpu.cpu().to(torch.float32), atol=1e-6)
print(f"Q6k {key+target}")
print("load gguf tensor from cpu cost: ", t2-t1)
print("load gguf tensor from gpu cost: ", t3-t2)
print("allclose: ", allclose)

View file

@ -11,7 +11,7 @@ from ktransformers.operators.linear import KTransformerLinear, QuantizedLinearMa
from ktransformers.operators.experts import KTransformersMLPExpert, MLPExpertsTorch from ktransformers.operators.experts import KTransformersMLPExpert, MLPExpertsTorch
from ktransformers.util.custom_gguf import GGUFLoader, dequantize_q4_k_gpu, dequantize_q4_k from ktransformers.util.custom_gguf import GGUFLoader, dequantize_q4_k_gpu, dequantize_q4_k
import torch import torch
import CudaOps import KTransformersOps
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
import time import time
from transformers import ( from transformers import (

View file

@ -21,6 +21,7 @@ class CUDAGraphRunner:
position_ids, position_ids,
cache_position, cache_position,
past_key_values, past_key_values,
main_device,
**kwargs, **kwargs,
) -> None: ) -> None:
assert self.graph is None assert self.graph is None
@ -29,15 +30,24 @@ class CUDAGraphRunner:
self.graph = torch.cuda.CUDAGraph() self.graph = torch.cuda.CUDAGraph()
#self.graph.enable_debug_mode() #self.graph.enable_debug_mode()
self.model = model self.model = model
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to("cuda") inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(main_device)
with torch.cuda.graph(self.graph): # torch.cuda.set_device can't set "cuda", must have a index
if main_device == "cuda":
main_device = "cuda:0"
torch.cuda.set_device(main_device)
self.main_device = main_device
capture_stream = torch.cuda.Stream()
with torch.cuda.graph(self.graph, stream = capture_stream):
logits=model(inputs_embeds=inputs_embeds, logits=model(inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,
cache_position=cache_position, cache_position=cache_position,
past_key_values=past_key_values, past_key_values=past_key_values,
**kwargs)[0] **kwargs)[0]
capture_stream.wait_stream(torch.cuda.current_stream())
torch.cuda.set_device(main_device)
torch.cuda.set_stream(capture_stream)
past_key_values.change_seq_length(-1) past_key_values.change_seq_length(-1)
torch.cuda.synchronize() torch.cuda.synchronize(self.main_device)
#self.graph.debug_dump("cuda_graph_hooked.dot") #self.graph.debug_dump("cuda_graph_hooked.dot")
# Save the input and output buffers. # Save the input and output buffers.
@ -65,7 +75,7 @@ class CUDAGraphRunner:
#print("begin replay") #print("begin replay")
#time.sleep(1) #time.sleep(1)
self.graph.replay() self.graph.replay()
torch.cuda.synchronize() torch.cuda.synchronize(self.main_device)
# Return the output tensor. # Return the output tensor.
return self.output_buffers["logits"] return self.output_buffers["logits"]

View file

@ -5,8 +5,11 @@ Description :
Author : Azure-Tang, Boxin Zhang, chenht2022 Author : Azure-Tang, Boxin Zhang, chenht2022
Date : 2024-07-26 08:48:54 Date : 2024-07-26 08:48:54
Version : 1.0.0 Version : 1.0.0
LastEditors : Azure LastEditors : kkk1nak0
LastEditTime : 2024-07-26 09:28:25 LastEditTime : 2024-08-09 08:03:44
Adapted from https://github.com/99991/pygguf/blob/main/gguf.py
Copyright (c) 2023-2024 The ggml authors
Copyright (c) 2024 Thomas Germer
Copyright (c) 2024 by KVCache.AI, All Rights Reserved. Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
''' '''
# copied from llama.cpp/gguf-py/gguf/constants.py to satisfy dependence of gguf # copied from llama.cpp/gguf-py/gguf/constants.py to satisfy dependence of gguf
@ -15,6 +18,7 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
import struct import struct
import warnings import warnings
import numpy as np import numpy as np
import re
import numpy.typing as npt import numpy.typing as npt
from typing import Sequence from typing import Sequence
import os import os
@ -96,6 +100,8 @@ def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantization
GGML_TYPES = { GGML_TYPES = {
"F32": 0, "F32": 0,
"F16": 1, "F16": 1,
"Q4_0": 2,
"Q5_0": 6,
"Q8_0": 8, "Q8_0": 8,
"Q2_K": 10, "Q2_K": 10,
"Q3_K": 11, "Q3_K": 11,
@ -109,6 +115,8 @@ GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()}
GGML_BLOCK_SIZES = { GGML_BLOCK_SIZES = {
"F32": 4, "F32": 4,
"F16": 2, "F16": 2,
"Q4_0": 2 + 16,
"Q5_0": 2 + 4 + 16,
"Q8_0": 2 + 32, "Q8_0": 2 + 32,
"Q2_K": 256 // 16 + 256 // 4 + 2 + 2, "Q2_K": 256 // 16 + 256 // 4 + 2 + 2,
"Q3_K": 256 // 8 + 256 // 4 + 12 + 2, "Q3_K": 256 // 8 + 256 // 4 + 12 + 2,
@ -120,6 +128,8 @@ GGML_BLOCK_SIZES = {
GGML_ELEMENTS_PER_BLOCK = { GGML_ELEMENTS_PER_BLOCK = {
"F32": 1, "F32": 1,
"F16": 1, "F16": 1,
"Q4_0": 32,
"Q5_0": 32,
"Q8_0": 32, "Q8_0": 32,
"Q2_K": 256, "Q2_K": 256,
"Q3_K": 256, "Q3_K": 256,
@ -128,14 +138,6 @@ GGML_ELEMENTS_PER_BLOCK = {
"Q6_K": 256, "Q6_K": 256,
} }
# DATA_TYPES = {
# "uint32": 4,
# "int32": 5,
# "float32": 6,
# "string": 8,
# "array": 9,
# "uint64": 10,
# }
DATA_TYPES = { DATA_TYPES = {
"uint8": 0, "uint8": 0,
"int8": 1, "int8": 1,
@ -167,6 +169,7 @@ class GGUFLoader:
self.tensor_file_map = {} self.tensor_file_map = {}
self.file_data_map = {} self.file_data_map = {}
self.gguf_file_meta = {} self.gguf_file_meta = {}
self.tensor_device_map = {}
# Walk through all the .gguf files in the directory # Walk through all the .gguf files in the directory
for root, dirs, files in os.walk(gguf_path): for root, dirs, files in os.walk(gguf_path):
@ -283,14 +286,27 @@ class GGUFLoader:
data = self.get_mmap_tensor(name) data = self.get_mmap_tensor(name)
if "cuda" in device.lower(): if "cuda" in device.lower():
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
#values = GGML_DEQUANTIZE[ggml_name](data)
#print("load_gguf_tensor")
#values = torch.from_numpy(values).to(device = device)
else: else:
values = GGML_DEQUANTIZE[ggml_name](data) values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values) values = torch.from_numpy(values)
return values.view(shape[::-1]) values = values.view(shape[::-1])
if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
n_head = self.gguf_file_meta['llama.attention.head_count']
values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:])
.swapaxes(1, 2)
.reshape(values.shape))
elif "attn_k" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
n_head = self.gguf_file_meta['llama.attention.head_count_kv']
values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:])
.swapaxes(1, 2)
.reshape(values.shape))
return values
def read_value(f, data_type): def read_value(f, data_type):
if data_type == DATA_TYPES["string"]: if data_type == DATA_TYPES["string"]:
@ -375,7 +391,7 @@ def dequantize_q2_k(data):
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4) return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
def dequantize_q2_k_gpu(data): def dequantize_q2_k_gpu(data):
pass raise NotImplementedError()
def dequantize_q3_k(data): def dequantize_q3_k(data):
# C implementation # C implementation
@ -420,7 +436,7 @@ def dequantize_q3_k(data):
], axis=1) ], axis=1)
def dequantize_q3_k_gpu(data): def dequantize_q3_k_gpu(data):
pass raise NotImplementedError()
def dequantize_q4_k(data): def dequantize_q4_k(data):
# C implementation # C implementation
@ -429,20 +445,16 @@ def dequantize_q4_k(data):
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L116 # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L116
block_size = GGML_BLOCK_SIZES["Q4_K"] block_size = GGML_BLOCK_SIZES["Q4_K"]
num_blocks = len(data) // block_size num_blocks = len(data) // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
# Casting to float32 because float16 is very slow on CPU # Casting to float32 because float16 is very slow on CPU
scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32) scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32)
scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32) scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32)
qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1) qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32) qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32)
# Dequantize scales and offsets (6 bits and 4 + 2 bits) # Dequantize scales and offsets (6 bits and 4 + 2 bits)
factors = scale_factors * np.concatenate([qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1) factors = scale_factors * np.concatenate([qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1)
offsets = scale_offsets * np.concatenate([qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1) offsets = scale_offsets * np.concatenate([qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1)
# Interleave low and high quantized bits # Interleave low and high quantized bits
qs2 = np.stack([qs2 & 0xf, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32) qs2 = np.stack([qs2 & 0xf, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32)
# Dequantize final weights using scales and offsets # Dequantize final weights using scales and offsets
@ -512,9 +524,14 @@ def dequantize_q5_k(data):
d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8, d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,
], axis=1) ], axis=1)
def dequantize_q5_k_gpu(data): def dequantize_q5_k_gpu(data, device:str ="cuda"):
pass block_size = GGML_BLOCK_SIZES["Q5_K"]
data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data)
return KTransformersOps.dequantize_q5_k(data, block_size, device)
def dequantize_q6_k(data): def dequantize_q6_k(data):
# C implementation # C implementation
@ -571,7 +588,49 @@ def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda"):
num_blocks = len(data) // block_size num_blocks = len(data) // block_size
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
data = torch.from_numpy(data) data = torch.from_numpy(data)
return KTransformersOps.dequantize_q6_k(data, 210, device) return KTransformersOps.dequantize_q6_k(data, block_size, device)
def dequantize_q4_0(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1515
# C struct definition
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-common.h#L141
num_blocks = len(data) // GGML_BLOCK_SIZES["Q4_0"]
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 8)[:, :1].astype(np.float32)
qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 16)[:, 2:]
return np.concatenate([
scales * ((qs & 0xf).astype(np.int8) - 8),
scales * ((qs >> 4).astype(np.int8) - 8),
], axis=1)
def dequantize_q4_0_gpu(data):
raise NotImplementedError()
def dequantize_q5_0(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1556
# C struct definition
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-common.h#L161
num_blocks = len(data) // GGML_BLOCK_SIZES["Q5_0"]
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 2 + 8)[:, :1].astype(np.float32)
qh = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2:2 + 4]
qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2 + 4:]
bits = np.unpackbits(qh, axis=-1, bitorder="little")
x0 = ((qs & 0xf).astype(np.int8) | (bits[:, :16] << 4)) - 16
x1 = ((qs >> 4).astype(np.int8) | (bits[:, 16:] << 4)) - 16
return np.concatenate([
scales * x0,
scales * x1,
], axis=1)
def dequantize_q5_0_gpu(data):
raise NotImplementedError()
def dequantize_q8_0(data): def dequantize_q8_0(data):
# C struct definition # C struct definition
@ -615,6 +674,8 @@ def dequantize_f16_gpu(data, device):
GGML_DEQUANTIZE = { GGML_DEQUANTIZE = {
"F32": dequantize_f32, "F32": dequantize_f32,
"F16": dequantize_f16, "F16": dequantize_f16,
"Q4_0": dequantize_q4_0,
"Q5_0": dequantize_q5_0,
"Q8_0": dequantize_q8_0, "Q8_0": dequantize_q8_0,
"Q2_K": dequantize_q2_k, "Q2_K": dequantize_q2_k,
"Q3_K": dequantize_q3_k, "Q3_K": dequantize_q3_k,
@ -626,6 +687,8 @@ GGML_DEQUANTIZE = {
GGML_DEQUANTIZE_GPU = { GGML_DEQUANTIZE_GPU = {
"F32": dequantize_f32_gpu, "F32": dequantize_f32_gpu,
"F16": dequantize_f16_gpu, "F16": dequantize_f16_gpu,
"Q4_0": dequantize_q4_0_gpu,
"Q5_0": dequantize_q5_0_gpu,
"Q8_0": dequantize_q8_0_gpu, "Q8_0": dequantize_q8_0_gpu,
"Q2_K": dequantize_q2_k_gpu, "Q2_K": dequantize_q2_k_gpu,
"Q3_K": dequantize_q3_k_gpu, "Q3_K": dequantize_q3_k_gpu,
@ -634,7 +697,34 @@ GGML_DEQUANTIZE_GPU = {
"Q6_K": dequantize_q6_k_gpu, "Q6_K": dequantize_q6_k_gpu,
} }
def translate_name_to_gguf_mixtral(name):
replacement_template = {
"w1.weight": "ffn_gate",
"w2.weight": "ffn_down",
"w3.weight": "ffn_up"
}
pattern = re.compile(r"model.layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.(w\d\.weight)")
def replace_match(match):
blk_id = match.group(1)
expert_id = match.group(2)
weight_type = match.group(3)
if weight_type in replacement_template:
return f"blk.{blk_id}.{replacement_template[weight_type]}.{expert_id}.weight"
else:
return match.group(0)
new_name = re.sub(pattern, replace_match, name)
return new_name
def translate_name_to_gguf(name): def translate_name_to_gguf(name):
name = translate_name_to_gguf_mixtral(name)
name = name.replace("lm_head.", "output.") name = name.replace("lm_head.", "output.")
name = name.replace("model.embed_tokens.", "token_embd.") name = name.replace("model.embed_tokens.", "token_embd.")
name = name.replace("model.norm.", "output_norm.") name = name.replace("model.norm.", "output_norm.")
@ -671,9 +761,14 @@ def translate_name_to_gguf(name):
name = name.replace(".mlp.experts.ffn_gate_exps", ".ffn_gate_exps") name = name.replace(".mlp.experts.ffn_gate_exps", ".ffn_gate_exps")
name = name.replace(".mlp.experts.ffn_up_exps", ".ffn_up_exps") name = name.replace(".mlp.experts.ffn_up_exps", ".ffn_up_exps")
name = name.replace(".block_sparse_moe.gate.", ".ffn_gate_inp.")
name = name.replace(".block_sparse_moe.experts", "")
return name return name
if __name__ == '__main__': if __name__ == '__main__':
gguf_path = '/mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH' gguf_path = '/mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH'
loader = GGUFLoader(gguf_path) loader = GGUFLoader(gguf_path)
loader.load_gguf_tensor('token_embd.weight') loader.load_gguf_tensor('token_embd.weight')

View file

@ -39,6 +39,22 @@ def set_param(module: nn.Module, name: str, weights: torch.Tensor):
param.unsqueeze_(0) param.unsqueeze_(0)
setattr(module, name, param) setattr(module, name, param)
def get_device(gguf_module_key:str, device_map:dict):
if gguf_module_key in device_map:
return device_map[gguf_module_key]["generate_device"]
else:
return "cuda"
def get_all_used_cuda_device(device_map:dict):
all_device_list = set()
for key in device_map:
all_device_list.add(device_map[key]["generate_device"]) if "generate_device" in device_map[key] else None
all_device_list.add(device_map[key]["prefill_device"]) if "prefill_device" in device_map[key] else None
if "cpu" in all_device_list:
all_device_list.remove("cpu")
all_device_list = list(all_device_list)
return all_device_list
def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str = ""): def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str = ""):
prefix = prefix.replace("orig_module.", "") prefix = prefix.replace("orig_module.", "")
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set} persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
@ -47,18 +63,19 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
for name, param in local_state.items(): for name, param in local_state.items():
key = prefix + name key = prefix + name
translated_key = translate_name_to_gguf(key) translated_key = translate_name_to_gguf(key)
print("default loading weights", key, translated_key)
if translated_key in gguf_loader.tensor_file_map: if translated_key in gguf_loader.tensor_file_map:
target_dtype = torch.get_default_dtype() target_dtype = torch.get_default_dtype()
device = "cpu" if "embd" in translated_key else "cuda" device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
print(f"loading {translated_key} to {device}")
# device = "cpu" if "embd" in translated_key else "cuda"
weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype) weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
set_param(module, name, weights) set_param(module, name, weights)
del weights del weights
else: else:
#print(load_config.tensor_file_map.keys()) #print(load_config.tensor_file_map.keys())
raise Exception(f"can't fand {translated_key} in GGUF file!") raise Exception(f"can't find {translated_key} in GGUF file!")
def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix='', return_when_injected:bool = False, only_load_injected:bool = False): def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
# print(f"recursively loading weights {prefix},{return_when_injected=}, {only_load_injected=}") # print(f"recursively loading weights {prefix},{return_when_injected=}, {only_load_injected=}")
if not isinstance(module, base_operator.BaseInjectedModule): if not isinstance(module, base_operator.BaseInjectedModule):
load_cur_state_dict(module, gguf_loader, prefix) load_cur_state_dict(module, gguf_loader, prefix)
@ -67,26 +84,35 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix='', return_whe
else: else:
module.load() module.load()
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000): def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True):
import os import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch._dynamo.config.suppress_errors = True torch._dynamo.config.suppress_errors = True
batch_size, seq_length = inputs.shape batch_size, seq_length = inputs.shape
torch_device = inputs.device device_map = model.config.gguf_loader.tensor_device_map
torch_device = get_device('blk.0.self_attn', device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
inputs = inputs.to(torch_device)
all_cuda_device = get_all_used_cuda_device(device_map)
tokens = [] tokens = []
def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values): def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, use_cuda_graph: bool = True):
logits = cuda_graph_runner(cur_token, position_ids, cache_position) if use_cuda_graph:
logits = cuda_graph_runner(cur_token, position_ids, cache_position)
else:
# custom_stream = torch.cuda.Stream()
torch.cuda.set_device(torch_device)
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device)
# with torch.cuda.stream(custom_stream):
logits=model(inputs_embeds=inputs_embeds,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False, use_cache=True)[0]
past_key_values.change_seq_length(1) past_key_values.change_seq_length(1)
""" for device in all_cuda_device:
with torch.cuda.stream(custom_stream): torch.cuda.synchronize(device)
logits=model(cur_token,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False, use_cache=True)[0]
#"""
torch.cuda.synchronize()
#print(logits) #print(logits)
next_token_scores = logits_warper(inputs, logits[:, -1, :]) next_token_scores = logits_warper(inputs, logits[:, -1, :])
if generation_config.do_sample: if generation_config.do_sample:
@ -96,10 +122,11 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
next_token = torch.argmax(next_token_scores, dim=-1) next_token = torch.argmax(next_token_scores, dim=-1)
return next_token return next_token
torch.cuda.set_device(torch_device)
with torch.no_grad(): with torch.no_grad():
stream = TextStreamer(tokenizer) stream = TextStreamer(tokenizer)
past_key_values = StaticCache( past_key_values = StaticCache(
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = torch_device, dtype = model.dtype config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
) )
cache_position = torch.arange(seq_length, device=torch_device) cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros( generated_ids = torch.zeros(
@ -108,23 +135,22 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int) generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
past_key_values.cur_idx=cache_position past_key_values.cur_idx=cache_position
start_time = time.time() start_time = time.time()
#custom_stream = torch.cuda.Stream()
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to("cuda") inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
logits = model( logits = model(
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
)[0][:,-1,:].unsqueeze(0).clone() )[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
generation_config, model_kwargs = model._prepare_generation_config( generation_config, model_kwargs = model._prepare_generation_config(
None, max_length=max_new_tokens, None, max_length=max_new_tokens,
do_sample=True, top_k=5, top_p=0.85, temperature=0.1 # change this to modify generate config do_sample=True, top_k=5, top_p=0.85, temperature=0.1 # change this to modify generate config
) )
try: # transformers==4.43 try: # transformers==4.43
logits_warper = ( logits_warper = (
model._get_logits_warper(generation_config,device=inputs.device) if generation_config.do_sample else None model._get_logits_warper(generation_config,device=inputs.device)
) )
except: except:
logits_warper = ( logits_warper = (
model._get_logits_warper(generation_config) if generation_config.do_sample else None model._get_logits_warper(generation_config)
) )
next_token_scores = logits_warper(inputs, logits[:, -1, :]) next_token_scores = logits_warper(inputs, logits[:, -1, :])
if generation_config.do_sample: if generation_config.do_sample:
@ -136,7 +162,6 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
prefill_count = seq_length prefill_count = seq_length
prefill_time = first_token_time prefill_time = first_token_time
print(stream.put(next_token.item()), end="", flush=True) print(stream.put(next_token.item()), end="", flush=True)
generated_ids[:, seq_length] = next_token generated_ids[:, seq_length] = next_token
tokens.append(next_token) tokens.append(next_token)
@ -145,11 +170,15 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
seq_length += 1 seq_length += 1
cuda_graph_runner = CUDAGraphRunner() if use_cuda_graph:
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, return_dict=False, use_cache=True) cuda_graph_runner = CUDAGraphRunner()
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
else:
cuda_graph_runner = None
start_time = time.time() start_time = time.time()
for _ in range(1, max_new_tokens): for _ in range(1, max_new_tokens):
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values) next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1) inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int() generated_ids[:, cache_position] = next_token.int()
tokens.append(next_token.int()) tokens.append(next_token.int())
@ -163,6 +192,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
cache_position += 1 cache_position += 1
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
total_time = time.time() - start_time total_time = time.time() - start_time
tokens_generated = len(tokens) tokens_generated = len(tokens)
tokens_per_second = tokens_generated / total_time tokens_per_second = tokens_generated / total_time

View file

@ -3,7 +3,8 @@ requires = [
"setuptools", "setuptools",
"torch >= 2.3.0", "torch >= 2.3.0",
"ninja", "ninja",
"packaging" "packaging",
"cpufeature"
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View file

@ -6,7 +6,7 @@ Author : chenxl
Date : 2024-07-27 16:15:27 Date : 2024-07-27 16:15:27
Version : 1.0.0 Version : 1.0.0
LastEditors : chenxl LastEditors : chenxl
LastEditTime : 2024-07-31 09:44:46 LastEditTime : 2024-08-08 02:45:15
Adapted from: Adapted from:
https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py
Copyright (c) 2023, Tri Dao. Copyright (c) 2023, Tri Dao.
@ -19,6 +19,7 @@ import re
import ast import ast
import subprocess import subprocess
import platform import platform
import shutil
import http.client import http.client
import urllib.request import urllib.request
import urllib.error import urllib.error
@ -27,6 +28,7 @@ from packaging.version import parse
import torch.version import torch.version
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
from setuptools import setup, Extension from setuptools import setup, Extension
from cpufeature.extension import CPUFeature
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
class CpuInstructInfo: class CpuInstructInfo:
@ -67,6 +69,8 @@ class VersionInfo:
""" """
if sys.platform.startswith("linux"): if sys.platform.startswith("linux"):
return f'linux_{platform.uname().machine}' return f'linux_{platform.uname().machine}'
elif sys.platform == "win32":
return "win_amd64"
else: else:
raise ValueError("Unsupported platform: {}".format(sys.platform)) raise ValueError("Unsupported platform: {}".format(sys.platform))
@ -97,6 +101,15 @@ class VersionInfo:
return 'avx2' return 'avx2'
raise ValueError( raise ValueError(
"Unsupported cpu Instructions: {}".format(flags_line)) "Unsupported cpu Instructions: {}".format(flags_line))
elif sys.platform == "win32":
if CPUFeature.get("AVX512bw", False):
return 'fancy'
if CPUFeature.get("AVX512f", False):
return 'avx512'
if CPUFeature.get("AVX2", False):
return 'avx2'
raise ValueError(
"Unsupported cpu Instructions: {}".format(str(CPUFeature)))
else: else:
raise ValueError("Unsupported platform: {}".format(sys.platform)) raise ValueError("Unsupported platform: {}".format(sys.platform))
@ -154,7 +167,7 @@ class BuildWheelsCommand(_bdist_wheel):
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
print("Raw wheel path", wheel_path) print("Raw wheel path", wheel_path)
os.rename(wheel_filename, wheel_path) shutil.move(wheel_filename, wheel_path)
except (urllib.error.HTTPError, urllib.error.URLError, http.client.RemoteDisconnected): except (urllib.error.HTTPError, urllib.error.URLError, http.client.RemoteDisconnected):
print("Precompiled wheel not found. Building from source...") print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source # If the wheel could not be downloaded, build from source

View file

@ -22,7 +22,7 @@
#include <cstring> #include <cstring>
#include <type_traits> #include <type_traits>
#if defined __x86_64__ || defined __aarch64__ #if defined __x86_64__ || defined __aarch64__ || defined(_M_X64)
#include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h" #include "llama.cpp/ggml-quants.h"
@ -225,7 +225,7 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const voi
return true; return true;
} }
#if defined __x86_64__ #if defined __x86_64__ || defined(_M_X64)
#if defined HAVE_FANCY_SIMD #if defined HAVE_FANCY_SIMD
#undef HAVE_FANCY_SIMD #undef HAVE_FANCY_SIMD
@ -1412,7 +1412,8 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int) { bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int) {
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00); if (ne00 % ggml_blck_size(GGML_TYPE_Q8_K) == 0)
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00);
switch (typeA) { switch (typeA) {
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:

View file

@ -3,6 +3,6 @@
// Copyrigth 2024 Iwan Kawrakow. // Copyrigth 2024 Iwan Kawrakow.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#include "iqk_mul_mat.inc" #include "iqk_mul_mat.inc"
#endif // __x86_64__ #endif // __x86_64__

View file

@ -3,7 +3,7 @@
// Copyrigth 2024 Iwan Kawrakow. // Copyrigth 2024 Iwan Kawrakow.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define iqk_mul_mat iqk_mul_mat_zen4 #define iqk_mul_mat iqk_mul_mat_zen4
#define iqk_mul_mat_moe iqk_mul_mat_moe_zen4 #define iqk_mul_mat_moe iqk_mul_mat_moe_zen4
#include "iqk_mul_mat.inc" #include "iqk_mul_mat.inc"

View file

@ -22,19 +22,22 @@
#include "sgemm.h" #include "sgemm.h"
// #include <cosmo.h> // #include <cosmo.h>
#include <cpuid.h> // #include <cpuid.h>
// #include <libc/sysv/consts/hwcap.h> // #include <libc/sysv/consts/hwcap.h>
#include <stdio.h> #include <stdio.h>
#include <sys/auxv.h> // #include <sys/auxv.h>
#include <cassert> #include <cassert>
// #include "llamafile.h" // #include "llamafile.h"
static const struct GemmFuncs { static const struct GemmFuncs {
typeof(llamafile_sgemm)* sgemm; bool (*sgemm)(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);
typeof(llamafile_mixmul)* mixmul; bool (*mixmul)(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);
typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported; bool (*iqk_mixmul)(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);
// typeof(llamafile_sgemm)* sgemm;
// typeof(llamafile_mixmul)* mixmul;
// typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported;
GemmFuncs() { GemmFuncs() {
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
// if (X86_HAVE(AVX)) { // if (X86_HAVE(AVX)) {
// if (X86_HAVE(FMA)) { // if (X86_HAVE(FMA)) {
// if (X86_HAVE(AVX2)) { // if (X86_HAVE(AVX2)) {
@ -86,10 +89,12 @@ static const struct GemmFuncs {
// sgemm = llamafile_sgemm_unsupported; // sgemm = llamafile_sgemm_unsupported;
// mixmul = llamafile_mixmul_unsupported; // mixmul = llamafile_mixmul_unsupported;
// } // }
#if defined(__AVX__) #if defined(__AVX__)
#if defined(__FMA__) #if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))
#if defined(__AVX2__) #if defined(__AVX2__)
#if defined(__AVX512F__) #if defined(__AVX512F__)
printf("__AVX512F__\n");
#if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__) #if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__)
// AMD Zen4+ (2023-) // AMD Zen4+ (2023-)
sgemm = llamafile_sgemm_amd_zen4; sgemm = llamafile_sgemm_amd_zen4;

View file

@ -223,7 +223,7 @@ inline float32x4_t badder(float32x4_t a, float b, float32x4_t c, float32x4_t* e)
} }
#endif #endif
#if defined(__FMA__) #if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
template <> template <>
inline __m256 madd(__m256 a, __m256 b, __m256 c) { inline __m256 madd(__m256 a, __m256 b, __m256 c) {

View file

@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation. // Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define llamafile_mixmul llamafile_mixmul_amd_avx #define llamafile_mixmul llamafile_mixmul_amd_avx
#include "tinyblas_cpu_mixmul.inc" #include "tinyblas_cpu_mixmul.inc"

View file

@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation. // Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define llamafile_mixmul llamafile_mixmul_amd_avx2 #define llamafile_mixmul llamafile_mixmul_amd_avx2
#include "tinyblas_cpu_mixmul.inc" #include "tinyblas_cpu_mixmul.inc"
#endif // __x86_64__ #endif // __x86_64__

View file

@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation. // Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define llamafile_mixmul llamafile_mixmul_amd_avx512f #define llamafile_mixmul llamafile_mixmul_amd_avx512f
#include "tinyblas_cpu_mixmul.inc" #include "tinyblas_cpu_mixmul.inc"
#endif // __x86_64__ #endif // __x86_64__

View file

@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation. // Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define llamafile_mixmul llamafile_mixmul_amd_avxvnni #define llamafile_mixmul llamafile_mixmul_amd_avxvnni
#include "tinyblas_cpu_mixmul.inc" #include "tinyblas_cpu_mixmul.inc"
#endif // __x86_64__ #endif // __x86_64__

View file

@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation. // Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define llamafile_mixmul llamafile_mixmul_amd_fma #define llamafile_mixmul llamafile_mixmul_amd_fma
#include "tinyblas_cpu_mixmul.inc" #include "tinyblas_cpu_mixmul.inc"
#endif // __x86_64__ #endif // __x86_64__

View file

@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation. // Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define llamafile_mixmul llamafile_mixmul_amd_zen4 #define llamafile_mixmul llamafile_mixmul_amd_zen4
#include "tinyblas_cpu_mixmul.inc" #include "tinyblas_cpu_mixmul.inc"
#endif // __x86_64__ #endif // __x86_64__

View file

@ -321,8 +321,8 @@ bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void
assert(ith < nth); assert(ith < nth);
#if QK_K == 256 #if QK_K == 256
#if defined(__x86_64__) #if defined(__x86_64__) || defined(_M_X64)
#if defined(__AVX2__) && defined(__FMA__) #if defined(__AVX2__) && (defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))))
// if (X86_CHECK(AVX2) && X86_CHECK(FMA)) { // if (X86_CHECK(AVX2) && X86_CHECK(FMA)) {
if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) { if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) {
if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float*)C, ldc, ith, nth)) { if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float*)C, ldc, ith, nth)) {

View file

@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation. // Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define llamafile_sgemm llamafile_sgemm_amd_avx #define llamafile_sgemm llamafile_sgemm_amd_avx
#include "tinyblas_cpu_sgemm.inc" #include "tinyblas_cpu_sgemm.inc"
#endif // __x86_64__ #endif // __x86_64__

View file

@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation. // Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define llamafile_sgemm llamafile_sgemm_amd_avx2 #define llamafile_sgemm llamafile_sgemm_amd_avx2
#include "tinyblas_cpu_sgemm.inc" #include "tinyblas_cpu_sgemm.inc"
#endif // __x86_64__ #endif // __x86_64__

View file

@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation. // Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define llamafile_sgemm llamafile_sgemm_amd_avx512f #define llamafile_sgemm llamafile_sgemm_amd_avx512f
#include "tinyblas_cpu_sgemm.inc" #include "tinyblas_cpu_sgemm.inc"
#endif // __x86_64__ #endif // __x86_64__

View file

@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation. // Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define llamafile_sgemm llamafile_sgemm_amd_avxvnni #define llamafile_sgemm llamafile_sgemm_amd_avxvnni
#include "tinyblas_cpu_sgemm.inc" #include "tinyblas_cpu_sgemm.inc"
#endif // __x86_64__ #endif // __x86_64__

View file

@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation. // Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define llamafile_sgemm llamafile_sgemm_amd_fma #define llamafile_sgemm llamafile_sgemm_amd_fma
#include "tinyblas_cpu_sgemm.inc" #include "tinyblas_cpu_sgemm.inc"
#endif // __x86_64__ #endif // __x86_64__

View file

@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation. // Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define llamafile_sgemm llamafile_sgemm_amd_zen4 #define llamafile_sgemm llamafile_sgemm_amd_zen4
#define iqk_mul_mat iqk_mul_mat_zen4 #define iqk_mul_mat iqk_mul_mat_zen4
#include "tinyblas_cpu_sgemm.inc" #include "tinyblas_cpu_sgemm.inc"