Implement multi-batch support for v2, v3, and r1 models with backend_type configured as ktransformers.

This commit is contained in:
jiafei96 2025-07-09 09:09:47 +00:00
parent 890b0f1622
commit a6ab9e349c
6 changed files with 383 additions and 52 deletions

View file

@ -58,7 +58,11 @@ class StaticCache(transformers.StaticCache):
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
self.page_size = 64
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
from ktransformers.server.backend.interfaces.ktransformers import multi_batch_enabled
if multi_batch_enabled:
latent_shape = (max_batch_size, self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
else:
latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
self.kv_lora_rank = config.kv_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
# TODO: support real page table
@ -143,8 +147,14 @@ class StaticCache(transformers.StaticCache):
page_idx = cache_position // self.page_size
page_offset = cache_position % self.page_size
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
from ktransformers.server.backend.interfaces.ktransformers import multi_batch_enabled
if multi_batch_enabled:
batch_size = key_states.size(0)
k_out[:batch_size, page_idx, page_offset, :, :self.kv_lora_rank] = key_states
k_out[:batch_size, page_idx, page_offset, :, self.kv_lora_rank:] = value_states
else:
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
return k_out, self.page_table_list[layer_idx]
else:
k_out[:, :, cache_position] = key_states