mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 05:54:06 +00:00
Merge pull request #626 from cyhasuka/main
Feat: Clear cache during weight loading to prevent OOM on GPUs with <=8GB VRAM
This commit is contained in:
commit
1d5d5faef6
1 changed files with 1 additions and 2 deletions
|
@ -92,8 +92,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
|
||||||
target_dtype = torch.get_default_dtype()
|
target_dtype = torch.get_default_dtype()
|
||||||
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
|
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
|
||||||
print(f"loading {translated_key} to {device}")
|
print(f"loading {translated_key} to {device}")
|
||||||
torch.cuda.empty_cache() # To fit in 16G VRAM. By "wkGCaSS - 知乎 https://zhuanlan.zhihu.com/p/25491611225"
|
torch.cuda.empty_cache()
|
||||||
# weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
|
|
||||||
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
|
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
|
||||||
set_param(module, name, weights)
|
set_param(module, name, weights)
|
||||||
del weights
|
del weights
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue