diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index 0c42c9c..919f432 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -314,10 +314,12 @@ class GGUFLoader: return values - def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = torch.get_default_dtype())->torch.Tensor: + def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = None)->torch.Tensor: t = self.tensor_info[name] if device.lower() == "cpu": print(f"loading {name} with CPU") + if target_dtype == None: + target_dtype = torch.get_default_dtype() shape = t["shape"] ggml_type = t["ggml_type"] @@ -336,7 +338,7 @@ class GGUFLoader: blocks_per_iter = 16384 if num_blocks > blocks_per_iter: # dequant large tensor - values = torch.empty((num_blocks, elements_per_block), dtype=torch.float, device=device) + values = torch.empty((num_blocks, elements_per_block), dtype=target_dtype, device=device) for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter): blocks_begin = i * blocks_per_iter blocks_end = min(blocks_begin + blocks_per_iter, num_blocks) @@ -347,6 +349,8 @@ class GGUFLoader: cur_values = torch.from_numpy(cur_values.copy()) cur_values = cur_values.view(-1, elements_per_block) + if ggml_name == "BF16": + cur_values = cur_values.view(torch.bfloat16) values[blocks_begin : blocks_end] = cur_values else: if "cuda" in device.lower():