mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-11 07:44:35 +00:00
support absorb for prefill long context
This commit is contained in:
parent
e9b1216a9a
commit
f4c198bd42
8 changed files with 93 additions and 33 deletions
|
@ -9,7 +9,7 @@ flashinfer_enabled = False
|
|||
|
||||
try:
|
||||
import flashinfer
|
||||
flashinfer_enabled = False # disabled now, TODO:use new version of flashinfer and enable
|
||||
flashinfer_enabled = True
|
||||
print("found flashinfer")
|
||||
|
||||
except ImportError:
|
||||
|
@ -132,14 +132,14 @@ class MLAWrapper():
|
|||
head_dim_ckv,
|
||||
head_dim_kpe,
|
||||
page_size,
|
||||
False, # causal is False for decoding
|
||||
True, # causal
|
||||
sm_scale,
|
||||
q_data_type,
|
||||
kv_data_type,
|
||||
)
|
||||
|
||||
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
|
||||
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse)
|
||||
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
|
||||
|
||||
class MLAWrapperSingleton():
|
||||
wrappers:dict = {}
|
||||
|
@ -179,6 +179,17 @@ class MLAWrapperSingleton():
|
|||
sm_scale,
|
||||
q_data_type,
|
||||
kv_data_type,)
|
||||
wrapper.need_plan = False
|
||||
|
||||
@classmethod
|
||||
def need_plan_all(cls):
|
||||
for device, wrapper in cls.wrappers.items():
|
||||
wrapper.need_plan = True
|
||||
|
||||
@classmethod
|
||||
def reset_buffer(cls):
|
||||
for device, wrapper in cls.wrappers.items():
|
||||
wrapper.qo_indptr_buf[1] = 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -187,8 +198,9 @@ if __name__ == "__main__":
|
|||
page_size = 64
|
||||
num_heads = 128
|
||||
|
||||
q_nope = torch.randn((1, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||
q_pe = torch.randn((1, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||
q_len = 10
|
||||
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
|
||||
k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
|
@ -199,10 +211,10 @@ if __name__ == "__main__":
|
|||
max_pages,
|
||||
)
|
||||
|
||||
kv_len_arr = torch.tensor([10], dtype=torch.int32, device="cuda")
|
||||
|
||||
kv_len_arr = torch.tensor([q_len], dtype=torch.int32, device="cuda")
|
||||
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
|
||||
wrapper.plan(
|
||||
None,
|
||||
qo_indptr,
|
||||
None,
|
||||
None,
|
||||
kv_len_arr,
|
||||
|
@ -216,6 +228,7 @@ if __name__ == "__main__":
|
|||
)
|
||||
|
||||
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
|
||||
print(attn_output.shape)
|
||||
|
||||
k = (
|
||||
torch.cat([ckv, k_pe], dim=-1)
|
||||
|
@ -235,6 +248,7 @@ if __name__ == "__main__":
|
|||
False,
|
||||
192 ** (-0.5)
|
||||
)
|
||||
print(attn_ref.shape)
|
||||
|
||||
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
|
||||
print("test past")
|
Loading…
Add table
Add a link
Reference in a new issue