model: add Ernie 4.5 MoE support (#14658)

* Add Ernie4.5 MoE

* Fix Flake errors.

* Properly encode/decode MoE layer step

* Correct tensor mappings (.weight)

* Pass and read n_ff_exp

* n_ff_shexp calculation and further minor changes

* Rope fixes.

* .gitignore fix

* Add unit32 cast for Linux builds

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Further fixes from code review

* Fix trailing whitespace

* Reenable missing experts error

* Code style from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Fix non-MoE regression

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
Piotr Wilkin (ilintar) 2025-07-17 23:15:32 +02:00 committed by GitHub
parent d6fb3f6b49
commit cb887f1bc1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 373 additions and 26 deletions

View file

@ -2861,7 +2861,8 @@ class Ernie4_5Model(TextModel):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
num_heads = self.hparams["num_attention_heads"]
num_kv_heads = self.hparams["num_key_value_heads"]
head_dim = self.hparams["head_dim"]
if (head_dim := self.hparams.get("head_dim")) is None:
head_dim = self.hparams["hidden_size"] // num_heads
if "ernie." in name:
name = name.replace("ernie.", "model.")
@ -2894,6 +2895,92 @@ class Ernie4_5Model(TextModel):
return [(self.map_tensor_name(name), data_torch)]
@ModelBase.register("Ernie4_5_MoeForCausalLM")
class Ernie4_5MoeModel(Ernie4_5Model):
model_arch = gguf.MODEL_ARCH.ERNIE4_5_MOE
_experts: list[dict[str, Tensor]] | None = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._experts = [{} for _ in range(self.block_count)]
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_expert_count(self.hparams["moe_num_experts"])
self.gguf_writer.add_expert_used_count(self.hparams["moe_k"])
self.gguf_writer.add_interleave_moe_layer_step(self.hparams["moe_layer_interval"])
self.gguf_writer.add_leading_dense_block_count(self.hparams["moe_layer_start_index"])
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
if (shared_expert_intermediate_size := self.hparams.get('intermediate_size')) is not None and (num_key_value_heads := self.hparams.get('num_key_value_heads')) is not None:
self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_intermediate_size // num_key_value_heads)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Modify correction bias name as in DeepseekV2
if name.endswith("e_score_correction_bias"):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
# skip Multi-Token Prediction (MTP) layers (again, same as DeepseekV2)
match = re.match(r"model.mtp_block.(\d+)", name)
if match:
return []
# skip all other MTP tensors for now
match = re.match(r"model.mtp_emb_norm.(\d+)", name)
if match:
return []
match = re.match(r"model.mtp_hidden_norm.(\d+)", name)
if match:
return []
match = re.match(r"model.mtp_linear_proj.(\d+)", name)
if match:
return []
# process the experts separately
if name.find("mlp.experts") != -1:
n_experts = self.hparams["moe_num_experts"]
assert bid is not None
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]
self._experts[bid][name] = data_torch
if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []
# merge the experts into a single 3d tensor
for w_name in ["gate_proj", "up_proj", "down_proj"]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename_to_retrieve = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename_to_retrieve])
del self._experts[bid][ename_to_retrieve]
data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))
return tensors
else:
return []
return [(self.map_tensor_name(name), data_torch)]
def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register(
"Qwen2VLModel",
"Qwen2VLForConditionalGeneration",