mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
Enable support for Intel XPU devices, add support for DeepSeek V2/V3 first
This commit is contained in:
parent
333351c7c8
commit
142fb7ce6c
22 changed files with 673 additions and 81 deletions
|
@ -51,7 +51,10 @@ def generate_cuda_graphs(chunk_size: int) -> list:
|
|||
|
||||
return deduplicate_and_sort(base_list + multiples)
|
||||
#cuda_graphs = [Config().chunk_size]
|
||||
cuda_graphs = generate_cuda_graphs(Config().chunk_size)
|
||||
if torch.cuda.is_available():
|
||||
cuda_graphs = generate_cuda_graphs(Config().chunk_size)
|
||||
else:
|
||||
cuda_graphs = 1
|
||||
# class Base(BaseInjectedModule, ABC):
|
||||
class KExpertsBase(ABC):
|
||||
def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs):
|
||||
|
@ -177,6 +180,11 @@ class KExpertsCPU(KExpertsBase):
|
|||
n_routed_experts = self.n_routed_experts
|
||||
self.cpu_infer = KExpertsCPU.CPU_INFER
|
||||
# n_routed_experts = len(self.orig_module)
|
||||
model_dtype = torch.get_default_dtype()
|
||||
if torch.xpu.is_available() and model_dtype == torch.float16:
|
||||
hidden_type = 1 # fp16
|
||||
else:
|
||||
hidden_type = 30 # bf16
|
||||
if self.backend == "llamafile":
|
||||
moe_config = MOEConfig(
|
||||
n_routed_experts,
|
||||
|
@ -192,7 +200,7 @@ class KExpertsCPU(KExpertsBase):
|
|||
self.gate_type,
|
||||
self.up_type,
|
||||
self.down_type,
|
||||
30, # TODO: get from model.dtype
|
||||
hidden_type, # TODO: get from model.dtype
|
||||
)
|
||||
self.moe = MOE(moe_config)
|
||||
elif self.backend == "AMXBF16":
|
||||
|
@ -252,8 +260,12 @@ class KExpertsCPU(KExpertsBase):
|
|||
KExpertsCPU.input_tensor_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True)
|
||||
KExpertsCPU.expert_ids_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
|
||||
KExpertsCPU.weights_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
|
||||
KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||
KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True)
|
||||
if torch.xpu.is_available():
|
||||
KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=model_dtype)
|
||||
KExpertsCPU.bsz_tensor_cpu = torch.ones((1), device="cpu", dtype=torch.int32, pin_memory=True)
|
||||
else:
|
||||
KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||
KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True)
|
||||
|
||||
def submit_for_one_decode(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
|
||||
if bsz_tensor is None:
|
||||
|
@ -285,9 +297,9 @@ class KExpertsCPU(KExpertsBase):
|
|||
def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
|
||||
# generate, capture and run cuda graph
|
||||
# print(expert_ids)
|
||||
if bsz_tensor is None:
|
||||
if bsz_tensor is None and (not torch.xpu.is_available() or input_tensor.size(0) > 1):
|
||||
bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32)
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
if torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
|
||||
if cuda_graph_idx != -1:
|
||||
KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True)
|
||||
KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True)
|
||||
|
@ -307,6 +319,15 @@ class KExpertsCPU(KExpertsBase):
|
|||
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
|
||||
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
|
||||
return KExpertsCPU.output_gpu_map[self.out_device]
|
||||
elif input_tensor.size(0)==1 and torch.xpu.is_available():
|
||||
KExpertsCPU.input_tensor_cpu.copy_(input_tensor.view(-1), non_blocking=True)
|
||||
KExpertsCPU.expert_ids_cpu.copy_(expert_ids.view(-1), non_blocking=True)
|
||||
KExpertsCPU.weights_cpu.copy_(weights.view(-1), non_blocking=True)
|
||||
# KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor.view(-1), non_blocking=True)
|
||||
self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))
|
||||
self.cpu_infer.sync()
|
||||
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
|
||||
return KExpertsCPU.output_gpu_map[self.out_device].view(1, -1)
|
||||
else:
|
||||
input_tensor = input_tensor.contiguous().cpu()
|
||||
expert_ids = expert_ids.contiguous().cpu()
|
||||
|
@ -822,7 +843,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
|
|||
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
|
||||
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
|
||||
if self.config.n_shared_experts is not None:
|
||||
y_ = self.shared_experts(identity).squeeze(0)
|
||||
|
@ -922,7 +943,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
|
|||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
# only for generate phase
|
||||
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
|
||||
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
|
||||
if self.config.n_shared_experts is not None:
|
||||
y_ = self.shared_experts(identity).squeeze(0)
|
||||
|
@ -1122,7 +1143,7 @@ class KDeepseekV3MoEV2(BaseInjectedModule, DeepseekV3MoE):
|
|||
|
||||
|
||||
# only for generate phase
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y_ = self.shared_experts(identity, bsz_tensor).squeeze(0)
|
||||
|
@ -1304,7 +1325,7 @@ class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
|
|||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
# only for generate phase
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
|
||||
y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
|
@ -1417,7 +1438,7 @@ class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock):
|
|||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
# only for generate phase
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
|
||||
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue