Enable support for Intel XPU devices, add support for DeepSeek V2/V3 first

This commit is contained in:
rnwang04 2025-05-14 14:28:22 +00:00
parent 333351c7c8
commit 142fb7ce6c
22 changed files with 673 additions and 81 deletions

View file

@ -51,7 +51,10 @@ def generate_cuda_graphs(chunk_size: int) -> list:
return deduplicate_and_sort(base_list + multiples)
#cuda_graphs = [Config().chunk_size]
cuda_graphs = generate_cuda_graphs(Config().chunk_size)
if torch.cuda.is_available():
cuda_graphs = generate_cuda_graphs(Config().chunk_size)
else:
cuda_graphs = 1
# class Base(BaseInjectedModule, ABC):
class KExpertsBase(ABC):
def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs):
@ -177,6 +180,11 @@ class KExpertsCPU(KExpertsBase):
n_routed_experts = self.n_routed_experts
self.cpu_infer = KExpertsCPU.CPU_INFER
# n_routed_experts = len(self.orig_module)
model_dtype = torch.get_default_dtype()
if torch.xpu.is_available() and model_dtype == torch.float16:
hidden_type = 1 # fp16
else:
hidden_type = 30 # bf16
if self.backend == "llamafile":
moe_config = MOEConfig(
n_routed_experts,
@ -192,7 +200,7 @@ class KExpertsCPU(KExpertsBase):
self.gate_type,
self.up_type,
self.down_type,
30, # TODO: get from model.dtype
hidden_type, # TODO: get from model.dtype
)
self.moe = MOE(moe_config)
elif self.backend == "AMXBF16":
@ -252,8 +260,12 @@ class KExpertsCPU(KExpertsBase):
KExpertsCPU.input_tensor_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True)
KExpertsCPU.expert_ids_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
KExpertsCPU.weights_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True)
if torch.xpu.is_available():
KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=model_dtype)
KExpertsCPU.bsz_tensor_cpu = torch.ones((1), device="cpu", dtype=torch.int32, pin_memory=True)
else:
KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True)
def submit_for_one_decode(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
if bsz_tensor is None:
@ -285,9 +297,9 @@ class KExpertsCPU(KExpertsBase):
def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
# generate, capture and run cuda graph
# print(expert_ids)
if bsz_tensor is None:
if bsz_tensor is None and (not torch.xpu.is_available() or input_tensor.size(0) > 1):
bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32)
if torch.cuda.is_current_stream_capturing():
if torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
if cuda_graph_idx != -1:
KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True)
KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True)
@ -307,6 +319,15 @@ class KExpertsCPU(KExpertsBase):
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
return KExpertsCPU.output_gpu_map[self.out_device]
elif input_tensor.size(0)==1 and torch.xpu.is_available():
KExpertsCPU.input_tensor_cpu.copy_(input_tensor.view(-1), non_blocking=True)
KExpertsCPU.expert_ids_cpu.copy_(expert_ids.view(-1), non_blocking=True)
KExpertsCPU.weights_cpu.copy_(weights.view(-1), non_blocking=True)
# KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor.view(-1), non_blocking=True)
self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))
self.cpu_infer.sync()
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
return KExpertsCPU.output_gpu_map[self.out_device].view(1, -1)
else:
input_tensor = input_tensor.contiguous().cpu()
expert_ids = expert_ids.contiguous().cpu()
@ -822,7 +843,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity).squeeze(0)
@ -922,7 +943,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
# only for generate phase
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity).squeeze(0)
@ -1122,7 +1143,7 @@ class KDeepseekV3MoEV2(BaseInjectedModule, DeepseekV3MoE):
# only for generate phase
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity, bsz_tensor).squeeze(0)
@ -1304,7 +1325,7 @@ class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
routing_weights = routing_weights.to(hidden_states.dtype)
# only for generate phase
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
@ -1417,7 +1438,7 @@ class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock):
routing_weights = routing_weights.to(hidden_states.dtype)
# only for generate phase
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_