[fix](kt-sft-refactor): fix AMXInt4_KGroup mode for SFT_MOE

This commit is contained in:
JimmyPeilinLi 2026-01-06 06:40:18 +00:00
parent 84935a22a6
commit 244b82eaa4
6 changed files with 722 additions and 34 deletions

View file

@ -2965,3 +2965,391 @@ CPUInfer.sync()
---
## Bug #23: INT4_1KGROUP (AWQ/K2) SFT MOE 崩溃 【已修复】
### 问题描述
在测试 INT4_1KGROUP (AWQ) 和 INT4_KGROUP (K2) 模式的 SFT MOE 时,程序崩溃:
```
SIGFPE, Arithmetic exception
amx_buffers.hpp:233: k % k_group_size (k_group_size=0)
```
### 调用栈
```
BufferAWithSumKGroupImpl::BufferAWithSumKGroupImpl(max_m=25600, k=7168, k_group_size=0)
← AMX_AWQ_MOE_TP::make_buffer_a_impl()
← AMX_MOE_BASE::make_buffer_a()
← AMX_MOE_BASE::init()
← AMX_AWQ_MOE_TP::AMX_AWQ_MOE_TP()
← AMX_SFT_MOE_TP::AMX_SFT_MOE_TP()
```
### 根因分析
1. `QuantConfig` 结构体默认 `group_size = 0` (`operators/common.hpp:225`)
2. AWQ/K2 模式需要 `group_size > 0`(标准值为 128
3. Python 测试创建 `MOESFTConfig` 时没有设置 `quant_config.group_size`
4. AWQ 构造函数中的检查(`awq-moe.hpp:399-401`)在 `AMX_MOE_BASE::init()` 之后才执行
5. 但 `init()` 调用 `make_buffer_a()` 时就已经用到了 `group_size`,导致除零错误
### 修复方案
在 Python 测试文件中,为 AWQ/K2 模式设置 `quant_config.group_size = 128``quant_config.zero_point = True`
```python
config = kt_kernel_ext.moe.MOESFTConfig(...)
# ... 其他配置 ...
config.pool = CPUInfer.backend_
# Bug #23 fix: Set quant_config for AWQ/K2 modes
if quant_mode in ("int4_1kgroup", "int4_kgroup"):
config.quant_config.group_size = 128
config.quant_config.zero_point = True
# Create MOE SFT instance
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
moe = MOE_SFT_CLASS(config)
```
### 修复文件清单
| 文件 | 修改位置 | 状态 |
|------|---------|------|
| `examples/test_moe_sft_amx_no_tp.py` | 4 处 config 创建后 | ✓ 已修复 |
| `examples/test_moe_sft_amx.py` | 4 处 config 创建后 | ✓ 已修复 |
### 教训总结
1. **配置默认值**`QuantConfig``group_size = 0` 默认值对 AWQ/K2 模式不安全
2. **构造顺序**CRTP 基类的 `init()` 在派生类检查之前执行,无法在派生类构造函数中提前检查
3. **测试配置**:添加新量化模式时,需要确保测试配置正确设置所有必需参数
---
## Bug #24: INT4_1KGROUP Training Loop SIGSEGV 崩溃 【已修复】
### 问题描述
在 Bug #23 修复后INT4_1KGROUP 的 Forward/Backward/Sync 测试都通过了,但 Training Loop 测试崩溃:
```
SIGSEGV, Segmentation fault
awq-moe.hpp:554: convert_or_copy() with garbage pointer
```
### 调用栈
```
ggml_fp16_to_fp32_row(x=0x880e881608fb2308, ...) ← 垃圾指针!
← convert_or_copy(gate_bb_[expert_idx]->d, (ggml_fp16_t*)config_.gate_scale + offset, ...)
← AMX_AWQ_MOE_TP::load_weights() [awq-moe.hpp:554]
← AMX_SFT_MOE_TP::load_weights_without_lora()
```
### 根因分析
1. `GeneralMOEConfig` 结构体中的 `void*` 指针没有初始化为 `nullptr`
2. 位于 `operators/common.hpp:243-253`
```cpp
void* gate_proj; // 未初始化!
void* gate_scale; // 未初始化! ← 导致崩溃
// ... 其他 void* 指针
```
3. 在 `awq-moe.hpp:507-586``load_weights()` 函数中:
```cpp
else if (config_.gate_scale != nullptr) { // 垃圾值被误判为 true!
// 预量化权重路径 - 错误地进入此分支
convert_or_copy(..., (ggml_fp16_t*)config_.gate_scale + offset, ...); // CRASH!
} else {
// Online Quantization from BF16 - SFT 应该走这条路径!
}
```
### 为什么 BF16/INT8 没有崩溃?
**关键差异**
- **BF16/INT8** (`moe.hpp:219`): 使用 `std::vector` 检查
```cpp
if (config_.gate_projs.size()) { // std::vector 默认初始化为空!
```
- **AWQ/K2** (`awq-moe.hpp:507`): 使用裸指针检查
```cpp
else if (config_.gate_scale != nullptr) // 未初始化 = 垃圾值!
```
`std::vector` 会被正确默认构造为空,而裸 `void*` 指针不会自动初始化。
### 为什么 Forward/Backward/Sync 测试没问题?
可能是内存分配/布局的偶然性:
- Forward/Backward/Sync 测试时,内存恰好被清零
- Training Loop 测试有不同的内存布局,包含垃圾值
### 修复方案
`operators/common.hpp` 中为所有 `void*` 指针添加默认值 `= nullptr`
```cpp
// 修改前:
void* gate_proj;
void* gate_scale;
// ...
// 修改后:
void* gate_proj = nullptr;
void* gate_scale = nullptr;
// ...
```
### 修复文件清单
| 文件 | 修改位置 | 状态 |
|------|---------|------|
| `operators/common.hpp` | 第 243-253 行 | ✓ 已修复 |
### 教训总结
1. **指针初始化**C++ 结构体中的裸指针应始终显式初始化为 `nullptr`
2. **未定义行为**:未初始化的指针会导致难以复现的 bug取决于内存状态
3. **std::vector vs void\***`std::vector` 会自动初始化,但 `void*` 不会
---
## Bug #25: INT4_KGROUP 测试 zero_point 配置错误
**时间**: 2026-01-05
**状态**: ✅ 已修复
### 问题描述
INT4_KGROUP 模式的 SFT 测试在构造函数处崩溃:
```
terminate called after throwing an instance of 'std::runtime_error'
what(): Kimi-K2 MoE only support KGroup Int4
Aborted (core dumped)
```
### 根因分析
1. K2 MOE (`k2-moe.hpp:51-52`) 的校验逻辑:
```cpp
if (quant_config.group_size == 0 || quant_config.zero_point) {
throw std::runtime_error("Kimi-K2 MoE only support KGroup Int4");
}
```
2. 测试文件错误地为 K2 模式设置了 `zero_point = True`
```python
if quant_mode in ("int4_1kgroup", "int4_kgroup"):
config.quant_config.group_size = 128
config.quant_config.zero_point = True # K2 不支持!
```
3. **AWQ vs K2 技术差异**
- AWQ (`int4_1kgroup`): 使用 scales + zero_points
- K2 (`int4_kgroup`): 只使用 scales不支持 zero_points
4. 证据:`k2-moe.hpp:175-180` 只加载 `gate_scale`, `up_scale`, `down_scale`,没有任何 zero_point 相关代码。
### 修复方案
为 AWQ 和 K2 设置不同的 `zero_point` 配置:
```python
# 修改前:
if quant_mode in ("int4_1kgroup", "int4_kgroup"):
config.quant_config.group_size = 128
config.quant_config.zero_point = True
# 修改后:
if quant_mode == "int4_1kgroup": # AWQ supports zero_point
config.quant_config.group_size = 128
config.quant_config.zero_point = True
elif quant_mode == "int4_kgroup": # K2 does NOT support zero_point
config.quant_config.group_size = 128
config.quant_config.zero_point = False
```
### 修复文件清单
| 文件 | 修改位置 | 状态 |
|------|---------|------|
| `examples/test_moe_sft_amx_no_tp.py` | 4 处 quant_config 配置 | ✓ 已修复 |
### 教训总结
1. **量化模式差异**不同的量化方案AWQ vs K2有不同的配置要求
2. **配置分离**:不应将不同模式的配置合并到同一个条件分支
3. **错误信息指引**K2 的错误信息 "Kimi-K2 MoE only support KGroup Int4" 准确指出了问题
---
## Bug #26: K2 MOE SFT 测试需要预量化权重
**时间**: 2026-01-06
**状态**: ✅ 已修复
### 问题描述
修复 Bug #25INT4_KGROUP SFT 测试在 `load_weights()` 时仍然崩溃:
```
what(): Kimi AVX MOE only support load native weight.
```
尝试添加 Online Quantization 路径后编译失败:
```
'amx::BufferBInt4KGroupImpl' has no member named 'from_mat'; did you mean 'from_raw_mat'?
```
### 根因分析
K2 与 AWQ 架构差异:
| 特性 | K2 (BufferBInt4KGroupImpl) | AWQ (BufferBInt4WithZeroKGroupImpl) |
|------|---------------------------|-------------------------------------|
| Buffer 存储 | weights + scales | weights + scales + zero_points |
| 支持方法 | `from_raw_mat()` 仅 | `from_raw_mat()` + `from_mat()` |
| 量化方式 | Signed Int4 (对称量化) | Unsigned Int4 + zero_point |
| 在线量化 | ❌ 不支持 | ✅ 支持 |
**结论**: K2 MOE 设计上只支持离线预量化权重,不支持在线量化。
### 修复方案
SFT 测试为 K2 模式提供**预量化的 Int4 权重 + scales**,而不是 BF16 权重。
```python
# 添加 K2 量化函数
def quantize_k2_tensor(weights: torch.Tensor, group_size: int):
"""K2 对称量化: BF16 → signed int4 (范围 -8 到 7)"""
reshaped = weights.view(e, rows, cols // group_size, group_size)
max_abs = reshaped.abs().amax(dim=-1, keepdim=True)
scales = (max_abs / 7.0).squeeze(-1)
q = torch.round(reshaped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int8)
packed = pack_tensor_per_row(q, num_bits=4)
return packed, scales.to(torch.bfloat16)
# 测试配置修改
if quant_mode == "int4_kgroup":
k2_weights = init_base_weights_for_k2(expert_num, hidden_size, intermediate_size)
config.gate_proj = k2_weights["gate_qweight"].data_ptr()
config.gate_scale = k2_weights["gate_scales"].data_ptr() # 关键!
```
### 关键改动
1. 撤销 `k2-moe.hpp` 的 Online Path 尝试
2. 添加 `quantize_k2_tensor()`, `pack_to_int32()`, `pack_tensor_per_row()` 函数到测试
3. 添加 `init_base_weights_for_k2()` 初始化函数
4. 修改 4 处测试函数的配置逻辑
### K2 量化格式
| 特性 | K2 | AWQ |
|------|----|----|
| 量化类型 | Symmetric (对称) | Asymmetric (非对称) |
| 范围 | -8 到 7 (signed) | 0 到 15 (unsigned) |
| 参数 | scale only | scale + zero_point |
| 公式 | q = round(w / scale) | q = round(w / scale) + zero |
### 修复文件清单
| 文件 | 修改内容 | 状态 |
|------|---------|------|
| `operators/amx/k2-moe.hpp` | 撤销 Online Path保持原始设计 | ✓ 已修复 |
| `examples/test_moe_sft_amx_no_tp.py` | 添加 K2 量化函数和预量化权重 | ✓ 已修复 |
### 教训总结
1. **架构理解**K2 和 AWQ 使用不同的 Buffer 类型,不能简单地复制代码
2. **设计一致性**K2 设计为仅支持离线预量化权重,测试需配合这一设计
3. **编译验证**:修改 C++ 代码前应先验证接口兼容性
---
## Bug #27: K2 MOE SFT load_weights 路径选择错误
**时间**: 2026-01-06
**状态**: ✅ 已修复
### 问题描述
修复 Bug #26INT4_KGROUP No-TP 测试在 `load_weights()` 时崩溃:
```
Thread "numa_0_t_50" received signal SIGSEGV, Segmentation fault.
__memmove_avx512_unaligned_erms ()
#1 TP_MOE_SFT::load_weights()::{lambda}::operator()(int) at moe-sft-tp.hpp:103
```
日志显示错误的路径被选中:
```
TP_MOE_SFT: From BF16 with partitioning ← 错误!应该用 K2 预量化路径
```
### 根因分析
**`moe-sft-tp.hpp:77` 的判断逻辑问题**
```cpp
if (config.gate_proj != nullptr) {
// 假设 gate_proj 是 BF16 数据
memcpy(..., (ggml_bf16_t*)config.gate_proj + ..., sizeof(ggml_bf16_t) * ...);
}
```
测试设置 `config.gate_proj = k2_weights["gate_qweight"].data_ptr()` (int4 packed),但 C++ 代码见 `gate_proj != nullptr` 就误认为是 BF16用 BF16 偏移量做 memcpy 导致 SIGSEGV。
### 修复方案
`moe-sft-tp.hpp::load_weights()` 添加 K2 预量化模式检测:
```cpp
// K2 pre-quantized mode: gate_scale != nullptr && !zero_point
bool is_k2_prequantized = (config.gate_scale != nullptr && !config.quant_config.zero_point);
if (is_k2_prequantized) {
printf("TP_MOE_SFT: K2 pre-quantized mode (no BF16 partitioning)\n");
if (tp_count == 1) {
// No-TP: 直接调用 load_weightstp_configs[i] 已有所有指针
pool->dispense_backend()->do_numa_job([this](int numa_id) {
tps[numa_id]->load_weights();
});
} else {
throw std::runtime_error("K2 pre-quantized mode does not support TP > 1 yet");
}
} else if (config.gate_proj != nullptr) {
// BF16 分区路径...
}
```
### 检测条件
| 模式 | gate_scale | zero_point | 检测结果 |
|------|-----------|------------|----------|
| K2 | != nullptr | false | is_k2_prequantized = true |
| AWQ | != nullptr | true | is_k2_prequantized = false |
| BF16 | nullptr | - | 走 gate_proj 检测 |
### 修复文件清单
| 文件 | 修改内容 | 状态 |
|------|---------|------|
| `operators/moe-sft-tp.hpp` | 添加 K2 预量化模式分支 | ✓ 已修复 |
### 教训总结
1. **数据类型区分**:同一个指针 (gate_proj) 可能指向不同格式的数据 (BF16 vs int4 packed)
2. **显式检测**:使用多个条件组合 (gate_scale + zero_point) 来区分不同的量化模式
3. **渐进支持**:先实现 No-TP 模式TP > 1 的 K2 支持可后续添加
---

View file

@ -675,6 +675,11 @@ def test_moe_sft_forward(quant_mode: str = "bf16"):
config.down_lora_b = down_lora_b.data_ptr()
config.pool = CPUInfer.backend_
# Bug #23 fix: Set quant_config for AWQ/K2 modes
if quant_mode in ("int4_1kgroup", "int4_kgroup"):
config.quant_config.group_size = 128
config.quant_config.zero_point = True
# Create MOE SFT instance based on quant_mode
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
moe = MOE_SFT_CLASS(config)
@ -821,6 +826,11 @@ def test_moe_sft_backward(quant_mode: str = "bf16"):
config.down_lora_b = down_lora_b.data_ptr()
config.pool = CPUInfer.backend_
# Bug #23 fix: Set quant_config for AWQ/K2 modes
if quant_mode in ("int4_1kgroup", "int4_kgroup"):
config.quant_config.group_size = 128
config.quant_config.zero_point = True
# Create MOE SFT instance based on quant_mode
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
moe = MOE_SFT_CLASS(config)
@ -1031,6 +1041,11 @@ def test_moe_sft_lora_weight_sync(quant_mode: str = "bf16"):
config.down_lora_b = down_lora_b.data_ptr()
config.pool = CPUInfer.backend_
# Bug #23 fix: Set quant_config for AWQ/K2 modes
if quant_mode in ("int4_1kgroup", "int4_kgroup"):
config.quant_config.group_size = 128
config.quant_config.zero_point = True
# Create MOE SFT instance based on quant_mode
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
moe = MOE_SFT_CLASS(config)
@ -1268,6 +1283,11 @@ def test_moe_sft_training_loop(quant_mode: str = "bf16"):
config.down_lora_b = down_lora_b_param.data.data_ptr()
config.pool = CPUInfer.backend_
# Bug #23 fix: Set quant_config for AWQ/K2 modes
if quant_mode in ("int4_1kgroup", "int4_kgroup"):
config.quant_config.group_size = 128
config.quant_config.zero_point = True
# Create MOE SFT instance based on quant_mode
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
moe = MOE_SFT_CLASS(config)
@ -1458,8 +1478,8 @@ def run_all_tests():
print("=" * 70)
# Quantization modes to test
quant_modes = ["bf16", "int8", "int4", "int4_1", "int4_1kgroup", "int4_kgroup"]
# quant_modes = ["int4_1kgroup", "int4_kgroup"]
# quant_modes = ["bf16", "int8", "int4", "int4_1"]
quant_modes = ["int4_1kgroup", "int4_kgroup"]
try:
for quant_mode in quant_modes:

View file

@ -13,6 +13,8 @@ Key difference from test_moe_sft_amx.py:
import os
import sys
import math
from typing import Literal, Dict
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
print("sys.path:", sys.path)
@ -115,6 +117,141 @@ def get_threshold(quant_mode: str, is_backward: bool = False) -> float:
return BF16_FORWARD_THRESHOLD # 0.05
# =============================================================================
# K2 Quantization Utilities (for INT4_KGROUP mode)
# =============================================================================
def pack_to_int32(value: torch.Tensor, num_bits: int, packed_dim: Literal[0, 1] = 1) -> torch.Tensor:
"""Pack int4 values into int32 tensor.
Args:
value: int8 tensor to pack
num_bits: number of bits per value (4 for int4)
packed_dim: dimension to pack along
Returns:
int32 tensor with packed values
"""
if value.dtype is not torch.int8:
raise ValueError("Tensor must be torch.int8 before packing")
if not (1 <= num_bits <= 8):
raise ValueError(f"num_bits must be in [1, 8], got {num_bits}")
offset = 1 << (num_bits - 1)
value = (value + offset).to(torch.uint8)
device = value.device
pack_factor = 32 // num_bits
if packed_dim == 0:
value = value.transpose(0, 1)
rows, cols = value.shape
padded_cols = math.ceil(cols / pack_factor) * pack_factor
pad_len = padded_cols - cols
if pad_len > 0:
value = torch.nn.functional.pad(value, (0, pad_len))
num_groups = padded_cols // pack_factor
# Use int32 here
reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32)
bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits
packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32)
if packed_dim == 0:
packed = packed.transpose(0, 1)
return packed
def pack_tensor_per_row(q: torch.Tensor, num_bits: int) -> torch.Tensor:
"""Pack tensor per row for K2 quantization.
Args:
q: [expert_num, rows, cols] int8 tensor
num_bits: number of bits per value
Returns:
Packed int32 tensor
"""
e, rows, cols = q.shape
flat = q.view(e * rows, cols)
packed = pack_to_int32(flat, num_bits)
return packed.view(e, rows, -1).contiguous()
def quantize_k2_tensor(weights: torch.Tensor, group_size: int):
"""
K2 symmetric max-abs/7 quantization per k-group.
Args:
weights: [expert_num, rows (N), cols (K)] bfloat16 tensor
Returns:
packed_q: int32 tensor storing 8 int4s per element with shape [expert_num, rows * (cols // 8)]
scales: bfloat16 tensor with shape [expert_num, rows * (cols // group_size)]
"""
weights_f32 = weights.to(torch.float32)
e, rows, cols = weights_f32.shape
if cols % group_size != 0 or cols % 2 != 0:
raise ValueError(f"cols ({cols}) must be divisible by group_size ({group_size}) and 2")
reshaped = weights_f32.view(e, rows, cols // group_size, group_size)
max_abs = reshaped.abs().amax(dim=-1, keepdim=True)
max_abs = torch.clamp(max_abs, min=1e-8)
scales = (max_abs / 7.0).squeeze(-1)
q = torch.round(reshaped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int8)
q = q.view(e, rows, cols)
packed = pack_tensor_per_row(q, num_bits=4).view(e, rows, cols // 8).contiguous()
scales = scales.to(torch.bfloat16).contiguous().view(e, rows, cols // group_size).contiguous()
return packed, scales
def init_base_weights_for_k2(
expert_num: int, hidden_size: int, intermediate_size: int, group_size: int = 128
) -> Dict[str, torch.Tensor]:
"""Initialize pre-quantized K2 weights for INT4_KGROUP mode.
Args:
expert_num: number of experts
hidden_size: hidden dimension
intermediate_size: intermediate dimension
group_size: quantization group size
Returns:
Dictionary containing:
- gate_qweight, up_qweight, down_qweight: packed int4 weights
- gate_scales, up_scales, down_scales: bf16 scales
- gate_proj_bf16, up_proj_bf16, down_proj_bf16: original bf16 weights for reference
"""
# Create random BF16 weights
gate_proj_bf16 = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16)
up_proj_bf16 = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16)
down_proj_bf16 = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.bfloat16)
# Quantize to int4
gate_qweight, gate_scales = quantize_k2_tensor(gate_proj_bf16, group_size)
up_qweight, up_scales = quantize_k2_tensor(up_proj_bf16, group_size)
down_qweight, down_scales = quantize_k2_tensor(down_proj_bf16, group_size)
return {
"gate_qweight": gate_qweight.contiguous(),
"up_qweight": up_qweight.contiguous(),
"down_qweight": down_qweight.contiguous(),
"gate_scales": gate_scales.contiguous(),
"up_scales": up_scales.contiguous(),
"down_scales": down_scales.contiguous(),
# Keep original BF16 for gradient verification
"gate_proj_bf16": gate_proj_bf16.contiguous(),
"up_proj_bf16": up_proj_bf16.contiguous(),
"down_proj_bf16": down_proj_bf16.contiguous(),
}
# =============================================================================
# Activation Functions
# =============================================================================
@ -620,8 +757,19 @@ def test_moe_sft_forward_no_tp(quant_mode: str = "bf16"):
# Set random seed for reproducibility
torch.manual_seed(42)
# Initialize weights
gate_proj, up_proj, down_proj = init_base_weights(expert_num, hidden_size, intermediate_size)
# Initialize weights based on quant_mode
k2_weights = None # Will be set for K2 mode
if quant_mode == "int4_kgroup":
# K2 needs pre-quantized int4 weights
k2_weights = init_base_weights_for_k2(expert_num, hidden_size, intermediate_size, group_size=128)
# Use original BF16 for reference computation
gate_proj = k2_weights["gate_proj_bf16"]
up_proj = k2_weights["up_proj_bf16"]
down_proj = k2_weights["down_proj_bf16"]
else:
# Other modes use BF16 weights
gate_proj, up_proj, down_proj = init_base_weights(expert_num, hidden_size, intermediate_size)
lora_weights = init_lora_weights(expert_num, hidden_size, intermediate_size, lora_rank)
gate_lora_a, gate_lora_b, up_lora_a, up_lora_b, down_lora_a, down_lora_b = lora_weights
@ -655,9 +803,20 @@ def test_moe_sft_forward_no_tp(quant_mode: str = "bf16"):
config.max_cache_depth = 1
config.max_len = max_len
config.layer_idx = 0
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
config.down_proj = down_proj.data_ptr()
# Bug #26 fix: K2 uses pre-quantized weights with scales
if quant_mode == "int4_kgroup" and k2_weights is not None:
config.gate_proj = k2_weights["gate_qweight"].data_ptr()
config.up_proj = k2_weights["up_qweight"].data_ptr()
config.down_proj = k2_weights["down_qweight"].data_ptr()
config.gate_scale = k2_weights["gate_scales"].data_ptr()
config.up_scale = k2_weights["up_scales"].data_ptr()
config.down_scale = k2_weights["down_scales"].data_ptr()
else:
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
config.down_proj = down_proj.data_ptr()
# Set LoRA weight pointers directly in config (zero-copy)
config.gate_lora_a = gate_lora_a.data_ptr()
config.gate_lora_b = gate_lora_b.data_ptr()
@ -667,6 +826,15 @@ def test_moe_sft_forward_no_tp(quant_mode: str = "bf16"):
config.down_lora_b = down_lora_b.data_ptr()
config.pool = CPUInfer.backend_
# Bug #23 fix: Set quant_config for AWQ/K2 modes
# Bug #25 fix: AWQ (int4_1kgroup) uses zero_point, K2 (int4_kgroup) does NOT
if quant_mode == "int4_1kgroup": # AWQ supports zero_point
config.quant_config.group_size = 128
config.quant_config.zero_point = True
elif quant_mode == "int4_kgroup": # K2 does NOT support zero_point
config.quant_config.group_size = 128
config.quant_config.zero_point = False
# Create MOE SFT instance based on quant_mode
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
moe = MOE_SFT_CLASS(config)
@ -772,8 +940,19 @@ def test_moe_sft_backward_no_tp(quant_mode: str = "bf16"):
# Set random seed for reproducibility
torch.manual_seed(42)
# Initialize weights
gate_proj, up_proj, down_proj = init_base_weights(expert_num, hidden_size, intermediate_size)
# Initialize weights based on quant_mode
k2_weights = None # Will be set for K2 mode
if quant_mode == "int4_kgroup":
# K2 needs pre-quantized int4 weights
k2_weights = init_base_weights_for_k2(expert_num, hidden_size, intermediate_size, group_size=128)
# Use original BF16 for reference computation
gate_proj = k2_weights["gate_proj_bf16"]
up_proj = k2_weights["up_proj_bf16"]
down_proj = k2_weights["down_proj_bf16"]
else:
# Other modes use BF16 weights
gate_proj, up_proj, down_proj = init_base_weights(expert_num, hidden_size, intermediate_size)
lora_weights = init_lora_weights(expert_num, hidden_size, intermediate_size, lora_rank)
gate_lora_a, gate_lora_b, up_lora_a, up_lora_b, down_lora_a, down_lora_b = lora_weights
@ -806,9 +985,20 @@ def test_moe_sft_backward_no_tp(quant_mode: str = "bf16"):
config.max_cache_depth = validation_iter # Need cache for backward
config.max_len = max_len
config.layer_idx = 0
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
config.down_proj = down_proj.data_ptr()
# Bug #26 fix: K2 uses pre-quantized weights with scales
if quant_mode == "int4_kgroup" and k2_weights is not None:
config.gate_proj = k2_weights["gate_qweight"].data_ptr()
config.up_proj = k2_weights["up_qweight"].data_ptr()
config.down_proj = k2_weights["down_qweight"].data_ptr()
config.gate_scale = k2_weights["gate_scales"].data_ptr()
config.up_scale = k2_weights["up_scales"].data_ptr()
config.down_scale = k2_weights["down_scales"].data_ptr()
else:
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
config.down_proj = down_proj.data_ptr()
config.gate_lora_a = gate_lora_a.data_ptr()
config.gate_lora_b = gate_lora_b.data_ptr()
config.up_lora_a = up_lora_a.data_ptr()
@ -817,6 +1007,15 @@ def test_moe_sft_backward_no_tp(quant_mode: str = "bf16"):
config.down_lora_b = down_lora_b.data_ptr()
config.pool = CPUInfer.backend_
# Bug #23 fix: Set quant_config for AWQ/K2 modes
# Bug #25 fix: AWQ (int4_1kgroup) uses zero_point, K2 (int4_kgroup) does NOT
if quant_mode == "int4_1kgroup": # AWQ supports zero_point
config.quant_config.group_size = 128
config.quant_config.zero_point = True
elif quant_mode == "int4_kgroup": # K2 does NOT support zero_point
config.quant_config.group_size = 128
config.quant_config.zero_point = False
# Create MOE SFT instance based on quant_mode
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
moe = MOE_SFT_CLASS(config)
@ -997,8 +1196,19 @@ def test_moe_sft_lora_weight_sync_no_tp(quant_mode: str = "bf16"):
torch.manual_seed(42)
# Initialize weights
gate_proj, up_proj, down_proj = init_base_weights(expert_num, hidden_size, intermediate_size)
# Initialize weights based on quant_mode
k2_weights = None # Will be set for K2 mode
if quant_mode == "int4_kgroup":
# K2 needs pre-quantized int4 weights
k2_weights = init_base_weights_for_k2(expert_num, hidden_size, intermediate_size, group_size=128)
# Use original BF16 for reference computation
gate_proj = k2_weights["gate_proj_bf16"]
up_proj = k2_weights["up_proj_bf16"]
down_proj = k2_weights["down_proj_bf16"]
else:
# Other modes use BF16 weights
gate_proj, up_proj, down_proj = init_base_weights(expert_num, hidden_size, intermediate_size)
lora_weights = init_lora_weights(expert_num, hidden_size, intermediate_size, lora_rank)
gate_lora_a, gate_lora_b, up_lora_a, up_lora_b, down_lora_a, down_lora_b = lora_weights
@ -1020,9 +1230,20 @@ def test_moe_sft_lora_weight_sync_no_tp(quant_mode: str = "bf16"):
config.max_cache_depth = 1
config.max_len = max_len
config.layer_idx = 0
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
config.down_proj = down_proj.data_ptr()
# Bug #26 fix: K2 uses pre-quantized weights with scales
if quant_mode == "int4_kgroup" and k2_weights is not None:
config.gate_proj = k2_weights["gate_qweight"].data_ptr()
config.up_proj = k2_weights["up_qweight"].data_ptr()
config.down_proj = k2_weights["down_qweight"].data_ptr()
config.gate_scale = k2_weights["gate_scales"].data_ptr()
config.up_scale = k2_weights["up_scales"].data_ptr()
config.down_scale = k2_weights["down_scales"].data_ptr()
else:
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
config.down_proj = down_proj.data_ptr()
config.gate_lora_a = gate_lora_a.data_ptr()
config.gate_lora_b = gate_lora_b.data_ptr()
config.up_lora_a = up_lora_a.data_ptr()
@ -1031,6 +1252,15 @@ def test_moe_sft_lora_weight_sync_no_tp(quant_mode: str = "bf16"):
config.down_lora_b = down_lora_b.data_ptr()
config.pool = CPUInfer.backend_
# Bug #23 fix: Set quant_config for AWQ/K2 modes
# Bug #25 fix: AWQ (int4_1kgroup) uses zero_point, K2 (int4_kgroup) does NOT
if quant_mode == "int4_1kgroup": # AWQ supports zero_point
config.quant_config.group_size = 128
config.quant_config.zero_point = True
elif quant_mode == "int4_kgroup": # K2 does NOT support zero_point
config.quant_config.group_size = 128
config.quant_config.zero_point = False
# Create MOE SFT instance based on quant_mode
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
moe = MOE_SFT_CLASS(config)
@ -1193,8 +1423,18 @@ def test_moe_sft_training_loop_no_tp(quant_mode: str = "bf16"):
torch.manual_seed(42)
# Initialize base weights
gate_proj, up_proj, down_proj = init_base_weights(expert_num, hidden_size, intermediate_size)
# Initialize base weights based on quant_mode
k2_weights = None # Will be set for K2 mode
if quant_mode == "int4_kgroup":
# K2 needs pre-quantized int4 weights
k2_weights = init_base_weights_for_k2(expert_num, hidden_size, intermediate_size, group_size=128)
# Use original BF16 for reference computation
gate_proj = k2_weights["gate_proj_bf16"]
up_proj = k2_weights["up_proj_bf16"]
down_proj = k2_weights["down_proj_bf16"]
else:
# Other modes use BF16 weights
gate_proj, up_proj, down_proj = init_base_weights(expert_num, hidden_size, intermediate_size)
# Initialize LoRA weights as contiguous tensors
gate_lora_a = (
@ -1261,9 +1501,20 @@ def test_moe_sft_training_loop_no_tp(quant_mode: str = "bf16"):
config.max_cache_depth = 1 # One forward-backward pair at a time
config.max_len = max_len
config.layer_idx = 0
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
config.down_proj = down_proj.data_ptr()
# Bug #26 fix: K2 uses pre-quantized weights with scales
if quant_mode == "int4_kgroup" and k2_weights is not None:
config.gate_proj = k2_weights["gate_qweight"].data_ptr()
config.up_proj = k2_weights["up_qweight"].data_ptr()
config.down_proj = k2_weights["down_qweight"].data_ptr()
config.gate_scale = k2_weights["gate_scales"].data_ptr()
config.up_scale = k2_weights["up_scales"].data_ptr()
config.down_scale = k2_weights["down_scales"].data_ptr()
else:
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
config.down_proj = down_proj.data_ptr()
config.gate_lora_a = gate_lora_a_param.data.data_ptr()
config.gate_lora_b = gate_lora_b_param.data.data_ptr()
config.up_lora_a = up_lora_a_param.data.data_ptr()
@ -1272,6 +1523,15 @@ def test_moe_sft_training_loop_no_tp(quant_mode: str = "bf16"):
config.down_lora_b = down_lora_b_param.data.data_ptr()
config.pool = CPUInfer.backend_
# Bug #23 fix: Set quant_config for AWQ/K2 modes
# Bug #25 fix: AWQ (int4_1kgroup) uses zero_point, K2 (int4_kgroup) does NOT
if quant_mode == "int4_1kgroup": # AWQ supports zero_point
config.quant_config.group_size = 128
config.quant_config.zero_point = True
elif quant_mode == "int4_kgroup": # K2 does NOT support zero_point
config.quant_config.group_size = 128
config.quant_config.zero_point = False
# Create MOE SFT instance based on quant_mode
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
moe = MOE_SFT_CLASS(config)
@ -1463,8 +1723,9 @@ def run_all_tests():
print("=" * 70)
# Quantization modes to test
# quant_modes = ["bf16", "int8", "int4", "int4_1", "int4_1kgroup", "int4_kgroup"]
quant_modes = ["int4_1kgroup", "int4_kgroup"]
# quant_modes = ["bf16", "int8", "int4", "int4_1"]
# quant_modes = ["int4_1kgroup", "int4_kgroup"]
quant_modes = ["int4_kgroup"]
try:
for quant_mode in quant_modes:

View file

@ -115,6 +115,9 @@ class AMX_K2_MOE_TP : public AMX_MOE_BASE<T, AMX_K2_MOE_TP<T>> {
*
* Loads weights from config_.gate_proj, up_proj, down_proj with scales
* from config_.gate_scale, up_scale, down_scale.
*
* Note: K2 MOE only supports offline pre-quantized weights (gate_scale must be set).
* For online quantization, use AWQ MOE instead.
*/
void load_weights() {
auto& quant_config = config_.quant_config;

View file

@ -240,17 +240,17 @@ struct GeneralMOEConfig {
int num_gpu_experts = 0;
void* physical_to_logical_map = nullptr;
void* gate_proj;
void* up_proj;
void* down_proj;
void* gate_proj = nullptr;
void* up_proj = nullptr;
void* down_proj = nullptr;
void* gate_scale;
void* up_scale;
void* down_scale;
void* gate_scale = nullptr;
void* up_scale = nullptr;
void* down_scale = nullptr;
void* gate_zero;
void* up_zero;
void* down_zero;
void* gate_zero = nullptr;
void* up_zero = nullptr;
void* down_zero = nullptr;
QuantConfig quant_config;

View file

@ -74,7 +74,23 @@ class TP_MOE_SFT : public TP_MOE<T> {
auto pool = config.pool;
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
if (config.gate_proj != nullptr) {
// Bug #27 fix: K2 pre-quantized mode detection
// K2 uses gate_scale != nullptr and zero_point = false
// AWQ also has gate_scale but has zero_point = true
bool is_k2_prequantized = (config.gate_scale != nullptr && !config.quant_config.zero_point);
if (is_k2_prequantized) {
printf("TP_MOE_SFT: K2 pre-quantized mode (no BF16 partitioning)\n");
// For K2, weights are already int4-packed with scales
// tp_configs[i] already has all pointers from config (copied in TP_MOE constructor)
if (tp_count == 1) {
// No-TP: just call load_weights directly
pool->dispense_backend()->do_numa_job([this](int numa_id) { tps[numa_id]->load_weights(); });
} else {
// TP mode with K2 would need int4-aware partitioning (not implemented yet)
throw std::runtime_error("K2 pre-quantized mode does not support TP > 1 yet");
}
} else if (config.gate_proj != nullptr) {
printf("TP_MOE_SFT: From BF16 with partitioning\n");
// Temporary storage for partitioned weights