Fix kv_b_proj shape for unsloth quantized models

This commit is contained in:
Ye Zhou 2025-06-05 17:33:11 +08:00
parent 7071970339
commit 255c0fcf3b

View file

@ -130,6 +130,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
attn_k_b = attn_k_b.transpose(1, 2).contiguous()
attn_v_b = load_dequantized_tensor(translated_key.replace("self_attn.kv_b_proj", "attn_v_b"), device=device).to(dtype=target_dtype)
kv_b_proj = torch.cat((attn_k_b, attn_v_b), dim=1)
kv_b_proj = kv_b_proj.contiguous() if kv_b_proj.ndim == 2 else kv_b_proj.flatten(0, 1).contiguous()
set_param(module, name, kv_b_proj)
del attn_k_b
del attn_v_b