mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
support qwen3, dont speak human language
This commit is contained in:
parent
f3d842a0ca
commit
3f9bbf1181
30 changed files with 3696 additions and 290 deletions
|
@ -762,92 +762,3 @@ class KLlamaAttention(BaseInjectedModule):
|
|||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
class flashinfer_attn(BaseInjectedModule, DeepseekV2Attention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
|
||||
|
||||
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
|
||||
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
|
||||
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)
|
||||
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)
|
||||
self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim,
|
||||
bias=False, dtype=q_absorb.dtype, device=q_absorb.device)
|
||||
self.q_absorb.weight.data = q_absorb
|
||||
self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim,
|
||||
bias=False, dtype=out_absorb.dtype, device=out_absorb.device)
|
||||
self.out_absorb.weight.data = out_absorb
|
||||
#del self.orig_module.kv_b_proj
|
||||
q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
|
||||
out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
|
||||
return q_absorb, out_absorb
|
||||
|
||||
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KDeepSeekV3Cache,
|
||||
position_ids: torch.Tensor,
|
||||
wrapper: BatchMLAPagedAttentionWrapper,
|
||||
num_tokens_tensors: torch.Tensor,
|
||||
page_idx: torch.Tensor,
|
||||
page_offset: torch.Tensor,
|
||||
):
|
||||
q_len, _ = hidden_states.size()
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(hidden_states, num_tokens_tensors)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states, num_tokens_tensors), num_tokens_tensors), num_tokens_tensors)
|
||||
q = q.view(q_len, self.num_heads, self.q_head_dim)
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states, num_tokens_tensors)
|
||||
compressed_kv, k_pe = torch.split(
|
||||
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
compressed_kv = compressed_kv.contiguous()
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv, num_tokens_tensors)
|
||||
k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim)
|
||||
compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank)
|
||||
|
||||
cos, sin = self.rotary_emb(q_pe, position_ids.unsqueeze(0))
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2)
|
||||
q_pe = q_pe.squeeze(0)
|
||||
if kv_cache is not None:
|
||||
|
||||
# page_idx, page_offset = kv_cache.get_page_table(position_ids, q_indptr, kv_indptr, kv_indices)
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "page_idx": page_idx, "page_offset": page_offset} # Specific to RoPE models
|
||||
compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, page_idx, page_offset, cache_kwargs)
|
||||
compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank)
|
||||
k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim)
|
||||
|
||||
q_absorb, out_absorb = self.get_absorbed()
|
||||
q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below
|
||||
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
# q_nope.squeeze_(1)
|
||||
# q_pe.squeeze_(1)
|
||||
|
||||
attn_output = wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(q_len, self.num_heads, self.kv_lora_rank)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = torch.matmul(attn_output, out_absorb.mT) # [self.num_heads, q_len, self.v_head_dim]
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.o_proj(attn_output, num_tokens_tensors)
|
||||
return attn_output
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue