diff --git a/ktransformers/server/args.py b/ktransformers/server/args.py index 660a782..e90ca2f 100644 --- a/ktransformers/server/args.py +++ b/ktransformers/server/args.py @@ -90,7 +90,7 @@ class ArgumentParser: # user config parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key) parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm) - parser.add_argument("--force_think", type=bool, default=self.cfg.force_think) + parser.add_argument("--force_think", type=bool, default=self.cfg.user_force_think) # web config parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain) diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index fd997b4..01a6b84 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -324,7 +324,7 @@ class TransformersInterface(BackendInterfaceBase): #input_ids = torch.tensor([[6366]], device=input_ids.device) else: raise ValueError("local_messages should be List or str") - if Config().force_think: + if Config().user_force_think: token_thinks = torch.tensor([self.tokenizer.encode("\\n",add_special_tokens=False)]) input_ids = torch.cat( [input_ids, token_thinks], dim=1 @@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase): self.profiler.pause_timer("tokenize") self.profiler.create_and_start_timer("prefill") - if Config().force_think: + if Config().user_force_think: print("\n") yield "\n" for t in self.prefill(input_ids, self.check_is_new(thread_id)):