mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 12:40:02 +00:00
781 lines
29 KiB
Python
781 lines
29 KiB
Python
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
"""
|
|
Description :
|
|
Author : Jianwei Dong
|
|
Date : 2024-08-26 23:25:24
|
|
Version : 1.0.0
|
|
LastEditors : Jianwei Dong
|
|
LastEditTime : 2024-08-26 23:25:24
|
|
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
|
"""
|
|
|
|
import torch
|
|
from transformers import AutoConfig
|
|
import sys, os
|
|
import logging
|
|
logger = logging.getLogger("dynamic_attention")
|
|
sys.path.append(os.path.dirname(__file__) + "/../ktransformers_ext/cpu_backend")
|
|
from ktransformers.operators.cpuinfer import CPUInfer, CPUInferKVCache
|
|
try:
|
|
from flash_attn import flash_attn_func, flash_attn_with_kvcache
|
|
except:
|
|
print("falsh attn not found")
|
|
|
|
|
|
import math
|
|
import json
|
|
|
|
|
|
class DynamicScaledDotProductAttention:
|
|
remaining_length: int
|
|
cpu_infer = None
|
|
|
|
def __init__(
|
|
self,
|
|
max_seq_len: int,
|
|
block_size: int,
|
|
config: AutoConfig,
|
|
device: torch.device,
|
|
local_windows_len: int,
|
|
topk: int,
|
|
threads_num: int,
|
|
anchor_type: str = "DYNAMIC",
|
|
kv_type: str = "FP16",
|
|
dense_layer_num: int = 0,
|
|
anchor_num: int = 1,
|
|
block_selection_mode: str = "SHARED",
|
|
layer_step: int = 1,
|
|
token_step: int = 1,
|
|
preselect_block: bool = False,
|
|
preselect_block_count: int = 96,
|
|
prefill_chunk_size: int = 20480,
|
|
use_attn_sparsity: bool = False,
|
|
):
|
|
# assert anchor_num == 1
|
|
# assert anchor_type == "DYNAMIC"
|
|
self.remaining_length = 0
|
|
valid_anchor_types = ["DYNAMIC", "FIXED", "BLOCK_MEAN", "BLOCK_MAX", "QUEST"]
|
|
assert anchor_type in valid_anchor_types
|
|
if anchor_type == "QUEST":
|
|
assert anchor_num == 2
|
|
elif anchor_type != "FIXED" and anchor_type != "DYNAMIC":
|
|
assert anchor_num == 1
|
|
|
|
valid_kv_types = ["FP16", "FP32", "Q4_0", "Q8_0"]
|
|
assert kv_type in valid_kv_types
|
|
if kv_type != "FP16" and kv_type != "FP32":
|
|
assert block_size % 32 == 0
|
|
|
|
valid_block_selection_modes = ["SHARED", "SEPARATE"] # individual
|
|
assert block_selection_mode in valid_block_selection_modes
|
|
|
|
self.max_seq_len = max_seq_len
|
|
self.block_num = max_seq_len // block_size
|
|
self.block_size = block_size
|
|
self.anchor_type = anchor_type
|
|
self.kv_type = kv_type
|
|
self.anchor_num = anchor_num
|
|
self.threads_num = threads_num
|
|
self.layer_step = layer_step
|
|
self.token_step = token_step
|
|
self.preselect_block = preselect_block
|
|
self.preselect_block_count = preselect_block_count
|
|
self.block_selection_mode = block_selection_mode
|
|
self.use_attn_sparsity = use_attn_sparsity
|
|
|
|
# model config
|
|
self.kv_head_num = config.num_key_value_heads
|
|
self.q_head_num = config.num_attention_heads
|
|
self.head_dim = config.hidden_size // config.num_attention_heads
|
|
self.layer_num = config.num_hidden_layers
|
|
|
|
self.device = device
|
|
self.local_windows_len = local_windows_len
|
|
self.local_block_num = self.local_windows_len // self.block_size + 1
|
|
self.prefill_chunk_size = prefill_chunk_size
|
|
|
|
self.topk = topk
|
|
self.dense_layer_num = dense_layer_num
|
|
# self.dense_layer_num = 32
|
|
self.cache_key_states = torch.zeros(
|
|
(self.block_num, block_size, self.kv_head_num, self.head_dim),
|
|
device=device,
|
|
dtype=torch.float16,
|
|
)
|
|
self.cache_value_states = torch.zeros(
|
|
(self.block_num, block_size, self.kv_head_num, self.head_dim),
|
|
device=device,
|
|
dtype=torch.float16,
|
|
)
|
|
# [max_num_block, block_size, head_num]
|
|
self.cache_importance = torch.zeros(
|
|
(self.block_num, block_size, self.q_head_num),
|
|
device=device,
|
|
dtype=torch.float16,
|
|
)
|
|
|
|
# key_states: [bsz, q_len, kv_head_num, head_dim]
|
|
# value_states: [bsz, q_len, kv_head_num, head_dim]
|
|
# query_states: [bsz, q_len, q_head_num, head_dim]
|
|
self.q_in_cpu = torch.zeros(
|
|
(1, 1, self.q_head_num, self.head_dim),
|
|
device="cpu",
|
|
dtype=torch.float16,
|
|
pin_memory=True,
|
|
)
|
|
self.k_in_cpu = torch.zeros(
|
|
(1, 1, self.kv_head_num, self.head_dim),
|
|
device="cpu",
|
|
dtype=torch.float16,
|
|
pin_memory=True,
|
|
)
|
|
self.v_in_cpu = torch.zeros(
|
|
(1, 1, self.kv_head_num, self.head_dim),
|
|
device="cpu",
|
|
dtype=torch.float16,
|
|
pin_memory=True,
|
|
)
|
|
|
|
self.cache_seqlens_cpu = torch.empty(
|
|
(1,), device="cpu", dtype=torch.int32, pin_memory=True
|
|
)
|
|
|
|
self.cache_seqlens_cuda = torch.empty((1,), device=device, dtype=torch.int32)
|
|
|
|
self.prefix_block_table = torch.arange(
|
|
self.block_num, device="cpu", dtype=torch.int32, pin_memory=True
|
|
).view(1, -1)
|
|
|
|
self.block_table_cpu = torch.arange(
|
|
self.block_num, device="cpu", dtype=torch.int32, pin_memory=True
|
|
).view(1, -1)
|
|
|
|
# assert (
|
|
# self.local_windows_len // self.block_size + 1 + self.preselect_block_count
|
|
# <= self.block_num
|
|
# )
|
|
|
|
self.output_cpu = torch.empty(
|
|
(1, 1, self.q_head_num, self.head_dim),
|
|
device="cpu",
|
|
dtype=torch.float16,
|
|
pin_memory=True,
|
|
)
|
|
self.lse_cpu = torch.empty(
|
|
(1, 1, self.q_head_num), device="cpu", dtype=torch.float32, pin_memory=True
|
|
)
|
|
|
|
self.output_cuda = torch.empty(
|
|
(1, 1, self.q_head_num, self.head_dim), device=device, dtype=torch.float16
|
|
)
|
|
|
|
self.attn_sparsity = torch.zeros(
|
|
(1, 1, self.q_head_num), device="cpu", dtype=torch.float32, pin_memory=True
|
|
)
|
|
|
|
if preselect_block == True:
|
|
self.preselect_block_table = torch.zeros(
|
|
self.layer_num,
|
|
self.preselect_block_count,
|
|
device=device,
|
|
dtype=torch.int32,
|
|
)
|
|
self.preselect_block_num = 0 # block_num before preselect
|
|
self.evict_tokens = 0
|
|
|
|
if DynamicScaledDotProductAttention.cpu_infer is None:
|
|
DynamicScaledDotProductAttention.cpu_infer = CPUInfer(threads_num)
|
|
self.cpu_infer = DynamicScaledDotProductAttention.cpu_infer
|
|
self.local_thread = CPUInferKVCache(
|
|
self.layer_num,
|
|
self.kv_head_num,
|
|
self.q_head_num,
|
|
self.head_dim,
|
|
self.block_size,
|
|
anchor_num=self.anchor_num,
|
|
anchor_type=anchor_type,
|
|
kv_type=self.kv_type,
|
|
retrieval_type=self.block_selection_mode,
|
|
layer_step=self.layer_step,
|
|
token_step=self.token_step,
|
|
layer_offset=self.dense_layer_num % self.layer_step,
|
|
max_batch_size=1,
|
|
max_block_num=self.block_num,
|
|
max_thread_num=self.threads_num,
|
|
)
|
|
|
|
print(
|
|
f"local_windows_len: {local_windows_len}, topk: {topk}, dense_layer_num: {dense_layer_num}, kv_type: {self.kv_type}, anchor_type: {self.anchor_type}, preselect_block: {self.preselect_block}, preselect_block_count: {self.preselect_block_count}, token_step: {self.token_step}, layer_step: {self.layer_step}"
|
|
)
|
|
|
|
self.shape_mask = (
|
|
self.q_head_num,
|
|
self.block_size,
|
|
self.block_size,
|
|
)
|
|
|
|
mask = torch.zeros(
|
|
self.shape_mask, dtype=torch.uint8, device=device
|
|
).contiguous()
|
|
elm_idx = torch.arange(self.block_size, device=device)
|
|
|
|
for i in range(mask.size(-2)):
|
|
idx = i + mask.size(-1) - mask.size(-2) - elm_idx
|
|
idx = idx[idx >= 0]
|
|
mask[..., i, idx] = 1
|
|
|
|
self.tril_mask = mask
|
|
self.triu_mask = mask ^ 1
|
|
|
|
self.generate_token_idx = 0
|
|
|
|
def get_attn_score_one_block(
|
|
self,
|
|
batch_idx: int,
|
|
max_block_num: int,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
offset: int,
|
|
width: int,
|
|
mask_mode: str | None = None,
|
|
use_softmax: bool = True,
|
|
):
|
|
n_rep = self.q_head_num // self.kv_head_num
|
|
importance = self.cache_importance.view(-1, self.q_head_num)
|
|
importance = importance.narrow(0, batch_idx * max_block_num + offset, width)
|
|
n_gqa_ = self.q_head_num // self.kv_head_num
|
|
for head_idx in range(self.q_head_num):
|
|
key_item = key[..., head_idx // n_gqa_, :].view(key.size(0), -1)
|
|
qk = torch.einsum(
|
|
"qd,kd->qk", query[:,head_idx,:], key_item
|
|
) # (num_attention_heads, len_q, len_k)
|
|
|
|
if mask_mode == "tril":
|
|
mask = self.tril_mask
|
|
mask = mask[0, -qk.size(-2) :, -qk.size(-1) :]
|
|
qk = qk * mask
|
|
elif mask_mode == "triu":
|
|
mask = self.triu_mask
|
|
mask = mask[0, -qk.size(-2) :, -qk.size(-1) :]
|
|
qk = qk * mask
|
|
|
|
if use_softmax:
|
|
qk = torch.nn.functional.softmax(
|
|
qk / math.sqrt(self.head_dim), dim=-1, dtype=torch.float32
|
|
).to(torch.float16)
|
|
|
|
qk = torch.sum(qk, dim=-2)
|
|
importance[...,head_idx] += qk
|
|
|
|
def get_preselect_block_table_and_attn_score(
|
|
self,
|
|
layer_idx: int,
|
|
batch_size: int,
|
|
offset: torch.Tensor,
|
|
width: int,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
union_with_last_layer: bool = True,
|
|
):
|
|
max_seqs_len = offset.max().item() + width
|
|
max_block_num = (max_seqs_len + self.block_size - 1) // self.block_size
|
|
|
|
for batch_idx in range(batch_size):
|
|
query_cur = query[batch_idx][-128:]
|
|
self.get_attn_score_one_block(
|
|
batch_idx,
|
|
max_block_num,
|
|
query_cur,
|
|
key[batch_idx][: offset[batch_idx].item() + width],
|
|
0,
|
|
offset[batch_idx].item() + width,
|
|
mask_mode=None,
|
|
)
|
|
|
|
if self.preselect_block:
|
|
self.prefill_block_num = max(
|
|
0, max_block_num - self.local_windows_len // self.block_size
|
|
)
|
|
self.evict_tokens = (
|
|
max(self.prefill_block_num - self.preselect_block_count, 0)
|
|
* self.block_size
|
|
)
|
|
|
|
if self.prefill_block_num != 0:
|
|
importance_cache = self.cache_importance.narrow(
|
|
0, 0, self.prefill_block_num * batch_size
|
|
).view(
|
|
batch_size, self.prefill_block_num, self.block_size, self.q_head_num
|
|
)
|
|
|
|
importance_r = importance_cache[:, 1:, : self.block_size // 4]
|
|
pad_r = torch.zeros_like(importance_r[:, :1])
|
|
importance_r = torch.cat((importance_r, pad_r), dim=1)
|
|
importance_l = importance_cache[:, :-1, -self.block_size // 4 :]
|
|
pad_l = torch.zeros_like(importance_l[:, :1])
|
|
importance_l = torch.cat((pad_l, importance_l), dim=1)
|
|
importance = torch.cat(
|
|
(importance_l, importance_cache, importance_r), dim=2
|
|
)
|
|
importance = importance.mean(dim=-1)
|
|
importance = importance.mean(dim=-1)
|
|
# importance: (batch_size, max_block_num)
|
|
topk = min(self.preselect_block_count, self.prefill_block_num)
|
|
values, indices = torch.topk(
|
|
importance,
|
|
k=topk,
|
|
dim=1,
|
|
)
|
|
|
|
self.preselect_block_table[
|
|
layer_idx : layer_idx + 1,
|
|
:topk,
|
|
].copy_(indices)
|
|
|
|
if union_with_last_layer and layer_idx == 31:
|
|
for tmp_layer_idx in range(self.layer_num - 1):
|
|
for i in range(1, min(topk, 6)):
|
|
x = self.preselect_block_table[-1, i]
|
|
if x not in self.preselect_block_table[tmp_layer_idx]:
|
|
self.preselect_block_table[tmp_layer_idx, topk - i] = x
|
|
if self.anchor_type == "DYNAMIC":
|
|
importance_cache = self.cache_importance.narrow(
|
|
0, 0, max_block_num * batch_size
|
|
).view(batch_size, max_block_num * self.block_size, self.q_head_num)
|
|
importance_cache_cpu = torch.empty_like(
|
|
importance_cache, device="cpu", pin_memory=True
|
|
)
|
|
|
|
importance_cache_cpu.copy_(importance_cache)
|
|
|
|
block_table_cpu = self.prefix_block_table[:, :max_block_num].to("cpu")
|
|
offset_cpu = offset.contiguous().to("cpu")
|
|
|
|
self.cpu_infer.submit(
|
|
self.local_thread.update_importance(
|
|
importance_cache_cpu,
|
|
layer_idx,
|
|
block_table_cpu,
|
|
max_block_num,
|
|
offset_cpu,
|
|
width,
|
|
)
|
|
)
|
|
self.cpu_infer.sync()
|
|
|
|
importance_cache = self.cache_importance.narrow(
|
|
0, 0, max_block_num * batch_size
|
|
).view(batch_size, max_block_num * self.block_size, self.q_head_num)
|
|
importance_cache.zero_()
|
|
|
|
# key: [bsz, past_len, head_num, head_dim] float16
|
|
# query: [bsz, q_len, q_head_num, head_dim] float16
|
|
def get_attn_score(
|
|
self,
|
|
layer_idx: int,
|
|
batch_size: int,
|
|
offset: torch.Tensor,
|
|
width: int,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
):
|
|
max_seqs_len = offset.max().item() + width
|
|
max_block_num = (max_seqs_len + self.block_size - 1) // self.block_size
|
|
|
|
for batch_idx in range(batch_size):
|
|
for idx in range(width // self.block_size):
|
|
offset_cur = idx * self.block_size
|
|
query_cur = query[batch_idx, offset_cur : offset_cur + self.block_size]
|
|
self.get_attn_score_one_block(
|
|
batch_idx,
|
|
max_block_num,
|
|
query_cur,
|
|
key[
|
|
batch_idx,
|
|
offset[batch_idx]
|
|
+ offset_cur : offset[batch_idx]
|
|
+ offset_cur
|
|
+ self.block_size,
|
|
],
|
|
offset[batch_idx].item() + offset_cur,
|
|
self.block_size,
|
|
mask_mode="tril",
|
|
use_softmax=False,
|
|
)
|
|
|
|
offset_key = (
|
|
offset[batch_idx].item()
|
|
+ idx * self.block_size
|
|
- self.local_windows_len
|
|
)
|
|
if offset_key >= 0:
|
|
self.get_attn_score_one_block(
|
|
batch_idx,
|
|
max_block_num,
|
|
query_cur,
|
|
key[batch_idx, offset_key : offset_key + self.block_size],
|
|
offset_key,
|
|
self.block_size,
|
|
mask_mode="triu",
|
|
use_softmax=False,
|
|
)
|
|
|
|
offset_key = max(0, offset_key + self.block_size)
|
|
width_key = (
|
|
offset[batch_idx].item() + idx * self.block_size - offset_key
|
|
)
|
|
if width_key > 0:
|
|
self.get_attn_score_one_block(
|
|
batch_idx,
|
|
max_block_num,
|
|
query_cur,
|
|
key[batch_idx, offset_key : offset_key + width_key],
|
|
offset_key,
|
|
width_key,
|
|
mask_mode=None,
|
|
use_softmax=False,
|
|
)
|
|
|
|
importance_cache = self.cache_importance.narrow(
|
|
0, 0, max_block_num * batch_size
|
|
).view(batch_size, max_block_num * self.block_size, self.q_head_num)
|
|
importance_cache_cpu = torch.empty_like(
|
|
importance_cache, device="cpu", pin_memory=True
|
|
)
|
|
|
|
importance_cache_cpu.copy_(importance_cache)
|
|
|
|
block_table_cpu = self.prefix_block_table[:, :max_block_num].to("cpu")
|
|
offset_cpu = offset.contiguous().to("cpu")
|
|
|
|
self.cpu_infer.submit(
|
|
self.local_thread.update_importance(
|
|
importance_cache_cpu,
|
|
layer_idx,
|
|
block_table_cpu,
|
|
max_block_num,
|
|
offset_cpu,
|
|
width,
|
|
)
|
|
)
|
|
self.cpu_infer.sync()
|
|
importance_cache.zero_()
|
|
|
|
# key: [bsz, q_len, head_num, head_dim] float16
|
|
# value: [bsz, q_len, head_num, head_dim] float16
|
|
def swap_in_and_swap_out(self, layer_idx, past_len, q_len, key, value):
|
|
batch_size = 1
|
|
max_seqs_len = past_len.max().item() + q_len
|
|
max_block_num = (max_seqs_len + self.block_size - 1) // self.block_size
|
|
k_cache = self.cache_key_states.narrow(0, 0, max_block_num * batch_size).view(
|
|
batch_size, max_block_num * self.block_size, self.kv_head_num, self.head_dim
|
|
)
|
|
v_cache = self.cache_value_states.narrow(0, 0, max_block_num * batch_size).view(
|
|
batch_size, max_block_num * self.block_size, self.kv_head_num, self.head_dim
|
|
)
|
|
|
|
for batch_idx in range(batch_size):
|
|
offset = past_len[batch_idx]
|
|
width = q_len
|
|
k_cache[batch_idx][offset : offset + width].copy_(
|
|
key[batch_idx].view(-1, self.kv_head_num, self.head_dim)
|
|
)
|
|
v_cache[batch_idx][offset : offset + width].copy_(
|
|
value[batch_idx].view(-1, self.kv_head_num, self.head_dim)
|
|
)
|
|
|
|
k_cache_cpu = torch.empty_like(k_cache, device="cpu", pin_memory=True)
|
|
v_cache_cpu = torch.empty_like(v_cache, device="cpu", pin_memory=True)
|
|
|
|
k_cache_cpu.copy_(k_cache)
|
|
v_cache_cpu.copy_(v_cache)
|
|
|
|
cur_block_num = (
|
|
q_len + past_len[0].item() + self.block_size - 1
|
|
) // self.block_size
|
|
block_table_cpu = self.prefix_block_table[:, :cur_block_num].to("cpu")
|
|
past_len_cpu = past_len.contiguous().to("cpu")
|
|
|
|
self.cpu_infer.submit(
|
|
self.local_thread.get_and_update_kvcache_fp16(
|
|
k_cache_cpu,
|
|
v_cache_cpu,
|
|
layer_idx,
|
|
block_table_cpu,
|
|
max_block_num,
|
|
past_len_cpu,
|
|
q_len,
|
|
)
|
|
)
|
|
|
|
self.cpu_infer.sync()
|
|
k_cache.copy_(k_cache_cpu)
|
|
v_cache.copy_(v_cache_cpu)
|
|
|
|
return k_cache, v_cache
|
|
|
|
def calc_anchor(self, cache_seqlens: int):
|
|
cur_block_num = (cache_seqlens + self.block_size - 1) // self.block_size
|
|
block_table_cpu = self.prefix_block_table[:, :cur_block_num].to("cpu")
|
|
cache_seqlens_cpu = torch.tensor(
|
|
[cache_seqlens], device="cpu", dtype=torch.int32
|
|
)
|
|
|
|
self.cpu_infer.submit(
|
|
self.local_thread.calc_anchor_all_layers(
|
|
block_table_cpu,
|
|
cache_seqlens_cpu,
|
|
)
|
|
)
|
|
self.cpu_infer.sync()
|
|
|
|
def clear_importance(self, cache_seqlens: int):
|
|
print(f"clear importance: {cache_seqlens}")
|
|
cur_block_num = (cache_seqlens + self.block_size - 1) // self.block_size
|
|
block_table_cpu = self.prefix_block_table[:, :cur_block_num].to("cpu")
|
|
cache_seqlens_cpu = torch.tensor(
|
|
[cache_seqlens], device="cpu", dtype=torch.int32
|
|
)
|
|
|
|
self.cpu_infer.submit(
|
|
self.local_thread.clear_importance_all_layers(
|
|
block_table_cpu,
|
|
cache_seqlens_cpu,
|
|
)
|
|
)
|
|
self.cpu_infer.sync()
|
|
|
|
def clear_kvcache(self, cache_seqlens: int):
|
|
cur_block_num = (cache_seqlens + self.block_size - 1) // self.block_size
|
|
block_table_cpu = self.prefix_block_table[:, :cur_block_num].to("cpu")
|
|
cache_seqlens_cpu = torch.tensor(
|
|
[cache_seqlens], device="cpu", dtype=torch.int32
|
|
)
|
|
|
|
self.cpu_infer.submit(
|
|
self.local_thread.clear_kvcache_all_layers(
|
|
block_table_cpu,
|
|
cache_seqlens_cpu,
|
|
)
|
|
)
|
|
self.cpu_infer.sync()
|
|
|
|
def get_attn_sparsity(
|
|
self,
|
|
q_in: torch.Tensor,
|
|
layer_idx: int,
|
|
block_table: torch.Tensor,
|
|
cache_seqlens: torch.Tensor,
|
|
block_table_origin: torch.Tensor,
|
|
cache_seqlens_origin: torch.Tensor,
|
|
generate_token_idx: int = 0,
|
|
topk: int | None = None,
|
|
local: int | None = None,
|
|
output_path: str = "./attn_sparsity.json",
|
|
):
|
|
self.attn_sparsity.zero_()
|
|
self.pcinfer.submit(
|
|
self.local_thread.get_attn_sparsity(
|
|
q_in,
|
|
self.attn_sparsity,
|
|
layer_idx,
|
|
block_table,
|
|
cache_seqlens,
|
|
block_table_origin,
|
|
cache_seqlens_origin,
|
|
generate_token_idx,
|
|
topk,
|
|
local,
|
|
)
|
|
)
|
|
self.cpu_infer.sync()
|
|
with open(output_path, "a") as file:
|
|
for head_idx in range(self.q_head_num):
|
|
sparsity = self.attn_sparsity[0][0][head_idx].item()
|
|
json_obj = {
|
|
"token_idx": generate_token_idx,
|
|
"layer_idx": layer_idx,
|
|
"head_idx": head_idx,
|
|
"sparsity": sparsity,
|
|
}
|
|
json.dump(json_obj, file)
|
|
file.write("\n")
|
|
|
|
def apply(
|
|
self,
|
|
layer_idx: int,
|
|
bsz: int,
|
|
past_len: int,
|
|
query_states: torch.Tensor,
|
|
key_states: torch.Tensor,
|
|
value_states: torch.Tensor,
|
|
mode: str = "prefill",
|
|
generate_token_idx: int = -1,
|
|
):
|
|
|
|
# key_states: [bsz, q_len, kv_head_num, head_dim]
|
|
# value_states: [bsz, q_len, kv_head_num, head_dim]
|
|
# query_states: [bsz, q_len, q_head_num, head_dim]
|
|
assert query_states.dtype == torch.float16
|
|
assert key_states.dtype == torch.float16
|
|
assert value_states.dtype == torch.float16
|
|
|
|
assert key_states.size(2) == self.kv_head_num
|
|
assert value_states.size(2) == self.kv_head_num
|
|
assert query_states.size(2) == self.q_head_num
|
|
|
|
q_len = query_states.size(1)
|
|
batch_size = query_states.size(0)
|
|
self.cache_seqlens_cuda.fill_(past_len)
|
|
last_chunk = False
|
|
if self.remaining_length <= self.prefill_chunk_size and q_len != 1:
|
|
last_chunk = True
|
|
device = query_states.device
|
|
if layer_idx == 0:
|
|
if q_len == 1:
|
|
self.generate_token_idx += 1
|
|
elif last_chunk:
|
|
self.generate_token_idx = -1
|
|
|
|
if mode == "prefill":
|
|
key, value = self.swap_in_and_swap_out(
|
|
layer_idx,
|
|
self.cache_seqlens_cuda,
|
|
q_len,
|
|
key_states,
|
|
value_states,
|
|
)
|
|
|
|
if last_chunk and (self.anchor_type == "DYNAMIC" or self.preselect_block):
|
|
self.get_preselect_block_table_and_attn_score(
|
|
layer_idx,
|
|
bsz,
|
|
self.cache_seqlens_cuda,
|
|
q_len,
|
|
query_states,
|
|
key,
|
|
)
|
|
output = flash_attn_with_kvcache(
|
|
q=query_states,
|
|
k_cache=key,
|
|
v_cache=value,
|
|
cache_seqlens=self.cache_seqlens_cuda + q_len,
|
|
causal=True,
|
|
)
|
|
return output.transpose(1, 2)
|
|
|
|
elif mode == "generate":
|
|
assert self.generate_token_idx >= 0
|
|
self.q_in_cpu.copy_(query_states, non_blocking=True)
|
|
self.k_in_cpu.copy_(key_states, non_blocking=True)
|
|
self.v_in_cpu.copy_(value_states, non_blocking=True)
|
|
self.cache_seqlens_cpu.copy_(self.cache_seqlens_cuda, non_blocking=True)
|
|
# print(layer_idx)
|
|
if layer_idx < self.dense_layer_num:
|
|
self.block_table_cpu.copy_(self.prefix_block_table, non_blocking=True)
|
|
self.cpu_infer.submit_with_cuda_stream(
|
|
torch.cuda.current_stream("cuda").cuda_stream,
|
|
self.local_thread.attn_with_kvcache(
|
|
q_in=self.q_in_cpu,
|
|
k_in=self.k_in_cpu,
|
|
v_in=self.v_in_cpu,
|
|
output=self.output_cpu,
|
|
attn_lse=self.lse_cpu,
|
|
layer_idx=layer_idx,
|
|
block_table=self.block_table_cpu,
|
|
cache_seqlens=self.cache_seqlens_cpu,
|
|
),
|
|
)
|
|
else:
|
|
if self.preselect_block:
|
|
self.cache_seqlens_cpu.copy_(
|
|
self.cache_seqlens_cuda - self.evict_tokens, non_blocking=True
|
|
)
|
|
if self.preselect_block_count < self.prefill_block_num:
|
|
self.block_table_cpu[:, : self.preselect_block_count].copy_(
|
|
self.preselect_block_table[layer_idx : layer_idx + 1],
|
|
non_blocking=True,
|
|
)
|
|
|
|
self.block_table_cpu[
|
|
:,
|
|
self.preselect_block_count : self.preselect_block_count
|
|
+ self.local_block_num,
|
|
].copy_(
|
|
self.prefix_block_table[
|
|
:,
|
|
self.prefill_block_num : self.prefill_block_num
|
|
+ self.local_block_num,
|
|
],
|
|
non_blocking=True,
|
|
)
|
|
# print("submit_with_cuda_stream")
|
|
self.cpu_infer.submit_with_cuda_stream(
|
|
torch.cuda.current_stream("cuda").cuda_stream,
|
|
self.local_thread.attn_with_kvcache(
|
|
q_in=self.q_in_cpu,
|
|
k_in=self.k_in_cpu,
|
|
v_in=self.v_in_cpu,
|
|
output=self.output_cpu,
|
|
attn_lse=self.lse_cpu,
|
|
layer_idx=layer_idx,
|
|
generate_token_idx=self.generate_token_idx,
|
|
block_table=self.block_table_cpu,
|
|
cache_seqlens=self.cache_seqlens_cpu,
|
|
topk=(
|
|
self.topk
|
|
if self.topk <= self.preselect_block_count
|
|
else None
|
|
),
|
|
local=self.local_windows_len // self.block_size,
|
|
),
|
|
)
|
|
# print("submit_with_cuda_stream enqueue\n")
|
|
else:
|
|
self.block_table_cpu.copy_(
|
|
self.prefix_block_table, non_blocking=True
|
|
)
|
|
self.cpu_infer.submit_with_cuda_stream(
|
|
torch.cuda.current_stream("cuda").cuda_stream,
|
|
self.local_thread.attn_with_kvcache(
|
|
q_in=self.q_in_cpu,
|
|
k_in=self.k_in_cpu,
|
|
v_in=self.v_in_cpu,
|
|
output=self.output_cpu,
|
|
attn_lse=self.lse_cpu,
|
|
layer_idx=layer_idx,
|
|
generate_token_idx=self.generate_token_idx,
|
|
block_table=self.block_table_cpu,
|
|
cache_seqlens=self.cache_seqlens_cpu,
|
|
topk=self.topk,
|
|
local=self.local_windows_len // self.block_size,
|
|
),
|
|
)
|
|
self.cpu_infer.sync_with_cuda_stream(
|
|
torch.cuda.current_stream("cuda").cuda_stream
|
|
)
|
|
# print("submit_with_cuda_stream finished\n")
|
|
self.output_cuda.copy_(self.output_cpu, non_blocking=True)
|
|
return self.output_cuda.transpose(1, 2)
|
|
|
|
def save(self, path: str, length: int):
|
|
cur_block_num = (length + self.block_size - 1) // self.block_size
|
|
block_table_cpu = self.prefix_block_table[0, :cur_block_num].to("cpu")
|
|
cache_seqlens_cpu = torch.tensor([length], device="cpu", dtype=torch.int32)
|
|
self.cpu_infer.submit(
|
|
self.local_thread.dump_kvcache(
|
|
block_table_cpu,
|
|
cache_seqlens_cpu,
|
|
path,
|
|
)
|
|
)
|
|
self.cpu_infer.sync()
|
|
|
|
def load(self, path: str, length: int):
|
|
self.cpu_infer.submit(
|
|
self.local_thread.load_kvcache(
|
|
path,
|
|
)
|
|
)
|
|
self.cpu_infer.sync()
|