support force thinking

This commit is contained in:
liam 2025-02-12 12:43:53 +08:00
parent 6f3a39be08
commit 4385e85096
3 changed files with 7 additions and 5 deletions

View file

@ -160,7 +160,7 @@ def local_chat(
messages, add_generation_prompt=True, return_tensors="pt" messages, add_generation_prompt=True, return_tensors="pt"
) )
if force_think: if force_think:
token_thinks = torch.tensor([tokenizer.encode("<think>\\n",add_special_tokens=False)]) token_thinks = torch.tensor([tokenizer.encode("<think>\\n",add_special_tokens=False)],device=input_tensor.device)
input_tensor = torch.cat( input_tensor = torch.cat(
[input_tensor, token_thinks], dim=1 [input_tensor, token_thinks], dim=1
) )

View file

@ -122,4 +122,5 @@ class ArgumentParser:
self.cfg.server_ip = args.host self.cfg.server_ip = args.host
self.cfg.server_port = args.port self.cfg.server_port = args.port
self.cfg.backend_type = args.type self.cfg.backend_type = args.type
self.cfg.user_force_think = args.force_think
return args return args

View file

@ -325,7 +325,7 @@ class TransformersInterface(BackendInterfaceBase):
else: else:
raise ValueError("local_messages should be List or str") raise ValueError("local_messages should be List or str")
if Config().user_force_think: if Config().user_force_think:
token_thinks = torch.tensor([self.tokenizer.encode("<think>\\n",add_special_tokens=False)]) token_thinks = torch.tensor([self.tokenizer.encode("<think>\\n",add_special_tokens=False)],device=input_ids.device)
input_ids = torch.cat( input_ids = torch.cat(
[input_ids, token_thinks], dim=1 [input_ids, token_thinks], dim=1
) )
@ -334,8 +334,9 @@ class TransformersInterface(BackendInterfaceBase):
self.profiler.create_and_start_timer("prefill") self.profiler.create_and_start_timer("prefill")
if Config().user_force_think: if Config().user_force_think:
print("<think>\n") t = "<think>\n"
yield "<think>\n" print(t,end="",flush=True)
yield t
for t in self.prefill(input_ids, self.check_is_new(thread_id)): for t in self.prefill(input_ids, self.check_is_new(thread_id)):
if t is not None: if t is not None:
print(t, end="",flush=True) print(t, end="",flush=True)