support safetensor load, delete architectures argument

This commit is contained in:
qiyuxinlin 2025-05-09 10:38:29 +00:00
parent 900a7f7c3e
commit c6aa379de2
30 changed files with 1075 additions and 328 deletions

View file

@ -39,6 +39,17 @@ def pad_num_tokens(num_tokens):
def deduplicate_and_sort(lst):
return sorted(set(lst))
def generate_cuda_graphs(chunk_size: int) -> list:
# 如果输入不符合要求assert掉
assert chunk_size <= 1024 or chunk_size % 1024 == 0, "chunk_size must <= 1024 or a multiple of 1024"
base_list = [1, 2, 3, Config().max_batch_size, 64, 256, 512, chunk_size]
if chunk_size <= 1024:
return base_list
multiples = [i for i in range(1024, chunk_size + 1, 1024)]
return deduplicate_and_sort(base_list + multiples)
class ModelRunner:
"""A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile."""
@ -56,7 +67,7 @@ class ModelRunner:
self.features_buf = None
self.output = None
self.graph_memory_pool = None
self.cuda_graphs = deduplicate_and_sort([1, 2, 3, Config().max_batch_size, 64, Config().chunk_size])
self.cuda_graphs = generate_cuda_graphs(Config().chunk_size)
self.use_cuda_graph = use_cuda_graph
self.model_time = 0
self.page_size = page_size