fix: use 'cuda:0' by default if torch_device is 'cuda'

This commit is contained in:
ceerrep 2025-02-18 11:15:17 +08:00
parent ee24eb8dc3
commit c70b6f4d5b

View file

@ -130,6 +130,7 @@ class KTransformersInterface(TransformersInterface):
logger.debug(f"input_ids: {input_ids.shape}")
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
device = "cuda:0" if device == "cuda" else device
if is_new:
self.ever_generated_ids.clear()