mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 23:34:35 +00:00
merge main; Add torch q8 linear
This commit is contained in:
parent
6c4ed59175
commit
ed8437413b
27 changed files with 1561 additions and 114 deletions
83
setup.py
83
setup.py
|
@ -29,7 +29,7 @@ import torch.version
|
|||
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
||||
from setuptools import setup, Extension
|
||||
from cpufeature.extension import CPUFeature
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
|
||||
try:
|
||||
from torch_musa.utils.simple_porting import SimplePorting
|
||||
from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME
|
||||
|
@ -64,6 +64,70 @@ class VersionInfo:
|
|||
musa_version = f"{bare_metal_version.major}{bare_metal_version.minor}"
|
||||
return musa_version
|
||||
|
||||
def get_rocm_bare_metal_version(self, rocm_dir):
|
||||
"""
|
||||
Get the ROCm version from the ROCm installation directory.
|
||||
|
||||
Args:
|
||||
rocm_dir: Path to the ROCm installation directory
|
||||
|
||||
Returns:
|
||||
A string representation of the ROCm version (e.g., "63" for ROCm 6.3)
|
||||
"""
|
||||
try:
|
||||
# Try using rocm_agent_enumerator to get version info
|
||||
raw_output = subprocess.check_output(
|
||||
[rocm_dir + "/bin/rocminfo", "--version"],
|
||||
universal_newlines=True,
|
||||
stderr=subprocess.STDOUT)
|
||||
# Extract version number from output
|
||||
match = re.search(r'(\d+\.\d+)', raw_output)
|
||||
if match:
|
||||
version_str = match.group(1)
|
||||
version = parse(version_str)
|
||||
rocm_version = f"{version.major}{version.minor}"
|
||||
return rocm_version
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
# If rocminfo --version fails, try alternative methods
|
||||
pass
|
||||
|
||||
try:
|
||||
# Try reading version from release file
|
||||
with open(os.path.join(rocm_dir, "share/doc/hip/version.txt"), "r") as f:
|
||||
version_str = f.read().strip()
|
||||
version = parse(version_str)
|
||||
rocm_version = f"{version.major}{version.minor}"
|
||||
return rocm_version
|
||||
except (FileNotFoundError, IOError):
|
||||
pass
|
||||
|
||||
# If all else fails, try to extract from directory name
|
||||
dir_name = os.path.basename(os.path.normpath(rocm_dir))
|
||||
match = re.search(r'rocm-(\d+\.\d+)', dir_name)
|
||||
if match:
|
||||
version_str = match.group(1)
|
||||
version = parse(version_str)
|
||||
rocm_version = f"{version.major}{version.minor}"
|
||||
return rocm_version
|
||||
|
||||
# Fallback to extracting from hipcc version
|
||||
try:
|
||||
raw_output = subprocess.check_output(
|
||||
[rocm_dir + "/bin/hipcc", "--version"],
|
||||
universal_newlines=True,
|
||||
stderr=subprocess.STDOUT)
|
||||
match = re.search(r'HIP version: (\d+\.\d+)', raw_output)
|
||||
if match:
|
||||
version_str = match.group(1)
|
||||
version = parse(version_str)
|
||||
rocm_version = f"{version.major}{version.minor}"
|
||||
return rocm_version
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
# If we still can't determine the version, raise an error
|
||||
raise ValueError(f"Could not determine ROCm version from directory: {rocm_dir}")
|
||||
|
||||
def get_cuda_bare_metal_version(self, cuda_dir):
|
||||
raw_output = subprocess.check_output(
|
||||
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
|
@ -148,11 +212,13 @@ class VersionInfo:
|
|||
cpu_instruct = self.get_cpu_instruct()
|
||||
backend_version = ""
|
||||
if CUDA_HOME is not None:
|
||||
backend_version = f"cu{self.get_cuda_bare_metal_version(CUDA_HOME)}"
|
||||
backend_version = f""
|
||||
elif MUSA_HOME is not None:
|
||||
backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}"
|
||||
elif ROCM_HOME is not None:
|
||||
backend_version = f"rocm{self.get_rocm_bare_metal_version(ROCM_HOME)}"
|
||||
else:
|
||||
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
|
||||
raise ValueError("Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set.")
|
||||
package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}"
|
||||
if full_version:
|
||||
return package_version
|
||||
|
@ -247,9 +313,13 @@ class CMakeBuild(BuildExtension):
|
|||
cmake_args += ["-DKTRANSFORMERS_USE_CUDA=ON"]
|
||||
elif MUSA_HOME is not None:
|
||||
cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"]
|
||||
elif ROCM_HOME is not None:
|
||||
cmake_args += ["-DKTRANSFORMERS_USE_ROCM=ON"]
|
||||
else:
|
||||
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
|
||||
|
||||
# log cmake_args
|
||||
print("CMake args:", cmake_args)
|
||||
|
||||
build_args = []
|
||||
if "CMAKE_ARGS" in os.environ:
|
||||
cmake_args += [
|
||||
|
@ -328,7 +398,7 @@ class CMakeBuild(BuildExtension):
|
|||
["cmake", "--build", ".", "--verbose", *build_args], cwd=build_temp, check=True
|
||||
)
|
||||
|
||||
if CUDA_HOME is not None:
|
||||
if CUDA_HOME is not None or ROCM_HOME is not None:
|
||||
ops_module = CUDAExtension('KTransformersOps', [
|
||||
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
|
||||
'ktransformers/ktransformers_ext/cuda/binding.cpp',
|
||||
|
@ -338,7 +408,7 @@ if CUDA_HOME is not None:
|
|||
'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'],
|
||||
'nvcc': [
|
||||
'-O3',
|
||||
'--use_fast_math',
|
||||
# '--use_fast_math',
|
||||
'-Xcompiler', '-fPIC',
|
||||
'-DKTRANSFORMERS_USE_CUDA',
|
||||
]
|
||||
|
@ -371,6 +441,7 @@ else:
|
|||
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
|
||||
|
||||
setup(
|
||||
name=VersionInfo.PACKAGE_NAME,
|
||||
version=VersionInfo().get_package_version(),
|
||||
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
|
||||
ext_modules=[
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue