fix chat template encoding

This commit is contained in:
Atream 2025-04-24 12:44:16 +08:00 committed by GitHub
parent 449a83dff6
commit 46493789eb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -387,23 +387,11 @@ class BalanceServeInterface(BackendInterfaceBase):
return input_ids
def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):
for m in messages:
if m["role"] == "system":
logger.warning(f'change {m["role"]} to user')
m["role"] = "user"
new_messages = [messages[0]]
for m in messages[1:]:
if m["role"] == "user" and new_messages[-1]["role"] == "user":
logger.warning("merge two adjacent user messages")
new_messages[-1]["content"] += '\n' + m["content"]
else:
new_messages.append(m)
input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
input_str: str = self.tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)
# drop <think> token in chat template
if input_str.endswith('<think>\n'):
input_str = input_str[:-len('<think>\n')]
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
input_ids = self.tokenizer.encode(input_str, return_tensors="pt", add_special_tokens=False).to(self.args.device)
logger.debug(f"get input ids of shape {input_ids.shape}")
return input_ids