mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 20:00:06 +00:00
Some checks failed
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
* feat(sft): AMX MoE SFT backend with LoRA support Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD: Core C++ implementation: - sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines) - moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA - amx/moe-sft-tp.hpp: AMX-specific TP implementation - avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM - amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization - worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure - ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants) Python sft/ submodule (kt_kernel.sft): - base.py: BaseSFTMoEWrapper with buffer management (template method pattern) - amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction) - autograd.py: KTMoEFunction (torch.autograd.Function for distributed training) - layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers) - arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection) - weights.py: Expert weight extraction and checkpoint loading - lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter) - wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map - config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough) - dist_utils.py: Distributed gather/scatter, checkpoint-phase detection Design decisions: - Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights - DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework interaction fields), all logic in kt_kernel.sft - Inference isolation: importing kt_kernel does not load sft/ submodule - Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally. * refactor(sft): unify KTConfig field names with kt_ prefix, add share_cache_pool, remove dead code - KTConfig fields all use kt_ prefix matching dict keys — eliminates _OLD_TO_NEW mapping and prefix-stripping in wrapper.py - Add kt_share_cache_pool field, auto-enabled when gradient_checkpointing is on (via training_args.py), flows through to C++ cache allocation - Remove dead checkpoint detection code: in_ckpt_recompute, in_ckpt_first_forward vars (assigned but never read), fallback _is_in_checkpoint_first_forward() function, unused inspect import - Remove redundant env var fallbacks in wrapper.py for share_backward_bb and share_cache_pool (KTConfig.__post_init__ already handles env vars) - Simplify layer.py checkpoint logic to single _checkpoint_hook_mode() check Verified: Qwen3-235B 3-step training on sap4, loss matches baseline (1.2886 / 1.9824 / 1.377 vs 1.2886 / 1.9766 / 1.3809) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * refactor(sft): share_backward_bb default True, share_cache_pool auto-derived - kt_share_backward_bb defaults to True (always saves memory) - kt_share_cache_pool no longer reads from env var; defaults False, auto-set to True by trainer_config_process when gradient checkpointing is enabled Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix: add missing gpu_experts_mask=None to KTMoEWrapper call in SFT wrapper KTMoEWrapper.__new__() requires gpu_experts_mask as a positional argument, but the SFT wrapper omitted it, causing MoE layer wrapping to fail silently and FSDP2 to attempt broadcasting all expert weights (OOM/NCCL crash). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * feat(sft): support transformers v5 fused expert format Fused experts (e.g. Qwen3MoeExperts) store weights as 3D Parameters (gate_up_proj [E,2I,H], down_proj [E,H,I]) instead of per-expert nn.Linear modules. PEFT cannot attach LoRA to these, so we create KT-managed LoRA buffers with kaiming init, nn.Parameter wrappers for the optimizer, and pre-assigned .grad for C++ backward. - arch.py: detect_fused_experts() detection - weights.py: fused format extraction and weight clearing - wrapper.py: detect fused at wrap time, store _fused_experts/_lora_rank - lora.py: _create_fused_expert_lora_buffers, save/load fused LoRA, get_kt_lora_params collects fused params, deduplicate wrapper finding - layer.py: handle v5 TopKRouter tuple output, remove dead code - autograd.py: sync_forward_sft/submit_forward_sft API rename Verified: v5 loss/expert-LoRA values match v4 baseline, v4 backward compat. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * feat(sft): add Qwen3.5 MoE support + fused checkpoint loading - arch.py: add Qwen3_5Moe arch match, read config from text_config, _get_layers_prefix returns model.language_model.layers for Qwen3.5, _get_model_container_and_layers searches language_model attr - weights.py: load_experts_from_checkpoint_files detects fused format (gate_up_proj in weight_map) and splits into gate/up/down - wrapper.py: hidden_size fallback to text_config Verified: Qwen3.5-35B-A3B (256 experts, fused format) E2E pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * [fix](sft): align Python API with C++ backend after v5 refactor - wrapper.py: pass gpu_experts_mask=None to KTMoEWrapper (required by C++ signature) - layer.py: rename submit_forward_sft/sync_forward_sft to submit_forward/sync_forward - autograd.py: rename sync_forward_sft to sync_forward The sft-v5 refactor (commits58d7eab,dd1da65) renamed Python-side method calls but the C++ backend (AMXSFTMoEWrapper) still exposes the original method names. This caused AttributeError on Qwen3.5-35B and other models. * align sft branch with main: revert worker_pool, strip sft_timer, fix inference defaults - Revert worker_pool.cpp/.h to main (remove RDTSC timer, Chrome Trace, sft_timer namespace, ITT API, extended do_work_stealing_job API) - Strip all sft_timer instrumentation from sft-only files (sft_moe.hpp, moe-sft-tp.hpp, avx_kernels.hpp) - Restore pin_memory=True in KExpertsCPUBuffer (inference path) - Restore fused tensor transpose logic in convert_cpu_weights.py (main layout) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * revert CMakeLists.txt to main: remove debug flags and cpptrace dep Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * clean up dev artifacts: remove SFT design docs, debug examples, bench scripts Remove files not needed in the merge: - docs/SFT+KTWrapper/ (6 Chinese design docs) - docs/sft_moe_amx/ (21 dev/debug docs) - 12 debug/test example scripts - 6 SFT-specific bench scripts and report Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * remove dev version stamps from ext_bindings, sft_moe, moe-sft-tp Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: JimmyPeilinLi <lipeilin@mail.nwpu.edu.cn>
402 lines
15 KiB
Python
402 lines
15 KiB
Python
# Base classes for SFT MoE operations
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
"""
|
|
SFT (Supervised Fine-Tuning) MoE base classes and buffer management.
|
|
|
|
Provides:
|
|
- KExpertsSFTBuffer: Grow-only shared buffer for forward/backward passes
|
|
- BaseSFTMoEWrapper: Abstract base with concrete buffer management (template method pattern)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import torch
|
|
from typing import Optional, Tuple
|
|
from abc import ABC, abstractmethod
|
|
|
|
from ..experts_base import _MoEBase
|
|
|
|
|
|
class KExpertsSFTBuffer:
|
|
"""
|
|
CPU buffer management for SFT expert computation.
|
|
|
|
Single grow-only buffer (never shrinks). Callers must use [:qlen] slicing
|
|
since the buffer may be larger than the current batch.
|
|
"""
|
|
|
|
_shared_buffer: Optional["KExpertsSFTBuffer"] = None
|
|
|
|
def __init__(
|
|
self,
|
|
qlen: int,
|
|
hidden_size: int,
|
|
moe_intermediate_size: int,
|
|
num_experts: int,
|
|
num_experts_per_tok: int,
|
|
lora_rank: int,
|
|
dtype: torch.dtype = torch.bfloat16,
|
|
):
|
|
self.qlen = qlen
|
|
self.hidden_size = hidden_size
|
|
self.moe_intermediate_size = moe_intermediate_size
|
|
self.num_experts = num_experts
|
|
self.num_experts_per_tok = num_experts_per_tok
|
|
self.lora_rank = lora_rank
|
|
self.dtype = dtype
|
|
|
|
pin_memory = False
|
|
|
|
# Forward buffers
|
|
self.input_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory)
|
|
self.expert_ids_cpu = torch.empty(
|
|
(qlen, num_experts_per_tok), dtype=torch.int64, device="cpu", pin_memory=pin_memory
|
|
)
|
|
self.weights_cpu = torch.empty(
|
|
(qlen, num_experts_per_tok), dtype=torch.float32, device="cpu", pin_memory=pin_memory
|
|
)
|
|
self.output_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory)
|
|
|
|
# Backward buffers
|
|
self.grad_output_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory)
|
|
self.grad_input_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory)
|
|
self.grad_weights = torch.empty((qlen, num_experts_per_tok), dtype=torch.float32, device="cpu")
|
|
|
|
# Batch size tensor for C++ interface
|
|
self.bsz_tensor = torch.tensor([qlen], dtype=torch.int32, device="cpu")
|
|
|
|
@classmethod
|
|
def get_buffer(
|
|
cls,
|
|
qlen: int,
|
|
hidden_size: int,
|
|
moe_intermediate_size: int,
|
|
num_experts: int,
|
|
num_experts_per_tok: int,
|
|
lora_rank: int,
|
|
dtype: torch.dtype = torch.bfloat16,
|
|
) -> "KExpertsSFTBuffer":
|
|
"""Get or grow the single shared buffer. Only reallocates when qlen exceeds capacity."""
|
|
buf = cls._shared_buffer
|
|
if buf is not None and qlen <= buf.qlen:
|
|
return buf
|
|
cls._shared_buffer = cls(
|
|
qlen=qlen,
|
|
hidden_size=hidden_size,
|
|
moe_intermediate_size=moe_intermediate_size,
|
|
num_experts=num_experts,
|
|
num_experts_per_tok=num_experts_per_tok,
|
|
lora_rank=lora_rank,
|
|
dtype=dtype,
|
|
)
|
|
return cls._shared_buffer
|
|
|
|
@classmethod
|
|
def clear_cache(cls) -> None:
|
|
"""Clear the shared buffer."""
|
|
cls._shared_buffer = None
|
|
|
|
|
|
class BaseSFTMoEWrapper(_MoEBase, ABC):
|
|
"""
|
|
Base class for SFT MoE CPU operations with concrete buffer management.
|
|
|
|
Subclasses implement:
|
|
- _make_forward_task(buffer, save_for_backward) -> C++ task object
|
|
- _make_backward_task(buffer) -> C++ task object
|
|
- load_weights(physical_to_logical_map_cpu)
|
|
- init_lora_weights(...)
|
|
- update_lora_weights()
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
layer_idx: int,
|
|
num_experts: int,
|
|
num_experts_per_tok: int,
|
|
hidden_size: int,
|
|
moe_intermediate_size: int,
|
|
num_gpu_experts: int,
|
|
cpuinfer_threads: int,
|
|
threadpool_count: int,
|
|
weight_path: str,
|
|
chunked_prefill_size: int,
|
|
lora_rank: int = 16,
|
|
lora_alpha: float = 32.0,
|
|
max_cache_depth: int = 1,
|
|
):
|
|
self.cpu_infer = self._get_cpu_infer(cpuinfer_threads, threadpool_count)
|
|
|
|
self._validate_base_config(
|
|
num_experts=num_experts,
|
|
hidden_size=hidden_size,
|
|
moe_intermediate_size=moe_intermediate_size,
|
|
num_experts_per_tok=num_experts_per_tok,
|
|
)
|
|
self._validate_sft_config(lora_rank, lora_alpha, max_cache_depth)
|
|
|
|
self.layer_idx = layer_idx
|
|
self.num_experts = num_experts
|
|
self.num_experts_per_tok = num_experts_per_tok
|
|
self.hidden_size = hidden_size
|
|
self.moe_intermediate_size = moe_intermediate_size
|
|
self.num_gpu_experts = num_gpu_experts
|
|
self.weight_path = weight_path
|
|
self.chunked_prefill_size = chunked_prefill_size
|
|
self.threadpool_count = threadpool_count
|
|
|
|
self.lora_rank = lora_rank
|
|
self.lora_alpha = lora_alpha
|
|
self.lora_scaling = lora_alpha / lora_rank
|
|
self.max_cache_depth = max_cache_depth
|
|
|
|
self.gate_lora_a: Optional[torch.Tensor] = None
|
|
self.gate_lora_b: Optional[torch.Tensor] = None
|
|
self.up_lora_a: Optional[torch.Tensor] = None
|
|
self.up_lora_b: Optional[torch.Tensor] = None
|
|
self.down_lora_a: Optional[torch.Tensor] = None
|
|
self.down_lora_b: Optional[torch.Tensor] = None
|
|
|
|
self._weights_loaded: bool = False
|
|
self._lora_initialized: bool = False
|
|
self._cache_depth: int = 0
|
|
self._is_skip_lora: bool = False
|
|
|
|
self.moe = None
|
|
|
|
@staticmethod
|
|
def _validate_sft_config(lora_rank: int, lora_alpha: float, max_cache_depth: int) -> None:
|
|
if lora_rank <= 0:
|
|
raise ValueError(f"lora_rank must be positive, got {lora_rank}")
|
|
if lora_alpha <= 0:
|
|
raise ValueError(f"lora_alpha must be positive, got {lora_alpha}")
|
|
if max_cache_depth <= 0:
|
|
raise ValueError(f"max_cache_depth must be positive, got {max_cache_depth}")
|
|
|
|
# ========== Abstract methods for subclasses ==========
|
|
|
|
@abstractmethod
|
|
def _make_forward_task(self, buffer: KExpertsSFTBuffer, save_for_backward: bool):
|
|
"""Construct the C++ forward task object. Backend-specific."""
|
|
...
|
|
|
|
@abstractmethod
|
|
def _make_backward_task(self, buffer: KExpertsSFTBuffer):
|
|
"""Construct the C++ backward task object. Backend-specific."""
|
|
...
|
|
|
|
@abstractmethod
|
|
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor) -> None:
|
|
...
|
|
|
|
@abstractmethod
|
|
def init_lora_weights(
|
|
self,
|
|
gate_lora_a: torch.Tensor, gate_lora_b: torch.Tensor,
|
|
up_lora_a: torch.Tensor, up_lora_b: torch.Tensor,
|
|
down_lora_a: torch.Tensor, down_lora_b: torch.Tensor,
|
|
grad_gate_lora_a: torch.Tensor, grad_gate_lora_b: torch.Tensor,
|
|
grad_up_lora_a: torch.Tensor, grad_up_lora_b: torch.Tensor,
|
|
grad_down_lora_a: torch.Tensor, grad_down_lora_b: torch.Tensor,
|
|
) -> None:
|
|
...
|
|
|
|
@abstractmethod
|
|
def update_lora_weights(self) -> None:
|
|
...
|
|
|
|
# ========== Buffer helpers ==========
|
|
|
|
def _get_buffer(self, qlen: int) -> KExpertsSFTBuffer:
|
|
return KExpertsSFTBuffer.get_buffer(
|
|
qlen=qlen,
|
|
hidden_size=self.hidden_size,
|
|
moe_intermediate_size=self.moe_intermediate_size,
|
|
num_experts=self.num_experts,
|
|
num_experts_per_tok=self.num_experts_per_tok,
|
|
lora_rank=self.lora_rank,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
|
|
def _validate_forward_inputs(self, hidden_states: torch.Tensor, expert_ids: torch.Tensor, weights: torch.Tensor):
|
|
if not self._weights_loaded:
|
|
raise RuntimeError("Weights not loaded. Call load_weights() or load_weights_from_tensors() first.")
|
|
if not self._lora_initialized and not self._is_skip_lora:
|
|
raise RuntimeError("LoRA weights not initialized. Call init_lora_weights() first.")
|
|
qlen = hidden_states.shape[0]
|
|
if qlen > self.chunked_prefill_size:
|
|
raise ValueError(
|
|
f"qlen ({qlen}) exceeds chunked_prefill_size ({self.chunked_prefill_size}). "
|
|
"Increase chunked_prefill_size or reduce qlen to avoid buffer overrun."
|
|
)
|
|
if expert_ids.shape[0] != qlen or expert_ids.shape[1] != self.num_experts_per_tok:
|
|
raise ValueError(
|
|
f"expert_ids shape {tuple(expert_ids.shape)} must be ({qlen}, {self.num_experts_per_tok})."
|
|
)
|
|
if weights.shape[0] != qlen or weights.shape[1] != self.num_experts_per_tok:
|
|
raise ValueError(
|
|
f"weights shape {tuple(weights.shape)} must be ({qlen}, {self.num_experts_per_tok})."
|
|
)
|
|
|
|
def _copy_inputs_to_buffer(self, buffer: KExpertsSFTBuffer, hidden_states: torch.Tensor,
|
|
expert_ids: torch.Tensor, weights: torch.Tensor, qlen: int) -> torch.device:
|
|
"""Copy inputs to CPU buffer, return input device."""
|
|
input_device = hidden_states.device
|
|
buffer.input_cpu[:qlen].copy_(hidden_states.to(torch.bfloat16), non_blocking=True)
|
|
buffer.expert_ids_cpu[:qlen].copy_(expert_ids.to(torch.int64), non_blocking=True)
|
|
buffer.weights_cpu[:qlen].copy_(weights.to(torch.float32), non_blocking=True)
|
|
buffer.bsz_tensor[0] = qlen
|
|
if input_device.type == "cuda":
|
|
torch.cuda.synchronize(input_device)
|
|
return input_device
|
|
|
|
def _copy_grad_output_to_cpu(self, buffer: KExpertsSFTBuffer, grad_output: torch.Tensor, qlen: int):
|
|
"""Copy grad_output to CPU buffer."""
|
|
input_device = grad_output.device
|
|
if input_device.type == "cuda":
|
|
torch.cuda.synchronize(input_device)
|
|
buffer.grad_output_cpu[:qlen].copy_(grad_output.to(torch.bfloat16))
|
|
|
|
def _return_output(self, buffer: KExpertsSFTBuffer, qlen: int, output_device: Optional[torch.device]):
|
|
if output_device is not None:
|
|
return buffer.output_cpu[:qlen].to(device=output_device, non_blocking=True)
|
|
else:
|
|
return buffer.output_cpu[:qlen].clone()
|
|
|
|
def _return_grads(self, buffer: KExpertsSFTBuffer, qlen: int, output_device: Optional[torch.device]):
|
|
if output_device is not None:
|
|
grad_input = buffer.grad_input_cpu[:qlen].to(device=output_device, non_blocking=True)
|
|
grad_weights = buffer.grad_weights[:qlen].to(device=output_device, non_blocking=True)
|
|
else:
|
|
grad_input = buffer.grad_input_cpu[:qlen].clone()
|
|
grad_weights = buffer.grad_weights[:qlen].clone()
|
|
return grad_input, grad_weights
|
|
|
|
# ========== Concrete forward/backward ==========
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
expert_ids: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
save_for_backward: bool = True,
|
|
output_device: Optional[torch.device] = None,
|
|
) -> torch.Tensor:
|
|
"""Synchronous forward pass with optional gradient caching."""
|
|
self._validate_forward_inputs(hidden_states, expert_ids, weights)
|
|
qlen = hidden_states.shape[0]
|
|
buffer = self._get_buffer(qlen)
|
|
self._copy_inputs_to_buffer(buffer, hidden_states, expert_ids, weights, qlen)
|
|
|
|
self.cpu_infer.submit(self._make_forward_task(buffer, save_for_backward))
|
|
self.cpu_infer.sync()
|
|
|
|
if save_for_backward and self._cache_depth == 0:
|
|
self._cache_depth += 1
|
|
|
|
return self._return_output(buffer, qlen, output_device)
|
|
|
|
def backward(
|
|
self,
|
|
grad_output: torch.Tensor,
|
|
output_device: Optional[torch.device] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Backward pass computing grad_input and grad_weights."""
|
|
if self._cache_depth <= 0:
|
|
raise RuntimeError("No forward cache available. Call forward(save_for_backward=True) first.")
|
|
|
|
qlen = grad_output.shape[0]
|
|
buffer = self._get_buffer(qlen)
|
|
self._copy_grad_output_to_cpu(buffer, grad_output, qlen)
|
|
|
|
self.cpu_infer.submit(self._make_backward_task(buffer))
|
|
self.cpu_infer.sync()
|
|
|
|
self._cache_depth -= 1
|
|
return self._return_grads(buffer, qlen, output_device)
|
|
|
|
# ========== Async forward ==========
|
|
|
|
def submit_forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
expert_ids: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
save_for_backward: bool = True,
|
|
) -> None:
|
|
"""Submit forward pass asynchronously (non-blocking). Call sync_forward() to get results."""
|
|
self._validate_forward_inputs(hidden_states, expert_ids, weights)
|
|
qlen = hidden_states.shape[0]
|
|
buffer = self._get_buffer(qlen)
|
|
self._copy_inputs_to_buffer(buffer, hidden_states, expert_ids, weights, qlen)
|
|
|
|
self._pending_buffer = buffer
|
|
self._pending_save_for_backward = save_for_backward
|
|
self._pending_qlen = qlen
|
|
|
|
self.cpu_infer.submit(self._make_forward_task(buffer, save_for_backward))
|
|
|
|
def sync_forward(self, output_device: Optional[torch.device] = None) -> torch.Tensor:
|
|
"""Synchronize and retrieve forward results. Must be called after submit_forward()."""
|
|
if not hasattr(self, "_pending_buffer") or self._pending_buffer is None:
|
|
raise RuntimeError("No pending forward. Call submit_forward() first.")
|
|
|
|
self.cpu_infer.sync()
|
|
|
|
buffer = self._pending_buffer
|
|
save_for_backward = self._pending_save_for_backward
|
|
qlen = self._pending_qlen
|
|
|
|
if save_for_backward and self._cache_depth == 0:
|
|
self._cache_depth += 1
|
|
|
|
self._pending_buffer = None
|
|
self._pending_save_for_backward = None
|
|
self._pending_qlen = None
|
|
|
|
return self._return_output(buffer, qlen, output_device)
|
|
|
|
# ========== Async backward ==========
|
|
|
|
def submit_backward_async(
|
|
self,
|
|
grad_output: torch.Tensor,
|
|
output_device: Optional[torch.device] = None,
|
|
) -> None:
|
|
"""Submit backward task without waiting. Call sync_backward() for results."""
|
|
if self._cache_depth <= 0:
|
|
raise RuntimeError("No forward cache available. Call forward(save_for_backward=True) first.")
|
|
|
|
qlen = grad_output.shape[0]
|
|
buffer = self._get_buffer(qlen)
|
|
self._copy_grad_output_to_cpu(buffer, grad_output, qlen)
|
|
|
|
self.cpu_infer.submit(self._make_backward_task(buffer))
|
|
self._async_bwd_qlen = qlen
|
|
self._async_bwd_output_device = output_device
|
|
|
|
def sync_backward(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Wait for async backward and return results."""
|
|
self.cpu_infer.sync()
|
|
|
|
qlen = self._async_bwd_qlen
|
|
output_device = self._async_bwd_output_device
|
|
buffer = self._get_buffer(qlen)
|
|
|
|
self._cache_depth -= 1
|
|
return self._return_grads(buffer, qlen, output_device)
|
|
|
|
# ========== Backward repack (optional, subclasses may override) ==========
|
|
|
|
def submit_backward_repack(self):
|
|
if not self._weights_loaded or self.moe is None:
|
|
return
|
|
if hasattr(self.moe, 'submit_backward_repack'):
|
|
self.moe.submit_backward_repack()
|
|
|
|
def wait_backward_repack(self):
|
|
if not self._weights_loaded or self.moe is None:
|
|
return
|
|
if hasattr(self.moe, 'wait_backward_repack'):
|
|
self.moe.wait_backward_repack()
|