mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
Merge pull request #506 from makllama/musa
feat: Support Moore Threads GPU
This commit is contained in:
commit
25c5bddd08
8 changed files with 145 additions and 34 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -28,3 +28,4 @@ ktransformers/tests/chat_txt.txt
|
||||||
mmlu_result_q4km.json
|
mmlu_result_q4km.json
|
||||||
mmlu_result_q4km.log
|
mmlu_result_q4km.log
|
||||||
ktransformers/tests/mmlu_result_silicon.log
|
ktransformers/tests/mmlu_result_silicon.log
|
||||||
|
ktransformers/ktransformers_ext/cuda_musa/
|
||||||
|
|
|
@ -30,6 +30,8 @@ if (NOT MSVC)
|
||||||
option(LLAMA_F16C "llama: enable F16C" OFF)
|
option(LLAMA_F16C "llama: enable F16C" OFF)
|
||||||
endif()
|
endif()
|
||||||
option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF)
|
option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF)
|
||||||
|
option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" OFF)
|
||||||
|
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
|
||||||
|
|
||||||
# Architecture specific
|
# Architecture specific
|
||||||
# TODO: probably these flags need to be tweaked on some architectures
|
# TODO: probably these flags need to be tweaked on some architectures
|
||||||
|
@ -208,8 +210,31 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
|
||||||
if (WIN32)
|
if (WIN32)
|
||||||
include_directories("$ENV{CUDA_PATH}/include")
|
include_directories("$ENV{CUDA_PATH}/include")
|
||||||
elseif (UNIX)
|
elseif (UNIX)
|
||||||
find_package(CUDA REQUIRED)
|
if (KTRANSFORMERS_USE_CUDA)
|
||||||
include_directories("${CUDA_INCLUDE_DIRS}")
|
find_package(CUDA REQUIRED)
|
||||||
|
include_directories("${CUDA_INCLUDE_DIRS}")
|
||||||
|
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (KTRANSFORMERS_USE_MUSA)
|
||||||
|
if (NOT EXISTS $ENV{MUSA_PATH})
|
||||||
|
if (NOT EXISTS /opt/musa)
|
||||||
|
set(MUSA_PATH /usr/local/musa)
|
||||||
|
else()
|
||||||
|
set(MUSA_PATH /opt/musa)
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
set(MUSA_PATH $ENV{MUSA_PATH})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
|
||||||
|
|
||||||
|
find_package(MUSAToolkit)
|
||||||
|
if (MUSAToolkit_FOUND)
|
||||||
|
message(STATUS "MUSA Toolkit found")
|
||||||
|
add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
|
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
|
||||||
|
@ -225,10 +250,15 @@ target_link_libraries(${PROJECT_NAME} PRIVATE llama)
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart
|
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart
|
||||||
elseif(UNIX)
|
elseif(UNIX)
|
||||||
if(NOT DEFINED ENV{CUDA_HOME} OR "$ENV{CUDA_HOME}" STREQUAL "")
|
if(KTRANSFORMERS_USE_CUDA)
|
||||||
set(ENV{CUDA_HOME} "/usr/local/cuda")
|
if(NOT DEFINED ENV{CUDA_HOME} OR "$ENV{CUDA_HOME}" STREQUAL "")
|
||||||
|
set(ENV{CUDA_HOME} "/usr/local/cuda")
|
||||||
|
endif()
|
||||||
|
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
|
||||||
|
endif()
|
||||||
|
if(KTRANSFORMERS_USE_MUSA)
|
||||||
|
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
|
||||||
endif()
|
endif()
|
||||||
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Define the USE_NUMA option
|
# Define the USE_NUMA option
|
||||||
|
|
|
@ -17,7 +17,11 @@
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "cuda_runtime.h"
|
#ifdef KTRANSFORMERS_USE_CUDA
|
||||||
|
#include "vendors/cuda.h"
|
||||||
|
#elif KTRANSFORMERS_USE_MUSA
|
||||||
|
#include "vendors/musa.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "backend.h"
|
#include "backend.h"
|
||||||
#include "task_queue.h"
|
#include "task_queue.h"
|
||||||
|
|
3
ktransformers/ktransformers_ext/cpu_backend/vendors/README.md
vendored
Normal file
3
ktransformers/ktransformers_ext/cpu_backend/vendors/README.md
vendored
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
## TODO
|
||||||
|
|
||||||
|
This directory can be removed after updating the version of `llama.cpp`.
|
3
ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h
vendored
Normal file
3
ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h
vendored
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
7
ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h
vendored
Normal file
7
ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h
vendored
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <musa_runtime.h>
|
||||||
|
|
||||||
|
#define cudaLaunchHostFunc musaLaunchHostFunc
|
||||||
|
#define cudaStream_t musaStream_t
|
||||||
|
#define cudaHostFn_t musaHostFn_t
|
|
@ -1,15 +1,17 @@
|
||||||
/**
|
/**
|
||||||
* @Description :
|
* @Description :
|
||||||
* @Author : Azure-Tang
|
* @Author : Azure-Tang
|
||||||
* @Date : 2024-07-25 13:38:30
|
* @Date : 2024-07-25 13:38:30
|
||||||
* @Version : 1.0.0
|
* @Version : 1.0.0
|
||||||
* @LastEditors : kkk1nak0
|
* @LastEditors : kkk1nak0
|
||||||
* @LastEditTime : 2024-08-12 03:05:04
|
* @LastEditTime : 2024-08-12 03:05:04
|
||||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||||
**/
|
**/
|
||||||
|
|
||||||
#include "custom_gguf/ops.h"
|
#include "custom_gguf/ops.h"
|
||||||
|
#ifdef KTRANSFORMERS_USE_CUDA
|
||||||
#include "gptq_marlin/ops.h"
|
#include "gptq_marlin/ops.h"
|
||||||
|
#endif
|
||||||
// Python bindings
|
// Python bindings
|
||||||
#include <pybind11/pybind11.h>
|
#include <pybind11/pybind11.h>
|
||||||
#include <pybind11/stl.h>
|
#include <pybind11/stl.h>
|
||||||
|
@ -33,8 +35,10 @@ PYBIND11_MODULE(KTransformersOps, m) {
|
||||||
py::arg("data"), py::arg("blk_size"), py::arg("device"));
|
py::arg("data"), py::arg("blk_size"), py::arg("device"));
|
||||||
m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.",
|
m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.",
|
||||||
py::arg("data"), py::arg("blk_size"), py::arg("device"));
|
py::arg("data"), py::arg("blk_size"), py::arg("device"));
|
||||||
|
#ifdef KTRANSFORMERS_USE_CUDA
|
||||||
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.",
|
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.",
|
||||||
py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
|
py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
|
||||||
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"),
|
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"),
|
||||||
py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full"));
|
py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full"));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
111
setup.py
111
setup.py
|
@ -1,16 +1,16 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
'''
|
'''
|
||||||
Description :
|
Description :
|
||||||
Author : chenxl
|
Author : chenxl
|
||||||
Date : 2024-07-27 16:15:27
|
Date : 2024-07-27 16:15:27
|
||||||
Version : 1.0.0
|
Version : 1.0.0
|
||||||
LastEditors : chenxl
|
LastEditors : chenxl
|
||||||
LastEditTime : 2024-08-14 16:36:19
|
LastEditTime : 2024-08-14 16:36:19
|
||||||
Adapted from:
|
Adapted from:
|
||||||
https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py
|
https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py
|
||||||
Copyright (c) 2023, Tri Dao.
|
Copyright (c) 2023, Tri Dao.
|
||||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
@ -30,6 +30,11 @@ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
||||||
from setuptools import setup, Extension
|
from setuptools import setup, Extension
|
||||||
from cpufeature.extension import CPUFeature
|
from cpufeature.extension import CPUFeature
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
||||||
|
try:
|
||||||
|
from torch_musa.utils.simple_porting import SimplePorting
|
||||||
|
from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME
|
||||||
|
except ImportError:
|
||||||
|
MUSA_HOME=None
|
||||||
|
|
||||||
class CpuInstructInfo:
|
class CpuInstructInfo:
|
||||||
CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE")
|
CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE")
|
||||||
|
@ -40,7 +45,7 @@ class CpuInstructInfo:
|
||||||
CMAKE_FANCY = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON -DLLAMA_AVX512_FANCY_SIMD=ON"
|
CMAKE_FANCY = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON -DLLAMA_AVX512_FANCY_SIMD=ON"
|
||||||
CMAKE_AVX512 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON"
|
CMAKE_AVX512 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON"
|
||||||
CMAKE_AVX2 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON"
|
CMAKE_AVX2 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON"
|
||||||
|
|
||||||
class VersionInfo:
|
class VersionInfo:
|
||||||
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
PACKAGE_NAME = "ktransformers"
|
PACKAGE_NAME = "ktransformers"
|
||||||
|
@ -49,6 +54,16 @@ class VersionInfo:
|
||||||
)
|
)
|
||||||
FORCE_BUILD = os.getenv("KTRANSFORMERS_FORCE_BUILD", "FALSE") == "TRUE"
|
FORCE_BUILD = os.getenv("KTRANSFORMERS_FORCE_BUILD", "FALSE") == "TRUE"
|
||||||
|
|
||||||
|
def get_musa_bare_metal_version(self, musa_dir):
|
||||||
|
raw_output = subprocess.run(
|
||||||
|
[musa_dir + "/bin/mcc", "-v"], check=True,
|
||||||
|
stdout=subprocess.PIPE, stderr=subprocess.STDOUT).stdout.decode("utf-8")
|
||||||
|
output = raw_output.split()
|
||||||
|
release_idx = output.index("version") + 1
|
||||||
|
bare_metal_version = parse(output[release_idx].split(",")[0])
|
||||||
|
musa_version = f"{bare_metal_version.major}{bare_metal_version.minor}"
|
||||||
|
return musa_version
|
||||||
|
|
||||||
def get_cuda_bare_metal_version(self, cuda_dir):
|
def get_cuda_bare_metal_version(self, cuda_dir):
|
||||||
raw_output = subprocess.check_output(
|
raw_output = subprocess.check_output(
|
||||||
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||||
|
@ -58,7 +73,7 @@ class VersionInfo:
|
||||||
cuda_version = f"{bare_metal_version.major}{bare_metal_version.minor}"
|
cuda_version = f"{bare_metal_version.major}{bare_metal_version.minor}"
|
||||||
return cuda_version
|
return cuda_version
|
||||||
|
|
||||||
def get_cuda_version_of_torch(self,):
|
def get_cuda_version_of_torch(self):
|
||||||
torch_cuda_version = parse(torch.version.cuda)
|
torch_cuda_version = parse(torch.version.cuda)
|
||||||
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
|
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
|
||||||
return cuda_version
|
return cuda_version
|
||||||
|
@ -117,7 +132,7 @@ class VersionInfo:
|
||||||
torch_version_raw = parse(torch.__version__)
|
torch_version_raw = parse(torch.__version__)
|
||||||
torch_version = f"{torch_version_raw.major}{torch_version_raw.minor}"
|
torch_version = f"{torch_version_raw.major}{torch_version_raw.minor}"
|
||||||
return torch_version
|
return torch_version
|
||||||
|
|
||||||
def get_flash_version(self,):
|
def get_flash_version(self,):
|
||||||
version_file = os.path.join(
|
version_file = os.path.join(
|
||||||
Path(VersionInfo.THIS_DIR), VersionInfo.PACKAGE_NAME, "__init__.py")
|
Path(VersionInfo.THIS_DIR), VersionInfo.PACKAGE_NAME, "__init__.py")
|
||||||
|
@ -128,12 +143,21 @@ class VersionInfo:
|
||||||
return flash_version
|
return flash_version
|
||||||
|
|
||||||
def get_package_version(self, full_version=False):
|
def get_package_version(self, full_version=False):
|
||||||
flash_version = self.get_flash_version()
|
flash_version = str(self.get_flash_version())
|
||||||
package_version = f"{str(flash_version)}+cu{self.get_cuda_bare_metal_version(CUDA_HOME)}torch{self.get_torch_version()}{self.get_cpu_instruct()}"
|
torch_version = self.get_torch_version()
|
||||||
|
cpu_instruct = self.get_cpu_instruct()
|
||||||
|
backend_version = ""
|
||||||
|
if CUDA_HOME is not None:
|
||||||
|
backend_version = f"cu{self.get_cuda_bare_metal_version(CUDA_HOME)}"
|
||||||
|
elif MUSA_HOME is not None:
|
||||||
|
backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}"
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
|
||||||
|
package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}"
|
||||||
if full_version:
|
if full_version:
|
||||||
return package_version
|
return package_version
|
||||||
if not VersionInfo.FORCE_BUILD:
|
if not VersionInfo.FORCE_BUILD:
|
||||||
return str(flash_version)
|
return flash_version
|
||||||
return package_version
|
return package_version
|
||||||
|
|
||||||
|
|
||||||
|
@ -218,11 +242,19 @@ class CMakeBuild(BuildExtension):
|
||||||
f"-DPYTHON_EXECUTABLE={sys.executable}",
|
f"-DPYTHON_EXECUTABLE={sys.executable}",
|
||||||
f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm
|
f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if CUDA_HOME is not None:
|
||||||
|
cmake_args += ["-DKTRANSFORMERS_USE_CUDA=ON"]
|
||||||
|
elif MUSA_HOME is not None:
|
||||||
|
cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"]
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
|
||||||
|
|
||||||
build_args = []
|
build_args = []
|
||||||
if "CMAKE_ARGS" in os.environ:
|
if "CMAKE_ARGS" in os.environ:
|
||||||
cmake_args += [
|
cmake_args += [
|
||||||
item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
|
item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
|
||||||
|
|
||||||
if CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.FANCY:
|
if CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.FANCY:
|
||||||
cpu_args = CpuInstructInfo.CMAKE_FANCY
|
cpu_args = CpuInstructInfo.CMAKE_FANCY
|
||||||
elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX512:
|
elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX512:
|
||||||
|
@ -231,7 +263,7 @@ class CMakeBuild(BuildExtension):
|
||||||
cpu_args = CpuInstructInfo.CMAKE_AVX2
|
cpu_args = CpuInstructInfo.CMAKE_AVX2
|
||||||
else:
|
else:
|
||||||
cpu_args = CpuInstructInfo.CMAKE_NATIVE
|
cpu_args = CpuInstructInfo.CMAKE_NATIVE
|
||||||
|
|
||||||
cmake_args += [
|
cmake_args += [
|
||||||
item for item in cpu_args.split(" ") if item
|
item for item in cpu_args.split(" ") if item
|
||||||
]
|
]
|
||||||
|
@ -288,28 +320,55 @@ class CMakeBuild(BuildExtension):
|
||||||
print("Standard output:", result.stdout)
|
print("Standard output:", result.stdout)
|
||||||
print("Standard error:", result.stderr)
|
print("Standard error:", result.stderr)
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
|
["cmake", "--build", ".", "--verbose", *build_args], cwd=build_temp, check=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if CUDA_HOME is not None:
|
||||||
|
ops_module = CUDAExtension('KTransformersOps', [
|
||||||
|
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
|
||||||
|
'ktransformers/ktransformers_ext/cuda/binding.cpp',
|
||||||
|
'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
|
||||||
|
],
|
||||||
|
extra_compile_args={
|
||||||
|
'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'],
|
||||||
|
'nvcc': [
|
||||||
|
'-O3',
|
||||||
|
'--use_fast_math',
|
||||||
|
'-Xcompiler', '-fPIC',
|
||||||
|
'-DKTRANSFORMERS_USE_CUDA',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif MUSA_HOME is not None:
|
||||||
|
SimplePorting(cuda_dir_path="ktransformers/ktransformers_ext/cuda", mapping_rule={
|
||||||
|
# Common rules
|
||||||
|
"at::cuda": "at::musa",
|
||||||
|
"#include <ATen/cuda/CUDAContext.h>": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"",
|
||||||
|
"#include <c10/cuda/CUDAGuard.h>": "#include \"torch_musa/csrc/core/MUSAGuard.h\"",
|
||||||
|
}).run()
|
||||||
|
ops_module = MUSAExtension('KTransformersOps', [
|
||||||
|
'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',
|
||||||
|
'ktransformers/ktransformers_ext/cuda_musa/binding.cpp',
|
||||||
|
# TODO: Add Marlin support for MUSA.
|
||||||
|
# 'ktransformers/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu'
|
||||||
|
],
|
||||||
|
extra_compile_args={
|
||||||
|
'cxx': ['force_mcc'],
|
||||||
|
'mcc': [
|
||||||
|
'-O3',
|
||||||
|
'-DKTRANSFORMERS_USE_MUSA',
|
||||||
|
'-DTHRUST_IGNORE_CUB_VERSION_CHECK',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
version=VersionInfo().get_package_version(),
|
version=VersionInfo().get_package_version(),
|
||||||
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
|
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
|
||||||
ext_modules=[
|
ext_modules=[
|
||||||
CMakeExtension("cpuinfer_ext"),
|
CMakeExtension("cpuinfer_ext"),
|
||||||
CUDAExtension('KTransformersOps', [
|
ops_module,
|
||||||
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
|
|
||||||
'ktransformers/ktransformers_ext/cuda/binding.cpp',
|
|
||||||
'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
|
|
||||||
],
|
|
||||||
extra_compile_args={
|
|
||||||
'cxx': ['-O3'],
|
|
||||||
'nvcc': [
|
|
||||||
'-O3',
|
|
||||||
'--use_fast_math',
|
|
||||||
'-Xcompiler', '-fPIC',
|
|
||||||
]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue