support safetensor load, delete architectures argument

This commit is contained in:
qiyuxinlin 2025-05-09 10:38:29 +00:00
parent 900a7f7c3e
commit c6aa379de2
30 changed files with 1075 additions and 328 deletions

View file

@ -40,7 +40,7 @@ class flashInferAttn():
self.kv_layout = kv_layout
self.use_cuda_graph = use_cuda_graph
if flashInferAttn.float_workspace_buffer is None:
flashInferAttn.float_workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.uint8, device=device)
flashInferAttn.float_workspace_buffer = torch.empty(max_batch_token * 1024 * 1024, dtype=torch.uint8, device=device)
self.qo_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)
self.paged_kv_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)
self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device)