mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-08 05:29:29 +00:00
fix: server: drop <think> tag in chat template
This commit is contained in:
parent
ca2090d89b
commit
cd9f7f8f34
1 changed files with 7 additions and 2 deletions
|
@ -170,7 +170,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
for m in messages[1:]:
|
for m in messages[1:]:
|
||||||
if m["role"] == "user" and new_messages[-1]["role"] == "user":
|
if m["role"] == "user" and new_messages[-1]["role"] == "user":
|
||||||
logger.warning("merge two adjacent user messages")
|
logger.warning("merge two adjacent user messages")
|
||||||
new_messages[-1]["content"] += m["content"]
|
new_messages[-1]["content"] += '\n' + m["content"]
|
||||||
else:
|
else:
|
||||||
new_messages.append(m)
|
new_messages.append(m)
|
||||||
# if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
# if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
||||||
|
@ -179,7 +179,11 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
# input_ids = self.tokenizer.apply_chat_template(
|
# input_ids = self.tokenizer.apply_chat_template(
|
||||||
# new_messages, return_tensors="pt", add_generation_prompt=True
|
# new_messages, return_tensors="pt", add_generation_prompt=True
|
||||||
# ).to(self.args.device)
|
# ).to(self.args.device)
|
||||||
input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device)
|
input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
|
||||||
|
# drop <think> token in chat template
|
||||||
|
if input_str.endswith('<think>\n'):
|
||||||
|
input_str = input_str[:-len('<think>\n')]
|
||||||
|
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
|
||||||
if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
||||||
x = self.generated_ids[:,:self.seq_length]
|
x = self.generated_ids[:,:self.seq_length]
|
||||||
y = input_ids[:,:self.seq_length]
|
y = input_ids[:,:self.seq_length]
|
||||||
|
@ -360,6 +364,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
#input_ids = torch.tensor([[6366]], device=input_ids.device)
|
#input_ids = torch.tensor([[6366]], device=input_ids.device)
|
||||||
else:
|
else:
|
||||||
raise ValueError("local_messages should be List or str")
|
raise ValueError("local_messages should be List or str")
|
||||||
|
|
||||||
if Config().user_force_think:
|
if Config().user_force_think:
|
||||||
token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
|
token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
|
||||||
input_ids = torch.cat(
|
input_ids = torch.cat(
|
||||||
|
|
Loading…
Add table
Reference in a new issue