mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
⚡ update force_think
This commit is contained in:
parent
a2fc2a8658
commit
e536e1420d
3 changed files with 11 additions and 0 deletions
|
@ -10,6 +10,7 @@ from transformers import (
|
|||
BitsAndBytesConfig,
|
||||
)
|
||||
|
||||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.server.schemas.base import ObjectID
|
||||
from ktransformers.server.utils.multi_timer import Profiler
|
||||
import torch
|
||||
|
@ -323,10 +324,18 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
#input_ids = torch.tensor([[6366]], device=input_ids.device)
|
||||
else:
|
||||
raise ValueError("local_messages should be List or str")
|
||||
if Config().force_think:
|
||||
token_thinks = torch.tensor([self.tokenizer.encode("<think>\\n",add_special_tokens=False)])
|
||||
input_ids = torch.cat(
|
||||
[input_ids, token_thinks], dim=1
|
||||
)
|
||||
|
||||
self.profiler.pause_timer("tokenize")
|
||||
|
||||
self.profiler.create_and_start_timer("prefill")
|
||||
if Config().force_think:
|
||||
print("<think>\n")
|
||||
yield "<think>\n"
|
||||
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue