support absorb for prefill long context

This commit is contained in:
Atream 2025-02-25 08:52:02 +00:00
parent e9b1216a9a
commit f4c198bd42
8 changed files with 93 additions and 33 deletions

View file

@ -9,7 +9,7 @@ flashinfer_enabled = False
try:
import flashinfer
flashinfer_enabled = False # disabled now, TODO:use new version of flashinfer and enable
flashinfer_enabled = True
print("found flashinfer")
except ImportError:
@ -132,14 +132,14 @@ class MLAWrapper():
head_dim_ckv,
head_dim_kpe,
page_size,
False, # causal is False for decoding
True, # causal
sm_scale,
q_data_type,
kv_data_type,
)
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse)
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
class MLAWrapperSingleton():
wrappers:dict = {}
@ -179,6 +179,17 @@ class MLAWrapperSingleton():
sm_scale,
q_data_type,
kv_data_type,)
wrapper.need_plan = False
@classmethod
def need_plan_all(cls):
for device, wrapper in cls.wrappers.items():
wrapper.need_plan = True
@classmethod
def reset_buffer(cls):
for device, wrapper in cls.wrappers.items():
wrapper.qo_indptr_buf[1] = 1
if __name__ == "__main__":
@ -187,8 +198,9 @@ if __name__ == "__main__":
page_size = 64
num_heads = 128
q_nope = torch.randn((1, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe = torch.randn((1, num_heads, 64), dtype=torch.bfloat16, device="cuda")
q_len = 10
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda")
@ -199,10 +211,10 @@ if __name__ == "__main__":
max_pages,
)
kv_len_arr = torch.tensor([10], dtype=torch.int32, device="cuda")
kv_len_arr = torch.tensor([q_len], dtype=torch.int32, device="cuda")
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
wrapper.plan(
None,
qo_indptr,
None,
None,
kv_len_arr,
@ -216,6 +228,7 @@ if __name__ == "__main__":
)
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
print(attn_output.shape)
k = (
torch.cat([ckv, k_pe], dim=-1)
@ -235,6 +248,7 @@ if __name__ == "__main__":
False,
192 ** (-0.5)
)
print(attn_ref.shape)
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
print("test past")