mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-11 07:44:35 +00:00
support glm4moe
This commit is contained in:
parent
1677e90092
commit
d03d92ba53
31 changed files with 2265 additions and 74 deletions
|
@ -194,6 +194,7 @@ class KExpertsCPU(KExpertsBase):
|
|||
64,
|
||||
10,
|
||||
1024,
|
||||
self.config.hidden_act == 'silu',
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
down_ptr,
|
||||
|
@ -214,6 +215,7 @@ class KExpertsCPU(KExpertsBase):
|
|||
self.config.hidden_size,
|
||||
self.config.moe_intermediate_size,
|
||||
max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size,
|
||||
self.config.hidden_act == 'silu',
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
down_ptr,
|
||||
|
@ -232,6 +234,7 @@ class KExpertsCPU(KExpertsBase):
|
|||
self.config.hidden_size,
|
||||
self.config.moe_intermediate_size,
|
||||
max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size,
|
||||
self.config.hidden_act == 'silu',
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
down_ptr,
|
||||
|
@ -729,6 +732,8 @@ from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE
|
|||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
|
||||
from ktransformers.models.modeling_smallthinker import SmallthinkerMoeBlock
|
||||
from ktransformers.models.modeling_glm4_moe import Glm4MoeMoE
|
||||
|
||||
|
||||
class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
|
||||
|
@ -1248,6 +1253,12 @@ class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):
|
|||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
|
||||
if prefill_op == 'None':
|
||||
prefill_op = None
|
||||
if generate_op == 'None':
|
||||
generate_op = None
|
||||
|
||||
if generate_op is not None:
|
||||
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
|
||||
else:
|
||||
|
@ -1307,6 +1318,152 @@ class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):
|
|||
else:
|
||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||
|
||||
|
||||
class KSmallthinkerExperts(BaseInjectedModule, KExpertsBase):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
# device: str = "cuda",
|
||||
prefill_device:str = "cuda",
|
||||
prefill_op: str | None = "KExpertsTorch",
|
||||
generate_device: str = "cpu",
|
||||
generate_op: str | None = "KExpertsCPU",
|
||||
**kwargs):
|
||||
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
if generate_op is not None:
|
||||
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
|
||||
else:
|
||||
self.generate_experts = None
|
||||
if prefill_op is not None:
|
||||
self.prefill_experts = None
|
||||
self.gpu_mlp_type = prefill_op
|
||||
self.cpu_mlp_type = generate_op
|
||||
self.mode = InferenceState.UNLOAD
|
||||
|
||||
def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True):
|
||||
# TODO support w as input
|
||||
if not mode: mode = InferenceState.GENERATE
|
||||
if mode == InferenceState.GENERATE:
|
||||
# self.prefill_experts.unload()
|
||||
self.generate_experts.load(w, warmup=warmup)
|
||||
self.device = self.generate_experts.device
|
||||
self.mode = mode
|
||||
elif mode == InferenceState.PREFILL:
|
||||
self.generate_experts.unload()
|
||||
self.prefill_experts.load(w, warmup=warmup)
|
||||
self.device = self.prefill_experts.device
|
||||
self.mode = mode
|
||||
elif mode == InferenceState.UNLOAD:
|
||||
self.unload()
|
||||
self.mode = mode
|
||||
self.device = self.generate_experts.device
|
||||
else:
|
||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||
|
||||
def unload(self):
|
||||
if self.generate_experts is not None:
|
||||
self.generate_experts.unload()
|
||||
if self.prefill_experts is not None:
|
||||
self.prefill_experts.unload()
|
||||
self.device = self.generate_experts.device
|
||||
|
||||
def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0):
|
||||
if self.mode == InferenceState.GENERATE:
|
||||
assert self.generate_experts is not None, "generate_experts is None"
|
||||
return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
|
||||
elif self.mode == InferenceState.PREFILL:
|
||||
assert self.prefill_experts is not None, "prefill_experts is None"
|
||||
return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
|
||||
else:
|
||||
raise ValueError("load or set_inference_mode before forward")
|
||||
|
||||
def set_inference_mode(self, mode: InferenceState):
|
||||
if mode == InferenceState.GENERATE:
|
||||
self.load(mode=InferenceState.GENERATE, warmup=False)
|
||||
elif mode == InferenceState.PREFILL:
|
||||
self.load(mode=InferenceState.PREFILL, warmup=False)
|
||||
elif mode == InferenceState.UNLOAD:
|
||||
self.unload()
|
||||
else:
|
||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||
|
||||
class KGlm4Experts(BaseInjectedModule, KExpertsBase):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
# device: str = "cuda",
|
||||
prefill_device:str = "cuda",
|
||||
prefill_op: str | None = "KExpertsTorch",
|
||||
generate_device: str = "cpu",
|
||||
generate_op: str | None = "KExpertsCPU",
|
||||
**kwargs):
|
||||
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
if generate_op is not None:
|
||||
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
|
||||
else:
|
||||
self.generate_experts = None
|
||||
if prefill_op is not None:
|
||||
self.prefill_experts = None
|
||||
self.gpu_mlp_type = prefill_op
|
||||
self.cpu_mlp_type = generate_op
|
||||
self.mode = InferenceState.UNLOAD
|
||||
|
||||
def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True):
|
||||
# TODO support w as input
|
||||
if not mode: mode = InferenceState.GENERATE
|
||||
if mode == InferenceState.GENERATE:
|
||||
# self.prefill_experts.unload()
|
||||
self.generate_experts.load(w, warmup=warmup)
|
||||
self.device = self.generate_experts.device
|
||||
self.mode = mode
|
||||
elif mode == InferenceState.PREFILL:
|
||||
self.generate_experts.unload()
|
||||
self.prefill_experts.load(w, warmup=warmup)
|
||||
self.device = self.prefill_experts.device
|
||||
self.mode = mode
|
||||
elif mode == InferenceState.UNLOAD:
|
||||
self.unload()
|
||||
self.mode = mode
|
||||
self.device = self.generate_experts.device
|
||||
else:
|
||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||
|
||||
def unload(self):
|
||||
if self.generate_experts is not None:
|
||||
self.generate_experts.unload()
|
||||
if self.prefill_experts is not None:
|
||||
self.prefill_experts.unload()
|
||||
self.device = self.generate_experts.device
|
||||
|
||||
def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0):
|
||||
if self.mode == InferenceState.GENERATE:
|
||||
assert self.generate_experts is not None, "generate_experts is None"
|
||||
return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
|
||||
elif self.mode == InferenceState.PREFILL:
|
||||
assert self.prefill_experts is not None, "prefill_experts is None"
|
||||
return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
|
||||
else:
|
||||
raise ValueError("load or set_inference_mode before forward")
|
||||
|
||||
def set_inference_mode(self, mode: InferenceState):
|
||||
if mode == InferenceState.GENERATE:
|
||||
self.load(mode=InferenceState.GENERATE, warmup=False)
|
||||
elif mode == InferenceState.PREFILL:
|
||||
self.load(mode=InferenceState.PREFILL, warmup=False)
|
||||
elif mode == InferenceState.UNLOAD:
|
||||
self.unload()
|
||||
else:
|
||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||
|
||||
|
||||
class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
|
||||
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
|
||||
|
||||
|
@ -1507,6 +1664,246 @@ class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock):
|
|||
)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer(self, x, topk_ids, topk_weight):
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
||||
cnts.scatter_(1, topk_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = topk_ids.view(-1).argsort()
|
||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = expert.forward(tokens_for_this_expert)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
final_out = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weight.dtype)
|
||||
.mul_(topk_weight.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return final_out
|
||||
|
||||
|
||||
class KSmallthinkerMoeBlock(BaseInjectedModule, SmallthinkerMoeBlock):
|
||||
def forward(self, router_input: torch.Tensor, hidden_states: torch.Tensor, bsz_tensor=None, cuda_graph_idx=0):
|
||||
|
||||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
if bsz_tensor is None:
|
||||
if self.enable_early_router:
|
||||
router_logits = self.primary_router(router_input)
|
||||
else:
|
||||
router_logits = self.primary_router(hidden_states)
|
||||
else:
|
||||
if self.enable_early_router:
|
||||
router_logits = self.primary_router(router_input, bsz_tensor)
|
||||
else:
|
||||
router_logits = self.primary_router(hidden_states, bsz_tensor)
|
||||
|
||||
router_logits, selected_experts = torch.topk(router_logits, self.num_active_primary_experts, dim=-1)
|
||||
|
||||
|
||||
if router_logits.device.type == "xpu":
|
||||
# TODO: support self.moe_primary_router_apply_softmax False case
|
||||
from ipex_llm.transformers.models.common import moe_softmax_topk
|
||||
selected_experts, routing_weights = moe_softmax_topk(
|
||||
router_logits.half(), self.top_k, self.norm_topk_prob
|
||||
)
|
||||
else:
|
||||
if self.moe_primary_router_apply_softmax:
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
else:
|
||||
routing_weights = F.sigmoid(router_logits)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
# only for generate phase
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
|
||||
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
|
||||
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
|
||||
|
||||
# y += y_
|
||||
y.resize_(*orig_shape)
|
||||
return y
|
||||
|
||||
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = (
|
||||
# F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
# )
|
||||
|
||||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
else:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer_simple(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
# y += y_
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
|
||||
outs = torch.empty_like(x)
|
||||
outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer_simple(
|
||||
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: [num_tokens, hidden_size]
|
||||
topk_ids, topk_weight: [num_tokens, num_selected_experts]
|
||||
"""
|
||||
outs = torch.zeros_like(x)
|
||||
for token_idx in range(topk_ids.size(0)):
|
||||
for expert_idx in range(topk_ids.size(1)):
|
||||
expert = self.experts[topk_ids[token_idx, expert_idx]]
|
||||
outs[token_idx] += (
|
||||
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
|
||||
)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer(self, x, topk_ids, topk_weight):
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
||||
cnts.scatter_(1, topk_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = topk_ids.view(-1).argsort()
|
||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = expert.forward(tokens_for_this_expert)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
final_out = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weight.dtype)
|
||||
.mul_(topk_weight.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return final_out
|
||||
|
||||
|
||||
class KGlm4MoeMoE(BaseInjectedModule, Glm4MoeMoE):
|
||||
def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0):
|
||||
|
||||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
|
||||
topk_idx, topk_weight = self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
# only for generate phase
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||
y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
|
||||
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
|
||||
|
||||
y += y_
|
||||
y.resize_(*orig_shape)
|
||||
return y
|
||||
|
||||
y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = (
|
||||
# F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
# )
|
||||
|
||||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer(hidden_states, topk_idx, topk_weight)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
else:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer_simple(hidden_states, topk_idx, topk_weight)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
y += y_
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
|
||||
outs = torch.empty_like(x)
|
||||
outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer_simple(
|
||||
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: [num_tokens, hidden_size]
|
||||
topk_ids, topk_weight: [num_tokens, num_selected_experts]
|
||||
"""
|
||||
outs = torch.zeros_like(x)
|
||||
for token_idx in range(topk_ids.size(0)):
|
||||
for expert_idx in range(topk_ids.size(1)):
|
||||
expert = self.experts[topk_ids[token_idx, expert_idx]]
|
||||
outs[token_idx] += (
|
||||
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
|
||||
)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer(self, x, topk_ids, topk_weight):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue