mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 11:49:51 +00:00
feat(sft): support transformers v5 fused expert format
Fused experts (e.g. Qwen3MoeExperts) store weights as 3D Parameters (gate_up_proj [E,2I,H], down_proj [E,H,I]) instead of per-expert nn.Linear modules. PEFT cannot attach LoRA to these, so we create KT-managed LoRA buffers with kaiming init, nn.Parameter wrappers for the optimizer, and pre-assigned .grad for C++ backward. - arch.py: detect_fused_experts() detection - weights.py: fused format extraction and weight clearing - wrapper.py: detect fused at wrap time, store _fused_experts/_lora_rank - lora.py: _create_fused_expert_lora_buffers, save/load fused LoRA, get_kt_lora_params collects fused params, deduplicate wrapper finding - layer.py: handle v5 TopKRouter tuple output, remove dead code - autograd.py: sync_forward_sft/submit_forward_sft API rename Verified: v5 loss/expert-LoRA values match v4 baseline, v4 backward compat. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
6d4632b8c7
commit
58d7eabb9b
6 changed files with 249 additions and 69 deletions
|
|
@ -82,10 +82,6 @@ class KTMoELayerWrapper(nn.Module):
|
|||
# PEFT LoRA tracking (set by kt_adapt_peft_lora)
|
||||
# _peft_lora_modules: {expert_idx: {proj_name: (lora_A, lora_B)}}
|
||||
self._peft_lora_modules: dict[int, dict[str, tuple[nn.Module, nn.Module]]] | None = None
|
||||
self._peft_lora_rank: int = 0
|
||||
self._peft_lora_alpha: float = 0.0
|
||||
self._skip_lora: bool = False # True when using SkipLoRA backend (no LoRA on experts)
|
||||
|
||||
self._lora_pointers_dirty = False
|
||||
|
||||
def _apply(self, fn, recurse=True):
|
||||
|
|
@ -210,7 +206,7 @@ class KTMoELayerWrapper(nn.Module):
|
|||
if rank == 0:
|
||||
if self.wrapper is None:
|
||||
raise RuntimeError("Rank0 wrapper is required in distributed KT overlap path.")
|
||||
cpu_output = self.wrapper.sync_forward(output_device=original_device)
|
||||
cpu_output = self.wrapper.sync_forward_sft(output_device=original_device)
|
||||
cpu_output = cpu_output.to(dtype=original_dtype).view(total_qlen, self.hidden_size)
|
||||
offsets = _qlen_offsets(all_qlens_list)
|
||||
scatter_list = [cpu_output[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)]
|
||||
|
|
@ -231,7 +227,7 @@ class KTMoELayerWrapper(nn.Module):
|
|||
return output
|
||||
|
||||
if self.wrapper is not None:
|
||||
cpu_output = self.wrapper.sync_forward(output_device=original_device)
|
||||
cpu_output = self.wrapper.sync_forward_sft(output_device=original_device)
|
||||
output = cpu_output.view(batch_size, seq_len, self.hidden_size).to(dtype=original_dtype)
|
||||
return output
|
||||
|
||||
|
|
@ -263,7 +259,18 @@ class KTMoELayerWrapper(nn.Module):
|
|||
topk_weights = topk_weights.to(torch.bfloat16)
|
||||
return topk_ids, topk_weights
|
||||
|
||||
router_logits = router(hidden_states.view(-1, self.hidden_size))
|
||||
router_output = router(hidden_states.view(-1, self.hidden_size))
|
||||
# transformers v5 TopKRouter returns (router_logits, router_scores, router_indices)
|
||||
# directly — scores/indices are already topk-normalized.
|
||||
if isinstance(router_output, tuple):
|
||||
if len(router_output) >= 3:
|
||||
_logits, topk_weights, topk_ids = router_output[0], router_output[1], router_output[2]
|
||||
if topk_weights.is_floating_point():
|
||||
topk_weights = topk_weights.to(torch.bfloat16)
|
||||
return topk_ids, topk_weights
|
||||
router_output = router_output[0]
|
||||
|
||||
router_logits = router_output
|
||||
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
|
||||
topk_weights, topk_ids = torch.topk(routing_weights, self.moe_config.num_experts_per_tok, dim=-1)
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
|
@ -328,7 +335,7 @@ class KTMoELayerWrapper(nn.Module):
|
|||
all_hs = torch.cat(gathered_hs, dim=0)
|
||||
all_ids = torch.cat(gathered_ids, dim=0)
|
||||
all_wts = torch.cat(gathered_wts, dim=0)
|
||||
self.wrapper.submit_forward(
|
||||
self.wrapper.submit_forward_sft(
|
||||
all_hs,
|
||||
all_ids,
|
||||
all_wts,
|
||||
|
|
@ -357,7 +364,7 @@ class KTMoELayerWrapper(nn.Module):
|
|||
submit_hs = input_flat.detach()
|
||||
submit_ids = expert_ids.detach()
|
||||
submit_wts = weights.detach()
|
||||
self.wrapper.submit_forward(
|
||||
self.wrapper.submit_forward_sft(
|
||||
submit_hs,
|
||||
submit_ids,
|
||||
submit_wts,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue