🚑️: back transformer.py bugs version, and fix typo error in local_chat.py

This commit is contained in:
liam 2024-10-28 21:09:40 +08:00
parent dd1d8667f3
commit 7c94df4bcf
2 changed files with 19 additions and 7 deletions

View file

@ -91,7 +91,7 @@ def local_chat():
generated = asyncio.run(async_inference(messages)) generated = asyncio.run(async_inference(messages))
his_content += [ his_content += [
{"role": "user", "content": content}, {"role": "user", "content": content},
{"role": "assitant", "content": generated}, {"role": "assistant", "content": generated},
] ]

View file

@ -164,7 +164,6 @@ class TransformersInterface(BackendInterfaceBase):
if m["role"] == "system": if m["role"] == "system":
logger.warning(f'change {m["role"]} to user') logger.warning(f'change {m["role"]} to user')
m["role"] = "user" m["role"] = "user"
new_messages = [messages[0]] new_messages = [messages[0]]
for m in messages[1:]: for m in messages[1:]:
if m["role"] == "user" and new_messages[-1]["role"] == "user": if m["role"] == "user" and new_messages[-1]["role"] == "user":
@ -173,12 +172,25 @@ class TransformersInterface(BackendInterfaceBase):
else: else:
new_messages.append(m) new_messages.append(m)
# if (self.last_request_id is not None) and self.last_request_id == thread_id:
# logger.debug(f"last message: {new_messages[-1]}")
# input_ids = self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",add_generation_prompt=False).to(self.args.device)
# else:
# input_ids = self.tokenizer.apply_chat_template(
# new_messages, return_tensors="pt", add_generation_prompt=True
# ).to(self.args.device)
input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device)
if (self.last_request_id is not None) and self.last_request_id == thread_id: if (self.last_request_id is not None) and self.last_request_id == thread_id:
input_ids = self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt").to(self.args.device) x = self.generated_ids[:,:self.seq_length]
else: y = input_ids[:,:self.seq_length]
input_ids = self.tokenizer.apply_chat_template( # We can only hope that the input_ids are the same
new_messages, return_tensors="pt", add_generation_prompt=True unequal_mask = torch.ne(x,y)
).to(self.args.device) unequal_positions = torch.nonzero(unequal_mask)
num_unequal_elements = unequal_mask.sum().item()
logger.warning(f'num_unequal_elements: {num_unequal_elements}')
input_ids = input_ids[:,self.seq_length:]
logger.debug(f"get input ids of shape {input_ids.shape}") logger.debug(f"get input ids of shape {input_ids.shape}")
return input_ids return input_ids