mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 20:19:51 +00:00
Merge pull request #612 from kvcache-ai/fix-bf16-load
fix bf16 load, TODO: refactor cpu dequant
This commit is contained in:
commit
cdb6f896bb
1 changed files with 6 additions and 2 deletions
|
@ -314,10 +314,12 @@ class GGUFLoader:
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = torch.get_default_dtype())->torch.Tensor:
|
def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = None)->torch.Tensor:
|
||||||
t = self.tensor_info[name]
|
t = self.tensor_info[name]
|
||||||
if device.lower() == "cpu":
|
if device.lower() == "cpu":
|
||||||
print(f"loading {name} with CPU")
|
print(f"loading {name} with CPU")
|
||||||
|
if target_dtype == None:
|
||||||
|
target_dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
shape = t["shape"]
|
shape = t["shape"]
|
||||||
ggml_type = t["ggml_type"]
|
ggml_type = t["ggml_type"]
|
||||||
|
@ -336,7 +338,7 @@ class GGUFLoader:
|
||||||
|
|
||||||
blocks_per_iter = 16384
|
blocks_per_iter = 16384
|
||||||
if num_blocks > blocks_per_iter: # dequant large tensor
|
if num_blocks > blocks_per_iter: # dequant large tensor
|
||||||
values = torch.empty((num_blocks, elements_per_block), dtype=torch.float, device=device)
|
values = torch.empty((num_blocks, elements_per_block), dtype=target_dtype, device=device)
|
||||||
for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter):
|
for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter):
|
||||||
blocks_begin = i * blocks_per_iter
|
blocks_begin = i * blocks_per_iter
|
||||||
blocks_end = min(blocks_begin + blocks_per_iter, num_blocks)
|
blocks_end = min(blocks_begin + blocks_per_iter, num_blocks)
|
||||||
|
@ -347,6 +349,8 @@ class GGUFLoader:
|
||||||
cur_values = torch.from_numpy(cur_values.copy())
|
cur_values = torch.from_numpy(cur_values.copy())
|
||||||
|
|
||||||
cur_values = cur_values.view(-1, elements_per_block)
|
cur_values = cur_values.view(-1, elements_per_block)
|
||||||
|
if ggml_name == "BF16":
|
||||||
|
cur_values = cur_values.view(torch.bfloat16)
|
||||||
values[blocks_begin : blocks_end] = cur_values
|
values[blocks_begin : blocks_end] = cur_values
|
||||||
else:
|
else:
|
||||||
if "cuda" in device.lower():
|
if "cuda" in device.lower():
|
||||||
|
|
Loading…
Add table
Reference in a new issue