mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
[feature] experts can be injected using CPUInfer
[fix] fix ktransformers interface when use new CUDAGraphRunner [fix] fix YAML and optimize logic, the top rule has the highest priority
This commit is contained in:
parent
80815dbc50
commit
412055d450
13 changed files with 318 additions and 158 deletions
|
@ -33,6 +33,7 @@ from transformers.configuration_utils import PretrainedConfig
|
|||
from abc import ABC, abstractmethod
|
||||
from ktransformers.operators.linear import QuantizedLinearMarlin, QuantizedLinearTorch, KTransformerLinear
|
||||
import time
|
||||
from ktransformers.operators.cpuinfer import CPUInfer
|
||||
|
||||
|
||||
# class Base(BaseInjectedModule, ABC):
|
||||
|
@ -117,7 +118,7 @@ class MLPCPUExperts(MLPExpertsBase):
|
|||
output_cpu:Tensor = None
|
||||
output_gpu_map:dict = {} # Manage output tensor buffer on different gpu
|
||||
#stream_map:dict = {} # Manage cuda stream on different gpu
|
||||
CPU_INFER = cpuinfer_ext.CPUInfer(Config().cpu_infer)
|
||||
CPU_INFER = CPUInfer(Config().cpu_infer)
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
|
@ -126,7 +127,7 @@ class MLPCPUExperts(MLPExpertsBase):
|
|||
n_routed_experts: int,
|
||||
orig_module: nn.Module = None,
|
||||
device: str = "cpu",
|
||||
out_device: str = "cuda", # this device mean which device the output should on
|
||||
out_device: str = "cuda", # this device mean which device the output should on. TODO: support cpu.
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
|
@ -135,51 +136,50 @@ class MLPCPUExperts(MLPExpertsBase):
|
|||
self.out_device = out_device
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False):
|
||||
with torch.device(self.out_device):
|
||||
if device:
|
||||
assert device.lower() == "cpu", "MLPCPUExperts can only be loaded on CPU, Parameter \"device\" can be cpu or None."
|
||||
if w is None: w = self.load_weights()[self.key]
|
||||
self.gate = w["gate"]
|
||||
self.up = w["up"]
|
||||
self.down = w["down"]
|
||||
self.gate_type = w["gate_type"]
|
||||
self.up_type = w["up_type"]
|
||||
self.down_type = w["down_type"]
|
||||
gate_ptr = ctypes.addressof(
|
||||
ctypes.cast(self.gate.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
||||
)
|
||||
up_ptr = ctypes.addressof(
|
||||
ctypes.cast(self.up.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
||||
)
|
||||
down_ptr = ctypes.addressof(
|
||||
ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
||||
)
|
||||
# print(self.gate_qtype, self.up_qtype, self.down_qtype)
|
||||
n_routed_experts = self.n_routed_experts
|
||||
# n_routed_experts = len(self.orig_module)
|
||||
moe_config = MOEConfig(
|
||||
n_routed_experts,
|
||||
self.config.num_experts_per_tok,
|
||||
self.config.hidden_size,
|
||||
self.config.moe_intermediate_size,
|
||||
64,
|
||||
10,
|
||||
1024,
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
down_ptr,
|
||||
self.gate_type,
|
||||
self.up_type,
|
||||
self.down_type,
|
||||
30, # TODO: get from model.dtype
|
||||
)
|
||||
# print(n_routed_experts, hidden_size, moe_intermediate_size)
|
||||
num_experts_per_tok = self.config.num_experts_per_tok
|
||||
self.moe = MOE(moe_config)
|
||||
self.cpu_infer = MLPCPUExperts.CPU_INFER
|
||||
if warmup:
|
||||
self.cpu_infer.submit(self.moe.warm_up())
|
||||
self.cpu_infer.sync()
|
||||
if device:
|
||||
assert device.lower() == "cpu", "MLPCPUExperts can only be loaded on CPU, Parameter \"device\" can be cpu or None."
|
||||
if w is None: w = self.load_weights()[self.key]
|
||||
self.gate = w["gate"]
|
||||
self.up = w["up"]
|
||||
self.down = w["down"]
|
||||
self.gate_type = w["gate_type"]
|
||||
self.up_type = w["up_type"]
|
||||
self.down_type = w["down_type"]
|
||||
gate_ptr = ctypes.addressof(
|
||||
ctypes.cast(self.gate.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
||||
)
|
||||
up_ptr = ctypes.addressof(
|
||||
ctypes.cast(self.up.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
||||
)
|
||||
down_ptr = ctypes.addressof(
|
||||
ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
||||
)
|
||||
# print(self.gate_qtype, self.up_qtype, self.down_qtype)
|
||||
n_routed_experts = self.n_routed_experts
|
||||
# n_routed_experts = len(self.orig_module)
|
||||
moe_config = MOEConfig(
|
||||
n_routed_experts,
|
||||
self.config.num_experts_per_tok,
|
||||
self.config.hidden_size,
|
||||
self.config.moe_intermediate_size,
|
||||
64,
|
||||
10,
|
||||
1024,
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
down_ptr,
|
||||
self.gate_type,
|
||||
self.up_type,
|
||||
self.down_type,
|
||||
30, # TODO: get from model.dtype
|
||||
)
|
||||
# print(n_routed_experts, hidden_size, moe_intermediate_size)
|
||||
num_experts_per_tok = self.config.num_experts_per_tok
|
||||
self.moe = MOE(moe_config)
|
||||
self.cpu_infer = MLPCPUExperts.CPU_INFER
|
||||
if warmup:
|
||||
self.cpu_infer.submit(self.moe.warm_up())
|
||||
self.cpu_infer.sync()
|
||||
if self.out_device not in MLPCPUExperts.output_gpu_map:
|
||||
MLPCPUExperts.output_gpu_map[self.out_device] = torch.zeros((self.config.hidden_size), device=self.out_device)
|
||||
if MLPCPUExperts.input_tensor_cpu == None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue