add balance-serve, support concurrence

This commit is contained in:
Atream 2025-03-31 22:55:32 +08:00
parent 8d0292aa44
commit 25cee5810e
196 changed files with 22077 additions and 565 deletions

View file

@ -8,9 +8,11 @@ Version : 0.1.0
# Copyright 2018- The Hugging Face team. All rights reserved.
# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
import torch
import torch.nn as nn
import transformers
from transformers import Cache, PretrainedConfig
from typing import List, Optional, Dict, Any, Tuple
from ktransformers.server.balance_serve.settings import sched_ext
class StaticCache(transformers.StaticCache):
"""
Static Cache class to be used with `torch.compile(model)`.
@ -188,3 +190,85 @@ class StaticCache(transformers.StaticCache):
def get_max_cache_shape(self) -> Tuple[int, int, int, int]:
"""Returns the maximum shape of the cache."""
return self.max_cache_len
class KDeepSeekV3Cache(nn.Module):
def __init__(
self,
config: PretrainedConfig,
page_size: int = 256,
dtype=torch.bfloat16,
device=torch.device("cuda:0"),
):
super().__init__()
self.config = config
self.dtype = dtype
self.device = device
self.kv_lora_rank = config.kv_lora_rank
self.page_size = page_size
self.k_caches = []
self.v_caches = []
def load(self, inference_context: sched_ext.InferenceContext):
for i in range(self.config.num_hidden_layers):
self.k_caches.append(
inference_context.k_cache[0][i]
)
self.max_cache_len = self.k_caches[0].shape[0]*self.k_caches[0].shape[1]
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
page_idx: torch.Tensor,
page_offset: torch.Tensor,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
to know how where to write in the cache.
Return:
A tuple containing the updated key and value states.
"""
k_out = self.k_caches[layer_idx]
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states.reshape(-1, *key_states.shape[2:])
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states.reshape(-1, *value_states.shape[2:])
return k_out
def get_page_table(self, cache_position: torch.Tensor, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, bsz_tensors: torch.tensor):
page_offset = cache_position % self.page_size
page_idx_local = cache_position // self.page_size
query_ids = torch.zeros_like(cache_position)
for i in range(len(q_indptr) - 1):
start_idx = q_indptr[i]
end_idx = q_indptr[i + 1]
query_ids[start_idx:end_idx] = i
page_idx = torch.zeros_like(page_idx_local)
for i in range(bsz_tensors[0]):
query_id = query_ids[i]
local_block = page_idx_local[i]
start_block = kv_indptr[query_id]
if local_block < kv_indptr[query_id + 1] - kv_indptr[query_id]:
page_idx[i] = kv_indices[start_block + local_block]
return page_idx, page_offset