mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-07 21:19:51 +00:00
172 lines
7.5 KiB
C
172 lines
7.5 KiB
C
#pragma once
|
|
|
|
#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
|
|
#include <hip/hip_runtime.h>
|
|
#include <hipblas/hipblas.h>
|
|
#include <hip/hip_fp16.h>
|
|
#include <hip/hip_bfloat16.h>
|
|
#ifdef __HIP_PLATFORM_AMD__
|
|
// for rocblas_initialize()
|
|
#include "rocblas/rocblas.h"
|
|
#endif // __HIP_PLATFORM_AMD__
|
|
|
|
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
|
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
|
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
|
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
|
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
|
#define CUBLAS_OP_N HIPBLAS_OP_N
|
|
#define CUBLAS_OP_T HIPBLAS_OP_T
|
|
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
|
#define CUBLAS_TF32_TENSOR_OP_MATH 0
|
|
#define CUDA_R_16F HIPBLAS_R_16F
|
|
#define CUDA_R_32F HIPBLAS_R_32F
|
|
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
|
|
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
|
|
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
|
|
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
|
|
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
|
|
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
|
|
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
|
|
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
|
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
|
|
#define cublasCreate hipblasCreate
|
|
#define cublasDestroy hipblasDestroy
|
|
#define cublasGemmEx hipblasGemmEx
|
|
#define cublasGemmBatchedEx hipblasGemmBatchedEx
|
|
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
|
|
#define cublasHandle_t hipblasHandle_t
|
|
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
|
#define cublasSetStream hipblasSetStream
|
|
#define cublasSgemm hipblasSgemm
|
|
#define cublasStatus_t hipblasStatus_t
|
|
#define cublasOperation_t hipblasOperation_t
|
|
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
|
|
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
|
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
|
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
|
#define cudaDeviceProp hipDeviceProp_t
|
|
#define cudaDeviceSynchronize hipDeviceSynchronize
|
|
#define cudaError_t hipError_t
|
|
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
|
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
|
|
#define cudaEventCreateWithFlags hipEventCreateWithFlags
|
|
#define cudaEventDisableTiming hipEventDisableTiming
|
|
#define cudaEventRecord hipEventRecord
|
|
#define cudaEventSynchronize hipEventSynchronize
|
|
#define cudaEvent_t hipEvent_t
|
|
#define cudaEventDestroy hipEventDestroy
|
|
#define cudaFree hipFree
|
|
#define cudaFreeHost hipHostFree
|
|
#define cudaGetDevice hipGetDevice
|
|
#define cudaGetDeviceCount hipGetDeviceCount
|
|
#define cudaGetDeviceProperties hipGetDeviceProperties
|
|
#define cudaGetErrorString hipGetErrorString
|
|
#define cudaGetLastError hipGetLastError
|
|
#define cudaHostRegister hipHostRegister
|
|
#define cudaHostRegisterPortable hipHostRegisterPortable
|
|
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
|
|
#define cudaHostUnregister hipHostUnregister
|
|
#define cudaLaunchHostFunc hipLaunchHostFunc
|
|
#define cudaMalloc hipMalloc
|
|
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
|
#define cudaMemcpy hipMemcpy
|
|
#define cudaMemcpyAsync hipMemcpyAsync
|
|
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
|
|
#define cudaMemcpy2DAsync hipMemcpy2DAsync
|
|
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
|
|
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
|
|
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
|
|
#define cudaMemcpyKind hipMemcpyKind
|
|
#define cudaMemset hipMemset
|
|
#define cudaMemsetAsync hipMemsetAsync
|
|
#define cudaMemGetInfo hipMemGetInfo
|
|
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
|
|
#define cudaSetDevice hipSetDevice
|
|
#define cuDeviceGet hipDeviceGet
|
|
#define CUdevice hipDevice_t
|
|
#define CUdeviceptr hipDeviceptr_t
|
|
#define cuMemUnmap hipMemUnmap
|
|
#define CUmemAccessDesc hipMemAccessDesc
|
|
#define cuMemAddressFree hipMemAddressFree
|
|
#define cuMemRelease hipMemRelease
|
|
#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
|
|
#define cuMemCreate hipMemCreate
|
|
#define cuMemAddressReserve hipMemAddressReserve
|
|
#define cuMemMap hipMemMap
|
|
#define cuMemSetAccess hipMemSetAccess
|
|
#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
|
|
#define CUmemAllocationProp hipMemAllocationProp
|
|
#define cuDeviceGetAttribute hipDeviceGetAttribute
|
|
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
|
|
#define cudaStreamDestroy hipStreamDestroy
|
|
#define cudaStreamFireAndForget hipStreamFireAndForget
|
|
#define cudaStreamNonBlocking hipStreamNonBlocking
|
|
#define cudaStreamPerThread hipStreamPerThread
|
|
#define cudaStreamSynchronize hipStreamSynchronize
|
|
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
|
|
#define cudaGraphExec_t hipGraphExec_t
|
|
#define cudaGraphNode_t hipGraphNode_t
|
|
#define cudaKernelNodeParams hipKernelNodeParams
|
|
#define cudaKernelNodeParams hipKernelNodeParams
|
|
#define cudaGraphExecDestroy hipGraphExecDestroy
|
|
#define cudaGraphLaunch hipGraphLaunch
|
|
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
|
|
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
|
|
#define cudaGraphNodeType hipGraphNodeType
|
|
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
|
|
#define cudaGraphInstantiate hipGraphInstantiate
|
|
#define cudaStreamEndCapture hipStreamEndCapture
|
|
#define cudaGraphDestroy hipGraphDestroy
|
|
#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
|
|
#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
|
|
#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
|
|
#define cudaGraphNodeGetType hipGraphNodeGetType
|
|
#define cudaGraphGetNodes hipGraphGetNodes
|
|
#define cudaGraphExecUpdate hipGraphExecUpdate
|
|
#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
|
|
#define cudaStreamBeginCapture hipStreamBeginCapture
|
|
#define cudaGraph_t hipGraph_t
|
|
#define cudaStream_t hipStream_t
|
|
#define cudaSuccess hipSuccess
|
|
#define cudaHostFn_t hipHostFn_t
|
|
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
|
|
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
|
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
|
|
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
|
|
#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
|
|
#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
|
|
#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
|
|
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
|
|
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
|
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
|
|
|
#define __CUDA_ARCH__ 1300
|
|
|
|
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
|
|
#define GCN
|
|
#endif
|
|
|
|
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
|
|
#define CDNA
|
|
#endif
|
|
|
|
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
|
defined(__gfx1150__) || defined(__gfx1151__)
|
|
#define RDNA3
|
|
#endif
|
|
|
|
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
|
|
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
|
|
#define RDNA2
|
|
#endif
|
|
|
|
#if defined(__gfx1010__) || defined(__gfx1012__)
|
|
#define RDNA1
|
|
#endif
|
|
|
|
#ifndef __has_builtin
|
|
#define __has_builtin(x) 0
|
|
#endif
|
|
|
|
typedef hip_bfloat16 nv_bfloat16;
|