mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 20:00:06 +00:00
support qwen3.5 (#1846)
Some checks failed
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
Some checks failed
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
This commit is contained in:
parent
411b69bec0
commit
16a8b98f3e
2 changed files with 218 additions and 0 deletions
|
|
@ -440,6 +440,13 @@ class BF16SafeTensorLoader(SafeTensorLoader):
|
|||
"""Auto-detect the MoE naming format by checking tensor keys."""
|
||||
sample_keys = list(self.tensor_file_map.keys())[:1000]
|
||||
|
||||
# Check for packed format first (Qwen3.5 MoE style: all experts in one 3D tensor)
|
||||
for key in sample_keys:
|
||||
if key.endswith(".mlp.experts.gate_up_proj"):
|
||||
self._detected_format = "packed"
|
||||
print("[BF16SafeTensorLoader] Detected format: packed (Qwen3.5 MoE style)")
|
||||
return
|
||||
|
||||
for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items():
|
||||
for key in sample_keys:
|
||||
if ".experts." in key and f".{gate}.weight" in key:
|
||||
|
|
@ -479,6 +486,9 @@ class BF16SafeTensorLoader(SafeTensorLoader):
|
|||
|
||||
def load_experts(self, base_key: str, device: str = "cpu"):
|
||||
"""Load BF16 expert weights (no scales needed)."""
|
||||
if self._detected_format == "packed":
|
||||
return self._load_experts_packed(base_key, device)
|
||||
|
||||
experts_prefix = self._get_experts_prefix(base_key)
|
||||
gate_name, up_name, down_name = self._get_proj_names()
|
||||
|
||||
|
|
@ -533,6 +543,13 @@ class BF16SafeTensorLoader(SafeTensorLoader):
|
|||
"""Auto-detect the MoE naming format by checking tensor keys."""
|
||||
sample_keys = list(self.tensor_file_map.keys())[:1000]
|
||||
|
||||
# Check for packed format first (Qwen3.5 MoE style: all experts in one 3D tensor)
|
||||
for key in sample_keys:
|
||||
if key.endswith(".mlp.experts.gate_up_proj"):
|
||||
self._detected_format = "packed"
|
||||
print("[BF16SafeTensorLoader] Detected format: packed (Qwen3.5 MoE style)")
|
||||
return
|
||||
|
||||
for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items():
|
||||
for key in sample_keys:
|
||||
if ".experts." in key and f".{gate}.weight" in key:
|
||||
|
|
@ -572,6 +589,9 @@ class BF16SafeTensorLoader(SafeTensorLoader):
|
|||
|
||||
def load_experts(self, base_key: str, device: str = "cpu"):
|
||||
"""Load BF16 expert weights (no scales needed)."""
|
||||
if self._detected_format == "packed":
|
||||
return self._load_experts_packed(base_key, device)
|
||||
|
||||
experts_prefix = self._get_experts_prefix(base_key)
|
||||
gate_name, up_name, down_name = self._get_proj_names()
|
||||
|
||||
|
|
@ -601,6 +621,49 @@ class BF16SafeTensorLoader(SafeTensorLoader):
|
|||
"down": down_weights,
|
||||
}
|
||||
|
||||
def _resolve_packed_experts_prefix(self, base_key: str) -> str:
|
||||
"""Resolve the experts prefix for packed format, trying fallbacks."""
|
||||
# Direct: model.layers.{N}.mlp.experts
|
||||
experts_prefix = f"{base_key}.mlp.experts"
|
||||
if self.has_tensor(f"{experts_prefix}.gate_up_proj"):
|
||||
return experts_prefix
|
||||
|
||||
# VL models: model.layers.{N} -> model.language_model.layers.{N}
|
||||
parts = base_key.split(".", 1)
|
||||
if len(parts) == 2:
|
||||
alt_base = f"{parts[0]}.language_model.{parts[1]}"
|
||||
experts_prefix = f"{alt_base}.mlp.experts"
|
||||
if self.has_tensor(f"{experts_prefix}.gate_up_proj"):
|
||||
return experts_prefix
|
||||
|
||||
raise ValueError(f"No packed experts found for base_key '{base_key}'.")
|
||||
|
||||
def _load_experts_packed(self, base_key: str, device: str = "cpu"):
|
||||
"""Load packed expert weights (Qwen3.5 MoE style).
|
||||
|
||||
Packed format stores all experts in stacked 3D tensors:
|
||||
- gate_up_proj: [num_experts, 2 * intermediate_size, hidden_size]
|
||||
- down_proj: [num_experts, hidden_size, intermediate_size]
|
||||
"""
|
||||
experts_prefix = self._resolve_packed_experts_prefix(base_key)
|
||||
|
||||
gate_up_key = f"{experts_prefix}.gate_up_proj"
|
||||
down_key = f"{experts_prefix}.down_proj"
|
||||
|
||||
gate_up = self.load_tensor(gate_up_key, device) # [E, 2*I, H]
|
||||
down = self.load_tensor(down_key, device) # [E, H, I]
|
||||
|
||||
mid = gate_up.shape[1] // 2
|
||||
gate_list = [gate_up[i, :mid, :].contiguous() for i in range(gate_up.shape[0])]
|
||||
up_list = [gate_up[i, mid:, :].contiguous() for i in range(gate_up.shape[0])]
|
||||
down_list = [down[i].contiguous() for i in range(down.shape[0])]
|
||||
|
||||
return {
|
||||
"gate": gate_list,
|
||||
"up": up_list,
|
||||
"down": down_list,
|
||||
}
|
||||
|
||||
|
||||
class CompressedSafeTensorLoader(SafeTensorLoader):
|
||||
"""Loader for compressed SafeTensor layouts (RAWINT4 weights)."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue