gptq_marlin: temporarily disable on AMD ROCm

Signed-off-by: fxzjshm <fxzjshm@163.com>
This commit is contained in:
fxzjshm 2025-02-13 02:03:22 +08:00
parent 4cda45433f
commit ae76a729d8
3 changed files with 7 additions and 2 deletions

View file

@ -36,7 +36,7 @@ inline std::string str(T x) {
namespace gptq_marlin {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined(__HIP_PLATFORM_AMD__)
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,

View file

@ -39,7 +39,7 @@ using I4 = Vec<int, 4>;
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined (__HIP_PLATFORM_AMD__)
// No support for async
#else

View file

@ -8,6 +8,11 @@
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#ifdef __HIP_PLATFORM_AMD__
typedef __hip_bfloat16 nv_bfloat16;
typedef __hip_bfloat162 nv_bfloat162;
#endif
namespace gptq_marlin {
template <typename scalar_t>