diff --git a/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h b/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h index 7c94102..1892221 100644 --- a/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h +++ b/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h @@ -1,7 +1,9 @@ #pragma once #include +#include #define cudaLaunchHostFunc musaLaunchHostFunc #define cudaStream_t musaStream_t -#define cudaHostFn_t musaHostFn_t \ No newline at end of file +#define cudaHostFn_t musaHostFn_t +#define nv_bfloat16 mt_bfloat16 \ No newline at end of file diff --git a/setup.py b/setup.py index 345fdb1..ea15482 100644 --- a/setup.py +++ b/setup.py @@ -350,6 +350,7 @@ elif MUSA_HOME is not None: "at::cuda": "at::musa", "#include ": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"", "#include ": "#include \"torch_musa/csrc/core/MUSAGuard.h\"", + "nv_bfloat16": "mt_bfloat16", }).run() ops_module = MUSAExtension('KTransformersOps', [ 'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',