mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
support AMX
This commit is contained in:
parent
b90362b5e6
commit
f3d842a0ca
15 changed files with 1799 additions and 62 deletions
|
@ -25,8 +25,9 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext
|
|||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
|
||||
import cpuinfer_ext
|
||||
from cpuinfer_ext.moe import MOEConfig, MOE
|
||||
from cpuinfer_ext.moe import AMX_MOEConfig, AMXBF16_MOE, AMXInt8_MOE
|
||||
import ctypes
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from ktransformers.util.custom_gguf import GGMLQuantizationType, GGUFLoader
|
||||
from ktransformers.util.utils import InferenceState
|
||||
from ktransformers.server.config.config import Config
|
||||
from transformers.activations import ACT2FN
|
||||
|
@ -141,6 +142,7 @@ class KExpertsCPU(KExpertsBase):
|
|||
assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU"
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.out_device = out_device
|
||||
self.backend = kwargs.get("backend", "llamafile")
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False):
|
||||
if device:
|
||||
|
@ -163,27 +165,62 @@ class KExpertsCPU(KExpertsBase):
|
|||
)
|
||||
# print(self.gate_qtype, self.up_qtype, self.down_qtype)
|
||||
n_routed_experts = self.n_routed_experts
|
||||
self.cpu_infer = KExpertsCPU.CPU_INFER
|
||||
# n_routed_experts = len(self.orig_module)
|
||||
moe_config = MOEConfig(
|
||||
n_routed_experts,
|
||||
self.config.num_experts_per_tok,
|
||||
self.config.hidden_size,
|
||||
self.config.moe_intermediate_size,
|
||||
64,
|
||||
10,
|
||||
1024,
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
down_ptr,
|
||||
self.gate_type,
|
||||
self.up_type,
|
||||
self.down_type,
|
||||
30, # TODO: get from model.dtype
|
||||
)
|
||||
if self.backend == "llamafile":
|
||||
moe_config = MOEConfig(
|
||||
n_routed_experts,
|
||||
self.config.num_experts_per_tok,
|
||||
self.config.hidden_size,
|
||||
self.config.moe_intermediate_size,
|
||||
64,
|
||||
10,
|
||||
1024,
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
down_ptr,
|
||||
self.gate_type,
|
||||
self.up_type,
|
||||
self.down_type,
|
||||
30, # TODO: get from model.dtype
|
||||
)
|
||||
self.moe = MOE(moe_config)
|
||||
elif self.backend == "AMXBF16":
|
||||
assert self.gate_type == GGMLQuantizationType.BF16
|
||||
assert self.up_type == GGMLQuantizationType.BF16
|
||||
assert self.down_type == GGMLQuantizationType.BF16
|
||||
moe_config = AMX_MOEConfig(
|
||||
n_routed_experts,
|
||||
self.config.num_experts_per_tok,
|
||||
self.config.hidden_size,
|
||||
self.config.moe_intermediate_size,
|
||||
25600,
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
down_ptr,
|
||||
)
|
||||
self.moe = AMXBF16_MOE(moe_config)
|
||||
self.cpu_infer.submit(self.moe.load_weights())
|
||||
self.cpu_infer.sync()
|
||||
elif self.backend == "AMXInt8":
|
||||
assert self.gate_type == GGMLQuantizationType.BF16
|
||||
assert self.up_type == GGMLQuantizationType.BF16
|
||||
assert self.down_type == GGMLQuantizationType.BF16
|
||||
moe_config = AMX_MOEConfig(
|
||||
n_routed_experts,
|
||||
self.config.num_experts_per_tok,
|
||||
self.config.hidden_size,
|
||||
self.config.moe_intermediate_size,
|
||||
25600,
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
down_ptr,
|
||||
)
|
||||
self.moe = AMXInt8_MOE(moe_config)
|
||||
self.cpu_infer.submit(self.moe.load_weights())
|
||||
self.cpu_infer.sync()
|
||||
# print(n_routed_experts, hidden_size, moe_intermediate_size)
|
||||
num_experts_per_tok = self.config.num_experts_per_tok
|
||||
self.moe = MOE(moe_config)
|
||||
self.cpu_infer = KExpertsCPU.CPU_INFER
|
||||
if warmup:
|
||||
self.cpu_infer.submit(self.moe.warm_up())
|
||||
self.cpu_infer.sync()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue