mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 11:49:51 +00:00
support GLM 4.7 (#1791)
Some checks failed
Book-CI / test-2 (push) Has been cancelled
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
Some checks failed
Book-CI / test-2 (push) Has been cancelled
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
support GLM 4.7
This commit is contained in:
parent
667030d6e6
commit
6277da4c2b
14 changed files with 2336 additions and 144 deletions
|
|
@ -15,6 +15,14 @@ except (ImportError, AttributeError):
|
|||
_HAS_AMX_SUPPORT = False
|
||||
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE, AMXBF16_MOE = None, None, None, None, None
|
||||
|
||||
try:
|
||||
from kt_kernel_ext.moe import AMXFP8PerChannel_MOE
|
||||
|
||||
_HAS_FP8_PERCHANNEL_SUPPORT = True
|
||||
except (ImportError, AttributeError):
|
||||
_HAS_FP8_PERCHANNEL_SUPPORT = False
|
||||
AMXFP8PerChannel_MOE = None
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
|
|
@ -304,7 +312,7 @@ class AMXMoEWrapper(BaseMoEWrapper):
|
|||
|
||||
|
||||
class NativeMoEWrapper(BaseMoEWrapper):
|
||||
"""Wrapper for RAWINT4/FP8/BF16 experts stored in compressed SafeTensor format."""
|
||||
"""Wrapper for RAWINT4/FP8/FP8_PERCHANNEL/BF16 experts stored in compressed SafeTensor format."""
|
||||
|
||||
_native_loader_instance = None
|
||||
|
||||
|
|
@ -330,6 +338,8 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
|||
raise RuntimeError("AMX backend with RAWINT4 support is not available.")
|
||||
if method == "FP8" and AMXFP8_MOE is None:
|
||||
raise RuntimeError("AMX backend with FP8 support is not available.")
|
||||
if method == "FP8_PERCHANNEL" and not _HAS_FP8_PERCHANNEL_SUPPORT:
|
||||
raise RuntimeError("AMX backend with FP8 per-channel support is not available.")
|
||||
if method == "BF16" and AMXBF16_MOE is None:
|
||||
raise RuntimeError("AMX backend with BF16 support is not available.")
|
||||
|
||||
|
|
@ -354,6 +364,9 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
|||
NativeMoEWrapper._native_loader_instance = CompressedSafeTensorLoader(weight_path)
|
||||
elif method == "FP8":
|
||||
NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path)
|
||||
elif method == "FP8_PERCHANNEL":
|
||||
# Use FP8SafeTensorLoader with per-channel scale format
|
||||
NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path, scale_suffix="weight_scale")
|
||||
elif method == "BF16":
|
||||
NativeMoEWrapper._native_loader_instance = BF16SafeTensorLoader(weight_path)
|
||||
else:
|
||||
|
|
@ -408,6 +421,8 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
|||
assert self.gate_scales[0].dtype == torch.bfloat16, "Expected bf16 scales for RAWINT4"
|
||||
elif self.method == "FP8":
|
||||
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8"
|
||||
elif self.method == "FP8_PERCHANNEL":
|
||||
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8_PERCHANNEL"
|
||||
|
||||
t2 = time.time()
|
||||
|
||||
|
|
@ -462,6 +477,11 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
|||
moe_config.quant_config.group_size = 128
|
||||
moe_config.quant_config.zero_point = False
|
||||
self.moe = AMXFP8_MOE(moe_config)
|
||||
elif self.method == "FP8_PERCHANNEL":
|
||||
moe_config.quant_config.bits = 8
|
||||
moe_config.quant_config.per_channel = True
|
||||
moe_config.quant_config.zero_point = False
|
||||
self.moe = AMXFP8PerChannel_MOE(moe_config)
|
||||
elif self.method == "BF16":
|
||||
# BF16 has no quantization config needed
|
||||
self.moe = AMXBF16_MOE(moe_config)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue