use generation config from json file in official repo

This commit is contained in:
Atream 2025-02-27 11:48:34 +00:00
parent 9660b2cc1e
commit e645d84794
4 changed files with 57 additions and 21 deletions

View file

@ -122,7 +122,7 @@ class MLAWrapper():
if kv_indices is None:
assert self.max_batch_size == 1
kv_indices = self.kv_indices_buf
self.wrapper.plan(
qo_indptr,
kv_indptr,
@ -139,6 +139,11 @@ class MLAWrapper():
)
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
#print("run")
#print(self.wrapper._qo_indptr_buf)
#print(self.wrapper._kv_indptr_buf)
#print(self.wrapper._kv_indices_buf)
#print(self.wrapper._kv_len_arr_buf)
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
class MLAWrapperSingleton():
@ -201,11 +206,12 @@ class MLAWrapperSingleton():
if __name__ == "__main__":
max_batch_size = 1
max_pages = 1
max_pages = 128
page_size = 64
num_heads = 128
q_len = 10
kv_len = 2069
q_len = 1
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
@ -218,7 +224,7 @@ if __name__ == "__main__":
max_pages,
)
kv_len_arr = torch.tensor([q_len], dtype=torch.int32, device="cuda")
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
wrapper.plan(
qo_indptr,
@ -244,15 +250,15 @@ if __name__ == "__main__":
)
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
print(k[:10].shape)
print(v[:10].shape)
print(k[:kv_len].shape)
print(v[:kv_len].shape)
attn_ref, lse_ref = attention_ref(
max_batch_size,
torch.cat([q_nope, q_pe], dim=-1),
k[:10],
v[:10],
False,
k[:kv_len],
v[:kv_len],
True,
192 ** (-0.5)
)
print(attn_ref.shape)