mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 20:49:55 +00:00
Merge pull request #1005 from fishingfly/improve/backend-error-msg
fix: refine backend error message to include ROCM_HOME
This commit is contained in:
commit
ec12429c46
1 changed files with 9 additions and 9 deletions
18
setup.py
18
setup.py
|
@ -69,17 +69,17 @@ class VersionInfo:
|
|||
def get_rocm_bare_metal_version(self, rocm_dir):
|
||||
"""
|
||||
Get the ROCm version from the ROCm installation directory.
|
||||
|
||||
|
||||
Args:
|
||||
rocm_dir: Path to the ROCm installation directory
|
||||
|
||||
|
||||
Returns:
|
||||
A string representation of the ROCm version (e.g., "63" for ROCm 6.3)
|
||||
"""
|
||||
try:
|
||||
# Try using rocm_agent_enumerator to get version info
|
||||
raw_output = subprocess.check_output(
|
||||
[rocm_dir + "/bin/rocminfo", "--version"],
|
||||
[rocm_dir + "/bin/rocminfo", "--version"],
|
||||
universal_newlines=True,
|
||||
stderr=subprocess.STDOUT)
|
||||
# Extract version number from output
|
||||
|
@ -92,7 +92,7 @@ class VersionInfo:
|
|||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
# If rocminfo --version fails, try alternative methods
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
# Try reading version from release file
|
||||
with open(os.path.join(rocm_dir, "share/doc/hip/version.txt"), "r") as f:
|
||||
|
@ -102,7 +102,7 @@ class VersionInfo:
|
|||
return rocm_version
|
||||
except (FileNotFoundError, IOError):
|
||||
pass
|
||||
|
||||
|
||||
# If all else fails, try to extract from directory name
|
||||
dir_name = os.path.basename(os.path.normpath(rocm_dir))
|
||||
match = re.search(r'rocm-(\d+\.\d+)', dir_name)
|
||||
|
@ -111,7 +111,7 @@ class VersionInfo:
|
|||
version = parse(version_str)
|
||||
rocm_version = f"{version.major}{version.minor}"
|
||||
return rocm_version
|
||||
|
||||
|
||||
# Fallback to extracting from hipcc version
|
||||
try:
|
||||
raw_output = subprocess.check_output(
|
||||
|
@ -126,7 +126,7 @@ class VersionInfo:
|
|||
return rocm_version
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
|
||||
# If we still can't determine the version, raise an error
|
||||
raise ValueError(f"Could not determine ROCm version from directory: {rocm_dir}")
|
||||
|
||||
|
@ -317,10 +317,10 @@ class CMakeBuild(BuildExtension):
|
|||
elif ROCM_HOME is not None:
|
||||
cmake_args += ["-DKTRANSFORMERS_USE_ROCM=ON"]
|
||||
else:
|
||||
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
|
||||
raise ValueError("Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set.")
|
||||
# log cmake_args
|
||||
print("CMake args:", cmake_args)
|
||||
|
||||
|
||||
build_args = []
|
||||
if "CMAKE_ARGS" in os.environ:
|
||||
cmake_args += [
|
||||
|
|
Loading…
Add table
Reference in a new issue