mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
support qwen3, dont speak human language
This commit is contained in:
parent
f3d842a0ca
commit
3f9bbf1181
30 changed files with 3696 additions and 290 deletions
|
@ -275,3 +275,59 @@ class KDeepSeekV3Cache(nn.Module):
|
|||
|
||||
return page_idx, page_offset
|
||||
|
||||
class KGQACache(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
page_size: int = 256,
|
||||
dtype=torch.bfloat16,
|
||||
device=torch.device("cuda:0"),
|
||||
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.page_size = page_size
|
||||
self.k_caches = []
|
||||
self.v_caches = []
|
||||
|
||||
|
||||
def load(self, inference_context: sched_ext.InferenceContext):
|
||||
print(self.config.num_hidden_layers)
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
self.k_caches.append(
|
||||
inference_context.k_cache[0][i]
|
||||
)
|
||||
self.v_caches.append(
|
||||
inference_context.v_cache[0][i]
|
||||
)
|
||||
|
||||
|
||||
self.max_cache_len = self.k_caches[0].shape[0]*self.k_caches[0].shape[1]
|
||||
|
||||
|
||||
|
||||
def get_page_table(self, cache_position: torch.Tensor, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, bsz_tensors: torch.tensor):
|
||||
page_offset = cache_position % self.page_size
|
||||
page_idx_local = cache_position // self.page_size
|
||||
query_ids = torch.zeros_like(cache_position)
|
||||
for i in range(len(q_indptr) - 1):
|
||||
start_idx = q_indptr[i]
|
||||
end_idx = q_indptr[i + 1]
|
||||
query_ids[start_idx:end_idx] = i
|
||||
page_idx = torch.zeros_like(page_idx_local)
|
||||
for i in range(bsz_tensors[0]):
|
||||
query_id = query_ids[i]
|
||||
local_block = page_idx_local[i]
|
||||
start_block = kv_indptr[query_id]
|
||||
if local_block < kv_indptr[query_id + 1] - kv_indptr[query_id]:
|
||||
page_idx[i] = kv_indices[start_block + local_block]
|
||||
|
||||
return page_idx, page_offset
|
||||
|
||||
def get_k_cache(self, layer_idx):
|
||||
return self.k_caches[layer_idx]
|
||||
|
||||
def get_v_cache(self, layer_idx):
|
||||
return self.v_caches[layer_idx]
|
Loading…
Add table
Add a link
Reference in a new issue