diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 4e006b6..676ea67 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -160,7 +160,7 @@ def local_chat( messages, add_generation_prompt=True, return_tensors="pt" ) if force_think: - token_thinks = torch.tensor([tokenizer.encode("\\n",add_special_tokens=False)]) + token_thinks = torch.tensor([tokenizer.encode("\\n",add_special_tokens=False)],device=input_tensor.device) input_tensor = torch.cat( [input_tensor, token_thinks], dim=1 ) diff --git a/ktransformers/server/args.py b/ktransformers/server/args.py index e90ca2f..44fe7d2 100644 --- a/ktransformers/server/args.py +++ b/ktransformers/server/args.py @@ -122,4 +122,5 @@ class ArgumentParser: self.cfg.server_ip = args.host self.cfg.server_port = args.port self.cfg.backend_type = args.type + self.cfg.user_force_think = args.force_think return args diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index 01a6b84..f18581a 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -325,7 +325,7 @@ class TransformersInterface(BackendInterfaceBase): else: raise ValueError("local_messages should be List or str") if Config().user_force_think: - token_thinks = torch.tensor([self.tokenizer.encode("\\n",add_special_tokens=False)]) + token_thinks = torch.tensor([self.tokenizer.encode("\\n",add_special_tokens=False)],device=input_ids.device) input_ids = torch.cat( [input_ids, token_thinks], dim=1 ) @@ -334,8 +334,9 @@ class TransformersInterface(BackendInterfaceBase): self.profiler.create_and_start_timer("prefill") if Config().user_force_think: - print("\n") - yield "\n" + t = "\n" + print(t,end="",flush=True) + yield t for t in self.prefill(input_ids, self.check_is_new(thread_id)): if t is not None: print(t, end="",flush=True) @@ -346,7 +347,7 @@ class TransformersInterface(BackendInterfaceBase): for t in self.generate(): if t is not None: print(t, end="",flush=True) - yield t + yield t print("") self.profiler.pause_timer("decode") self.report_last_time_performance()