support glm4moe

This commit is contained in:
djw 2025-07-25 17:22:20 +00:00
parent 1677e90092
commit d03d92ba53
31 changed files with 2265 additions and 74 deletions

View file

@ -194,6 +194,7 @@ class KExpertsCPU(KExpertsBase):
64,
10,
1024,
self.config.hidden_act == 'silu',
gate_ptr,
up_ptr,
down_ptr,
@ -214,6 +215,7 @@ class KExpertsCPU(KExpertsBase):
self.config.hidden_size,
self.config.moe_intermediate_size,
max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size,
self.config.hidden_act == 'silu',
gate_ptr,
up_ptr,
down_ptr,
@ -232,6 +234,7 @@ class KExpertsCPU(KExpertsBase):
self.config.hidden_size,
self.config.moe_intermediate_size,
max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size,
self.config.hidden_act == 'silu',
gate_ptr,
up_ptr,
down_ptr,
@ -729,6 +732,8 @@ from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
from ktransformers.models.modeling_smallthinker import SmallthinkerMoeBlock
from ktransformers.models.modeling_glm4_moe import Glm4MoeMoE
class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
@ -1248,6 +1253,12 @@ class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
if prefill_op == 'None':
prefill_op = None
if generate_op == 'None':
generate_op = None
if generate_op is not None:
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
else:
@ -1307,6 +1318,152 @@ class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):
else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
class KSmallthinkerExperts(BaseInjectedModule, KExpertsBase):
def __init__(self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
prefill_device:str = "cuda",
prefill_op: str | None = "KExpertsTorch",
generate_device: str = "cpu",
generate_op: str | None = "KExpertsCPU",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
if generate_op is not None:
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
else:
self.generate_experts = None
if prefill_op is not None:
self.prefill_experts = None
self.gpu_mlp_type = prefill_op
self.cpu_mlp_type = generate_op
self.mode = InferenceState.UNLOAD
def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True):
# TODO support w as input
if not mode: mode = InferenceState.GENERATE
if mode == InferenceState.GENERATE:
# self.prefill_experts.unload()
self.generate_experts.load(w, warmup=warmup)
self.device = self.generate_experts.device
self.mode = mode
elif mode == InferenceState.PREFILL:
self.generate_experts.unload()
self.prefill_experts.load(w, warmup=warmup)
self.device = self.prefill_experts.device
self.mode = mode
elif mode == InferenceState.UNLOAD:
self.unload()
self.mode = mode
self.device = self.generate_experts.device
else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
def unload(self):
if self.generate_experts is not None:
self.generate_experts.unload()
if self.prefill_experts is not None:
self.prefill_experts.unload()
self.device = self.generate_experts.device
def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0):
if self.mode == InferenceState.GENERATE:
assert self.generate_experts is not None, "generate_experts is None"
return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
elif self.mode == InferenceState.PREFILL:
assert self.prefill_experts is not None, "prefill_experts is None"
return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
else:
raise ValueError("load or set_inference_mode before forward")
def set_inference_mode(self, mode: InferenceState):
if mode == InferenceState.GENERATE:
self.load(mode=InferenceState.GENERATE, warmup=False)
elif mode == InferenceState.PREFILL:
self.load(mode=InferenceState.PREFILL, warmup=False)
elif mode == InferenceState.UNLOAD:
self.unload()
else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
class KGlm4Experts(BaseInjectedModule, KExpertsBase):
def __init__(self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
prefill_device:str = "cuda",
prefill_op: str | None = "KExpertsTorch",
generate_device: str = "cpu",
generate_op: str | None = "KExpertsCPU",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
if generate_op is not None:
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
else:
self.generate_experts = None
if prefill_op is not None:
self.prefill_experts = None
self.gpu_mlp_type = prefill_op
self.cpu_mlp_type = generate_op
self.mode = InferenceState.UNLOAD
def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True):
# TODO support w as input
if not mode: mode = InferenceState.GENERATE
if mode == InferenceState.GENERATE:
# self.prefill_experts.unload()
self.generate_experts.load(w, warmup=warmup)
self.device = self.generate_experts.device
self.mode = mode
elif mode == InferenceState.PREFILL:
self.generate_experts.unload()
self.prefill_experts.load(w, warmup=warmup)
self.device = self.prefill_experts.device
self.mode = mode
elif mode == InferenceState.UNLOAD:
self.unload()
self.mode = mode
self.device = self.generate_experts.device
else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
def unload(self):
if self.generate_experts is not None:
self.generate_experts.unload()
if self.prefill_experts is not None:
self.prefill_experts.unload()
self.device = self.generate_experts.device
def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0):
if self.mode == InferenceState.GENERATE:
assert self.generate_experts is not None, "generate_experts is None"
return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
elif self.mode == InferenceState.PREFILL:
assert self.prefill_experts is not None, "prefill_experts is None"
return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
else:
raise ValueError("load or set_inference_mode before forward")
def set_inference_mode(self, mode: InferenceState):
if mode == InferenceState.GENERATE:
self.load(mode=InferenceState.GENERATE, warmup=False)
elif mode == InferenceState.PREFILL:
self.load(mode=InferenceState.PREFILL, warmup=False)
elif mode == InferenceState.UNLOAD:
self.unload()
else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
@ -1507,6 +1664,246 @@ class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock):
)
return outs
@torch.no_grad()
# TODO may bugs here
def moe_infer(self, x, topk_ids, topk_weight):
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
cnts.scatter_(1, topk_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = expert.forward(tokens_for_this_expert)
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.type(topk_weight.dtype)
.mul_(topk_weight.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return final_out
class KSmallthinkerMoeBlock(BaseInjectedModule, SmallthinkerMoeBlock):
def forward(self, router_input: torch.Tensor, hidden_states: torch.Tensor, bsz_tensor=None, cuda_graph_idx=0):
orig_shape = hidden_states.shape
sequence_length = orig_shape[1]
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
if bsz_tensor is None:
if self.enable_early_router:
router_logits = self.primary_router(router_input)
else:
router_logits = self.primary_router(hidden_states)
else:
if self.enable_early_router:
router_logits = self.primary_router(router_input, bsz_tensor)
else:
router_logits = self.primary_router(hidden_states, bsz_tensor)
router_logits, selected_experts = torch.topk(router_logits, self.num_active_primary_experts, dim=-1)
if router_logits.device.type == "xpu":
# TODO: support self.moe_primary_router_apply_softmax False case
from ipex_llm.transformers.models.common import moe_softmax_topk
selected_experts, routing_weights = moe_softmax_topk(
router_logits.half(), self.top_k, self.norm_topk_prob
)
else:
if self.moe_primary_router_apply_softmax:
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
else:
routing_weights = F.sigmoid(router_logits)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
# only for generate phase
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
# y += y_
y.resize_(*orig_shape)
return y
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
# y_ = (
# F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
# )
if isinstance(self.experts, KExpertsBase):
y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
elif hidden_states.size(0) > 10:
# TODO may bugs here
y = (
self.moe_infer(hidden_states, selected_experts, routing_weights)
.view(*orig_shape)
.to(device=hidden_states.device)
)
else:
# TODO may bugs here
y = (
self.moe_infer_simple(hidden_states, selected_experts, routing_weights)
.view(*orig_shape)
.to(device=hidden_states.device)
)
# y += y_
return y
@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
outs = torch.empty_like(x)
outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
return outs
@torch.no_grad()
# TODO may bugs here
def moe_infer_simple(
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
) -> torch.Tensor:
"""
x: [num_tokens, hidden_size]
topk_ids, topk_weight: [num_tokens, num_selected_experts]
"""
outs = torch.zeros_like(x)
for token_idx in range(topk_ids.size(0)):
for expert_idx in range(topk_ids.size(1)):
expert = self.experts[topk_ids[token_idx, expert_idx]]
outs[token_idx] += (
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
)
return outs
@torch.no_grad()
# TODO may bugs here
def moe_infer(self, x, topk_ids, topk_weight):
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
cnts.scatter_(1, topk_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = expert.forward(tokens_for_this_expert)
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.type(topk_weight.dtype)
.mul_(topk_weight.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return final_out
class KGlm4MoeMoE(BaseInjectedModule, Glm4MoeMoE):
def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0):
orig_shape = hidden_states.shape
sequence_length = orig_shape[1]
topk_idx, topk_weight = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
# only for generate phase
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)
y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0)
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
y += y_
y.resize_(*orig_shape)
return y
y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0)
# y_ = (
# F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
# )
if isinstance(self.experts, KExpertsBase):
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
elif hidden_states.size(0) > 10:
# TODO may bugs here
y = (
self.moe_infer(hidden_states, topk_idx, topk_weight)
.view(*orig_shape)
.to(device=hidden_states.device)
)
else:
# TODO may bugs here
y = (
self.moe_infer_simple(hidden_states, topk_idx, topk_weight)
.view(*orig_shape)
.to(device=hidden_states.device)
)
y += y_
return y
@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
outs = torch.empty_like(x)
outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
return outs
@torch.no_grad()
# TODO may bugs here
def moe_infer_simple(
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
) -> torch.Tensor:
"""
x: [num_tokens, hidden_size]
topk_ids, topk_weight: [num_tokens, num_selected_experts]
"""
outs = torch.zeros_like(x)
for token_idx in range(topk_ids.size(0)):
for expert_idx in range(topk_ids.size(1)):
expert = self.experts[topk_ids[token_idx, expert_idx]]
outs[token_idx] += (
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
)
return outs
@torch.no_grad()
# TODO may bugs here
def moe_infer(self, x, topk_ids, topk_weight):