From 18b1d1836719f8cb3b443b0ddeb683575d0c0a54 Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Sun, 23 Feb 2025 10:19:19 +0800 Subject: [PATCH] musa: support bf16 Signed-off-by: Xiaodong Ye --- ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h | 4 +++- setup.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h b/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h index 7c94102..1892221 100644 --- a/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h +++ b/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h @@ -1,7 +1,9 @@ #pragma once #include +#include #define cudaLaunchHostFunc musaLaunchHostFunc #define cudaStream_t musaStream_t -#define cudaHostFn_t musaHostFn_t \ No newline at end of file +#define cudaHostFn_t musaHostFn_t +#define nv_bfloat16 mt_bfloat16 \ No newline at end of file diff --git a/setup.py b/setup.py index 345fdb1..ea15482 100644 --- a/setup.py +++ b/setup.py @@ -350,6 +350,7 @@ elif MUSA_HOME is not None: "at::cuda": "at::musa", "#include ": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"", "#include ": "#include \"torch_musa/csrc/core/MUSAGuard.h\"", + "nv_bfloat16": "mt_bfloat16", }).run() ops_module = MUSAExtension('KTransformersOps', [ 'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',