mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
⚡ support R1 force thinking
This commit is contained in:
parent
a339f573f0
commit
d07087a7e2
4 changed files with 43 additions and 116 deletions
|
@ -1,109 +1,3 @@
|
|||
# """
|
||||
# Description :
|
||||
# Author : Boxin Zhang, Azure-Tang
|
||||
# Version : 0.1.0
|
||||
# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
# """
|
||||
|
||||
# import asyncio
|
||||
# import os
|
||||
# import platform
|
||||
# import sys
|
||||
# project_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
# sys.path.insert(0, project_dir)
|
||||
# from ktransformers.server.args import ArgumentParser
|
||||
|
||||
|
||||
# from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
|
||||
# from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
|
||||
# from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
|
||||
# from ktransformers.models.modeling_llama import LlamaForCausalLM
|
||||
# from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
||||
# from ktransformers.server.config.config import Config
|
||||
|
||||
# custom_models = {
|
||||
# "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
|
||||
# "DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
|
||||
# "Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
|
||||
# "LlamaForCausalLM": LlamaForCausalLM,
|
||||
# "MixtralForCausalLM": MixtralForCausalLM,
|
||||
# }
|
||||
|
||||
# ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
|
||||
# default_optimize_rules = {
|
||||
# "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
|
||||
# "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml",
|
||||
# "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
|
||||
# "LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml",
|
||||
# "MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml",
|
||||
# }
|
||||
|
||||
|
||||
# def local_chat():
|
||||
# config = Config()
|
||||
# arg_parser = ArgumentParser(config)
|
||||
# # 初始化消息
|
||||
# arg_parser.parse_args()
|
||||
# if config.backend_type == "transformers":
|
||||
# from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
|
||||
# elif config.backend_type == "exllamav2":
|
||||
# from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface
|
||||
# elif config.backend_type == "ktransformers":
|
||||
# from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
|
||||
# else:
|
||||
# raise NotImplementedError(f"{config.backend_type} not implemented")
|
||||
# interface = BackendInterface(config)
|
||||
|
||||
# system = platform.system()
|
||||
# if system == "Windows":
|
||||
# os.system("cls")
|
||||
# else:
|
||||
# os.system("clear")
|
||||
# # add a history chat content
|
||||
# his_content = []
|
||||
# while True:
|
||||
# content = input("Chat: ")
|
||||
# if content.startswith('"""'): # prefix """
|
||||
# # multi lines input
|
||||
# content = content[3:] + "\n"
|
||||
# while True:
|
||||
# line = input("")
|
||||
# if line.endswith('"""'):
|
||||
# # end multi lines input
|
||||
# line = line[:-3] # suffix """
|
||||
# if line:
|
||||
# content += line + "\n"
|
||||
# break
|
||||
# else:
|
||||
# content += line + "\n"
|
||||
# if content == "":
|
||||
# if not config.prompt_file:
|
||||
# content = "hi"
|
||||
# else:
|
||||
# content = open(config.prompt_file, "r").read()
|
||||
# print("User: ", content)
|
||||
# elif os.path.isfile(content):
|
||||
# content = open(content, "r").read()
|
||||
# print("User: ", content)
|
||||
# messages = his_content + [{"role": "user", "content": content}]
|
||||
|
||||
# async def async_inference(messages):
|
||||
# generated = ""
|
||||
# async for token in interface.inference(messages, "local_chat"):
|
||||
# generated += token
|
||||
# return generated
|
||||
|
||||
# generated = asyncio.run(async_inference(messages))
|
||||
# his_content += [
|
||||
# {"role": "user", "content": content},
|
||||
# {"role": "assistant", "content": generated},
|
||||
# ]
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# local_chat()
|
||||
|
||||
|
||||
"""
|
||||
Description :
|
||||
Author : Boxin Zhang, Azure-Tang
|
||||
|
@ -161,11 +55,12 @@ def local_chat(
|
|||
model_path: str | None = None,
|
||||
optimize_rule_path: str = None,
|
||||
gguf_path: str | None = None,
|
||||
max_new_tokens: int = 1000,
|
||||
max_new_tokens: int = 300,
|
||||
cpu_infer: int = Config().cpu_infer,
|
||||
use_cuda_graph: bool = True,
|
||||
prompt_file : str | None = None,
|
||||
mode: str = "normal",
|
||||
force_think: bool = False,
|
||||
):
|
||||
|
||||
|
||||
|
@ -259,10 +154,16 @@ def local_chat(
|
|||
content = "Please write a piece of quicksort code in C++."
|
||||
elif os.path.isfile(content):
|
||||
content = open(content, "r").read()
|
||||
|
||||
messages = [{"role": "user", "content": content}]
|
||||
input_tensor = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
if force_think:
|
||||
token_thinks = torch.tensor([tokenizer.encode("<think>\\n",add_special_tokens=False)])
|
||||
input_tensor = torch.cat(
|
||||
[input_tensor, token_thinks], dim=1
|
||||
)
|
||||
if mode == 'long_context':
|
||||
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
|
||||
"please change max_seq_len in ~/.ktransformers/config.yaml"
|
||||
|
@ -270,7 +171,7 @@ def local_chat(
|
|||
torch.bfloat16
|
||||
) # TODO: Remove this, replace dtype using config
|
||||
generated = prefill_and_generate(
|
||||
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode
|
||||
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode, force_think
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue