diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index f3f0373..f205ac5 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -172,13 +172,23 @@ class TransformersInterface(BackendInterfaceBase): new_messages[-1]["content"] += m["content"] else: new_messages.append(m) - + # if (self.last_request_id is not None) and self.last_request_id == thread_id: + # input_ids = self.tokenizer.encode(self.tokenizer.eos_token+self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",tokenize=False, add_generation_prompt=True), add_special_tokens = False, return_tensors="pt").to(self.args.device) + # else: + # input_ids = self.tokenizer.apply_chat_template( + # new_messages, return_tensors="pt", add_generation_prompt=True + # ).to(self.args.device) + input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device) if (self.last_request_id is not None) and self.last_request_id == thread_id: - input_ids = self.tokenizer.encode(self.tokenizer.eos_token+self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",tokenize=False, add_generation_prompt=True), add_special_tokens = False, return_tensors="pt").to(self.args.device) - else: - input_ids = self.tokenizer.apply_chat_template( - new_messages, return_tensors="pt", add_generation_prompt=True - ).to(self.args.device) + x = self.generated_ids[:,:self.seq_length] + y = input_ids[:,:self.seq_length] + # We can only hope that the input_ids are the same + unequal_mask = torch.ne(x,y) + unequal_positions = torch.nonzero(unequal_mask) + num_unequal_elements = unequal_mask.sum().item() + logger.warning(f'num_unequal_elements: {num_unequal_elements}') + + input_ids = input_ids[:,self.seq_length:] logger.debug(f"get input ids of shape {input_ids.shape}") return input_ids