diff --git a/ktransformers/util/custom_loader.py b/ktransformers/util/custom_loader.py index edb92de..003f93c 100644 --- a/ktransformers/util/custom_loader.py +++ b/ktransformers/util/custom_loader.py @@ -446,7 +446,11 @@ class GGUFLoader(ModelLoader): blocks_begin = i * blocks_per_iter blocks_end = min(blocks_begin + blocks_per_iter, num_blocks) if "cuda" in device.lower(): - cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype) + try: + cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype) + except: + cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size]) + cur_values = torch.from_numpy(cur_values.copy()).to(device) else: cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size]) cur_values = torch.from_numpy(cur_values.copy()) diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 308def1..891d032 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -117,7 +117,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str load_dequantized_tensor = gguf_loader.load_gguf_tensor tensor_file_map = gguf_loader.tensor_file_map - if gguf_loader.has_tensor(translated_key): + if gguf_loader.has_tensor(translated_key) or "kv_b_proj" in translated_key: target_dtype = torch.get_default_dtype() device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map) print(f"loading {translated_key} to {device}") @@ -125,9 +125,18 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str torch.cuda.empty_cache() elif torch.xpu.is_available(): torch.xpu.empty_cache() - weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype) - set_param(module, name, weights) - del weights + if "kv_b_proj" in translated_key and not gguf_loader.has_tensor(translated_key): + attn_k_b = load_dequantized_tensor(translated_key.replace("self_attn.kv_b_proj", "attn_k_b"), device=device).to(dtype=target_dtype) + attn_k_b = attn_k_b.transpose(1, 2).contiguous() + attn_v_b = load_dequantized_tensor(translated_key.replace("self_attn.kv_b_proj", "attn_v_b"), device=device).to(dtype=target_dtype) + kv_b_proj = torch.cat((attn_k_b, attn_v_b), dim=1) + set_param(module, name, kv_b_proj) + del attn_k_b + del attn_v_b + else: + weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype) + set_param(module, name, weights) + del weights else: #print(load_config.tensor_file_map.keys()) raise Exception(f"can't find {translated_key} in GGUF file!")