mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
support deepseekv3; runable but have precition problem
This commit is contained in:
parent
de7e892f72
commit
476b1d8dc6
13 changed files with 2178 additions and 24 deletions
|
@ -34,9 +34,12 @@ class StaticCache(transformers.StaticCache):
|
|||
self.max_batch_size = max_batch_size
|
||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||||
self.head_dim = (
|
||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
if config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||
self.head_dim = config.qk_rope_head_dim
|
||||
else:
|
||||
self.head_dim = (
|
||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
|
||||
self.dtype = dtype if dtype is not None else torch.float32
|
||||
self.num_key_value_heads = (
|
||||
|
@ -46,7 +49,7 @@ class StaticCache(transformers.StaticCache):
|
|||
self.key_cache: List[torch.Tensor] = []
|
||||
self.value_cache: List[torch.Tensor] = []
|
||||
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||
if config.architectures[0] == "DeepseekV2ForCausalLM":
|
||||
if config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
|
||||
# key_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.qk_rope_head_dim + config.qk_nope_head_dim)
|
||||
# value_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.v_head_dim)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue