kvcache-ai-ktransformers/kt-kernel/python/experts_base.py
mrhaoxx 9544a8960d
Some checks failed
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
feat(sft): AMX MoE SFT backend with LoRA support (#1936)
* feat(sft): AMX MoE SFT backend with LoRA support

Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD:

Core C++ implementation:
- sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines)
- moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA
- amx/moe-sft-tp.hpp: AMX-specific TP implementation
- avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM
- amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization
- worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure
- ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants)

Python sft/ submodule (kt_kernel.sft):
- base.py: BaseSFTMoEWrapper with buffer management (template method pattern)
- amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction)
- autograd.py: KTMoEFunction (torch.autograd.Function for distributed training)
- layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers)
- arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection)
- weights.py: Expert weight extraction and checkpoint loading
- lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter)
- wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map
- config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough)
- dist_utils.py: Distributed gather/scatter, checkpoint-phase detection

Design decisions:
- Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights
- DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework
  interaction fields), all logic in kt_kernel.sft
- Inference isolation: importing kt_kernel does not load sft/ submodule
- Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically

Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally.

* refactor(sft): unify KTConfig field names with kt_ prefix, add share_cache_pool, remove dead code

- KTConfig fields all use kt_ prefix matching dict keys — eliminates
  _OLD_TO_NEW mapping and prefix-stripping in wrapper.py
- Add kt_share_cache_pool field, auto-enabled when gradient_checkpointing
  is on (via training_args.py), flows through to C++ cache allocation
- Remove dead checkpoint detection code: in_ckpt_recompute,
  in_ckpt_first_forward vars (assigned but never read), fallback
  _is_in_checkpoint_first_forward() function, unused inspect import
- Remove redundant env var fallbacks in wrapper.py for share_backward_bb
  and share_cache_pool (KTConfig.__post_init__ already handles env vars)
- Simplify layer.py checkpoint logic to single _checkpoint_hook_mode() check

Verified: Qwen3-235B 3-step training on sap4, loss matches baseline
(1.2886 / 1.9824 / 1.377 vs 1.2886 / 1.9766 / 1.3809)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* refactor(sft): share_backward_bb default True, share_cache_pool auto-derived

- kt_share_backward_bb defaults to True (always saves memory)
- kt_share_cache_pool no longer reads from env var; defaults False,
  auto-set to True by trainer_config_process when gradient checkpointing
  is enabled

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix: add missing gpu_experts_mask=None to KTMoEWrapper call in SFT wrapper

KTMoEWrapper.__new__() requires gpu_experts_mask as a positional argument,
but the SFT wrapper omitted it, causing MoE layer wrapping to fail silently
and FSDP2 to attempt broadcasting all expert weights (OOM/NCCL crash).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(sft): support transformers v5 fused expert format

Fused experts (e.g. Qwen3MoeExperts) store weights as 3D Parameters
(gate_up_proj [E,2I,H], down_proj [E,H,I]) instead of per-expert
nn.Linear modules. PEFT cannot attach LoRA to these, so we create
KT-managed LoRA buffers with kaiming init, nn.Parameter wrappers
for the optimizer, and pre-assigned .grad for C++ backward.

- arch.py: detect_fused_experts() detection
- weights.py: fused format extraction and weight clearing
- wrapper.py: detect fused at wrap time, store _fused_experts/_lora_rank
- lora.py: _create_fused_expert_lora_buffers, save/load fused LoRA,
  get_kt_lora_params collects fused params, deduplicate wrapper finding
- layer.py: handle v5 TopKRouter tuple output, remove dead code
- autograd.py: sync_forward_sft/submit_forward_sft API rename

Verified: v5 loss/expert-LoRA values match v4 baseline, v4 backward compat.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(sft): add Qwen3.5 MoE support + fused checkpoint loading

- arch.py: add Qwen3_5Moe arch match, read config from text_config,
  _get_layers_prefix returns model.language_model.layers for Qwen3.5,
  _get_model_container_and_layers searches language_model attr
- weights.py: load_experts_from_checkpoint_files detects fused format
  (gate_up_proj in weight_map) and splits into gate/up/down
- wrapper.py: hidden_size fallback to text_config

Verified: Qwen3.5-35B-A3B (256 experts, fused format) E2E pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [fix](sft): align Python API with C++ backend after v5 refactor

- wrapper.py: pass gpu_experts_mask=None to KTMoEWrapper (required by C++ signature)
- layer.py: rename submit_forward_sft/sync_forward_sft to submit_forward/sync_forward
- autograd.py: rename sync_forward_sft to sync_forward

The sft-v5 refactor (commits 58d7eab, dd1da65) renamed Python-side method
calls but the C++ backend (AMXSFTMoEWrapper) still exposes the original
method names. This caused AttributeError on Qwen3.5-35B and other models.

* align sft branch with main: revert worker_pool, strip sft_timer, fix inference defaults

- Revert worker_pool.cpp/.h to main (remove RDTSC timer, Chrome Trace,
  sft_timer namespace, ITT API, extended do_work_stealing_job API)
- Strip all sft_timer instrumentation from sft-only files (sft_moe.hpp,
  moe-sft-tp.hpp, avx_kernels.hpp)
- Restore pin_memory=True in KExpertsCPUBuffer (inference path)
- Restore fused tensor transpose logic in convert_cpu_weights.py (main layout)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* revert CMakeLists.txt to main: remove debug flags and cpptrace dep

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* clean up dev artifacts: remove SFT design docs, debug examples, bench scripts

Remove files not needed in the merge:
- docs/SFT+KTWrapper/ (6 Chinese design docs)
- docs/sft_moe_amx/ (21 dev/debug docs)
- 12 debug/test example scripts
- 6 SFT-specific bench scripts and report

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* remove dev version stamps from ext_bindings, sft_moe, moe-sft-tp

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: JimmyPeilinLi <lipeilin@mail.nwpu.edu.cn>
2026-04-22 11:27:01 +08:00

538 lines
20 KiB
Python

# Base classes for MoE CPU inference operations
# SPDX-License-Identifier: Apache-2.0
"""
Base infrastructure for CPU-based MoE inference.
This module contains base classes and utilities shared across all backend implementations.
"""
from __future__ import annotations
import torch
from typing import Dict, List, Optional, Tuple
from abc import ABC, abstractmethod
import os
import ctypes
from kt_kernel import kt_kernel_ext
def generate_gpu_experts_masks(
activation_freq: torch.Tensor,
num_gpu_experts: int,
) -> torch.Tensor:
"""
Generate GPU experts masks based on activation frequency.
Selects the top `num_gpu_experts` experts with highest activation frequency
across all layers to be placed on GPU.
Args:
activation_freq: Activation frequency table of shape (num_layers, num_experts).
Higher values indicate more frequently activated experts.
num_gpu_experts: Total number of experts to place on GPU across all layers.
Returns:
gpu_experts_masks: Boolean mask of shape (num_layers, num_experts) on CPU.
True means the expert should be on GPU.
Example:
>>> activation_freq = torch.tensor([
... [0.1, 0.5, 0.3, 0.8], # layer 0
... [0.2, 0.4, 0.9, 0.1], # layer 1
... ])
>>> masks = generate_gpu_experts_masks(activation_freq, num_gpu_experts=3)
>>> # Top 3: layer0-expert3 (0.8), layer1-expert2 (0.9), layer0-expert1 (0.5)
>>> masks
tensor([[False, True, False, True],
[False, False, True, False]])
"""
num_layers, num_experts_per_layer = activation_freq.shape
total_experts = num_layers * num_experts_per_layer
# Clamp num_gpu_experts to valid range
num_gpu_experts = min(num_gpu_experts, total_experts)
num_gpu_experts = max(num_gpu_experts, 0)
if num_gpu_experts == 0:
return torch.zeros(num_layers, num_experts_per_layer, dtype=torch.bool, device="cpu")
# Flatten and find top-k indices
flat_freq = activation_freq.view(-1).to(device="cpu")
_, top_indices = torch.topk(flat_freq, k=num_gpu_experts, largest=True, sorted=False)
# Create mask
gpu_experts_masks = torch.zeros(total_experts, dtype=torch.bool, device="cpu")
gpu_experts_masks[top_indices] = True
# Reshape to (num_layers, num_experts)
gpu_experts_masks = gpu_experts_masks.view(num_layers, num_experts_per_layer)
return gpu_experts_masks
class KExpertsCPUBuffer:
"""
CPU buffer management for expert computation.
Manages pinned memory buffers for efficient GPU-CPU data transfer.
"""
capture_bs: List = list()
capture_buffers: Dict = dict()
temp_bs: int = 0
temp_buffer: tuple = tuple()
buffer_depth: int = 2
@classmethod
def get_buffer(cls, hidden_states: torch.Tensor, num_experts_per_tok):
hidden_size = hidden_states.shape[-1]
batch_size = hidden_states.shape[0]
pin_memory = True
if batch_size in cls.capture_buffers:
return cls.capture_buffers[batch_size]
if batch_size == cls.temp_bs:
return cls.temp_buffer
input_tensor_cpu = [
torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=pin_memory, dtype=torch.bfloat16)
for _ in range(cls.buffer_depth)
]
immediate_experts_ids_cpu = [
torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=pin_memory)
for _ in range(cls.buffer_depth)
]
deferred_experts_ids_cpu = [
torch.full((batch_size, num_experts_per_tok), -1, device="cpu", dtype=torch.long, pin_memory=pin_memory)
for _ in range(cls.buffer_depth)
]
weights_cpu = [
torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=pin_memory)
for _ in range(cls.buffer_depth)
]
output_cpu = [
torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=pin_memory, dtype=torch.bfloat16)
for _ in range(cls.buffer_depth)
]
bsz_tensor_cpu = [
torch.full((1,), batch_size, device="cpu", dtype=torch.int32, pin_memory=pin_memory)
for _ in range(cls.buffer_depth)
]
output_gpu = [
torch.zeros((batch_size, hidden_size), device=hidden_states.device, dtype=hidden_states.dtype)
for _ in range(cls.buffer_depth)
]
cur_buffer = (
input_tensor_cpu,
immediate_experts_ids_cpu,
deferred_experts_ids_cpu,
weights_cpu,
output_cpu,
bsz_tensor_cpu,
output_gpu,
)
if batch_size in cls.capture_bs:
cls.capture_buffers[batch_size] = cur_buffer
cls.temp_bs = batch_size
cls.temp_buffer = cur_buffer
return cur_buffer
class _MoEBase:
"""
Shared base class for inference and SFT MoE wrappers.
Provides:
- CPUInfer singleton management
- Basic configuration validation
This class is shared between BaseMoEWrapper (inference) and BaseSFTMoEWrapper (SFT).
"""
_cpu_infer_instance = None
@classmethod
def _get_cpu_infer(
cls,
cpuinfer_threads: int,
threadpool_count: int,
numa_nodes=None,
):
"""
Get or create the CPUInfer singleton instance.
Args:
cpuinfer_threads: Total number of CPU inference threads
threadpool_count: Number of NUMA subpools (TP count)
numa_nodes: Explicit list of NUMA node IDs. If None, defaults to sequential.
Returns:
CPUInfer singleton instance
"""
if cls._cpu_infer_instance is None:
worker_config = kt_kernel_ext.WorkerPoolConfig()
if numa_nodes is not None:
if len(numa_nodes) != threadpool_count:
raise ValueError(
f"numa_nodes length ({len(numa_nodes)}) must match "
f"threadpool_count ({threadpool_count})"
)
subpool_numa_map = list(numa_nodes)
else:
subpool_numa_map = list(range(threadpool_count))
subpool_thread_count = [
cpuinfer_threads // threadpool_count + (1 if i < cpuinfer_threads % threadpool_count else 0)
for i in range(threadpool_count)
]
worker_config.subpool_count = threadpool_count
worker_config.subpool_numa_map = subpool_numa_map
worker_config.subpool_thread_count = subpool_thread_count
cls._cpu_infer_instance = kt_kernel_ext.CPUInfer(worker_config)
return cls._cpu_infer_instance
@staticmethod
def _validate_base_config(
num_experts: int,
hidden_size: int,
moe_intermediate_size: int,
num_experts_per_tok: int,
) -> None:
"""
Validate basic configuration parameters.
Raises:
ValueError: If parameters are invalid
"""
if num_experts <= 0:
raise ValueError(f"num_experts must be positive, got {num_experts}")
if hidden_size <= 0:
raise ValueError(f"hidden_size must be positive, got {hidden_size}")
if moe_intermediate_size <= 0:
raise ValueError(f"moe_intermediate_size must be positive, got {moe_intermediate_size}")
if num_experts_per_tok <= 0:
raise ValueError(f"num_experts_per_tok must be positive, got {num_experts_per_tok}")
if num_experts_per_tok > num_experts:
raise ValueError(
f"num_experts_per_tok ({num_experts_per_tok}) cannot exceed " f"num_experts ({num_experts})"
)
class BaseMoEWrapper(_MoEBase, ABC):
"""
Base class for MoE CPU inference operations.
Provides common functionality for all backend implementations.
"""
_layer_has_pending_deferred: Dict[int, bool] = {}
def __init__(
self,
layer_idx: int,
num_experts: int,
num_experts_per_tok: int,
hidden_size: int,
moe_intermediate_size: int,
gpu_experts_mask: Optional[torch.Tensor],
cpuinfer_threads: int,
threadpool_count: int,
weight_path: str,
chunked_prefill_size: int,
cpu_save: bool = False,
max_deferred_experts_per_token: Optional[int] = None,
method: str = "AMXINT4",
numa_nodes: Optional[List[int]] = None,
):
"""
Initialize base MoE Wrapper.
Args:
layer_idx: Layer index
num_experts: Total number of experts
num_experts_per_tok: Number of experts per token (top-k)
hidden_size: Hidden dimension size
moe_intermediate_size: MoE intermediate size
gpu_experts_mask: Boolean mask indicating which experts are on GPU.
Shape: [num_experts], dtype: torch.bool.
mask[i] = True means expert i is on GPU.
If None, all experts are on CPU.
cpuinfer_threads: Number of CPU inference threads
threadpool_count: Number of NUMA subpools
weight_path: Path to weights
chunked_prefill_size: Maximum prefill chunk size
cpu_save: Whether to save weights to CPU memory
max_deferred_experts_per_token: Number of experts per token to defer on this layer. Defaults to 0 (no defer).
method: Backend method string
numa_nodes: Explicit list of NUMA node IDs for subpool mapping.
If None, defaults to [0, 1, ..., threadpool_count-1].
"""
self.layer_idx = layer_idx
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.hidden_size = hidden_size
self.moe_intermediate_size = moe_intermediate_size
# Process gpu_experts_mask: convert to bool tensor on CPU, pinned memory for async copy
# This mask is shared between C and Python (C uses uint8_t*), both can read/write it
if gpu_experts_mask is None:
# No GPU experts - all experts on CPU
self.gpu_experts_mask = torch.zeros(num_experts, dtype=torch.bool, device="cpu", pin_memory=True)
else:
# Create a new pinned tensor and copy data into it
self.gpu_experts_mask = torch.empty(num_experts, dtype=torch.bool, device="cpu", pin_memory=True)
self.gpu_experts_mask.copy_(gpu_experts_mask)
self.num_gpu_experts = int(self.gpu_experts_mask.sum().item())
# GPU copy for mask operations in forward pass (e.g., mask_cpu_expert_ids)
# This will be lazily initialized when needed
self._gpu_experts_mask_gpu: Optional[torch.Tensor] = None
self.weight_path = weight_path
self.chunked_prefill_size = chunked_prefill_size
self.cpu_save = cpu_save
self.max_deferred_experts_per_token = (
int(max_deferred_experts_per_token) if max_deferred_experts_per_token is not None else 0
)
BaseMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False
self.method = method
# Initialize CPU inference engine (singleton via shared base class)
self.cpu_infer = self._get_cpu_infer(cpuinfer_threads, threadpool_count, numa_nodes=numa_nodes)
# Backend-specific initialization happens in subclasses
self.moe = None
@abstractmethod
def load_weights_from_tensors(
self,
gate_proj: torch.Tensor,
up_proj: torch.Tensor,
down_proj: torch.Tensor,
physical_to_logical_map_cpu: torch.Tensor,
):
"""
Load and quantize weights from BF16/FP16 tensors (online quantization).
Args:
gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size]
up_proj: Up projection weights [num_experts, intermediate_size, hidden_size]
down_proj: Down projection weights [num_experts, hidden_size, intermediate_size]
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
"""
pass
@abstractmethod
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
"""
Load weights for this layer and initialize the MoE module.
Args:
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
"""
pass
def select_deferred_experts(
self,
expert_ids: torch.Tensor,
expert_scores: torch.Tensor,
protected_k: int,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
batch, topk = expert_ids.shape
device = expert_ids.device
protected_k = max(0, min(int(protected_k), topk))
if protected_k == 0:
deferred_ids = expert_ids.clone()
immediate_ids = torch.full_like(expert_ids, -1)
return immediate_ids, deferred_ids
topk_result = torch.topk(expert_scores, k=protected_k, dim=-1, largest=True, sorted=False)
protected_indices = topk_result.indices
protected_ids = torch.gather(expert_ids, -1, protected_indices)
protected_flag = torch.zeros((self.num_experts,), dtype=torch.int32, device=device)
protected_flag.scatter_(0, protected_ids.reshape(-1), 1)
protected_mask_flat = torch.gather(protected_flag, 0, expert_ids.reshape(-1)).ne(0)
protected_mask = protected_mask_flat.view(batch, topk)
immediate_ids = expert_ids.clone().masked_fill(~protected_mask, -1)
deferred_ids = expert_ids.clone().masked_fill(protected_mask, -1)
return immediate_ids, deferred_ids
def submit_forward(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
cuda_stream,
):
"""
Submit forward inference task to CPU (non-blocking).
Args:
hidden_states: Input hidden states [batch_size, hidden_size]
topk_ids: Top-k expert IDs [batch_size, num_experts_per_tok]
topk_weights: Top-k expert weights [batch_size, num_experts_per_tok]
cuda_stream: CUDA stream for synchronization
"""
flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
batch_size = flat_hidden_states.shape[0]
(
input_tensor_cpu,
immediate_experts_ids_cpu,
deferred_experts_ids_cpu,
weights_cpu,
output_cpu,
bsz_tensor_cpu,
_output_gpu,
) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok)
current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth
next_slot = (current_slot + 1) % KExpertsCPUBuffer.buffer_depth
bsz_slot_tensor = bsz_tensor_cpu[current_slot]
topk_ids_long = topk_ids.to(torch.long)
immediate_ids: torch.Tensor
deferred_ids: Optional[torch.Tensor]
if self.max_deferred_experts_per_token > 0:
protected_k = self.num_experts_per_tok - self.max_deferred_experts_per_token
immediate_ids, deferred_ids = self.select_deferred_experts(topk_ids_long, topk_weights, protected_k)
else:
immediate_ids = topk_ids_long
deferred_ids = None
input_tensor_cpu[current_slot].copy_(flat_hidden_states, non_blocking=True)
weights_cpu[current_slot].copy_(topk_weights, non_blocking=True)
immediate_experts_ids_cpu[current_slot].copy_(immediate_ids, non_blocking=True)
incremental = BaseMoEWrapper._layer_has_pending_deferred.get(self.layer_idx - 1, False)
self.cpu_infer.submit_with_cuda_stream(
cuda_stream,
self.moe.forward_task(
bsz_slot_tensor.data_ptr(),
immediate_experts_ids_cpu[current_slot].size(-1),
immediate_experts_ids_cpu[current_slot].data_ptr(),
weights_cpu[current_slot].data_ptr(),
input_tensor_cpu[current_slot].data_ptr(),
output_cpu[current_slot].data_ptr(),
incremental,
),
)
BaseMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False
if deferred_ids is not None:
deferred_experts_ids_cpu[current_slot].copy_(deferred_ids, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(
cuda_stream,
self.moe.forward_task(
bsz_slot_tensor.data_ptr(),
deferred_experts_ids_cpu[current_slot].size(-1),
deferred_experts_ids_cpu[current_slot].data_ptr(),
weights_cpu[current_slot].data_ptr(),
input_tensor_cpu[current_slot].data_ptr(),
output_cpu[next_slot].data_ptr(),
False,
),
)
BaseMoEWrapper._layer_has_pending_deferred[self.layer_idx] = True
def sync_forward(self, hidden_states: torch.Tensor, cuda_stream) -> torch.Tensor:
"""
Synchronize and retrieve forward inference results.
Args:
hidden_states: Original input hidden states (for getting buffer)
cuda_stream: CUDA stream for synchronization
Returns:
output_gpu: Output tensor on GPU
"""
flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
(
_input_tensor_cpu,
_immediate_experts_ids_cpu,
_deferred_experts_ids_cpu,
_weights_cpu,
output_cpu,
_bsz_tensor_cpu,
output_gpu,
) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok)
current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth
allow_pending = 1 if BaseMoEWrapper._layer_has_pending_deferred.get(self.layer_idx, False) else 0
self.cpu_infer.sync_with_cuda_stream(cuda_stream, allow_pending)
output_gpu[current_slot].copy_(output_cpu[current_slot], non_blocking=True)
return output_gpu[current_slot]
def forward(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
cuda_stream,
) -> torch.Tensor:
"""
Execute forward inference synchronously (submit + sync).
Args:
hidden_states: Input hidden states [batch_size, hidden_size]
topk_ids: Top-k expert IDs [batch_size, num_experts_per_tok]
topk_weights: Top-k expert weights [batch_size, num_experts_per_tok]
cuda_stream: CUDA stream for synchronization
Returns:
Output tensor on GPU
"""
self.submit_forward(hidden_states, topk_ids, topk_weights, cuda_stream)
return self.sync_forward(hidden_states, cuda_stream)
@staticmethod
def set_capture_batch_sizes(capture_bs: List[int]):
"""
Set batch sizes to capture and cache buffers for.
This allows pre-allocation of CPU buffers for specific batch sizes,
improving performance by avoiding buffer re-allocation during inference.
Args:
capture_bs: List of batch sizes to capture (e.g., [1, 2, 4, 8, 16])
Example:
>>> BaseMoEWrapper.set_capture_batch_sizes([1, 2, 4, 8, 16])
"""
KExpertsCPUBuffer.capture_bs = capture_bs
@staticmethod
def get_capture_batch_sizes() -> List[int]:
"""
Get currently configured capture batch sizes.
Returns:
List of batch sizes that are being captured
"""
return KExpertsCPUBuffer.capture_bs
@staticmethod
def clear_buffer_cache():
"""
Clear all cached buffers.
This frees up memory by clearing the buffer cache. Useful when you want
to reset the buffer state or free memory.
"""
KExpertsCPUBuffer.capture_buffers.clear()
KExpertsCPUBuffer.temp_bs = 0
KExpertsCPUBuffer.temp_buffer = tuple()