Enable support for Intel XPU devices, add support for DeepSeek V2/V3 first

This commit is contained in:
rnwang04 2025-05-14 14:28:22 +00:00
parent 333351c7c8
commit 142fb7ce6c
22 changed files with 673 additions and 81 deletions

View file

@ -587,6 +587,100 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
return attn_output, None, past_key_value
def forward_xpu(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
query_states = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)
)
k_nope, value_states = torch.split(
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
position_embeddings = kwargs.get("position_embeddings", None)
if position_embeddings is not None:
cos, sin = position_embeddings
key_states = torch.cat(
[k_nope, k_pe.expand([-1, self.num_heads, -1, -1])],
dim=-1
)
from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced
rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim :],
key_states[:, :, :, self.qk_nope_head_dim:],
cos, sin, True)
else:
q_nope, q_pe = torch.split(
query_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
cos, sin = self.rotary_emb(q_pe, position_ids)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states.half(), value_states.half(), self.layer_idx, cache_kwargs
)
attn_weights = None
from ipex_llm.transformers.models.common import scaled_dot_product_attention
attn_output = scaled_dot_product_attention(
query_states.half(), key_states, value_states,
attention_mask.half(), q_len == kv_seq_len, self.softmax_scale
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output).to(hidden_states.dtype)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def forward(
self,
hidden_states: torch.Tensor,
@ -598,10 +692,21 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if (os.name == 'nt'
or get_compute_capability() < 8
or hidden_states.device.type == 'cpu'
or device_manager.gpu_vendor != GPUVendor.NVIDIA):
if torch.xpu.is_available():
return self.forward_xpu(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
**kwargs,
)
elif (os.name == 'nt'
or get_compute_capability() < 8
or hidden_states.device.type == 'cpu'
or device_manager.gpu_vendor != GPUVendor.NVIDIA):
return self.forward_windows(
hidden_states,
attention_mask,