diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 3f5ad8e..87bbd2b 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -92,8 +92,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str target_dtype = torch.get_default_dtype() device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map) print(f"loading {translated_key} to {device}") - torch.cuda.empty_cache() # To fit in 16G VRAM. By "wkGCaSS - 知乎 https://zhuanlan.zhihu.com/p/25491611225" - # weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype) + torch.cuda.empty_cache() weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype) set_param(module, name, weights) del weights