mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
support npu
This commit is contained in:
parent
dd0e41b3b8
commit
7d51a13c9b
34 changed files with 14004 additions and 5626 deletions
|
@ -16,6 +16,16 @@ try:
|
|||
from ktransformers.server.balance_serve.settings import sched_ext
|
||||
except:
|
||||
print("no balance_serve")
|
||||
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
from ktransformers.util import utils
|
||||
|
||||
use_torch_npu = torch_npu.npu.is_available()
|
||||
except:
|
||||
use_torch_npu = False
|
||||
|
||||
class StaticCache(transformers.StaticCache):
|
||||
"""
|
||||
Static Cache class to be used with `torch.compile(model)`.
|
||||
|
@ -37,6 +47,10 @@ class StaticCache(transformers.StaticCache):
|
|||
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device: torch.device| dict, dtype=None) -> None:
|
||||
Cache.__init__(self)
|
||||
self.max_batch_size = max_batch_size
|
||||
|
||||
if use_torch_npu:
|
||||
self.position = [0]
|
||||
|
||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||||
if config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||
|
@ -56,8 +70,18 @@ class StaticCache(transformers.StaticCache):
|
|||
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||
if config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
|
||||
self.page_size = 64
|
||||
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
|
||||
|
||||
if use_torch_npu:
|
||||
self.page_size = 128
|
||||
self.page_size_tensor = torch.tensor(
|
||||
self.page_size,
|
||||
dtype=torch.int32,
|
||||
).npu()
|
||||
self.max_pages_per_batch = (self.max_cache_len + self.page_size - 1) // self.page_size
|
||||
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size * self.max_batch_size
|
||||
else:
|
||||
self.page_size = 64
|
||||
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
|
||||
latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
|
@ -71,9 +95,14 @@ class StaticCache(transformers.StaticCache):
|
|||
target_device = device
|
||||
|
||||
if target_device not in self.page_table_map:
|
||||
page_table = torch.zeros((max_batch_size, self.max_pages), dtype=torch.int32, device=target_device)
|
||||
for seq_id in range(max_batch_size):
|
||||
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages, seq_id * self.max_pages + self.max_pages, dtype=torch.int32, device=target_device)
|
||||
if use_torch_npu:
|
||||
page_table = torch.zeros((max_batch_size, self.max_pages_per_batch), dtype=torch.int32, device=target_device)
|
||||
for seq_id in range(max_batch_size):
|
||||
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages_per_batch, seq_id * self.max_pages_per_batch + self.max_pages_per_batch, dtype=torch.int32, device=target_device)
|
||||
else:
|
||||
page_table = torch.zeros((max_batch_size, self.max_pages), dtype=torch.int32, device=target_device)
|
||||
for seq_id in range(max_batch_size):
|
||||
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages, seq_id * self.max_pages + self.max_pages, dtype=torch.int32, device=target_device)
|
||||
self.page_table_map[target_device] = page_table
|
||||
|
||||
self.page_table_list.append(self.page_table_map[target_device])
|
||||
|
@ -140,11 +169,24 @@ class StaticCache(transformers.StaticCache):
|
|||
self.past_tokens[layer_idx] += cache_position.size(0)
|
||||
#print(cache_position)
|
||||
if self.is_MLA:
|
||||
page_idx = cache_position // self.page_size
|
||||
page_offset = cache_position % self.page_size
|
||||
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||||
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
|
||||
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
|
||||
if use_torch_npu:
|
||||
page_idx = cache_position // self.page_size_tensor
|
||||
page_offset = cache_position % self.page_size_tensor
|
||||
|
||||
page_idx = page_idx.unsqueeze(0).expand(self.max_batch_size, -1)
|
||||
page_offset = page_offset.unsqueeze(0).expand(self.max_batch_size, -1)
|
||||
|
||||
page_idx_offset = torch.arange(self.max_batch_size, device=page_idx.device) * self.max_pages_per_batch
|
||||
page_idx = page_idx + page_idx_offset.unsqueeze(1)
|
||||
|
||||
combined = torch.cat([key_states, value_states], dim=-1)
|
||||
combined = combined.contiguous()
|
||||
else:
|
||||
page_idx = cache_position // self.page_size
|
||||
page_offset = cache_position % self.page_size
|
||||
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||||
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
|
||||
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
|
||||
return k_out, self.page_table_list[layer_idx]
|
||||
else:
|
||||
k_out[:, :, cache_position] = key_states
|
||||
|
@ -178,6 +220,9 @@ class StaticCache(transformers.StaticCache):
|
|||
if self.value_cache[layer_idx] is not None:
|
||||
self.value_cache[layer_idx].zero_()
|
||||
self.past_tokens[layer_idx] = 0
|
||||
|
||||
if use_torch_npu:
|
||||
self.position = [0]
|
||||
|
||||
def remove_suffix(self, start_pos):
|
||||
for layer_idx in range(len(self.key_cache)):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue