mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
[fix] format classes and files name
This commit is contained in:
parent
1db4a67dca
commit
67043b4b5c
15 changed files with 212 additions and 212 deletions
|
@ -5,8 +5,8 @@ Description :
|
|||
Author : Azure-Tang, Boxin Zhang, chenht2022
|
||||
Date : 2024-07-25 11:25:24
|
||||
Version : 0.1.0
|
||||
LastEditors : kkk1nak0
|
||||
LastEditTime : 2024-08-11 12:14:39
|
||||
LastEditors : Azure
|
||||
LastEditTime : 2024-08-15 02:36:29
|
||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
'''
|
||||
|
||||
|
@ -31,13 +31,13 @@ from ktransformers.server.config.config import Config
|
|||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from abc import ABC, abstractmethod
|
||||
from ktransformers.operators.linear import QuantizedLinearMarlin, QuantizedLinearTorch, KTransformerLinear
|
||||
from ktransformers.operators.linear import KLinearMarlin, KLinearTorch, KTransformersLinear
|
||||
import time
|
||||
from ktransformers.operators.cpuinfer import CPUInfer
|
||||
|
||||
|
||||
# class Base(BaseInjectedModule, ABC):
|
||||
class MLPExpertsBase(ABC):
|
||||
class KExpertsBase(ABC):
|
||||
def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs):
|
||||
# super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
self.key = key
|
||||
|
@ -111,7 +111,7 @@ class MLPExpertsBase(ABC):
|
|||
tensors[k] = self.gguf_loader.load_gguf_tensor(key + k, device=device)
|
||||
return tensors
|
||||
|
||||
class MLPCPUExperts(MLPExpertsBase):
|
||||
class KExpertsCPU(KExpertsBase):
|
||||
input_tensor_cpu:Tensor = None
|
||||
expert_ids_cpu:Tensor = None
|
||||
weights_cpu:Tensor = None
|
||||
|
@ -131,13 +131,13 @@ class MLPCPUExperts(MLPExpertsBase):
|
|||
**kwargs
|
||||
):
|
||||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
assert device.lower() == "cpu", "MLPCPUExperts can only be loaded on CPU"
|
||||
assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU"
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.out_device = out_device
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False):
|
||||
if device:
|
||||
assert device.lower() == "cpu", "MLPCPUExperts can only be loaded on CPU, Parameter \"device\" can be cpu or None."
|
||||
assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU, Parameter \"device\" can be cpu or None."
|
||||
if w is None: w = self.load_weights()[self.key]
|
||||
self.gate = w["gate"]
|
||||
self.up = w["up"]
|
||||
|
@ -176,28 +176,28 @@ class MLPCPUExperts(MLPExpertsBase):
|
|||
# print(n_routed_experts, hidden_size, moe_intermediate_size)
|
||||
num_experts_per_tok = self.config.num_experts_per_tok
|
||||
self.moe = MOE(moe_config)
|
||||
self.cpu_infer = MLPCPUExperts.CPU_INFER
|
||||
self.cpu_infer = KExpertsCPU.CPU_INFER
|
||||
if warmup:
|
||||
self.cpu_infer.submit(self.moe.warm_up())
|
||||
self.cpu_infer.sync()
|
||||
if self.out_device not in MLPCPUExperts.output_gpu_map:
|
||||
MLPCPUExperts.output_gpu_map[self.out_device] = torch.zeros((self.config.hidden_size), device=self.out_device)
|
||||
if MLPCPUExperts.input_tensor_cpu == None:
|
||||
MLPCPUExperts.input_tensor_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True)
|
||||
MLPCPUExperts.expert_ids_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
|
||||
MLPCPUExperts.weights_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
|
||||
MLPCPUExperts.output_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||
if self.out_device not in KExpertsCPU.output_gpu_map:
|
||||
KExpertsCPU.output_gpu_map[self.out_device] = torch.zeros((self.config.hidden_size), device=self.out_device)
|
||||
if KExpertsCPU.input_tensor_cpu == None:
|
||||
KExpertsCPU.input_tensor_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True)
|
||||
KExpertsCPU.expert_ids_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
|
||||
KExpertsCPU.weights_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
|
||||
KExpertsCPU.output_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||
|
||||
def submit_for_one_decode(self, input_tensor, expert_ids, weights):
|
||||
MLPCPUExperts.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
|
||||
MLPCPUExperts.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
|
||||
MLPCPUExperts.weights_cpu.copy_(weights, non_blocking=True)
|
||||
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(0), MLPCPUExperts.expert_ids_cpu.data_ptr(), MLPCPUExperts.weights_cpu.data_ptr(), MLPCPUExperts.input_tensor_cpu.data_ptr(), MLPCPUExperts.output_cpu.data_ptr()))
|
||||
KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
|
||||
KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
|
||||
KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)
|
||||
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(0), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr()))
|
||||
|
||||
def sync_for_one_decode(self):
|
||||
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)
|
||||
MLPCPUExperts.output_gpu_map[self.out_device].copy_(MLPCPUExperts.output_cpu, non_blocking=True)
|
||||
return MLPCPUExperts.output_gpu_map[self.out_device]
|
||||
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
|
||||
return KExpertsCPU.output_gpu_map[self.out_device]
|
||||
|
||||
def forward(self, input_tensor, expert_ids, weights):
|
||||
# generate, capture and run cuda graph
|
||||
|
@ -205,13 +205,13 @@ class MLPCPUExperts(MLPExpertsBase):
|
|||
if input_tensor.size(0)==1:
|
||||
# TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible
|
||||
#print("capturing experts")
|
||||
MLPCPUExperts.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
|
||||
MLPCPUExperts.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
|
||||
MLPCPUExperts.weights_cpu.copy_(weights, non_blocking=True)
|
||||
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(1, expert_ids.size(1), MLPCPUExperts.expert_ids_cpu.data_ptr(), MLPCPUExperts.weights_cpu.data_ptr(), MLPCPUExperts.input_tensor_cpu.data_ptr(), MLPCPUExperts.output_cpu.data_ptr()))
|
||||
KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
|
||||
KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
|
||||
KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)
|
||||
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(1, expert_ids.size(1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr()))
|
||||
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
|
||||
MLPCPUExperts.output_gpu_map[self.out_device].copy_(MLPCPUExperts.output_cpu, non_blocking=True)
|
||||
return MLPCPUExperts.output_gpu_map[self.out_device]
|
||||
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
|
||||
return KExpertsCPU.output_gpu_map[self.out_device]
|
||||
else:
|
||||
input_tensor = input_tensor.contiguous().cpu()
|
||||
expert_ids = expert_ids.contiguous().cpu()
|
||||
|
@ -269,7 +269,7 @@ class MLPCPUExperts(MLPExpertsBase):
|
|||
res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
|
||||
return res
|
||||
|
||||
class MLPExpertsMarlin(MLPExpertsBase):
|
||||
class KExpertsMarlin(KExpertsBase):
|
||||
expert_num: int
|
||||
loaded_experts_idx: list[int]
|
||||
def __init__(
|
||||
|
@ -290,11 +290,11 @@ class MLPExpertsMarlin(MLPExpertsBase):
|
|||
self.device = device
|
||||
# create empty marlin experts according to the number of experts per token
|
||||
# up
|
||||
self.up_projs = [QuantizedLinearMarlin(key+ "." + "ffn_up_exps", gguf_loader, config, device=device) for i in range(self.expert_num)]
|
||||
self.up_projs = [KLinearMarlin(key+ "." + "ffn_up_exps", gguf_loader, config, device=device) for i in range(self.expert_num)]
|
||||
# gate
|
||||
self.gate_projs = [QuantizedLinearMarlin(key+ "." + "ffn_gate_exps", gguf_loader, config, device=device) for i in range(self.expert_num)]
|
||||
self.gate_projs = [KLinearMarlin(key+ "." + "ffn_gate_exps", gguf_loader, config, device=device) for i in range(self.expert_num)]
|
||||
# down
|
||||
self.down_projs = [QuantizedLinearMarlin(key+ "." + "ffn_down_exps", gguf_loader, config, device=device) for i in range(self.expert_num)]
|
||||
self.down_projs = [KLinearMarlin(key+ "." + "ffn_down_exps", gguf_loader, config, device=device) for i in range(self.expert_num)]
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False):
|
||||
if device is None: device = self.device
|
||||
|
@ -359,7 +359,7 @@ class MLPExpertsMarlin(MLPExpertsBase):
|
|||
outs = outs.to(device)
|
||||
return outs
|
||||
|
||||
class MLPExpertsTorch(MLPExpertsBase):
|
||||
class KExpertsTorch(KExpertsBase):
|
||||
expert_num: int
|
||||
loaded_experts_idx: list[int]
|
||||
gate: torch.Tensor
|
||||
|
@ -439,12 +439,12 @@ class MLPExpertsTorch(MLPExpertsBase):
|
|||
return final_hidden_states.to(org_dtype, device=org_device)
|
||||
|
||||
EXPERTS_MAP = {
|
||||
"MLPCPUExperts": MLPCPUExperts,
|
||||
"MLPExpertsTorch": MLPExpertsTorch,
|
||||
"MLPExpertsMarlin": MLPExpertsMarlin,
|
||||
"KExpertsCPU": KExpertsCPU,
|
||||
"KExpertsTorch": KExpertsTorch,
|
||||
"KExpertsMarlin": KExpertsMarlin,
|
||||
}
|
||||
|
||||
class KTransformersMLPExpert(BaseInjectedModule, MLPExpertsBase):
|
||||
class KTransformersExperts(BaseInjectedModule, KExpertsBase):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
|
@ -452,22 +452,22 @@ class KTransformersMLPExpert(BaseInjectedModule, MLPExpertsBase):
|
|||
orig_module: nn.Module,
|
||||
# device: str = "cuda",
|
||||
prefill_device:str = "cuda",
|
||||
prefill_mlp_type: str | None = "MLPExpertsTorch",
|
||||
prefill_op: str | None = "KExpertsTorch",
|
||||
generate_device: str = "cpu",
|
||||
generate_mlp_type: str | None = "MLPCPUExperts",
|
||||
generate_op: str | None = "KExpertsCPU",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
MLPExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
if generate_mlp_type is not None:
|
||||
self.generate_experts = EXPERTS_MAP[generate_mlp_type](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
|
||||
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
if generate_op is not None:
|
||||
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
|
||||
else:
|
||||
self.generate_experts = None
|
||||
if prefill_mlp_type is not None:
|
||||
self.prefill_experts = EXPERTS_MAP[prefill_mlp_type](key, gguf_loader, config, len(orig_module), device=prefill_device, **kwargs)
|
||||
if prefill_op is not None:
|
||||
self.prefill_experts = EXPERTS_MAP[prefill_op](key, gguf_loader, config, len(orig_module), device=prefill_device, **kwargs)
|
||||
else:
|
||||
self.prefill_experts = None
|
||||
self.gpu_mlp_type = prefill_mlp_type
|
||||
self.cpu_mlp_type = generate_mlp_type
|
||||
self.gpu_mlp_type = prefill_op
|
||||
self.cpu_mlp_type = generate_op
|
||||
self.mode = InferenceState.UNLOAD
|
||||
|
||||
def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True):
|
||||
|
@ -523,7 +523,7 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
|||
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
|
||||
class Qwen2MoeSparseMoeBlockInjected(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
|
||||
class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
""" """
|
||||
orig_shape = hidden_states.shape
|
||||
|
@ -548,16 +548,16 @@ class Qwen2MoeSparseMoeBlockInjected(BaseInjectedModule, Qwen2MoeSparseMoeBlock)
|
|||
y.resize_(*orig_shape)
|
||||
return y, router_logits
|
||||
|
||||
hidden_states_expert = hidden_states.to(self.experts.device) if isinstance(self.experts, MLPExpertsBase) else hidden_states_expert.cpu()
|
||||
selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, MLPExpertsBase) else selected_experts_expert.cpu()
|
||||
routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, MLPExpertsBase) else routing_weights_expert.cpu()
|
||||
hidden_states_expert = hidden_states.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else hidden_states_expert.cpu()
|
||||
selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else selected_experts_expert.cpu()
|
||||
routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else routing_weights_expert.cpu()
|
||||
|
||||
shared_expert_output = self.shared_expert(hidden_states)
|
||||
shared_expert_output = (
|
||||
F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output
|
||||
)
|
||||
|
||||
if isinstance(self.experts, MLPExpertsBase):
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = (
|
||||
self.moe_on_cpuinfer(
|
||||
hidden_states_expert, selected_experts_expert, routing_weights_expert
|
||||
|
@ -628,7 +628,7 @@ class Qwen2MoeSparseMoeBlockInjected(BaseInjectedModule, Qwen2MoeSparseMoeBlock)
|
|||
|
||||
return final_hidden_states
|
||||
|
||||
class DeepseekV2MoEInjected(BaseInjectedModule, DeepseekV2MoE):
|
||||
class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
|
||||
def forward(self, hidden_states):
|
||||
identity = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
|
@ -648,7 +648,7 @@ class DeepseekV2MoEInjected(BaseInjectedModule, DeepseekV2MoE):
|
|||
if self.config.n_shared_experts is not None:
|
||||
y_ = self.shared_experts(identity).squeeze(0)
|
||||
|
||||
if isinstance(self.experts, MLPExpertsBase):
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
# TODO may bugs here
|
||||
|
@ -727,7 +727,7 @@ class DeepseekV2MoEInjected(BaseInjectedModule, DeepseekV2MoE):
|
|||
)
|
||||
return final_out
|
||||
|
||||
class MisrtalSparseMoEBlockInjected(BaseInjectedModule, MixtralSparseMoeBlock):
|
||||
class KMisrtalSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
""" """
|
||||
|
@ -751,11 +751,11 @@ class MisrtalSparseMoEBlockInjected(BaseInjectedModule, MixtralSparseMoeBlock):
|
|||
y.resize_(*orig_shape)
|
||||
return y, router_logits
|
||||
|
||||
hidden_states_expert = hidden_states.to(self.experts.device) if isinstance(self.experts, MLPExpertsBase) else hidden_states_expert.cpu()
|
||||
selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, MLPExpertsBase) else selected_experts_expert.cpu()
|
||||
routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, MLPExpertsBase) else routing_weights_expert.cpu()
|
||||
hidden_states_expert = hidden_states.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else hidden_states_expert.cpu()
|
||||
selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else selected_experts_expert.cpu()
|
||||
routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else routing_weights_expert.cpu()
|
||||
|
||||
if isinstance(self.experts, MLPExpertsBase):
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = (
|
||||
self.moe_on_cpuinfer(
|
||||
hidden_states_expert, selected_experts_expert, routing_weights_expert
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue