mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-07 21:19:51 +00:00
[Patch] lload DeepSeek-R1-0528
This commit is contained in:
parent
ac48a58cca
commit
a6b3243a56
2 changed files with 18 additions and 5 deletions
|
@ -446,7 +446,11 @@ class GGUFLoader(ModelLoader):
|
|||
blocks_begin = i * blocks_per_iter
|
||||
blocks_end = min(blocks_begin + blocks_per_iter, num_blocks)
|
||||
if "cuda" in device.lower():
|
||||
cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype)
|
||||
try:
|
||||
cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype)
|
||||
except:
|
||||
cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size])
|
||||
cur_values = torch.from_numpy(cur_values.copy()).to(device)
|
||||
else:
|
||||
cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size])
|
||||
cur_values = torch.from_numpy(cur_values.copy())
|
||||
|
|
|
@ -117,7 +117,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
|
|||
load_dequantized_tensor = gguf_loader.load_gguf_tensor
|
||||
tensor_file_map = gguf_loader.tensor_file_map
|
||||
|
||||
if gguf_loader.has_tensor(translated_key):
|
||||
if gguf_loader.has_tensor(translated_key) or "kv_b_proj" in translated_key:
|
||||
target_dtype = torch.get_default_dtype()
|
||||
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
|
||||
print(f"loading {translated_key} to {device}")
|
||||
|
@ -125,9 +125,18 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
|
|||
torch.cuda.empty_cache()
|
||||
elif torch.xpu.is_available():
|
||||
torch.xpu.empty_cache()
|
||||
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
|
||||
set_param(module, name, weights)
|
||||
del weights
|
||||
if "kv_b_proj" in translated_key and not gguf_loader.has_tensor(translated_key):
|
||||
attn_k_b = load_dequantized_tensor(translated_key.replace("self_attn.kv_b_proj", "attn_k_b"), device=device).to(dtype=target_dtype)
|
||||
attn_k_b = attn_k_b.transpose(1, 2).contiguous()
|
||||
attn_v_b = load_dequantized_tensor(translated_key.replace("self_attn.kv_b_proj", "attn_v_b"), device=device).to(dtype=target_dtype)
|
||||
kv_b_proj = torch.cat((attn_k_b, attn_v_b), dim=1)
|
||||
set_param(module, name, kv_b_proj)
|
||||
del attn_k_b
|
||||
del attn_v_b
|
||||
else:
|
||||
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
|
||||
set_param(module, name, weights)
|
||||
del weights
|
||||
else:
|
||||
#print(load_config.tensor_file_map.keys())
|
||||
raise Exception(f"can't find {translated_key} in GGUF file!")
|
||||
|
|
Loading…
Add table
Reference in a new issue