From bb35dc5b0dd11a3d580207f19f5a749c380452c8 Mon Sep 17 00:00:00 2001 From: Atream Date: Thu, 13 Feb 2025 15:01:14 +0000 Subject: [PATCH 1/2] init support for MLA using Attention kernel --- ktransformers/models/custom_cache.py | 45 ++- ktransformers/operators/attention.py | 376 +++++++------------ ktransformers/operators/triton_attention.py | 379 ++++++++++++++++++++ ktransformers/util/utils.py | 4 +- test_prompt.txt | 9 + 5 files changed, 551 insertions(+), 262 deletions(-) create mode 100644 ktransformers/operators/triton_attention.py create mode 100644 test_prompt.txt diff --git a/ktransformers/models/custom_cache.py b/ktransformers/models/custom_cache.py index e402506..0849d55 100644 --- a/ktransformers/models/custom_cache.py +++ b/ktransformers/models/custom_cache.py @@ -51,13 +51,33 @@ class StaticCache(transformers.StaticCache): cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) if config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM": # TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically - # key_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.qk_rope_head_dim + config.qk_nope_head_dim) - # value_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.v_head_dim) - key_shape = (max_batch_size, 1, self.max_cache_len, config.qk_rope_head_dim) - value_shape = (max_batch_size, 1, self.max_cache_len, config.kv_lora_rank) + self.page_size = 64 + self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size + key_shape = (self.max_pages, self.page_size, 1, config.qk_rope_head_dim) + value_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank) + # TODO: support real page table + self.page_table_map = dict() + self.page_table_list = [] + for idx in range(config.num_hidden_layers): + if isinstance(device, dict): + target_device = device[f"blk.{idx}.self_attn"]["generate_device"] + else: + target_device = device + + if target_device not in self.page_table_map: + page_table = torch.zeros((max_batch_size, self.max_pages), dtype=torch.int32, device=target_device) + for seq_id in range(max_batch_size): + page_table[seq_id, :] = torch.arange(seq_id * self.max_pages, seq_id * self.max_pages + self.max_pages, dtype=torch.int32, device=target_device) + self.page_table_map[target_device] = page_table + + self.page_table_list.append(self.page_table_map[target_device]) + + self.is_MLA = True + self.is_page = True else: key_shape = cache_shape value_shape = cache_shape + self.is_MLA = False self.past_tokens = [] self.num_hidden_layers = config.num_hidden_layers @@ -104,11 +124,20 @@ class StaticCache(transformers.StaticCache): cache_position = cache_kwargs.get("cache_position") k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] - #print(cache_position) - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states self.past_tokens[layer_idx] += cache_position.size(0) - return k_out, v_out + #print(cache_position) + if self.is_MLA: + page_idx = cache_position // self.page_size + page_offset = cache_position % self.page_size + #print("page_idx", page_idx) + #print("page_offset", page_offset) + k_out[page_idx, page_offset, ...] = key_states + v_out[page_idx, page_offset, ...] = value_states + return k_out, v_out, self.page_table_list[layer_idx] + else: + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + return k_out, v_out def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model.""" diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index 9b47b89..f487773 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -21,207 +21,12 @@ from ktransformers.util.custom_gguf import GGUFLoader import logging from transformers.configuration_utils import PretrainedConfig from transformers.cache_utils import Cache +from flash_attn import flash_attn_with_kvcache, flash_attn_func +from ktransformers.operators.triton_attention import decode_attention_fwd_grouped logger = logging.getLogger("attention") -class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention): - """Multi-headed attention from 'Attention Is All You Need' paper""" - attn_mask: Optional[torch.Tensor] = None - - def __init__(self, - key: str, - gguf_loader : GGUFLoader, - config: PretrainedConfig, - orig_module: nn.Module, - device: str = "cuda", - chunck_size: int = 1000, - **kwargs): - BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) - self.orig_module.__init__(orig_module.config, - orig_module.layer_idx) - self.chunck_size = chunck_size # TODO, generate chunck_size automatically. - self.softmax_scale = self.q_head_dim ** (-0.5) - - def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]: - if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')): - kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) - q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank) - out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank) - self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, - bias=False, dtype=q_absorb.dtype, device=q_absorb.device) - self.q_absorb.weight.data = q_absorb - self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim, - bias=False, dtype=out_absorb.dtype, device=out_absorb.device) - self.out_absorb.weight.data = out_absorb - del self.orig_module.kv_b_proj - q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank) - out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank) - return q_absorb, out_absorb - - def forward_chunck( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - if self.q_lora_rank is None: - q = self.q_proj(hidden_states) - else: - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - q_nope, q_pe = torch.split( - q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 - ) - - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv, k_pe = torch.split( - compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 - ) - compressed_kv = self.kv_a_layernorm(compressed_kv) - k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) - - kv_seq_len = k_pe.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - cos, sin = self.rotary_emb(q_pe, position_ids) - q_pe, k_pe = apply_rotary_pos_emb_v3(q_pe, k_pe, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - compressed_kv = compressed_kv.unsqueeze(1) - k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs) - compressed_kv = compressed_kv.squeeze(1) - #if cache_position is not None: - # compressed_kv = compressed_kv[:,: cache_position[-1] + 1,:] - # k_pe = k_pe[:,:,: cache_position[-1] + 1,:] - q_absorb, out_absorb = self.get_absorbed() - - q_nope = torch.matmul(q_nope, q_absorb) - attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale - """ - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - assert attention_mask is not None - """ - if attention_mask is not None: - """ - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - """ - #causal_mask = attention_mask[:, :, :, : kv_seq_len] - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(q_pe.dtype) - attn_weights = nn.functional.dropout( - attn_weights, p=self.attention_dropout, training=self.training - ) - attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv) - - attn_output = torch.matmul(attn_output, out_absorb.mT) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) - - attn_output = self.o_proj(attn_output) - - return attn_output, attn_weights, past_key_value - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - - if q_len <= self.chunck_size: - return self.forward_chunck( - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - use_cache, - cache_position, - **kwargs - ) - - assert output_attentions == False, "output_attentions is not supported when using chunked attention" - attn_output = None - attn_weight = None - cur_idx = 0 - while cur_idx < q_len: - if attention_mask is not None: - chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + self.chunck_size, q_len), ...] - else: - # generate chunk_mask automatically. - self.attn_mask = \ - torch.zeros(1, 1, self.chunck_size, past_key_value.max_cache_len, device=hidden_states.device) \ - if self.attn_mask is None \ - else self.attn_mask - self.attn_mask[:, :, :, cur_idx:min(cur_idx+self.chunck_size, past_key_value.max_cache_len)] = \ - -1e+38 * torch.triu(torch.ones(self.chunck_size, self.chunck_size, device=hidden_states.device), diagonal=1)\ - [:,:min(self.chunck_size, min(past_key_value.max_cache_len-cur_idx, self.chunck_size))] - self.attn_mask[:, :, :, cur_idx+self.chunck_size:] = -1e+38 - self.attn_mask[:, :, :, :cur_idx] = 0 - chunk_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len-cur_idx)) - - cur_output, cur_attn_weight = self.forward_chunck( - hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...], - chunk_mask, - position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)], - past_key_value, - output_attentions, - use_cache, - cache_position[cur_idx:min(cur_idx + self.chunck_size, q_len)], - **kwargs - ) - cur_idx += self.chunck_size - if attn_output is None: - attn_output = cur_output - attn_weight = cur_attn_weight - else: - attn_output = torch.cat((attn_output, cur_output), dim=-2) - attn_weight = torch.cat((attn_weight, cur_attn_weight), dim=-2) - - return attn_output, attn_weight, past_key_value +# V3 MLA is same to V2 class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): """Multi-headed attention from 'Attention Is All You Need' paper""" attn_mask: Optional[torch.Tensor] = None @@ -250,7 +55,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim, bias=False, dtype=out_absorb.dtype, device=out_absorb.device) self.out_absorb.weight.data = out_absorb - del self.orig_module.kv_b_proj + #del self.orig_module.kv_b_proj q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank) out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank) return q_absorb, out_absorb @@ -362,60 +167,129 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - if q_len <= self.chunck_size: - return self.forward_chunck( - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - use_cache, - cache_position, - **kwargs - ) + bsz, q_len, _ = hidden_states.size() - assert output_attentions == False, "output_attentions is not supported when using chunked attention" - attn_output = None - cur_idx = 0 - while cur_idx < q_len: - if attention_mask is not None: - chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + self.chunck_size, q_len), ...] - else: - # generate chunk_mask automatically. - self.attn_mask = \ - torch.zeros(1, 1, self.chunck_size, past_key_value.max_cache_len, device=hidden_states.device) \ - if self.attn_mask is None \ - else self.attn_mask - self.attn_mask[:, :, :, cur_idx:min(cur_idx+self.chunck_size, past_key_value.max_cache_len)] = \ - -1e+38 * torch.triu(torch.ones(self.chunck_size, self.chunck_size, device=hidden_states.device), diagonal=1)\ - [:,:min(self.chunck_size, min(past_key_value.max_cache_len-cur_idx, self.chunck_size))] - self.attn_mask[:, :, :, cur_idx+self.chunck_size:] = -1e+38 - self.attn_mask[:, :, :, :cur_idx] = 0 - chunk_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len-cur_idx)) + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + compressed_kv = self.kv_a_layernorm(compressed_kv) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank) + + cos, sin = self.rotary_emb(q_pe, position_ids) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin) + k_pe = k_pe.transpose(1, 2) # [bsz, q_len, 1, self.qk_rope_head_dim] + + # decode + if q_len == 1: + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + k_pe, compressed_kv, page_table = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs) + + # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim] + # q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank] + q_absorb, out_absorb = self.get_absorbed() + q_nope = torch.matmul(q_nope, q_absorb) # batched MM + # q_nope [bsz, self.num_heads, q_len, self.kv_lora_rank] + # q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim] + query_states = torch.cat([q_nope, q_pe], dim=-1) + # k_pe [bsz, q_len, 1, self.qk_rope_head_dim] + # compressed_kv [bsz, q_len, 1, self.kv_lora_rank] + key_states = torch.cat([compressed_kv, k_pe], dim=-1) + + query_states = query_states.squeeze(2) + attn_output = torch.zeros_like(q_nope) + + attn_logits = torch.empty( + ( + bsz, + self.num_heads, + 1, #num_kv_splits # follow vLLM, fix it TODO + self.kv_lora_rank + 1, + ), + dtype=torch.float32, + device = attn_output.device + ) + + """ + print("query_states", torch.isnan(query_states).any()) + print("key_states", torch.isnan(key_states[:,:,0,:]).any()) + print("compressed_kv", torch.isnan(compressed_kv[:,:,0,:]).any()) + print("position_ids", torch.isnan(position_ids).any()) + """ + + # flash attn doesn't support head_dim bigger than 256 + # use vLLM triton attention kernel for MQA + decode_attention_fwd_grouped(query_states, key_states, compressed_kv, attn_output, + page_table, + position_ids.squeeze(0).to(torch.int32), attn_logits, + 1, #num_kv_splits # follow vLLM, fix it TODO + self.softmax_scale, + past_key_value.page_size) + + attn_output = torch.matmul(attn_output, out_absorb.mT) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output) + + #print("attn_output", torch.isnan(attn_output).any()) + return attn_output, None, past_key_value + else: + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + k_pe.squeeze(0) + compressed_kv.squeeze(0) + past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs) + k_pe.unsqueeze(0) + compressed_kv.unsqueeze(0) + + k_pe = k_pe[:, :q_len] + compressed_kv = compressed_kv[:, :q_len] + kv = ( + self.kv_b_proj(compressed_kv) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + ) + k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim) + key_states[:, :, :, :self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim:] = k_pe + + query_states = query_states.transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_heads, self.v_head_dim) + value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0) + + attn_output = flash_attn_func( + query_states, + key_states, + value_states_padded, + softmax_scale=self.softmax_scale, + causal=True, + ) + + if self.q_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape( + bsz, q_len, self.num_heads * self.v_head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, None, past_key_value - cur_output, _, _ = self.forward_chunck( - hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...], - chunk_mask, - position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)], - past_key_value, - output_attentions, - use_cache, - cache_position[cur_idx:min(cur_idx + self.chunck_size, q_len)], - **kwargs - ) - cur_idx += self.chunck_size - if attn_output is None: - attn_output = cur_output - else: - attn_output = torch.cat((attn_output, cur_output), dim=-2) - - return attn_output, None, past_key_value def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -423,8 +297,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) - - class KLlamaAttention(BaseInjectedModule): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/ktransformers/operators/triton_attention.py b/ktransformers/operators/triton_attention.py new file mode 100644 index 0000000..d622be1 --- /dev/null +++ b/ktransformers/operators/triton_attention.py @@ -0,0 +1,379 @@ +import triton +import triton.language as tl + +@triton.jit +def tanh(x): + # Tanh is just a scaled sigmoid + return 2 * tl.sigmoid(2 * x) - 1 + +@triton.jit +def _fwd_grouped_kernel_stage1( + Q, + K_Buffer, + V_Buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + Att_Out, + stride_req_to_tokens_b, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + q_head_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head_id = tl.program_id(1) + cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) + split_kv_id = tl.program_id(2) + + if kv_group_num > BLOCK_H: + VALID_BLOCK_H: tl.constexpr = BLOCK_H + else: + VALID_BLOCK_H: tl.constexpr = kv_group_num + cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) + mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H + mask_h = mask_h & (cur_head < q_head_num) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = cur_batch + + offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[ + None, :] + q = tl.load(Q + offs_q, + mask=(mask_h[:, None]) & (mask_d[None, :]), + other=0.0) + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + mask_dpe = offs_dpe < Lk + off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + + offs_dpe[None, :]) + qpe = tl.load(Q + off_qpe, + mask=(mask_h[:, None]) & (mask_dpe[None, :]), + other=0.0) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, + cur_batch_seq_len) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_buf_k = (kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + offs_d[:, None]) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), + other=0.0, + ) + qk = tl.dot(q, k.to(q.dtype)) + + if BLOCK_DPE > 0: + offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None]) + kpe = tl.load( + K_Buffer + offs_buf_kpe, + mask=(offs_n[None, :] < split_kv_end) & + (mask_dpe[:, None]), + other=0.0, + ) + qk += tl.dot(qpe, kpe.to(qpe.dtype)) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end), + qk, float("-inf")) + + offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + offs_dv[None, :]) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v.dtype), v) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + + offs_mid_o = (cur_batch * stride_mid_ob + + cur_head[:, None] * stride_mid_oh + + split_kv_id * stride_mid_os + offs_dv[None, :]) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum[:, None], + mask=(mask_h[:, None]) & (mask_dv[None, :]), + ) + + offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + Lv) + + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + mask=mask_h, + ) + +def _decode_grouped_att_m_fwd( + q, + k_buffer, + v_buffer, + att_out, + Req_to_tokens, + B_Seqlen, + num_kv_splits, + sm_scale, + page_size, + logit_cap, +): + BLOCK = 32 + Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] + + # [TODO] work around shmem limit on MI3xx + + # TODO: support hip + #if is_hip_ and Lk >= 576: + # BLOCK = 16 + + if Lk == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + elif Lk == 288: + BLOCK_DMODEL = 256 + BLOCK_DPE = 32 + else: + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DPE = 0 + BLOCK_DV = triton.next_power_of_2(Lv) + + batch, head_num = q.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k_buffer.shape[-2] + + BLOCK_H = 16 + NUM_KV_SPLITS = num_kv_splits + grid = ( + batch, + triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), + NUM_KV_SPLITS, + ) + + extra_kargs = {} + # TODO: support hip + """ + if is_hip_: + # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html + # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py + extra_kargs = { + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } + """ + + _fwd_grouped_kernel_stage1[grid]( + q, + k_buffer, + v_buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + att_out, + Req_to_tokens.stride(0), + q.stride(0), + q.stride(1), + k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + q_head_num=head_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK, + BLOCK_H=BLOCK_H, + NUM_KV_SPLITS=NUM_KV_SPLITS, + PAGE_SIZE=page_size, + logit_cap=logit_cap, + num_warps=4, + num_stages=2, + Lk=Lk, + Lv=Lv, + **extra_kargs, + ) + +@triton.jit +def _fwd_kernel_stage2( + Mid_O, + o, + B_Seqlen, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + NUM_KV_SPLITS: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, + cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os, + mask=mask_d, + other=0.0) + tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + o + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + +def _decode_softmax_reducev_fwd( + logits, + q, + o, + v_buffer, + b_seq_len, + num_kv_splits, +): + batch, head_num = q.shape[0], q.shape[1] + Lv = v_buffer.shape[-1] + BLOCK_DV = triton.next_power_of_2(Lv) + + NUM_KV_SPLITS = num_kv_splits + + extra_kargs = {} + # TODO: support hip + """ + if is_hip_: + # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html + # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py + extra_kargs = { + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } + """ + + grid = (batch, head_num) + _fwd_kernel_stage2[grid]( + logits, + o, + b_seq_len, + logits.stride(0), + logits.stride(1), + logits.stride(2), + o.stride(0), + o.stride(1), + NUM_KV_SPLITS=NUM_KV_SPLITS, + BLOCK_DV=BLOCK_DV, + Lv=Lv, + num_warps=4, + num_stages=2, + **extra_kargs, + ) + +def decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap=0.0, +): + _decode_grouped_att_m_fwd( + q, + k_buffer, + v_buffer, + attn_logits, + req_to_token, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) + + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, + num_kv_splits) \ No newline at end of file diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 88c33fd..b2bc9d6 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -133,7 +133,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud ) else: past_key_values = None - cache_position = torch.arange(seq_length, device=torch_device) + cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.long) generated_ids = torch.zeros( batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device ) @@ -178,7 +178,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud generated_ids[:, seq_length] = next_token tokens.append(int(next_token)) inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1) - cache_position = torch.tensor([seq_length], device=torch_device) + cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.long) position_ids = cache_position.unsqueeze(0) seq_length += 1 diff --git a/test_prompt.txt b/test_prompt.txt new file mode 100644 index 0000000..69fd23b --- /dev/null +++ b/test_prompt.txt @@ -0,0 +1,9 @@ +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +阅读以上文字,并概括大意 From 1084d4e4b48570840f0087bb6f99c4c7389b19bd Mon Sep 17 00:00:00 2001 From: Atream Date: Fri, 14 Feb 2025 11:38:55 +0000 Subject: [PATCH 2/2] linux support triton MLA kernel --- ktransformers/models/custom_cache.py | 27 ++-- ktransformers/operators/attention.py | 232 +++++++++++++++++++++------ 2 files changed, 198 insertions(+), 61 deletions(-) diff --git a/ktransformers/models/custom_cache.py b/ktransformers/models/custom_cache.py index 0849d55..8a11f1f 100644 --- a/ktransformers/models/custom_cache.py +++ b/ktransformers/models/custom_cache.py @@ -53,8 +53,9 @@ class StaticCache(transformers.StaticCache): # TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically self.page_size = 64 self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size - key_shape = (self.max_pages, self.page_size, 1, config.qk_rope_head_dim) - value_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank) + latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim) + self.kv_lora_rank = config.kv_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim # TODO: support real page table self.page_table_map = dict() self.page_table_list = [] @@ -88,10 +89,17 @@ class StaticCache(transformers.StaticCache): target_device = device[f"blk.{idx}.self_attn"]["generate_device"] else: target_device = device - new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=target_device) - new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=target_device) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) + + if self.is_MLA: + new_layer_key_cache = torch.zeros(latent_shape, dtype=self.dtype, device=target_device) + new_layer_value_cache = None + torch._dynamo.mark_static_address(new_layer_key_cache) + else: + new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=target_device) + new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=target_device) + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) self.past_tokens.append(0) @@ -129,11 +137,12 @@ class StaticCache(transformers.StaticCache): if self.is_MLA: page_idx = cache_position // self.page_size page_offset = cache_position % self.page_size + # key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim) #print("page_idx", page_idx) #print("page_offset", page_offset) - k_out[page_idx, page_offset, ...] = key_states - v_out[page_idx, page_offset, ...] = value_states - return k_out, v_out, self.page_table_list[layer_idx] + k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states + k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states + return k_out, self.page_table_list[layer_idx] else: k_out[:, :, cache_position] = key_states v_out[:, :, cache_position] = value_states diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index f487773..9f73d48 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -13,8 +13,6 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.models.modeling_llama import LlamaRotaryEmbedding from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb -from ktransformers.models.modeling_deepseek_v3 import DeepseekV3Attention -from ktransformers.models.modeling_deepseek_v3 import apply_rotary_pos_emb as apply_rotary_pos_emb_v3 from typing import Optional, Tuple from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_gguf import GGUFLoader @@ -23,8 +21,15 @@ from transformers.configuration_utils import PretrainedConfig from transformers.cache_utils import Cache from flash_attn import flash_attn_with_kvcache, flash_attn_func from ktransformers.operators.triton_attention import decode_attention_fwd_grouped +import os logger = logging.getLogger("attention") +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) # V3 MLA is same to V2 class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): @@ -80,6 +85,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) + # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim] + # q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim] compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split( @@ -103,16 +110,37 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - compressed_kv = compressed_kv.unsqueeze(1) - k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs) - compressed_kv = compressed_kv.squeeze(1) - #if cache_position is not None: - # compressed_kv = compressed_kv[:,: cache_position[-1] + 1,:] - # k_pe = k_pe[:,:,: cache_position[-1] + 1,:] + + # compressed_kv [bsz, q_len, self.kv_lora_rank] + # k_pe [bsz, 1, q_len, self.qk_rope_head_dim] + k_pe = k_pe.transpose(1,2) + compressed_kv = compressed_kv.unsqueeze(2) + compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) + compressed_kv, k_pe = torch.split( + compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + # k_pe [pages, page_size, 1, self.qk_rope_head_dim] + # compressed_kv [pages, page_size, 1, self.kv_lora_rank] + q_absorb, out_absorb = self.get_absorbed() + if hasattr(self.orig_module, 'kv_b_proj'): + del self.orig_module.kv_b_proj + # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim] + # q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim] + k_pe = k_pe.view(bsz, 1, -1, self.qk_rope_head_dim)[:,:,:attention_mask.size(-1),:] + compressed_kv = compressed_kv.view(bsz, 1, -1, self.kv_lora_rank)[:,:,:attention_mask.size(-1),:] + # k_pe [bsz, 1, cache_len, self.qk_rope_head_dim] + # compressed_kv [bsz, 1, cache_len,self.kv_lora_rank] q_nope = torch.matmul(q_nope, q_absorb) - attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale + #print(q_pe.shape) + #print(k_pe.shape) + #print(q_nope.shape) + #print(compressed_kv.shape) + + attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale + #attn_weights [bsz, self.num_heads, q_len, kv_seq_len] + compressed_kv = compressed_kv.squeeze(1) """ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( @@ -156,25 +184,25 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): return attn_output, None, past_key_value - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - + def forward_linux( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim) q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) @@ -184,38 +212,42 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) compressed_kv = self.kv_a_layernorm(compressed_kv) - k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim) compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank) cos, sin = self.rotary_emb(q_pe, position_ids) - q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin) - k_pe = k_pe.transpose(1, 2) # [bsz, q_len, 1, self.qk_rope_head_dim] - + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2) + # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim] + # decode if q_len == 1: if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - k_pe, compressed_kv, page_table = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs) - - # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim] + compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) + compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank] # for speed + # compressed_kv_with_k_pe [bsz, q_len, 1, self.kv_lora_rank + self.qk_rope_head_dim] + # compressed_kv [bsz, q_len, 1, self.kv_lora_rank] + + # q_nope [bsz, q_len, self.num_heads, self.qk_nope_head_dim] # q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank] q_absorb, out_absorb = self.get_absorbed() + q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below q_nope = torch.matmul(q_nope, q_absorb) # batched MM - # q_nope [bsz, self.num_heads, q_len, self.kv_lora_rank] - # q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim] - query_states = torch.cat([q_nope, q_pe], dim=-1) - # k_pe [bsz, q_len, 1, self.qk_rope_head_dim] - # compressed_kv [bsz, q_len, 1, self.kv_lora_rank] - key_states = torch.cat([compressed_kv, k_pe], dim=-1) + q_nope = q_nope.transpose(1, 2) + assert q_nope.is_contiguous() - query_states = query_states.squeeze(2) - attn_output = torch.zeros_like(q_nope) + # q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank] + # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] + query_states = torch.cat([q_nope, q_pe], dim=-1) + + query_states = query_states.squeeze(1) + attn_output = torch.zeros_like(q_nope) # [bsz, q_len, self.num_heads, self.kv_lora_rank] attn_logits = torch.empty( ( bsz, self.num_heads, - 1, #num_kv_splits # follow vLLM, fix it TODO + 4, #num_kv_splits # follow vLLM, fix it TODO self.kv_lora_rank + 1, ), dtype=torch.float32, @@ -224,22 +256,25 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): """ print("query_states", torch.isnan(query_states).any()) - print("key_states", torch.isnan(key_states[:,:,0,:]).any()) + print("compressed_kv_with_k_pe", torch.isnan(compressed_kv_with_k_pe[:,:,0,:]).any()) print("compressed_kv", torch.isnan(compressed_kv[:,:,0,:]).any()) print("position_ids", torch.isnan(position_ids).any()) """ - # flash attn doesn't support head_dim bigger than 256 + # flash attn doesn't support head_dim bigger than 256 # use vLLM triton attention kernel for MQA - decode_attention_fwd_grouped(query_states, key_states, compressed_kv, attn_output, + decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output, page_table, position_ids.squeeze(0).to(torch.int32), attn_logits, - 1, #num_kv_splits # follow vLLM, fix it TODO + 4, #num_kv_splits # follow vLLM, fix it TODO self.softmax_scale, past_key_value.page_size) - attn_output = torch.matmul(attn_output, out_absorb.mT) - attn_output = attn_output.transpose(1, 2).contiguous() + # attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank] + # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank] + attn_output = attn_output.transpose(1, 2) + attn_output = torch.matmul(attn_output, out_absorb.mT) + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) @@ -250,7 +285,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models k_pe.squeeze(0) compressed_kv.squeeze(0) - past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs) + past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) k_pe.unsqueeze(0) compressed_kv.unsqueeze(0) @@ -261,7 +296,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) ) k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe @@ -269,7 +304,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): key_states[:, :, :, :self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim:] = k_pe - query_states = query_states.transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_heads, self.v_head_dim) value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0) @@ -289,12 +323,106 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): ).contiguous() attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value + + def forward_windows( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) + if q_len <= self.chunck_size: + return self.forward_chunck( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + **kwargs + ) + + assert output_attentions == False, "output_attentions is not supported when using chunked attention" + attn_output = None + cur_idx = 0 + while cur_idx < q_len: + if attention_mask is not None: + chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + self.chunck_size, q_len), ...] + else: + # generate chunk_mask automatically. + self.attn_mask = \ + torch.zeros(1, 1, self.chunck_size, past_key_value.max_cache_len, device=hidden_states.device) \ + if self.attn_mask is None \ + else self.attn_mask + self.attn_mask[:, :, :, cur_idx:min(cur_idx+self.chunck_size, past_key_value.max_cache_len)] = \ + -1e+38 * torch.triu(torch.ones(self.chunck_size, self.chunck_size, device=hidden_states.device), diagonal=1)\ + [:,:min(self.chunck_size, min(past_key_value.max_cache_len-cur_idx, self.chunck_size))] + self.attn_mask[:, :, :, cur_idx+self.chunck_size:] = -1e+38 + self.attn_mask[:, :, :, :cur_idx] = 0 + chunk_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len-cur_idx)) + + cur_output, _, _ = self.forward_chunck( + hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...], + chunk_mask, + position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)], + past_key_value, + output_attentions, + use_cache, + cache_position[cur_idx:min(cur_idx + self.chunck_size, q_len)], + **kwargs + ) + cur_idx += self.chunck_size + if attn_output is None: + attn_output = cur_output + else: + attn_output = torch.cat((attn_output, cur_output), dim=-2) + + return attn_output, None, past_key_value + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if os.name == 'nt': + return self.forward_windows( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + **kwargs, + ) + else: + return self.forward_linux( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + **kwargs, + ) class KLlamaAttention(BaseInjectedModule):