Add data loader to read special weights for fp8; Add special weight process script

This commit is contained in:
Azure 2025-02-24 11:16:23 +00:00
parent 7b7c6a657d
commit 581a524f65
10 changed files with 481 additions and 26 deletions

View file

@ -66,12 +66,23 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
for name, param in local_state.items():
key = prefix + name
translated_key = translate_name_to_gguf(key)
if translated_key in gguf_loader.tensor_file_map:
# TODO: Merge all loader.
# I know this is ugly but lets do it for now.
if gguf_loader.safetensor_loader is not None:
load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor
tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map
else:
load_dequantized_tensor = gguf_loader.load_gguf_tensor
tensor_file_map = gguf_loader.tensor_file_map
if translated_key in tensor_file_map:
target_dtype = torch.get_default_dtype()
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
print(f"loading {translated_key} to {device}")
torch.cuda.empty_cache() # To fit in 16G VRAM. By "wkGCaSS - 知乎 https://zhuanlan.zhihu.com/p/25491611225"
weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
# weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
set_param(module, name, weights)
del weights
else: