Merge pull request #1005 from fishingfly/improve/backend-error-msg

fix: refine backend error message to include ROCM_HOME
This commit is contained in:
Atream 2025-04-02 14:54:23 +08:00 committed by GitHub
commit ec12429c46
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -69,17 +69,17 @@ class VersionInfo:
def get_rocm_bare_metal_version(self, rocm_dir):
"""
Get the ROCm version from the ROCm installation directory.
Args:
rocm_dir: Path to the ROCm installation directory
Returns:
A string representation of the ROCm version (e.g., "63" for ROCm 6.3)
"""
try:
# Try using rocm_agent_enumerator to get version info
raw_output = subprocess.check_output(
[rocm_dir + "/bin/rocminfo", "--version"],
[rocm_dir + "/bin/rocminfo", "--version"],
universal_newlines=True,
stderr=subprocess.STDOUT)
# Extract version number from output
@ -92,7 +92,7 @@ class VersionInfo:
except (subprocess.CalledProcessError, FileNotFoundError):
# If rocminfo --version fails, try alternative methods
pass
try:
# Try reading version from release file
with open(os.path.join(rocm_dir, "share/doc/hip/version.txt"), "r") as f:
@ -102,7 +102,7 @@ class VersionInfo:
return rocm_version
except (FileNotFoundError, IOError):
pass
# If all else fails, try to extract from directory name
dir_name = os.path.basename(os.path.normpath(rocm_dir))
match = re.search(r'rocm-(\d+\.\d+)', dir_name)
@ -111,7 +111,7 @@ class VersionInfo:
version = parse(version_str)
rocm_version = f"{version.major}{version.minor}"
return rocm_version
# Fallback to extracting from hipcc version
try:
raw_output = subprocess.check_output(
@ -126,7 +126,7 @@ class VersionInfo:
return rocm_version
except (subprocess.CalledProcessError, FileNotFoundError):
pass
# If we still can't determine the version, raise an error
raise ValueError(f"Could not determine ROCm version from directory: {rocm_dir}")
@ -317,10 +317,10 @@ class CMakeBuild(BuildExtension):
elif ROCM_HOME is not None:
cmake_args += ["-DKTRANSFORMERS_USE_ROCM=ON"]
else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
raise ValueError("Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set.")
# log cmake_args
print("CMake args:", cmake_args)
build_args = []
if "CMAKE_ARGS" in os.environ:
cmake_args += [