mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-07 04:59:55 +00:00
fix chat template encoding
This commit is contained in:
parent
449a83dff6
commit
46493789eb
1 changed files with 2 additions and 14 deletions
|
@ -387,23 +387,11 @@ class BalanceServeInterface(BackendInterfaceBase):
|
|||
return input_ids
|
||||
|
||||
def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):
|
||||
for m in messages:
|
||||
if m["role"] == "system":
|
||||
logger.warning(f'change {m["role"]} to user')
|
||||
m["role"] = "user"
|
||||
|
||||
new_messages = [messages[0]]
|
||||
for m in messages[1:]:
|
||||
if m["role"] == "user" and new_messages[-1]["role"] == "user":
|
||||
logger.warning("merge two adjacent user messages")
|
||||
new_messages[-1]["content"] += '\n' + m["content"]
|
||||
else:
|
||||
new_messages.append(m)
|
||||
input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
|
||||
input_str: str = self.tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)
|
||||
# drop <think> token in chat template
|
||||
if input_str.endswith('<think>\n'):
|
||||
input_str = input_str[:-len('<think>\n')]
|
||||
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
|
||||
input_ids = self.tokenizer.encode(input_str, return_tensors="pt", add_special_tokens=False).to(self.args.device)
|
||||
logger.debug(f"get input ids of shape {input_ids.shape}")
|
||||
return input_ids
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue