support npu

This commit is contained in:
djw 2025-07-21 12:26:14 +00:00
parent dd0e41b3b8
commit 7d51a13c9b
34 changed files with 14004 additions and 5626 deletions

View file

@ -16,6 +16,16 @@ try:
from ktransformers.server.balance_serve.settings import sched_ext
except:
print("no balance_serve")
try:
import torch_npu
from ktransformers.util import utils
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
class StaticCache(transformers.StaticCache):
"""
Static Cache class to be used with `torch.compile(model)`.
@ -37,6 +47,10 @@ class StaticCache(transformers.StaticCache):
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device: torch.device| dict, dtype=None) -> None:
Cache.__init__(self)
self.max_batch_size = max_batch_size
if use_torch_npu:
self.position = [0]
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
if config.architectures[0] == "DeepseekV3ForCausalLM":
@ -56,8 +70,18 @@ class StaticCache(transformers.StaticCache):
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
if config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM":
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
self.page_size = 64
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
if use_torch_npu:
self.page_size = 128
self.page_size_tensor = torch.tensor(
self.page_size,
dtype=torch.int32,
).npu()
self.max_pages_per_batch = (self.max_cache_len + self.page_size - 1) // self.page_size
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size * self.max_batch_size
else:
self.page_size = 64
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
self.kv_lora_rank = config.kv_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
@ -71,9 +95,14 @@ class StaticCache(transformers.StaticCache):
target_device = device
if target_device not in self.page_table_map:
page_table = torch.zeros((max_batch_size, self.max_pages), dtype=torch.int32, device=target_device)
for seq_id in range(max_batch_size):
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages, seq_id * self.max_pages + self.max_pages, dtype=torch.int32, device=target_device)
if use_torch_npu:
page_table = torch.zeros((max_batch_size, self.max_pages_per_batch), dtype=torch.int32, device=target_device)
for seq_id in range(max_batch_size):
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages_per_batch, seq_id * self.max_pages_per_batch + self.max_pages_per_batch, dtype=torch.int32, device=target_device)
else:
page_table = torch.zeros((max_batch_size, self.max_pages), dtype=torch.int32, device=target_device)
for seq_id in range(max_batch_size):
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages, seq_id * self.max_pages + self.max_pages, dtype=torch.int32, device=target_device)
self.page_table_map[target_device] = page_table
self.page_table_list.append(self.page_table_map[target_device])
@ -140,11 +169,24 @@ class StaticCache(transformers.StaticCache):
self.past_tokens[layer_idx] += cache_position.size(0)
#print(cache_position)
if self.is_MLA:
page_idx = cache_position // self.page_size
page_offset = cache_position % self.page_size
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
if use_torch_npu:
page_idx = cache_position // self.page_size_tensor
page_offset = cache_position % self.page_size_tensor
page_idx = page_idx.unsqueeze(0).expand(self.max_batch_size, -1)
page_offset = page_offset.unsqueeze(0).expand(self.max_batch_size, -1)
page_idx_offset = torch.arange(self.max_batch_size, device=page_idx.device) * self.max_pages_per_batch
page_idx = page_idx + page_idx_offset.unsqueeze(1)
combined = torch.cat([key_states, value_states], dim=-1)
combined = combined.contiguous()
else:
page_idx = cache_position // self.page_size
page_offset = cache_position % self.page_size
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
return k_out, self.page_table_list[layer_idx]
else:
k_out[:, :, cache_position] = key_states
@ -178,6 +220,9 @@ class StaticCache(transformers.StaticCache):
if self.value_cache[layer_idx] is not None:
self.value_cache[layer_idx].zero_()
self.past_tokens[layer_idx] = 0
if use_torch_npu:
self.position = [0]
def remove_suffix(self, start_pos):
for layer_idx in range(len(self.key_cache)):