mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
linux support triton MLA kernel
This commit is contained in:
parent
bb35dc5b0d
commit
1084d4e4b4
2 changed files with 198 additions and 61 deletions
|
@ -53,8 +53,9 @@ class StaticCache(transformers.StaticCache):
|
|||
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
|
||||
self.page_size = 64
|
||||
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
|
||||
key_shape = (self.max_pages, self.page_size, 1, config.qk_rope_head_dim)
|
||||
value_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank)
|
||||
latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
# TODO: support real page table
|
||||
self.page_table_map = dict()
|
||||
self.page_table_list = []
|
||||
|
@ -88,10 +89,17 @@ class StaticCache(transformers.StaticCache):
|
|||
target_device = device[f"blk.{idx}.self_attn"]["generate_device"]
|
||||
else:
|
||||
target_device = device
|
||||
new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=target_device)
|
||||
new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=target_device)
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
||||
|
||||
if self.is_MLA:
|
||||
new_layer_key_cache = torch.zeros(latent_shape, dtype=self.dtype, device=target_device)
|
||||
new_layer_value_cache = None
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
else:
|
||||
new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=target_device)
|
||||
new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=target_device)
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
||||
|
||||
self.key_cache.append(new_layer_key_cache)
|
||||
self.value_cache.append(new_layer_value_cache)
|
||||
self.past_tokens.append(0)
|
||||
|
@ -129,11 +137,12 @@ class StaticCache(transformers.StaticCache):
|
|||
if self.is_MLA:
|
||||
page_idx = cache_position // self.page_size
|
||||
page_offset = cache_position % self.page_size
|
||||
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||||
#print("page_idx", page_idx)
|
||||
#print("page_offset", page_offset)
|
||||
k_out[page_idx, page_offset, ...] = key_states
|
||||
v_out[page_idx, page_offset, ...] = value_states
|
||||
return k_out, v_out, self.page_table_list[layer_idx]
|
||||
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
|
||||
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
|
||||
return k_out, self.page_table_list[layer_idx]
|
||||
else:
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue