kvcache-ai-ktransformers/ktransformers/operators/ascend/ascend_experts.py
2025-07-22 10:58:16 +00:00

192 lines
No EOL
10 KiB
Python

import bisect
import acl
import torch
import numpy as np
from torch import nn
from transformers import PretrainedConfig
from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size, get_tensor_parallel_group
from ktransformers.operators.experts import KExpertsCPU, KTransformersExperts, EXPERTS_MAP, KDeepseekV3MoE, cuda_graphs
from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import CUR_DEVICE, get_use_npu_graph, InferenceState
class KExpertsCPUW8A8(KExpertsCPU):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
device: str = "cpu",
**kwargs
):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.input_tensor_cpu_graph = torch.zeros((1, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
self.expert_ids_cpu_graph = torch.zeros((1, self.config.num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
self.weights_cpu_graph = torch.zeros((1, self.config.num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
self.output_cpu_graph = torch.zeros((1, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
self.bsz_tensor_cpu_graph = torch.ones((1), device="cpu", dtype=torch.int32, pin_memory=True)
def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
if get_use_npu_graph():
self.cpu_infer.submit(self.moe.forward(self.expert_ids_cpu_graph.size(0),
self.expert_ids_cpu_graph.size(1),
self.expert_ids_cpu_graph.data_ptr(),
self.weights_cpu_graph.data_ptr(),
self.input_tensor_cpu_graph.data_ptr(),
self.output_cpu_graph.data_ptr(),
self.bsz_tensor_cpu_graph.data_ptr()))
self.cpu_infer.sync()
else:
if bsz_tensor is None:
bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32)
org_type = input_tensor.dtype
input_tensor = input_tensor.contiguous().cpu()
input_tensor = input_tensor.to(torch.bfloat16)
expert_ids = expert_ids.contiguous().cpu()
weights = weights.contiguous().to(torch.float32).cpu()
bsz_tensor = bsz_tensor.contiguous().cpu()
output = torch.empty_like(input_tensor).contiguous()
self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr(), bsz_tensor.data_ptr()))
self.cpu_infer.sync()
return output.to(org_type).to(device=CUR_DEVICE)
EXPERTS_MAP["KExpertsCPUW8A8"] = KExpertsCPUW8A8
class KTransformersExpertsW8A8(KTransformersExperts):
def forward(self, input_tensor, expert_ids, weights):
if self.mode == InferenceState.GENERATE:
assert self.generate_experts is not None, "generate_experts is None"
return self.generate_experts.forward(input_tensor, expert_ids, weights)
elif self.mode == InferenceState.PREFILL:
assert self.prefill_experts is not None, "prefill_experts is None"
return self.prefill_experts.forward(input_tensor, expert_ids, weights)
else:
raise ValueError("load or set_inference_mode before forward")
class KDeepseekV3MoEW8A8(KDeepseekV3MoE):
def forward_tp(self, hidden_states):
identity = hidden_states
orig_shape = hidden_states.shape
rank = torch.distributed.get_rank()
def share_experts_forward():
if self.config.n_shared_experts is not None:
return self.shared_experts(identity).squeeze(0)
if rank == 0:
topk_idx, topk_weight = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
cuda_graph_idx = bisect.bisect_left(cuda_graphs, 1)
if get_use_npu_graph():
from ktransformers.util.npu_graph_runner import get_or_create_runner
npu_graph_runner = get_or_create_runner(CUR_DEVICE)
event = torch.npu.Event()
event.record(npu_graph_runner.main_stream)
with torch.npu.stream(npu_graph_runner.update_stream):
event.wait(npu_graph_runner.update_stream)
y_ = share_experts_forward() if share_experts_forward is not None else None
event.record(npu_graph_runner.update_stream)
org_type = hidden_states.dtype
input_tensor = hidden_states.to(torch.bfloat16)
topk_weight = topk_weight.contiguous().to(torch.float32)
self.moe_kexperts_param = (hidden_states, topk_idx, topk_weight)
self.experts.generate_experts.input_tensor_cpu_graph.copy_(input_tensor, non_blocking=True)
self.experts.generate_experts.expert_ids_cpu_graph.copy_(topk_idx, non_blocking=True)
self.experts.generate_experts.weights_cpu_graph.copy_(topk_weight, non_blocking=True)
npu_graph_runner.launch_callback(
self.cpu_moe_kexperts,
self.moe_kexperts_param,
1, npu_graph_runner.stream)
output_npu_graph = self.experts.generate_experts.output_cpu_graph.to(CUR_DEVICE, non_blocking=True)
y = output_npu_graph.to(org_type)
event.wait(npu_graph_runner.main_stream)
else:
y = self.moe_kexperts(hidden_states, topk_idx, topk_weight)
y_ = share_experts_forward() if share_experts_forward is not None else None
y = y.view(*orig_shape).to(device=hidden_states.device)
else:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
y = torch.zeros(orig_shape, dtype=torch.float16, device=CUR_DEVICE)
y_ = share_experts_forward() if share_experts_forward is not None else None
torch.distributed.all_reduce(y, op=torch.distributed.ReduceOp.SUM, group=get_tensor_parallel_group())
if self.config.n_shared_experts is not None:
y += y_
return y
def forward(self, hidden_states):
tp_size = get_tensor_parallel_size()
world_size = torch.distributed.get_world_size()
if tp_size > 1 and world_size == tp_size:
return self.forward_tp(hidden_states)
identity = hidden_states
orig_shape = hidden_states.shape
sequence_length = orig_shape[1]
topk_idx, topk_weight = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
y_ = None
# only for generate phase
# if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and False:
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity).squeeze(0)
y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)
y += y_
y.resize_(*orig_shape)
return y
def share_experts_forward():
if self.config.n_shared_experts is not None:
return self.shared_experts(identity).squeeze(0)
cuda_graph_idx = bisect.bisect_left(cuda_graphs, 1)
if get_use_npu_graph():
from ktransformers.util.npu_graph_runner import get_or_create_runner
npu_graph_runner = get_or_create_runner(CUR_DEVICE)
event = torch.npu.Event()
event.record(npu_graph_runner.main_stream)
with torch.npu.stream(npu_graph_runner.update_stream):
event.wait(npu_graph_runner.update_stream)
y_ = share_experts_forward() if share_experts_forward is not None else None
event.record(npu_graph_runner.update_stream)
org_type = hidden_states.dtype
input_tensor = hidden_states.to(torch.bfloat16)
topk_weight = topk_weight.contiguous().to(torch.float32)
self.moe_kexperts_param = (hidden_states, topk_idx, topk_weight)
self.experts.generate_experts.input_tensor_cpu_graph.copy_(input_tensor, non_blocking=True)
self.experts.generate_experts.expert_ids_cpu_graph.copy_(topk_idx, non_blocking=True)
self.experts.generate_experts.weights_cpu_graph.copy_(topk_weight, non_blocking=True)
npu_graph_runner.launch_callback(
self.cpu_moe_kexperts,
self.moe_kexperts_param,
1, npu_graph_runner.stream)
output_npu_graph = self.experts.generate_experts.output_cpu_graph.to(CUR_DEVICE, non_blocking=True)
y = output_npu_graph.to(org_type)
event.wait(npu_graph_runner.main_stream)
else:
y = self.moe_kexperts(hidden_states, topk_idx, topk_weight)
y_ = share_experts_forward() if share_experts_forward is not None else None
y = y.view(*orig_shape).to(device=hidden_states.device)
if self.config.n_shared_experts is not None:
y += y_
return y
@torch.no_grad()
def cpu_moe_kexperts(self, moe_kexperts_param) -> torch.Tensor:
x, topk_ids, topk_weight = moe_kexperts_param
self.moe_kexperts(x, topk_ids, topk_weight)
@torch.no_grad()
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = self.experts(x, topk_ids, topk_weight)
return outs