mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
add balance-serve, support concurrence
This commit is contained in:
parent
8d0292aa44
commit
25cee5810e
196 changed files with 22077 additions and 565 deletions
|
@ -8,9 +8,11 @@ Version : 0.1.0
|
|||
# Copyright 2018- The Hugging Face team. All rights reserved.
|
||||
# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from transformers import Cache, PretrainedConfig
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from ktransformers.server.balance_serve.settings import sched_ext
|
||||
class StaticCache(transformers.StaticCache):
|
||||
"""
|
||||
Static Cache class to be used with `torch.compile(model)`.
|
||||
|
@ -188,3 +190,85 @@ class StaticCache(transformers.StaticCache):
|
|||
def get_max_cache_shape(self) -> Tuple[int, int, int, int]:
|
||||
"""Returns the maximum shape of the cache."""
|
||||
return self.max_cache_len
|
||||
|
||||
class KDeepSeekV3Cache(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
page_size: int = 256,
|
||||
dtype=torch.bfloat16,
|
||||
device=torch.device("cuda:0"),
|
||||
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.page_size = page_size
|
||||
self.k_caches = []
|
||||
self.v_caches = []
|
||||
|
||||
|
||||
def load(self, inference_context: sched_ext.InferenceContext):
|
||||
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
self.k_caches.append(
|
||||
inference_context.k_cache[0][i]
|
||||
)
|
||||
self.max_cache_len = self.k_caches[0].shape[0]*self.k_caches[0].shape[1]
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: int,
|
||||
|
||||
page_idx: torch.Tensor,
|
||||
page_offset: torch.Tensor,
|
||||
|
||||
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
||||
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
|
||||
|
||||
Parameters:
|
||||
key_states (`torch.Tensor`):
|
||||
The new key states to cache.
|
||||
value_states (`torch.Tensor`):
|
||||
The new value states to cache.
|
||||
layer_idx (`int`):
|
||||
The index of the layer to cache the states for.
|
||||
cache_kwargs (`Dict[str, Any]`, `optional`):
|
||||
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
|
||||
to know how where to write in the cache.
|
||||
|
||||
Return:
|
||||
A tuple containing the updated key and value states.
|
||||
"""
|
||||
k_out = self.k_caches[layer_idx]
|
||||
|
||||
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states.reshape(-1, *key_states.shape[2:])
|
||||
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states.reshape(-1, *value_states.shape[2:])
|
||||
return k_out
|
||||
|
||||
|
||||
def get_page_table(self, cache_position: torch.Tensor, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, bsz_tensors: torch.tensor):
|
||||
page_offset = cache_position % self.page_size
|
||||
page_idx_local = cache_position // self.page_size
|
||||
query_ids = torch.zeros_like(cache_position)
|
||||
for i in range(len(q_indptr) - 1):
|
||||
start_idx = q_indptr[i]
|
||||
end_idx = q_indptr[i + 1]
|
||||
query_ids[start_idx:end_idx] = i
|
||||
page_idx = torch.zeros_like(page_idx_local)
|
||||
for i in range(bsz_tensors[0]):
|
||||
query_id = query_ids[i]
|
||||
local_block = page_idx_local[i]
|
||||
start_block = kv_indptr[query_id]
|
||||
if local_block < kv_indptr[query_id + 1] - kv_indptr[query_id]:
|
||||
page_idx[i] = kv_indices[start_block + local_block]
|
||||
|
||||
return page_idx, page_offset
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue