diff --git a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu index 54e538a..87f4581 100644 --- a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu +++ b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu @@ -36,7 +36,7 @@ inline std::string str(T x) { namespace gptq_marlin { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined(__HIP_PLATFORM_AMD__) __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, diff --git a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh index 66a5920..ccf9cfd 100644 --- a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh +++ b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh @@ -39,7 +39,7 @@ using I4 = Vec; constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined (__HIP_PLATFORM_AMD__) // No support for async #else diff --git a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh index b8babfb..80f6ea4 100644 --- a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh +++ b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh @@ -8,6 +8,11 @@ #include #include +#ifdef __HIP_PLATFORM_AMD__ +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; +#endif + namespace gptq_marlin { template