mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
Implement multi-batch support for v2, v3, and r1 models with backend_type configured as ktransformers.
This commit is contained in:
parent
890b0f1622
commit
a6ab9e349c
6 changed files with 383 additions and 52 deletions
|
@ -58,7 +58,11 @@ class StaticCache(transformers.StaticCache):
|
|||
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
|
||||
self.page_size = 64
|
||||
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
|
||||
latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||||
from ktransformers.server.backend.interfaces.ktransformers import multi_batch_enabled
|
||||
if multi_batch_enabled:
|
||||
latent_shape = (max_batch_size, self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||||
else:
|
||||
latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
# TODO: support real page table
|
||||
|
@ -143,8 +147,14 @@ class StaticCache(transformers.StaticCache):
|
|||
page_idx = cache_position // self.page_size
|
||||
page_offset = cache_position % self.page_size
|
||||
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||||
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
|
||||
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
|
||||
from ktransformers.server.backend.interfaces.ktransformers import multi_batch_enabled
|
||||
if multi_batch_enabled:
|
||||
batch_size = key_states.size(0)
|
||||
k_out[:batch_size, page_idx, page_offset, :, :self.kv_lora_rank] = key_states
|
||||
k_out[:batch_size, page_idx, page_offset, :, self.kv_lora_rank:] = value_states
|
||||
else:
|
||||
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
|
||||
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
|
||||
return k_out, self.page_table_list[layer_idx]
|
||||
else:
|
||||
k_out[:, :, cache_position] = key_states
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue