diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index 33331d0..5086a3b 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -170,7 +170,7 @@ class TransformersInterface(BackendInterfaceBase): for m in messages[1:]: if m["role"] == "user" and new_messages[-1]["role"] == "user": logger.warning("merge two adjacent user messages") - new_messages[-1]["content"] += m["content"] + new_messages[-1]["content"] += '\n' + m["content"] else: new_messages.append(m) # if (self.last_request_id is not None) and self.last_request_id == thread_id: @@ -179,7 +179,11 @@ class TransformersInterface(BackendInterfaceBase): # input_ids = self.tokenizer.apply_chat_template( # new_messages, return_tensors="pt", add_generation_prompt=True # ).to(self.args.device) - input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device) + input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True) + # drop token in chat template + if input_str.endswith('\n'): + input_str = input_str[:-len('\n')] + input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device) if (self.last_request_id is not None) and self.last_request_id == thread_id: x = self.generated_ids[:,:self.seq_length] y = input_ids[:,:self.seq_length] @@ -360,6 +364,7 @@ class TransformersInterface(BackendInterfaceBase): #input_ids = torch.tensor([[6366]], device=input_ids.device) else: raise ValueError("local_messages should be List or str") + if Config().user_force_think: token_thinks = torch.tensor([self.tokenizer.encode("\n",add_special_tokens=False)],device=input_ids.device) input_ids = torch.cat(