mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 20:19:51 +00:00
linux support triton MLA kernel
This commit is contained in:
parent
bb35dc5b0d
commit
1084d4e4b4
2 changed files with 198 additions and 61 deletions
|
@ -53,8 +53,9 @@ class StaticCache(transformers.StaticCache):
|
|||
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
|
||||
self.page_size = 64
|
||||
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
|
||||
key_shape = (self.max_pages, self.page_size, 1, config.qk_rope_head_dim)
|
||||
value_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank)
|
||||
latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
# TODO: support real page table
|
||||
self.page_table_map = dict()
|
||||
self.page_table_list = []
|
||||
|
@ -88,10 +89,17 @@ class StaticCache(transformers.StaticCache):
|
|||
target_device = device[f"blk.{idx}.self_attn"]["generate_device"]
|
||||
else:
|
||||
target_device = device
|
||||
new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=target_device)
|
||||
new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=target_device)
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
||||
|
||||
if self.is_MLA:
|
||||
new_layer_key_cache = torch.zeros(latent_shape, dtype=self.dtype, device=target_device)
|
||||
new_layer_value_cache = None
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
else:
|
||||
new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=target_device)
|
||||
new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=target_device)
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
||||
|
||||
self.key_cache.append(new_layer_key_cache)
|
||||
self.value_cache.append(new_layer_value_cache)
|
||||
self.past_tokens.append(0)
|
||||
|
@ -129,11 +137,12 @@ class StaticCache(transformers.StaticCache):
|
|||
if self.is_MLA:
|
||||
page_idx = cache_position // self.page_size
|
||||
page_offset = cache_position % self.page_size
|
||||
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||||
#print("page_idx", page_idx)
|
||||
#print("page_offset", page_offset)
|
||||
k_out[page_idx, page_offset, ...] = key_states
|
||||
v_out[page_idx, page_offset, ...] = value_states
|
||||
return k_out, v_out, self.page_table_list[layer_idx]
|
||||
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
|
||||
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
|
||||
return k_out, self.page_table_list[layer_idx]
|
||||
else:
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
|
|
|
@ -13,8 +13,6 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
|
|||
from ktransformers.models.configuration_llama import LlamaConfig
|
||||
from ktransformers.models.modeling_llama import LlamaRotaryEmbedding
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
|
||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3Attention
|
||||
from ktransformers.models.modeling_deepseek_v3 import apply_rotary_pos_emb as apply_rotary_pos_emb_v3
|
||||
from typing import Optional, Tuple
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
|
@ -23,8 +21,15 @@ from transformers.configuration_utils import PretrainedConfig
|
|||
from transformers.cache_utils import Cache
|
||||
from flash_attn import flash_attn_with_kvcache, flash_attn_func
|
||||
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
|
||||
import os
|
||||
logger = logging.getLogger("attention")
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
# V3 MLA is same to V2
|
||||
class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||
|
@ -80,6 +85,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
q_nope, q_pe = torch.split(
|
||||
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
|
||||
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||
compressed_kv, k_pe = torch.split(
|
||||
|
@ -103,16 +110,37 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
compressed_kv = compressed_kv.unsqueeze(1)
|
||||
k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs)
|
||||
compressed_kv = compressed_kv.squeeze(1)
|
||||
#if cache_position is not None:
|
||||
# compressed_kv = compressed_kv[:,: cache_position[-1] + 1,:]
|
||||
# k_pe = k_pe[:,:,: cache_position[-1] + 1,:]
|
||||
|
||||
# compressed_kv [bsz, q_len, self.kv_lora_rank]
|
||||
# k_pe [bsz, 1, q_len, self.qk_rope_head_dim]
|
||||
k_pe = k_pe.transpose(1,2)
|
||||
compressed_kv = compressed_kv.unsqueeze(2)
|
||||
compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||
compressed_kv, k_pe = torch.split(
|
||||
compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
# k_pe [pages, page_size, 1, self.qk_rope_head_dim]
|
||||
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
|
||||
|
||||
q_absorb, out_absorb = self.get_absorbed()
|
||||
if hasattr(self.orig_module, 'kv_b_proj'):
|
||||
del self.orig_module.kv_b_proj
|
||||
|
||||
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
|
||||
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
|
||||
k_pe = k_pe.view(bsz, 1, -1, self.qk_rope_head_dim)[:,:,:attention_mask.size(-1),:]
|
||||
compressed_kv = compressed_kv.view(bsz, 1, -1, self.kv_lora_rank)[:,:,:attention_mask.size(-1),:]
|
||||
# k_pe [bsz, 1, cache_len, self.qk_rope_head_dim]
|
||||
# compressed_kv [bsz, 1, cache_len,self.kv_lora_rank]
|
||||
q_nope = torch.matmul(q_nope, q_absorb)
|
||||
attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale
|
||||
#print(q_pe.shape)
|
||||
#print(k_pe.shape)
|
||||
#print(q_nope.shape)
|
||||
#print(compressed_kv.shape)
|
||||
|
||||
attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale
|
||||
#attn_weights [bsz, self.num_heads, q_len, kv_seq_len]
|
||||
compressed_kv = compressed_kv.squeeze(1)
|
||||
"""
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
|
@ -156,25 +184,25 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
|
||||
def forward_linux(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(hidden_states)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
||||
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
||||
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim)
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
@ -184,38 +212,42 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
||||
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
||||
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)
|
||||
compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank)
|
||||
|
||||
cos, sin = self.rotary_emb(q_pe, position_ids)
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)
|
||||
k_pe = k_pe.transpose(1, 2) # [bsz, q_len, 1, self.qk_rope_head_dim]
|
||||
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
|
||||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
|
||||
|
||||
# decode
|
||||
if q_len == 1:
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
k_pe, compressed_kv, page_table = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs)
|
||||
|
||||
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
|
||||
compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||
compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank] # for speed
|
||||
# compressed_kv_with_k_pe [bsz, q_len, 1, self.kv_lora_rank + self.qk_rope_head_dim]
|
||||
# compressed_kv [bsz, q_len, 1, self.kv_lora_rank]
|
||||
|
||||
# q_nope [bsz, q_len, self.num_heads, self.qk_nope_head_dim]
|
||||
# q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank]
|
||||
q_absorb, out_absorb = self.get_absorbed()
|
||||
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
|
||||
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
||||
# q_nope [bsz, self.num_heads, q_len, self.kv_lora_rank]
|
||||
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
|
||||
query_states = torch.cat([q_nope, q_pe], dim=-1)
|
||||
# k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
|
||||
# compressed_kv [bsz, q_len, 1, self.kv_lora_rank]
|
||||
key_states = torch.cat([compressed_kv, k_pe], dim=-1)
|
||||
q_nope = q_nope.transpose(1, 2)
|
||||
assert q_nope.is_contiguous()
|
||||
|
||||
query_states = query_states.squeeze(2)
|
||||
attn_output = torch.zeros_like(q_nope)
|
||||
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
|
||||
query_states = torch.cat([q_nope, q_pe], dim=-1)
|
||||
|
||||
query_states = query_states.squeeze(1)
|
||||
attn_output = torch.zeros_like(q_nope) # [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||
|
||||
attn_logits = torch.empty(
|
||||
(
|
||||
bsz,
|
||||
self.num_heads,
|
||||
1, #num_kv_splits # follow vLLM, fix it TODO
|
||||
4, #num_kv_splits # follow vLLM, fix it TODO
|
||||
self.kv_lora_rank + 1,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
|
@ -224,22 +256,25 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
|
||||
"""
|
||||
print("query_states", torch.isnan(query_states).any())
|
||||
print("key_states", torch.isnan(key_states[:,:,0,:]).any())
|
||||
print("compressed_kv_with_k_pe", torch.isnan(compressed_kv_with_k_pe[:,:,0,:]).any())
|
||||
print("compressed_kv", torch.isnan(compressed_kv[:,:,0,:]).any())
|
||||
print("position_ids", torch.isnan(position_ids).any())
|
||||
"""
|
||||
|
||||
# flash attn doesn't support head_dim bigger than 256
|
||||
# flash attn doesn't support head_dim bigger than 256
|
||||
# use vLLM triton attention kernel for MQA
|
||||
decode_attention_fwd_grouped(query_states, key_states, compressed_kv, attn_output,
|
||||
decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output,
|
||||
page_table,
|
||||
position_ids.squeeze(0).to(torch.int32), attn_logits,
|
||||
1, #num_kv_splits # follow vLLM, fix it TODO
|
||||
4, #num_kv_splits # follow vLLM, fix it TODO
|
||||
self.softmax_scale,
|
||||
past_key_value.page_size)
|
||||
|
||||
attn_output = torch.matmul(attn_output, out_absorb.mT)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = torch.matmul(attn_output, out_absorb.mT)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
|
@ -250,7 +285,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
k_pe.squeeze(0)
|
||||
compressed_kv.squeeze(0)
|
||||
past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs)
|
||||
past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||
k_pe.unsqueeze(0)
|
||||
compressed_kv.unsqueeze(0)
|
||||
|
||||
|
@ -261,7 +296,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
)
|
||||
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
||||
query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
|
||||
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
||||
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
||||
|
||||
|
@ -269,7 +304,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
|
||||
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_heads, self.v_head_dim)
|
||||
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
|
||||
|
||||
|
@ -289,12 +323,106 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
def forward_windows(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
if q_len <= self.chunck_size:
|
||||
return self.forward_chunck(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
assert output_attentions == False, "output_attentions is not supported when using chunked attention"
|
||||
attn_output = None
|
||||
cur_idx = 0
|
||||
while cur_idx < q_len:
|
||||
if attention_mask is not None:
|
||||
chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + self.chunck_size, q_len), ...]
|
||||
else:
|
||||
# generate chunk_mask automatically.
|
||||
self.attn_mask = \
|
||||
torch.zeros(1, 1, self.chunck_size, past_key_value.max_cache_len, device=hidden_states.device) \
|
||||
if self.attn_mask is None \
|
||||
else self.attn_mask
|
||||
self.attn_mask[:, :, :, cur_idx:min(cur_idx+self.chunck_size, past_key_value.max_cache_len)] = \
|
||||
-1e+38 * torch.triu(torch.ones(self.chunck_size, self.chunck_size, device=hidden_states.device), diagonal=1)\
|
||||
[:,:min(self.chunck_size, min(past_key_value.max_cache_len-cur_idx, self.chunck_size))]
|
||||
self.attn_mask[:, :, :, cur_idx+self.chunck_size:] = -1e+38
|
||||
self.attn_mask[:, :, :, :cur_idx] = 0
|
||||
chunk_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len-cur_idx))
|
||||
|
||||
cur_output, _, _ = self.forward_chunck(
|
||||
hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...],
|
||||
chunk_mask,
|
||||
position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)],
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position[cur_idx:min(cur_idx + self.chunck_size, q_len)],
|
||||
**kwargs
|
||||
)
|
||||
cur_idx += self.chunck_size
|
||||
if attn_output is None:
|
||||
attn_output = cur_output
|
||||
else:
|
||||
attn_output = torch.cat((attn_output, cur_output), dim=-2)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if os.name == 'nt':
|
||||
return self.forward_windows(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return self.forward_linux(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class KLlamaAttention(BaseInjectedModule):
|
||||
|
|
Loading…
Add table
Reference in a new issue