diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index cdc934c..75e12fb 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -70,6 +70,7 @@ def local_chat( torch.set_grad_enabled(False) Config().cpu_infer = cpu_infer + Config().chunk_size = chunk_size if torch.xpu.is_available(): use_cuda_graph = False diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 33e6a1d..7a40168 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -213,7 +213,7 @@ class KExpertsCPU(KExpertsBase): self.config.num_experts_per_tok, self.config.hidden_size, self.config.moe_intermediate_size, - max(cuda_graphs), + max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size, gate_ptr, up_ptr, down_ptr, @@ -231,7 +231,7 @@ class KExpertsCPU(KExpertsBase): self.config.num_experts_per_tok, self.config.hidden_size, self.config.moe_intermediate_size, - max(cuda_graphs), + max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size, gate_ptr, up_ptr, down_ptr,