fix bf16 load, TODO: refactor cpu dequant

This commit is contained in:
Atream 2025-02-23 15:37:09 +08:00 committed by GitHub
parent 94ab2de3b9
commit 036ae25a89
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -314,10 +314,12 @@ class GGUFLoader:
return values
def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = torch.get_default_dtype())->torch.Tensor:
def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = None)->torch.Tensor:
t = self.tensor_info[name]
if device.lower() == "cpu":
print(f"loading {name} with CPU")
if target_dtype == None:
target_dtype = torch.get_default_dtype()
shape = t["shape"]
ggml_type = t["ggml_type"]
@ -336,7 +338,7 @@ class GGUFLoader:
blocks_per_iter = 16384
if num_blocks > blocks_per_iter: # dequant large tensor
values = torch.empty((num_blocks, elements_per_block), dtype=torch.float, device=device)
values = torch.empty((num_blocks, elements_per_block), dtype=target_dtype, device=device)
for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter):
blocks_begin = i * blocks_per_iter
blocks_end = min(blocks_begin + blocks_per_iter, num_blocks)
@ -347,6 +349,8 @@ class GGUFLoader:
cur_values = torch.from_numpy(cur_values.copy())
cur_values = cur_values.view(-1, elements_per_block)
if ggml_name == "BF16":
cur_values = cur_values.view(torch.bfloat16)
values[blocks_begin : blocks_end] = cur_values
else:
if "cuda" in device.lower():