kvcache-ai-ktransformers/ktransformers/util/vendors.py
2025-03-14 05:52:07 -04:00

202 lines
No EOL
6.6 KiB
Python

from __future__ import annotations
from enum import IntEnum, auto
from typing import Optional, Union, List
import torch
class GPUVendor(IntEnum):
NVIDIA = auto()
AMD = auto()
MooreThreads = auto()
MetaX = auto()
MUSA = auto()
Unknown = auto()
class DeviceManager:
"""
Device manager that provides a unified interface for handling different GPU vendors
"""
def __init__(self):
self.gpu_vendor = self._detect_gpu_vendor()
self.available_devices = self._get_available_devices()
def _detect_gpu_vendor(self) -> GPUVendor:
"""Detect GPU vendor type"""
if not torch.cuda.is_available():
# Check MUSA availability (assuming a musa module exists)
try:
import musa
if musa.is_available():
return GPUVendor.MUSA
except (ImportError, AttributeError):
pass
return GPUVendor.Unknown
device_name = torch.cuda.get_device_name(0).lower()
if any(name in device_name for name in ["nvidia", "geforce", "quadro", "tesla", "titan", "rtx", "gtx"]):
return GPUVendor.NVIDIA
elif any(name in device_name for name in ["amd", "radeon", "rx", "vega", "instinct", "firepro", "mi"]):
return GPUVendor.AMD
elif any(name in device_name for name in ["mthreads", "moore", "mtt"]):
return GPUVendor.MooreThreads
elif any(name in device_name for name in ["metax", "meta"]):
return GPUVendor.MetaX
elif "musa" in device_name:
return GPUVendor.MUSA
# Backend check
try:
if hasattr(torch.version, 'hip') and torch.version.hip is not None:
return GPUVendor.AMD
elif hasattr(torch.version, 'cuda') and torch.version.cuda is not None:
return GPUVendor.NVIDIA
except:
pass
return GPUVendor.Unknown
def _get_available_devices(self) -> List[int]:
"""Get list of available device indices"""
devices = []
if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD:
devices = list(range(torch.cuda.device_count()))
elif self.gpu_vendor == GPUVendor.MUSA:
try:
import musa
devices = list(range(musa.device_count()))
except (ImportError, AttributeError):
pass
return devices
def get_device_str(self, device_id: Union[int, str]) -> str:
"""
Get device string for the given device ID
Args:
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
Device string representation (e.g., "cuda:0", "musa:1", "cpu")
"""
if device_id == -1 or device_id == "cpu":
return "cpu"
if isinstance(device_id, int):
if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD:
if device_id < torch.cuda.device_count():
return f"cuda:{device_id}"
elif self.gpu_vendor == GPUVendor.MUSA:
try:
import musa
if device_id < musa.device_count():
return f"musa:{device_id}"
except (ImportError, AttributeError):
pass
return "cpu"
def to_torch_device(self, device_id: Union[int, str] = 0) -> torch.device:
"""
Convert device ID to torch.device object
Args:
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
torch.device object
"""
device_str = self.get_device_str(device_id)
# Handle MUSA device
if device_str.startswith("musa:"):
try:
import musa
index = int(device_str.split(":")[-1])
return musa.device(index)
except (ImportError, ValueError, AttributeError):
return torch.device("cpu")
# Standard PyTorch device
return torch.device(device_str)
def move_tensor_to_device(self, tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor:
"""
Move tensor to specified device
Args:
tensor: PyTorch tensor to move
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
Tensor moved to the specified device
"""
device = self.to_torch_device(device_id)
return tensor.to(device)
def is_available(self, index: int = 0) -> bool:
"""
Check if device at specified index is available
Args:
index: Device index to check
Returns:
True if the device is available, False otherwise
"""
if index < 0:
return True # CPU is always available
return index in self.available_devices
def get_all_devices(self) -> List[int]:
"""
Get all available device indices
Returns:
List of available device indices (0, 1, 2, etc.)
"""
return self.available_devices
# Create global device manager instance
device_manager = DeviceManager()
# Convenience functions
def get_device(device_id: Union[int, str] = 0) -> torch.device:
"""
Get torch.device object for the specified device ID
Args:
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
torch.device object
"""
return device_manager.to_torch_device(device_id)
def to_device(tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor:
"""
Move tensor to specified device
Args:
tensor: PyTorch tensor to move
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
Tensor moved to the specified device
"""
return device_manager.move_tensor_to_device(tensor, device_id)
# Get devices
cpu_device = get_device(-1) # CPU using index -1
cpu_device2 = get_device("cpu") # CPU using string "cpu"
gpu0 = get_device(0) # First GPU
# Move tensors
x = torch.randn(3, 3)
x_gpu = to_device(x, 0) # Move to first GPU
x_cpu1 = to_device(x, -1) # Move to CPU using index -1
x_cpu2 = to_device(x, "cpu") # Move to CPU using string "cpu"