diff --git a/ktransformers/server/backend/interfaces/balance_serve.py b/ktransformers/server/backend/interfaces/balance_serve.py index 008431e..0abe2e0 100644 --- a/ktransformers/server/backend/interfaces/balance_serve.py +++ b/ktransformers/server/backend/interfaces/balance_serve.py @@ -387,23 +387,11 @@ class BalanceServeInterface(BackendInterfaceBase): return input_ids def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List): - for m in messages: - if m["role"] == "system": - logger.warning(f'change {m["role"]} to user') - m["role"] = "user" - - new_messages = [messages[0]] - for m in messages[1:]: - if m["role"] == "user" and new_messages[-1]["role"] == "user": - logger.warning("merge two adjacent user messages") - new_messages[-1]["content"] += '\n' + m["content"] - else: - new_messages.append(m) - input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True) + input_str: str = self.tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True) # drop token in chat template if input_str.endswith('\n'): input_str = input_str[:-len('\n')] - input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device) + input_ids = self.tokenizer.encode(input_str, return_tensors="pt", add_special_tokens=False).to(self.args.device) logger.debug(f"get input ids of shape {input_ids.shape}") return input_ids