[Patch] lload DeepSeek-R1-0528

This commit is contained in:
qiyuxinlin 2025-05-31 14:19:20 +00:00
parent ac48a58cca
commit a6b3243a56
2 changed files with 18 additions and 5 deletions

View file

@ -446,7 +446,11 @@ class GGUFLoader(ModelLoader):
blocks_begin = i * blocks_per_iter
blocks_end = min(blocks_begin + blocks_per_iter, num_blocks)
if "cuda" in device.lower():
cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype)
try:
cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype)
except:
cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size])
cur_values = torch.from_numpy(cur_values.copy()).to(device)
else:
cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size])
cur_values = torch.from_numpy(cur_values.copy())

View file

@ -117,7 +117,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
load_dequantized_tensor = gguf_loader.load_gguf_tensor
tensor_file_map = gguf_loader.tensor_file_map
if gguf_loader.has_tensor(translated_key):
if gguf_loader.has_tensor(translated_key) or "kv_b_proj" in translated_key:
target_dtype = torch.get_default_dtype()
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
print(f"loading {translated_key} to {device}")
@ -125,9 +125,18 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
torch.cuda.empty_cache()
elif torch.xpu.is_available():
torch.xpu.empty_cache()
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
set_param(module, name, weights)
del weights
if "kv_b_proj" in translated_key and not gguf_loader.has_tensor(translated_key):
attn_k_b = load_dequantized_tensor(translated_key.replace("self_attn.kv_b_proj", "attn_k_b"), device=device).to(dtype=target_dtype)
attn_k_b = attn_k_b.transpose(1, 2).contiguous()
attn_v_b = load_dequantized_tensor(translated_key.replace("self_attn.kv_b_proj", "attn_v_b"), device=device).to(dtype=target_dtype)
kv_b_proj = torch.cat((attn_k_b, attn_v_b), dim=1)
set_param(module, name, kv_b_proj)
del attn_k_b
del attn_v_b
else:
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
set_param(module, name, weights)
del weights
else:
#print(load_config.tensor_file_map.keys())
raise Exception(f"can't find {translated_key} in GGUF file!")