mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 20:19:51 +00:00
commit
3986e2d2cf
31 changed files with 1713 additions and 114 deletions
|
@ -23,6 +23,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
|
|||
|
||||
<h2 id="Updates">🔥 Updates</h2>
|
||||
|
||||
* **Mar 15, 2025**: Support ROCm on AMD GPU ([Tutorial](./doc/en/ROCm.md)).
|
||||
* **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./doc/en/fp8_kernel.md) weights. Support 139K [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context) for DeepSeek-V3 and R1 in 24GB VRAM.
|
||||
* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context).
|
||||
* **Feb 15, 2025**: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%, up to 16 Tokens/s), update [docs](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/).
|
||||
|
|
|
@ -22,6 +22,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
|
|||
|
||||
<h2 id="Updates">🔥 Updates</h2>
|
||||
|
||||
* **Mar 15, 2025**: Support ROCm on AMD GPU ([Tutorial](./doc/en/ROCm.md)).
|
||||
* **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./doc/en/fp8_kernel.md) weights. Support 139K [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context) for DeepSeek-V3 and R1 in 24GB VRAM.
|
||||
* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context).
|
||||
* **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. The detailed tutorial is [here](./en/DeepseekR1_V3_tutorial.md).
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
- [Injection Tutorial](en/injection_tutorial.md)
|
||||
- [Multi-GPU Tutorial](en/multi-gpu-tutorial.md)
|
||||
- [Use FP8 GPU Kernel](en/fp8_kernel.md)
|
||||
- [Use AMD GPU](en/ROCm.md)
|
||||
# Server
|
||||
- [Server](en/api/server/server.md)
|
||||
- [Website](en/api/server/website.md)
|
||||
|
|
96
doc/en/ROCm.md
Normal file
96
doc/en/ROCm.md
Normal file
|
@ -0,0 +1,96 @@
|
|||
# ROCm Support for ktransformers (Beta)
|
||||
|
||||
## Introduction
|
||||
|
||||
### Overview
|
||||
In our effort to expand GPU architecture support beyond NVIDIA, we are excited to introduce **AMD GPU support through ROCm** in ktransformers (Beta release). This implementation has been tested and developed using EPYC 9274F processors and AMD Radeon 7900xtx GPUs.
|
||||
|
||||
## Installation Guide
|
||||
|
||||
### 1. Install ROCm Driver
|
||||
Begin by installing the ROCm drivers for your AMD GPU:
|
||||
- [Official ROCm Installation Guide for Radeon GPUs](https://rocm.docs.amd.com/projects/radeon/en/latest/docs/install/native_linux/install-radeon.html)
|
||||
|
||||
### 2. Set Up Conda Environment
|
||||
We recommend using Miniconda3/Anaconda3 for environment management:
|
||||
|
||||
```bash
|
||||
# Download Miniconda
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
|
||||
|
||||
# Create environment
|
||||
conda create --name ktransformers python=3.11
|
||||
conda activate ktransformers
|
||||
|
||||
# Install required libraries
|
||||
conda install -c conda-forge libstdcxx-ng
|
||||
|
||||
# Verify GLIBCXX version (should include 3.4.32)
|
||||
strings ~/anaconda3/envs/ktransformers/lib/libstdc++.so.6 | grep GLIBCXX
|
||||
```
|
||||
|
||||
> **Note:** Adjust the Anaconda path if your installation directory differs from `~/anaconda3`
|
||||
|
||||
### 3. Install PyTorch for ROCm
|
||||
Install PyTorch with ROCm 6.2.4 support:
|
||||
|
||||
```bash
|
||||
pip3 install torch torchvision torchaudio \
|
||||
--index-url https://download.pytorch.org/whl/rocm6.2.4
|
||||
pip3 install packaging ninja cpufeature numpy
|
||||
```
|
||||
|
||||
> **Tip:** For other ROCm versions, visit [PyTorch Previous Versions](https://pytorch.org/get-started/previous-versions/)
|
||||
|
||||
### 4. Build ktransformers
|
||||
|
||||
```bash
|
||||
# Clone repository
|
||||
git clone https://github.com/kvcache-ai/ktransformers.git
|
||||
cd ktransformers
|
||||
git submodule update --init
|
||||
|
||||
# Optional: Compile web interface
|
||||
# See: api/server/website.md
|
||||
|
||||
# Install dependencies
|
||||
bash install.sh
|
||||
```
|
||||
|
||||
## Running DeepSeek-R1 Models
|
||||
|
||||
### Configuration for 24GB VRAM GPUs
|
||||
Use our optimized configuration for constrained VRAM:
|
||||
|
||||
```bash
|
||||
python ktransformers/local_chat.py \
|
||||
--model_path deepseek-ai/DeepSeek-R1 \
|
||||
--gguf_path <path_to_gguf_files> \
|
||||
--optimize_config_path ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml \
|
||||
--cpu_infer <cpu_cores + 1>
|
||||
```
|
||||
|
||||
> **Beta Note:** Current Q8 linear implementation (Marlin alternative) shows suboptimal performance. Expect optimizations in future releases.
|
||||
|
||||
### Configuration for 40GB+ VRAM GPUs
|
||||
For better performance on high-VRAM GPUs:
|
||||
|
||||
1. Modify `DeepSeek-V3-Chat.yaml`:
|
||||
```yaml
|
||||
# Replace all instances of:
|
||||
KLinearMarlin → KLinearTorch
|
||||
```
|
||||
|
||||
2. Execute with:
|
||||
```bash
|
||||
python ktransformers/local_chat.py \
|
||||
--model_path deepseek-ai/DeepSeek-R1 \
|
||||
--gguf_path <path_to_gguf_files> \
|
||||
--optimize_config_path <modified_yaml_path> \
|
||||
--cpu_infer <cpu_cores + 1>
|
||||
```
|
||||
> **Tip:** If you got 2 * 24GB AMD GPUS, you may also do the same modify and run `ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml` instead.
|
||||
|
||||
## Known Limitations
|
||||
- Marlin operations not supported on ROCm platform
|
||||
- Current Q8 linear implementation shows reduced performance (Beta limitation)
|
|
@ -32,6 +32,7 @@ endif()
|
|||
option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF)
|
||||
option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" OFF)
|
||||
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
|
||||
option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF)
|
||||
|
||||
# Architecture specific
|
||||
# TODO: probably these flags need to be tweaked on some architectures
|
||||
|
@ -201,6 +202,31 @@ endif()
|
|||
# message(STATUS "Can't found CUDA lib")
|
||||
# endif()
|
||||
|
||||
if (NOT EXISTS $ENV{ROCM_PATH})
|
||||
if (NOT EXISTS /opt/rocm)
|
||||
set(ROCM_PATH /usr)
|
||||
else()
|
||||
set(ROCM_PATH /opt/rocm)
|
||||
endif()
|
||||
else()
|
||||
set(ROCM_PATH $ENV{ROCM_PATH})
|
||||
endif()
|
||||
|
||||
list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
|
||||
list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}/lib64/cmake")
|
||||
|
||||
if (NOT EXISTS $ENV{MUSA_PATH})
|
||||
if (NOT EXISTS /opt/musa)
|
||||
set(MUSA_PATH /usr/local/musa)
|
||||
else()
|
||||
set(MUSA_PATH /opt/musa)
|
||||
endif()
|
||||
else()
|
||||
set(MUSA_PATH $ENV{MUSA_PATH})
|
||||
endif()
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
|
||||
|
||||
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>")
|
||||
add_compile_options("$<$<COMPILE_LANGUAGE:C>:${ARCH_FLAGS}>")
|
||||
|
||||
|
@ -218,6 +244,14 @@ elseif (UNIX)
|
|||
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
|
||||
endif()
|
||||
|
||||
if (KTRANSFORMERS_USE_ROCM)
|
||||
find_package(HIP REQUIRED)
|
||||
if(HIP_FOUND)
|
||||
include_directories("${HIP_INCLUDE_DIRS}")
|
||||
add_compile_definitions(KTRANSFORMERS_USE_ROCM=1)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (KTRANSFORMERS_USE_MUSA)
|
||||
if (NOT EXISTS $ENV{MUSA_PATH})
|
||||
if (NOT EXISTS /opt/musa)
|
||||
|
@ -258,6 +292,11 @@ elseif(UNIX)
|
|||
endif()
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
|
||||
endif()
|
||||
if (KTRANSFORMERS_USE_ROCM)
|
||||
add_compile_definitions(USE_HIP=1)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE "${ROCM_PATH}/lib/libamdhip64.so")
|
||||
message(STATUS "Building for HIP")
|
||||
endif()
|
||||
if(KTRANSFORMERS_USE_MUSA)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
|
||||
endif()
|
||||
|
|
|
@ -7,79 +7,83 @@
|
|||
* @LastEditTime : 2024-08-07 09:47:43
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
**/
|
||||
#ifndef CPUINFER_CPUINFER_H
|
||||
#define CPUINFER_CPUINFER_H
|
||||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#ifdef KTRANSFORMERS_USE_CUDA
|
||||
#include "vendors/cuda.h"
|
||||
#elif KTRANSFORMERS_USE_MUSA
|
||||
#include "vendors/musa.h"
|
||||
#endif
|
||||
|
||||
#include "backend.h"
|
||||
#include "task_queue.h"
|
||||
|
||||
#include "llama.cpp/ggml-impl.h"
|
||||
|
||||
class CPUInfer {
|
||||
public:
|
||||
CPUInfer(int thread_num) {
|
||||
backend_ = new Backend(thread_num - 1);
|
||||
task_queue_ = new TaskQueue();
|
||||
for (int i = 0; i < (1 << 16); ++i) {
|
||||
ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(i);
|
||||
}
|
||||
}
|
||||
|
||||
~CPUInfer() {
|
||||
delete backend_;
|
||||
delete task_queue_;
|
||||
}
|
||||
|
||||
template <typename Func, typename Obj, typename... Args>
|
||||
void enqueue(Func f, Obj* obj, Args... args) {
|
||||
task_queue_->enqueue([=]() {
|
||||
std::invoke(f, *obj, args..., backend_);
|
||||
});
|
||||
}
|
||||
|
||||
void submit(std::pair<intptr_t, intptr_t> params) {
|
||||
void (*func)(void*) = (void (*)(void*))params.first;
|
||||
void* args = (void*)params.second;
|
||||
*((CPUInfer**)args) = this;
|
||||
func(args);
|
||||
}
|
||||
|
||||
void sync() {
|
||||
task_queue_->sync();
|
||||
}
|
||||
|
||||
void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair<intptr_t, intptr_t> params) {
|
||||
void (*func)(void*) = (void (*)(void*))params.first;
|
||||
void* args = (void*)params.second;
|
||||
*((CPUInfer**)args) = this;
|
||||
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);
|
||||
}
|
||||
|
||||
static void sync_(void* cpu_infer_ptr) {
|
||||
CPUInfer* cpuinfer = (CPUInfer*)cpu_infer_ptr;
|
||||
cpuinfer->sync();
|
||||
}
|
||||
|
||||
void sync_with_cuda_stream(intptr_t user_cuda_stream) {
|
||||
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this);
|
||||
}
|
||||
|
||||
public:
|
||||
Backend* backend_;
|
||||
TaskQueue* task_queue_;
|
||||
};
|
||||
|
||||
#endif
|
||||
#ifndef CPUINFER_CPUINFER_H
|
||||
#define CPUINFER_CPUINFER_H
|
||||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#ifdef KTRANSFORMERS_USE_CUDA
|
||||
#include "vendors/cuda.h"
|
||||
#elif KTRANSFORMERS_USE_MUSA
|
||||
#include "vendors/musa.h"
|
||||
#elif KTRANSFORMERS_USE_ROCM
|
||||
#define __HIP_PLATFORM_AMD__
|
||||
#include "vendors/hip.h"
|
||||
#endif
|
||||
|
||||
#include "backend.h"
|
||||
#include "task_queue.h"
|
||||
#include "../vendors/vendor.h"
|
||||
|
||||
#include "llama.cpp/ggml-impl.h"
|
||||
|
||||
class CPUInfer {
|
||||
public:
|
||||
CPUInfer(int thread_num) {
|
||||
backend_ = new Backend(thread_num - 1);
|
||||
task_queue_ = new TaskQueue();
|
||||
for (int i = 0; i < (1 << 16); ++i) {
|
||||
ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(i);
|
||||
}
|
||||
}
|
||||
|
||||
~CPUInfer() {
|
||||
delete backend_;
|
||||
delete task_queue_;
|
||||
}
|
||||
|
||||
template <typename Func, typename Obj, typename... Args>
|
||||
void enqueue(Func f, Obj* obj, Args... args) {
|
||||
task_queue_->enqueue([=]() {
|
||||
std::invoke(f, *obj, args..., backend_);
|
||||
});
|
||||
}
|
||||
|
||||
void submit(std::pair<intptr_t, intptr_t> params) {
|
||||
void (*func)(void*) = (void (*)(void*))params.first;
|
||||
void* args = (void*)params.second;
|
||||
*((CPUInfer**)args) = this;
|
||||
func(args);
|
||||
}
|
||||
|
||||
void sync() {
|
||||
task_queue_->sync();
|
||||
}
|
||||
|
||||
void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair<intptr_t, intptr_t> params) {
|
||||
void (*func)(void*) = (void (*)(void*))params.first;
|
||||
void* args = (void*)params.second;
|
||||
*((CPUInfer**)args) = this;
|
||||
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);
|
||||
}
|
||||
|
||||
static void sync_(void* cpu_infer_ptr) {
|
||||
CPUInfer* cpuinfer = (CPUInfer*)cpu_infer_ptr;
|
||||
cpuinfer->sync();
|
||||
}
|
||||
|
||||
void sync_with_cuda_stream(intptr_t user_cuda_stream) {
|
||||
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this);
|
||||
}
|
||||
|
||||
public:
|
||||
Backend* backend_;
|
||||
TaskQueue* task_queue_;
|
||||
};
|
||||
|
||||
#endif
|
|
@ -1,3 +1,15 @@
|
|||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#if CUDART_VERSION < 11020
|
||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
||||
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
|
||||
#define CUBLAS_COMPUTE_16F CUDA_R_16F
|
||||
#define CUBLAS_COMPUTE_32F CUDA_R_32F
|
||||
#define cublasComputeType_t cudaDataType_t
|
||||
#endif // CUDART_VERSION < 11020
|
||||
|
|
172
ktransformers/ktransformers_ext/cpu_backend/vendors/hip.h
vendored
Normal file
172
ktransformers/ktransformers_ext/cpu_backend/vendors/hip.h
vendored
Normal file
|
@ -0,0 +1,172 @@
|
|||
#pragma once
|
||||
|
||||
#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hipblas/hipblas.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_bfloat16.h>
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
// for rocblas_initialize()
|
||||
#include "rocblas/rocblas.h"
|
||||
#endif // __HIP_PLATFORM_AMD__
|
||||
|
||||
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
||||
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
||||
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
||||
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_OP_N HIPBLAS_OP_N
|
||||
#define CUBLAS_OP_T HIPBLAS_OP_T
|
||||
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
||||
#define CUBLAS_TF32_TENSOR_OP_MATH 0
|
||||
#define CUDA_R_16F HIPBLAS_R_16F
|
||||
#define CUDA_R_32F HIPBLAS_R_32F
|
||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
|
||||
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
|
||||
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
|
||||
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
|
||||
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
|
||||
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
|
||||
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
|
||||
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
||||
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
|
||||
#define cublasCreate hipblasCreate
|
||||
#define cublasDestroy hipblasDestroy
|
||||
#define cublasGemmEx hipblasGemmEx
|
||||
#define cublasGemmBatchedEx hipblasGemmBatchedEx
|
||||
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
|
||||
#define cublasHandle_t hipblasHandle_t
|
||||
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
||||
#define cublasSetStream hipblasSetStream
|
||||
#define cublasSgemm hipblasSgemm
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define cublasOperation_t hipblasOperation_t
|
||||
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
|
||||
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||
#define cudaDeviceProp hipDeviceProp_t
|
||||
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||
#define cudaError_t hipError_t
|
||||
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
||||
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
|
||||
#define cudaEventCreateWithFlags hipEventCreateWithFlags
|
||||
#define cudaEventDisableTiming hipEventDisableTiming
|
||||
#define cudaEventRecord hipEventRecord
|
||||
#define cudaEventSynchronize hipEventSynchronize
|
||||
#define cudaEvent_t hipEvent_t
|
||||
#define cudaEventDestroy hipEventDestroy
|
||||
#define cudaFree hipFree
|
||||
#define cudaFreeHost hipHostFree
|
||||
#define cudaGetDevice hipGetDevice
|
||||
#define cudaGetDeviceCount hipGetDeviceCount
|
||||
#define cudaGetDeviceProperties hipGetDeviceProperties
|
||||
#define cudaGetErrorString hipGetErrorString
|
||||
#define cudaGetLastError hipGetLastError
|
||||
#define cudaHostRegister hipHostRegister
|
||||
#define cudaHostRegisterPortable hipHostRegisterPortable
|
||||
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
|
||||
#define cudaHostUnregister hipHostUnregister
|
||||
#define cudaLaunchHostFunc hipLaunchHostFunc
|
||||
#define cudaMalloc hipMalloc
|
||||
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
||||
#define cudaMemcpy hipMemcpy
|
||||
#define cudaMemcpyAsync hipMemcpyAsync
|
||||
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
|
||||
#define cudaMemcpy2DAsync hipMemcpy2DAsync
|
||||
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
|
||||
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
|
||||
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
|
||||
#define cudaMemcpyKind hipMemcpyKind
|
||||
#define cudaMemset hipMemset
|
||||
#define cudaMemsetAsync hipMemsetAsync
|
||||
#define cudaMemGetInfo hipMemGetInfo
|
||||
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
|
||||
#define cudaSetDevice hipSetDevice
|
||||
#define cuDeviceGet hipDeviceGet
|
||||
#define CUdevice hipDevice_t
|
||||
#define CUdeviceptr hipDeviceptr_t
|
||||
#define cuMemUnmap hipMemUnmap
|
||||
#define CUmemAccessDesc hipMemAccessDesc
|
||||
#define cuMemAddressFree hipMemAddressFree
|
||||
#define cuMemRelease hipMemRelease
|
||||
#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
|
||||
#define cuMemCreate hipMemCreate
|
||||
#define cuMemAddressReserve hipMemAddressReserve
|
||||
#define cuMemMap hipMemMap
|
||||
#define cuMemSetAccess hipMemSetAccess
|
||||
#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
|
||||
#define CUmemAllocationProp hipMemAllocationProp
|
||||
#define cuDeviceGetAttribute hipDeviceGetAttribute
|
||||
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
|
||||
#define cudaStreamDestroy hipStreamDestroy
|
||||
#define cudaStreamFireAndForget hipStreamFireAndForget
|
||||
#define cudaStreamNonBlocking hipStreamNonBlocking
|
||||
#define cudaStreamPerThread hipStreamPerThread
|
||||
#define cudaStreamSynchronize hipStreamSynchronize
|
||||
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
|
||||
#define cudaGraphExec_t hipGraphExec_t
|
||||
#define cudaGraphNode_t hipGraphNode_t
|
||||
#define cudaKernelNodeParams hipKernelNodeParams
|
||||
#define cudaKernelNodeParams hipKernelNodeParams
|
||||
#define cudaGraphExecDestroy hipGraphExecDestroy
|
||||
#define cudaGraphLaunch hipGraphLaunch
|
||||
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
|
||||
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
|
||||
#define cudaGraphNodeType hipGraphNodeType
|
||||
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
|
||||
#define cudaGraphInstantiate hipGraphInstantiate
|
||||
#define cudaStreamEndCapture hipStreamEndCapture
|
||||
#define cudaGraphDestroy hipGraphDestroy
|
||||
#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
|
||||
#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
|
||||
#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
|
||||
#define cudaGraphNodeGetType hipGraphNodeGetType
|
||||
#define cudaGraphGetNodes hipGraphGetNodes
|
||||
#define cudaGraphExecUpdate hipGraphExecUpdate
|
||||
#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
|
||||
#define cudaStreamBeginCapture hipStreamBeginCapture
|
||||
#define cudaGraph_t hipGraph_t
|
||||
#define cudaStream_t hipStream_t
|
||||
#define cudaSuccess hipSuccess
|
||||
#define cudaHostFn_t hipHostFn_t
|
||||
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
|
||||
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
||||
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
|
||||
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
|
||||
#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
|
||||
#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
|
||||
#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
|
||||
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
|
||||
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
||||
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
||||
|
||||
#define __CUDA_ARCH__ 1300
|
||||
|
||||
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
|
||||
#define GCN
|
||||
#endif
|
||||
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
|
||||
#define CDNA
|
||||
#endif
|
||||
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
||||
defined(__gfx1150__) || defined(__gfx1151__)
|
||||
#define RDNA3
|
||||
#endif
|
||||
|
||||
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
|
||||
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
|
||||
#define RDNA2
|
||||
#endif
|
||||
|
||||
#if defined(__gfx1010__) || defined(__gfx1012__)
|
||||
#define RDNA1
|
||||
#endif
|
||||
|
||||
#ifndef __has_builtin
|
||||
#define __has_builtin(x) 0
|
||||
#endif
|
||||
|
||||
typedef hip_bfloat16 nv_bfloat16;
|
|
@ -1,9 +1,137 @@
|
|||
#pragma once
|
||||
|
||||
#include <musa_runtime.h>
|
||||
#include <musa.h>
|
||||
#include <mublas.h>
|
||||
#include <musa_bf16.h>
|
||||
|
||||
#include <musa_fp16.h>
|
||||
#define CUBLAS_COMPUTE_16F CUDA_R_16F
|
||||
#define CUBLAS_COMPUTE_32F CUDA_R_32F
|
||||
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
|
||||
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_OP_N MUBLAS_OP_N
|
||||
#define CUBLAS_OP_T MUBLAS_OP_T
|
||||
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
|
||||
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
|
||||
#define CUDA_R_16F MUSA_R_16F
|
||||
#define CUDA_R_32F MUSA_R_32F
|
||||
#define cublasComputeType_t cudaDataType_t
|
||||
#define cublasCreate mublasCreate
|
||||
#define cublasDestroy mublasDestroy
|
||||
#define cublasGemmEx mublasGemmEx
|
||||
#define cublasGemmBatchedEx mublasGemmBatchedEx
|
||||
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
|
||||
#define cublasHandle_t mublasHandle_t
|
||||
#define cublasSetMathMode mublasSetMathMode
|
||||
#define cublasSetStream mublasSetStream
|
||||
#define cublasSgemm mublasSgemm
|
||||
#define cublasStatus_t mublasStatus_t
|
||||
#define cublasOperation_t mublasOperation_t
|
||||
#define cublasGetStatusString mublasStatus_to_string
|
||||
#define cudaDataType_t musaDataType_t
|
||||
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
||||
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
|
||||
#define cudaDeviceProp musaDeviceProp
|
||||
#define cudaDeviceSynchronize musaDeviceSynchronize
|
||||
#define cudaError_t musaError_t
|
||||
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
|
||||
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
|
||||
#define cudaEventCreateWithFlags musaEventCreateWithFlags
|
||||
#define cudaEventDisableTiming musaEventDisableTiming
|
||||
#define cudaEventRecord musaEventRecord
|
||||
#define cudaEventSynchronize musaEventSynchronize
|
||||
#define cudaEvent_t musaEvent_t
|
||||
#define cudaEventDestroy musaEventDestroy
|
||||
#define cudaFree musaFree
|
||||
#define cudaFreeHost musaFreeHost
|
||||
#define cudaGetDevice musaGetDevice
|
||||
#define cudaGetDeviceCount musaGetDeviceCount
|
||||
#define cudaGetDeviceProperties musaGetDeviceProperties
|
||||
#define cudaGetErrorString musaGetErrorString
|
||||
#define cudaGetLastError musaGetLastError
|
||||
#define cudaHostRegister musaHostRegister
|
||||
#define cudaHostRegisterPortable musaHostRegisterPortable
|
||||
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
|
||||
#define cudaHostUnregister musaHostUnregister
|
||||
#define cudaLaunchHostFunc musaLaunchHostFunc
|
||||
#define cudaMalloc musaMalloc
|
||||
#define cudaMallocHost musaMallocHost
|
||||
#define cudaMallocManaged musaMallocManaged
|
||||
#define cudaMemcpy musaMemcpy
|
||||
#define cudaMemcpyAsync musaMemcpyAsync
|
||||
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
|
||||
#define cudaMemcpy2DAsync musaMemcpy2DAsync
|
||||
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
|
||||
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
|
||||
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
|
||||
#define cudaMemcpyKind musaMemcpyKind
|
||||
#define cudaMemset musaMemset
|
||||
#define cudaMemsetAsync musaMemsetAsync
|
||||
#define cudaMemGetInfo musaMemGetInfo
|
||||
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
|
||||
#define cudaSetDevice musaSetDevice
|
||||
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
|
||||
#define cudaStreamDestroy musaStreamDestroy
|
||||
#define cudaStreamFireAndForget musaStreamFireAndForget
|
||||
#define cudaStreamNonBlocking musaStreamNonBlocking
|
||||
#define cudaStreamPerThread musaStreamPerThread
|
||||
#define cudaStreamSynchronize musaStreamSynchronize
|
||||
#define cudaStreamWaitEvent musaStreamWaitEvent
|
||||
#define cudaStream_t musaStream_t
|
||||
#define cudaHostFn_t musaHostFn_t
|
||||
#define nv_bfloat16 mt_bfloat16
|
||||
#define cudaSuccess musaSuccess
|
||||
|
||||
// Additional mappings for MUSA virtual memory pool
|
||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
||||
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
|
||||
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
|
||||
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
|
||||
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
|
||||
#define CUdevice MUdevice
|
||||
#define CUdeviceptr MUdeviceptr
|
||||
#define CUmemAccessDesc MUmemAccessDesc
|
||||
#define CUmemAllocationProp MUmemAllocationProp
|
||||
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
|
||||
#define cuDeviceGet muDeviceGet
|
||||
#define cuDeviceGetAttribute muDeviceGetAttribute
|
||||
#define cuMemAddressFree muMemAddressFree
|
||||
#define cuMemAddressReserve muMemAddressReserve
|
||||
#define cuMemCreate muMemCreate
|
||||
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
|
||||
#define cuMemMap muMemMap
|
||||
#define cuMemRelease muMemRelease
|
||||
#define cuMemSetAccess muMemSetAccess
|
||||
#define cuMemUnmap muMemUnmap
|
||||
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
|
||||
#define cudaFuncSetAttribute musaFuncSetAttribute
|
||||
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
|
||||
#define make_cudaExtent make_musaExtent
|
||||
#define make_cudaPitchedPtr make_musaPitchedPtr
|
||||
|
||||
// Additional mappings for MUSA graphs
|
||||
#define CUDA_SUCCESS MUSA_SUCCESS
|
||||
#define CUresult MUresult
|
||||
#define cuGetErrorString muGetErrorString
|
||||
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
|
||||
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
|
||||
#define cudaGraphDestroy musaGraphDestroy
|
||||
#define cudaGraphExecDestroy musaGraphExecDestroy
|
||||
#define cudaGraphExec_t musaGraphExec_t
|
||||
#define cudaGraphExecUpdate musaGraphExecUpdate
|
||||
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
|
||||
#define cudaGraphGetNodes musaGraphGetNodes
|
||||
#define cudaGraphInstantiate musaGraphInstantiate
|
||||
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
||||
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
|
||||
#define cudaGraphLaunch musaGraphLaunch
|
||||
#define cudaGraphNodeGetType musaGraphNodeGetType
|
||||
#define cudaGraphNode_t musaGraphNode_t
|
||||
#define cudaGraphNodeType musaGraphNodeType
|
||||
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
|
||||
#define cudaGraph_t musaGraph_t
|
||||
#define cudaKernelNodeParams musaKernelNodeParams
|
||||
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
||||
#define cudaStreamEndCapture musaStreamEndCapture
|
||||
|
||||
typedef mt_bfloat16 nv_bfloat16;
|
||||
|
|
13
ktransformers/ktransformers_ext/cpu_backend/vendors/vendor.h
vendored
Normal file
13
ktransformers/ktransformers_ext/cpu_backend/vendors/vendor.h
vendored
Normal file
|
@ -0,0 +1,13 @@
|
|||
#ifndef CPUINFER_VENDOR_VENDOR_H
|
||||
#define CPUINFER_VENDOR_VENDOR_H
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include "cuda.h"
|
||||
#elif USE_HIP
|
||||
#define __HIP_PLATFORM_AMD__
|
||||
#include "hip.h"
|
||||
#elif USE_MUSA
|
||||
#include "musa.h"
|
||||
#endif
|
||||
|
||||
#endif // CPUINFER_VENDOR_VENDOR_H
|
|
@ -15,6 +15,7 @@
|
|||
#include <torch/torch.h>
|
||||
#include <cstdint>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
typedef hip_bfloat16 nv_bfloat16;
|
||||
|
||||
__global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
|
||||
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
|
|
@ -36,7 +36,7 @@ inline std::string str(T x) {
|
|||
|
||||
namespace gptq_marlin {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined(__HIP_PLATFORM_AMD__)
|
||||
|
||||
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
int const* __restrict__ perm_int_ptr,
|
||||
|
|
|
@ -39,7 +39,7 @@ using I4 = Vec<int, 4>;
|
|||
|
||||
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined (__HIP_PLATFORM_AMD__)
|
||||
// No support for async
|
||||
#else
|
||||
|
||||
|
|
|
@ -8,6 +8,11 @@
|
|||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
typedef __hip_bfloat16 nv_bfloat16;
|
||||
typedef __hip_bfloat162 nv_bfloat162;
|
||||
#endif
|
||||
|
||||
namespace gptq_marlin {
|
||||
|
||||
template <typename scalar_t>
|
||||
|
|
|
@ -9,7 +9,6 @@
|
|||
**/
|
||||
// Python bindings
|
||||
#include "cpu_backend/cpuinfer.h"
|
||||
#include "device_launch_parameters.h"
|
||||
#include "llamafile/flags.h"
|
||||
#include "operators/kvcache/kvcache.h"
|
||||
#include "operators/llamafile/linear.h"
|
||||
|
|
15
ktransformers/ktransformers_ext/vendors/cuda.h
vendored
Normal file
15
ktransformers/ktransformers_ext/vendors/cuda.h
vendored
Normal file
|
@ -0,0 +1,15 @@
|
|||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#if CUDART_VERSION < 11020
|
||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
||||
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
|
||||
#define CUBLAS_COMPUTE_16F CUDA_R_16F
|
||||
#define CUBLAS_COMPUTE_32F CUDA_R_32F
|
||||
#define cublasComputeType_t cudaDataType_t
|
||||
#endif // CUDART_VERSION < 11020
|
172
ktransformers/ktransformers_ext/vendors/hip.h
vendored
Normal file
172
ktransformers/ktransformers_ext/vendors/hip.h
vendored
Normal file
|
@ -0,0 +1,172 @@
|
|||
#pragma once
|
||||
|
||||
#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hipblas/hipblas.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_bfloat16.h>
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
// for rocblas_initialize()
|
||||
#include "rocblas/rocblas.h"
|
||||
#endif // __HIP_PLATFORM_AMD__
|
||||
|
||||
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
||||
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
||||
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
||||
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_OP_N HIPBLAS_OP_N
|
||||
#define CUBLAS_OP_T HIPBLAS_OP_T
|
||||
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
||||
#define CUBLAS_TF32_TENSOR_OP_MATH 0
|
||||
#define CUDA_R_16F HIPBLAS_R_16F
|
||||
#define CUDA_R_32F HIPBLAS_R_32F
|
||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
|
||||
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
|
||||
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
|
||||
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
|
||||
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
|
||||
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
|
||||
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
|
||||
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
||||
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
|
||||
#define cublasCreate hipblasCreate
|
||||
#define cublasDestroy hipblasDestroy
|
||||
#define cublasGemmEx hipblasGemmEx
|
||||
#define cublasGemmBatchedEx hipblasGemmBatchedEx
|
||||
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
|
||||
#define cublasHandle_t hipblasHandle_t
|
||||
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
||||
#define cublasSetStream hipblasSetStream
|
||||
#define cublasSgemm hipblasSgemm
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define cublasOperation_t hipblasOperation_t
|
||||
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
|
||||
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||
#define cudaDeviceProp hipDeviceProp_t
|
||||
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||
#define cudaError_t hipError_t
|
||||
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
||||
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
|
||||
#define cudaEventCreateWithFlags hipEventCreateWithFlags
|
||||
#define cudaEventDisableTiming hipEventDisableTiming
|
||||
#define cudaEventRecord hipEventRecord
|
||||
#define cudaEventSynchronize hipEventSynchronize
|
||||
#define cudaEvent_t hipEvent_t
|
||||
#define cudaEventDestroy hipEventDestroy
|
||||
#define cudaFree hipFree
|
||||
#define cudaFreeHost hipHostFree
|
||||
#define cudaGetDevice hipGetDevice
|
||||
#define cudaGetDeviceCount hipGetDeviceCount
|
||||
#define cudaGetDeviceProperties hipGetDeviceProperties
|
||||
#define cudaGetErrorString hipGetErrorString
|
||||
#define cudaGetLastError hipGetLastError
|
||||
#define cudaHostRegister hipHostRegister
|
||||
#define cudaHostRegisterPortable hipHostRegisterPortable
|
||||
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
|
||||
#define cudaHostUnregister hipHostUnregister
|
||||
#define cudaLaunchHostFunc hipLaunchHostFunc
|
||||
#define cudaMalloc hipMalloc
|
||||
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
||||
#define cudaMemcpy hipMemcpy
|
||||
#define cudaMemcpyAsync hipMemcpyAsync
|
||||
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
|
||||
#define cudaMemcpy2DAsync hipMemcpy2DAsync
|
||||
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
|
||||
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
|
||||
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
|
||||
#define cudaMemcpyKind hipMemcpyKind
|
||||
#define cudaMemset hipMemset
|
||||
#define cudaMemsetAsync hipMemsetAsync
|
||||
#define cudaMemGetInfo hipMemGetInfo
|
||||
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
|
||||
#define cudaSetDevice hipSetDevice
|
||||
#define cuDeviceGet hipDeviceGet
|
||||
#define CUdevice hipDevice_t
|
||||
#define CUdeviceptr hipDeviceptr_t
|
||||
#define cuMemUnmap hipMemUnmap
|
||||
#define CUmemAccessDesc hipMemAccessDesc
|
||||
#define cuMemAddressFree hipMemAddressFree
|
||||
#define cuMemRelease hipMemRelease
|
||||
#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
|
||||
#define cuMemCreate hipMemCreate
|
||||
#define cuMemAddressReserve hipMemAddressReserve
|
||||
#define cuMemMap hipMemMap
|
||||
#define cuMemSetAccess hipMemSetAccess
|
||||
#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
|
||||
#define CUmemAllocationProp hipMemAllocationProp
|
||||
#define cuDeviceGetAttribute hipDeviceGetAttribute
|
||||
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
|
||||
#define cudaStreamDestroy hipStreamDestroy
|
||||
#define cudaStreamFireAndForget hipStreamFireAndForget
|
||||
#define cudaStreamNonBlocking hipStreamNonBlocking
|
||||
#define cudaStreamPerThread hipStreamPerThread
|
||||
#define cudaStreamSynchronize hipStreamSynchronize
|
||||
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
|
||||
#define cudaGraphExec_t hipGraphExec_t
|
||||
#define cudaGraphNode_t hipGraphNode_t
|
||||
#define cudaKernelNodeParams hipKernelNodeParams
|
||||
#define cudaKernelNodeParams hipKernelNodeParams
|
||||
#define cudaGraphExecDestroy hipGraphExecDestroy
|
||||
#define cudaGraphLaunch hipGraphLaunch
|
||||
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
|
||||
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
|
||||
#define cudaGraphNodeType hipGraphNodeType
|
||||
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
|
||||
#define cudaGraphInstantiate hipGraphInstantiate
|
||||
#define cudaStreamEndCapture hipStreamEndCapture
|
||||
#define cudaGraphDestroy hipGraphDestroy
|
||||
#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
|
||||
#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
|
||||
#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
|
||||
#define cudaGraphNodeGetType hipGraphNodeGetType
|
||||
#define cudaGraphGetNodes hipGraphGetNodes
|
||||
#define cudaGraphExecUpdate hipGraphExecUpdate
|
||||
#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
|
||||
#define cudaStreamBeginCapture hipStreamBeginCapture
|
||||
#define cudaGraph_t hipGraph_t
|
||||
#define cudaStream_t hipStream_t
|
||||
#define cudaSuccess hipSuccess
|
||||
#define cudaHostFn_t hipHostFn_t
|
||||
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
|
||||
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
||||
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
|
||||
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
|
||||
#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
|
||||
#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
|
||||
#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
|
||||
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
|
||||
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
||||
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
||||
|
||||
#define __CUDA_ARCH__ 1300
|
||||
|
||||
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
|
||||
#define GCN
|
||||
#endif
|
||||
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
|
||||
#define CDNA
|
||||
#endif
|
||||
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
||||
defined(__gfx1150__) || defined(__gfx1151__)
|
||||
#define RDNA3
|
||||
#endif
|
||||
|
||||
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
|
||||
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
|
||||
#define RDNA2
|
||||
#endif
|
||||
|
||||
#if defined(__gfx1010__) || defined(__gfx1012__)
|
||||
#define RDNA1
|
||||
#endif
|
||||
|
||||
#ifndef __has_builtin
|
||||
#define __has_builtin(x) 0
|
||||
#endif
|
||||
|
||||
typedef hip_bfloat16 nv_bfloat16;
|
137
ktransformers/ktransformers_ext/vendors/musa.h
vendored
Normal file
137
ktransformers/ktransformers_ext/vendors/musa.h
vendored
Normal file
|
@ -0,0 +1,137 @@
|
|||
#pragma once
|
||||
|
||||
#include <musa_runtime.h>
|
||||
#include <musa.h>
|
||||
#include <mublas.h>
|
||||
#include <musa_bf16.h>
|
||||
#include <musa_fp16.h>
|
||||
#define CUBLAS_COMPUTE_16F CUDA_R_16F
|
||||
#define CUBLAS_COMPUTE_32F CUDA_R_32F
|
||||
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
|
||||
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_OP_N MUBLAS_OP_N
|
||||
#define CUBLAS_OP_T MUBLAS_OP_T
|
||||
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
|
||||
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
|
||||
#define CUDA_R_16F MUSA_R_16F
|
||||
#define CUDA_R_32F MUSA_R_32F
|
||||
#define cublasComputeType_t cudaDataType_t
|
||||
#define cublasCreate mublasCreate
|
||||
#define cublasDestroy mublasDestroy
|
||||
#define cublasGemmEx mublasGemmEx
|
||||
#define cublasGemmBatchedEx mublasGemmBatchedEx
|
||||
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
|
||||
#define cublasHandle_t mublasHandle_t
|
||||
#define cublasSetMathMode mublasSetMathMode
|
||||
#define cublasSetStream mublasSetStream
|
||||
#define cublasSgemm mublasSgemm
|
||||
#define cublasStatus_t mublasStatus_t
|
||||
#define cublasOperation_t mublasOperation_t
|
||||
#define cublasGetStatusString mublasStatus_to_string
|
||||
#define cudaDataType_t musaDataType_t
|
||||
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
||||
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
|
||||
#define cudaDeviceProp musaDeviceProp
|
||||
#define cudaDeviceSynchronize musaDeviceSynchronize
|
||||
#define cudaError_t musaError_t
|
||||
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
|
||||
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
|
||||
#define cudaEventCreateWithFlags musaEventCreateWithFlags
|
||||
#define cudaEventDisableTiming musaEventDisableTiming
|
||||
#define cudaEventRecord musaEventRecord
|
||||
#define cudaEventSynchronize musaEventSynchronize
|
||||
#define cudaEvent_t musaEvent_t
|
||||
#define cudaEventDestroy musaEventDestroy
|
||||
#define cudaFree musaFree
|
||||
#define cudaFreeHost musaFreeHost
|
||||
#define cudaGetDevice musaGetDevice
|
||||
#define cudaGetDeviceCount musaGetDeviceCount
|
||||
#define cudaGetDeviceProperties musaGetDeviceProperties
|
||||
#define cudaGetErrorString musaGetErrorString
|
||||
#define cudaGetLastError musaGetLastError
|
||||
#define cudaHostRegister musaHostRegister
|
||||
#define cudaHostRegisterPortable musaHostRegisterPortable
|
||||
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
|
||||
#define cudaHostUnregister musaHostUnregister
|
||||
#define cudaLaunchHostFunc musaLaunchHostFunc
|
||||
#define cudaMalloc musaMalloc
|
||||
#define cudaMallocHost musaMallocHost
|
||||
#define cudaMallocManaged musaMallocManaged
|
||||
#define cudaMemcpy musaMemcpy
|
||||
#define cudaMemcpyAsync musaMemcpyAsync
|
||||
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
|
||||
#define cudaMemcpy2DAsync musaMemcpy2DAsync
|
||||
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
|
||||
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
|
||||
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
|
||||
#define cudaMemcpyKind musaMemcpyKind
|
||||
#define cudaMemset musaMemset
|
||||
#define cudaMemsetAsync musaMemsetAsync
|
||||
#define cudaMemGetInfo musaMemGetInfo
|
||||
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
|
||||
#define cudaSetDevice musaSetDevice
|
||||
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
|
||||
#define cudaStreamDestroy musaStreamDestroy
|
||||
#define cudaStreamFireAndForget musaStreamFireAndForget
|
||||
#define cudaStreamNonBlocking musaStreamNonBlocking
|
||||
#define cudaStreamPerThread musaStreamPerThread
|
||||
#define cudaStreamSynchronize musaStreamSynchronize
|
||||
#define cudaStreamWaitEvent musaStreamWaitEvent
|
||||
#define cudaStream_t musaStream_t
|
||||
#define cudaSuccess musaSuccess
|
||||
|
||||
// Additional mappings for MUSA virtual memory pool
|
||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
||||
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
|
||||
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
|
||||
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
|
||||
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
|
||||
#define CUdevice MUdevice
|
||||
#define CUdeviceptr MUdeviceptr
|
||||
#define CUmemAccessDesc MUmemAccessDesc
|
||||
#define CUmemAllocationProp MUmemAllocationProp
|
||||
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
|
||||
#define cuDeviceGet muDeviceGet
|
||||
#define cuDeviceGetAttribute muDeviceGetAttribute
|
||||
#define cuMemAddressFree muMemAddressFree
|
||||
#define cuMemAddressReserve muMemAddressReserve
|
||||
#define cuMemCreate muMemCreate
|
||||
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
|
||||
#define cuMemMap muMemMap
|
||||
#define cuMemRelease muMemRelease
|
||||
#define cuMemSetAccess muMemSetAccess
|
||||
#define cuMemUnmap muMemUnmap
|
||||
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
|
||||
#define cudaFuncSetAttribute musaFuncSetAttribute
|
||||
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
|
||||
#define make_cudaExtent make_musaExtent
|
||||
#define make_cudaPitchedPtr make_musaPitchedPtr
|
||||
|
||||
// Additional mappings for MUSA graphs
|
||||
#define CUDA_SUCCESS MUSA_SUCCESS
|
||||
#define CUresult MUresult
|
||||
#define cuGetErrorString muGetErrorString
|
||||
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
|
||||
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
|
||||
#define cudaGraphDestroy musaGraphDestroy
|
||||
#define cudaGraphExecDestroy musaGraphExecDestroy
|
||||
#define cudaGraphExec_t musaGraphExec_t
|
||||
#define cudaGraphExecUpdate musaGraphExecUpdate
|
||||
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
|
||||
#define cudaGraphGetNodes musaGraphGetNodes
|
||||
#define cudaGraphInstantiate musaGraphInstantiate
|
||||
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
||||
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
|
||||
#define cudaGraphLaunch musaGraphLaunch
|
||||
#define cudaGraphNodeGetType musaGraphNodeGetType
|
||||
#define cudaGraphNode_t musaGraphNode_t
|
||||
#define cudaGraphNodeType musaGraphNodeType
|
||||
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
|
||||
#define cudaGraph_t musaGraph_t
|
||||
#define cudaKernelNodeParams musaKernelNodeParams
|
||||
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
||||
#define cudaStreamEndCapture musaStreamEndCapture
|
||||
|
||||
typedef mt_bfloat16 nv_bfloat16;
|
13
ktransformers/ktransformers_ext/vendors/vendor.h
vendored
Normal file
13
ktransformers/ktransformers_ext/vendors/vendor.h
vendored
Normal file
|
@ -0,0 +1,13 @@
|
|||
#ifndef CPUINFER_VENDOR_VENDOR_H
|
||||
#define CPUINFER_VENDOR_VENDOR_H
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include "cuda.h"
|
||||
#elif USE_HIP
|
||||
#define __HIP_PLATFORM_AMD__
|
||||
#include "hip.h"
|
||||
#elif USE_MUSA
|
||||
#include "musa.h"
|
||||
#endif
|
||||
|
||||
#endif // CPUINFER_VENDOR_VENDOR_H
|
|
@ -31,6 +31,7 @@ from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
|||
from ktransformers.util.utils import prefill_and_generate, get_compute_capability
|
||||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
|
||||
|
||||
custom_models = {
|
||||
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
|
||||
|
@ -169,7 +170,7 @@ def local_chat(
|
|||
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
|
||||
"please change max_seq_len in ~/.ktransformers/config.yaml"
|
||||
|
||||
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8:
|
||||
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA:
|
||||
generated = prefill_and_generate(
|
||||
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,
|
||||
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
|
||||
|
|
|
@ -20,8 +20,14 @@ from ktransformers.util.utils import get_compute_capability
|
|||
import logging
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.cache_utils import Cache
|
||||
from flash_attn import flash_attn_func
|
||||
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
|
||||
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
except:
|
||||
pass
|
||||
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
|
||||
from ktransformers.operators.triton_attention_prefill import context_attention_fwd
|
||||
import os
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||
if flashinfer_enabled:
|
||||
|
@ -319,18 +325,27 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
|
||||
|
||||
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
|
||||
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
|
||||
|
||||
attn_output = flash_attn_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states_padded,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=True,
|
||||
# for bsz = 1
|
||||
attn_output = torch.zeros(bsz * q_len, self.num_heads, self.v_head_dim, device=hidden_states.device)
|
||||
b_start_loc = torch.zeros(bsz, dtype=torch.int64, device=hidden_states.device)
|
||||
b_seq_len = torch.full((bsz,), q_len, dtype=torch.int64, device=hidden_states.device)
|
||||
|
||||
max_input_len = q_len
|
||||
|
||||
context_attention_fwd(
|
||||
q=query_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
|
||||
k=key_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
|
||||
v=value_states.squeeze(0).view(-1, self.num_heads, self.v_head_dim),
|
||||
o=attn_output,
|
||||
b_start_loc=b_start_loc,
|
||||
b_seq_len=b_seq_len,
|
||||
max_input_len=max_input_len,
|
||||
is_causal=True
|
||||
)
|
||||
|
||||
if self.q_head_dim != self.v_head_dim:
|
||||
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
||||
attn_output = attn_output[:, :, : self.v_head_dim]
|
||||
|
||||
attn_output = attn_output.reshape(
|
||||
bsz, q_len, self.num_heads * self.v_head_dim
|
||||
|
@ -589,8 +604,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if os.name == 'nt' or get_compute_capability()<8:
|
||||
print("for Windows or GPU before ampere, use forward_windows")
|
||||
if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
|
||||
return self.forward_windows(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
|
|
@ -17,7 +17,10 @@ import logging
|
|||
logger = logging.getLogger("dynamic_attention")
|
||||
sys.path.append(os.path.dirname(__file__) + "/../ktransformers_ext/cpu_backend")
|
||||
from ktransformers.operators.cpuinfer import CPUInfer, CPUInferKVCache
|
||||
from flash_attn import flash_attn_func, flash_attn_with_kvcache
|
||||
try:
|
||||
from flash_attn import flash_attn_func, flash_attn_with_kvcache
|
||||
except:
|
||||
print("falsh attn not found")
|
||||
|
||||
|
||||
import math
|
||||
|
|
|
@ -35,6 +35,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext
|
|||
import cpuinfer_ext
|
||||
from ktransformers.operators.cpuinfer import CPUInfer
|
||||
from ktransformers.server.config.config import Config
|
||||
from typing import Dict, Tuple, Optional, Union
|
||||
import numpy as np
|
||||
|
||||
#class KLinearBase(BaseInjectedModule, ABC):
|
||||
class KLinearBase(ABC):
|
||||
|
@ -176,16 +178,182 @@ class KLinearTorch(KLinearBase):
|
|||
if self.has_bias:
|
||||
self.bias = None
|
||||
|
||||
|
||||
class KLinearQ8(KLinearBase):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module = None,
|
||||
device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
self.has_bias = False
|
||||
self.compute_dtype = torch.float32
|
||||
self.weight = None
|
||||
self.weight_scale = None
|
||||
self.weight_zero_point = None
|
||||
self.bias = None
|
||||
self.loaded = False
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
out_device = x.device
|
||||
|
||||
x = x.to(device=self.device, dtype=self.compute_dtype)
|
||||
|
||||
# 使用原始权重做矩阵乘法,模拟原始行为
|
||||
|
||||
# 反量化权重进行矩阵乘法
|
||||
weight_dequant = self._dequantize_weight(self.weight, self.weight_scale, bits=8)
|
||||
out = x @ weight_dequant.T
|
||||
|
||||
if self.has_bias:
|
||||
out = out + self.bias
|
||||
|
||||
return out.to(dtype=orig_dtype, device=out_device)
|
||||
|
||||
def _dequantize_weight(self, q_matrix, scales, bits=8):
|
||||
"""
|
||||
Dequantize a low-precision matrix back to floating-point
|
||||
|
||||
Args:
|
||||
q_matrix (torch.Tensor): Quantized int matrix
|
||||
scales (torch.Tensor): Scale factors for each column
|
||||
bits (int): Quantization bits used (8 or 4)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Dequantized floating-point matrix
|
||||
"""
|
||||
# Ensure inputs are torch tensors
|
||||
if not isinstance(q_matrix, torch.Tensor):
|
||||
q_matrix = torch.tensor(q_matrix, dtype=torch.int8)
|
||||
if not isinstance(scales, torch.Tensor):
|
||||
scales = torch.tensor(scales, dtype=torch.float32)
|
||||
|
||||
# Convert to correct dtype if needed
|
||||
if q_matrix.dtype != torch.int8:
|
||||
q_matrix = q_matrix.to(torch.int8)
|
||||
if scales.dtype != torch.float32:
|
||||
scales = scales.to(torch.float32)
|
||||
|
||||
# For Q4, ensure the values stay within 4-bit range
|
||||
if bits == 4:
|
||||
q_matrix = torch.clamp(q_matrix, -7, 7)
|
||||
rows, cols = q_matrix.shape
|
||||
dequant_matrix = q_matrix.to(torch.float32)
|
||||
scales_broadcast = scales.view(1, cols)
|
||||
# Apply dequantization to all columns at once using matrix multiplication
|
||||
dequant_matrix = dequant_matrix * scales_broadcast
|
||||
|
||||
return dequant_matrix
|
||||
|
||||
|
||||
def _quantize_weight(self, matrix, bits=8):
|
||||
"""
|
||||
Quantize a floating-point matrix to lower precision (Q8 or Q4)
|
||||
|
||||
Args:
|
||||
matrix (torch.Tensor): Input matrix in floating-point format
|
||||
bits (int): Quantization bits, either 8 or 4
|
||||
|
||||
Returns:
|
||||
tuple: (quantized int matrix, scale factors for each column)
|
||||
"""
|
||||
if not isinstance(matrix, torch.Tensor):
|
||||
matrix = torch.tensor(matrix, dtype=torch.float32)
|
||||
|
||||
# Convert to float32 if needed
|
||||
if matrix.dtype != torch.float32:
|
||||
matrix = matrix.to(torch.float32)
|
||||
|
||||
# Get matrix shape
|
||||
rows, cols = matrix.shape
|
||||
|
||||
# Determine quantization parameters based on bits
|
||||
if bits == 8:
|
||||
max_int = 127
|
||||
qtype = torch.int8
|
||||
elif bits == 4:
|
||||
max_int = 7
|
||||
qtype = torch.int8 # We'll still use int8 storage but limit to 4-bit range, wait for native support
|
||||
else:
|
||||
raise ValueError("Quantization bits must be either 8 or 4")
|
||||
|
||||
scales = torch.zeros(cols, dtype=torch.float32, device=matrix.device)
|
||||
|
||||
# Calculate max absolute value for each column
|
||||
max_abs_vals, _ = torch.max(torch.abs(matrix), dim=0)
|
||||
|
||||
# Handle zero columns (avoid division by zero)
|
||||
zero_cols = max_abs_vals == 0
|
||||
max_abs_vals[zero_cols] = 1.0
|
||||
|
||||
# Calculate scale factors for all columns at once
|
||||
scales = max_abs_vals / max_int
|
||||
|
||||
# Prepare the scales for broadcasting [1, cols]
|
||||
scales_broadcast = scales.view(1, cols)
|
||||
|
||||
# Apply quantization to the entire matrix at once
|
||||
q_matrix = torch.round(matrix / scales_broadcast).to(qtype)
|
||||
|
||||
# For Q4, clamp values to ensure they stay within 4-bit range
|
||||
if bits == 4:
|
||||
q_matrix = torch.clamp(q_matrix, -max_int, max_int)
|
||||
|
||||
return q_matrix, scales
|
||||
|
||||
def load(self, w: Union[Dict, nn.Parameter, Tuple, None] = None, device: Optional[str] = None):
|
||||
if self.loaded: return
|
||||
if device is None: device = self.device
|
||||
if w is None: w = self.load_weight(device=device)
|
||||
|
||||
if isinstance(w, nn.Parameter):
|
||||
try:
|
||||
weight = w.to(dtype=self.compute_dtype).view(self.out_features, self.in_features)
|
||||
except:
|
||||
weight = w.to(dtype=self.compute_dtype)
|
||||
self.has_bias = False
|
||||
elif isinstance(w, tuple):
|
||||
try:
|
||||
weight = w[0].to(dtype=self.compute_dtype).view(self.out_features, self.in_features)
|
||||
except:
|
||||
weight = w[0].to(dtype=self.compute_dtype)
|
||||
self.bias = w[1].to(dtype=self.compute_dtype).to(device)
|
||||
self.has_bias = True
|
||||
else:
|
||||
raise ValueError("Invalid weight type")
|
||||
|
||||
self.weight, self.weight_scale = self._quantize_weight(weight, bits=8)
|
||||
|
||||
self.weight = self.weight.to(device)
|
||||
self.weight_scale = self.weight_scale.to(device)
|
||||
|
||||
if self.has_bias:
|
||||
self.bias = self.bias.to(device)
|
||||
|
||||
self.loaded = True
|
||||
|
||||
def unload(self):
|
||||
self.weight = None
|
||||
self.weight_scale = None
|
||||
self.weight_zero_point = None
|
||||
self._orig_weight = None
|
||||
|
||||
if self.has_bias:
|
||||
self.bias = None
|
||||
|
||||
self.loaded = False
|
||||
|
||||
|
||||
class KLinearFP8(KLinearBase):
|
||||
# this kernel requires special handling for weight
|
||||
# Please load the weight file downloaded from KVCache.AI
|
||||
marlin_q_w: torch.Tensor
|
||||
marlin_s: torch.Tensor
|
||||
g_idx: torch.Tensor
|
||||
sort_indices: torch.Tensor
|
||||
has_bias: bool
|
||||
weight: torch.Tensor
|
||||
scale_w: torch.Tensor
|
||||
bias: torch.Tensor
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -468,6 +636,7 @@ LINEAR_MAP = {
|
|||
"KLinearTorch": KLinearTorch,
|
||||
"KLinearCPUInfer": KLinearCPUInfer,
|
||||
"KLinearFP8": KLinearFP8,
|
||||
"KLinearQ8": KLinearQ8,
|
||||
}
|
||||
|
||||
class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
||||
|
|
|
@ -53,6 +53,7 @@ from ktransformers.models.modeling_deepseek import (
|
|||
DeepseekV2DecoderLayer,
|
||||
DeepseekV2MoE,
|
||||
)
|
||||
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
|
||||
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
|
||||
from ktransformers.models.configuration_llama import LlamaConfig
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
|
@ -649,8 +650,8 @@ class KDeepseekV2Model(BaseInjectedModule):
|
|||
if per_layer_prefill_flag:
|
||||
causal_mask = None
|
||||
else:
|
||||
if os.name == 'nt' or get_compute_capability()<8:
|
||||
print("for Windows or GPU before ampere, use forward_windows")
|
||||
if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
|
||||
# print("for Windows or GPU before ampere, use forward_windows")
|
||||
# only use mask in forward windows or can't flash attn
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
|
@ -673,6 +674,7 @@ class KDeepseekV2Model(BaseInjectedModule):
|
|||
t_f = 0
|
||||
|
||||
for i, decoder_layer in enumerate(self.layers):
|
||||
# print(f"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \n")
|
||||
if self.transfer_map is not None and i in self.transfer_map:
|
||||
prev_stream = torch.cuda.current_stream()
|
||||
cur_device = self.transfer_map[i]
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
|
||||
@triton.jit
|
||||
def tanh(x):
|
||||
# Tanh is just a scaled sigmoid
|
||||
|
@ -181,8 +181,8 @@ def _decode_grouped_att_m_fwd(
|
|||
# [TODO] work around shmem limit on MI3xx
|
||||
|
||||
# TODO: support hip
|
||||
#if is_hip_ and Lk >= 576:
|
||||
# BLOCK = 16
|
||||
if device_manager.gpu_vendor == GPUVendor.AMD and Lk >= 576:
|
||||
BLOCK = 16
|
||||
|
||||
if Lk == 576:
|
||||
BLOCK_DMODEL = 512
|
||||
|
|
206
ktransformers/operators/triton_attention_prefill.py
Normal file
206
ktransformers/operators/triton_attention_prefill.py
Normal file
|
@ -0,0 +1,206 @@
|
|||
|
||||
# Adapted from
|
||||
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
|
||||
# which was originally adapted from
|
||||
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
|
||||
|
||||
"""
|
||||
Memory-efficient attention for prefill.
|
||||
It supporst page size = 1.
|
||||
"""
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
is_cuda_available = torch.cuda.is_available()
|
||||
if is_cuda_available:
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
Out,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
kv_group_num: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
Lk: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
|
||||
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]
|
||||
|
||||
mask_d = offs_d < Lk
|
||||
|
||||
q = tl.load(
|
||||
Q + off_q,
|
||||
mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
||||
|
||||
end_n = (
|
||||
cur_batch_seq_len
|
||||
if not IS_CAUSAL
|
||||
else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len)
|
||||
)
|
||||
for start_n in range(0, block_mask * end_n, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]),
|
||||
other=0.0,
|
||||
)
|
||||
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
|
||||
if IS_CAUSAL:
|
||||
qk += tl.where(
|
||||
(start_n + offs_n[None, :] < cur_batch_seq_len)
|
||||
& (offs_m[:, None] >= (start_n + offs_n[None, :])),
|
||||
0,
|
||||
float("-inf"),
|
||||
)
|
||||
else:
|
||||
qk += tl.where(
|
||||
(start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float("-inf")
|
||||
)
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
# initialize pointers to output
|
||||
off_o = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
|
||||
+ cur_head * stride_oh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(
|
||||
out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :])
|
||||
)
|
||||
|
||||
|
||||
def context_attention_fwd(
|
||||
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
|
||||
):
|
||||
"""
|
||||
q, k, v: [b * s, head, head_dim]
|
||||
b_start_loc: [b]
|
||||
b_seq_len: [b]
|
||||
out: [b * s, head, head_dim]
|
||||
"""
|
||||
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
|
||||
BLOCK = 128
|
||||
else:
|
||||
BLOCK = 64
|
||||
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
kv_group_num = q.shape[1] // k.shape[1]
|
||||
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
o,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=triton.next_power_of_2(Lk),
|
||||
BLOCK_N=BLOCK,
|
||||
IS_CAUSAL=is_causal,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
Lk=Lk,
|
||||
)
|
|
@ -22,7 +22,7 @@
|
|||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
- match:
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
name: "^lm_head$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearCPUInfer"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearQ8"
|
||||
prefill_op: "KLinearTorch"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPU"
|
||||
out_device: "cuda"
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
46
ktransformers/tests/test_pytorch_q8.py
Normal file
46
ktransformers/tests/test_pytorch_q8.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
import torch
|
||||
|
||||
# 定义一个包含线性层的浮点模型
|
||||
class LinearModel(torch.nn.Module):
|
||||
def __init__(self, in_features, out_features):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features, out_features)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
# 创建浮点模型实例
|
||||
in_features = 64
|
||||
out_features = 128
|
||||
model_fp32 = LinearModel(in_features, out_features)
|
||||
|
||||
# 创建量化模型实例
|
||||
model_int8 = torch.ao.quantization.quantize_dynamic(
|
||||
model_fp32, # 原始浮点模型
|
||||
{torch.nn.Linear}, # 要量化的层类型集合
|
||||
dtype=torch.qint8 # 量化的目标数据类型
|
||||
)
|
||||
|
||||
# 测试模型
|
||||
batch_size = 32
|
||||
input_fp32 = torch.randn(1, batch_size, in_features) # 生成随机输入数据
|
||||
output_int8 = model_int8(input_fp32) # 通过量化模型运行数据
|
||||
|
||||
# 打印输出形状验证
|
||||
print(f"输入形状: {input_fp32.shape}")
|
||||
print(f"输出形状: {output_int8.shape}")
|
||||
|
||||
# 比较原始模型和量化模型的输出
|
||||
with torch.no_grad():
|
||||
output_fp32 = model_fp32(input_fp32)
|
||||
|
||||
print(f"FP32输出的前几个值: {output_fp32[0, :5]}")
|
||||
print(f"INT8输出的前几个值: {output_int8[0, :5]}")
|
||||
|
||||
# 计算平均误差
|
||||
error = torch.abs(output_fp32 - output_int8).mean().item()
|
||||
print(f"平均绝对误差: {error}")
|
||||
|
||||
# 打印模型类型信息
|
||||
print(f"量化前模型类型: {type(model_fp32.linear)}")
|
||||
print(f"量化后模型类型: {type(model_int8.linear)}")
|
202
ktransformers/util/vendors.py
Normal file
202
ktransformers/util/vendors.py
Normal file
|
@ -0,0 +1,202 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from enum import IntEnum, auto
|
||||
from typing import Optional, Union, List
|
||||
import torch
|
||||
|
||||
class GPUVendor(IntEnum):
|
||||
NVIDIA = auto()
|
||||
AMD = auto()
|
||||
MooreThreads = auto()
|
||||
MetaX = auto()
|
||||
MUSA = auto()
|
||||
Unknown = auto()
|
||||
|
||||
class DeviceManager:
|
||||
"""
|
||||
Device manager that provides a unified interface for handling different GPU vendors
|
||||
"""
|
||||
def __init__(self):
|
||||
self.gpu_vendor = self._detect_gpu_vendor()
|
||||
self.available_devices = self._get_available_devices()
|
||||
|
||||
def _detect_gpu_vendor(self) -> GPUVendor:
|
||||
"""Detect GPU vendor type"""
|
||||
if not torch.cuda.is_available():
|
||||
# Check MUSA availability (assuming a musa module exists)
|
||||
try:
|
||||
import musa
|
||||
if musa.is_available():
|
||||
return GPUVendor.MUSA
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
return GPUVendor.Unknown
|
||||
|
||||
device_name = torch.cuda.get_device_name(0).lower()
|
||||
|
||||
if any(name in device_name for name in ["nvidia", "geforce", "quadro", "tesla", "titan", "rtx", "gtx"]):
|
||||
return GPUVendor.NVIDIA
|
||||
elif any(name in device_name for name in ["amd", "radeon", "rx", "vega", "instinct", "firepro", "mi"]):
|
||||
return GPUVendor.AMD
|
||||
elif any(name in device_name for name in ["mthreads", "moore", "mtt"]):
|
||||
return GPUVendor.MooreThreads
|
||||
elif any(name in device_name for name in ["metax", "meta"]):
|
||||
return GPUVendor.MetaX
|
||||
elif "musa" in device_name:
|
||||
return GPUVendor.MUSA
|
||||
|
||||
# Backend check
|
||||
try:
|
||||
if hasattr(torch.version, 'hip') and torch.version.hip is not None:
|
||||
return GPUVendor.AMD
|
||||
elif hasattr(torch.version, 'cuda') and torch.version.cuda is not None:
|
||||
return GPUVendor.NVIDIA
|
||||
except:
|
||||
pass
|
||||
|
||||
return GPUVendor.Unknown
|
||||
|
||||
def _get_available_devices(self) -> List[int]:
|
||||
"""Get list of available device indices"""
|
||||
devices = []
|
||||
|
||||
if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD:
|
||||
devices = list(range(torch.cuda.device_count()))
|
||||
elif self.gpu_vendor == GPUVendor.MUSA:
|
||||
try:
|
||||
import musa
|
||||
devices = list(range(musa.device_count()))
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
return devices
|
||||
|
||||
def get_device_str(self, device_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Get device string for the given device ID
|
||||
|
||||
Args:
|
||||
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
|
||||
|
||||
Returns:
|
||||
Device string representation (e.g., "cuda:0", "musa:1", "cpu")
|
||||
"""
|
||||
if device_id == -1 or device_id == "cpu":
|
||||
return "cpu"
|
||||
|
||||
if isinstance(device_id, int):
|
||||
if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD:
|
||||
if device_id < torch.cuda.device_count():
|
||||
return f"cuda:{device_id}"
|
||||
elif self.gpu_vendor == GPUVendor.MUSA:
|
||||
try:
|
||||
import musa
|
||||
if device_id < musa.device_count():
|
||||
return f"musa:{device_id}"
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
return "cpu"
|
||||
|
||||
def to_torch_device(self, device_id: Union[int, str] = 0) -> torch.device:
|
||||
"""
|
||||
Convert device ID to torch.device object
|
||||
|
||||
Args:
|
||||
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
|
||||
|
||||
Returns:
|
||||
torch.device object
|
||||
"""
|
||||
device_str = self.get_device_str(device_id)
|
||||
|
||||
# Handle MUSA device
|
||||
if device_str.startswith("musa:"):
|
||||
try:
|
||||
import musa
|
||||
index = int(device_str.split(":")[-1])
|
||||
return musa.device(index)
|
||||
except (ImportError, ValueError, AttributeError):
|
||||
return torch.device("cpu")
|
||||
|
||||
# Standard PyTorch device
|
||||
return torch.device(device_str)
|
||||
|
||||
def move_tensor_to_device(self, tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor:
|
||||
"""
|
||||
Move tensor to specified device
|
||||
|
||||
Args:
|
||||
tensor: PyTorch tensor to move
|
||||
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
|
||||
|
||||
Returns:
|
||||
Tensor moved to the specified device
|
||||
"""
|
||||
device = self.to_torch_device(device_id)
|
||||
return tensor.to(device)
|
||||
|
||||
def is_available(self, index: int = 0) -> bool:
|
||||
"""
|
||||
Check if device at specified index is available
|
||||
|
||||
Args:
|
||||
index: Device index to check
|
||||
|
||||
Returns:
|
||||
True if the device is available, False otherwise
|
||||
"""
|
||||
if index < 0:
|
||||
return True # CPU is always available
|
||||
|
||||
return index in self.available_devices
|
||||
|
||||
def get_all_devices(self) -> List[int]:
|
||||
"""
|
||||
Get all available device indices
|
||||
|
||||
Returns:
|
||||
List of available device indices (0, 1, 2, etc.)
|
||||
"""
|
||||
return self.available_devices
|
||||
|
||||
# Create global device manager instance
|
||||
device_manager = DeviceManager()
|
||||
|
||||
# Convenience functions
|
||||
def get_device(device_id: Union[int, str] = 0) -> torch.device:
|
||||
"""
|
||||
Get torch.device object for the specified device ID
|
||||
|
||||
Args:
|
||||
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
|
||||
|
||||
Returns:
|
||||
torch.device object
|
||||
"""
|
||||
return device_manager.to_torch_device(device_id)
|
||||
|
||||
def to_device(tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor:
|
||||
"""
|
||||
Move tensor to specified device
|
||||
|
||||
Args:
|
||||
tensor: PyTorch tensor to move
|
||||
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
|
||||
|
||||
Returns:
|
||||
Tensor moved to the specified device
|
||||
"""
|
||||
return device_manager.move_tensor_to_device(tensor, device_id)
|
||||
|
||||
# Get devices
|
||||
cpu_device = get_device(-1) # CPU using index -1
|
||||
cpu_device2 = get_device("cpu") # CPU using string "cpu"
|
||||
gpu0 = get_device(0) # First GPU
|
||||
|
||||
# Move tensors
|
||||
x = torch.randn(3, 3)
|
||||
x_gpu = to_device(x, 0) # Move to first GPU
|
||||
x_cpu1 = to_device(x, -1) # Move to CPU using index -1
|
||||
x_cpu2 = to_device(x, "cpu") # Move to CPU using string "cpu"
|
83
setup.py
83
setup.py
|
@ -29,7 +29,7 @@ import torch.version
|
|||
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
||||
from setuptools import setup, Extension
|
||||
from cpufeature.extension import CPUFeature
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
|
||||
try:
|
||||
from torch_musa.utils.simple_porting import SimplePorting
|
||||
from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME
|
||||
|
@ -64,6 +64,70 @@ class VersionInfo:
|
|||
musa_version = f"{bare_metal_version.major}{bare_metal_version.minor}"
|
||||
return musa_version
|
||||
|
||||
def get_rocm_bare_metal_version(self, rocm_dir):
|
||||
"""
|
||||
Get the ROCm version from the ROCm installation directory.
|
||||
|
||||
Args:
|
||||
rocm_dir: Path to the ROCm installation directory
|
||||
|
||||
Returns:
|
||||
A string representation of the ROCm version (e.g., "63" for ROCm 6.3)
|
||||
"""
|
||||
try:
|
||||
# Try using rocm_agent_enumerator to get version info
|
||||
raw_output = subprocess.check_output(
|
||||
[rocm_dir + "/bin/rocminfo", "--version"],
|
||||
universal_newlines=True,
|
||||
stderr=subprocess.STDOUT)
|
||||
# Extract version number from output
|
||||
match = re.search(r'(\d+\.\d+)', raw_output)
|
||||
if match:
|
||||
version_str = match.group(1)
|
||||
version = parse(version_str)
|
||||
rocm_version = f"{version.major}{version.minor}"
|
||||
return rocm_version
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
# If rocminfo --version fails, try alternative methods
|
||||
pass
|
||||
|
||||
try:
|
||||
# Try reading version from release file
|
||||
with open(os.path.join(rocm_dir, "share/doc/hip/version.txt"), "r") as f:
|
||||
version_str = f.read().strip()
|
||||
version = parse(version_str)
|
||||
rocm_version = f"{version.major}{version.minor}"
|
||||
return rocm_version
|
||||
except (FileNotFoundError, IOError):
|
||||
pass
|
||||
|
||||
# If all else fails, try to extract from directory name
|
||||
dir_name = os.path.basename(os.path.normpath(rocm_dir))
|
||||
match = re.search(r'rocm-(\d+\.\d+)', dir_name)
|
||||
if match:
|
||||
version_str = match.group(1)
|
||||
version = parse(version_str)
|
||||
rocm_version = f"{version.major}{version.minor}"
|
||||
return rocm_version
|
||||
|
||||
# Fallback to extracting from hipcc version
|
||||
try:
|
||||
raw_output = subprocess.check_output(
|
||||
[rocm_dir + "/bin/hipcc", "--version"],
|
||||
universal_newlines=True,
|
||||
stderr=subprocess.STDOUT)
|
||||
match = re.search(r'HIP version: (\d+\.\d+)', raw_output)
|
||||
if match:
|
||||
version_str = match.group(1)
|
||||
version = parse(version_str)
|
||||
rocm_version = f"{version.major}{version.minor}"
|
||||
return rocm_version
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
# If we still can't determine the version, raise an error
|
||||
raise ValueError(f"Could not determine ROCm version from directory: {rocm_dir}")
|
||||
|
||||
def get_cuda_bare_metal_version(self, cuda_dir):
|
||||
raw_output = subprocess.check_output(
|
||||
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
|
@ -148,11 +212,13 @@ class VersionInfo:
|
|||
cpu_instruct = self.get_cpu_instruct()
|
||||
backend_version = ""
|
||||
if CUDA_HOME is not None:
|
||||
backend_version = f"cu{self.get_cuda_bare_metal_version(CUDA_HOME)}"
|
||||
backend_version = f""
|
||||
elif MUSA_HOME is not None:
|
||||
backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}"
|
||||
elif ROCM_HOME is not None:
|
||||
backend_version = f"rocm{self.get_rocm_bare_metal_version(ROCM_HOME)}"
|
||||
else:
|
||||
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
|
||||
raise ValueError("Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set.")
|
||||
package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}"
|
||||
if full_version:
|
||||
return package_version
|
||||
|
@ -247,9 +313,13 @@ class CMakeBuild(BuildExtension):
|
|||
cmake_args += ["-DKTRANSFORMERS_USE_CUDA=ON"]
|
||||
elif MUSA_HOME is not None:
|
||||
cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"]
|
||||
elif ROCM_HOME is not None:
|
||||
cmake_args += ["-DKTRANSFORMERS_USE_ROCM=ON"]
|
||||
else:
|
||||
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
|
||||
|
||||
# log cmake_args
|
||||
print("CMake args:", cmake_args)
|
||||
|
||||
build_args = []
|
||||
if "CMAKE_ARGS" in os.environ:
|
||||
cmake_args += [
|
||||
|
@ -328,7 +398,7 @@ class CMakeBuild(BuildExtension):
|
|||
["cmake", "--build", ".", "--verbose", *build_args], cwd=build_temp, check=True
|
||||
)
|
||||
|
||||
if CUDA_HOME is not None:
|
||||
if CUDA_HOME is not None or ROCM_HOME is not None:
|
||||
ops_module = CUDAExtension('KTransformersOps', [
|
||||
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
|
||||
'ktransformers/ktransformers_ext/cuda/binding.cpp',
|
||||
|
@ -338,7 +408,7 @@ if CUDA_HOME is not None:
|
|||
'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'],
|
||||
'nvcc': [
|
||||
'-O3',
|
||||
'--use_fast_math',
|
||||
# '--use_fast_math',
|
||||
'-Xcompiler', '-fPIC',
|
||||
'-DKTRANSFORMERS_USE_CUDA',
|
||||
]
|
||||
|
@ -371,6 +441,7 @@ else:
|
|||
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
|
||||
|
||||
setup(
|
||||
name=VersionInfo.PACKAGE_NAME,
|
||||
version=VersionInfo().get_package_version(),
|
||||
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
|
||||
ext_modules=[
|
||||
|
|
Loading…
Add table
Reference in a new issue