mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 20:19:51 +00:00
202 lines
No EOL
6.6 KiB
Python
202 lines
No EOL
6.6 KiB
Python
from __future__ import annotations
|
|
|
|
from enum import IntEnum, auto
|
|
from typing import Optional, Union, List
|
|
import torch
|
|
|
|
class GPUVendor(IntEnum):
|
|
NVIDIA = auto()
|
|
AMD = auto()
|
|
MooreThreads = auto()
|
|
MetaX = auto()
|
|
MUSA = auto()
|
|
Unknown = auto()
|
|
|
|
class DeviceManager:
|
|
"""
|
|
Device manager that provides a unified interface for handling different GPU vendors
|
|
"""
|
|
def __init__(self):
|
|
self.gpu_vendor = self._detect_gpu_vendor()
|
|
self.available_devices = self._get_available_devices()
|
|
|
|
def _detect_gpu_vendor(self) -> GPUVendor:
|
|
"""Detect GPU vendor type"""
|
|
if not torch.cuda.is_available():
|
|
# Check MUSA availability (assuming a musa module exists)
|
|
try:
|
|
import musa
|
|
if musa.is_available():
|
|
return GPUVendor.MUSA
|
|
except (ImportError, AttributeError):
|
|
pass
|
|
|
|
return GPUVendor.Unknown
|
|
|
|
device_name = torch.cuda.get_device_name(0).lower()
|
|
|
|
if any(name in device_name for name in ["nvidia", "geforce", "quadro", "tesla", "titan", "rtx", "gtx"]):
|
|
return GPUVendor.NVIDIA
|
|
elif any(name in device_name for name in ["amd", "radeon", "rx", "vega", "instinct", "firepro", "mi"]):
|
|
return GPUVendor.AMD
|
|
elif any(name in device_name for name in ["mthreads", "moore", "mtt"]):
|
|
return GPUVendor.MooreThreads
|
|
elif any(name in device_name for name in ["metax", "meta"]):
|
|
return GPUVendor.MetaX
|
|
elif "musa" in device_name:
|
|
return GPUVendor.MUSA
|
|
|
|
# Backend check
|
|
try:
|
|
if hasattr(torch.version, 'hip') and torch.version.hip is not None:
|
|
return GPUVendor.AMD
|
|
elif hasattr(torch.version, 'cuda') and torch.version.cuda is not None:
|
|
return GPUVendor.NVIDIA
|
|
except:
|
|
pass
|
|
|
|
return GPUVendor.Unknown
|
|
|
|
def _get_available_devices(self) -> List[int]:
|
|
"""Get list of available device indices"""
|
|
devices = []
|
|
|
|
if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD:
|
|
devices = list(range(torch.cuda.device_count()))
|
|
elif self.gpu_vendor == GPUVendor.MUSA:
|
|
try:
|
|
import musa
|
|
devices = list(range(musa.device_count()))
|
|
except (ImportError, AttributeError):
|
|
pass
|
|
|
|
return devices
|
|
|
|
def get_device_str(self, device_id: Union[int, str]) -> str:
|
|
"""
|
|
Get device string for the given device ID
|
|
|
|
Args:
|
|
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
|
|
|
|
Returns:
|
|
Device string representation (e.g., "cuda:0", "musa:1", "cpu")
|
|
"""
|
|
if device_id == -1 or device_id == "cpu":
|
|
return "cpu"
|
|
|
|
if isinstance(device_id, int):
|
|
if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD:
|
|
if device_id < torch.cuda.device_count():
|
|
return f"cuda:{device_id}"
|
|
elif self.gpu_vendor == GPUVendor.MUSA:
|
|
try:
|
|
import musa
|
|
if device_id < musa.device_count():
|
|
return f"musa:{device_id}"
|
|
except (ImportError, AttributeError):
|
|
pass
|
|
|
|
return "cpu"
|
|
|
|
def to_torch_device(self, device_id: Union[int, str] = 0) -> torch.device:
|
|
"""
|
|
Convert device ID to torch.device object
|
|
|
|
Args:
|
|
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
|
|
|
|
Returns:
|
|
torch.device object
|
|
"""
|
|
device_str = self.get_device_str(device_id)
|
|
|
|
# Handle MUSA device
|
|
if device_str.startswith("musa:"):
|
|
try:
|
|
import musa
|
|
index = int(device_str.split(":")[-1])
|
|
return musa.device(index)
|
|
except (ImportError, ValueError, AttributeError):
|
|
return torch.device("cpu")
|
|
|
|
# Standard PyTorch device
|
|
return torch.device(device_str)
|
|
|
|
def move_tensor_to_device(self, tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor:
|
|
"""
|
|
Move tensor to specified device
|
|
|
|
Args:
|
|
tensor: PyTorch tensor to move
|
|
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
|
|
|
|
Returns:
|
|
Tensor moved to the specified device
|
|
"""
|
|
device = self.to_torch_device(device_id)
|
|
return tensor.to(device)
|
|
|
|
def is_available(self, index: int = 0) -> bool:
|
|
"""
|
|
Check if device at specified index is available
|
|
|
|
Args:
|
|
index: Device index to check
|
|
|
|
Returns:
|
|
True if the device is available, False otherwise
|
|
"""
|
|
if index < 0:
|
|
return True # CPU is always available
|
|
|
|
return index in self.available_devices
|
|
|
|
def get_all_devices(self) -> List[int]:
|
|
"""
|
|
Get all available device indices
|
|
|
|
Returns:
|
|
List of available device indices (0, 1, 2, etc.)
|
|
"""
|
|
return self.available_devices
|
|
|
|
# Create global device manager instance
|
|
device_manager = DeviceManager()
|
|
|
|
# Convenience functions
|
|
def get_device(device_id: Union[int, str] = 0) -> torch.device:
|
|
"""
|
|
Get torch.device object for the specified device ID
|
|
|
|
Args:
|
|
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
|
|
|
|
Returns:
|
|
torch.device object
|
|
"""
|
|
return device_manager.to_torch_device(device_id)
|
|
|
|
def to_device(tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor:
|
|
"""
|
|
Move tensor to specified device
|
|
|
|
Args:
|
|
tensor: PyTorch tensor to move
|
|
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
|
|
|
|
Returns:
|
|
Tensor moved to the specified device
|
|
"""
|
|
return device_manager.move_tensor_to_device(tensor, device_id)
|
|
|
|
# Get devices
|
|
cpu_device = get_device(-1) # CPU using index -1
|
|
cpu_device2 = get_device("cpu") # CPU using string "cpu"
|
|
gpu0 = get_device(0) # First GPU
|
|
|
|
# Move tensors
|
|
x = torch.randn(3, 3)
|
|
x_gpu = to_device(x, 0) # Move to first GPU
|
|
x_cpu1 = to_device(x, -1) # Move to CPU using index -1
|
|
x_cpu2 = to_device(x, "cpu") # Move to CPU using string "cpu" |