mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
wjh-change
This commit is contained in:
parent
7c94df4bcf
commit
2d67016d14
4 changed files with 74 additions and 26 deletions
|
@ -164,6 +164,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
if m["role"] == "system":
|
||||
logger.warning(f'change {m["role"]} to user')
|
||||
m["role"] = "user"
|
||||
|
||||
new_messages = [messages[0]]
|
||||
for m in messages[1:]:
|
||||
if m["role"] == "user" and new_messages[-1]["role"] == "user":
|
||||
|
@ -172,25 +173,12 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
else:
|
||||
new_messages.append(m)
|
||||
|
||||
# if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
||||
# logger.debug(f"last message: {new_messages[-1]}")
|
||||
# input_ids = self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",add_generation_prompt=False).to(self.args.device)
|
||||
# else:
|
||||
# input_ids = self.tokenizer.apply_chat_template(
|
||||
# new_messages, return_tensors="pt", add_generation_prompt=True
|
||||
# ).to(self.args.device)
|
||||
|
||||
input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device)
|
||||
if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
||||
x = self.generated_ids[:,:self.seq_length]
|
||||
y = input_ids[:,:self.seq_length]
|
||||
# We can only hope that the input_ids are the same
|
||||
unequal_mask = torch.ne(x,y)
|
||||
unequal_positions = torch.nonzero(unequal_mask)
|
||||
num_unequal_elements = unequal_mask.sum().item()
|
||||
logger.warning(f'num_unequal_elements: {num_unequal_elements}')
|
||||
|
||||
input_ids = input_ids[:,self.seq_length:]
|
||||
input_ids = self.tokenizer.encode(self.tokenizer.eos_token+self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",tokenize=False, add_generation_prompt=True), add_special_tokens = False, return_tensors="pt").to(self.args.device)
|
||||
else:
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
new_messages, return_tensors="pt", add_generation_prompt=True
|
||||
).to(self.args.device)
|
||||
logger.debug(f"get input ids of shape {input_ids.shape}")
|
||||
return input_ids
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue