update force_think

This commit is contained in:
liam 2025-02-12 11:42:55 +08:00
parent a2fc2a8658
commit e536e1420d
3 changed files with 11 additions and 0 deletions

View file

@ -90,6 +90,7 @@ class ArgumentParser:
# user config
parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key)
parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm)
parser.add_argument("--force_think", type=bool, default=self.cfg.force_think)
# web config
parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain)

View file

@ -10,6 +10,7 @@ from transformers import (
BitsAndBytesConfig,
)
from ktransformers.server.config.config import Config
from ktransformers.server.schemas.base import ObjectID
from ktransformers.server.utils.multi_timer import Profiler
import torch
@ -323,10 +324,18 @@ class TransformersInterface(BackendInterfaceBase):
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else:
raise ValueError("local_messages should be List or str")
if Config().force_think:
token_thinks = torch.tensor([self.tokenizer.encode("<think>\\n",add_special_tokens=False)])
input_ids = torch.cat(
[input_ids, token_thinks], dim=1
)
self.profiler.pause_timer("tokenize")
self.profiler.create_and_start_timer("prefill")
if Config().force_think:
print("<think>\n")
yield "<think>\n"
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
if t is not None:
print(t, end="",flush=True)

View file

@ -83,6 +83,7 @@ class Config(metaclass=Singleton):
self.user_config: dict = cfg.get("user", {})
self.user_secret_key = self.user_config.get("secret_key", "")
self.user_algorithm = self.user_config.get("algorithm", "")
self.user_force_think = self.user_config.get("force_think", False)
# model config
self.model: dict = cfg.get("model", {})