From 7549ff335adecb9b840bddf8b77f9de1a05006ea Mon Sep 17 00:00:00 2001 From: fishingfly Date: Tue, 1 Apr 2025 10:50:38 +0800 Subject: [PATCH] fix: refine backend error message to include ROCM_HOME Signed-off-by: fishingfly --- setup.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index 5c29b8f..fc8a7bd 100644 --- a/setup.py +++ b/setup.py @@ -67,17 +67,17 @@ class VersionInfo: def get_rocm_bare_metal_version(self, rocm_dir): """ Get the ROCm version from the ROCm installation directory. - + Args: rocm_dir: Path to the ROCm installation directory - + Returns: A string representation of the ROCm version (e.g., "63" for ROCm 6.3) """ try: # Try using rocm_agent_enumerator to get version info raw_output = subprocess.check_output( - [rocm_dir + "/bin/rocminfo", "--version"], + [rocm_dir + "/bin/rocminfo", "--version"], universal_newlines=True, stderr=subprocess.STDOUT) # Extract version number from output @@ -90,7 +90,7 @@ class VersionInfo: except (subprocess.CalledProcessError, FileNotFoundError): # If rocminfo --version fails, try alternative methods pass - + try: # Try reading version from release file with open(os.path.join(rocm_dir, "share/doc/hip/version.txt"), "r") as f: @@ -100,7 +100,7 @@ class VersionInfo: return rocm_version except (FileNotFoundError, IOError): pass - + # If all else fails, try to extract from directory name dir_name = os.path.basename(os.path.normpath(rocm_dir)) match = re.search(r'rocm-(\d+\.\d+)', dir_name) @@ -109,7 +109,7 @@ class VersionInfo: version = parse(version_str) rocm_version = f"{version.major}{version.minor}" return rocm_version - + # Fallback to extracting from hipcc version try: raw_output = subprocess.check_output( @@ -124,7 +124,7 @@ class VersionInfo: return rocm_version except (subprocess.CalledProcessError, FileNotFoundError): pass - + # If we still can't determine the version, raise an error raise ValueError(f"Could not determine ROCm version from directory: {rocm_dir}") @@ -316,10 +316,10 @@ class CMakeBuild(BuildExtension): elif ROCM_HOME is not None: cmake_args += ["-DKTRANSFORMERS_USE_ROCM=ON"] else: - raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.") + raise ValueError("Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set.") # log cmake_args print("CMake args:", cmake_args) - + build_args = [] if "CMAKE_ARGS" in os.environ: cmake_args += [