mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 14:51:06 +00:00
init support for MLA using Attention kernel
This commit is contained in:
parent
62011fd63e
commit
bb35dc5b0d
5 changed files with 551 additions and 262 deletions
|
@ -51,13 +51,33 @@ class StaticCache(transformers.StaticCache):
|
|||
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||
if config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
|
||||
# key_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.qk_rope_head_dim + config.qk_nope_head_dim)
|
||||
# value_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.v_head_dim)
|
||||
key_shape = (max_batch_size, 1, self.max_cache_len, config.qk_rope_head_dim)
|
||||
value_shape = (max_batch_size, 1, self.max_cache_len, config.kv_lora_rank)
|
||||
self.page_size = 64
|
||||
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
|
||||
key_shape = (self.max_pages, self.page_size, 1, config.qk_rope_head_dim)
|
||||
value_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank)
|
||||
# TODO: support real page table
|
||||
self.page_table_map = dict()
|
||||
self.page_table_list = []
|
||||
for idx in range(config.num_hidden_layers):
|
||||
if isinstance(device, dict):
|
||||
target_device = device[f"blk.{idx}.self_attn"]["generate_device"]
|
||||
else:
|
||||
target_device = device
|
||||
|
||||
if target_device not in self.page_table_map:
|
||||
page_table = torch.zeros((max_batch_size, self.max_pages), dtype=torch.int32, device=target_device)
|
||||
for seq_id in range(max_batch_size):
|
||||
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages, seq_id * self.max_pages + self.max_pages, dtype=torch.int32, device=target_device)
|
||||
self.page_table_map[target_device] = page_table
|
||||
|
||||
self.page_table_list.append(self.page_table_map[target_device])
|
||||
|
||||
self.is_MLA = True
|
||||
self.is_page = True
|
||||
else:
|
||||
key_shape = cache_shape
|
||||
value_shape = cache_shape
|
||||
self.is_MLA = False
|
||||
|
||||
self.past_tokens = []
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
|
@ -104,11 +124,20 @@ class StaticCache(transformers.StaticCache):
|
|||
cache_position = cache_kwargs.get("cache_position")
|
||||
k_out = self.key_cache[layer_idx]
|
||||
v_out = self.value_cache[layer_idx]
|
||||
#print(cache_position)
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
self.past_tokens[layer_idx] += cache_position.size(0)
|
||||
return k_out, v_out
|
||||
#print(cache_position)
|
||||
if self.is_MLA:
|
||||
page_idx = cache_position // self.page_size
|
||||
page_offset = cache_position % self.page_size
|
||||
#print("page_idx", page_idx)
|
||||
#print("page_offset", page_offset)
|
||||
k_out[page_idx, page_offset, ...] = key_states
|
||||
v_out[page_idx, page_offset, ...] = value_states
|
||||
return k_out, v_out, self.page_table_list[layer_idx]
|
||||
else:
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
return k_out, v_out
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
"""Returns the sequence length of the cached states that were seen by the model."""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue