mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 11:49:51 +00:00
1762 lines
65 KiB
Python
1762 lines
65 KiB
Python
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
"""
|
|
MOE SFT AMX Test File - Non-TP (Single NUMA Node) Version
|
|
|
|
This file tests the SFT MoE AMX operator with a single NUMA node configuration
|
|
to isolate whether numerical bugs are in the basic SFT logic or TP partitioning.
|
|
|
|
Key difference from test_moe_sft_amx.py:
|
|
- Uses WorkerPoolConfig to force single subpool (tp_count=1)
|
|
- Only tests BF16 forward pass for simplicity
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import math
|
|
from typing import Literal, Dict
|
|
|
|
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
|
|
print("sys.path:", sys.path)
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
# Try to import kt_kernel_ext
|
|
try:
|
|
from kt_kernel import kt_kernel_ext
|
|
|
|
HAS_KT_KERNEL = True
|
|
except ImportError:
|
|
HAS_KT_KERNEL = False
|
|
kt_kernel_ext = None
|
|
|
|
# =============================================================================
|
|
# Test Configuration
|
|
# =============================================================================
|
|
|
|
# Model configuration (based on DeepSeek-V3 architecture)
|
|
expert_num = 256 # Total number of experts
|
|
hidden_size = 7168 # Hidden dimension
|
|
intermediate_size = 2048 # MLP intermediate dimension
|
|
max_len = 25600 # Maximum sequence length
|
|
num_experts_per_tok = 8 # Number of experts per token (top-k)
|
|
qlen = 4 # Sequence length for testing
|
|
layer_num = 1 # Number of layers to test
|
|
|
|
# LoRA configuration
|
|
lora_rank = 16 # LoRA rank (r)
|
|
lora_alpha = 32.0 # LoRA scaling factor (alpha)
|
|
lora_scaling = lora_alpha / lora_rank # Effective scaling: alpha / r
|
|
|
|
# Test configuration
|
|
validation_iter = 2 # Number of validation iterations
|
|
debug_print_count = 8 # Number of values to print in debug output
|
|
num_threads = 60 # Number of CPU threads for inference
|
|
|
|
# Precision thresholds
|
|
BF16_FORWARD_THRESHOLD = 0.05 # Maximum relative error for BF16 forward
|
|
BF16_BACKWARD_THRESHOLD = 0.10 # Maximum relative error for BF16 backward
|
|
INT4_FORWARD_THRESHOLD = 0.35 # Maximum relative error for INT4 forward (same as inference)
|
|
INT4_BACKWARD_THRESHOLD = 0.40 # Maximum relative error for INT4 backward
|
|
|
|
|
|
# =============================================================================
|
|
# Quantization Mode Utilities
|
|
# =============================================================================
|
|
|
|
|
|
def get_moe_sft_class(quant_mode: str):
|
|
"""根据量化模式返回对应的 MOE SFT 类。
|
|
|
|
Args:
|
|
quant_mode: 量化模式,支持 "bf16", "int8", "int4", "int4_1", "int4_1kgroup", "int4_kgroup"
|
|
|
|
Returns:
|
|
对应的 MOE SFT 类
|
|
"""
|
|
if not HAS_KT_KERNEL:
|
|
raise RuntimeError("kt_kernel_ext not available")
|
|
|
|
if quant_mode == "bf16":
|
|
return kt_kernel_ext.moe.AMXBF16_SFT_MOE
|
|
elif quant_mode == "int8":
|
|
return kt_kernel_ext.moe.AMXInt8_SFT_MOE
|
|
elif quant_mode == "int4":
|
|
return kt_kernel_ext.moe.AMXInt4_SFT_MOE
|
|
elif quant_mode == "int4_1":
|
|
return kt_kernel_ext.moe.AMXInt4_1_SFT_MOE
|
|
elif quant_mode == "int4_1kgroup":
|
|
return kt_kernel_ext.moe.AMXInt4_1KGroup_SFT_MOE
|
|
elif quant_mode == "int4_kgroup":
|
|
return kt_kernel_ext.moe.AMXInt4_KGroup_SFT_MOE
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported quant_mode: {quant_mode}. Supported: bf16, int8, int4, int4_1, int4_1kgroup, int4_kgroup"
|
|
)
|
|
|
|
|
|
def get_threshold(quant_mode: str, is_backward: bool = False) -> float:
|
|
"""根据量化模式返回精度阈值(与推理测试保持一致)。
|
|
|
|
Args:
|
|
quant_mode: 量化模式
|
|
is_backward: 是否为 backward 阈值
|
|
|
|
Returns:
|
|
精度阈值
|
|
"""
|
|
# INT4 variants (int4, int4_1, int4_1kgroup, int4_kgroup) 使用更高的阈值
|
|
if quant_mode in ("int4", "int4_1", "int4_1kgroup", "int4_kgroup"):
|
|
if is_backward:
|
|
return INT4_BACKWARD_THRESHOLD # 0.40
|
|
return INT4_FORWARD_THRESHOLD # 0.35
|
|
# BF16 和 INT8 使用相同阈值
|
|
if is_backward:
|
|
return BF16_BACKWARD_THRESHOLD # 0.10
|
|
return BF16_FORWARD_THRESHOLD # 0.05
|
|
|
|
|
|
# =============================================================================
|
|
# K2 Quantization Utilities (for INT4_KGROUP mode)
|
|
# =============================================================================
|
|
|
|
|
|
def pack_to_int32(value: torch.Tensor, num_bits: int, packed_dim: Literal[0, 1] = 1) -> torch.Tensor:
|
|
"""Pack int4 values into int32 tensor.
|
|
|
|
Args:
|
|
value: int8 tensor to pack
|
|
num_bits: number of bits per value (4 for int4)
|
|
packed_dim: dimension to pack along
|
|
|
|
Returns:
|
|
int32 tensor with packed values
|
|
"""
|
|
if value.dtype is not torch.int8:
|
|
raise ValueError("Tensor must be torch.int8 before packing")
|
|
if not (1 <= num_bits <= 8):
|
|
raise ValueError(f"num_bits must be in [1, 8], got {num_bits}")
|
|
|
|
offset = 1 << (num_bits - 1)
|
|
value = (value + offset).to(torch.uint8)
|
|
device = value.device
|
|
|
|
pack_factor = 32 // num_bits
|
|
|
|
if packed_dim == 0:
|
|
value = value.transpose(0, 1)
|
|
|
|
rows, cols = value.shape
|
|
padded_cols = math.ceil(cols / pack_factor) * pack_factor
|
|
pad_len = padded_cols - cols
|
|
|
|
if pad_len > 0:
|
|
value = torch.nn.functional.pad(value, (0, pad_len))
|
|
|
|
num_groups = padded_cols // pack_factor
|
|
|
|
# Use int32 here
|
|
reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32)
|
|
bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits
|
|
packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32)
|
|
|
|
if packed_dim == 0:
|
|
packed = packed.transpose(0, 1)
|
|
|
|
return packed
|
|
|
|
|
|
def pack_tensor_per_row(q: torch.Tensor, num_bits: int) -> torch.Tensor:
|
|
"""Pack tensor per row for K2 quantization.
|
|
|
|
Args:
|
|
q: [expert_num, rows, cols] int8 tensor
|
|
num_bits: number of bits per value
|
|
|
|
Returns:
|
|
Packed int32 tensor
|
|
"""
|
|
e, rows, cols = q.shape
|
|
flat = q.view(e * rows, cols)
|
|
packed = pack_to_int32(flat, num_bits)
|
|
return packed.view(e, rows, -1).contiguous()
|
|
|
|
|
|
def quantize_k2_tensor(weights: torch.Tensor, group_size: int):
|
|
"""
|
|
K2 symmetric max-abs/7 quantization per k-group.
|
|
|
|
Args:
|
|
weights: [expert_num, rows (N), cols (K)] bfloat16 tensor
|
|
|
|
Returns:
|
|
packed_q: int32 tensor storing 8 int4s per element with shape [expert_num, rows * (cols // 8)]
|
|
scales: bfloat16 tensor with shape [expert_num, rows * (cols // group_size)]
|
|
"""
|
|
weights_f32 = weights.to(torch.float32)
|
|
e, rows, cols = weights_f32.shape
|
|
if cols % group_size != 0 or cols % 2 != 0:
|
|
raise ValueError(f"cols ({cols}) must be divisible by group_size ({group_size}) and 2")
|
|
|
|
reshaped = weights_f32.view(e, rows, cols // group_size, group_size)
|
|
max_abs = reshaped.abs().amax(dim=-1, keepdim=True)
|
|
max_abs = torch.clamp(max_abs, min=1e-8)
|
|
scales = (max_abs / 7.0).squeeze(-1)
|
|
q = torch.round(reshaped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int8)
|
|
q = q.view(e, rows, cols)
|
|
packed = pack_tensor_per_row(q, num_bits=4).view(e, rows, cols // 8).contiguous()
|
|
scales = scales.to(torch.bfloat16).contiguous().view(e, rows, cols // group_size).contiguous()
|
|
|
|
return packed, scales
|
|
|
|
|
|
def init_base_weights_for_k2(
|
|
expert_num: int, hidden_size: int, intermediate_size: int, group_size: int = 128
|
|
) -> Dict[str, torch.Tensor]:
|
|
"""Initialize pre-quantized K2 weights for INT4_KGROUP mode.
|
|
|
|
Args:
|
|
expert_num: number of experts
|
|
hidden_size: hidden dimension
|
|
intermediate_size: intermediate dimension
|
|
group_size: quantization group size
|
|
|
|
Returns:
|
|
Dictionary containing:
|
|
- gate_qweight, up_qweight, down_qweight: packed int4 weights
|
|
- gate_scales, up_scales, down_scales: bf16 scales
|
|
- gate_proj_bf16, up_proj_bf16, down_proj_bf16: original bf16 weights for reference
|
|
"""
|
|
# Create random BF16 weights
|
|
gate_proj_bf16 = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16)
|
|
up_proj_bf16 = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16)
|
|
down_proj_bf16 = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.bfloat16)
|
|
|
|
# Quantize to int4
|
|
gate_qweight, gate_scales = quantize_k2_tensor(gate_proj_bf16, group_size)
|
|
up_qweight, up_scales = quantize_k2_tensor(up_proj_bf16, group_size)
|
|
down_qweight, down_scales = quantize_k2_tensor(down_proj_bf16, group_size)
|
|
|
|
return {
|
|
"gate_qweight": gate_qweight.contiguous(),
|
|
"up_qweight": up_qweight.contiguous(),
|
|
"down_qweight": down_qweight.contiguous(),
|
|
"gate_scales": gate_scales.contiguous(),
|
|
"up_scales": up_scales.contiguous(),
|
|
"down_scales": down_scales.contiguous(),
|
|
# Keep original BF16 for gradient verification
|
|
"gate_proj_bf16": gate_proj_bf16.contiguous(),
|
|
"up_proj_bf16": up_proj_bf16.contiguous(),
|
|
"down_proj_bf16": down_proj_bf16.contiguous(),
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Activation Functions
|
|
# =============================================================================
|
|
|
|
|
|
def silu(x: torch.Tensor) -> torch.Tensor:
|
|
"""SiLU (Swish) activation function: x * sigmoid(x)"""
|
|
return x * torch.sigmoid(x)
|
|
|
|
|
|
def act_fn(x: torch.Tensor) -> torch.Tensor:
|
|
"""Activation function for MoE MLP (SiLU/Swish)"""
|
|
return x / (1.0 + torch.exp(-x))
|
|
|
|
|
|
# =============================================================================
|
|
# LoRA Linear Layer Reference Implementation
|
|
# =============================================================================
|
|
|
|
|
|
def lora_linear_forward(
|
|
x: torch.Tensor, weight: torch.Tensor, lora_a: torch.Tensor, lora_b: torch.Tensor, scaling: float
|
|
) -> torch.Tensor:
|
|
"""
|
|
LoRA linear layer forward pass.
|
|
|
|
Computes: y = x @ W^T + (x @ A^T @ B^T) * scaling
|
|
"""
|
|
# Base output: x @ W^T
|
|
base_out = torch.mm(x, weight.t())
|
|
|
|
# LoRA output: (x @ A^T @ B^T) * scaling
|
|
lora_out = torch.mm(torch.mm(x, lora_a.t()), lora_b.t()) * scaling
|
|
|
|
return base_out + lora_out
|
|
|
|
|
|
def lora_linear_backward(
|
|
grad_output: torch.Tensor,
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
lora_a: torch.Tensor,
|
|
lora_b: torch.Tensor,
|
|
scaling: float,
|
|
) -> tuple:
|
|
"""
|
|
LoRA linear layer backward pass.
|
|
|
|
Computes gradients for input and LoRA weights (A and B matrices).
|
|
Base weight W is frozen and does not receive gradients.
|
|
|
|
Args:
|
|
grad_output: Gradient from upstream [batch, out_features]
|
|
x: Input tensor from forward pass [batch, in_features]
|
|
weight: Base weight matrix [out_features, in_features] (frozen)
|
|
lora_a: LoRA A matrix [rank, in_features]
|
|
lora_b: LoRA B matrix [out_features, rank]
|
|
scaling: LoRA scaling factor (alpha / rank)
|
|
|
|
Returns:
|
|
Tuple of (grad_input, grad_lora_a, grad_lora_b)
|
|
"""
|
|
# Gradient for input: grad_output @ W + grad_output @ B @ A * scaling
|
|
grad_input = torch.mm(grad_output, weight)
|
|
grad_input += torch.mm(torch.mm(grad_output, lora_b), lora_a) * scaling
|
|
|
|
# Gradient for lora_b: (grad_output^T @ (x @ A^T)) * scaling
|
|
# Shape: [out_features, rank]
|
|
lora_intermediate = torch.mm(x, lora_a.t()) # [batch, rank]
|
|
grad_lora_b = torch.mm(grad_output.t(), lora_intermediate) * scaling
|
|
|
|
# Gradient for lora_a: (B^T @ grad_output^T @ x) * scaling
|
|
# Shape: [rank, in_features]
|
|
grad_lora_a = torch.mm(torch.mm(lora_b.t(), grad_output.t()), x) * scaling
|
|
|
|
return grad_input, grad_lora_a, grad_lora_b
|
|
|
|
|
|
# =============================================================================
|
|
# MLP Reference Implementation (Single Expert with LoRA)
|
|
# =============================================================================
|
|
|
|
|
|
def mlp_lora_forward(
|
|
x: torch.Tensor,
|
|
gate_proj: torch.Tensor,
|
|
up_proj: torch.Tensor,
|
|
down_proj: torch.Tensor,
|
|
gate_lora_a: torch.Tensor,
|
|
gate_lora_b: torch.Tensor,
|
|
up_lora_a: torch.Tensor,
|
|
up_lora_b: torch.Tensor,
|
|
down_lora_a: torch.Tensor,
|
|
down_lora_b: torch.Tensor,
|
|
scaling: float,
|
|
debug_print: bool = False,
|
|
) -> tuple:
|
|
"""
|
|
MLP forward pass with LoRA adapters on all projections.
|
|
|
|
Computes: down(silu(gate(x)) * up(x))
|
|
where each linear layer has LoRA: linear(x) = x @ W^T + (x @ A^T @ B^T) * scaling
|
|
"""
|
|
# Gate projection with LoRA
|
|
gate_out = lora_linear_forward(x, gate_proj, gate_lora_a, gate_lora_b, scaling)
|
|
|
|
# Up projection with LoRA
|
|
up_out = lora_linear_forward(x, up_proj, up_lora_a, up_lora_b, scaling)
|
|
|
|
# Activation: silu(gate) * up
|
|
gate_activated = act_fn(gate_out)
|
|
intermediate = gate_activated * up_out
|
|
|
|
# Down projection with LoRA
|
|
output = lora_linear_forward(intermediate, down_proj, down_lora_a, down_lora_b, scaling)
|
|
|
|
if debug_print:
|
|
print(f" gate_out[:8] = {gate_out.flatten()[:8]}")
|
|
print(f" up_out[:8] = {up_out.flatten()[:8]}")
|
|
print(f" intermediate[:8] = {intermediate.flatten()[:8]}")
|
|
print(f" output[:8] = {output.flatten()[:8]}")
|
|
|
|
# Save tensors for backward pass
|
|
saved_tensors = {
|
|
"x": x,
|
|
"gate_out": gate_out,
|
|
"up_out": up_out,
|
|
"gate_activated": gate_activated,
|
|
"intermediate": intermediate,
|
|
}
|
|
|
|
return output, saved_tensors
|
|
|
|
|
|
def mlp_lora_backward(
|
|
grad_output: torch.Tensor,
|
|
saved_tensors: dict,
|
|
gate_proj: torch.Tensor,
|
|
up_proj: torch.Tensor,
|
|
down_proj: torch.Tensor,
|
|
gate_lora_a: torch.Tensor,
|
|
gate_lora_b: torch.Tensor,
|
|
up_lora_a: torch.Tensor,
|
|
up_lora_b: torch.Tensor,
|
|
down_lora_a: torch.Tensor,
|
|
down_lora_b: torch.Tensor,
|
|
scaling: float,
|
|
) -> dict:
|
|
"""
|
|
MLP backward pass with LoRA adapters.
|
|
|
|
Computes gradients for input and all LoRA weights.
|
|
|
|
Args:
|
|
grad_output: Gradient from upstream [batch, hidden_size]
|
|
saved_tensors: Dictionary of tensors saved during forward pass
|
|
gate_proj, up_proj, down_proj: Base projection weights (frozen)
|
|
gate_lora_a/b, up_lora_a/b, down_lora_a/b: LoRA weights
|
|
scaling: LoRA scaling factor
|
|
|
|
Returns:
|
|
Dictionary containing:
|
|
- grad_input: Gradient for input
|
|
- grad_gate_lora_a/b: Gradients for gate LoRA weights
|
|
- grad_up_lora_a/b: Gradients for up LoRA weights
|
|
- grad_down_lora_a/b: Gradients for down LoRA weights
|
|
"""
|
|
x = saved_tensors["x"]
|
|
gate_out = saved_tensors["gate_out"]
|
|
up_out = saved_tensors["up_out"]
|
|
gate_activated = saved_tensors["gate_activated"]
|
|
intermediate = saved_tensors["intermediate"]
|
|
|
|
# Backward through down projection
|
|
grad_intermediate, grad_down_lora_a, grad_down_lora_b = lora_linear_backward(
|
|
grad_output, intermediate, down_proj, down_lora_a, down_lora_b, scaling
|
|
)
|
|
|
|
# Backward through activation: d(silu(gate) * up) / d(gate, up)
|
|
# grad_gate_activated = grad_intermediate * up_out
|
|
# grad_up_out = grad_intermediate * gate_activated
|
|
grad_gate_activated = grad_intermediate * up_out
|
|
grad_up_out = grad_intermediate * gate_activated
|
|
|
|
# Backward through silu: d(silu(x)) / dx = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
|
|
# = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
|
sigmoid_gate = torch.sigmoid(gate_out)
|
|
grad_gate_out = grad_gate_activated * sigmoid_gate * (1 + gate_out * (1 - sigmoid_gate))
|
|
|
|
# Backward through up projection
|
|
grad_x_up, grad_up_lora_a, grad_up_lora_b = lora_linear_backward(
|
|
grad_up_out, x, up_proj, up_lora_a, up_lora_b, scaling
|
|
)
|
|
|
|
# Backward through gate projection
|
|
grad_x_gate, grad_gate_lora_a, grad_gate_lora_b = lora_linear_backward(
|
|
grad_gate_out, x, gate_proj, gate_lora_a, gate_lora_b, scaling
|
|
)
|
|
|
|
# Total gradient for input
|
|
grad_input = grad_x_up + grad_x_gate
|
|
|
|
return {
|
|
"grad_input": grad_input,
|
|
"grad_gate_lora_a": grad_gate_lora_a,
|
|
"grad_gate_lora_b": grad_gate_lora_b,
|
|
"grad_up_lora_a": grad_up_lora_a,
|
|
"grad_up_lora_b": grad_up_lora_b,
|
|
"grad_down_lora_a": grad_down_lora_a,
|
|
"grad_down_lora_b": grad_down_lora_b,
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# MOE SFT Reference Implementation (PyTorch)
|
|
# =============================================================================
|
|
|
|
|
|
def moe_sft_torch_forward(
|
|
input: torch.Tensor,
|
|
expert_ids: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
gate_proj: torch.Tensor,
|
|
up_proj: torch.Tensor,
|
|
down_proj: torch.Tensor,
|
|
gate_lora_a: torch.Tensor,
|
|
gate_lora_b: torch.Tensor,
|
|
up_lora_a: torch.Tensor,
|
|
up_lora_b: torch.Tensor,
|
|
down_lora_a: torch.Tensor,
|
|
down_lora_b: torch.Tensor,
|
|
scaling: float,
|
|
debug_print: bool = False,
|
|
) -> tuple:
|
|
"""
|
|
MoE SFT forward pass with LoRA adapters.
|
|
|
|
Routes tokens to selected experts and applies MLP with LoRA.
|
|
"""
|
|
qlen = input.shape[0]
|
|
k = expert_ids.shape[1] # num_experts_per_tok
|
|
|
|
# Count tokens per expert
|
|
cnts = expert_ids.new_zeros((qlen, expert_num))
|
|
cnts.scatter_(1, expert_ids, 1)
|
|
tokens_per_expert = cnts.sum(dim=0)
|
|
|
|
# Sort tokens by expert
|
|
idxs = expert_ids.view(-1).argsort()
|
|
sorted_tokens = input[idxs // k]
|
|
|
|
if debug_print:
|
|
activated_experts = [i for i, n in enumerate(tokens_per_expert) if n > 0]
|
|
print(f"[MOE SFT DEBUG] Activated experts: {activated_experts}")
|
|
|
|
outputs = []
|
|
saved_tensors_list = []
|
|
start_idx = 0
|
|
|
|
for i, num_tokens in enumerate(tokens_per_expert):
|
|
if num_tokens == 0:
|
|
saved_tensors_list.append(None)
|
|
continue
|
|
|
|
end_idx = start_idx + int(num_tokens)
|
|
tokens_for_expert = sorted_tokens[start_idx:end_idx]
|
|
|
|
# Forward through MLP with LoRA
|
|
expert_out, saved = mlp_lora_forward(
|
|
tokens_for_expert,
|
|
gate_proj[i],
|
|
up_proj[i],
|
|
down_proj[i],
|
|
gate_lora_a[i],
|
|
gate_lora_b[i],
|
|
up_lora_a[i],
|
|
up_lora_b[i],
|
|
down_lora_a[i],
|
|
down_lora_b[i],
|
|
scaling,
|
|
debug_print=(debug_print and i == expert_ids[0, 0].item()),
|
|
)
|
|
|
|
outputs.append(expert_out)
|
|
saved["expert_id"] = i
|
|
saved["start_idx"] = start_idx
|
|
saved["end_idx"] = end_idx
|
|
saved_tensors_list.append(saved)
|
|
start_idx = end_idx
|
|
|
|
# Concatenate outputs
|
|
if outputs:
|
|
outs = torch.cat(outputs, dim=0)
|
|
else:
|
|
outs = sorted_tokens.new_empty(0)
|
|
|
|
# Reorder outputs back to original order
|
|
new_x = torch.empty_like(outs)
|
|
new_x[idxs] = outs
|
|
|
|
# Apply routing weights and sum
|
|
output = new_x.view(qlen, k, -1).type(weights.dtype).mul_(weights.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype)
|
|
|
|
if debug_print:
|
|
print(f"[MOE SFT DEBUG] Final output[:8] = {output.flatten()[:8]}")
|
|
|
|
# Save additional tensors for backward
|
|
moe_saved = {
|
|
"input": input,
|
|
"expert_ids": expert_ids,
|
|
"weights": weights,
|
|
"idxs": idxs,
|
|
"tokens_per_expert": tokens_per_expert,
|
|
"expert_saved_tensors": saved_tensors_list,
|
|
}
|
|
|
|
return output, moe_saved
|
|
|
|
|
|
def moe_sft_torch_backward(
|
|
grad_output: torch.Tensor,
|
|
moe_saved: dict,
|
|
gate_proj: torch.Tensor,
|
|
up_proj: torch.Tensor,
|
|
down_proj: torch.Tensor,
|
|
gate_lora_a: torch.Tensor,
|
|
gate_lora_b: torch.Tensor,
|
|
up_lora_a: torch.Tensor,
|
|
up_lora_b: torch.Tensor,
|
|
down_lora_a: torch.Tensor,
|
|
down_lora_b: torch.Tensor,
|
|
scaling: float,
|
|
) -> dict:
|
|
"""
|
|
MoE SFT backward pass.
|
|
|
|
Computes gradients for input and all LoRA weights across all experts.
|
|
|
|
Args:
|
|
grad_output: Gradient from upstream [qlen, hidden_size]
|
|
moe_saved: Dictionary of tensors saved during forward
|
|
gate_proj, up_proj, down_proj: Base projection weights (frozen)
|
|
gate_lora_a/b, up_lora_a/b, down_lora_a/b: LoRA weights
|
|
scaling: LoRA scaling factor
|
|
|
|
Returns:
|
|
Dictionary containing:
|
|
- grad_input: Gradient for input [qlen, hidden_size]
|
|
- grad_gate_lora_a/b: Gradients for gate LoRA [expert_num, ...]
|
|
- grad_up_lora_a/b: Gradients for up LoRA [expert_num, ...]
|
|
- grad_down_lora_a/b: Gradients for down LoRA [expert_num, ...]
|
|
"""
|
|
input = moe_saved["input"]
|
|
expert_ids = moe_saved["expert_ids"]
|
|
weights = moe_saved["weights"]
|
|
idxs = moe_saved["idxs"]
|
|
tokens_per_expert = moe_saved["tokens_per_expert"]
|
|
expert_saved_list = moe_saved["expert_saved_tensors"]
|
|
|
|
qlen, k = expert_ids.shape
|
|
|
|
# Expand grad_output for each expert
|
|
# grad_output: [qlen, hidden_size] -> [qlen, k, hidden_size]
|
|
# Note: weights is float32, grad_output is bf16. Multiplication promotes to float32.
|
|
# We must convert back to bf16 to match weight dtypes in subsequent matrix operations.
|
|
grad_output_expanded = grad_output.unsqueeze(1) * weights.unsqueeze(-1)
|
|
grad_output_expanded = grad_output_expanded.view(-1, grad_output.shape[-1]).to(grad_output.dtype)
|
|
|
|
# Reorder to match sorted token order
|
|
sorted_grad_output = grad_output_expanded[idxs]
|
|
|
|
# Initialize gradient accumulators
|
|
grad_input_sorted = torch.zeros_like(sorted_grad_output)
|
|
|
|
# Initialize LoRA gradient accumulators
|
|
grad_gate_lora_a = torch.zeros_like(gate_lora_a)
|
|
grad_gate_lora_b = torch.zeros_like(gate_lora_b)
|
|
grad_up_lora_a = torch.zeros_like(up_lora_a)
|
|
grad_up_lora_b = torch.zeros_like(up_lora_b)
|
|
grad_down_lora_a = torch.zeros_like(down_lora_a)
|
|
grad_down_lora_b = torch.zeros_like(down_lora_b)
|
|
|
|
# Backward through each expert
|
|
for i, saved in enumerate(expert_saved_list):
|
|
if saved is None:
|
|
continue
|
|
|
|
start_idx = saved["start_idx"]
|
|
end_idx = saved["end_idx"]
|
|
grad_out_expert = sorted_grad_output[start_idx:end_idx]
|
|
|
|
# Backward through MLP
|
|
grads = mlp_lora_backward(
|
|
grad_out_expert,
|
|
saved,
|
|
gate_proj[i],
|
|
up_proj[i],
|
|
down_proj[i],
|
|
gate_lora_a[i],
|
|
gate_lora_b[i],
|
|
up_lora_a[i],
|
|
up_lora_b[i],
|
|
down_lora_a[i],
|
|
down_lora_b[i],
|
|
scaling,
|
|
)
|
|
|
|
grad_input_sorted[start_idx:end_idx] = grads["grad_input"]
|
|
grad_gate_lora_a[i] = grads["grad_gate_lora_a"]
|
|
grad_gate_lora_b[i] = grads["grad_gate_lora_b"]
|
|
grad_up_lora_a[i] = grads["grad_up_lora_a"]
|
|
grad_up_lora_b[i] = grads["grad_up_lora_b"]
|
|
grad_down_lora_a[i] = grads["grad_down_lora_a"]
|
|
grad_down_lora_b[i] = grads["grad_down_lora_b"]
|
|
|
|
# Reorder gradients back to original order
|
|
grad_input_flat = torch.zeros_like(grad_input_sorted)
|
|
grad_input_flat[idxs] = grad_input_sorted
|
|
|
|
# Sum gradients for each token (from multiple experts)
|
|
grad_input = grad_input_flat.view(qlen, k, -1).sum(dim=1)
|
|
|
|
return {
|
|
"grad_input": grad_input,
|
|
"grad_gate_lora_a": grad_gate_lora_a,
|
|
"grad_gate_lora_b": grad_gate_lora_b,
|
|
"grad_up_lora_a": grad_up_lora_a,
|
|
"grad_up_lora_b": grad_up_lora_b,
|
|
"grad_down_lora_a": grad_down_lora_a,
|
|
"grad_down_lora_b": grad_down_lora_b,
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Weight Initialization Utilities
|
|
# =============================================================================
|
|
|
|
|
|
def init_base_weights(expert_num: int, hidden_size: int, intermediate_size: int, dtype=torch.bfloat16):
|
|
"""Initialize base MoE weights (frozen during fine-tuning).
|
|
|
|
NOTE: Weights are NOT divided by 100 (matching inference test).
|
|
This ensures output values are in a normal range for bf16 precision.
|
|
Uses CUDA for fast random generation, then moves to CPU.
|
|
"""
|
|
gate_proj = (
|
|
torch.randn((expert_num, intermediate_size, hidden_size), dtype=dtype, device="cuda").to("cpu").contiguous()
|
|
)
|
|
up_proj = (
|
|
torch.randn((expert_num, intermediate_size, hidden_size), dtype=dtype, device="cuda").to("cpu").contiguous()
|
|
)
|
|
down_proj = (
|
|
torch.randn((expert_num, hidden_size, intermediate_size), dtype=dtype, device="cuda").to("cpu").contiguous()
|
|
)
|
|
|
|
return gate_proj, up_proj, down_proj
|
|
|
|
|
|
def init_lora_weights(expert_num: int, hidden_size: int, intermediate_size: int, rank: int, dtype=torch.bfloat16):
|
|
"""
|
|
Initialize LoRA weights.
|
|
|
|
LoRA A matrices are initialized with small random values.
|
|
LoRA B matrices are initialized to zero (so initial output equals base model).
|
|
Uses CUDA for fast random generation, then moves to CPU.
|
|
"""
|
|
# Gate projection LoRA
|
|
gate_lora_a = torch.randn((expert_num, rank, hidden_size), dtype=dtype, device="cuda").to("cpu").contiguous() / 100
|
|
gate_lora_b = torch.zeros((expert_num, intermediate_size, rank), dtype=dtype, device="cpu").contiguous()
|
|
|
|
# Up projection LoRA
|
|
up_lora_a = torch.randn((expert_num, rank, hidden_size), dtype=dtype, device="cuda").to("cpu").contiguous() / 100
|
|
up_lora_b = torch.zeros((expert_num, intermediate_size, rank), dtype=dtype, device="cpu").contiguous()
|
|
|
|
# Down projection LoRA
|
|
down_lora_a = (
|
|
torch.randn((expert_num, rank, intermediate_size), dtype=dtype, device="cuda").to("cpu").contiguous() / 100
|
|
)
|
|
down_lora_b = torch.zeros((expert_num, hidden_size, rank), dtype=dtype, device="cpu").contiguous()
|
|
|
|
return (gate_lora_a, gate_lora_b, up_lora_a, up_lora_b, down_lora_a, down_lora_b)
|
|
|
|
|
|
# =============================================================================
|
|
# Test Functions
|
|
# =============================================================================
|
|
|
|
|
|
def test_moe_sft_forward_no_tp(quant_mode: str = "bf16"):
|
|
"""
|
|
Test MOE SFT forward pass accuracy with single NUMA node (no TP).
|
|
|
|
Compares the AMX implementation against PyTorch reference.
|
|
Uses WorkerPoolConfig to force single subpool.
|
|
|
|
Args:
|
|
quant_mode: Quantization mode, "bf16" or "int8"
|
|
"""
|
|
print(f"\n{'='*60}")
|
|
print(f"Testing MOE SFT Forward Pass - {quant_mode.upper()} mode (NO TP)")
|
|
print(f"{'='*60}")
|
|
|
|
# Set random seed for reproducibility
|
|
torch.manual_seed(42)
|
|
|
|
# Initialize weights based on quant_mode
|
|
k2_weights = None # Will be set for K2 mode
|
|
if quant_mode == "int4_kgroup":
|
|
# K2 needs pre-quantized int4 weights
|
|
k2_weights = init_base_weights_for_k2(expert_num, hidden_size, intermediate_size, group_size=128)
|
|
# Use original BF16 for reference computation
|
|
gate_proj = k2_weights["gate_proj_bf16"]
|
|
up_proj = k2_weights["up_proj_bf16"]
|
|
down_proj = k2_weights["down_proj_bf16"]
|
|
else:
|
|
# Other modes use BF16 weights
|
|
gate_proj, up_proj, down_proj = init_base_weights(expert_num, hidden_size, intermediate_size)
|
|
|
|
lora_weights = init_lora_weights(expert_num, hidden_size, intermediate_size, lora_rank)
|
|
gate_lora_a, gate_lora_b, up_lora_a, up_lora_b, down_lora_a, down_lora_b = lora_weights
|
|
|
|
# Make LoRA B non-zero for testing
|
|
gate_lora_b.normal_().div_(100)
|
|
up_lora_b.normal_().div_(100)
|
|
down_lora_b.normal_().div_(100)
|
|
|
|
if not HAS_KT_KERNEL:
|
|
print("ERROR: kt_kernel_ext not available, cannot run test")
|
|
sys.exit(1)
|
|
|
|
# Initialize CPUInfer with single NUMA node configuration
|
|
# This forces tp_count=1, bypassing TP partitioning
|
|
print("\n[INFO] Creating CPUInfer with single NUMA node (NO TP)...")
|
|
pool_config = kt_kernel_ext.WorkerPoolConfig()
|
|
pool_config.subpool_count = 1
|
|
pool_config.subpool_numa_map = [0]
|
|
pool_config.subpool_thread_count = [num_threads]
|
|
CPUInfer = kt_kernel_ext.CPUInfer(pool_config)
|
|
print("[INFO] CPUInfer created with single subpool (tp_count=1)")
|
|
|
|
# Create MOE SFT config using the new API
|
|
config = kt_kernel_ext.moe.MOESFTConfig()
|
|
config.expert_num = expert_num
|
|
config.num_experts_per_tok = num_experts_per_tok
|
|
config.hidden_size = hidden_size
|
|
config.intermediate_size = intermediate_size
|
|
config.lora_rank = lora_rank
|
|
config.lora_alpha = lora_alpha
|
|
config.max_cache_depth = 1
|
|
config.max_len = max_len
|
|
config.layer_idx = 0
|
|
|
|
# Bug #26 fix: K2 uses pre-quantized weights with scales
|
|
if quant_mode == "int4_kgroup" and k2_weights is not None:
|
|
config.gate_proj = k2_weights["gate_qweight"].data_ptr()
|
|
config.up_proj = k2_weights["up_qweight"].data_ptr()
|
|
config.down_proj = k2_weights["down_qweight"].data_ptr()
|
|
config.gate_scale = k2_weights["gate_scales"].data_ptr()
|
|
config.up_scale = k2_weights["up_scales"].data_ptr()
|
|
config.down_scale = k2_weights["down_scales"].data_ptr()
|
|
else:
|
|
config.gate_proj = gate_proj.data_ptr()
|
|
config.up_proj = up_proj.data_ptr()
|
|
config.down_proj = down_proj.data_ptr()
|
|
|
|
# Set LoRA weight pointers directly in config (zero-copy)
|
|
config.gate_lora_a = gate_lora_a.data_ptr()
|
|
config.gate_lora_b = gate_lora_b.data_ptr()
|
|
config.up_lora_a = up_lora_a.data_ptr()
|
|
config.up_lora_b = up_lora_b.data_ptr()
|
|
config.down_lora_a = down_lora_a.data_ptr()
|
|
config.down_lora_b = down_lora_b.data_ptr()
|
|
config.pool = CPUInfer.backend_
|
|
|
|
# Bug #23 fix: Set quant_config for AWQ/K2 modes
|
|
# Bug #25 fix: AWQ (int4_1kgroup) uses zero_point, K2 (int4_kgroup) does NOT
|
|
if quant_mode == "int4_1kgroup": # AWQ supports zero_point
|
|
config.quant_config.group_size = 128
|
|
config.quant_config.zero_point = True
|
|
elif quant_mode == "int4_kgroup": # K2 does NOT support zero_point
|
|
config.quant_config.group_size = 128
|
|
config.quant_config.zero_point = False
|
|
|
|
# Create MOE SFT instance based on quant_mode
|
|
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
|
|
moe = MOE_SFT_CLASS(config)
|
|
print(f"[INFO] Using {quant_mode.upper()} MOE SFT class: {MOE_SFT_CLASS.__name__}")
|
|
|
|
# Load base weights
|
|
CPUInfer.submit(moe.load_weights_task())
|
|
CPUInfer.sync()
|
|
|
|
# Warm up
|
|
CPUInfer.submit(moe.warm_up_task())
|
|
CPUInfer.sync()
|
|
|
|
# Get threshold for this quant_mode
|
|
threshold = get_threshold(quant_mode)
|
|
|
|
# Run validation iterations
|
|
for iter_idx in range(validation_iter):
|
|
print(f"\n--- Iteration {iter_idx} ---")
|
|
|
|
# Generate random inputs
|
|
bsz_tensor = torch.tensor([qlen], device="cpu")
|
|
expert_ids = (
|
|
torch.stack([torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)])
|
|
.to(torch.int64)
|
|
.contiguous()
|
|
)
|
|
weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()
|
|
weights = weights / weights.sum(dim=-1, keepdim=True) # Normalize
|
|
input_data = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() / 100
|
|
|
|
# PyTorch reference forward
|
|
torch_output, _ = moe_sft_torch_forward(
|
|
input_data,
|
|
expert_ids,
|
|
weights,
|
|
gate_proj,
|
|
up_proj,
|
|
down_proj,
|
|
gate_lora_a,
|
|
gate_lora_b,
|
|
up_lora_a,
|
|
up_lora_b,
|
|
down_lora_a,
|
|
down_lora_b,
|
|
lora_scaling,
|
|
debug_print=(iter_idx == 0),
|
|
)
|
|
|
|
# AMX forward using forward_sft_task
|
|
output = torch.zeros((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
|
|
CPUInfer.submit(
|
|
moe.forward_sft_task(
|
|
bsz_tensor.data_ptr(),
|
|
num_experts_per_tok,
|
|
expert_ids.data_ptr(),
|
|
weights.data_ptr(),
|
|
input_data.data_ptr(),
|
|
output.data_ptr(),
|
|
False, # save_for_backward=False to avoid cache overflow
|
|
)
|
|
)
|
|
CPUInfer.sync()
|
|
|
|
# Debug: print AMX output
|
|
print(f"[AMX SFT DEBUG] AMX output[:8] = {output.flatten()[:8]}")
|
|
print(f"[AMX SFT DEBUG] AMX output mean abs = {torch.mean(torch.abs(output)):.6e}")
|
|
print(f"[AMX SFT DEBUG] Torch output mean abs = {torch.mean(torch.abs(torch_output)):.6e}")
|
|
|
|
# Compare results
|
|
diff = torch.mean(torch.abs(output - torch_output)) / (torch.mean(torch.abs(torch_output)) + 1e-8)
|
|
print(f"Relative difference: {diff:.6f}")
|
|
|
|
if diff < threshold:
|
|
print(f"PASSED (threshold: {threshold})")
|
|
else:
|
|
print(f"FAILED: diff={diff:.6f} >= {threshold}")
|
|
# Don't exit immediately, continue to show all iterations
|
|
|
|
print(f"\n--- Final Result ---")
|
|
if diff < threshold:
|
|
print(f"[OK] MOE SFT Forward Pass Test - {quant_mode.upper()} mode (NO TP) PASSED")
|
|
else:
|
|
print(f"[FAILED] MOE SFT Forward Pass Test - {quant_mode.upper()} mode (NO TP) FAILED")
|
|
print(f" This means the bug is in the basic SFT forward logic, not TP partitioning.")
|
|
sys.exit(1)
|
|
|
|
|
|
def test_moe_sft_backward_no_tp(quant_mode: str = "bf16"):
|
|
"""
|
|
Test MOE SFT backward pass accuracy with single NUMA node (no TP).
|
|
|
|
Compares the AMX implementation gradients against PyTorch reference.
|
|
Uses WorkerPoolConfig to force single subpool.
|
|
|
|
Args:
|
|
quant_mode: Quantization mode, "bf16" or "int8"
|
|
"""
|
|
print(f"\n{'='*60}")
|
|
print(f"Testing MOE SFT Backward Pass - {quant_mode.upper()} mode (NO TP)")
|
|
print(f"{'='*60}")
|
|
|
|
# Set random seed for reproducibility
|
|
torch.manual_seed(42)
|
|
|
|
# Initialize weights based on quant_mode
|
|
k2_weights = None # Will be set for K2 mode
|
|
if quant_mode == "int4_kgroup":
|
|
# K2 needs pre-quantized int4 weights
|
|
k2_weights = init_base_weights_for_k2(expert_num, hidden_size, intermediate_size, group_size=128)
|
|
# Use original BF16 for reference computation
|
|
gate_proj = k2_weights["gate_proj_bf16"]
|
|
up_proj = k2_weights["up_proj_bf16"]
|
|
down_proj = k2_weights["down_proj_bf16"]
|
|
else:
|
|
# Other modes use BF16 weights
|
|
gate_proj, up_proj, down_proj = init_base_weights(expert_num, hidden_size, intermediate_size)
|
|
|
|
lora_weights = init_lora_weights(expert_num, hidden_size, intermediate_size, lora_rank)
|
|
gate_lora_a, gate_lora_b, up_lora_a, up_lora_b, down_lora_a, down_lora_b = lora_weights
|
|
|
|
# Make LoRA B non-zero for testing
|
|
gate_lora_b.normal_().div_(100)
|
|
up_lora_b.normal_().div_(100)
|
|
down_lora_b.normal_().div_(100)
|
|
|
|
if not HAS_KT_KERNEL:
|
|
print("ERROR: kt_kernel_ext not available, cannot run test")
|
|
sys.exit(1)
|
|
|
|
# Initialize CPUInfer with single NUMA node configuration
|
|
print("\n[INFO] Creating CPUInfer with single NUMA node (NO TP)...")
|
|
pool_config = kt_kernel_ext.WorkerPoolConfig()
|
|
pool_config.subpool_count = 1
|
|
pool_config.subpool_numa_map = [0]
|
|
pool_config.subpool_thread_count = [num_threads]
|
|
CPUInfer = kt_kernel_ext.CPUInfer(pool_config)
|
|
print("[INFO] CPUInfer created with single subpool (tp_count=1)")
|
|
|
|
# Create MOE SFT config - max_cache_depth must match validation_iter for backward
|
|
config = kt_kernel_ext.moe.MOESFTConfig()
|
|
config.expert_num = expert_num
|
|
config.num_experts_per_tok = num_experts_per_tok
|
|
config.hidden_size = hidden_size
|
|
config.intermediate_size = intermediate_size
|
|
config.lora_rank = lora_rank
|
|
config.lora_alpha = lora_alpha
|
|
config.max_cache_depth = validation_iter # Need cache for backward
|
|
config.max_len = max_len
|
|
config.layer_idx = 0
|
|
|
|
# Bug #26 fix: K2 uses pre-quantized weights with scales
|
|
if quant_mode == "int4_kgroup" and k2_weights is not None:
|
|
config.gate_proj = k2_weights["gate_qweight"].data_ptr()
|
|
config.up_proj = k2_weights["up_qweight"].data_ptr()
|
|
config.down_proj = k2_weights["down_qweight"].data_ptr()
|
|
config.gate_scale = k2_weights["gate_scales"].data_ptr()
|
|
config.up_scale = k2_weights["up_scales"].data_ptr()
|
|
config.down_scale = k2_weights["down_scales"].data_ptr()
|
|
else:
|
|
config.gate_proj = gate_proj.data_ptr()
|
|
config.up_proj = up_proj.data_ptr()
|
|
config.down_proj = down_proj.data_ptr()
|
|
|
|
config.gate_lora_a = gate_lora_a.data_ptr()
|
|
config.gate_lora_b = gate_lora_b.data_ptr()
|
|
config.up_lora_a = up_lora_a.data_ptr()
|
|
config.up_lora_b = up_lora_b.data_ptr()
|
|
config.down_lora_a = down_lora_a.data_ptr()
|
|
config.down_lora_b = down_lora_b.data_ptr()
|
|
config.pool = CPUInfer.backend_
|
|
|
|
# Bug #23 fix: Set quant_config for AWQ/K2 modes
|
|
# Bug #25 fix: AWQ (int4_1kgroup) uses zero_point, K2 (int4_kgroup) does NOT
|
|
if quant_mode == "int4_1kgroup": # AWQ supports zero_point
|
|
config.quant_config.group_size = 128
|
|
config.quant_config.zero_point = True
|
|
elif quant_mode == "int4_kgroup": # K2 does NOT support zero_point
|
|
config.quant_config.group_size = 128
|
|
config.quant_config.zero_point = False
|
|
|
|
# Create MOE SFT instance based on quant_mode
|
|
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
|
|
moe = MOE_SFT_CLASS(config)
|
|
print(f"[INFO] Using {quant_mode.upper()} MOE SFT class: {MOE_SFT_CLASS.__name__}")
|
|
|
|
# Load base weights
|
|
CPUInfer.submit(moe.load_weights_task())
|
|
CPUInfer.sync()
|
|
|
|
# Warm up
|
|
CPUInfer.submit(moe.warm_up_task())
|
|
CPUInfer.sync()
|
|
|
|
# Get threshold for this quant_mode
|
|
threshold = get_threshold(quant_mode, is_backward=True)
|
|
|
|
# Run validation iterations
|
|
for iter_idx in range(validation_iter):
|
|
print(f"\n--- Iteration {iter_idx} ---")
|
|
|
|
# Generate random inputs
|
|
bsz_tensor = torch.tensor([qlen], device="cpu")
|
|
expert_ids = (
|
|
torch.stack([torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)])
|
|
.to(torch.int64)
|
|
.contiguous()
|
|
)
|
|
weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()
|
|
weights = weights / weights.sum(dim=-1, keepdim=True)
|
|
input_data = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() / 100
|
|
|
|
# Random gradient from upstream
|
|
grad_output = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() / 100
|
|
|
|
# PyTorch reference forward + backward
|
|
_, moe_saved = moe_sft_torch_forward(
|
|
input_data,
|
|
expert_ids,
|
|
weights,
|
|
gate_proj,
|
|
up_proj,
|
|
down_proj,
|
|
gate_lora_a,
|
|
gate_lora_b,
|
|
up_lora_a,
|
|
up_lora_b,
|
|
down_lora_a,
|
|
down_lora_b,
|
|
lora_scaling,
|
|
)
|
|
|
|
torch_grads = moe_sft_torch_backward(
|
|
grad_output,
|
|
moe_saved,
|
|
gate_proj,
|
|
up_proj,
|
|
down_proj,
|
|
gate_lora_a,
|
|
gate_lora_b,
|
|
up_lora_a,
|
|
up_lora_b,
|
|
down_lora_a,
|
|
down_lora_b,
|
|
lora_scaling,
|
|
)
|
|
|
|
# AMX forward (with save_for_backward=True)
|
|
output = torch.zeros((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
|
|
CPUInfer.submit(
|
|
moe.forward_sft_task(
|
|
bsz_tensor.data_ptr(),
|
|
num_experts_per_tok,
|
|
expert_ids.data_ptr(),
|
|
weights.data_ptr(),
|
|
input_data.data_ptr(),
|
|
output.data_ptr(),
|
|
True, # save_for_backward
|
|
)
|
|
)
|
|
CPUInfer.sync()
|
|
|
|
# Allocate gradient buffers
|
|
grad_input = torch.zeros((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
|
|
grad_gate_lora_a = torch.zeros_like(gate_lora_a)
|
|
grad_gate_lora_b = torch.zeros_like(gate_lora_b)
|
|
grad_up_lora_a = torch.zeros_like(up_lora_a)
|
|
grad_up_lora_b = torch.zeros_like(up_lora_b)
|
|
grad_down_lora_a = torch.zeros_like(down_lora_a)
|
|
grad_down_lora_b = torch.zeros_like(down_lora_b)
|
|
|
|
# AMX backward
|
|
CPUInfer.submit(
|
|
moe.backward_task(
|
|
grad_output.data_ptr(),
|
|
grad_input.data_ptr(),
|
|
grad_gate_lora_a.data_ptr(),
|
|
grad_gate_lora_b.data_ptr(),
|
|
grad_up_lora_a.data_ptr(),
|
|
grad_up_lora_b.data_ptr(),
|
|
grad_down_lora_a.data_ptr(),
|
|
grad_down_lora_b.data_ptr(),
|
|
)
|
|
)
|
|
CPUInfer.sync()
|
|
|
|
# Compare gradients (threshold already set before loop)
|
|
# Input gradient
|
|
diff_input = torch.mean(torch.abs(grad_input - torch_grads["grad_input"])) / (
|
|
torch.mean(torch.abs(torch_grads["grad_input"])) + 1e-8
|
|
)
|
|
print(f"grad_input diff: {diff_input:.6f}")
|
|
assert diff_input < threshold, f"grad_input accuracy failed: {diff_input:.6f}"
|
|
|
|
# LoRA gradients (check activated experts only)
|
|
activated = [i for i, n in enumerate(moe_saved["tokens_per_expert"]) if n > 0]
|
|
|
|
# Debug: compare PyTorch and C++ gradient values for Bug #17
|
|
print(f"\n[DEBUG COMPARISON] Activated experts: {activated[:5]}...") # Only print first 5
|
|
print(f"[DEBUG COMPARISON] First activated expert: {activated[0] if activated else 'None'}")
|
|
|
|
if activated:
|
|
first_exp = activated[0]
|
|
print(
|
|
f"\n[TORCH DEBUG] grad_gate_lora_a[{first_exp}][0, 0:8] = {torch_grads['grad_gate_lora_a'][first_exp, 0, :8]}"
|
|
)
|
|
print(f"[AMX DEBUG] grad_gate_lora_a[{first_exp}][0, 0:8] = {grad_gate_lora_a[first_exp, 0, :8]}")
|
|
print(f"[TORCH DEBUG] mean abs = {torch.mean(torch.abs(torch_grads['grad_gate_lora_a'][first_exp])):.6e}")
|
|
print(f"[AMX DEBUG] mean abs = {torch.mean(torch.abs(grad_gate_lora_a[first_exp])):.6e}")
|
|
|
|
# Also check up_lora_a and down_lora_a
|
|
print(
|
|
f"\n[TORCH DEBUG] grad_up_lora_a[{first_exp}][0, 0:8] = {torch_grads['grad_up_lora_a'][first_exp, 0, :8]}"
|
|
)
|
|
print(f"[AMX DEBUG] grad_up_lora_a[{first_exp}][0, 0:8] = {grad_up_lora_a[first_exp, 0, :8]}")
|
|
print(
|
|
f"[TORCH DEBUG] grad_down_lora_a[{first_exp}][0, 0:8] = {torch_grads['grad_down_lora_a'][first_exp, 0, :8]}"
|
|
)
|
|
print(f"[AMX DEBUG] grad_down_lora_a[{first_exp}][0, 0:8] = {grad_down_lora_a[first_exp, 0, :8]}")
|
|
|
|
for name, amx_grad, torch_grad in [
|
|
("gate_lora_a", grad_gate_lora_a, torch_grads["grad_gate_lora_a"]),
|
|
("gate_lora_b", grad_gate_lora_b, torch_grads["grad_gate_lora_b"]),
|
|
("up_lora_a", grad_up_lora_a, torch_grads["grad_up_lora_a"]),
|
|
("up_lora_b", grad_up_lora_b, torch_grads["grad_up_lora_b"]),
|
|
("down_lora_a", grad_down_lora_a, torch_grads["grad_down_lora_a"]),
|
|
("down_lora_b", grad_down_lora_b, torch_grads["grad_down_lora_b"]),
|
|
]:
|
|
amx_subset = amx_grad[activated]
|
|
torch_subset = torch_grad[activated]
|
|
diff = torch.mean(torch.abs(amx_subset - torch_subset)) / (torch.mean(torch.abs(torch_subset)) + 1e-8)
|
|
print(f" {name} diff: {diff:.6f}")
|
|
assert diff < threshold, f"{name} accuracy failed: {diff:.6f}"
|
|
|
|
print(f"PASSED (threshold: {threshold})")
|
|
|
|
print(f"\n[OK] MOE SFT Backward Pass Test - {quant_mode.upper()} mode (NO TP) PASSED")
|
|
|
|
|
|
def test_moe_sft_lora_weight_sync_no_tp(quant_mode: str = "bf16"):
|
|
"""
|
|
Test LoRA weight synchronization with single NUMA node (no TP).
|
|
|
|
Verifies that:
|
|
1. Initial config correctly sets LoRA weight pointers (zero-copy)
|
|
2. Modified weights are correctly reflected via update_lora_weights_task
|
|
3. Forward pass uses the updated weights
|
|
|
|
Args:
|
|
quant_mode: Quantization mode, "bf16" or "int8"
|
|
"""
|
|
print(f"\n{'='*60}")
|
|
print(f"Testing LoRA Weight Synchronization - {quant_mode.upper()} mode (NO TP)")
|
|
print(f"{'='*60}")
|
|
|
|
if not HAS_KT_KERNEL:
|
|
print("ERROR: kt_kernel_ext not available, cannot run test")
|
|
sys.exit(1)
|
|
|
|
torch.manual_seed(42)
|
|
|
|
# Initialize weights based on quant_mode
|
|
k2_weights = None # Will be set for K2 mode
|
|
if quant_mode == "int4_kgroup":
|
|
# K2 needs pre-quantized int4 weights
|
|
k2_weights = init_base_weights_for_k2(expert_num, hidden_size, intermediate_size, group_size=128)
|
|
# Use original BF16 for reference computation
|
|
gate_proj = k2_weights["gate_proj_bf16"]
|
|
up_proj = k2_weights["up_proj_bf16"]
|
|
down_proj = k2_weights["down_proj_bf16"]
|
|
else:
|
|
# Other modes use BF16 weights
|
|
gate_proj, up_proj, down_proj = init_base_weights(expert_num, hidden_size, intermediate_size)
|
|
|
|
lora_weights = init_lora_weights(expert_num, hidden_size, intermediate_size, lora_rank)
|
|
gate_lora_a, gate_lora_b, up_lora_a, up_lora_b, down_lora_a, down_lora_b = lora_weights
|
|
|
|
# Initialize CPUInfer with single NUMA node
|
|
pool_config = kt_kernel_ext.WorkerPoolConfig()
|
|
pool_config.subpool_count = 1
|
|
pool_config.subpool_numa_map = [0]
|
|
pool_config.subpool_thread_count = [num_threads]
|
|
CPUInfer = kt_kernel_ext.CPUInfer(pool_config)
|
|
|
|
# Create MOE SFT config
|
|
config = kt_kernel_ext.moe.MOESFTConfig()
|
|
config.expert_num = expert_num
|
|
config.num_experts_per_tok = num_experts_per_tok
|
|
config.hidden_size = hidden_size
|
|
config.intermediate_size = intermediate_size
|
|
config.lora_rank = lora_rank
|
|
config.lora_alpha = lora_alpha
|
|
config.max_cache_depth = 1
|
|
config.max_len = max_len
|
|
config.layer_idx = 0
|
|
|
|
# Bug #26 fix: K2 uses pre-quantized weights with scales
|
|
if quant_mode == "int4_kgroup" and k2_weights is not None:
|
|
config.gate_proj = k2_weights["gate_qweight"].data_ptr()
|
|
config.up_proj = k2_weights["up_qweight"].data_ptr()
|
|
config.down_proj = k2_weights["down_qweight"].data_ptr()
|
|
config.gate_scale = k2_weights["gate_scales"].data_ptr()
|
|
config.up_scale = k2_weights["up_scales"].data_ptr()
|
|
config.down_scale = k2_weights["down_scales"].data_ptr()
|
|
else:
|
|
config.gate_proj = gate_proj.data_ptr()
|
|
config.up_proj = up_proj.data_ptr()
|
|
config.down_proj = down_proj.data_ptr()
|
|
|
|
config.gate_lora_a = gate_lora_a.data_ptr()
|
|
config.gate_lora_b = gate_lora_b.data_ptr()
|
|
config.up_lora_a = up_lora_a.data_ptr()
|
|
config.up_lora_b = up_lora_b.data_ptr()
|
|
config.down_lora_a = down_lora_a.data_ptr()
|
|
config.down_lora_b = down_lora_b.data_ptr()
|
|
config.pool = CPUInfer.backend_
|
|
|
|
# Bug #23 fix: Set quant_config for AWQ/K2 modes
|
|
# Bug #25 fix: AWQ (int4_1kgroup) uses zero_point, K2 (int4_kgroup) does NOT
|
|
if quant_mode == "int4_1kgroup": # AWQ supports zero_point
|
|
config.quant_config.group_size = 128
|
|
config.quant_config.zero_point = True
|
|
elif quant_mode == "int4_kgroup": # K2 does NOT support zero_point
|
|
config.quant_config.group_size = 128
|
|
config.quant_config.zero_point = False
|
|
|
|
# Create MOE SFT instance based on quant_mode
|
|
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
|
|
moe = MOE_SFT_CLASS(config)
|
|
print(f"[INFO] Using {quant_mode.upper()} MOE SFT class: {MOE_SFT_CLASS.__name__}")
|
|
|
|
# Load base weights
|
|
CPUInfer.submit(moe.load_weights_task())
|
|
CPUInfer.sync()
|
|
|
|
# Warm up
|
|
CPUInfer.submit(moe.warm_up_task())
|
|
CPUInfer.sync()
|
|
|
|
# Test data
|
|
bsz_tensor = torch.tensor([qlen], device="cpu")
|
|
expert_ids = (
|
|
torch.stack([torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)])
|
|
.to(torch.int64)
|
|
.contiguous()
|
|
)
|
|
weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()
|
|
weights = weights / weights.sum(dim=-1, keepdim=True)
|
|
input_data = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() / 100
|
|
|
|
# First forward with initial LoRA weights
|
|
output1 = torch.zeros((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
|
|
CPUInfer.submit(
|
|
moe.forward_sft_task(
|
|
bsz_tensor.data_ptr(),
|
|
num_experts_per_tok,
|
|
expert_ids.data_ptr(),
|
|
weights.data_ptr(),
|
|
input_data.data_ptr(),
|
|
output1.data_ptr(),
|
|
False,
|
|
)
|
|
)
|
|
CPUInfer.sync()
|
|
|
|
# Modify LoRA weights (simulating optimizer.step())
|
|
gate_lora_a.add_(0.1)
|
|
gate_lora_b.add_(0.1)
|
|
up_lora_a.add_(0.1)
|
|
up_lora_b.add_(0.1)
|
|
down_lora_a.add_(0.1)
|
|
down_lora_b.add_(0.1)
|
|
|
|
# Bug #22 fix: After modifying LoRA weights, sync to kernel
|
|
# (partitioned weights are copied, not zero-copy)
|
|
CPUInfer.submit(
|
|
moe.update_lora_weights_task(
|
|
gate_lora_a.data_ptr(),
|
|
gate_lora_b.data_ptr(),
|
|
up_lora_a.data_ptr(),
|
|
up_lora_b.data_ptr(),
|
|
down_lora_a.data_ptr(),
|
|
down_lora_b.data_ptr(),
|
|
)
|
|
)
|
|
CPUInfer.sync()
|
|
|
|
# Second forward with updated LoRA weights
|
|
output2 = torch.zeros((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
|
|
CPUInfer.submit(
|
|
moe.forward_sft_task(
|
|
bsz_tensor.data_ptr(),
|
|
num_experts_per_tok,
|
|
expert_ids.data_ptr(),
|
|
weights.data_ptr(),
|
|
input_data.data_ptr(),
|
|
output2.data_ptr(),
|
|
False,
|
|
)
|
|
)
|
|
CPUInfer.sync()
|
|
|
|
# Outputs should be different after weight update
|
|
diff = torch.mean(torch.abs(output1 - output2))
|
|
print(f"Output difference after weight update: {diff:.6f}")
|
|
assert diff > 1e-6, "Outputs should differ after LoRA weight update"
|
|
|
|
# Debug: Print current pointer and value before clone
|
|
print(f"\n[PYTHON DEBUG] Phase 2 - Original pointers:")
|
|
print(f" gate_lora_a ptr: {hex(gate_lora_a.data_ptr())}")
|
|
print(f" gate_lora_a[0,0,0]: {gate_lora_a[0,0,0].item():.6f}")
|
|
print(f" gate_lora_b ptr: {hex(gate_lora_b.data_ptr())}")
|
|
|
|
# Test explicit update_lora_weights_task (for when tensors are reallocated)
|
|
new_gate_lora_a = gate_lora_a.clone()
|
|
new_gate_lora_b = gate_lora_b.clone()
|
|
new_up_lora_a = up_lora_a.clone()
|
|
new_up_lora_b = up_lora_b.clone()
|
|
new_down_lora_a = down_lora_a.clone()
|
|
new_down_lora_b = down_lora_b.clone()
|
|
|
|
# Debug: Verify cloned values match and print new pointers
|
|
print(f"\n[PYTHON DEBUG] Phase 3 - Cloned pointers:")
|
|
print(f" new_gate_lora_a ptr: {hex(new_gate_lora_a.data_ptr())}")
|
|
print(f" new_gate_lora_a[0,0,0]: {new_gate_lora_a[0,0,0].item():.6f}")
|
|
print(f" new_gate_lora_b ptr: {hex(new_gate_lora_b.data_ptr())}")
|
|
assert torch.allclose(gate_lora_a, new_gate_lora_a), "Clone failed for gate_lora_a!"
|
|
assert torch.allclose(gate_lora_b, new_gate_lora_b), "Clone failed for gate_lora_b!"
|
|
print(f" Clone verification: PASSED")
|
|
|
|
# Update pointers using update_lora_weights_task
|
|
print(f"\n[PYTHON DEBUG] Calling update_lora_weights_task...")
|
|
CPUInfer.submit(
|
|
moe.update_lora_weights_task(
|
|
new_gate_lora_a.data_ptr(),
|
|
new_gate_lora_b.data_ptr(),
|
|
new_up_lora_a.data_ptr(),
|
|
new_up_lora_b.data_ptr(),
|
|
new_down_lora_a.data_ptr(),
|
|
new_down_lora_b.data_ptr(),
|
|
)
|
|
)
|
|
CPUInfer.sync()
|
|
print(f"[PYTHON DEBUG] update_lora_weights_task completed")
|
|
|
|
# Third forward with new tensor pointers
|
|
print(f"\n[PYTHON DEBUG] Phase 3 - Running forward with new pointers...")
|
|
output3 = torch.zeros((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
|
|
CPUInfer.submit(
|
|
moe.forward_sft_task(
|
|
bsz_tensor.data_ptr(),
|
|
num_experts_per_tok,
|
|
expert_ids.data_ptr(),
|
|
weights.data_ptr(),
|
|
input_data.data_ptr(),
|
|
output3.data_ptr(),
|
|
False,
|
|
)
|
|
)
|
|
CPUInfer.sync()
|
|
|
|
# Output3 should match output2 (same weights, different tensor locations)
|
|
diff_same = torch.mean(torch.abs(output2 - output3))
|
|
print(f"Output difference after pointer update (should be ~0): {diff_same:.6f}")
|
|
assert diff_same < 1e-5, f"Outputs should match after pointer update: {diff_same:.6f}"
|
|
|
|
print(f"[OK] LoRA Weight Synchronization Test - {quant_mode.upper()} mode (NO TP) PASSED")
|
|
|
|
|
|
def test_moe_sft_training_loop_no_tp(quant_mode: str = "bf16"):
|
|
"""
|
|
Test complete training loop with single NUMA node (no TP).
|
|
|
|
This simulates a real training scenario where:
|
|
1. Forward pass computes output and saves activations
|
|
2. Backward pass computes gradients for LoRA weights
|
|
3. Optimizer updates LoRA weights
|
|
4. Next forward uses updated weights (zero-copy via pointers)
|
|
|
|
Args:
|
|
quant_mode: Quantization mode, "bf16" or "int8"
|
|
"""
|
|
print(f"\n{'='*60}")
|
|
print(f"Testing Complete Training Loop - {quant_mode.upper()} mode (NO TP)")
|
|
print(f"{'='*60}")
|
|
|
|
torch.manual_seed(42)
|
|
|
|
# Initialize base weights based on quant_mode
|
|
k2_weights = None # Will be set for K2 mode
|
|
if quant_mode == "int4_kgroup":
|
|
# K2 needs pre-quantized int4 weights
|
|
k2_weights = init_base_weights_for_k2(expert_num, hidden_size, intermediate_size, group_size=128)
|
|
# Use original BF16 for reference computation
|
|
gate_proj = k2_weights["gate_proj_bf16"]
|
|
up_proj = k2_weights["up_proj_bf16"]
|
|
down_proj = k2_weights["down_proj_bf16"]
|
|
else:
|
|
# Other modes use BF16 weights
|
|
gate_proj, up_proj, down_proj = init_base_weights(expert_num, hidden_size, intermediate_size)
|
|
|
|
# Initialize LoRA weights as contiguous tensors
|
|
gate_lora_a = (
|
|
torch.randn(expert_num, lora_rank, hidden_size, dtype=torch.bfloat16, device="cuda").to("cpu").contiguous()
|
|
/ 100
|
|
)
|
|
gate_lora_b = torch.zeros(expert_num, intermediate_size, lora_rank, dtype=torch.bfloat16).contiguous()
|
|
up_lora_a = (
|
|
torch.randn(expert_num, lora_rank, hidden_size, dtype=torch.bfloat16, device="cuda").to("cpu").contiguous()
|
|
/ 100
|
|
)
|
|
up_lora_b = torch.zeros(expert_num, intermediate_size, lora_rank, dtype=torch.bfloat16).contiguous()
|
|
down_lora_a = (
|
|
torch.randn(expert_num, lora_rank, intermediate_size, dtype=torch.bfloat16, device="cuda")
|
|
.to("cpu")
|
|
.contiguous()
|
|
/ 100
|
|
)
|
|
down_lora_b = torch.zeros(expert_num, hidden_size, lora_rank, dtype=torch.bfloat16).contiguous()
|
|
|
|
# Make LoRA B non-zero for testing
|
|
gate_lora_b.normal_().div_(100)
|
|
up_lora_b.normal_().div_(100)
|
|
down_lora_b.normal_().div_(100)
|
|
|
|
# Wrap tensors as nn.Parameters for optimizer
|
|
gate_lora_a_param = torch.nn.Parameter(gate_lora_a)
|
|
gate_lora_b_param = torch.nn.Parameter(gate_lora_b)
|
|
up_lora_a_param = torch.nn.Parameter(up_lora_a)
|
|
up_lora_b_param = torch.nn.Parameter(up_lora_b)
|
|
down_lora_a_param = torch.nn.Parameter(down_lora_a)
|
|
down_lora_b_param = torch.nn.Parameter(down_lora_b)
|
|
|
|
lora_params = [
|
|
gate_lora_a_param,
|
|
gate_lora_b_param,
|
|
up_lora_a_param,
|
|
up_lora_b_param,
|
|
down_lora_a_param,
|
|
down_lora_b_param,
|
|
]
|
|
|
|
# Create optimizer
|
|
optimizer = torch.optim.AdamW(lora_params, lr=1e-4)
|
|
|
|
# Initialize kt_kernel
|
|
moe = None
|
|
CPUInfer = None
|
|
if HAS_KT_KERNEL:
|
|
pool_config = kt_kernel_ext.WorkerPoolConfig()
|
|
pool_config.subpool_count = 1
|
|
pool_config.subpool_numa_map = [0]
|
|
pool_config.subpool_thread_count = [num_threads]
|
|
CPUInfer = kt_kernel_ext.CPUInfer(pool_config)
|
|
|
|
# Create MOE SFT config
|
|
config = kt_kernel_ext.moe.MOESFTConfig()
|
|
config.expert_num = expert_num
|
|
config.num_experts_per_tok = num_experts_per_tok
|
|
config.hidden_size = hidden_size
|
|
config.intermediate_size = intermediate_size
|
|
config.lora_rank = lora_rank
|
|
config.lora_alpha = lora_alpha
|
|
config.max_cache_depth = 1 # One forward-backward pair at a time
|
|
config.max_len = max_len
|
|
config.layer_idx = 0
|
|
|
|
# Bug #26 fix: K2 uses pre-quantized weights with scales
|
|
if quant_mode == "int4_kgroup" and k2_weights is not None:
|
|
config.gate_proj = k2_weights["gate_qweight"].data_ptr()
|
|
config.up_proj = k2_weights["up_qweight"].data_ptr()
|
|
config.down_proj = k2_weights["down_qweight"].data_ptr()
|
|
config.gate_scale = k2_weights["gate_scales"].data_ptr()
|
|
config.up_scale = k2_weights["up_scales"].data_ptr()
|
|
config.down_scale = k2_weights["down_scales"].data_ptr()
|
|
else:
|
|
config.gate_proj = gate_proj.data_ptr()
|
|
config.up_proj = up_proj.data_ptr()
|
|
config.down_proj = down_proj.data_ptr()
|
|
|
|
config.gate_lora_a = gate_lora_a_param.data.data_ptr()
|
|
config.gate_lora_b = gate_lora_b_param.data.data_ptr()
|
|
config.up_lora_a = up_lora_a_param.data.data_ptr()
|
|
config.up_lora_b = up_lora_b_param.data.data_ptr()
|
|
config.down_lora_a = down_lora_a_param.data.data_ptr()
|
|
config.down_lora_b = down_lora_b_param.data.data_ptr()
|
|
config.pool = CPUInfer.backend_
|
|
|
|
# Bug #23 fix: Set quant_config for AWQ/K2 modes
|
|
# Bug #25 fix: AWQ (int4_1kgroup) uses zero_point, K2 (int4_kgroup) does NOT
|
|
if quant_mode == "int4_1kgroup": # AWQ supports zero_point
|
|
config.quant_config.group_size = 128
|
|
config.quant_config.zero_point = True
|
|
elif quant_mode == "int4_kgroup": # K2 does NOT support zero_point
|
|
config.quant_config.group_size = 128
|
|
config.quant_config.zero_point = False
|
|
|
|
# Create MOE SFT instance based on quant_mode
|
|
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
|
|
moe = MOE_SFT_CLASS(config)
|
|
print(f"[INFO] Using {quant_mode.upper()} MOE SFT class: {MOE_SFT_CLASS.__name__}")
|
|
|
|
# Load base weights
|
|
CPUInfer.submit(moe.load_weights_task())
|
|
CPUInfer.sync()
|
|
|
|
# Warm up
|
|
CPUInfer.submit(moe.warm_up_task())
|
|
CPUInfer.sync()
|
|
else:
|
|
print("WARNING: kt_kernel_ext not available, running PyTorch-only training loop")
|
|
|
|
num_training_steps = 3
|
|
|
|
for step in range(num_training_steps):
|
|
print(f"\n--- Training Step {step} ---")
|
|
|
|
# Generate batch
|
|
expert_ids = (
|
|
torch.stack([torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)])
|
|
.to(torch.int64)
|
|
.contiguous()
|
|
)
|
|
weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()
|
|
weights = weights / weights.sum(dim=-1, keepdim=True)
|
|
input_data = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() / 100
|
|
target = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() / 100
|
|
|
|
if HAS_KT_KERNEL and moe is not None:
|
|
bsz_tensor = torch.tensor([qlen], device="cpu")
|
|
|
|
# Forward pass (with save_for_backward=True)
|
|
output = torch.zeros((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
|
|
CPUInfer.submit(
|
|
moe.forward_sft_task(
|
|
bsz_tensor.data_ptr(),
|
|
num_experts_per_tok,
|
|
expert_ids.data_ptr(),
|
|
weights.data_ptr(),
|
|
input_data.data_ptr(),
|
|
output.data_ptr(),
|
|
True, # save_for_backward
|
|
)
|
|
)
|
|
CPUInfer.sync()
|
|
|
|
# Simple MSE loss
|
|
loss = torch.mean((output.float() - target.float()) ** 2)
|
|
print(f" Loss (AMX): {loss.item():.6f}")
|
|
|
|
# Compute gradient of loss w.r.t. output
|
|
grad_output = 2 * (output.float() - target.float()) / output.numel()
|
|
grad_output = grad_output.to(torch.bfloat16).contiguous()
|
|
|
|
# Allocate gradient buffers
|
|
grad_input = torch.zeros((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
|
|
grad_gate_lora_a = torch.zeros_like(gate_lora_a_param.data)
|
|
grad_gate_lora_b = torch.zeros_like(gate_lora_b_param.data)
|
|
grad_up_lora_a = torch.zeros_like(up_lora_a_param.data)
|
|
grad_up_lora_b = torch.zeros_like(up_lora_b_param.data)
|
|
grad_down_lora_a = torch.zeros_like(down_lora_a_param.data)
|
|
grad_down_lora_b = torch.zeros_like(down_lora_b_param.data)
|
|
|
|
# Backward pass
|
|
CPUInfer.submit(
|
|
moe.backward_task(
|
|
grad_output.data_ptr(),
|
|
grad_input.data_ptr(),
|
|
grad_gate_lora_a.data_ptr(),
|
|
grad_gate_lora_b.data_ptr(),
|
|
grad_up_lora_a.data_ptr(),
|
|
grad_up_lora_b.data_ptr(),
|
|
grad_down_lora_a.data_ptr(),
|
|
grad_down_lora_b.data_ptr(),
|
|
)
|
|
)
|
|
CPUInfer.sync()
|
|
|
|
# Copy gradients to parameters
|
|
gate_lora_a_param.grad = grad_gate_lora_a
|
|
gate_lora_b_param.grad = grad_gate_lora_b
|
|
up_lora_a_param.grad = grad_up_lora_a
|
|
up_lora_b_param.grad = grad_up_lora_b
|
|
down_lora_a_param.grad = grad_down_lora_a
|
|
down_lora_b_param.grad = grad_down_lora_b
|
|
|
|
else:
|
|
# PyTorch reference forward + backward
|
|
output, moe_saved = moe_sft_torch_forward(
|
|
input_data.detach(),
|
|
expert_ids,
|
|
weights,
|
|
gate_proj,
|
|
up_proj,
|
|
down_proj,
|
|
gate_lora_a_param.data.contiguous(),
|
|
gate_lora_b_param.data.contiguous(),
|
|
up_lora_a_param.data.contiguous(),
|
|
up_lora_b_param.data.contiguous(),
|
|
down_lora_a_param.data.contiguous(),
|
|
down_lora_b_param.data.contiguous(),
|
|
lora_scaling,
|
|
)
|
|
|
|
# Simple MSE loss
|
|
loss = torch.mean((output.float() - target.float()) ** 2)
|
|
print(f" Loss (PyTorch): {loss.item():.6f}")
|
|
|
|
# Compute gradient of loss w.r.t. output
|
|
grad_output = 2 * (output.float() - target.float()) / output.numel()
|
|
grad_output = grad_output.to(torch.bfloat16).contiguous()
|
|
|
|
# Backward pass
|
|
grads = moe_sft_torch_backward(
|
|
grad_output,
|
|
moe_saved,
|
|
gate_proj,
|
|
up_proj,
|
|
down_proj,
|
|
gate_lora_a_param.data.contiguous(),
|
|
gate_lora_b_param.data.contiguous(),
|
|
up_lora_a_param.data.contiguous(),
|
|
up_lora_b_param.data.contiguous(),
|
|
down_lora_a_param.data.contiguous(),
|
|
down_lora_b_param.data.contiguous(),
|
|
lora_scaling,
|
|
)
|
|
|
|
# Copy gradients to parameters
|
|
gate_lora_a_param.grad = grads["grad_gate_lora_a"]
|
|
gate_lora_b_param.grad = grads["grad_gate_lora_b"]
|
|
up_lora_a_param.grad = grads["grad_up_lora_a"]
|
|
up_lora_b_param.grad = grads["grad_up_lora_b"]
|
|
down_lora_a_param.grad = grads["grad_down_lora_a"]
|
|
down_lora_b_param.grad = grads["grad_down_lora_b"]
|
|
|
|
# Print gradient norms to verify gradients are computed
|
|
print(f" gate_lora_a grad norm: {gate_lora_a_param.grad.norm().item():.6e}")
|
|
print(f" gate_lora_b grad norm: {gate_lora_b_param.grad.norm().item():.6e}")
|
|
|
|
# Save weight snapshots before optimizer step
|
|
gate_lora_a_before = gate_lora_a_param.data.clone()
|
|
gate_lora_b_before = gate_lora_b_param.data.clone()
|
|
|
|
# Optimizer step
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
# Calculate weight changes
|
|
gate_a_diff = (gate_lora_a_param.data - gate_lora_a_before).abs().mean().item()
|
|
gate_b_diff = (gate_lora_b_param.data - gate_lora_b_before).abs().mean().item()
|
|
|
|
# Print weight norms with higher precision
|
|
print(f" gate_lora_a norm: {gate_lora_a_param.data.norm().item():.10f}")
|
|
print(f" gate_lora_b norm: {gate_lora_b_param.data.norm().item():.10f}")
|
|
print(f" gate_lora_a weight change (mean abs): {gate_a_diff:.10e}")
|
|
print(f" gate_lora_b weight change (mean abs): {gate_b_diff:.10e}")
|
|
|
|
# Verify weights are actually being updated
|
|
assert gate_a_diff > 0, "gate_lora_a weights should change after optimizer step"
|
|
assert gate_b_diff > 0, "gate_lora_b weights should change after optimizer step"
|
|
|
|
print(f"\n[OK] Training Loop Test - {quant_mode.upper()} mode (NO TP) PASSED")
|
|
|
|
|
|
# =============================================================================
|
|
# Main Entry Point
|
|
# =============================================================================
|
|
|
|
|
|
def run_all_tests():
|
|
"""Run all MOE SFT tests for all quantization modes in non-TP mode."""
|
|
print("\n" + "=" * 70)
|
|
print(" MOE SFT AMX Test Suite - Non-TP Version (Single NUMA Node)")
|
|
print("=" * 70)
|
|
print(f"Configuration:")
|
|
print(f" expert_num: {expert_num}")
|
|
print(f" hidden_size: {hidden_size}")
|
|
print(f" intermediate_size: {intermediate_size}")
|
|
print(f" num_experts_per_tok: {num_experts_per_tok}")
|
|
print(f" lora_rank: {lora_rank}")
|
|
print(f" lora_alpha: {lora_alpha}")
|
|
print(f" qlen: {qlen}")
|
|
print(f" num_threads: {num_threads}")
|
|
print(f" TP mode: DISABLED (single NUMA node)")
|
|
print("=" * 70)
|
|
|
|
# Quantization modes to test
|
|
# quant_modes = ["bf16", "int8", "int4", "int4_1"]
|
|
# quant_modes = ["int4_1kgroup", "int4_kgroup"]
|
|
quant_modes = ["int4_kgroup"]
|
|
|
|
try:
|
|
for quant_mode in quant_modes:
|
|
print(f"\n{'='*70}")
|
|
print(f" Testing MOE SFT AMX - {quant_mode.upper()} Mode (NO TP)")
|
|
print(f"{'='*70}")
|
|
|
|
# Forward pass test
|
|
test_moe_sft_forward_no_tp(quant_mode)
|
|
|
|
# Backward pass test
|
|
test_moe_sft_backward_no_tp(quant_mode)
|
|
|
|
# Weight sync test
|
|
test_moe_sft_lora_weight_sync_no_tp(quant_mode)
|
|
|
|
# Full training loop test
|
|
test_moe_sft_training_loop_no_tp(quant_mode)
|
|
|
|
print("\n" + "=" * 70)
|
|
print(" ALL TESTS PASSED!")
|
|
print(f" Tested quantization modes: {', '.join(m.upper() for m in quant_modes)}")
|
|
print("=" * 70)
|
|
|
|
except Exception as e:
|
|
print(f"\n[FAILED] Test failed with error: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_all_tests()
|