mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
⚡ support force thinking
This commit is contained in:
parent
6f3a39be08
commit
4385e85096
3 changed files with 7 additions and 5 deletions
|
@ -160,7 +160,7 @@ def local_chat(
|
||||||
messages, add_generation_prompt=True, return_tensors="pt"
|
messages, add_generation_prompt=True, return_tensors="pt"
|
||||||
)
|
)
|
||||||
if force_think:
|
if force_think:
|
||||||
token_thinks = torch.tensor([tokenizer.encode("<think>\\n",add_special_tokens=False)])
|
token_thinks = torch.tensor([tokenizer.encode("<think>\\n",add_special_tokens=False)],device=input_tensor.device)
|
||||||
input_tensor = torch.cat(
|
input_tensor = torch.cat(
|
||||||
[input_tensor, token_thinks], dim=1
|
[input_tensor, token_thinks], dim=1
|
||||||
)
|
)
|
||||||
|
|
|
@ -122,4 +122,5 @@ class ArgumentParser:
|
||||||
self.cfg.server_ip = args.host
|
self.cfg.server_ip = args.host
|
||||||
self.cfg.server_port = args.port
|
self.cfg.server_port = args.port
|
||||||
self.cfg.backend_type = args.type
|
self.cfg.backend_type = args.type
|
||||||
|
self.cfg.user_force_think = args.force_think
|
||||||
return args
|
return args
|
||||||
|
|
|
@ -325,7 +325,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
else:
|
else:
|
||||||
raise ValueError("local_messages should be List or str")
|
raise ValueError("local_messages should be List or str")
|
||||||
if Config().user_force_think:
|
if Config().user_force_think:
|
||||||
token_thinks = torch.tensor([self.tokenizer.encode("<think>\\n",add_special_tokens=False)])
|
token_thinks = torch.tensor([self.tokenizer.encode("<think>\\n",add_special_tokens=False)],device=input_ids.device)
|
||||||
input_ids = torch.cat(
|
input_ids = torch.cat(
|
||||||
[input_ids, token_thinks], dim=1
|
[input_ids, token_thinks], dim=1
|
||||||
)
|
)
|
||||||
|
@ -334,8 +334,9 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
|
|
||||||
self.profiler.create_and_start_timer("prefill")
|
self.profiler.create_and_start_timer("prefill")
|
||||||
if Config().user_force_think:
|
if Config().user_force_think:
|
||||||
print("<think>\n")
|
t = "<think>\n"
|
||||||
yield "<think>\n"
|
print(t,end="",flush=True)
|
||||||
|
yield t
|
||||||
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
|
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
|
||||||
if t is not None:
|
if t is not None:
|
||||||
print(t, end="",flush=True)
|
print(t, end="",flush=True)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue