mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
[feature] release 0.1.3
This commit is contained in:
parent
67f8b370c3
commit
4d1d561d28
58 changed files with 11709 additions and 374 deletions
|
@ -7,16 +7,22 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
|||
import torch
|
||||
from torch import nn
|
||||
import warnings
|
||||
import torch.nn.functional as F
|
||||
from ktransformers.operators.models import KLlamaModel
|
||||
from ktransformers.models.configuration_deepseek import DeepseekV2Config
|
||||
from ktransformers.models.configuration_llama import LlamaConfig
|
||||
from ktransformers.models.modeling_llama import LlamaRotaryEmbedding
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
|
||||
from typing import Optional, Tuple
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
import logging
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
logger = logging.getLogger("attention")
|
||||
class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(self,
|
||||
key: str,
|
||||
|
@ -24,10 +30,12 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
|
||||
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
|
||||
|
@ -157,9 +165,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
chunck_size = 256 # TODO, generate chunck_size automatically.
|
||||
|
||||
if q_len <= chunck_size:
|
||||
if q_len <= self.chunck_size:
|
||||
return self.forward_chunck(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -176,24 +183,170 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
cur_idx = 0
|
||||
while cur_idx < q_len:
|
||||
if attention_mask is not None:
|
||||
chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + chunck_size, q_len), ...]
|
||||
chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + self.chunck_size, q_len), ...]
|
||||
else:
|
||||
chunk_mask = None
|
||||
# generate chunk_mask automatically.
|
||||
self.attn_mask = \
|
||||
torch.zeros(1, 1, self.chunck_size, past_key_value.max_cache_len, device=hidden_states.device) \
|
||||
if self.attn_mask is None \
|
||||
else self.attn_mask
|
||||
self.attn_mask[:, :, :, cur_idx:min(cur_idx+self.chunck_size, past_key_value.max_cache_len)] = \
|
||||
-1e+38 * torch.triu(torch.ones(self.chunck_size, self.chunck_size, device=hidden_states.device), diagonal=1)\
|
||||
[:,:min(self.chunck_size, min(past_key_value.max_cache_len-cur_idx, self.chunck_size))]
|
||||
self.attn_mask[:, :, :, cur_idx+self.chunck_size:] = -1e+38
|
||||
self.attn_mask[:, :, :, :cur_idx] = 0
|
||||
chunck_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len-cur_idx))
|
||||
|
||||
cur_output, _, _ = self.forward_chunck(
|
||||
hidden_states[:, cur_idx:min(cur_idx + chunck_size, q_len), ...],
|
||||
chunk_mask,
|
||||
position_ids[:, cur_idx:min(cur_idx + chunck_size, q_len)],
|
||||
hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...],
|
||||
chunck_mask,
|
||||
position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)],
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position[cur_idx:min(cur_idx + chunck_size, q_len)],
|
||||
cache_position[cur_idx:min(cur_idx + self.chunck_size, q_len)],
|
||||
**kwargs
|
||||
)
|
||||
cur_idx += chunck_size
|
||||
cur_idx += self.chunck_size
|
||||
if attn_output is None:
|
||||
attn_output = cur_output
|
||||
else:
|
||||
attn_output = torch.cat((attn_output, cur_output), dim=-2)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
|
||||
|
||||
class KLlamaAttention(BaseInjectedModule):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is None:
|
||||
|
||||
logger.warning(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
if q_len == 1:
|
||||
position_ids = position_ids[0][-1].unsqueeze(0).unsqueeze(0)
|
||||
query_states = query_states[:, :, -1:]
|
||||
key_states = key_states[:, :, -1:]
|
||||
|
||||
attn_output = KLlamaModel.dynamic_sdpa.apply(
|
||||
self.layer_idx,
|
||||
bsz,
|
||||
position_ids[0][0],
|
||||
query_states.transpose(1, 2).to(torch.float16),
|
||||
key_states.transpose(1, 2).to(torch.float16),
|
||||
value_states.transpose(1, 2).to(torch.float16),
|
||||
mode="prefill" if q_len > 1 else "generate",
|
||||
)
|
||||
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
Loading…
Add table
Add a link
Reference in a new issue