mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
support qwen3, dont speak human language
This commit is contained in:
parent
f3d842a0ca
commit
3f9bbf1181
30 changed files with 3696 additions and 290 deletions
|
@ -122,3 +122,72 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
|
|||
self.e_score_correction_bias = None
|
||||
|
||||
|
||||
class KMoEGateQwen2Moe(BaseInjectedModule, KMoEGateBase):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module = None,
|
||||
generate_device: str = "cuda",
|
||||
generate_op: str| None = "KLinearMarlin",
|
||||
prefill_device: str = "cuda",
|
||||
prefill_op: str| None = "KLinearMarlin",
|
||||
use_quant: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
||||
KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
self.generate_op = generate_op
|
||||
self.prefill_op = prefill_op
|
||||
self.is_windows = os.name == 'nt'
|
||||
self.use_quant = use_quant
|
||||
if not self.is_windows and use_quant:
|
||||
self.gate_linear = nn.Linear(self.gating_dim, self.n_routed_experts, device=generate_device)
|
||||
self.gate_linear = KTransformersLinear(key + ".ffn_gate_inp",
|
||||
gguf_loader, config, self.gate_linear, #orig_module
|
||||
generate_device, generate_op, prefill_device, prefill_op)
|
||||
else:
|
||||
self.gate_linear = None
|
||||
|
||||
def forward(self, hidden_states) -> torch.Tensor:
|
||||
if self.is_windows:
|
||||
return self.orig_module.forward(hidden_states)
|
||||
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
### compute gating score
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
if self.use_quant:
|
||||
logits = self.gate_linear.forward(logits)
|
||||
else:
|
||||
logits = F.linear(
|
||||
hidden_states.type(torch.float32), self.weight.type(torch.float32), None
|
||||
)
|
||||
|
||||
return grouped_topk(hidden_states, logits,
|
||||
self.top_k, self.norm_topk_prob,
|
||||
self.n_group, self.topk_group)
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||
if device is None: device = self.device
|
||||
if w is None: w = self.load_weights(device=device)
|
||||
|
||||
if isinstance(w, dict):
|
||||
self.weight_type = w["weight_type"]
|
||||
self.e_score_correction_bias_type = w["e_score_correction_bias_type"]
|
||||
self.orig_module.weight = nn.Parameter(w["weight"])
|
||||
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
|
||||
else:
|
||||
raise ValueError("Invalid weight type")
|
||||
self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device))
|
||||
self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device))
|
||||
if not self.is_windows and self.use_quant:
|
||||
self.gate_linear.load(self.orig_module.weight)
|
||||
|
||||
def unload(self):
|
||||
if self.weight is not None:
|
||||
self.weight = None
|
||||
if self.e_score_correction_bias is not None:
|
||||
self.e_score_correction_bias = None
|
Loading…
Add table
Add a link
Reference in a new issue