diff --git a/setup.py b/setup.py index e13ceb7..6eccef2 100644 --- a/setup.py +++ b/setup.py @@ -69,17 +69,17 @@ class VersionInfo: def get_rocm_bare_metal_version(self, rocm_dir): """ Get the ROCm version from the ROCm installation directory. - + Args: rocm_dir: Path to the ROCm installation directory - + Returns: A string representation of the ROCm version (e.g., "63" for ROCm 6.3) """ try: # Try using rocm_agent_enumerator to get version info raw_output = subprocess.check_output( - [rocm_dir + "/bin/rocminfo", "--version"], + [rocm_dir + "/bin/rocminfo", "--version"], universal_newlines=True, stderr=subprocess.STDOUT) # Extract version number from output @@ -92,7 +92,7 @@ class VersionInfo: except (subprocess.CalledProcessError, FileNotFoundError): # If rocminfo --version fails, try alternative methods pass - + try: # Try reading version from release file with open(os.path.join(rocm_dir, "share/doc/hip/version.txt"), "r") as f: @@ -102,7 +102,7 @@ class VersionInfo: return rocm_version except (FileNotFoundError, IOError): pass - + # If all else fails, try to extract from directory name dir_name = os.path.basename(os.path.normpath(rocm_dir)) match = re.search(r'rocm-(\d+\.\d+)', dir_name) @@ -111,7 +111,7 @@ class VersionInfo: version = parse(version_str) rocm_version = f"{version.major}{version.minor}" return rocm_version - + # Fallback to extracting from hipcc version try: raw_output = subprocess.check_output( @@ -126,7 +126,7 @@ class VersionInfo: return rocm_version except (subprocess.CalledProcessError, FileNotFoundError): pass - + # If we still can't determine the version, raise an error raise ValueError(f"Could not determine ROCm version from directory: {rocm_dir}") @@ -317,10 +317,10 @@ class CMakeBuild(BuildExtension): elif ROCM_HOME is not None: cmake_args += ["-DKTRANSFORMERS_USE_ROCM=ON"] else: - raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.") + raise ValueError("Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set.") # log cmake_args print("CMake args:", cmake_args) - + build_args = [] if "CMAKE_ARGS" in os.environ: cmake_args += [