mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
[ADD] support multi-gpu qlen>1 q5_k
This commit is contained in:
parent
f293803156
commit
f5f79f5c0e
63 changed files with 3271 additions and 1285 deletions
|
@ -22,13 +22,14 @@ class StaticCache(transformers.StaticCache):
|
|||
The maximum batch size with which the model will be used.
|
||||
max_cache_len (`int`):
|
||||
The maximum sequence length with which the model will be used.
|
||||
device (`torch.device`):
|
||||
device (`torch.device` or `dict`):
|
||||
The device on which the cache should be initialized. Should be the same as the layer.
|
||||
If a `dict`, it should contain the `device` key with the device name as the value.
|
||||
dtype (*optional*, defaults to `torch.float32`):
|
||||
The default `dtype` to use when initializing the layer.
|
||||
"""
|
||||
|
||||
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
|
||||
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device: torch.device| dict, dtype=None) -> None:
|
||||
Cache.__init__(self)
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||
|
@ -46,6 +47,7 @@ class StaticCache(transformers.StaticCache):
|
|||
self.value_cache: List[torch.Tensor] = []
|
||||
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||
if config.architectures[0] == "DeepseekV2ForCausalLM":
|
||||
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
|
||||
# key_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.qk_rope_head_dim + config.qk_nope_head_dim)
|
||||
# value_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.v_head_dim)
|
||||
key_shape = (max_batch_size, 1, self.max_cache_len, config.qk_rope_head_dim)
|
||||
|
@ -56,11 +58,15 @@ class StaticCache(transformers.StaticCache):
|
|||
|
||||
self.past_tokens = []
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
for _ in range(self.num_hidden_layers):
|
||||
for idx in range(self.num_hidden_layers):
|
||||
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||
# breaks when updating the cache.
|
||||
new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=device)
|
||||
new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=device)
|
||||
if isinstance(device, dict):
|
||||
target_device = device[f"blk.{idx}.self_attn"]["generate_device"]
|
||||
else:
|
||||
target_device = device
|
||||
new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=target_device)
|
||||
new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=target_device)
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
||||
self.key_cache.append(new_layer_key_cache)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue