mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-29 12:19:50 +00:00
[feat](kt-kernel): support avx2 only inference for bf16 fp8 and gptq int4 (#1892)
Some checks are pending
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Book-CI / test-2 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run
Some checks are pending
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Book-CI / test-2 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run
* feat: support avx2 bf16 fp8 inference * feat: support avx2 gptq int4 inference * fix: numeric issues in fp8 dequant * Tutorial avx2 (#1900) * fix: prevent injecting -DLLAMA_AVX512=ON on AVX2-only machines * docs: add AVX2 tutorial for running KTransformers on AVX2-only CPUs * Tutorial avx2 (#1901) * fix: prevent injecting -DLLAMA_AVX512=ON on AVX2-only machines * docs: add AVX2 tutorial for running KTransformers on AVX2-only CPUs * docs: update README.md --------- Co-authored-by: Benjamin F <159887351+yyj6666667@users.noreply.github.com>
This commit is contained in:
parent
8561a71dd1
commit
7a9daf0cd4
19 changed files with 3472 additions and 12 deletions
|
|
@ -5,7 +5,7 @@ from typing import Optional
|
|||
|
||||
# Use relative imports for package structure
|
||||
from ..experts_base import BaseMoEWrapper
|
||||
from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader, BF16SafeTensorLoader
|
||||
from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader, BF16SafeTensorLoader, GPTQSafeTensorLoader
|
||||
from kt_kernel_ext.moe import MOEConfig
|
||||
import kt_kernel_ext.moe as _moe_mod
|
||||
|
||||
|
|
@ -15,6 +15,9 @@ AMXInt4_KGroup_MOE = getattr(_moe_mod, "AMXInt4_KGroup_MOE", None)
|
|||
AMXFP8_MOE = getattr(_moe_mod, "AMXFP8_MOE", None)
|
||||
AMXBF16_MOE = getattr(_moe_mod, "AMXBF16_MOE", None)
|
||||
AMXFP8PerChannel_MOE = getattr(_moe_mod, "AMXFP8PerChannel_MOE", None)
|
||||
AVX2BF16_MOE = getattr(_moe_mod, "AVX2BF16_MOE", None)
|
||||
AVX2FP8_MOE = getattr(_moe_mod, "AVX2FP8_MOE", None)
|
||||
AVX2GPTQInt4_MOE = getattr(_moe_mod, "AVX2GPTQInt4_MOE", None)
|
||||
|
||||
_HAS_AMXINT4_SUPPORT = AMXInt4_MOE is not None
|
||||
_HAS_AMXINT8_SUPPORT = AMXInt8_MOE is not None
|
||||
|
|
@ -22,6 +25,9 @@ _HAS_RAWINT4_SUPPORT = AMXInt4_KGroup_MOE is not None
|
|||
_HAS_FP8_SUPPORT = AMXFP8_MOE is not None
|
||||
_HAS_BF16_SUPPORT = AMXBF16_MOE is not None
|
||||
_HAS_FP8_PERCHANNEL_SUPPORT = AMXFP8PerChannel_MOE is not None
|
||||
_HAS_AVX2_BF16_SUPPORT = AVX2BF16_MOE is not None
|
||||
_HAS_AVX2_FP8_SUPPORT = AVX2FP8_MOE is not None
|
||||
_HAS_AVX2_GPTQ_INT4_SUPPORT = AVX2GPTQInt4_MOE is not None
|
||||
|
||||
|
||||
class AMXMoEWrapper(BaseMoEWrapper):
|
||||
|
|
@ -346,10 +352,11 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
|||
" - AVX512F + AVX512BW (VNNI optional)\n"
|
||||
"Please recompile kt_kernel_ext with AVX512 enabled."
|
||||
)
|
||||
if method == "FP8" and not _HAS_FP8_SUPPORT:
|
||||
if method == "FP8" and not _HAS_FP8_SUPPORT and not _HAS_AVX2_FP8_SUPPORT:
|
||||
raise RuntimeError(
|
||||
"FP8 backend not available. Required ISA:\n"
|
||||
" - AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI\n"
|
||||
" - AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI (for AMX), or\n"
|
||||
" - AVX2 + FMA (for AVX2 fallback)\n"
|
||||
"Please recompile kt_kernel_ext with AVX512 + BF16 + VBMI enabled."
|
||||
)
|
||||
if method == "FP8_PERCHANNEL" and not _HAS_FP8_PERCHANNEL_SUPPORT:
|
||||
|
|
@ -358,11 +365,17 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
|||
" - AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI\n"
|
||||
"Please recompile kt_kernel_ext with AVX512 + BF16 + VBMI enabled."
|
||||
)
|
||||
if method == "BF16" and not _HAS_BF16_SUPPORT:
|
||||
if method == "BF16" and not _HAS_BF16_SUPPORT and not _HAS_AVX2_BF16_SUPPORT:
|
||||
raise RuntimeError(
|
||||
"BF16 backend not available. Required ISA:\n"
|
||||
" - AVX512F + AVX512BW + AVX512_BF16\n"
|
||||
"Please recompile kt_kernel_ext with AVX512 + BF16 enabled."
|
||||
" - AVX512F + AVX512BW + AVX512_BF16 (for AMX backend), or\n"
|
||||
" - AVX2 + FMA (for AVX2 fallback backend)\n"
|
||||
"Please recompile kt_kernel_ext with AVX512+BF16 or AVX2 enabled."
|
||||
)
|
||||
if method == "GPTQ_INT4" and not _HAS_AVX2_GPTQ_INT4_SUPPORT:
|
||||
raise RuntimeError(
|
||||
"GPTQ_INT4 backend not available.\n"
|
||||
"Please recompile kt_kernel_ext with AVX2 enabled."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
|
|
@ -391,6 +404,8 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
|||
NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path, scale_suffix="weight_scale")
|
||||
elif method == "BF16":
|
||||
NativeMoEWrapper._native_loader_instance = BF16SafeTensorLoader(weight_path)
|
||||
elif method == "GPTQ_INT4":
|
||||
NativeMoEWrapper._native_loader_instance = GPTQSafeTensorLoader(weight_path)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}")
|
||||
self.loader = NativeMoEWrapper._native_loader_instance
|
||||
|
|
@ -506,15 +521,31 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
|||
moe_config.quant_config.bits = 8
|
||||
moe_config.quant_config.group_size = 128
|
||||
moe_config.quant_config.zero_point = False
|
||||
self.moe = AMXFP8_MOE(moe_config)
|
||||
if _HAS_FP8_SUPPORT:
|
||||
self.moe = AMXFP8_MOE(moe_config)
|
||||
else:
|
||||
self.moe = AVX2FP8_MOE(moe_config)
|
||||
elif self.method == "FP8_PERCHANNEL":
|
||||
moe_config.quant_config.bits = 8
|
||||
moe_config.quant_config.per_channel = True
|
||||
moe_config.quant_config.zero_point = False
|
||||
self.moe = AMXFP8PerChannel_MOE(moe_config)
|
||||
elif self.method == "GPTQ_INT4":
|
||||
# GPTQ symmetric INT4: qweight (int32) + scales (fp32)
|
||||
group_size = self.gate_scales[0].shape[0] # scales shape [K/gs, N], first dim = num_groups
|
||||
# hidden_size / num_groups = group_size
|
||||
actual_gs = self.hidden_size // group_size
|
||||
moe_config.quant_config.bits = 4
|
||||
moe_config.quant_config.group_size = actual_gs
|
||||
moe_config.quant_config.zero_point = False
|
||||
self.moe = AVX2GPTQInt4_MOE(moe_config)
|
||||
elif self.method == "BF16":
|
||||
# BF16 has no quantization config needed
|
||||
self.moe = AMXBF16_MOE(moe_config)
|
||||
# Prefer AMX backend, fall back to AVX2
|
||||
if _HAS_BF16_SUPPORT:
|
||||
self.moe = AMXBF16_MOE(moe_config)
|
||||
else:
|
||||
self.moe = AVX2BF16_MOE(moe_config)
|
||||
t4 = time.time()
|
||||
|
||||
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
|
||||
|
|
|
|||
|
|
@ -961,3 +961,120 @@ class GGUFLoader:
|
|||
data = torch.from_numpy(np.frombuffer(data_bytes, dtype=np.uint8).copy())
|
||||
|
||||
return data, ggml_type
|
||||
|
||||
|
||||
class GPTQSafeTensorLoader(FP8SafeTensorLoader):
|
||||
"""Loader for symmetric GPTQ-Int4 expert weights (qweight + scales, no qzeros).
|
||||
|
||||
Only supports sym=true, desc_act=false GPTQ models.
|
||||
|
||||
Tensor keys:
|
||||
- qweight: {prefix}.{id}.{proj}.qweight (int32, packed 8x4-bit along K)
|
||||
- scales: {prefix}.{id}.{proj}.scales (fp16 -> converted to fp32)
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
# Call FP8SafeTensorLoader init (which calls SafeTensorLoader init + format detection)
|
||||
super().__init__(file_path, scale_suffix="scales")
|
||||
# Verify GPTQ config
|
||||
self._verify_gptq_config(file_path)
|
||||
|
||||
def _detect_format(self):
|
||||
"""Override FP8 format detection to look for .qweight instead of .weight."""
|
||||
sample_keys = list(self.tensor_file_map.keys())[:2000]
|
||||
|
||||
for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items():
|
||||
for key in sample_keys:
|
||||
if ".experts." in key and f".{gate}.qweight" in key:
|
||||
if "block_sparse_moe.experts" in key and fmt_name == "mixtral":
|
||||
self._detected_format = fmt_name
|
||||
break
|
||||
elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek":
|
||||
self._detected_format = fmt_name
|
||||
# Check for VL model (language_model prefix)
|
||||
if "language_model." in key:
|
||||
self._is_vl_model = True
|
||||
break
|
||||
elif fmt_name == "mistral" and "block_sparse_moe" not in key and "mlp" not in key:
|
||||
self._detected_format = fmt_name
|
||||
break
|
||||
if self._detected_format is not None:
|
||||
break
|
||||
|
||||
if self._detected_format is None:
|
||||
self._detected_format = "deepseek"
|
||||
|
||||
vl_str = " (VL model)" if self._is_vl_model else ""
|
||||
print(f"[GPTQSafeTensorLoader] Detected format: {self._detected_format}{vl_str}")
|
||||
|
||||
def _verify_gptq_config(self, file_path):
|
||||
"""Check that the model uses sym=true, desc_act=false."""
|
||||
import json
|
||||
import os
|
||||
|
||||
config_path = os.path.join(os.path.dirname(file_path), "config.json")
|
||||
if not os.path.exists(config_path):
|
||||
# Try parent directory
|
||||
config_path = os.path.join(file_path, "config.json")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
qc = config.get("quantization_config", {})
|
||||
if qc.get("quant_method") == "gptq":
|
||||
if qc.get("desc_act", False):
|
||||
raise NotImplementedError(
|
||||
"GPTQ desc_act=true is not supported. Only desc_act=false models are supported."
|
||||
)
|
||||
if not qc.get("sym", True):
|
||||
raise NotImplementedError(
|
||||
"GPTQ sym=false (asymmetric) is not supported. Only sym=true models are supported."
|
||||
)
|
||||
print(f"[GPTQSafeTensorLoader] Verified: sym={qc.get('sym')}, desc_act={qc.get('desc_act')}, "
|
||||
f"bits={qc.get('bits')}, group_size={qc.get('group_size')}")
|
||||
|
||||
def load_experts(self, base_key: str, device: str = "cpu"):
|
||||
"""Load GPTQ expert qweight and scales.
|
||||
|
||||
Returns dict with keys: gate, up, down (qweight int32), gate_scale, up_scale, down_scale (fp32).
|
||||
"""
|
||||
experts_prefix_candidates = self._get_experts_prefix_candidates(base_key)
|
||||
gate_name, up_name, down_name = self._get_proj_names()
|
||||
|
||||
expert_count = 0
|
||||
experts_prefix = None
|
||||
for prefix in experts_prefix_candidates:
|
||||
expert_count = 0
|
||||
while self.has_tensor(f"{prefix}.{expert_count}.{gate_name}.qweight"):
|
||||
expert_count += 1
|
||||
if expert_count > 0:
|
||||
experts_prefix = prefix
|
||||
break
|
||||
|
||||
if expert_count == 0 or experts_prefix is None:
|
||||
raise ValueError(f"No GPTQ experts found for keys: {experts_prefix_candidates}")
|
||||
|
||||
gate_weights = [None] * expert_count
|
||||
up_weights = [None] * expert_count
|
||||
down_weights = [None] * expert_count
|
||||
gate_scales = [None] * expert_count
|
||||
up_scales = [None] * expert_count
|
||||
down_scales = [None] * expert_count
|
||||
|
||||
for exp_id in range(expert_count):
|
||||
gate_weights[exp_id] = self.load_tensor(f"{experts_prefix}.{exp_id}.{gate_name}.qweight", device).contiguous()
|
||||
up_weights[exp_id] = self.load_tensor(f"{experts_prefix}.{exp_id}.{up_name}.qweight", device).contiguous()
|
||||
down_weights[exp_id] = self.load_tensor(f"{experts_prefix}.{exp_id}.{down_name}.qweight", device).contiguous()
|
||||
|
||||
gate_scales[exp_id] = self.load_tensor(f"{experts_prefix}.{exp_id}.{gate_name}.scales", device).float().contiguous()
|
||||
up_scales[exp_id] = self.load_tensor(f"{experts_prefix}.{exp_id}.{up_name}.scales", device).float().contiguous()
|
||||
down_scales[exp_id] = self.load_tensor(f"{experts_prefix}.{exp_id}.{down_name}.scales", device).float().contiguous()
|
||||
|
||||
print(f"[GPTQSafeTensorLoader] Loaded {expert_count} experts from {experts_prefix}")
|
||||
return {
|
||||
"gate": gate_weights,
|
||||
"up": up_weights,
|
||||
"down": down_weights,
|
||||
"gate_scale": gate_scales,
|
||||
"up_scale": up_scales,
|
||||
"down_scale": down_scales,
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue