From b84524622ef0cbfefb412f4c86eb2b6ec032bfec Mon Sep 17 00:00:00 2001 From: Atream Date: Sun, 16 Feb 2025 06:43:27 +0000 Subject: [PATCH] support bf16 read --- ktransformers/util/custom_gguf.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index 62059a7..eaa1a7d 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -109,6 +109,7 @@ GGML_TYPES = { "Q5_K": 13, "Q6_K": 14, "IQ4_XS": 23, + "BF16": 30, } GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()} @@ -116,6 +117,7 @@ GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()} GGML_BLOCK_SIZES = { "F32": 4, "F16": 2, + "BF16": 2, "Q4_0": 2 + 16, "Q5_0": 2 + 4 + 16, "Q8_0": 2 + 32, @@ -130,6 +132,7 @@ GGML_BLOCK_SIZES = { GGML_ELEMENTS_PER_BLOCK = { "F32": 1, "F16": 1, + "BF16": 1, "Q4_0": 32, "Q5_0": 32, "Q8_0": 32, @@ -333,6 +336,8 @@ class GGUFLoader: else: values = GGML_DEQUANTIZE[ggml_name](data) values = torch.from_numpy(values) + if ggml_name == "BF16": + values = values.view(torch.bfloat16) values = values.view(shape[::-1]) if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]: n_head = self.gguf_file_meta['llama.attention.head_count'] @@ -764,6 +769,7 @@ def dequantize_f16_gpu(data, device): GGML_DEQUANTIZE = { "F32": dequantize_f32, "F16": dequantize_f16, + "BF16": dequantize_f16, "Q4_0": dequantize_q4_0, "Q5_0": dequantize_q5_0, "Q8_0": dequantize_q8_0, @@ -778,6 +784,7 @@ GGML_DEQUANTIZE = { GGML_DEQUANTIZE_GPU = { "F32": dequantize_f32_gpu, "F16": dequantize_f16_gpu, + "BF16": dequantize_f16_gpu, "Q4_0": dequantize_q4_0_gpu, "Q5_0": dequantize_q5_0_gpu, "Q8_0": dequantize_q8_0_gpu,