mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
Enable support for Intel XPU devices, add support for DeepSeek V2/V3 first
This commit is contained in:
parent
333351c7c8
commit
142fb7ce6c
22 changed files with 673 additions and 81 deletions
|
@ -587,6 +587,100 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(hidden_states)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
||||
query_states = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||
compressed_kv, k_pe = torch.split(
|
||||
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
||||
kv = (
|
||||
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
||||
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
k_nope, value_states = torch.split(
|
||||
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
||||
)
|
||||
kv_seq_len = value_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
position_embeddings = kwargs.get("position_embeddings", None)
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
key_states = torch.cat(
|
||||
[k_nope, k_pe.expand([-1, self.num_heads, -1, -1])],
|
||||
dim=-1
|
||||
)
|
||||
from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced
|
||||
rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim :],
|
||||
key_states[:, :, :, self.qk_nope_head_dim:],
|
||||
cos, sin, True)
|
||||
else:
|
||||
q_nope, q_pe = torch.split(
|
||||
query_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
cos, sin = self.rotary_emb(q_pe, position_ids)
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)
|
||||
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
||||
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
||||
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
||||
|
||||
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
||||
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
|
||||
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states.half(), value_states.half(), self.layer_idx, cache_kwargs
|
||||
)
|
||||
|
||||
attn_weights = None
|
||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
||||
attn_output = scaled_dot_product_attention(
|
||||
query_states.half(), key_states, value_states,
|
||||
attention_mask.half(), q_len == kv_seq_len, self.softmax_scale
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.o_proj(attn_output).to(hidden_states.dtype)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
@ -598,10 +692,21 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if (os.name == 'nt'
|
||||
or get_compute_capability() < 8
|
||||
or hidden_states.device.type == 'cpu'
|
||||
or device_manager.gpu_vendor != GPUVendor.NVIDIA):
|
||||
if torch.xpu.is_available():
|
||||
return self.forward_xpu(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
elif (os.name == 'nt'
|
||||
or get_compute_capability() < 8
|
||||
or hidden_states.device.type == 'cpu'
|
||||
or device_manager.gpu_vendor != GPUVendor.NVIDIA):
|
||||
return self.forward_windows(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue