mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
update rope calculation; update modeling.py; update gate for moe
This commit is contained in:
parent
5a50b34627
commit
f873558a89
11 changed files with 402 additions and 412 deletions
|
@ -519,7 +519,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
|
|||
|
||||
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2MoE
|
||||
from ktransformers.models.modeling_deepseekv3 import DeepseekV3MoE
|
||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
|
@ -734,9 +734,10 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
|
|||
identity = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
topk_idx, topk_weight= self.gate(hidden_states)
|
||||
topk_idx, topk_weight, router_logits= self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
# only for generate phase
|
||||
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
|
||||
if self.config.n_shared_experts is not None:
|
||||
|
@ -744,7 +745,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
|
|||
y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)
|
||||
y += y_
|
||||
y.resize_(*orig_shape)
|
||||
return y
|
||||
return y, router_logits
|
||||
|
||||
if self.config.n_shared_experts is not None:
|
||||
y_ = self.shared_experts(identity).squeeze(0)
|
||||
|
@ -767,7 +768,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
|
|||
)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y += y_
|
||||
return y
|
||||
return y, router_logits
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue