fix: refine backend error message to include ROCM_HOME

Signed-off-by: fishingfly <zhoyuzf@163.com>
This commit is contained in:
fishingfly 2025-04-01 10:50:38 +08:00
parent f142f4dff3
commit 7549ff335a

View file

@ -67,17 +67,17 @@ class VersionInfo:
def get_rocm_bare_metal_version(self, rocm_dir):
"""
Get the ROCm version from the ROCm installation directory.
Args:
rocm_dir: Path to the ROCm installation directory
Returns:
A string representation of the ROCm version (e.g., "63" for ROCm 6.3)
"""
try:
# Try using rocm_agent_enumerator to get version info
raw_output = subprocess.check_output(
[rocm_dir + "/bin/rocminfo", "--version"],
[rocm_dir + "/bin/rocminfo", "--version"],
universal_newlines=True,
stderr=subprocess.STDOUT)
# Extract version number from output
@ -90,7 +90,7 @@ class VersionInfo:
except (subprocess.CalledProcessError, FileNotFoundError):
# If rocminfo --version fails, try alternative methods
pass
try:
# Try reading version from release file
with open(os.path.join(rocm_dir, "share/doc/hip/version.txt"), "r") as f:
@ -100,7 +100,7 @@ class VersionInfo:
return rocm_version
except (FileNotFoundError, IOError):
pass
# If all else fails, try to extract from directory name
dir_name = os.path.basename(os.path.normpath(rocm_dir))
match = re.search(r'rocm-(\d+\.\d+)', dir_name)
@ -109,7 +109,7 @@ class VersionInfo:
version = parse(version_str)
rocm_version = f"{version.major}{version.minor}"
return rocm_version
# Fallback to extracting from hipcc version
try:
raw_output = subprocess.check_output(
@ -124,7 +124,7 @@ class VersionInfo:
return rocm_version
except (subprocess.CalledProcessError, FileNotFoundError):
pass
# If we still can't determine the version, raise an error
raise ValueError(f"Could not determine ROCm version from directory: {rocm_dir}")
@ -316,10 +316,10 @@ class CMakeBuild(BuildExtension):
elif ROCM_HOME is not None:
cmake_args += ["-DKTRANSFORMERS_USE_ROCM=ON"]
else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
raise ValueError("Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set.")
# log cmake_args
print("CMake args:", cmake_args)
build_args = []
if "CMAKE_ARGS" in os.environ:
cmake_args += [