mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
support KExpertsMarlin backend
This commit is contained in:
parent
0262f954c7
commit
c4d9bc6670
5 changed files with 214 additions and 46 deletions
|
@ -302,13 +302,13 @@ class KExpertsMarlin(KExpertsBase):
|
|||
if w is None: w = self.load_weights()[self.key]
|
||||
|
||||
if isinstance(w, dict):
|
||||
self.gate = nn.Parameter(torch.from_numpy(w["gate"]))
|
||||
self.up = nn.Parameter(torch.from_numpy(w["up"]))
|
||||
self.down = nn.Parameter(torch.from_numpy(w["down"]))
|
||||
self.gate = w["gate"]
|
||||
self.up = (w["up"])
|
||||
self.down = (w["down"])
|
||||
for i in range(self.expert_num):
|
||||
self.up_projs[i].load(self.up[i,...], device=device)
|
||||
self.gate_projs[i].load(self.gate[i,...], device=device)
|
||||
self.down_projs[i].load(self.down[i,...], device=device)
|
||||
self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device)
|
||||
self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device)
|
||||
self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device)
|
||||
self.loaded_experts_idx.append(i)
|
||||
return
|
||||
|
||||
|
@ -342,23 +342,45 @@ class KExpertsMarlin(KExpertsBase):
|
|||
up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"]
|
||||
down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"]
|
||||
# tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"])
|
||||
res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
|
||||
res = {key:{"gate": nn.Parameter(gate), "up": nn.Parameter(up), "down": nn.Parameter(down), "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
|
||||
return res
|
||||
|
||||
def forward(self, input_tensor:torch.Tensor, expert_ids, weights):
|
||||
# forward
|
||||
device = input_tensor.device
|
||||
input_tensor = input_tensor.to("cuda")
|
||||
outs = torch.zeros_like(input_tensor)
|
||||
for expert_idx in range(expert_ids.size(0)):
|
||||
down_proj = self.down_projs[expert_idx]
|
||||
gate_proj = self.gate_projs[expert_idx]
|
||||
up_proj = self.up_projs[expert_idx]
|
||||
def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
|
||||
org_dtype = hidden_states_cpu.dtype
|
||||
org_device = hidden_states_cpu.device
|
||||
hidden_states_cpu = hidden_states_cpu.to(self.device)
|
||||
selected_experts_cpu = selected_experts_cpu.to(self.device)
|
||||
routing_weights_cpu = routing_weights_cpu.to(self.device).to(org_dtype)
|
||||
|
||||
batch_sequence_length, hidden_dim = hidden_states_cpu.size()
|
||||
|
||||
outs += down_proj(self.act_fn(gate_proj(input_tensor)) * up_proj(input_tensor)) * weights[expert_idx]
|
||||
outs = outs.to(device)
|
||||
return outs
|
||||
final_hidden_states = torch.zeros(
|
||||
(batch_sequence_length, hidden_dim), dtype=hidden_states_cpu.dtype, device=hidden_states_cpu.device
|
||||
)
|
||||
# One hot encode the selected experts to create an expert mask
|
||||
# this will be used to easily index which expert is going to be sollicitated
|
||||
expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.expert_num).permute(2, 1, 0)
|
||||
|
||||
# Loop over all available experts in the model and perform the computation on each expert
|
||||
for expert_idx in range(self.expert_num):
|
||||
if not expert_mask[expert_idx].any():
|
||||
continue
|
||||
idx, top_x = torch.where(expert_mask[expert_idx])
|
||||
# Index the correct hidden states and compute the expert hidden state for
|
||||
# the current expert. We need to make sure to multiply the output hidden
|
||||
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
||||
current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim)
|
||||
G = self.gate_projs[expert_idx].forward(current_state)
|
||||
A = self.act_fn(G)
|
||||
U = self.up_projs[expert_idx].forward(current_state)
|
||||
H = A * U # Element-wise multiplication
|
||||
current_hidden_states = self.down_projs[expert_idx].forward(H) * routing_weights_cpu[top_x, idx, None]
|
||||
# However `index_add_` only support torch tensors for indexing so we'll use
|
||||
# the `top_x` tensor here.
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states)
|
||||
|
||||
return final_hidden_states.to(dtype=org_dtype, device=org_device)
|
||||
|
||||
class KExpertsTorch(KExpertsBase):
|
||||
expert_num: int
|
||||
loaded_experts_idx: list[int]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue