kvcache-ai-ktransformers/kt-kernel/python/sft/amx.py
mrhaoxx 9544a8960d
Some checks failed
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
feat(sft): AMX MoE SFT backend with LoRA support (#1936)
* feat(sft): AMX MoE SFT backend with LoRA support

Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD:

Core C++ implementation:
- sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines)
- moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA
- amx/moe-sft-tp.hpp: AMX-specific TP implementation
- avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM
- amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization
- worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure
- ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants)

Python sft/ submodule (kt_kernel.sft):
- base.py: BaseSFTMoEWrapper with buffer management (template method pattern)
- amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction)
- autograd.py: KTMoEFunction (torch.autograd.Function for distributed training)
- layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers)
- arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection)
- weights.py: Expert weight extraction and checkpoint loading
- lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter)
- wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map
- config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough)
- dist_utils.py: Distributed gather/scatter, checkpoint-phase detection

Design decisions:
- Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights
- DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework
  interaction fields), all logic in kt_kernel.sft
- Inference isolation: importing kt_kernel does not load sft/ submodule
- Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically

Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally.

* refactor(sft): unify KTConfig field names with kt_ prefix, add share_cache_pool, remove dead code

- KTConfig fields all use kt_ prefix matching dict keys — eliminates
  _OLD_TO_NEW mapping and prefix-stripping in wrapper.py
- Add kt_share_cache_pool field, auto-enabled when gradient_checkpointing
  is on (via training_args.py), flows through to C++ cache allocation
- Remove dead checkpoint detection code: in_ckpt_recompute,
  in_ckpt_first_forward vars (assigned but never read), fallback
  _is_in_checkpoint_first_forward() function, unused inspect import
- Remove redundant env var fallbacks in wrapper.py for share_backward_bb
  and share_cache_pool (KTConfig.__post_init__ already handles env vars)
- Simplify layer.py checkpoint logic to single _checkpoint_hook_mode() check

Verified: Qwen3-235B 3-step training on sap4, loss matches baseline
(1.2886 / 1.9824 / 1.377 vs 1.2886 / 1.9766 / 1.3809)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* refactor(sft): share_backward_bb default True, share_cache_pool auto-derived

- kt_share_backward_bb defaults to True (always saves memory)
- kt_share_cache_pool no longer reads from env var; defaults False,
  auto-set to True by trainer_config_process when gradient checkpointing
  is enabled

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix: add missing gpu_experts_mask=None to KTMoEWrapper call in SFT wrapper

KTMoEWrapper.__new__() requires gpu_experts_mask as a positional argument,
but the SFT wrapper omitted it, causing MoE layer wrapping to fail silently
and FSDP2 to attempt broadcasting all expert weights (OOM/NCCL crash).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(sft): support transformers v5 fused expert format

Fused experts (e.g. Qwen3MoeExperts) store weights as 3D Parameters
(gate_up_proj [E,2I,H], down_proj [E,H,I]) instead of per-expert
nn.Linear modules. PEFT cannot attach LoRA to these, so we create
KT-managed LoRA buffers with kaiming init, nn.Parameter wrappers
for the optimizer, and pre-assigned .grad for C++ backward.

- arch.py: detect_fused_experts() detection
- weights.py: fused format extraction and weight clearing
- wrapper.py: detect fused at wrap time, store _fused_experts/_lora_rank
- lora.py: _create_fused_expert_lora_buffers, save/load fused LoRA,
  get_kt_lora_params collects fused params, deduplicate wrapper finding
- layer.py: handle v5 TopKRouter tuple output, remove dead code
- autograd.py: sync_forward_sft/submit_forward_sft API rename

Verified: v5 loss/expert-LoRA values match v4 baseline, v4 backward compat.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(sft): add Qwen3.5 MoE support + fused checkpoint loading

- arch.py: add Qwen3_5Moe arch match, read config from text_config,
  _get_layers_prefix returns model.language_model.layers for Qwen3.5,
  _get_model_container_and_layers searches language_model attr
- weights.py: load_experts_from_checkpoint_files detects fused format
  (gate_up_proj in weight_map) and splits into gate/up/down
- wrapper.py: hidden_size fallback to text_config

Verified: Qwen3.5-35B-A3B (256 experts, fused format) E2E pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [fix](sft): align Python API with C++ backend after v5 refactor

- wrapper.py: pass gpu_experts_mask=None to KTMoEWrapper (required by C++ signature)
- layer.py: rename submit_forward_sft/sync_forward_sft to submit_forward/sync_forward
- autograd.py: rename sync_forward_sft to sync_forward

The sft-v5 refactor (commits 58d7eab, dd1da65) renamed Python-side method
calls but the C++ backend (AMXSFTMoEWrapper) still exposes the original
method names. This caused AttributeError on Qwen3.5-35B and other models.

* align sft branch with main: revert worker_pool, strip sft_timer, fix inference defaults

- Revert worker_pool.cpp/.h to main (remove RDTSC timer, Chrome Trace,
  sft_timer namespace, ITT API, extended do_work_stealing_job API)
- Strip all sft_timer instrumentation from sft-only files (sft_moe.hpp,
  moe-sft-tp.hpp, avx_kernels.hpp)
- Restore pin_memory=True in KExpertsCPUBuffer (inference path)
- Restore fused tensor transpose logic in convert_cpu_weights.py (main layout)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* revert CMakeLists.txt to main: remove debug flags and cpptrace dep

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* clean up dev artifacts: remove SFT design docs, debug examples, bench scripts

Remove files not needed in the merge:
- docs/SFT+KTWrapper/ (6 Chinese design docs)
- docs/sft_moe_amx/ (21 dev/debug docs)
- 12 debug/test example scripts
- 6 SFT-specific bench scripts and report

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* remove dev version stamps from ext_bindings, sft_moe, moe-sft-tp

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: JimmyPeilinLi <lipeilin@mail.nwpu.edu.cn>
2026-04-22 11:27:01 +08:00

434 lines
17 KiB
Python

# AMX SFT MoE Wrapper implementation
# SPDX-License-Identifier: Apache-2.0
"""
AMX-based SFT MoE Wrapper. Forward/backward buffer management is in base class;
this file handles weight loading, LoRA init, and C++ task construction.
"""
from __future__ import annotations
import ctypes
import os
import glob as _glob
import torch
from typing import Optional, List
from kt_kernel_ext.moe import MOESFTConfig
from ..utils.loader import BF16SafeTensorLoader, SafeTensorLoader
try:
from kt_kernel_ext.moe import (
AMXBF16_SFT_MOE,
AMXInt8_SFT_MOE,
AMXInt4_SFT_MOE,
AMXBF16_SFT_MOE_SkipLoRA,
AMXInt8_SFT_MOE_SkipLoRA,
AMXInt4_SFT_MOE_SkipLoRA,
)
_HAS_AMX_SFT_SUPPORT = True
except (ImportError, AttributeError):
_HAS_AMX_SFT_SUPPORT = False
AMXBF16_SFT_MOE = None
AMXInt8_SFT_MOE = None
AMXInt4_SFT_MOE = None
AMXBF16_SFT_MOE_SkipLoRA = None
AMXInt8_SFT_MOE_SkipLoRA = None
AMXInt4_SFT_MOE_SkipLoRA = None
from .base import BaseSFTMoEWrapper, KExpertsSFTBuffer
# Mapping from method string to C++ SFT MOE class
_SFT_METHOD_TO_CLASS = {
"AMXBF16_SFT": AMXBF16_SFT_MOE,
"AMXINT8_SFT": AMXInt8_SFT_MOE,
"AMXINT4_SFT": AMXInt4_SFT_MOE,
"AMXBF16_SFT_SkipLoRA": AMXBF16_SFT_MOE_SkipLoRA,
"AMXINT8_SFT_SkipLoRA": AMXInt8_SFT_MOE_SkipLoRA,
"AMXINT4_SFT_SkipLoRA": AMXInt4_SFT_MOE_SkipLoRA,
}
class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
"""
AMX-based SFT MoE wrapper.
Supports BF16, INT8, INT4, and SkipLoRA variants.
Forward/backward buffer management is in BaseSFTMoEWrapper;
this class implements weight loading and C++ task construction.
"""
def __init__(
self,
layer_idx: int,
num_experts: int,
num_experts_per_tok: int,
hidden_size: int,
moe_intermediate_size: int,
num_gpu_experts: int,
cpuinfer_threads: int,
threadpool_count: int,
weight_path: str,
chunked_prefill_size: int,
lora_rank: int = 16,
lora_alpha: float = 32.0,
max_cache_depth: int = 1,
method: str = "AMXBF16_SFT",
group_size: int = 128,
zero_point: bool = True,
):
if not _HAS_AMX_SFT_SUPPORT:
raise RuntimeError(
"AMX SFT backend not available. kt_kernel_ext was not compiled with AMX SFT support.\n"
"Please recompile with AMX SFT enabled."
)
super().__init__(
layer_idx=layer_idx,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
hidden_size=hidden_size,
moe_intermediate_size=moe_intermediate_size,
num_gpu_experts=num_gpu_experts,
cpuinfer_threads=cpuinfer_threads,
threadpool_count=threadpool_count,
weight_path=weight_path,
chunked_prefill_size=chunked_prefill_size,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
max_cache_depth=max_cache_depth,
)
self.method = method
self._is_skip_lora = "SkipLoRA" in method
self.group_size = group_size
self.zero_point = zero_point
if method not in _SFT_METHOD_TO_CLASS:
raise ValueError(f"Unknown SFT method: {method}. Supported: {list(_SFT_METHOD_TO_CLASS.keys())}")
moe_class = _SFT_METHOD_TO_CLASS[method]
if moe_class is None:
raise RuntimeError(f"AMX SFT method '{method}' not available in current build.")
self.gate_proj: Optional[torch.Tensor] = None
self.up_proj: Optional[torch.Tensor] = None
self.down_proj: Optional[torch.Tensor] = None
self._moe_class = moe_class
# ========== Template method: C++ task construction ==========
def _make_forward_task(self, buffer: KExpertsSFTBuffer, save_for_backward: bool):
return self.moe.forward_sft_task(
buffer.bsz_tensor.data_ptr(),
self.num_experts_per_tok,
buffer.expert_ids_cpu.data_ptr(),
buffer.weights_cpu.data_ptr(),
buffer.input_cpu.data_ptr(),
buffer.output_cpu.data_ptr(),
save_for_backward,
)
def _make_backward_task(self, buffer: KExpertsSFTBuffer):
if self._is_skip_lora:
return self.moe.backward_task(
buffer.grad_output_cpu.data_ptr(),
buffer.grad_input_cpu.data_ptr(),
0, 0, 0, 0, 0, 0,
buffer.grad_weights.data_ptr(),
)
return self.moe.backward_task(
buffer.grad_output_cpu.data_ptr(),
buffer.grad_input_cpu.data_ptr(),
self.grad_gate_lora_a.data_ptr(),
self.grad_gate_lora_b.data_ptr(),
self.grad_up_lora_a.data_ptr(),
self.grad_up_lora_b.data_ptr(),
self.grad_down_lora_a.data_ptr(),
self.grad_down_lora_b.data_ptr(),
buffer.grad_weights.data_ptr(),
)
# ========== Weight loading ==========
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor) -> None:
if self._weights_loaded:
return
if self.gate_proj is None and not getattr(self, "_use_projs_path", False):
self._load_base_weights_from_file()
config = MOESFTConfig()
config.expert_num = self.num_experts
config.num_experts_per_tok = self.num_experts_per_tok
config.hidden_size = self.hidden_size
config.intermediate_size = self.moe_intermediate_size
config.lora_rank = self.lora_rank
config.lora_alpha = self.lora_alpha
config.max_cache_depth = self.max_cache_depth
config.max_len = self.chunked_prefill_size
config.layer_idx = self.layer_idx
config.share_backward_bb = getattr(self, "share_backward_bb", False)
config.share_cache_pool = getattr(self, "share_cache_pool", False)
if getattr(self, "_use_kt_direct_load", False):
config.load = True
config.path = self.weight_path
elif getattr(self, "_use_projs_path", False):
config.gate_projs = self._gate_projs_ptrs
config.up_projs = self._up_projs_ptrs
config.down_projs = self._down_projs_ptrs
config.gate_scales = self._gate_scale_ptrs
config.up_scales = self._up_scale_ptrs
config.down_scales = self._down_scale_ptrs
if getattr(self, "_bf16_gate_proj", None) is not None:
config.gate_proj = self._bf16_gate_proj.data_ptr()
config.up_proj = self._bf16_up_proj.data_ptr()
config.down_proj = self._bf16_down_proj.data_ptr()
if getattr(self, "_has_bwd_projs", False):
config.gate_bwd_projs = self._gate_bwd_projs_ptrs
config.up_bwd_projs = self._up_bwd_projs_ptrs
config.down_bwd_projs = self._down_bwd_projs_ptrs
config.gate_bwd_scales = self._gate_bwd_scale_ptrs
config.up_bwd_scales = self._up_bwd_scale_ptrs
config.down_bwd_scales = self._down_bwd_scale_ptrs
else:
config.gate_proj = self.gate_proj.data_ptr()
config.up_proj = self.up_proj.data_ptr()
config.down_proj = self.down_proj.data_ptr()
if self._lora_initialized:
config.gate_lora_a = self.gate_lora_a.data_ptr()
config.gate_lora_b = self.gate_lora_b.data_ptr()
config.up_lora_a = self.up_lora_a.data_ptr()
config.up_lora_b = self.up_lora_b.data_ptr()
config.down_lora_a = self.down_lora_a.data_ptr()
config.down_lora_b = self.down_lora_b.data_ptr()
config.pool = self.cpu_infer.backend_
if self.method in ("AMXINT4_KGroup_SFT", "AMXINT4_1KGroup_SFT"):
config.quant_config.group_size = self.group_size
config.quant_config.zero_point = self.zero_point
self.moe = self._moe_class(config)
self.cpu_infer.submit(self.moe.load_weights_task())
self.cpu_infer.sync()
self.cpu_infer.submit(self.moe.warm_up_task())
self.cpu_infer.sync()
# Release Python-side weight tensors (C++ copied them)
self.gate_proj = None
self.up_proj = None
self.down_proj = None
if getattr(self, "_bf16_gate_proj", None) is not None:
self._bf16_gate_proj = None
self._bf16_up_proj = None
self._bf16_down_proj = None
if getattr(self, "_use_projs_path", False):
for attr in [
"_gate_weights_per_numa", "_up_weights_per_numa", "_down_weights_per_numa",
"_gate_scales_per_numa", "_up_scales_per_numa", "_down_scales_per_numa",
"_gate_projs_ptrs", "_up_projs_ptrs", "_down_projs_ptrs",
"_gate_scale_ptrs", "_up_scale_ptrs", "_down_scale_ptrs",
]:
setattr(self, attr, None)
if getattr(self, "_has_bwd_projs", False):
for attr in [
"_gate_bwd_weights_per_numa", "_up_bwd_weights_per_numa", "_down_bwd_weights_per_numa",
"_gate_bwd_scales_per_numa", "_up_bwd_scales_per_numa", "_down_bwd_scales_per_numa",
"_gate_bwd_projs_ptrs", "_up_bwd_projs_ptrs", "_down_bwd_projs_ptrs",
"_gate_bwd_scale_ptrs", "_up_bwd_scale_ptrs", "_down_bwd_scale_ptrs",
]:
setattr(self, attr, None)
self._weights_loaded = True
def load_weights_from_tensors(
self,
gate_proj: torch.Tensor,
up_proj: torch.Tensor,
down_proj: torch.Tensor,
physical_to_logical_map_cpu: torch.Tensor,
) -> None:
self.gate_proj = gate_proj.contiguous()
self.up_proj = up_proj.contiguous()
self.down_proj = down_proj.contiguous()
self.load_weights(physical_to_logical_map_cpu)
del gate_proj, up_proj, down_proj
def _load_base_weights_from_file(self) -> None:
if not hasattr(self, "weight_path") or self.weight_path is None:
raise RuntimeError(
"weight_path not set. Cannot load weights from file. "
"Either set weight_path or call load_weights_from_tensors() instead."
)
kt_layer_dir = os.path.join(self.weight_path, f"_layer_{self.layer_idx}")
if os.path.isdir(kt_layer_dir):
kt_files = _glob.glob(os.path.join(kt_layer_dir, "_numa_0", "*.kt"))
if kt_files:
self._use_kt_direct_load = True
return
if "BF16" in self.method:
loader = BF16SafeTensorLoader(self.weight_path)
base_key = f"model.layers.{self.layer_idx}"
else:
loader = SafeTensorLoader(self.weight_path)
base_key = f"blk.{self.layer_idx}"
experts_data = loader.load_experts(base_key, device="cpu")
gate_weights: List[torch.Tensor] = experts_data["gate"]
up_weights: List[torch.Tensor] = experts_data["up"]
down_weights: List[torch.Tensor] = experts_data["down"]
if "BF16" in self.method:
self.gate_proj = torch.stack(gate_weights, dim=0).contiguous()
self.up_proj = torch.stack(up_weights, dim=0).contiguous()
self.down_proj = torch.stack(down_weights, dim=0).contiguous()
else:
def _make_ptrs(arrays_per_numa):
return [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in arrays_per_numa
]
self._gate_weights_per_numa = gate_weights
self._up_weights_per_numa = up_weights
self._down_weights_per_numa = down_weights
self._gate_scales_per_numa = experts_data["gate_scale"]
self._up_scales_per_numa = experts_data["up_scale"]
self._down_scales_per_numa = experts_data["down_scale"]
self._gate_projs_ptrs = _make_ptrs(gate_weights)
self._up_projs_ptrs = _make_ptrs(up_weights)
self._down_projs_ptrs = _make_ptrs(down_weights)
self._gate_scale_ptrs = _make_ptrs(experts_data["gate_scale"])
self._up_scale_ptrs = _make_ptrs(experts_data["up_scale"])
self._down_scale_ptrs = _make_ptrs(experts_data["down_scale"])
if "gate_bwd" in experts_data:
self._gate_bwd_weights_per_numa = experts_data["gate_bwd"]
self._up_bwd_weights_per_numa = experts_data["up_bwd"]
self._down_bwd_weights_per_numa = experts_data["down_bwd"]
self._gate_bwd_scales_per_numa = experts_data["gate_bwd_scale"]
self._up_bwd_scales_per_numa = experts_data["up_bwd_scale"]
self._down_bwd_scales_per_numa = experts_data["down_bwd_scale"]
self._gate_bwd_projs_ptrs = _make_ptrs(experts_data["gate_bwd"])
self._up_bwd_projs_ptrs = _make_ptrs(experts_data["up_bwd"])
self._down_bwd_projs_ptrs = _make_ptrs(experts_data["down_bwd"])
self._gate_bwd_scale_ptrs = _make_ptrs(experts_data["gate_bwd_scale"])
self._up_bwd_scale_ptrs = _make_ptrs(experts_data["up_bwd_scale"])
self._down_bwd_scale_ptrs = _make_ptrs(experts_data["down_bwd_scale"])
self._has_bwd_projs = True
else:
self._has_bwd_projs = False
self.gate_proj = None
self.up_proj = None
self.down_proj = None
self._use_projs_path = True
loader.close_all_handles()
# ========== LoRA ==========
def init_lora_weights(
self,
gate_lora_a: torch.Tensor, gate_lora_b: torch.Tensor,
up_lora_a: torch.Tensor, up_lora_b: torch.Tensor,
down_lora_a: torch.Tensor, down_lora_b: torch.Tensor,
grad_gate_lora_a: torch.Tensor, grad_gate_lora_b: torch.Tensor,
grad_up_lora_a: torch.Tensor, grad_up_lora_b: torch.Tensor,
grad_down_lora_a: torch.Tensor, grad_down_lora_b: torch.Tensor,
) -> None:
expected_shapes = {
"gate_lora_a": (self.num_experts, self.lora_rank, self.hidden_size),
"gate_lora_b": (self.num_experts, self.moe_intermediate_size, self.lora_rank),
"up_lora_a": (self.num_experts, self.lora_rank, self.hidden_size),
"up_lora_b": (self.num_experts, self.moe_intermediate_size, self.lora_rank),
"down_lora_a": (self.num_experts, self.lora_rank, self.moe_intermediate_size),
"down_lora_b": (self.num_experts, self.hidden_size, self.lora_rank),
}
provided = {
"gate_lora_a": gate_lora_a, "gate_lora_b": gate_lora_b,
"up_lora_a": up_lora_a, "up_lora_b": up_lora_b,
"down_lora_a": down_lora_a, "down_lora_b": down_lora_b,
}
for name, tensor in provided.items():
expected = expected_shapes[name]
if tensor.shape != expected:
raise ValueError(f"{name} shape mismatch: expected {expected}, got {tuple(tensor.shape)}")
self.gate_lora_a = gate_lora_a.contiguous()
self.gate_lora_b = gate_lora_b.contiguous()
self.up_lora_a = up_lora_a.contiguous()
self.up_lora_b = up_lora_b.contiguous()
self.down_lora_a = down_lora_a.contiguous()
self.down_lora_b = down_lora_b.contiguous()
self.grad_gate_lora_a = grad_gate_lora_a.contiguous()
self.grad_gate_lora_b = grad_gate_lora_b.contiguous()
self.grad_up_lora_a = grad_up_lora_a.contiguous()
self.grad_up_lora_b = grad_up_lora_b.contiguous()
self.grad_down_lora_a = grad_down_lora_a.contiguous()
self.grad_down_lora_b = grad_down_lora_b.contiguous()
self._lora_initialized = True
if self._weights_loaded and self.moe is not None:
self.update_lora_weights()
def update_lora_weights(self) -> None:
if not self._weights_loaded:
raise RuntimeError("Weights not loaded. Call load_weights() first.")
if self._is_skip_lora:
return
if not self._lora_initialized:
raise RuntimeError("LoRA weights not initialized. Call init_lora_weights() first.")
self.cpu_infer.submit(
self.moe.update_lora_weights_task(
self.gate_lora_a.data_ptr(),
self.gate_lora_b.data_ptr(),
self.up_lora_a.data_ptr(),
self.up_lora_b.data_ptr(),
self.down_lora_a.data_ptr(),
self.down_lora_b.data_ptr(),
)
)
self.cpu_infer.sync()
def save_backward_weights_from_tensors(
self,
gate_proj: torch.Tensor,
up_proj: torch.Tensor,
down_proj: torch.Tensor,
physical_to_logical_map: torch.Tensor,
output_path: str,
) -> None:
if not self._weights_loaded:
raise RuntimeError("Weights not loaded. Call load_weights() first.")
gate_proj = gate_proj.contiguous()
up_proj = up_proj.contiguous()
down_proj = down_proj.contiguous()
self.moe.prepare_and_save_bwd(
gate_proj.data_ptr(),
up_proj.data_ptr(),
down_proj.data_ptr(),
output_path,
)