From 21fca5a326097de6629098ede47357b899868010 Mon Sep 17 00:00:00 2001 From: fxzjshm Date: Thu, 13 Feb 2025 00:58:59 +0800 Subject: [PATCH 1/6] Add compat layer from llama.cpp Signed-off-by: fxzjshm --- .../ktransformers_ext/CMakeLists.txt | 41 ++++- .../ktransformers_ext/cpu_backend/cpuinfer.h | 2 +- .../ktransformers_ext/ext_bindings.cpp | 1 - .../ktransformers_ext/vendors/cuda.h | 15 ++ ktransformers/ktransformers_ext/vendors/hip.h | 172 ++++++++++++++++++ .../ktransformers_ext/vendors/musa.h | 137 ++++++++++++++ .../ktransformers_ext/vendors/vendor.h | 13 ++ 7 files changed, 376 insertions(+), 5 deletions(-) create mode 100644 ktransformers/ktransformers_ext/vendors/cuda.h create mode 100644 ktransformers/ktransformers_ext/vendors/hip.h create mode 100644 ktransformers/ktransformers_ext/vendors/musa.h create mode 100644 ktransformers/ktransformers_ext/vendors/vendor.h diff --git a/ktransformers/ktransformers_ext/CMakeLists.txt b/ktransformers/ktransformers_ext/CMakeLists.txt index d9ecd7a..30e1e6e 100644 --- a/ktransformers/ktransformers_ext/CMakeLists.txt +++ b/ktransformers/ktransformers_ext/CMakeLists.txt @@ -198,6 +198,19 @@ endif() # message(STATUS "Can't found CUDA lib") # endif() +if (NOT EXISTS $ENV{ROCM_PATH}) + if (NOT EXISTS /opt/rocm) + set(ROCM_PATH /usr) + else() + set(ROCM_PATH /opt/rocm) + endif() +else() + set(ROCM_PATH $ENV{ROCM_PATH}) +endif() + +list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) +list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}/lib64/cmake") + add_compile_options("$<$:${ARCH_FLAGS}>") add_compile_options("$<$:${ARCH_FLAGS}>") @@ -208,8 +221,18 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party) if (WIN32) include_directories("$ENV{CUDA_PATH}/include") elseif (UNIX) - find_package(CUDA REQUIRED) - include_directories("${CUDA_INCLUDE_DIRS}") + find_package(CUDA) + find_package(HIP) + find_package(MUSA) + if(CUDA_FOUND) + include_directories("${CUDA_INCLUDE_DIRS}") + endif() + if(HIP_FOUND) + include_directories("${HIP_INCLUDE_DIRS}") + endif() + if(MUSA_FOUND) + include_directories("${MUSA_INCLUDE_DIRS}") + endif() endif() aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1) @@ -228,7 +251,19 @@ elseif(UNIX) if(NOT DEFINED ENV{CUDA_HOME} OR "$ENV{CUDA_HOME}" STREQUAL "") set(ENV{CUDA_HOME} "/usr/local/cuda") endif() - target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so") + if(CUDA_FOUND) + add_compile_definitions(USE_CUDA=1) + target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so") + message(STATUS "Building for CUDA") + endif() + if(HIP_FOUND) + add_compile_definitions(USE_HIP=1) + message(STATUS "Building for HIP") + endif() + if(MUSA_FOUND) + add_compile_definitions(USE_MUSA=1) + message(STATUS "Building for MUSA") + endif() endif() # Define the USE_NUMA option diff --git a/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h b/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h index 9618e6b..180eb1d 100644 --- a/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h +++ b/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h @@ -17,10 +17,10 @@ #include #include #include -#include "cuda_runtime.h" #include "backend.h" #include "task_queue.h" +#include "../vendors/vendor.h" #include "llama.cpp/ggml-impl.h" diff --git a/ktransformers/ktransformers_ext/ext_bindings.cpp b/ktransformers/ktransformers_ext/ext_bindings.cpp index 902d427..0078a79 100644 --- a/ktransformers/ktransformers_ext/ext_bindings.cpp +++ b/ktransformers/ktransformers_ext/ext_bindings.cpp @@ -9,7 +9,6 @@ **/ // Python bindings #include "cpu_backend/cpuinfer.h" -#include "device_launch_parameters.h" #include "llamafile/flags.h" #include "operators/kvcache/kvcache.h" #include "operators/llamafile/linear.h" diff --git a/ktransformers/ktransformers_ext/vendors/cuda.h b/ktransformers/ktransformers_ext/vendors/cuda.h new file mode 100644 index 0000000..1746b07 --- /dev/null +++ b/ktransformers/ktransformers_ext/vendors/cuda.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include +#include +#include +#include + +#if CUDART_VERSION < 11020 +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED +#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH +#define CUBLAS_COMPUTE_16F CUDA_R_16F +#define CUBLAS_COMPUTE_32F CUDA_R_32F +#define cublasComputeType_t cudaDataType_t +#endif // CUDART_VERSION < 11020 diff --git a/ktransformers/ktransformers_ext/vendors/hip.h b/ktransformers/ktransformers_ext/vendors/hip.h new file mode 100644 index 0000000..abbc1e8 --- /dev/null +++ b/ktransformers/ktransformers_ext/vendors/hip.h @@ -0,0 +1,172 @@ +#pragma once + +#define HIP_ENABLE_WARP_SYNC_BUILTINS 1 +#include +#include +#include +#include +#ifdef __HIP_PLATFORM_AMD__ +// for rocblas_initialize() +#include "rocblas/rocblas.h" +#endif // __HIP_PLATFORM_AMD__ + +#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F +#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F +#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F +#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT +#define CUBLAS_OP_N HIPBLAS_OP_N +#define CUBLAS_OP_T HIPBLAS_OP_T +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define CUBLAS_TF32_TENSOR_OP_MATH 0 +#define CUDA_R_16F HIPBLAS_R_16F +#define CUDA_R_32F HIPBLAS_R_32F +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported +#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended +#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned +#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice +#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite +#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} +#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) +#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) +#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6 +#define cublasCreate hipblasCreate +#define cublasDestroy hipblasDestroy +#define cublasGemmEx hipblasGemmEx +#define cublasGemmBatchedEx hipblasGemmBatchedEx +#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx +#define cublasHandle_t hipblasHandle_t +#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS +#define cublasSetStream hipblasSetStream +#define cublasSgemm hipblasSgemm +#define cublasStatus_t hipblasStatus_t +#define cublasOperation_t hipblasOperation_t +#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6 +#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer +#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess +#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess +#define cudaDeviceProp hipDeviceProp_t +#define cudaDeviceSynchronize hipDeviceSynchronize +#define cudaError_t hipError_t +#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled +#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled +#define cudaEventCreateWithFlags hipEventCreateWithFlags +#define cudaEventDisableTiming hipEventDisableTiming +#define cudaEventRecord hipEventRecord +#define cudaEventSynchronize hipEventSynchronize +#define cudaEvent_t hipEvent_t +#define cudaEventDestroy hipEventDestroy +#define cudaFree hipFree +#define cudaFreeHost hipHostFree +#define cudaGetDevice hipGetDevice +#define cudaGetDeviceCount hipGetDeviceCount +#define cudaGetDeviceProperties hipGetDeviceProperties +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaHostRegister hipHostRegister +#define cudaHostRegisterPortable hipHostRegisterPortable +#define cudaHostRegisterReadOnly hipHostRegisterReadOnly +#define cudaHostUnregister hipHostUnregister +#define cudaLaunchHostFunc hipLaunchHostFunc +#define cudaMalloc hipMalloc +#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) +#define cudaMemcpy hipMemcpy +#define cudaMemcpyAsync hipMemcpyAsync +#define cudaMemcpyPeerAsync hipMemcpyPeerAsync +#define cudaMemcpy2DAsync hipMemcpy2DAsync +#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemcpyKind hipMemcpyKind +#define cudaMemset hipMemset +#define cudaMemsetAsync hipMemsetAsync +#define cudaMemGetInfo hipMemGetInfo +#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize +#define cudaSetDevice hipSetDevice +#define cuDeviceGet hipDeviceGet +#define CUdevice hipDevice_t +#define CUdeviceptr hipDeviceptr_t +#define cuMemUnmap hipMemUnmap +#define CUmemAccessDesc hipMemAccessDesc +#define cuMemAddressFree hipMemAddressFree +#define cuMemRelease hipMemRelease +#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t +#define cuMemCreate hipMemCreate +#define cuMemAddressReserve hipMemAddressReserve +#define cuMemMap hipMemMap +#define cuMemSetAccess hipMemSetAccess +#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity +#define CUmemAllocationProp hipMemAllocationProp +#define cuDeviceGetAttribute hipDeviceGetAttribute +#define cudaStreamCreateWithFlags hipStreamCreateWithFlags +#define cudaStreamDestroy hipStreamDestroy +#define cudaStreamFireAndForget hipStreamFireAndForget +#define cudaStreamNonBlocking hipStreamNonBlocking +#define cudaStreamPerThread hipStreamPerThread +#define cudaStreamSynchronize hipStreamSynchronize +#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) +#define cudaGraphExec_t hipGraphExec_t +#define cudaGraphNode_t hipGraphNode_t +#define cudaKernelNodeParams hipKernelNodeParams +#define cudaKernelNodeParams hipKernelNodeParams +#define cudaGraphExecDestroy hipGraphExecDestroy +#define cudaGraphLaunch hipGraphLaunch +#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure +#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult +#define cudaGraphNodeType hipGraphNodeType +#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel +#define cudaGraphInstantiate hipGraphInstantiate +#define cudaStreamEndCapture hipStreamEndCapture +#define cudaGraphDestroy hipGraphDestroy +#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams +#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction +#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams +#define cudaGraphNodeGetType hipGraphNodeGetType +#define cudaGraphGetNodes hipGraphGetNodes +#define cudaGraphExecUpdate hipGraphExecUpdate +#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed +#define cudaStreamBeginCapture hipStreamBeginCapture +#define cudaGraph_t hipGraph_t +#define cudaStream_t hipStream_t +#define cudaSuccess hipSuccess +#define cudaHostFn_t hipHostFn_t +#define __trap() do { abort(); __builtin_unreachable(); } while(0) +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED +#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED +#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE +#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH +#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR +#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED +#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR +#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED + +#define __CUDA_ARCH__ 1300 + +#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) +#define GCN +#endif + +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) +#define CDNA +#endif + +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ + defined(__gfx1150__) || defined(__gfx1151__) +#define RDNA3 +#endif + +#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \ + defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__) +#define RDNA2 +#endif + +#if defined(__gfx1010__) || defined(__gfx1012__) +#define RDNA1 +#endif + +#ifndef __has_builtin + #define __has_builtin(x) 0 +#endif + +typedef hip_bfloat16 nv_bfloat16; diff --git a/ktransformers/ktransformers_ext/vendors/musa.h b/ktransformers/ktransformers_ext/vendors/musa.h new file mode 100644 index 0000000..6cc1b69 --- /dev/null +++ b/ktransformers/ktransformers_ext/vendors/musa.h @@ -0,0 +1,137 @@ +#pragma once + +#include +#include +#include +#include +#include +#define CUBLAS_COMPUTE_16F CUDA_R_16F +#define CUBLAS_COMPUTE_32F CUDA_R_32F +#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F +#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT +#define CUBLAS_OP_N MUBLAS_OP_N +#define CUBLAS_OP_T MUBLAS_OP_T +#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS +#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT +#define CUDA_R_16F MUSA_R_16F +#define CUDA_R_32F MUSA_R_32F +#define cublasComputeType_t cudaDataType_t +#define cublasCreate mublasCreate +#define cublasDestroy mublasDestroy +#define cublasGemmEx mublasGemmEx +#define cublasGemmBatchedEx mublasGemmBatchedEx +#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx +#define cublasHandle_t mublasHandle_t +#define cublasSetMathMode mublasSetMathMode +#define cublasSetStream mublasSetStream +#define cublasSgemm mublasSgemm +#define cublasStatus_t mublasStatus_t +#define cublasOperation_t mublasOperation_t +#define cublasGetStatusString mublasStatus_to_string +#define cudaDataType_t musaDataType_t +#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer +#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess +#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess +#define cudaDeviceProp musaDeviceProp +#define cudaDeviceSynchronize musaDeviceSynchronize +#define cudaError_t musaError_t +#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled +#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled +#define cudaEventCreateWithFlags musaEventCreateWithFlags +#define cudaEventDisableTiming musaEventDisableTiming +#define cudaEventRecord musaEventRecord +#define cudaEventSynchronize musaEventSynchronize +#define cudaEvent_t musaEvent_t +#define cudaEventDestroy musaEventDestroy +#define cudaFree musaFree +#define cudaFreeHost musaFreeHost +#define cudaGetDevice musaGetDevice +#define cudaGetDeviceCount musaGetDeviceCount +#define cudaGetDeviceProperties musaGetDeviceProperties +#define cudaGetErrorString musaGetErrorString +#define cudaGetLastError musaGetLastError +#define cudaHostRegister musaHostRegister +#define cudaHostRegisterPortable musaHostRegisterPortable +#define cudaHostRegisterReadOnly musaHostRegisterReadOnly +#define cudaHostUnregister musaHostUnregister +#define cudaLaunchHostFunc musaLaunchHostFunc +#define cudaMalloc musaMalloc +#define cudaMallocHost musaMallocHost +#define cudaMallocManaged musaMallocManaged +#define cudaMemcpy musaMemcpy +#define cudaMemcpyAsync musaMemcpyAsync +#define cudaMemcpyPeerAsync musaMemcpyPeerAsync +#define cudaMemcpy2DAsync musaMemcpy2DAsync +#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice +#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost +#define cudaMemcpyHostToDevice musaMemcpyHostToDevice +#define cudaMemcpyKind musaMemcpyKind +#define cudaMemset musaMemset +#define cudaMemsetAsync musaMemsetAsync +#define cudaMemGetInfo musaMemGetInfo +#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize +#define cudaSetDevice musaSetDevice +#define cudaStreamCreateWithFlags musaStreamCreateWithFlags +#define cudaStreamDestroy musaStreamDestroy +#define cudaStreamFireAndForget musaStreamFireAndForget +#define cudaStreamNonBlocking musaStreamNonBlocking +#define cudaStreamPerThread musaStreamPerThread +#define cudaStreamSynchronize musaStreamSynchronize +#define cudaStreamWaitEvent musaStreamWaitEvent +#define cudaStream_t musaStream_t +#define cudaSuccess musaSuccess + +// Additional mappings for MUSA virtual memory pool +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED +#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE +#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED +#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED +#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE +#define CUdevice MUdevice +#define CUdeviceptr MUdeviceptr +#define CUmemAccessDesc MUmemAccessDesc +#define CUmemAllocationProp MUmemAllocationProp +#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle +#define cuDeviceGet muDeviceGet +#define cuDeviceGetAttribute muDeviceGetAttribute +#define cuMemAddressFree muMemAddressFree +#define cuMemAddressReserve muMemAddressReserve +#define cuMemCreate muMemCreate +#define cuMemGetAllocationGranularity muMemGetAllocationGranularity +#define cuMemMap muMemMap +#define cuMemRelease muMemRelease +#define cuMemSetAccess muMemSetAccess +#define cuMemUnmap muMemUnmap +#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize +#define cudaFuncSetAttribute musaFuncSetAttribute +#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms +#define make_cudaExtent make_musaExtent +#define make_cudaPitchedPtr make_musaPitchedPtr + +// Additional mappings for MUSA graphs +#define CUDA_SUCCESS MUSA_SUCCESS +#define CUresult MUresult +#define cuGetErrorString muGetErrorString +#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure +#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction +#define cudaGraphDestroy musaGraphDestroy +#define cudaGraphExecDestroy musaGraphExecDestroy +#define cudaGraphExec_t musaGraphExec_t +#define cudaGraphExecUpdate musaGraphExecUpdate +#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult +#define cudaGraphGetNodes musaGraphGetNodes +#define cudaGraphInstantiate musaGraphInstantiate +#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams +#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams +#define cudaGraphLaunch musaGraphLaunch +#define cudaGraphNodeGetType musaGraphNodeGetType +#define cudaGraphNode_t musaGraphNode_t +#define cudaGraphNodeType musaGraphNodeType +#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel +#define cudaGraph_t musaGraph_t +#define cudaKernelNodeParams musaKernelNodeParams +#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed +#define cudaStreamEndCapture musaStreamEndCapture + +typedef mt_bfloat16 nv_bfloat16; diff --git a/ktransformers/ktransformers_ext/vendors/vendor.h b/ktransformers/ktransformers_ext/vendors/vendor.h new file mode 100644 index 0000000..8470438 --- /dev/null +++ b/ktransformers/ktransformers_ext/vendors/vendor.h @@ -0,0 +1,13 @@ +#ifndef CPUINFER_VENDOR_VENDOR_H +#define CPUINFER_VENDOR_VENDOR_H + +#ifdef USE_CUDA +#include "cuda.h" +#elif USE_HIP +#define __HIP_PLATFORM_AMD__ +#include "hip.h" +#elif USE_MUSA +#include "musa.h" +#endif + +#endif // CPUINFER_VENDOR_VENDOR_H \ No newline at end of file From 4cda45433f5d464fcddba0eebb69175acf71eee1 Mon Sep 17 00:00:00 2001 From: fxzjshm Date: Thu, 13 Feb 2025 00:59:28 +0800 Subject: [PATCH 2/6] Don't add CUDA version to version in case not for CUDA Signed-off-by: fxzjshm --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index d24db14..390fec6 100644 --- a/setup.py +++ b/setup.py @@ -129,7 +129,7 @@ class VersionInfo: def get_package_version(self, full_version=False): flash_version = self.get_flash_version() - package_version = f"{str(flash_version)}+cu{self.get_cuda_bare_metal_version(CUDA_HOME)}torch{self.get_torch_version()}{self.get_cpu_instruct()}" + package_version = f"{str(flash_version)}+torch{self.get_torch_version()}{self.get_cpu_instruct()}" if full_version: return package_version if not VersionInfo.FORCE_BUILD: @@ -306,7 +306,7 @@ setup( 'cxx': ['-O3'], 'nvcc': [ '-O3', - '--use_fast_math', + # '--use_fast_math', '-Xcompiler', '-fPIC', ] } From ae76a729d872ece70952fb63c8fe497fa6787863 Mon Sep 17 00:00:00 2001 From: fxzjshm Date: Thu, 13 Feb 2025 02:03:22 +0800 Subject: [PATCH 3/6] gptq_marlin: temporarily disable on AMD ROCm Signed-off-by: fxzjshm --- .../ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu | 2 +- .../ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh | 2 +- .../cuda/gptq_marlin/gptq_marlin_dtypes.cuh | 5 +++++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu index 54e538a..87f4581 100644 --- a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu +++ b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu @@ -36,7 +36,7 @@ inline std::string str(T x) { namespace gptq_marlin { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined(__HIP_PLATFORM_AMD__) __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, diff --git a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh index 66a5920..ccf9cfd 100644 --- a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh +++ b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh @@ -39,7 +39,7 @@ using I4 = Vec; constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined (__HIP_PLATFORM_AMD__) // No support for async #else diff --git a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh index b8babfb..80f6ea4 100644 --- a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh +++ b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh @@ -8,6 +8,11 @@ #include #include +#ifdef __HIP_PLATFORM_AMD__ +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; +#endif + namespace gptq_marlin { template From 38e5dbc8955116e899a16f632b319468c970d05e Mon Sep 17 00:00:00 2001 From: fxzjshm Date: Thu, 13 Feb 2025 03:14:35 +0800 Subject: [PATCH 4/6] Fix symbol lookup Signed-off-by: fxzjshm --- ktransformers/ktransformers_ext/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/ktransformers/ktransformers_ext/CMakeLists.txt b/ktransformers/ktransformers_ext/CMakeLists.txt index 30e1e6e..8f7c6bc 100644 --- a/ktransformers/ktransformers_ext/CMakeLists.txt +++ b/ktransformers/ktransformers_ext/CMakeLists.txt @@ -258,6 +258,7 @@ elseif(UNIX) endif() if(HIP_FOUND) add_compile_definitions(USE_HIP=1) + target_link_libraries(${PROJECT_NAME} PRIVATE "${ROCM_PATH}/lib/libamdhip64.so") message(STATUS "Building for HIP") endif() if(MUSA_FOUND) From c1f13a69ed535c0c5e0b8960c3066fd4ca4badb8 Mon Sep 17 00:00:00 2001 From: fxzjshm Date: Thu, 13 Feb 2025 03:15:22 +0800 Subject: [PATCH 5/6] Correctly import compat layer from llama.cpp Signed-off-by: fxzjshm --- ktransformers/ktransformers_ext/CMakeLists.txt | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/ktransformers/ktransformers_ext/CMakeLists.txt b/ktransformers/ktransformers_ext/CMakeLists.txt index 8f7c6bc..60cf721 100644 --- a/ktransformers/ktransformers_ext/CMakeLists.txt +++ b/ktransformers/ktransformers_ext/CMakeLists.txt @@ -211,6 +211,18 @@ endif() list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}/lib64/cmake") +if (NOT EXISTS $ENV{MUSA_PATH}) + if (NOT EXISTS /opt/musa) + set(MUSA_PATH /usr/local/musa) + else() + set(MUSA_PATH /opt/musa) + endif() +else() + set(MUSA_PATH $ENV{MUSA_PATH}) +endif() + +list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake") + add_compile_options("$<$:${ARCH_FLAGS}>") add_compile_options("$<$:${ARCH_FLAGS}>") @@ -223,14 +235,14 @@ if (WIN32) elseif (UNIX) find_package(CUDA) find_package(HIP) - find_package(MUSA) + find_package(MUSAToolkit) if(CUDA_FOUND) include_directories("${CUDA_INCLUDE_DIRS}") endif() if(HIP_FOUND) include_directories("${HIP_INCLUDE_DIRS}") endif() - if(MUSA_FOUND) + if(MUSAToolkit_FOUND) include_directories("${MUSA_INCLUDE_DIRS}") endif() endif() @@ -261,7 +273,7 @@ elseif(UNIX) target_link_libraries(${PROJECT_NAME} PRIVATE "${ROCM_PATH}/lib/libamdhip64.so") message(STATUS "Building for HIP") endif() - if(MUSA_FOUND) + if(MUSAToolkit_FOUND) add_compile_definitions(USE_MUSA=1) message(STATUS "Building for MUSA") endif() From 086a9d1cceff60f709d6236262e44aa7b6b6aae8 Mon Sep 17 00:00:00 2001 From: Azure-Tang Date: Thu, 13 Mar 2025 07:10:26 -0400 Subject: [PATCH 6/6] Add vendor control --- .../cuda/custom_gguf/dequant.cu | 1 + ktransformers/local_chat.py | 3 +- ktransformers/operators/attention.py | 37 +++- ktransformers/operators/dynamic_attention.py | 5 +- ktransformers/operators/models.py | 3 +- ktransformers/operators/triton_attention.py | 6 +- .../operators/triton_attention_prefill.py | 206 ++++++++++++++++++ ktransformers/util/vendors.py | 202 +++++++++++++++++ 8 files changed, 446 insertions(+), 17 deletions(-) create mode 100644 ktransformers/operators/triton_attention_prefill.py create mode 100644 ktransformers/util/vendors.py diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu index e80efc4..11100e3 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu @@ -15,6 +15,7 @@ #include #include #include +typedef hip_bfloat16 nv_bfloat16; __global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 4acaf86..386bbe7 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -31,6 +31,7 @@ from ktransformers.models.modeling_mixtral import MixtralForCausalLM from ktransformers.util.utils import prefill_and_generate, get_compute_capability from ktransformers.server.config.config import Config from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled +from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor custom_models = { "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, @@ -169,7 +170,7 @@ def local_chat( assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ "please change max_seq_len in ~/.ktransformers/config.yaml" - if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8: + if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA: generated = prefill_and_generate( model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size, use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index a9bbea6..eae3cf0 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -20,8 +20,14 @@ from ktransformers.util.utils import get_compute_capability import logging from transformers.configuration_utils import PretrainedConfig from transformers.cache_utils import Cache -from flash_attn import flash_attn_func -from ktransformers.operators.triton_attention import decode_attention_fwd_grouped +from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor + +try: + from flash_attn import flash_attn_func +except: + pass +from ktransformers.operators.triton_attention import decode_attention_fwd_grouped +from ktransformers.operators.triton_attention_prefill import context_attention_fwd import os from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled if flashinfer_enabled: @@ -319,18 +325,27 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1) value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim) - value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0) - attn_output = flash_attn_func( - query_states, - key_states, - value_states_padded, - softmax_scale=self.softmax_scale, - causal=True, + # for bsz = 1 + attn_output = torch.zeros(bsz * q_len, self.num_heads, self.v_head_dim, device=hidden_states.device) + b_start_loc = torch.zeros(bsz, dtype=torch.int64, device=hidden_states.device) + b_seq_len = torch.full((bsz,), q_len, dtype=torch.int64, device=hidden_states.device) + + max_input_len = q_len + + context_attention_fwd( + q=query_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim), + k=key_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim), + v=value_states.squeeze(0).view(-1, self.num_heads, self.v_head_dim), + o=attn_output, + b_start_loc=b_start_loc, + b_seq_len=b_seq_len, + max_input_len=max_input_len, + is_causal=True ) if self.q_head_dim != self.v_head_dim: - attn_output = attn_output[:, :, :, : self.v_head_dim] + attn_output = attn_output[:, :, : self.v_head_dim] attn_output = attn_output.reshape( bsz, q_len, self.num_heads * self.v_head_dim @@ -589,7 +604,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if os.name == 'nt' or get_compute_capability()<8: + if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA: print("for Windows or GPU before ampere, use forward_windows") return self.forward_windows( hidden_states, diff --git a/ktransformers/operators/dynamic_attention.py b/ktransformers/operators/dynamic_attention.py index 13a74b4..09ccb6c 100644 --- a/ktransformers/operators/dynamic_attention.py +++ b/ktransformers/operators/dynamic_attention.py @@ -17,7 +17,10 @@ import logging logger = logging.getLogger("dynamic_attention") sys.path.append(os.path.dirname(__file__) + "/../ktransformers_ext/cpu_backend") from ktransformers.operators.cpuinfer import CPUInfer, CPUInferKVCache -from flash_attn import flash_attn_func, flash_attn_with_kvcache +try: + from flash_attn import flash_attn_func, flash_attn_with_kvcache +except: + print("falsh attn not found") import math diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py index 57d4bea..59e89bd 100644 --- a/ktransformers/operators/models.py +++ b/ktransformers/operators/models.py @@ -53,6 +53,7 @@ from ktransformers.models.modeling_deepseek import ( DeepseekV2DecoderLayer, DeepseekV2MoE, ) +from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.operators.base_operator import BaseInjectedModule @@ -649,7 +650,7 @@ class KDeepseekV2Model(BaseInjectedModule): if per_layer_prefill_flag: causal_mask = None else: - if os.name == 'nt' or get_compute_capability()<8: + if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA: print("for Windows or GPU before ampere, use forward_windows") # only use mask in forward windows or can't flash attn causal_mask = self._update_causal_mask( diff --git a/ktransformers/operators/triton_attention.py b/ktransformers/operators/triton_attention.py index 4437520..aafdea0 100644 --- a/ktransformers/operators/triton_attention.py +++ b/ktransformers/operators/triton_attention.py @@ -6,7 +6,7 @@ import triton import triton.language as tl - +from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor @triton.jit def tanh(x): # Tanh is just a scaled sigmoid @@ -181,8 +181,8 @@ def _decode_grouped_att_m_fwd( # [TODO] work around shmem limit on MI3xx # TODO: support hip - #if is_hip_ and Lk >= 576: - # BLOCK = 16 + if device_manager.gpu_vendor == GPUVendor.AMD and Lk >= 576: + BLOCK = 16 if Lk == 576: BLOCK_DMODEL = 512 diff --git a/ktransformers/operators/triton_attention_prefill.py b/ktransformers/operators/triton_attention_prefill.py new file mode 100644 index 0000000..a807ef3 --- /dev/null +++ b/ktransformers/operators/triton_attention_prefill.py @@ -0,0 +1,206 @@ + +# Adapted from +# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py +# which was originally adapted from +# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1 + +""" +Memory-efficient attention for prefill. +It supporst page size = 1. +""" + +# Adapted from +# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1 +import torch +import triton +import triton.language as tl + +is_cuda_available = torch.cuda.is_available() +if is_cuda_available: + CUDA_CAPABILITY = torch.cuda.get_device_capability() + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, + Out, + stride_qbs, + stride_qh, + stride_kbs, + stride_kh, + stride_vbs, + stride_vh, + stride_obs, + stride_oh, + kv_group_num: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, + Lk: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] + ) + off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] + off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] + + mask_d = offs_d < Lk + + q = tl.load( + Q + off_q, + mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]), + other=0.0, + ) + + k_ptrs = K + off_k + v_ptrs = V + off_v + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + end_n = ( + cur_batch_seq_len + if not IS_CAUSAL + else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len) + ) + for start_n in range(0, block_mask * end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]), + other=0.0, + ) + # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + if IS_CAUSAL: + qk += tl.where( + (start_n + offs_n[None, :] < cur_batch_seq_len) + & (offs_m[:, None] >= (start_n + offs_n[None, :])), + 0, + float("-inf"), + ) + else: + qk += tl.where( + (start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float("-inf") + ) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]), + other=0.0, + ) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] + ) + out_ptrs = Out + off_o + tl.store( + out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]) + ) + + +def context_attention_fwd( + q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True +): + """ + q, k, v: [b * s, head, head_dim] + b_start_loc: [b] + b_seq_len: [b] + out: [b * s, head, head_dim] + """ + if is_cuda_available and CUDA_CAPABILITY[0] > 8: + BLOCK = 128 + else: + BLOCK = 64 + + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + num_warps = 4 if Lk <= 64 else 8 + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, + o, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + o.stride(0), + o.stride(1), + kv_group_num=kv_group_num, + BLOCK_M=BLOCK, + BLOCK_DMODEL=triton.next_power_of_2(Lk), + BLOCK_N=BLOCK, + IS_CAUSAL=is_causal, + num_warps=num_warps, + num_stages=1, + Lk=Lk, + ) \ No newline at end of file diff --git a/ktransformers/util/vendors.py b/ktransformers/util/vendors.py new file mode 100644 index 0000000..c9a709e --- /dev/null +++ b/ktransformers/util/vendors.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +from enum import IntEnum, auto +from typing import Optional, Union, List +import torch + +class GPUVendor(IntEnum): + NVIDIA = auto() + AMD = auto() + MooreThreads = auto() + MetaX = auto() + MUSA = auto() + Unknown = auto() + +class DeviceManager: + """ + Device manager that provides a unified interface for handling different GPU vendors + """ + def __init__(self): + self.gpu_vendor = self._detect_gpu_vendor() + self.available_devices = self._get_available_devices() + + def _detect_gpu_vendor(self) -> GPUVendor: + """Detect GPU vendor type""" + if not torch.cuda.is_available(): + # Check MUSA availability (assuming a musa module exists) + try: + import musa + if musa.is_available(): + return GPUVendor.MUSA + except (ImportError, AttributeError): + pass + + return GPUVendor.Unknown + + device_name = torch.cuda.get_device_name(0).lower() + + if any(name in device_name for name in ["nvidia", "geforce", "quadro", "tesla", "titan", "rtx", "gtx"]): + return GPUVendor.NVIDIA + elif any(name in device_name for name in ["amd", "radeon", "rx", "vega", "instinct", "firepro", "mi"]): + return GPUVendor.AMD + elif any(name in device_name for name in ["mthreads", "moore", "mtt"]): + return GPUVendor.MooreThreads + elif any(name in device_name for name in ["metax", "meta"]): + return GPUVendor.MetaX + elif "musa" in device_name: + return GPUVendor.MUSA + + # Backend check + try: + if hasattr(torch.version, 'hip') and torch.version.hip is not None: + return GPUVendor.AMD + elif hasattr(torch.version, 'cuda') and torch.version.cuda is not None: + return GPUVendor.NVIDIA + except: + pass + + return GPUVendor.Unknown + + def _get_available_devices(self) -> List[int]: + """Get list of available device indices""" + devices = [] + + if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD: + devices = list(range(torch.cuda.device_count())) + elif self.gpu_vendor == GPUVendor.MUSA: + try: + import musa + devices = list(range(musa.device_count())) + except (ImportError, AttributeError): + pass + + return devices + + def get_device_str(self, device_id: Union[int, str]) -> str: + """ + Get device string for the given device ID + + Args: + device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string + + Returns: + Device string representation (e.g., "cuda:0", "musa:1", "cpu") + """ + if device_id == -1 or device_id == "cpu": + return "cpu" + + if isinstance(device_id, int): + if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD: + if device_id < torch.cuda.device_count(): + return f"cuda:{device_id}" + elif self.gpu_vendor == GPUVendor.MUSA: + try: + import musa + if device_id < musa.device_count(): + return f"musa:{device_id}" + except (ImportError, AttributeError): + pass + + return "cpu" + + def to_torch_device(self, device_id: Union[int, str] = 0) -> torch.device: + """ + Convert device ID to torch.device object + + Args: + device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string + + Returns: + torch.device object + """ + device_str = self.get_device_str(device_id) + + # Handle MUSA device + if device_str.startswith("musa:"): + try: + import musa + index = int(device_str.split(":")[-1]) + return musa.device(index) + except (ImportError, ValueError, AttributeError): + return torch.device("cpu") + + # Standard PyTorch device + return torch.device(device_str) + + def move_tensor_to_device(self, tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor: + """ + Move tensor to specified device + + Args: + tensor: PyTorch tensor to move + device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string + + Returns: + Tensor moved to the specified device + """ + device = self.to_torch_device(device_id) + return tensor.to(device) + + def is_available(self, index: int = 0) -> bool: + """ + Check if device at specified index is available + + Args: + index: Device index to check + + Returns: + True if the device is available, False otherwise + """ + if index < 0: + return True # CPU is always available + + return index in self.available_devices + + def get_all_devices(self) -> List[int]: + """ + Get all available device indices + + Returns: + List of available device indices (0, 1, 2, etc.) + """ + return self.available_devices + +# Create global device manager instance +device_manager = DeviceManager() + +# Convenience functions +def get_device(device_id: Union[int, str] = 0) -> torch.device: + """ + Get torch.device object for the specified device ID + + Args: + device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string + + Returns: + torch.device object + """ + return device_manager.to_torch_device(device_id) + +def to_device(tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor: + """ + Move tensor to specified device + + Args: + tensor: PyTorch tensor to move + device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string + + Returns: + Tensor moved to the specified device + """ + return device_manager.move_tensor_to_device(tensor, device_id) + +# Get devices +cpu_device = get_device(-1) # CPU using index -1 +cpu_device2 = get_device("cpu") # CPU using string "cpu" +gpu0 = get_device(0) # First GPU + +# Move tensors +x = torch.randn(3, 3) +x_gpu = to_device(x, 0) # Move to first GPU +x_cpu1 = to_device(x, -1) # Move to CPU using index -1 +x_cpu2 = to_device(x, "cpu") # Move to CPU using string "cpu" \ No newline at end of file