mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
support absorb for prefill long context
This commit is contained in:
parent
e9b1216a9a
commit
f4c198bd42
8 changed files with 93 additions and 33 deletions
|
@ -14,6 +14,7 @@ from ktransformers.models.custom_cache import StaticCache
|
|||
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
||||
from ktransformers.local_chat import custom_models, default_optimize_rules
|
||||
from ktransformers.util.utils import get_device
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
||||
|
||||
|
||||
warm_uped = False
|
||||
|
@ -186,6 +187,8 @@ class KTransformersInterface(TransformersInterface):
|
|||
input_ids = input_ids.to("cpu")
|
||||
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
||||
torch.cuda.set_device(device)
|
||||
if flashinfer_enabled:
|
||||
MLAWrapperSingleton.need_plan_all()
|
||||
if self.use_static_cache:
|
||||
logits = self.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -198,6 +201,8 @@ class KTransformersInterface(TransformersInterface):
|
|||
else:
|
||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||
|
||||
if flashinfer_enabled:
|
||||
MLAWrapperSingleton.reset_buffer()
|
||||
self.prepare_logits_wrapper(input_ids, device)
|
||||
next_token = self.logits_to_token(logits[0, -1, :])
|
||||
yield self.append_new_tokens(next_token)
|
||||
|
|
|
@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
for i in range(1, self.args.max_new_tokens):
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
||||
if i > 1 and flashinfer_enabled:
|
||||
if flashinfer_enabled:
|
||||
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue