mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
add balance-serve, support concurrence
This commit is contained in:
parent
8d0292aa44
commit
25cee5810e
196 changed files with 22077 additions and 565 deletions
|
@ -99,6 +99,7 @@ class DeepseekV3RMSNorm(nn.Module):
|
|||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
|
@ -398,7 +399,6 @@ class MoEGate(nn.Module):
|
|||
self.n_routed_experts = config.n_routed_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.scoring_func = config.scoring_func
|
||||
self.seq_aux = config.seq_aux
|
||||
self.topk_method = config.topk_method
|
||||
self.n_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
|
@ -436,6 +436,7 @@ class MoEGate(nn.Module):
|
|||
|
||||
### select top-k experts
|
||||
if self.topk_method == "noaux_tc":
|
||||
assert not self.training
|
||||
scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
|
||||
group_scores = (
|
||||
scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1)
|
||||
|
@ -454,7 +455,7 @@ class MoEGate(nn.Module):
|
|||
)
|
||||
.reshape(bsz * seq_len, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
|
||||
_, topk_idx = torch.topk(
|
||||
tmp_scores, k=self.top_k, dim=-1, sorted=False
|
||||
)
|
||||
|
@ -1933,4 +1934,4 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
|
|||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue