diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 80ada29..3e0d2ef 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -91,7 +91,7 @@ def local_chat(): generated = asyncio.run(async_inference(messages)) his_content += [ {"role": "user", "content": content}, - {"role": "assitant", "content": generated}, + {"role": "assistant", "content": generated}, ] diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index 7f569c4..cddc198 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -164,7 +164,6 @@ class TransformersInterface(BackendInterfaceBase): if m["role"] == "system": logger.warning(f'change {m["role"]} to user') m["role"] = "user" - new_messages = [messages[0]] for m in messages[1:]: if m["role"] == "user" and new_messages[-1]["role"] == "user": @@ -173,12 +172,25 @@ class TransformersInterface(BackendInterfaceBase): else: new_messages.append(m) + # if (self.last_request_id is not None) and self.last_request_id == thread_id: + # logger.debug(f"last message: {new_messages[-1]}") + # input_ids = self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",add_generation_prompt=False).to(self.args.device) + # else: + # input_ids = self.tokenizer.apply_chat_template( + # new_messages, return_tensors="pt", add_generation_prompt=True + # ).to(self.args.device) + + input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device) if (self.last_request_id is not None) and self.last_request_id == thread_id: - input_ids = self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt").to(self.args.device) - else: - input_ids = self.tokenizer.apply_chat_template( - new_messages, return_tensors="pt", add_generation_prompt=True - ).to(self.args.device) + x = self.generated_ids[:,:self.seq_length] + y = input_ids[:,:self.seq_length] + # We can only hope that the input_ids are the same + unequal_mask = torch.ne(x,y) + unequal_positions = torch.nonzero(unequal_mask) + num_unequal_elements = unequal_mask.sum().item() + logger.warning(f'num_unequal_elements: {num_unequal_elements}') + + input_ids = input_ids[:,self.seq_length:] logger.debug(f"get input ids of shape {input_ids.shape}") return input_ids