mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-16 10:09:42 +00:00
[ADD] support multi-gpu qlen>1 q5_k
This commit is contained in:
parent
f293803156
commit
f5f79f5c0e
63 changed files with 3271 additions and 1285 deletions
|
@ -10,6 +10,7 @@ from ktransformers.operators.base_operator import BaseInjectedModule
|
|||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from ktransformers.util.utils import InferenceState
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
|
||||
class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
|
||||
def __init__(self,
|
||||
|
@ -17,12 +18,16 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
|
|||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
device: str = "cuda",
|
||||
# device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.dim,
|
||||
orig_module.max_position_embeddings,
|
||||
orig_module.base)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
|
||||
def load(self):
|
||||
self.orig_module.__init__(self.orig_module.dim,
|
||||
|
@ -36,9 +41,11 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
|
|||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
device: str = "cuda",
|
||||
# device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.dim,
|
||||
orig_module.max_position_embeddings,
|
||||
orig_module.base,
|
||||
|
@ -49,13 +56,15 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
|
|||
orig_module.beta_slow,
|
||||
orig_module.mscale,
|
||||
orig_module.mscale_all_dim)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
|
||||
|
||||
def load(self):
|
||||
self.orig_module.__init__(self.orig_module.dim,
|
||||
self.orig_module.max_position_embeddings,
|
||||
self.orig_module.base,
|
||||
self.device,
|
||||
self.generate_device,
|
||||
self.orig_module.scaling_factor,
|
||||
self.orig_module.original_max_position_embeddings,
|
||||
self.orig_module.beta_fast,
|
||||
|
|
|
@ -5,8 +5,8 @@ Description :
|
|||
Author : Azure-Tang, Boxin Zhang, chenht2022
|
||||
Date : 2024-07-25 11:25:24
|
||||
Version : 0.1.0
|
||||
LastEditors : Azure
|
||||
LastEditTime : 2024-07-26 09:27:41
|
||||
LastEditors : kkk1nak0
|
||||
LastEditTime : 2024-08-11 12:14:39
|
||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
'''
|
||||
|
||||
|
@ -19,7 +19,9 @@ import torch
|
|||
import sys, os
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
|
||||
sys.path.append(os.path.dirname(__file__) + "/../ktransformers_ext/build")
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
|
||||
import cpuinfer_ext
|
||||
from cpuinfer_ext.moe import MOEConfig, MOE
|
||||
import ctypes
|
||||
|
@ -78,6 +80,25 @@ class MLPExpertsBase(ABC):
|
|||
gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"]
|
||||
up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"]
|
||||
down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"]
|
||||
elif key + ".ffn_down.0.weight" in self.gguf_loader.tensor_info:
|
||||
# for supporting Mixtral-8x7B-Instuct
|
||||
gate = []
|
||||
up = []
|
||||
down = []
|
||||
for i in range(8):
|
||||
gatei, upi, downi = f".ffn_gate.{i}.weight", f".ffn_up.{i}.weight", f".ffn_down.{i}.weight"
|
||||
targets = [gatei, upi, downi]
|
||||
tensors = self.load_multi(key, targets, device=device)
|
||||
gate_it, up_it, down_it = tensors[gatei], tensors[upi], tensors[downi]
|
||||
gate.append(gate_it)
|
||||
up.append(up_it)
|
||||
down.append(down_it)
|
||||
gate = torch.stack(gate)
|
||||
up = torch.stack(up)
|
||||
down = torch.stack(down)
|
||||
gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate.0.weight"]["ggml_type"]
|
||||
up_type = self.gguf_loader.tensor_info[key + ".ffn_up.0.weight"]["ggml_type"]
|
||||
down_type = self.gguf_loader.tensor_info[key + ".ffn_down.0.weight"]["ggml_type"]
|
||||
else:
|
||||
raise ValueError(f"Experts {key} not found in gguf_loader")
|
||||
res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
|
||||
|
@ -94,7 +115,8 @@ class MLPCPUExperts(MLPExpertsBase):
|
|||
expert_ids_cpu:Tensor = None
|
||||
weights_cpu:Tensor = None
|
||||
output_cpu:Tensor = None
|
||||
output_gpu:Tensor = None
|
||||
output_gpu_map:dict = {} # Manage output tensor buffer on different gpu
|
||||
#stream_map:dict = {} # Manage cuda stream on different gpu
|
||||
CPU_INFER = cpuinfer_ext.CPUInfer(Config().cpu_infer)
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -113,81 +135,83 @@ class MLPCPUExperts(MLPExpertsBase):
|
|||
self.out_device = out_device
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False):
|
||||
if device:
|
||||
assert device.lower() == "cpu", "MLPCPUExperts can only be loaded on CPU, Parameter \"device\" can be cpu or None."
|
||||
if w is None: w = self.load_weights()[self.key]
|
||||
self.gate = w["gate"]
|
||||
self.up = w["up"]
|
||||
self.down = w["down"]
|
||||
self.gate_type = w["gate_type"]
|
||||
self.up_type = w["up_type"]
|
||||
self.down_type = w["down_type"]
|
||||
gate_ptr = ctypes.addressof(
|
||||
ctypes.cast(self.gate.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
||||
)
|
||||
up_ptr = ctypes.addressof(
|
||||
ctypes.cast(self.up.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
||||
)
|
||||
down_ptr = ctypes.addressof(
|
||||
ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
||||
)
|
||||
# print(self.gate_qtype, self.up_qtype, self.down_qtype)
|
||||
n_routed_experts = self.n_routed_experts
|
||||
# n_routed_experts = len(self.orig_module)
|
||||
moe_config = MOEConfig(
|
||||
n_routed_experts,
|
||||
self.config.num_experts_per_tok,
|
||||
self.config.hidden_size,
|
||||
self.config.moe_intermediate_size,
|
||||
64,
|
||||
10,
|
||||
1024,
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
down_ptr,
|
||||
self.gate_type,
|
||||
self.up_type,
|
||||
self.down_type,
|
||||
30, # TODO: get from model.dtype
|
||||
)
|
||||
# print(n_routed_experts, hidden_size, moe_intermediate_size)
|
||||
num_experts_per_tok = self.config.num_experts_per_tok
|
||||
self.moe = MOE(moe_config)
|
||||
self.cpu_infer = MLPCPUExperts.CPU_INFER
|
||||
if warmup:
|
||||
self.cpu_infer.submit(self.moe.warm_up())
|
||||
self.cpu_infer.sync()
|
||||
if MLPCPUExperts.output_gpu == None:
|
||||
MLPCPUExperts.input_tensor_cpu = torch.empty((self.config.hidden_size), device="cpu", pin_memory=True)
|
||||
MLPCPUExperts.expert_ids_cpu = torch.empty((num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
|
||||
MLPCPUExperts.weights_cpu = torch.empty((num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
|
||||
MLPCPUExperts.output_cpu = torch.empty((self.config.hidden_size), device="cpu", pin_memory=True)
|
||||
MLPCPUExperts.output_gpu = torch.empty((self.config.hidden_size), device=self.out_device)
|
||||
|
||||
with torch.device(self.out_device):
|
||||
if device:
|
||||
assert device.lower() == "cpu", "MLPCPUExperts can only be loaded on CPU, Parameter \"device\" can be cpu or None."
|
||||
if w is None: w = self.load_weights()[self.key]
|
||||
self.gate = w["gate"]
|
||||
self.up = w["up"]
|
||||
self.down = w["down"]
|
||||
self.gate_type = w["gate_type"]
|
||||
self.up_type = w["up_type"]
|
||||
self.down_type = w["down_type"]
|
||||
gate_ptr = ctypes.addressof(
|
||||
ctypes.cast(self.gate.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
||||
)
|
||||
up_ptr = ctypes.addressof(
|
||||
ctypes.cast(self.up.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
||||
)
|
||||
down_ptr = ctypes.addressof(
|
||||
ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
||||
)
|
||||
# print(self.gate_qtype, self.up_qtype, self.down_qtype)
|
||||
n_routed_experts = self.n_routed_experts
|
||||
# n_routed_experts = len(self.orig_module)
|
||||
moe_config = MOEConfig(
|
||||
n_routed_experts,
|
||||
self.config.num_experts_per_tok,
|
||||
self.config.hidden_size,
|
||||
self.config.moe_intermediate_size,
|
||||
64,
|
||||
10,
|
||||
1024,
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
down_ptr,
|
||||
self.gate_type,
|
||||
self.up_type,
|
||||
self.down_type,
|
||||
30, # TODO: get from model.dtype
|
||||
)
|
||||
# print(n_routed_experts, hidden_size, moe_intermediate_size)
|
||||
num_experts_per_tok = self.config.num_experts_per_tok
|
||||
self.moe = MOE(moe_config)
|
||||
self.cpu_infer = MLPCPUExperts.CPU_INFER
|
||||
if warmup:
|
||||
self.cpu_infer.submit(self.moe.warm_up())
|
||||
self.cpu_infer.sync()
|
||||
if self.out_device not in MLPCPUExperts.output_gpu_map:
|
||||
MLPCPUExperts.output_gpu_map[self.out_device] = torch.zeros((self.config.hidden_size), device=self.out_device)
|
||||
if MLPCPUExperts.input_tensor_cpu == None:
|
||||
MLPCPUExperts.input_tensor_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True)
|
||||
MLPCPUExperts.expert_ids_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
|
||||
MLPCPUExperts.weights_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
|
||||
MLPCPUExperts.output_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||
|
||||
def submit_for_one_decode(self, input_tensor, expert_ids, weights):
|
||||
MLPCPUExperts.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
|
||||
MLPCPUExperts.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
|
||||
MLPCPUExperts.weights_cpu.copy_(weights, non_blocking=True)
|
||||
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(1, expert_ids.size(0), MLPCPUExperts.expert_ids_cpu.data_ptr(), MLPCPUExperts.weights_cpu.data_ptr(), MLPCPUExperts.input_tensor_cpu.data_ptr(), MLPCPUExperts.output_cpu.data_ptr()))
|
||||
|
||||
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(0), MLPCPUExperts.expert_ids_cpu.data_ptr(), MLPCPUExperts.weights_cpu.data_ptr(), MLPCPUExperts.input_tensor_cpu.data_ptr(), MLPCPUExperts.output_cpu.data_ptr()))
|
||||
|
||||
def sync_for_one_decode(self):
|
||||
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
|
||||
MLPCPUExperts.output_gpu.copy_(MLPCPUExperts.output_cpu, non_blocking=True)
|
||||
#print("capturing experts finish")
|
||||
return MLPCPUExperts.output_gpu
|
||||
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)
|
||||
MLPCPUExperts.output_gpu_map[self.out_device].copy_(MLPCPUExperts.output_cpu, non_blocking=True)
|
||||
return MLPCPUExperts.output_gpu_map[self.out_device]
|
||||
|
||||
def forward(self, input_tensor, expert_ids, weights):
|
||||
# generate, capture and run cuda graph
|
||||
# print(expert_ids)
|
||||
if input_tensor.size(0)==1:
|
||||
# TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible
|
||||
#print("capturing experts")
|
||||
MLPCPUExperts.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
|
||||
MLPCPUExperts.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
|
||||
MLPCPUExperts.weights_cpu.copy_(weights, non_blocking=True)
|
||||
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(1, expert_ids.size(1), MLPCPUExperts.expert_ids_cpu.data_ptr(), MLPCPUExperts.weights_cpu.data_ptr(), MLPCPUExperts.input_tensor_cpu.data_ptr(), MLPCPUExperts.output_cpu.data_ptr()))
|
||||
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
|
||||
MLPCPUExperts.output_gpu.copy_(MLPCPUExperts.output_cpu, non_blocking=True)
|
||||
#print("capturing experts finish")
|
||||
return MLPCPUExperts.output_gpu
|
||||
MLPCPUExperts.output_gpu_map[self.out_device].copy_(MLPCPUExperts.output_cpu, non_blocking=True)
|
||||
return MLPCPUExperts.output_gpu_map[self.out_device]
|
||||
else:
|
||||
input_tensor = input_tensor.contiguous().cpu()
|
||||
expert_ids = expert_ids.contiguous().cpu()
|
||||
|
@ -195,7 +219,7 @@ class MLPCPUExperts(MLPExpertsBase):
|
|||
output = torch.empty_like(input_tensor).contiguous()
|
||||
self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr()))
|
||||
self.cpu_infer.sync()
|
||||
return output.to(device=object.__getattribute__(self, "device"))
|
||||
return output.to(device=object.__getattribute__(self, "out_device"))
|
||||
|
||||
def unload(self):
|
||||
return
|
||||
|
@ -222,6 +246,24 @@ class MLPCPUExperts(MLPExpertsBase):
|
|||
gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"]
|
||||
up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"]
|
||||
down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"]
|
||||
elif key + ".ffn_down.0.weight" in self.gguf_loader.tensor_info:
|
||||
# for supporting Mixtral-8x7B-Instuct
|
||||
gate = []
|
||||
up = []
|
||||
down = []
|
||||
for i in range(8):
|
||||
gate_it = self.gguf_loader.get_mmap_tensor(f"{key}.ffn_gate.{i}.weight")
|
||||
up_it = self.gguf_loader.get_mmap_tensor(f"{key}.ffn_up.{i}.weight")
|
||||
down_it = self.gguf_loader.get_mmap_tensor(f"{key}.ffn_down.{i}.weight")
|
||||
gate.append(gate_it)
|
||||
up.append(up_it)
|
||||
down.append(down_it)
|
||||
gate = np.stack(gate)
|
||||
up = np.stack(up)
|
||||
down = np.stack(down)
|
||||
gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate.0.weight"]["ggml_type"]
|
||||
up_type = self.gguf_loader.tensor_info[key + ".ffn_up.0.weight"]["ggml_type"]
|
||||
down_type = self.gguf_loader.tensor_info[key + ".ffn_down.0.weight"]["ggml_type"]
|
||||
else:
|
||||
raise ValueError(f"Experts {key} not found in gguf_loader")
|
||||
res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
|
||||
|
@ -299,7 +341,7 @@ class MLPExpertsMarlin(MLPExpertsBase):
|
|||
gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"]
|
||||
up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"]
|
||||
down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"]
|
||||
# tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"])
|
||||
# tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"])
|
||||
res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
|
||||
return res
|
||||
|
||||
|
@ -359,6 +401,11 @@ class MLPExpertsTorch(MLPExpertsBase):
|
|||
self.down = None
|
||||
|
||||
def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
org_device = hidden_states_cpu.device
|
||||
hidden_states_cpu = hidden_states_cpu.to(self.device)
|
||||
selected_experts_cpu = selected_experts_cpu.to(self.device)
|
||||
routing_weights_cpu = routing_weights_cpu.to(self.device)
|
||||
|
||||
batch_sequence_length, hidden_dim = hidden_states_cpu.size()
|
||||
|
||||
|
@ -388,27 +435,29 @@ class MLPExpertsTorch(MLPExpertsBase):
|
|||
# the `top_x` tensor here.
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states)
|
||||
|
||||
return final_hidden_states.to(org_dtype)
|
||||
|
||||
return final_hidden_states.to(org_dtype, device=org_device)
|
||||
|
||||
EXPERTS_MAP = {
|
||||
"MLPCPUExperts": MLPCPUExperts,
|
||||
"MLPExpertsTorch": MLPExpertsTorch,
|
||||
"MLPExpertsMarlin": MLPExpertsMarlin,
|
||||
}
|
||||
|
||||
class KTransformersMLPExpert(BaseInjectedModule, MLPExpertsBase):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
device: str = "cuda",
|
||||
# device: str = "cuda",
|
||||
prefill_device:str = "cuda",
|
||||
prefill_mlp_type: str | None = "MLPExpertsTorch",
|
||||
generate_device: str = "cpu",
|
||||
generate_mlp_type: str | None = "MLPCPUExperts",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
MLPExpertsBase.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
MLPExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
if generate_mlp_type is not None:
|
||||
self.generate_experts = EXPERTS_MAP[generate_mlp_type](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
|
||||
else:
|
||||
|
@ -471,6 +520,7 @@ class KTransformersMLPExpert(BaseInjectedModule, MLPExpertsBase):
|
|||
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2MoE
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
|
||||
class Qwen2MoeSparseMoeBlockInjected(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
|
||||
|
@ -578,7 +628,6 @@ class Qwen2MoeSparseMoeBlockInjected(BaseInjectedModule, Qwen2MoeSparseMoeBlock)
|
|||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
class DeepseekV2MoEInjected(BaseInjectedModule, DeepseekV2MoE):
|
||||
def forward(self, hidden_states):
|
||||
identity = hidden_states
|
||||
|
@ -587,7 +636,7 @@ class DeepseekV2MoEInjected(BaseInjectedModule, DeepseekV2MoE):
|
|||
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
if sequence_length == 1:
|
||||
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode"):
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
|
||||
if self.config.n_shared_experts is not None:
|
||||
y_ = self.shared_experts(identity).squeeze(0)
|
||||
|
@ -677,3 +726,102 @@ class DeepseekV2MoEInjected(BaseInjectedModule, DeepseekV2MoE):
|
|||
.type(new_x.dtype)
|
||||
)
|
||||
return final_out
|
||||
|
||||
class MisrtalSparseMoEBlockInjected(BaseInjectedModule, MixtralSparseMoeBlock):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
""" """
|
||||
orig_shape = hidden_states.shape
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
if self.training and self.jitter_noise > 0:
|
||||
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode"):
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], selected_experts[0], routing_weights[0])
|
||||
y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)
|
||||
y.resize_(*orig_shape)
|
||||
return y, router_logits
|
||||
|
||||
hidden_states_expert = hidden_states.to(self.experts.device) if isinstance(self.experts, MLPExpertsBase) else hidden_states_expert.cpu()
|
||||
selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, MLPExpertsBase) else selected_experts_expert.cpu()
|
||||
routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, MLPExpertsBase) else routing_weights_expert.cpu()
|
||||
|
||||
if isinstance(self.experts, MLPExpertsBase):
|
||||
y = (
|
||||
self.moe_on_cpuinfer(
|
||||
hidden_states_expert, selected_experts_expert, routing_weights_expert
|
||||
)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
elif hidden_states_expert.size(0) > 10:
|
||||
y = self.moe_infer(
|
||||
hidden_states_expert, selected_experts_expert, routing_weights_expert, orig_shape
|
||||
).to(device=hidden_states.device)
|
||||
else:
|
||||
y = self.moe_infer_simple(
|
||||
hidden_states_expert, selected_experts_expert, routing_weights_expert
|
||||
).to(device=hidden_states.device)
|
||||
|
||||
y.resize_(*orig_shape)
|
||||
return y, router_logits
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
||||
outs = torch.empty_like(x)
|
||||
outs = self.experts(x, topk_ids, topk_weight)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer_simple(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
|
||||
'''
|
||||
hidden_states_cpu: [num_tokens, hidden_size]
|
||||
topk_ids, topk_weight: [num_tokens, num_selected_experts]
|
||||
'''
|
||||
outs = torch.zeros_like(hidden_states_cpu)
|
||||
for token_idx in range(selected_experts_cpu.size(0)):
|
||||
for expert_idx in range(selected_experts_cpu.size(1)):
|
||||
expert = self.experts[selected_experts_cpu[token_idx, expert_idx]]
|
||||
outs[token_idx] += expert.forward(hidden_states_cpu[token_idx]) * routing_weights_cpu[token_idx, expert_idx]
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor, orig_shape: tuple) -> torch.Tensor:
|
||||
|
||||
batch_size, sequence_length, hidden_dim = orig_shape
|
||||
|
||||
final_hidden_states = torch.zeros(
|
||||
(batch_size * sequence_length, hidden_dim), dtype=hidden_states_cpu.dtype, device=hidden_states_cpu.device
|
||||
)
|
||||
|
||||
# One hot encode the selected experts to create an expert mask
|
||||
# this will be used to easily index which expert is going to be sollicitated
|
||||
expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
# Loop over all available experts in the model and perform the computation on each expert
|
||||
for expert_idx in range(self.num_experts):
|
||||
expert_layer = self.experts[expert_idx]
|
||||
idx, top_x = torch.where(expert_mask[expert_idx])
|
||||
|
||||
# Index the correct hidden states and compute the expert hidden state for
|
||||
# the current expert. We need to make sure to multiply the output hidden
|
||||
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
||||
current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim)
|
||||
current_hidden_states = expert_layer.forward(current_state) * routing_weights_cpu[top_x, idx, None]
|
||||
|
||||
# However `index_add_` only support torch tensors for indexing so we'll use
|
||||
# the `top_x` tensor here.
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states_cpu.dtype))
|
||||
|
||||
return final_hidden_states
|
|
@ -6,7 +6,7 @@ Author : Azure-Tang
|
|||
Date : 2024-07-25 11:25:24
|
||||
Version : 1.0.0
|
||||
LastEditors : Azure
|
||||
LastEditTime : 2024-07-26 09:27:48
|
||||
LastEditTime : 2024-08-08 10:09:14
|
||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
'''
|
||||
|
||||
|
@ -45,6 +45,8 @@ from ktransformers.models.modeling_deepseek import BaseModelOutputWithPast, Deep
|
|||
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.utils import InferenceState
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
|
@ -73,34 +75,6 @@ QWEN2MOE_START_DOCSTRING = r"""
|
|||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.",
|
||||
QWEN2MOE_START_DOCSTRING,
|
||||
)
|
||||
class Qwen2MoePreTrainedModel(PreTrainedModel):
|
||||
config_class = Qwen2MoeConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Qwen2MoeDecoderLayer"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
QWEN2MOE_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
|
@ -177,13 +151,11 @@ QWEN2MOE_INPUTS_DOCSTRING = r"""
|
|||
the complete sequence length.
|
||||
"""
|
||||
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
@add_start_docstrings(
|
||||
"The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.",
|
||||
QWEN2MOE_START_DOCSTRING,
|
||||
)
|
||||
class Qwen2MoeModelPerLayerPrefill(BaseInjectedModule):
|
||||
class Qwen2MoeModelKTransformers(BaseInjectedModule):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`]
|
||||
|
||||
|
@ -198,10 +170,13 @@ class Qwen2MoeModelPerLayerPrefill(BaseInjectedModule):
|
|||
orig_module: nn.Module,
|
||||
device: str = "cuda",
|
||||
per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill
|
||||
transfer_map: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold
|
||||
self.transfer_map = transfer_map
|
||||
self.stream_device_map = dict()
|
||||
|
||||
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
|
@ -287,7 +262,20 @@ class Qwen2MoeModelPerLayerPrefill(BaseInjectedModule):
|
|||
all_router_logits = () if output_router_logits else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
for i, decoder_layer in enumerate(self.layers):
|
||||
if self.transfer_map is not None and i in self.transfer_map:
|
||||
prev_stream = torch.cuda.current_stream()
|
||||
cur_device = self.transfer_map[i]
|
||||
if cur_device not in self.stream_device_map:
|
||||
self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
|
||||
torch.cuda.set_device(cur_device)
|
||||
self.stream_device_map[cur_device].wait_stream(prev_stream)
|
||||
torch.cuda.set_stream(self.stream_device_map[cur_device])
|
||||
hidden_states = hidden_states.to(self.transfer_map[i], non_blocking = True)
|
||||
causal_mask = causal_mask.to(self.transfer_map[i], non_blocking = True) if causal_mask is not None else None
|
||||
position_ids = position_ids.to(self.transfer_map[i], non_blocking = True) if position_ids is not None else None
|
||||
cache_position = cache_position.to(self.transfer_map[i], non_blocking = True) if cache_position is not None else None
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
|
@ -463,7 +451,7 @@ DeepseekV2_INPUTS_DOCSTRING = r"""
|
|||
"""
|
||||
|
||||
|
||||
class DeepseekV2ModelPerLayerPrefill(BaseInjectedModule):
|
||||
class DeepseekV2ModelKTransformers(BaseInjectedModule):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`]
|
||||
|
||||
|
@ -478,10 +466,13 @@ class DeepseekV2ModelPerLayerPrefill(BaseInjectedModule):
|
|||
orig_module: nn.Module,
|
||||
device: str = "cuda",
|
||||
per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill
|
||||
transfer_map: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold
|
||||
self.transfer_map = transfer_map
|
||||
self.stream_device_map = dict()
|
||||
|
||||
@add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
|
@ -584,7 +575,20 @@ class DeepseekV2ModelPerLayerPrefill(BaseInjectedModule):
|
|||
t_cpu = 0
|
||||
t_f = 0
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
for i, decoder_layer in enumerate(self.layers):
|
||||
if self.transfer_map is not None and i in self.transfer_map:
|
||||
prev_stream = torch.cuda.current_stream()
|
||||
cur_device = self.transfer_map[i]
|
||||
if cur_device not in self.stream_device_map:
|
||||
self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
|
||||
torch.cuda.set_device(cur_device)
|
||||
self.stream_device_map[cur_device].wait_stream(prev_stream)
|
||||
torch.cuda.set_stream(self.stream_device_map[cur_device])
|
||||
hidden_states = hidden_states.to(self.transfer_map[i], non_blocking = True)
|
||||
causal_mask = causal_mask.to(self.transfer_map[i], non_blocking = True) if causal_mask is not None else None
|
||||
position_ids = position_ids.to(self.transfer_map[i], non_blocking = True) if position_ids is not None else None
|
||||
cache_position = cache_position.to(self.transfer_map[i], non_blocking = True) if cache_position is not None else None
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
|
|
|
@ -176,7 +176,7 @@ class QuantizedLinearMarlin(QuantizedLinearBase):
|
|||
self.act_order = act_order
|
||||
self.is_k_full = is_k_full
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = "cuda"):
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||
if device is None: device = self.device
|
||||
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
|
||||
if w is None: w = self.load_weight(device=device)
|
||||
|
@ -200,7 +200,7 @@ class QuantizedLinearMarlin(QuantizedLinearBase):
|
|||
weight, self.num_bits, self.group_size, self.act_order
|
||||
)
|
||||
self.workspace = MarlinWorkspace(
|
||||
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
|
||||
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device
|
||||
)
|
||||
self.marlin_q_w = marlin_q_w
|
||||
self.marlin_s = marlin_s
|
||||
|
@ -247,7 +247,6 @@ class QuantizedLinearMarlin(QuantizedLinearBase):
|
|||
LINEAR_MAP = {
|
||||
"QuantizedLinearMarlin": QuantizedLinearMarlin,
|
||||
"QuantizedLinearTorch": QuantizedLinearTorch,
|
||||
"QuantizedLinearTorch": QuantizedLinearTorch,
|
||||
}
|
||||
|
||||
class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase):
|
||||
|
@ -257,15 +256,15 @@ class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase):
|
|||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
device: str = "cuda",
|
||||
# device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
generate_op: str| None = "QuantizedLinearMarlin",
|
||||
prefill_device: str = "cuda",
|
||||
prefill_op: str| None = "QuantizedLinearTorch",
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
QuantizedLinearBase.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
QuantizedLinearBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
# build all the linear operators
|
||||
if prefill_op is not None:
|
||||
assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported"
|
||||
|
@ -289,7 +288,6 @@ class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase):
|
|||
self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
else:
|
||||
self.generate_linear = None
|
||||
self.device = device
|
||||
self.mode = InferenceState.UNLOAD
|
||||
|
||||
def forward(self, x):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue