diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 86b97d7..7ecc637 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -130,6 +130,7 @@ class KTransformersInterface(TransformersInterface): logger.debug(f"input_ids: {input_ids.shape}") device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") + device = "cuda:0" if device == "cuda" else device if is_new: self.ever_generated_ids.clear()