From c70b6f4d5b4021267daeb2acbae8ac36d6de2260 Mon Sep 17 00:00:00 2001 From: ceerrep Date: Tue, 18 Feb 2025 11:15:17 +0800 Subject: [PATCH] fix: use 'cuda:0' by default if torch_device is 'cuda' --- ktransformers/server/backend/interfaces/ktransformers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 86b97d7..7ecc637 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -130,6 +130,7 @@ class KTransformersInterface(TransformersInterface): logger.debug(f"input_ids: {input_ids.shape}") device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") + device = "cuda:0" if device == "cuda" else device if is_new: self.ever_generated_ids.clear()