mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
support safetensor load, delete architectures argument
This commit is contained in:
parent
900a7f7c3e
commit
c6aa379de2
30 changed files with 1075 additions and 328 deletions
|
@ -39,6 +39,17 @@ def pad_num_tokens(num_tokens):
|
|||
|
||||
def deduplicate_and_sort(lst):
|
||||
return sorted(set(lst))
|
||||
def generate_cuda_graphs(chunk_size: int) -> list:
|
||||
# 如果输入不符合要求,assert掉
|
||||
assert chunk_size <= 1024 or chunk_size % 1024 == 0, "chunk_size must <= 1024 or a multiple of 1024"
|
||||
base_list = [1, 2, 3, Config().max_batch_size, 64, 256, 512, chunk_size]
|
||||
|
||||
if chunk_size <= 1024:
|
||||
return base_list
|
||||
|
||||
multiples = [i for i in range(1024, chunk_size + 1, 1024)]
|
||||
|
||||
return deduplicate_and_sort(base_list + multiples)
|
||||
class ModelRunner:
|
||||
"""A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile."""
|
||||
|
||||
|
@ -56,7 +67,7 @@ class ModelRunner:
|
|||
self.features_buf = None
|
||||
self.output = None
|
||||
self.graph_memory_pool = None
|
||||
self.cuda_graphs = deduplicate_and_sort([1, 2, 3, Config().max_batch_size, 64, Config().chunk_size])
|
||||
self.cuda_graphs = generate_cuda_graphs(Config().chunk_size)
|
||||
self.use_cuda_graph = use_cuda_graph
|
||||
self.model_time = 0
|
||||
self.page_size = page_size
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue