mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 03:39:48 +00:00
[fix](kt-sft-refactor): fix AMXInt4_KGroup mode for SFT_MOE
This commit is contained in:
parent
84935a22a6
commit
244b82eaa4
6 changed files with 722 additions and 34 deletions
|
|
@ -2965,3 +2965,391 @@ CPUInfer.sync()
|
|||
|
||||
---
|
||||
|
||||
## Bug #23: INT4_1KGROUP (AWQ/K2) SFT MOE 崩溃 【已修复】
|
||||
|
||||
### 问题描述
|
||||
|
||||
在测试 INT4_1KGROUP (AWQ) 和 INT4_KGROUP (K2) 模式的 SFT MOE 时,程序崩溃:
|
||||
|
||||
```
|
||||
SIGFPE, Arithmetic exception
|
||||
amx_buffers.hpp:233: k % k_group_size (k_group_size=0)
|
||||
```
|
||||
|
||||
### 调用栈
|
||||
|
||||
```
|
||||
BufferAWithSumKGroupImpl::BufferAWithSumKGroupImpl(max_m=25600, k=7168, k_group_size=0)
|
||||
← AMX_AWQ_MOE_TP::make_buffer_a_impl()
|
||||
← AMX_MOE_BASE::make_buffer_a()
|
||||
← AMX_MOE_BASE::init()
|
||||
← AMX_AWQ_MOE_TP::AMX_AWQ_MOE_TP()
|
||||
← AMX_SFT_MOE_TP::AMX_SFT_MOE_TP()
|
||||
```
|
||||
|
||||
### 根因分析
|
||||
|
||||
1. `QuantConfig` 结构体默认 `group_size = 0` (`operators/common.hpp:225`)
|
||||
2. AWQ/K2 模式需要 `group_size > 0`(标准值为 128)
|
||||
3. Python 测试创建 `MOESFTConfig` 时没有设置 `quant_config.group_size`
|
||||
4. AWQ 构造函数中的检查(`awq-moe.hpp:399-401`)在 `AMX_MOE_BASE::init()` 之后才执行
|
||||
5. 但 `init()` 调用 `make_buffer_a()` 时就已经用到了 `group_size`,导致除零错误
|
||||
|
||||
### 修复方案
|
||||
|
||||
在 Python 测试文件中,为 AWQ/K2 模式设置 `quant_config.group_size = 128` 和 `quant_config.zero_point = True`:
|
||||
|
||||
```python
|
||||
config = kt_kernel_ext.moe.MOESFTConfig(...)
|
||||
# ... 其他配置 ...
|
||||
config.pool = CPUInfer.backend_
|
||||
|
||||
# Bug #23 fix: Set quant_config for AWQ/K2 modes
|
||||
if quant_mode in ("int4_1kgroup", "int4_kgroup"):
|
||||
config.quant_config.group_size = 128
|
||||
config.quant_config.zero_point = True
|
||||
|
||||
# Create MOE SFT instance
|
||||
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
|
||||
moe = MOE_SFT_CLASS(config)
|
||||
```
|
||||
|
||||
### 修复文件清单
|
||||
|
||||
| 文件 | 修改位置 | 状态 |
|
||||
|------|---------|------|
|
||||
| `examples/test_moe_sft_amx_no_tp.py` | 4 处 config 创建后 | ✓ 已修复 |
|
||||
| `examples/test_moe_sft_amx.py` | 4 处 config 创建后 | ✓ 已修复 |
|
||||
|
||||
### 教训总结
|
||||
|
||||
1. **配置默认值**:`QuantConfig` 的 `group_size = 0` 默认值对 AWQ/K2 模式不安全
|
||||
2. **构造顺序**:CRTP 基类的 `init()` 在派生类检查之前执行,无法在派生类构造函数中提前检查
|
||||
3. **测试配置**:添加新量化模式时,需要确保测试配置正确设置所有必需参数
|
||||
|
||||
---
|
||||
|
||||
## Bug #24: INT4_1KGROUP Training Loop SIGSEGV 崩溃 【已修复】
|
||||
|
||||
### 问题描述
|
||||
|
||||
在 Bug #23 修复后,INT4_1KGROUP 的 Forward/Backward/Sync 测试都通过了,但 Training Loop 测试崩溃:
|
||||
|
||||
```
|
||||
SIGSEGV, Segmentation fault
|
||||
awq-moe.hpp:554: convert_or_copy() with garbage pointer
|
||||
```
|
||||
|
||||
### 调用栈
|
||||
|
||||
```
|
||||
ggml_fp16_to_fp32_row(x=0x880e881608fb2308, ...) ← 垃圾指针!
|
||||
← convert_or_copy(gate_bb_[expert_idx]->d, (ggml_fp16_t*)config_.gate_scale + offset, ...)
|
||||
← AMX_AWQ_MOE_TP::load_weights() [awq-moe.hpp:554]
|
||||
← AMX_SFT_MOE_TP::load_weights_without_lora()
|
||||
```
|
||||
|
||||
### 根因分析
|
||||
|
||||
1. `GeneralMOEConfig` 结构体中的 `void*` 指针没有初始化为 `nullptr`
|
||||
2. 位于 `operators/common.hpp:243-253`:
|
||||
```cpp
|
||||
void* gate_proj; // 未初始化!
|
||||
void* gate_scale; // 未初始化! ← 导致崩溃
|
||||
// ... 其他 void* 指针
|
||||
```
|
||||
|
||||
3. 在 `awq-moe.hpp:507-586` 的 `load_weights()` 函数中:
|
||||
```cpp
|
||||
else if (config_.gate_scale != nullptr) { // 垃圾值被误判为 true!
|
||||
// 预量化权重路径 - 错误地进入此分支
|
||||
convert_or_copy(..., (ggml_fp16_t*)config_.gate_scale + offset, ...); // CRASH!
|
||||
} else {
|
||||
// Online Quantization from BF16 - SFT 应该走这条路径!
|
||||
}
|
||||
```
|
||||
|
||||
### 为什么 BF16/INT8 没有崩溃?
|
||||
|
||||
**关键差异**:
|
||||
- **BF16/INT8** (`moe.hpp:219`): 使用 `std::vector` 检查
|
||||
```cpp
|
||||
if (config_.gate_projs.size()) { // std::vector 默认初始化为空!
|
||||
```
|
||||
- **AWQ/K2** (`awq-moe.hpp:507`): 使用裸指针检查
|
||||
```cpp
|
||||
else if (config_.gate_scale != nullptr) // 未初始化 = 垃圾值!
|
||||
```
|
||||
|
||||
`std::vector` 会被正确默认构造为空,而裸 `void*` 指针不会自动初始化。
|
||||
|
||||
### 为什么 Forward/Backward/Sync 测试没问题?
|
||||
|
||||
可能是内存分配/布局的偶然性:
|
||||
- Forward/Backward/Sync 测试时,内存恰好被清零
|
||||
- Training Loop 测试有不同的内存布局,包含垃圾值
|
||||
|
||||
### 修复方案
|
||||
|
||||
在 `operators/common.hpp` 中为所有 `void*` 指针添加默认值 `= nullptr`:
|
||||
|
||||
```cpp
|
||||
// 修改前:
|
||||
void* gate_proj;
|
||||
void* gate_scale;
|
||||
// ...
|
||||
|
||||
// 修改后:
|
||||
void* gate_proj = nullptr;
|
||||
void* gate_scale = nullptr;
|
||||
// ...
|
||||
```
|
||||
|
||||
### 修复文件清单
|
||||
|
||||
| 文件 | 修改位置 | 状态 |
|
||||
|------|---------|------|
|
||||
| `operators/common.hpp` | 第 243-253 行 | ✓ 已修复 |
|
||||
|
||||
### 教训总结
|
||||
|
||||
1. **指针初始化**:C++ 结构体中的裸指针应始终显式初始化为 `nullptr`
|
||||
2. **未定义行为**:未初始化的指针会导致难以复现的 bug(取决于内存状态)
|
||||
3. **std::vector vs void\***:`std::vector` 会自动初始化,但 `void*` 不会
|
||||
|
||||
---
|
||||
|
||||
## Bug #25: INT4_KGROUP 测试 zero_point 配置错误
|
||||
|
||||
**时间**: 2026-01-05
|
||||
|
||||
**状态**: ✅ 已修复
|
||||
|
||||
### 问题描述
|
||||
|
||||
INT4_KGROUP 模式的 SFT 测试在构造函数处崩溃:
|
||||
|
||||
```
|
||||
terminate called after throwing an instance of 'std::runtime_error'
|
||||
what(): Kimi-K2 MoE only support KGroup Int4
|
||||
Aborted (core dumped)
|
||||
```
|
||||
|
||||
### 根因分析
|
||||
|
||||
1. K2 MOE (`k2-moe.hpp:51-52`) 的校验逻辑:
|
||||
```cpp
|
||||
if (quant_config.group_size == 0 || quant_config.zero_point) {
|
||||
throw std::runtime_error("Kimi-K2 MoE only support KGroup Int4");
|
||||
}
|
||||
```
|
||||
|
||||
2. 测试文件错误地为 K2 模式设置了 `zero_point = True`:
|
||||
```python
|
||||
if quant_mode in ("int4_1kgroup", "int4_kgroup"):
|
||||
config.quant_config.group_size = 128
|
||||
config.quant_config.zero_point = True # K2 不支持!
|
||||
```
|
||||
|
||||
3. **AWQ vs K2 技术差异**:
|
||||
- AWQ (`int4_1kgroup`): 使用 scales + zero_points
|
||||
- K2 (`int4_kgroup`): 只使用 scales,不支持 zero_points
|
||||
|
||||
4. 证据:`k2-moe.hpp:175-180` 只加载 `gate_scale`, `up_scale`, `down_scale`,没有任何 zero_point 相关代码。
|
||||
|
||||
### 修复方案
|
||||
|
||||
为 AWQ 和 K2 设置不同的 `zero_point` 配置:
|
||||
|
||||
```python
|
||||
# 修改前:
|
||||
if quant_mode in ("int4_1kgroup", "int4_kgroup"):
|
||||
config.quant_config.group_size = 128
|
||||
config.quant_config.zero_point = True
|
||||
|
||||
# 修改后:
|
||||
if quant_mode == "int4_1kgroup": # AWQ supports zero_point
|
||||
config.quant_config.group_size = 128
|
||||
config.quant_config.zero_point = True
|
||||
elif quant_mode == "int4_kgroup": # K2 does NOT support zero_point
|
||||
config.quant_config.group_size = 128
|
||||
config.quant_config.zero_point = False
|
||||
```
|
||||
|
||||
### 修复文件清单
|
||||
|
||||
| 文件 | 修改位置 | 状态 |
|
||||
|------|---------|------|
|
||||
| `examples/test_moe_sft_amx_no_tp.py` | 4 处 quant_config 配置 | ✓ 已修复 |
|
||||
|
||||
### 教训总结
|
||||
|
||||
1. **量化模式差异**:不同的量化方案(AWQ vs K2)有不同的配置要求
|
||||
2. **配置分离**:不应将不同模式的配置合并到同一个条件分支
|
||||
3. **错误信息指引**:K2 的错误信息 "Kimi-K2 MoE only support KGroup Int4" 准确指出了问题
|
||||
|
||||
---
|
||||
|
||||
## Bug #26: K2 MOE SFT 测试需要预量化权重
|
||||
|
||||
**时间**: 2026-01-06
|
||||
|
||||
**状态**: ✅ 已修复
|
||||
|
||||
### 问题描述
|
||||
|
||||
修复 Bug #25 后,INT4_KGROUP SFT 测试在 `load_weights()` 时仍然崩溃:
|
||||
|
||||
```
|
||||
what(): Kimi AVX MOE only support load native weight.
|
||||
```
|
||||
|
||||
尝试添加 Online Quantization 路径后编译失败:
|
||||
```
|
||||
'amx::BufferBInt4KGroupImpl' has no member named 'from_mat'; did you mean 'from_raw_mat'?
|
||||
```
|
||||
|
||||
### 根因分析
|
||||
|
||||
K2 与 AWQ 架构差异:
|
||||
|
||||
| 特性 | K2 (BufferBInt4KGroupImpl) | AWQ (BufferBInt4WithZeroKGroupImpl) |
|
||||
|------|---------------------------|-------------------------------------|
|
||||
| Buffer 存储 | weights + scales | weights + scales + zero_points |
|
||||
| 支持方法 | `from_raw_mat()` 仅 | `from_raw_mat()` + `from_mat()` |
|
||||
| 量化方式 | Signed Int4 (对称量化) | Unsigned Int4 + zero_point |
|
||||
| 在线量化 | ❌ 不支持 | ✅ 支持 |
|
||||
|
||||
**结论**: K2 MOE 设计上只支持离线预量化权重,不支持在线量化。
|
||||
|
||||
### 修复方案
|
||||
|
||||
SFT 测试为 K2 模式提供**预量化的 Int4 权重 + scales**,而不是 BF16 权重。
|
||||
|
||||
```python
|
||||
# 添加 K2 量化函数
|
||||
def quantize_k2_tensor(weights: torch.Tensor, group_size: int):
|
||||
"""K2 对称量化: BF16 → signed int4 (范围 -8 到 7)"""
|
||||
reshaped = weights.view(e, rows, cols // group_size, group_size)
|
||||
max_abs = reshaped.abs().amax(dim=-1, keepdim=True)
|
||||
scales = (max_abs / 7.0).squeeze(-1)
|
||||
q = torch.round(reshaped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int8)
|
||||
packed = pack_tensor_per_row(q, num_bits=4)
|
||||
return packed, scales.to(torch.bfloat16)
|
||||
|
||||
# 测试配置修改
|
||||
if quant_mode == "int4_kgroup":
|
||||
k2_weights = init_base_weights_for_k2(expert_num, hidden_size, intermediate_size)
|
||||
config.gate_proj = k2_weights["gate_qweight"].data_ptr()
|
||||
config.gate_scale = k2_weights["gate_scales"].data_ptr() # 关键!
|
||||
```
|
||||
|
||||
### 关键改动
|
||||
|
||||
1. 撤销 `k2-moe.hpp` 的 Online Path 尝试
|
||||
2. 添加 `quantize_k2_tensor()`, `pack_to_int32()`, `pack_tensor_per_row()` 函数到测试
|
||||
3. 添加 `init_base_weights_for_k2()` 初始化函数
|
||||
4. 修改 4 处测试函数的配置逻辑
|
||||
|
||||
### K2 量化格式
|
||||
|
||||
| 特性 | K2 | AWQ |
|
||||
|------|----|----|
|
||||
| 量化类型 | Symmetric (对称) | Asymmetric (非对称) |
|
||||
| 范围 | -8 到 7 (signed) | 0 到 15 (unsigned) |
|
||||
| 参数 | scale only | scale + zero_point |
|
||||
| 公式 | q = round(w / scale) | q = round(w / scale) + zero |
|
||||
|
||||
### 修复文件清单
|
||||
|
||||
| 文件 | 修改内容 | 状态 |
|
||||
|------|---------|------|
|
||||
| `operators/amx/k2-moe.hpp` | 撤销 Online Path,保持原始设计 | ✓ 已修复 |
|
||||
| `examples/test_moe_sft_amx_no_tp.py` | 添加 K2 量化函数和预量化权重 | ✓ 已修复 |
|
||||
|
||||
### 教训总结
|
||||
|
||||
1. **架构理解**:K2 和 AWQ 使用不同的 Buffer 类型,不能简单地复制代码
|
||||
2. **设计一致性**:K2 设计为仅支持离线预量化权重,测试需配合这一设计
|
||||
3. **编译验证**:修改 C++ 代码前应先验证接口兼容性
|
||||
|
||||
---
|
||||
|
||||
## Bug #27: K2 MOE SFT load_weights 路径选择错误
|
||||
|
||||
**时间**: 2026-01-06
|
||||
|
||||
**状态**: ✅ 已修复
|
||||
|
||||
### 问题描述
|
||||
|
||||
修复 Bug #26 后,INT4_KGROUP No-TP 测试在 `load_weights()` 时崩溃:
|
||||
|
||||
```
|
||||
Thread "numa_0_t_50" received signal SIGSEGV, Segmentation fault.
|
||||
__memmove_avx512_unaligned_erms ()
|
||||
#1 TP_MOE_SFT::load_weights()::{lambda}::operator()(int) at moe-sft-tp.hpp:103
|
||||
```
|
||||
|
||||
日志显示错误的路径被选中:
|
||||
```
|
||||
TP_MOE_SFT: From BF16 with partitioning ← 错误!应该用 K2 预量化路径
|
||||
```
|
||||
|
||||
### 根因分析
|
||||
|
||||
**`moe-sft-tp.hpp:77` 的判断逻辑问题**:
|
||||
```cpp
|
||||
if (config.gate_proj != nullptr) {
|
||||
// 假设 gate_proj 是 BF16 数据
|
||||
memcpy(..., (ggml_bf16_t*)config.gate_proj + ..., sizeof(ggml_bf16_t) * ...);
|
||||
}
|
||||
```
|
||||
|
||||
测试设置 `config.gate_proj = k2_weights["gate_qweight"].data_ptr()` (int4 packed),但 C++ 代码见 `gate_proj != nullptr` 就误认为是 BF16,用 BF16 偏移量做 memcpy 导致 SIGSEGV。
|
||||
|
||||
### 修复方案
|
||||
|
||||
在 `moe-sft-tp.hpp::load_weights()` 添加 K2 预量化模式检测:
|
||||
|
||||
```cpp
|
||||
// K2 pre-quantized mode: gate_scale != nullptr && !zero_point
|
||||
bool is_k2_prequantized = (config.gate_scale != nullptr && !config.quant_config.zero_point);
|
||||
|
||||
if (is_k2_prequantized) {
|
||||
printf("TP_MOE_SFT: K2 pre-quantized mode (no BF16 partitioning)\n");
|
||||
if (tp_count == 1) {
|
||||
// No-TP: 直接调用 load_weights,tp_configs[i] 已有所有指针
|
||||
pool->dispense_backend()->do_numa_job([this](int numa_id) {
|
||||
tps[numa_id]->load_weights();
|
||||
});
|
||||
} else {
|
||||
throw std::runtime_error("K2 pre-quantized mode does not support TP > 1 yet");
|
||||
}
|
||||
} else if (config.gate_proj != nullptr) {
|
||||
// BF16 分区路径...
|
||||
}
|
||||
```
|
||||
|
||||
### 检测条件
|
||||
|
||||
| 模式 | gate_scale | zero_point | 检测结果 |
|
||||
|------|-----------|------------|----------|
|
||||
| K2 | != nullptr | false | is_k2_prequantized = true |
|
||||
| AWQ | != nullptr | true | is_k2_prequantized = false |
|
||||
| BF16 | nullptr | - | 走 gate_proj 检测 |
|
||||
|
||||
### 修复文件清单
|
||||
|
||||
| 文件 | 修改内容 | 状态 |
|
||||
|------|---------|------|
|
||||
| `operators/moe-sft-tp.hpp` | 添加 K2 预量化模式分支 | ✓ 已修复 |
|
||||
|
||||
### 教训总结
|
||||
|
||||
1. **数据类型区分**:同一个指针 (gate_proj) 可能指向不同格式的数据 (BF16 vs int4 packed)
|
||||
2. **显式检测**:使用多个条件组合 (gate_scale + zero_point) 来区分不同的量化模式
|
||||
3. **渐进支持**:先实现 No-TP 模式,TP > 1 的 K2 支持可后续添加
|
||||
|
||||
---
|
||||
|
||||
|
|
|
|||
|
|
@ -675,6 +675,11 @@ def test_moe_sft_forward(quant_mode: str = "bf16"):
|
|||
config.down_lora_b = down_lora_b.data_ptr()
|
||||
config.pool = CPUInfer.backend_
|
||||
|
||||
# Bug #23 fix: Set quant_config for AWQ/K2 modes
|
||||
if quant_mode in ("int4_1kgroup", "int4_kgroup"):
|
||||
config.quant_config.group_size = 128
|
||||
config.quant_config.zero_point = True
|
||||
|
||||
# Create MOE SFT instance based on quant_mode
|
||||
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
|
||||
moe = MOE_SFT_CLASS(config)
|
||||
|
|
@ -821,6 +826,11 @@ def test_moe_sft_backward(quant_mode: str = "bf16"):
|
|||
config.down_lora_b = down_lora_b.data_ptr()
|
||||
config.pool = CPUInfer.backend_
|
||||
|
||||
# Bug #23 fix: Set quant_config for AWQ/K2 modes
|
||||
if quant_mode in ("int4_1kgroup", "int4_kgroup"):
|
||||
config.quant_config.group_size = 128
|
||||
config.quant_config.zero_point = True
|
||||
|
||||
# Create MOE SFT instance based on quant_mode
|
||||
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
|
||||
moe = MOE_SFT_CLASS(config)
|
||||
|
|
@ -1031,6 +1041,11 @@ def test_moe_sft_lora_weight_sync(quant_mode: str = "bf16"):
|
|||
config.down_lora_b = down_lora_b.data_ptr()
|
||||
config.pool = CPUInfer.backend_
|
||||
|
||||
# Bug #23 fix: Set quant_config for AWQ/K2 modes
|
||||
if quant_mode in ("int4_1kgroup", "int4_kgroup"):
|
||||
config.quant_config.group_size = 128
|
||||
config.quant_config.zero_point = True
|
||||
|
||||
# Create MOE SFT instance based on quant_mode
|
||||
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
|
||||
moe = MOE_SFT_CLASS(config)
|
||||
|
|
@ -1268,6 +1283,11 @@ def test_moe_sft_training_loop(quant_mode: str = "bf16"):
|
|||
config.down_lora_b = down_lora_b_param.data.data_ptr()
|
||||
config.pool = CPUInfer.backend_
|
||||
|
||||
# Bug #23 fix: Set quant_config for AWQ/K2 modes
|
||||
if quant_mode in ("int4_1kgroup", "int4_kgroup"):
|
||||
config.quant_config.group_size = 128
|
||||
config.quant_config.zero_point = True
|
||||
|
||||
# Create MOE SFT instance based on quant_mode
|
||||
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
|
||||
moe = MOE_SFT_CLASS(config)
|
||||
|
|
@ -1458,8 +1478,8 @@ def run_all_tests():
|
|||
print("=" * 70)
|
||||
|
||||
# Quantization modes to test
|
||||
quant_modes = ["bf16", "int8", "int4", "int4_1", "int4_1kgroup", "int4_kgroup"]
|
||||
# quant_modes = ["int4_1kgroup", "int4_kgroup"]
|
||||
# quant_modes = ["bf16", "int8", "int4", "int4_1"]
|
||||
quant_modes = ["int4_1kgroup", "int4_kgroup"]
|
||||
|
||||
try:
|
||||
for quant_mode in quant_modes:
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ Key difference from test_moe_sft_amx.py:
|
|||
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
from typing import Literal, Dict
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
|
||||
print("sys.path:", sys.path)
|
||||
|
|
@ -115,6 +117,141 @@ def get_threshold(quant_mode: str, is_backward: bool = False) -> float:
|
|||
return BF16_FORWARD_THRESHOLD # 0.05
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# K2 Quantization Utilities (for INT4_KGROUP mode)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def pack_to_int32(value: torch.Tensor, num_bits: int, packed_dim: Literal[0, 1] = 1) -> torch.Tensor:
|
||||
"""Pack int4 values into int32 tensor.
|
||||
|
||||
Args:
|
||||
value: int8 tensor to pack
|
||||
num_bits: number of bits per value (4 for int4)
|
||||
packed_dim: dimension to pack along
|
||||
|
||||
Returns:
|
||||
int32 tensor with packed values
|
||||
"""
|
||||
if value.dtype is not torch.int8:
|
||||
raise ValueError("Tensor must be torch.int8 before packing")
|
||||
if not (1 <= num_bits <= 8):
|
||||
raise ValueError(f"num_bits must be in [1, 8], got {num_bits}")
|
||||
|
||||
offset = 1 << (num_bits - 1)
|
||||
value = (value + offset).to(torch.uint8)
|
||||
device = value.device
|
||||
|
||||
pack_factor = 32 // num_bits
|
||||
|
||||
if packed_dim == 0:
|
||||
value = value.transpose(0, 1)
|
||||
|
||||
rows, cols = value.shape
|
||||
padded_cols = math.ceil(cols / pack_factor) * pack_factor
|
||||
pad_len = padded_cols - cols
|
||||
|
||||
if pad_len > 0:
|
||||
value = torch.nn.functional.pad(value, (0, pad_len))
|
||||
|
||||
num_groups = padded_cols // pack_factor
|
||||
|
||||
# Use int32 here
|
||||
reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32)
|
||||
bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits
|
||||
packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32)
|
||||
|
||||
if packed_dim == 0:
|
||||
packed = packed.transpose(0, 1)
|
||||
|
||||
return packed
|
||||
|
||||
|
||||
def pack_tensor_per_row(q: torch.Tensor, num_bits: int) -> torch.Tensor:
|
||||
"""Pack tensor per row for K2 quantization.
|
||||
|
||||
Args:
|
||||
q: [expert_num, rows, cols] int8 tensor
|
||||
num_bits: number of bits per value
|
||||
|
||||
Returns:
|
||||
Packed int32 tensor
|
||||
"""
|
||||
e, rows, cols = q.shape
|
||||
flat = q.view(e * rows, cols)
|
||||
packed = pack_to_int32(flat, num_bits)
|
||||
return packed.view(e, rows, -1).contiguous()
|
||||
|
||||
|
||||
def quantize_k2_tensor(weights: torch.Tensor, group_size: int):
|
||||
"""
|
||||
K2 symmetric max-abs/7 quantization per k-group.
|
||||
|
||||
Args:
|
||||
weights: [expert_num, rows (N), cols (K)] bfloat16 tensor
|
||||
|
||||
Returns:
|
||||
packed_q: int32 tensor storing 8 int4s per element with shape [expert_num, rows * (cols // 8)]
|
||||
scales: bfloat16 tensor with shape [expert_num, rows * (cols // group_size)]
|
||||
"""
|
||||
weights_f32 = weights.to(torch.float32)
|
||||
e, rows, cols = weights_f32.shape
|
||||
if cols % group_size != 0 or cols % 2 != 0:
|
||||
raise ValueError(f"cols ({cols}) must be divisible by group_size ({group_size}) and 2")
|
||||
|
||||
reshaped = weights_f32.view(e, rows, cols // group_size, group_size)
|
||||
max_abs = reshaped.abs().amax(dim=-1, keepdim=True)
|
||||
max_abs = torch.clamp(max_abs, min=1e-8)
|
||||
scales = (max_abs / 7.0).squeeze(-1)
|
||||
q = torch.round(reshaped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int8)
|
||||
q = q.view(e, rows, cols)
|
||||
packed = pack_tensor_per_row(q, num_bits=4).view(e, rows, cols // 8).contiguous()
|
||||
scales = scales.to(torch.bfloat16).contiguous().view(e, rows, cols // group_size).contiguous()
|
||||
|
||||
return packed, scales
|
||||
|
||||
|
||||
def init_base_weights_for_k2(
|
||||
expert_num: int, hidden_size: int, intermediate_size: int, group_size: int = 128
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Initialize pre-quantized K2 weights for INT4_KGROUP mode.
|
||||
|
||||
Args:
|
||||
expert_num: number of experts
|
||||
hidden_size: hidden dimension
|
||||
intermediate_size: intermediate dimension
|
||||
group_size: quantization group size
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- gate_qweight, up_qweight, down_qweight: packed int4 weights
|
||||
- gate_scales, up_scales, down_scales: bf16 scales
|
||||
- gate_proj_bf16, up_proj_bf16, down_proj_bf16: original bf16 weights for reference
|
||||
"""
|
||||
# Create random BF16 weights
|
||||
gate_proj_bf16 = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16)
|
||||
up_proj_bf16 = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16)
|
||||
down_proj_bf16 = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.bfloat16)
|
||||
|
||||
# Quantize to int4
|
||||
gate_qweight, gate_scales = quantize_k2_tensor(gate_proj_bf16, group_size)
|
||||
up_qweight, up_scales = quantize_k2_tensor(up_proj_bf16, group_size)
|
||||
down_qweight, down_scales = quantize_k2_tensor(down_proj_bf16, group_size)
|
||||
|
||||
return {
|
||||
"gate_qweight": gate_qweight.contiguous(),
|
||||
"up_qweight": up_qweight.contiguous(),
|
||||
"down_qweight": down_qweight.contiguous(),
|
||||
"gate_scales": gate_scales.contiguous(),
|
||||
"up_scales": up_scales.contiguous(),
|
||||
"down_scales": down_scales.contiguous(),
|
||||
# Keep original BF16 for gradient verification
|
||||
"gate_proj_bf16": gate_proj_bf16.contiguous(),
|
||||
"up_proj_bf16": up_proj_bf16.contiguous(),
|
||||
"down_proj_bf16": down_proj_bf16.contiguous(),
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Activation Functions
|
||||
# =============================================================================
|
||||
|
|
@ -620,8 +757,19 @@ def test_moe_sft_forward_no_tp(quant_mode: str = "bf16"):
|
|||
# Set random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Initialize weights
|
||||
gate_proj, up_proj, down_proj = init_base_weights(expert_num, hidden_size, intermediate_size)
|
||||
# Initialize weights based on quant_mode
|
||||
k2_weights = None # Will be set for K2 mode
|
||||
if quant_mode == "int4_kgroup":
|
||||
# K2 needs pre-quantized int4 weights
|
||||
k2_weights = init_base_weights_for_k2(expert_num, hidden_size, intermediate_size, group_size=128)
|
||||
# Use original BF16 for reference computation
|
||||
gate_proj = k2_weights["gate_proj_bf16"]
|
||||
up_proj = k2_weights["up_proj_bf16"]
|
||||
down_proj = k2_weights["down_proj_bf16"]
|
||||
else:
|
||||
# Other modes use BF16 weights
|
||||
gate_proj, up_proj, down_proj = init_base_weights(expert_num, hidden_size, intermediate_size)
|
||||
|
||||
lora_weights = init_lora_weights(expert_num, hidden_size, intermediate_size, lora_rank)
|
||||
gate_lora_a, gate_lora_b, up_lora_a, up_lora_b, down_lora_a, down_lora_b = lora_weights
|
||||
|
||||
|
|
@ -655,9 +803,20 @@ def test_moe_sft_forward_no_tp(quant_mode: str = "bf16"):
|
|||
config.max_cache_depth = 1
|
||||
config.max_len = max_len
|
||||
config.layer_idx = 0
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
config.up_proj = up_proj.data_ptr()
|
||||
config.down_proj = down_proj.data_ptr()
|
||||
|
||||
# Bug #26 fix: K2 uses pre-quantized weights with scales
|
||||
if quant_mode == "int4_kgroup" and k2_weights is not None:
|
||||
config.gate_proj = k2_weights["gate_qweight"].data_ptr()
|
||||
config.up_proj = k2_weights["up_qweight"].data_ptr()
|
||||
config.down_proj = k2_weights["down_qweight"].data_ptr()
|
||||
config.gate_scale = k2_weights["gate_scales"].data_ptr()
|
||||
config.up_scale = k2_weights["up_scales"].data_ptr()
|
||||
config.down_scale = k2_weights["down_scales"].data_ptr()
|
||||
else:
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
config.up_proj = up_proj.data_ptr()
|
||||
config.down_proj = down_proj.data_ptr()
|
||||
|
||||
# Set LoRA weight pointers directly in config (zero-copy)
|
||||
config.gate_lora_a = gate_lora_a.data_ptr()
|
||||
config.gate_lora_b = gate_lora_b.data_ptr()
|
||||
|
|
@ -667,6 +826,15 @@ def test_moe_sft_forward_no_tp(quant_mode: str = "bf16"):
|
|||
config.down_lora_b = down_lora_b.data_ptr()
|
||||
config.pool = CPUInfer.backend_
|
||||
|
||||
# Bug #23 fix: Set quant_config for AWQ/K2 modes
|
||||
# Bug #25 fix: AWQ (int4_1kgroup) uses zero_point, K2 (int4_kgroup) does NOT
|
||||
if quant_mode == "int4_1kgroup": # AWQ supports zero_point
|
||||
config.quant_config.group_size = 128
|
||||
config.quant_config.zero_point = True
|
||||
elif quant_mode == "int4_kgroup": # K2 does NOT support zero_point
|
||||
config.quant_config.group_size = 128
|
||||
config.quant_config.zero_point = False
|
||||
|
||||
# Create MOE SFT instance based on quant_mode
|
||||
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
|
||||
moe = MOE_SFT_CLASS(config)
|
||||
|
|
@ -772,8 +940,19 @@ def test_moe_sft_backward_no_tp(quant_mode: str = "bf16"):
|
|||
# Set random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Initialize weights
|
||||
gate_proj, up_proj, down_proj = init_base_weights(expert_num, hidden_size, intermediate_size)
|
||||
# Initialize weights based on quant_mode
|
||||
k2_weights = None # Will be set for K2 mode
|
||||
if quant_mode == "int4_kgroup":
|
||||
# K2 needs pre-quantized int4 weights
|
||||
k2_weights = init_base_weights_for_k2(expert_num, hidden_size, intermediate_size, group_size=128)
|
||||
# Use original BF16 for reference computation
|
||||
gate_proj = k2_weights["gate_proj_bf16"]
|
||||
up_proj = k2_weights["up_proj_bf16"]
|
||||
down_proj = k2_weights["down_proj_bf16"]
|
||||
else:
|
||||
# Other modes use BF16 weights
|
||||
gate_proj, up_proj, down_proj = init_base_weights(expert_num, hidden_size, intermediate_size)
|
||||
|
||||
lora_weights = init_lora_weights(expert_num, hidden_size, intermediate_size, lora_rank)
|
||||
gate_lora_a, gate_lora_b, up_lora_a, up_lora_b, down_lora_a, down_lora_b = lora_weights
|
||||
|
||||
|
|
@ -806,9 +985,20 @@ def test_moe_sft_backward_no_tp(quant_mode: str = "bf16"):
|
|||
config.max_cache_depth = validation_iter # Need cache for backward
|
||||
config.max_len = max_len
|
||||
config.layer_idx = 0
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
config.up_proj = up_proj.data_ptr()
|
||||
config.down_proj = down_proj.data_ptr()
|
||||
|
||||
# Bug #26 fix: K2 uses pre-quantized weights with scales
|
||||
if quant_mode == "int4_kgroup" and k2_weights is not None:
|
||||
config.gate_proj = k2_weights["gate_qweight"].data_ptr()
|
||||
config.up_proj = k2_weights["up_qweight"].data_ptr()
|
||||
config.down_proj = k2_weights["down_qweight"].data_ptr()
|
||||
config.gate_scale = k2_weights["gate_scales"].data_ptr()
|
||||
config.up_scale = k2_weights["up_scales"].data_ptr()
|
||||
config.down_scale = k2_weights["down_scales"].data_ptr()
|
||||
else:
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
config.up_proj = up_proj.data_ptr()
|
||||
config.down_proj = down_proj.data_ptr()
|
||||
|
||||
config.gate_lora_a = gate_lora_a.data_ptr()
|
||||
config.gate_lora_b = gate_lora_b.data_ptr()
|
||||
config.up_lora_a = up_lora_a.data_ptr()
|
||||
|
|
@ -817,6 +1007,15 @@ def test_moe_sft_backward_no_tp(quant_mode: str = "bf16"):
|
|||
config.down_lora_b = down_lora_b.data_ptr()
|
||||
config.pool = CPUInfer.backend_
|
||||
|
||||
# Bug #23 fix: Set quant_config for AWQ/K2 modes
|
||||
# Bug #25 fix: AWQ (int4_1kgroup) uses zero_point, K2 (int4_kgroup) does NOT
|
||||
if quant_mode == "int4_1kgroup": # AWQ supports zero_point
|
||||
config.quant_config.group_size = 128
|
||||
config.quant_config.zero_point = True
|
||||
elif quant_mode == "int4_kgroup": # K2 does NOT support zero_point
|
||||
config.quant_config.group_size = 128
|
||||
config.quant_config.zero_point = False
|
||||
|
||||
# Create MOE SFT instance based on quant_mode
|
||||
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
|
||||
moe = MOE_SFT_CLASS(config)
|
||||
|
|
@ -997,8 +1196,19 @@ def test_moe_sft_lora_weight_sync_no_tp(quant_mode: str = "bf16"):
|
|||
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Initialize weights
|
||||
gate_proj, up_proj, down_proj = init_base_weights(expert_num, hidden_size, intermediate_size)
|
||||
# Initialize weights based on quant_mode
|
||||
k2_weights = None # Will be set for K2 mode
|
||||
if quant_mode == "int4_kgroup":
|
||||
# K2 needs pre-quantized int4 weights
|
||||
k2_weights = init_base_weights_for_k2(expert_num, hidden_size, intermediate_size, group_size=128)
|
||||
# Use original BF16 for reference computation
|
||||
gate_proj = k2_weights["gate_proj_bf16"]
|
||||
up_proj = k2_weights["up_proj_bf16"]
|
||||
down_proj = k2_weights["down_proj_bf16"]
|
||||
else:
|
||||
# Other modes use BF16 weights
|
||||
gate_proj, up_proj, down_proj = init_base_weights(expert_num, hidden_size, intermediate_size)
|
||||
|
||||
lora_weights = init_lora_weights(expert_num, hidden_size, intermediate_size, lora_rank)
|
||||
gate_lora_a, gate_lora_b, up_lora_a, up_lora_b, down_lora_a, down_lora_b = lora_weights
|
||||
|
||||
|
|
@ -1020,9 +1230,20 @@ def test_moe_sft_lora_weight_sync_no_tp(quant_mode: str = "bf16"):
|
|||
config.max_cache_depth = 1
|
||||
config.max_len = max_len
|
||||
config.layer_idx = 0
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
config.up_proj = up_proj.data_ptr()
|
||||
config.down_proj = down_proj.data_ptr()
|
||||
|
||||
# Bug #26 fix: K2 uses pre-quantized weights with scales
|
||||
if quant_mode == "int4_kgroup" and k2_weights is not None:
|
||||
config.gate_proj = k2_weights["gate_qweight"].data_ptr()
|
||||
config.up_proj = k2_weights["up_qweight"].data_ptr()
|
||||
config.down_proj = k2_weights["down_qweight"].data_ptr()
|
||||
config.gate_scale = k2_weights["gate_scales"].data_ptr()
|
||||
config.up_scale = k2_weights["up_scales"].data_ptr()
|
||||
config.down_scale = k2_weights["down_scales"].data_ptr()
|
||||
else:
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
config.up_proj = up_proj.data_ptr()
|
||||
config.down_proj = down_proj.data_ptr()
|
||||
|
||||
config.gate_lora_a = gate_lora_a.data_ptr()
|
||||
config.gate_lora_b = gate_lora_b.data_ptr()
|
||||
config.up_lora_a = up_lora_a.data_ptr()
|
||||
|
|
@ -1031,6 +1252,15 @@ def test_moe_sft_lora_weight_sync_no_tp(quant_mode: str = "bf16"):
|
|||
config.down_lora_b = down_lora_b.data_ptr()
|
||||
config.pool = CPUInfer.backend_
|
||||
|
||||
# Bug #23 fix: Set quant_config for AWQ/K2 modes
|
||||
# Bug #25 fix: AWQ (int4_1kgroup) uses zero_point, K2 (int4_kgroup) does NOT
|
||||
if quant_mode == "int4_1kgroup": # AWQ supports zero_point
|
||||
config.quant_config.group_size = 128
|
||||
config.quant_config.zero_point = True
|
||||
elif quant_mode == "int4_kgroup": # K2 does NOT support zero_point
|
||||
config.quant_config.group_size = 128
|
||||
config.quant_config.zero_point = False
|
||||
|
||||
# Create MOE SFT instance based on quant_mode
|
||||
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
|
||||
moe = MOE_SFT_CLASS(config)
|
||||
|
|
@ -1193,8 +1423,18 @@ def test_moe_sft_training_loop_no_tp(quant_mode: str = "bf16"):
|
|||
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Initialize base weights
|
||||
gate_proj, up_proj, down_proj = init_base_weights(expert_num, hidden_size, intermediate_size)
|
||||
# Initialize base weights based on quant_mode
|
||||
k2_weights = None # Will be set for K2 mode
|
||||
if quant_mode == "int4_kgroup":
|
||||
# K2 needs pre-quantized int4 weights
|
||||
k2_weights = init_base_weights_for_k2(expert_num, hidden_size, intermediate_size, group_size=128)
|
||||
# Use original BF16 for reference computation
|
||||
gate_proj = k2_weights["gate_proj_bf16"]
|
||||
up_proj = k2_weights["up_proj_bf16"]
|
||||
down_proj = k2_weights["down_proj_bf16"]
|
||||
else:
|
||||
# Other modes use BF16 weights
|
||||
gate_proj, up_proj, down_proj = init_base_weights(expert_num, hidden_size, intermediate_size)
|
||||
|
||||
# Initialize LoRA weights as contiguous tensors
|
||||
gate_lora_a = (
|
||||
|
|
@ -1261,9 +1501,20 @@ def test_moe_sft_training_loop_no_tp(quant_mode: str = "bf16"):
|
|||
config.max_cache_depth = 1 # One forward-backward pair at a time
|
||||
config.max_len = max_len
|
||||
config.layer_idx = 0
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
config.up_proj = up_proj.data_ptr()
|
||||
config.down_proj = down_proj.data_ptr()
|
||||
|
||||
# Bug #26 fix: K2 uses pre-quantized weights with scales
|
||||
if quant_mode == "int4_kgroup" and k2_weights is not None:
|
||||
config.gate_proj = k2_weights["gate_qweight"].data_ptr()
|
||||
config.up_proj = k2_weights["up_qweight"].data_ptr()
|
||||
config.down_proj = k2_weights["down_qweight"].data_ptr()
|
||||
config.gate_scale = k2_weights["gate_scales"].data_ptr()
|
||||
config.up_scale = k2_weights["up_scales"].data_ptr()
|
||||
config.down_scale = k2_weights["down_scales"].data_ptr()
|
||||
else:
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
config.up_proj = up_proj.data_ptr()
|
||||
config.down_proj = down_proj.data_ptr()
|
||||
|
||||
config.gate_lora_a = gate_lora_a_param.data.data_ptr()
|
||||
config.gate_lora_b = gate_lora_b_param.data.data_ptr()
|
||||
config.up_lora_a = up_lora_a_param.data.data_ptr()
|
||||
|
|
@ -1272,6 +1523,15 @@ def test_moe_sft_training_loop_no_tp(quant_mode: str = "bf16"):
|
|||
config.down_lora_b = down_lora_b_param.data.data_ptr()
|
||||
config.pool = CPUInfer.backend_
|
||||
|
||||
# Bug #23 fix: Set quant_config for AWQ/K2 modes
|
||||
# Bug #25 fix: AWQ (int4_1kgroup) uses zero_point, K2 (int4_kgroup) does NOT
|
||||
if quant_mode == "int4_1kgroup": # AWQ supports zero_point
|
||||
config.quant_config.group_size = 128
|
||||
config.quant_config.zero_point = True
|
||||
elif quant_mode == "int4_kgroup": # K2 does NOT support zero_point
|
||||
config.quant_config.group_size = 128
|
||||
config.quant_config.zero_point = False
|
||||
|
||||
# Create MOE SFT instance based on quant_mode
|
||||
MOE_SFT_CLASS = get_moe_sft_class(quant_mode)
|
||||
moe = MOE_SFT_CLASS(config)
|
||||
|
|
@ -1463,8 +1723,9 @@ def run_all_tests():
|
|||
print("=" * 70)
|
||||
|
||||
# Quantization modes to test
|
||||
# quant_modes = ["bf16", "int8", "int4", "int4_1", "int4_1kgroup", "int4_kgroup"]
|
||||
quant_modes = ["int4_1kgroup", "int4_kgroup"]
|
||||
# quant_modes = ["bf16", "int8", "int4", "int4_1"]
|
||||
# quant_modes = ["int4_1kgroup", "int4_kgroup"]
|
||||
quant_modes = ["int4_kgroup"]
|
||||
|
||||
try:
|
||||
for quant_mode in quant_modes:
|
||||
|
|
|
|||
|
|
@ -115,6 +115,9 @@ class AMX_K2_MOE_TP : public AMX_MOE_BASE<T, AMX_K2_MOE_TP<T>> {
|
|||
*
|
||||
* Loads weights from config_.gate_proj, up_proj, down_proj with scales
|
||||
* from config_.gate_scale, up_scale, down_scale.
|
||||
*
|
||||
* Note: K2 MOE only supports offline pre-quantized weights (gate_scale must be set).
|
||||
* For online quantization, use AWQ MOE instead.
|
||||
*/
|
||||
void load_weights() {
|
||||
auto& quant_config = config_.quant_config;
|
||||
|
|
|
|||
|
|
@ -240,17 +240,17 @@ struct GeneralMOEConfig {
|
|||
int num_gpu_experts = 0;
|
||||
void* physical_to_logical_map = nullptr;
|
||||
|
||||
void* gate_proj;
|
||||
void* up_proj;
|
||||
void* down_proj;
|
||||
void* gate_proj = nullptr;
|
||||
void* up_proj = nullptr;
|
||||
void* down_proj = nullptr;
|
||||
|
||||
void* gate_scale;
|
||||
void* up_scale;
|
||||
void* down_scale;
|
||||
void* gate_scale = nullptr;
|
||||
void* up_scale = nullptr;
|
||||
void* down_scale = nullptr;
|
||||
|
||||
void* gate_zero;
|
||||
void* up_zero;
|
||||
void* down_zero;
|
||||
void* gate_zero = nullptr;
|
||||
void* up_zero = nullptr;
|
||||
void* down_zero = nullptr;
|
||||
|
||||
QuantConfig quant_config;
|
||||
|
||||
|
|
|
|||
|
|
@ -74,7 +74,23 @@ class TP_MOE_SFT : public TP_MOE<T> {
|
|||
auto pool = config.pool;
|
||||
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
|
||||
|
||||
if (config.gate_proj != nullptr) {
|
||||
// Bug #27 fix: K2 pre-quantized mode detection
|
||||
// K2 uses gate_scale != nullptr and zero_point = false
|
||||
// AWQ also has gate_scale but has zero_point = true
|
||||
bool is_k2_prequantized = (config.gate_scale != nullptr && !config.quant_config.zero_point);
|
||||
|
||||
if (is_k2_prequantized) {
|
||||
printf("TP_MOE_SFT: K2 pre-quantized mode (no BF16 partitioning)\n");
|
||||
// For K2, weights are already int4-packed with scales
|
||||
// tp_configs[i] already has all pointers from config (copied in TP_MOE constructor)
|
||||
if (tp_count == 1) {
|
||||
// No-TP: just call load_weights directly
|
||||
pool->dispense_backend()->do_numa_job([this](int numa_id) { tps[numa_id]->load_weights(); });
|
||||
} else {
|
||||
// TP mode with K2 would need int4-aware partitioning (not implemented yet)
|
||||
throw std::runtime_error("K2 pre-quantized mode does not support TP > 1 yet");
|
||||
}
|
||||
} else if (config.gate_proj != nullptr) {
|
||||
printf("TP_MOE_SFT: From BF16 with partitioning\n");
|
||||
|
||||
// Temporary storage for partitioned weights
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue