Merge pull request #333 from kvcache-ai/feat_experts_gpu

toy support for experts on GPU, no CUDA Graph
This commit is contained in:
Atream 2025-02-15 23:30:24 +08:00 committed by GitHub
commit c5f036e8a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 202 additions and 66 deletions

View file

@ -18,6 +18,7 @@ import torch.nn.functional as F
import torch
import sys, os
from ktransformers.operators.base_operator import BaseInjectedModule
from tqdm import tqdm
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
@ -225,6 +226,7 @@ class KExpertsCPU(KExpertsBase):
return
def load_weights(self, override_key: str | None = None, device: str = "cpu"):
# TODO: support Bias
res = {}
if override_key is not None:
keys = override_key
@ -288,6 +290,8 @@ class KExpertsMarlin(KExpertsBase):
self.act_fn = ACT2FN[config.hidden_act]
assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU"
self.device = device
self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size
# create empty marlin experts according to the number of experts per token
# up
self.up_projs = [KLinearMarlin(key+ "." + "ffn_up_exps", gguf_loader, config, device=device) for i in range(self.expert_num)]
@ -299,17 +303,34 @@ class KExpertsMarlin(KExpertsBase):
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False):
if device is None: device = self.device
assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU"
if w is None: w = self.load_weights()[self.key]
if w is None:
w = self.load_weights()
load_by_experts = True
if isinstance(w, dict):
self.gate = w["gate"]
self.up = (w["up"])
self.down = (w["down"])
for i in range(self.expert_num):
self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device)
self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device)
self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device)
self.loaded_experts_idx.append(i)
if load_by_experts:
if isinstance(w, dict):
self.gate = w["gate"]
self.up = (w["up"])
self.down = (w["down"])
for i in tqdm(range(self.expert_num), desc=f"Dequanting and quanting for KExpertsMarlin {self.key}"):
up_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_up_exps.weight", self.up, i, self.elements_per_tensor, device=self.device)
gate_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_gate_exps.weight", self.gate, i, self.elements_per_tensor, device=self.device)
down_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_down_exps.weight", self.down, i, self.elements_per_tensor, device=self.device)
self.up_projs[i].load(nn.Parameter(up_weights), device=device)
self.gate_projs[i].load(nn.Parameter(gate_weights), device=device)
self.down_projs[i].load(nn.Parameter(down_weights), device=device)
self.loaded_experts_idx.append(i)
else:
if isinstance(w, dict):
self.gate = w["gate"]
self.up = (w["up"])
self.down = (w["down"])
for i in range(self.expert_num):
self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device)
self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device)
self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device)
self.loaded_experts_idx.append(i)
return
def unload(self):
@ -329,20 +350,13 @@ class KExpertsMarlin(KExpertsBase):
gate = None
up = None
down = None
gate_type = None
up_type = None
down_type = None
for key in keys:
if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
gate = self.gguf_loader.load_gguf_tensor(key + ".ffn_gate_exps.weight")
up = self.gguf_loader.load_gguf_tensor(key + ".ffn_up_exps.weight")
down = self.gguf_loader.load_gguf_tensor(key + ".ffn_down_exps.weight")
gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"]
up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"]
down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"]
# tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"])
res = {key:{"gate": nn.Parameter(gate), "up": nn.Parameter(up), "down": nn.Parameter(down), "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
res = {"gate": gate, "up": up, "down": down}
return res
def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
@ -381,6 +395,7 @@ class KExpertsMarlin(KExpertsBase):
return final_hidden_states.to(dtype=org_dtype, device=org_device)
# untested, CUDA OOM
class KExpertsTorch(KExpertsBase):
expert_num: int
loaded_experts_idx: list[int]
@ -402,19 +417,39 @@ class KExpertsTorch(KExpertsBase):
# self.loaded_experts_idx = []
self.act_fn = ACT2FN[config.hidden_act]
self.device = device
self.gate = None
self.up = None
self.donw = None
self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size
self.gate = [None for _ in range(self.expert_num)]
self.up = [None for _ in range(self.expert_num)]
self.down = [None for _ in range(self.expert_num)]
self.dtype = torch.get_default_dtype()
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False):
if device is None: device = self.device
if w is None: w = self.load_weights(device=device)[self.key]
if w is None:
w = self.load_weights()
load_by_experts = True
if isinstance(w, dict):
self.gate = w["gate"].to(device=device, dtype=self.dtype)
self.up = w["up"].to(device=device, dtype=self.dtype)
self.down = w["down"].to(device=device, dtype=self.dtype)
if load_by_experts:
if isinstance(w, dict):
for i in tqdm(range(self.expert_num), desc=f"Dequanting for KExpertsTorch {self.key}"):
up_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_up_exps.weight", w["up"], i, self.elements_per_tensor, device=self.device)
gate_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_gate_exps.weight", w["gate"], i, self.elements_per_tensor, device=self.device)
down_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_down_exps.weight", w["down"], i, self.elements_per_tensor, device=self.device)
self.up[i] = up_weights
self.gate[i] = gate_weights
self.down[i] = down_weights
else:
if isinstance(w, dict):
for i in range(self.expert_num):
self.gate[i] = w["gate"][i, ...].to(device=device, dtype=self.dtype)
self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype)
self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype)
self.up = torch.cat(self.gate, dim=0)
self.gate = torch.cat(self.gate, dim=0)
self.down = torch.cat(self.gate, dim=0)
return
def unload(self):
if self.gate is not None:
@ -422,6 +457,25 @@ class KExpertsTorch(KExpertsBase):
self.up = None
self.down = None
def load_weights(self, override_key: str | None = None):
res = {}
if override_key is not None:
keys = override_key
else:
keys = [self.key]
gate = None
up = None
down = None
for key in keys:
if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
res = {"gate": gate, "up": up, "down": down}
return res
def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
org_device = hidden_states_cpu.device
@ -582,7 +636,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
if isinstance(self.experts, KExpertsBase):
y = (
self.moe_on_cpuinfer(
self.moe_kexperts(
hidden_states_expert, selected_experts_expert, routing_weights_expert
)
.view(*orig_shape)
@ -601,8 +655,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
return y, router_logits
@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = self.experts(x, topk_ids, topk_weight)
return outs
@ -672,7 +725,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
y_ = self.shared_experts(identity).squeeze(0)
if isinstance(self.experts, KExpertsBase):
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
elif hidden_states.size(0) > 10:
# TODO may bugs here
y = (
@ -692,8 +745,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
return y
@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = self.experts(x, topk_ids, topk_weight)
return outs
@ -773,7 +825,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
y_ = self.shared_experts(identity).squeeze(0)
if isinstance(self.experts, KExpertsBase):
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
elif hidden_states.size(0) > 10:
# TODO may bugs here
y = (
@ -793,8 +845,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
return y
@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = self.experts(x, topk_ids, topk_weight)
return outs
@ -881,7 +932,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
if isinstance(self.experts, KExpertsBase):
y = (
self.moe_on_cpuinfer(
self.moe_kexperts(
hidden_states_expert, selected_experts_expert, routing_weights_expert
)
.view(*orig_shape)
@ -900,8 +951,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
return y, router_logits
@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = self.experts(x, topk_ids, topk_weight)
return outs

View file

@ -119,7 +119,7 @@ class KLinearTorch(KLinearBase):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.has_bias = False
self.dtype = torch.get_default_dtype()
self.w = None
self.weight = None
self.has_bias = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -127,7 +127,7 @@ class KLinearTorch(KLinearBase):
out_device = x.device
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
x = x.to(device=self.device, dtype=self.dtype)
x = x @ self.w
x = x @ self.weight
if self.has_bias:
x = x + self.bias
x = x.to(dtype=dtype, device=out_device)
@ -140,27 +140,27 @@ class KLinearTorch(KLinearBase):
if isinstance(w, nn.Parameter):
try:
self.w = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
self.weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
except:
self.w = w.to(dtype=self.dtype).T
self.weight = w.to(dtype=self.dtype).T
self.has_bias = False
elif isinstance(w, tuple):
try:
self.w = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T
self.weight = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T
except:
self.w = w[0].to(dtype=self.dtype).T
self.weight = w[0].to(dtype=self.dtype).T
self.bias = w[1].to(dtype=self.dtype)
self.has_bias = True
else:
raise ValueError("Invalid weight type")
# self.linear = self.linear.to(device)
self.w = self.w.to(device)
self.weight = self.weight.to(device)
if self.has_bias:
self.bias = self.bias.to(device)
def unload(self):
if self.w is not None:
self.w = None
if self.weight is not None:
self.weight = None
if self.has_bias:
self.bias = None
@ -218,6 +218,7 @@ class KLinearMarlin(KLinearBase):
self.workspace = MarlinWorkspace(
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device
)
self.weight = marlin_q_w # modeling_xxx.py may use linear.weight
self.marlin_q_w = marlin_q_w
self.marlin_s = marlin_s
self.g_idx = g_idx
@ -424,11 +425,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
if mode == InferenceState.PREFILL:
self.generate_linear.unload()
self.prefill_linear.load(w=w)
self.device = self.prefill_linear.device
self.device = self.prefill_linear.device
self.weight = self.prefill_linear.weight # modeling_xxx.py may use linear.weight
elif mode == InferenceState.GENERATE:
self.prefill_linear.unload()
self.generate_linear.load(w=w)
self.device = self.generate_linear.device
self.weight = self.generate_linear.weight # modeling_xxx.py may use linear.weight
elif mode == InferenceState.UNLOAD:
self.prefill_linear.unload()
self.generate_linear.unload()