Merge pull request #612 from kvcache-ai/fix-bf16-load

fix bf16 load, TODO: refactor cpu dequant
This commit is contained in:
Atream 2025-02-23 15:37:23 +08:00 committed by GitHub
commit cdb6f896bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -314,10 +314,12 @@ class GGUFLoader:
return values return values
def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = torch.get_default_dtype())->torch.Tensor: def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = None)->torch.Tensor:
t = self.tensor_info[name] t = self.tensor_info[name]
if device.lower() == "cpu": if device.lower() == "cpu":
print(f"loading {name} with CPU") print(f"loading {name} with CPU")
if target_dtype == None:
target_dtype = torch.get_default_dtype()
shape = t["shape"] shape = t["shape"]
ggml_type = t["ggml_type"] ggml_type = t["ggml_type"]
@ -336,7 +338,7 @@ class GGUFLoader:
blocks_per_iter = 16384 blocks_per_iter = 16384
if num_blocks > blocks_per_iter: # dequant large tensor if num_blocks > blocks_per_iter: # dequant large tensor
values = torch.empty((num_blocks, elements_per_block), dtype=torch.float, device=device) values = torch.empty((num_blocks, elements_per_block), dtype=target_dtype, device=device)
for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter): for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter):
blocks_begin = i * blocks_per_iter blocks_begin = i * blocks_per_iter
blocks_end = min(blocks_begin + blocks_per_iter, num_blocks) blocks_end = min(blocks_begin + blocks_per_iter, num_blocks)
@ -347,6 +349,8 @@ class GGUFLoader:
cur_values = torch.from_numpy(cur_values.copy()) cur_values = torch.from_numpy(cur_values.copy())
cur_values = cur_values.view(-1, elements_per_block) cur_values = cur_values.view(-1, elements_per_block)
if ggml_name == "BF16":
cur_values = cur_values.view(torch.bfloat16)
values[blocks_begin : blocks_end] = cur_values values[blocks_begin : blocks_end] = cur_values
else: else:
if "cuda" in device.lower(): if "cuda" in device.lower():