support AMX

This commit is contained in:
chenht2022 2025-04-25 14:47:16 +00:00
parent b90362b5e6
commit f3d842a0ca
15 changed files with 1799 additions and 62 deletions

View file

@ -25,8 +25,9 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
import cpuinfer_ext
from cpuinfer_ext.moe import MOEConfig, MOE
from cpuinfer_ext.moe import AMX_MOEConfig, AMXBF16_MOE, AMXInt8_MOE
import ctypes
from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.custom_gguf import GGMLQuantizationType, GGUFLoader
from ktransformers.util.utils import InferenceState
from ktransformers.server.config.config import Config
from transformers.activations import ACT2FN
@ -141,6 +142,7 @@ class KExpertsCPU(KExpertsBase):
assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU"
self.n_routed_experts = n_routed_experts
self.out_device = out_device
self.backend = kwargs.get("backend", "llamafile")
def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False):
if device:
@ -163,27 +165,62 @@ class KExpertsCPU(KExpertsBase):
)
# print(self.gate_qtype, self.up_qtype, self.down_qtype)
n_routed_experts = self.n_routed_experts
self.cpu_infer = KExpertsCPU.CPU_INFER
# n_routed_experts = len(self.orig_module)
moe_config = MOEConfig(
n_routed_experts,
self.config.num_experts_per_tok,
self.config.hidden_size,
self.config.moe_intermediate_size,
64,
10,
1024,
gate_ptr,
up_ptr,
down_ptr,
self.gate_type,
self.up_type,
self.down_type,
30, # TODO: get from model.dtype
)
if self.backend == "llamafile":
moe_config = MOEConfig(
n_routed_experts,
self.config.num_experts_per_tok,
self.config.hidden_size,
self.config.moe_intermediate_size,
64,
10,
1024,
gate_ptr,
up_ptr,
down_ptr,
self.gate_type,
self.up_type,
self.down_type,
30, # TODO: get from model.dtype
)
self.moe = MOE(moe_config)
elif self.backend == "AMXBF16":
assert self.gate_type == GGMLQuantizationType.BF16
assert self.up_type == GGMLQuantizationType.BF16
assert self.down_type == GGMLQuantizationType.BF16
moe_config = AMX_MOEConfig(
n_routed_experts,
self.config.num_experts_per_tok,
self.config.hidden_size,
self.config.moe_intermediate_size,
25600,
gate_ptr,
up_ptr,
down_ptr,
)
self.moe = AMXBF16_MOE(moe_config)
self.cpu_infer.submit(self.moe.load_weights())
self.cpu_infer.sync()
elif self.backend == "AMXInt8":
assert self.gate_type == GGMLQuantizationType.BF16
assert self.up_type == GGMLQuantizationType.BF16
assert self.down_type == GGMLQuantizationType.BF16
moe_config = AMX_MOEConfig(
n_routed_experts,
self.config.num_experts_per_tok,
self.config.hidden_size,
self.config.moe_intermediate_size,
25600,
gate_ptr,
up_ptr,
down_ptr,
)
self.moe = AMXInt8_MOE(moe_config)
self.cpu_infer.submit(self.moe.load_weights())
self.cpu_infer.sync()
# print(n_routed_experts, hidden_size, moe_intermediate_size)
num_experts_per_tok = self.config.num_experts_per_tok
self.moe = MOE(moe_config)
self.cpu_infer = KExpertsCPU.CPU_INFER
if warmup:
self.cpu_infer.submit(self.moe.warm_up())
self.cpu_infer.sync()