support deepseekv3; runable but have precition problem

This commit is contained in:
Azure 2025-01-31 08:27:24 +00:00
parent de7e892f72
commit 476b1d8dc6
13 changed files with 2178 additions and 24 deletions

View file

@ -519,6 +519,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
from ktransformers.models.modeling_deepseek import DeepseekV2MoE
from ktransformers.models.modeling_deepseekv3 import DeepseekV3MoE
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
@ -727,6 +728,106 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
)
return final_out
class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
def forward(self, hidden_states):
identity = hidden_states
orig_shape = hidden_states.shape
sequence_length = orig_shape[1]
topk_idx, topk_weight= self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity).squeeze(0)
y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)
y += y_
y.resize_(*orig_shape)
return y
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity).squeeze(0)
if isinstance(self.experts, KExpertsBase):
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
elif hidden_states.size(0) > 10:
# TODO may bugs here
y = (
self.moe_infer(hidden_states, topk_idx, topk_weight)
.view(*orig_shape)
.to(device=hidden_states.device)
)
else:
# TODO may bugs here
y = (
self.moe_infer_simple(hidden_states, topk_idx, topk_weight)
.view(*orig_shape)
.to(device=hidden_states.device)
)
if self.config.n_shared_experts is not None:
y += y_
return y
@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
outs = self.experts(x, topk_ids, topk_weight)
return outs
@torch.no_grad()
# TODO may bugs here
def moe_infer_simple(
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
) -> torch.Tensor:
"""
x: [num_tokens, hidden_size]
topk_ids, topk_weight: [num_tokens, num_selected_experts]
"""
outs = torch.zeros_like(x)
for token_idx in range(topk_ids.size(0)):
for expert_idx in range(topk_ids.size(1)):
expert = self.experts[topk_ids[token_idx, expert_idx]]
outs[token_idx] += (
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
)
return outs
@torch.no_grad()
# TODO may bugs here
def moe_infer(self, x, topk_ids, topk_weight):
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
cnts.scatter_(1, topk_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = expert.forward(tokens_for_this_expert)
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.type(topk_weight.dtype)
.mul_(topk_weight.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return final_out
class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: