support R1 force thinking

This commit is contained in:
liam 2025-02-11 14:02:19 +08:00
parent a339f573f0
commit d07087a7e2
4 changed files with 43 additions and 116 deletions

View file

@ -19,6 +19,8 @@
- [Dual socket version (64 cores)](#dual-socket-version-64-cores-1)
- [Some Explanations](#some-explanations)
- [FAQ](#faq)
- [R1 No Thinking](#r1-no-thinking)
- [More FAQ](#more-faq)
# SUMMARY
@ -110,21 +112,30 @@ Our local_chat test command is:
``` shell
git clone https://github.com/kvcache-ai/ktransformers.git
cd ktransformers
numactl -N 1 -m 1 python ./ktransformers/local_chat.py --model_path <your model path> --gguf_path <your gguf path> --prompt_file <your prompt txt file> --cpu_infer 33 --cache_lens 1536
git submodule init
git submodule update
numactl -N 1 -m 1 python ./ktransformers/local_chat.py --model_path <your model path> --gguf_path <your gguf path> --prompt_file <your prompt txt file> --cpu_infer 33 --max_new_tokens 1000
<when you see chat, then press enter to load the text prompt_file>
```
\<your model path\> can be local or set from online hugging face like deepseek-ai/DeepSeek-V3. If online encounters connection problem, try use mirror (hf-mirror.com) <br>
\<your gguf path\> can also be online, but as its large we recommend you download it and quantize the model to what you want <br>
The command numactl -N 1 -m 1 aims to advoid data transfer between numa nodes
`<your model path>` can be local or set from online hugging face like deepseek-ai/DeepSeek-V3. If online encounters connection problem, try use mirror (hf-mirror.com) <br>
`<your gguf path>` can also be online, but as its large we recommend you download it and quantize the model to what you want (notice it's the dir path) <br>
`--max_new_tokens 1000` is the max output token length. If you find the answer is truncated, you
can increase the number for longer answer (But be aware of OOM, and increase it will slow down the generation rate.).
<br>
The command numactl -N 1 -m 1 aims to advoid data transfer between numa nodes<br>
Attention! If you are testing R1 and it may skip thinking. So you can add arg: `--force_think true`. This is explained in [FAQ](#faq) part
#### Dual socket version (64 cores)
Make suer before you install (use install.sh or `make dev_install`), setting the env var `USE_NUMA=1` by `export USE_NUMA=1` (if already installed, reinstall it with this env var set) <br>
Our local_chat test command is:
``` shell
git clone https://github.com/kvcache-ai/ktransformers.git
cd ktransformers
git submodule init
git submodule update
export USE_NUMA=1
make dev_install # or sh ./install.sh
python ./ktransformers/local_chat.py --model_path <your model path> --gguf_path <your gguf path> --prompt_file <your prompt txt file> --cpu_infer 65 --cache_lens 1536
python ./ktransformers/local_chat.py --model_path <your model path> --gguf_path <your gguf path> --prompt_file <your prompt txt file> --cpu_infer 65 --max_new_tokens 1000
<when you see chat, then press enter to load the text prompt_file>
```
The parameters' meaning is the same. But As we use dual socket, we set cpu_infer to 65
@ -135,7 +146,7 @@ Our local_chat test command is:
``` shell
wget https://github.com/kvcache-ai/ktransformers/releases/download/v0.1.4/ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl
pip install ./ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl
python -m ktransformers.local_chat --model_path <your model path> --gguf_path <your gguf path> --prompt_file <your prompt txt file> --cpu_infer 65 --cache_lens 1536
python -m ktransformers.local_chat --model_path <your model path> --gguf_path <your gguf path> --prompt_file <your prompt txt file> --cpu_infer 65 --max_new_tokens 1000
<when you see chat, then press enter to load the text prompt_file>
```
The parameters' meaning is the same with V0.2. But As we use dual socket, we set cpu_infer to 65
@ -160,4 +171,8 @@ DeepSeek's MLA operators are highly computationally intensive. While running eve
5. Why Intel CPUs?
Intel is currently the only CPU vendor that supports AMX-like instructions, which delivers significantly better performance compared to AVX-only alternatives.
## FAQ
### R1 No Thinking
Attention! If you are testing R1 and it may skip thinking. So you can add arg: `--force_think true`. The detail is in [FAQ](./FAQ.md) part <br>
### More FAQ
[See detail](./FAQ.md)

View file

@ -7,4 +7,13 @@ sudo add-apt-repository ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install --only-upgrade libstdc++6
```
from-https://github.com/kvcache-ai/ktransformers/issues/117#issuecomment-2647542979
from-https://github.com/kvcache-ai/ktransformers/issues/117#issuecomment-2647542979
### 2 DeepSeek-R1 not outputting initial <think> token
> from deepseek-R1 doc:<br>
> Additionally, we have observed that the DeepSeek-R1 series models tend to bypass thinking pattern (i.e., outputting "\<think>\n\n\</think>") when responding to certain queries, which can adversely affect the model's performance. To ensure that the model engages in thorough reasoning, we recommend enforcing the model to initiate its response with "\<think>\n" at the beginning of every output.
So we fix this by manually adding "\<think>\n" token at prompt end (you can check out at local_chat.py),
and pass the arg `--force_think true ` can let the local_chat initiate the response with "\<think>\n"
from-https://github.com/kvcache-ai/ktransformers/issues/129#issue-2842799552

View file

@ -1,109 +1,3 @@
# """
# Description :
# Author : Boxin Zhang, Azure-Tang
# Version : 0.1.0
# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
# """
# import asyncio
# import os
# import platform
# import sys
# project_dir = os.path.dirname(os.path.dirname(__file__))
# sys.path.insert(0, project_dir)
# from ktransformers.server.args import ArgumentParser
# from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
# from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
# from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
# from ktransformers.models.modeling_llama import LlamaForCausalLM
# from ktransformers.models.modeling_mixtral import MixtralForCausalLM
# from ktransformers.server.config.config import Config
# custom_models = {
# "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
# "DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
# "Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
# "LlamaForCausalLM": LlamaForCausalLM,
# "MixtralForCausalLM": MixtralForCausalLM,
# }
# ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
# default_optimize_rules = {
# "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
# "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml",
# "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
# "LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml",
# "MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml",
# }
# def local_chat():
# config = Config()
# arg_parser = ArgumentParser(config)
# # 初始化消息
# arg_parser.parse_args()
# if config.backend_type == "transformers":
# from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
# elif config.backend_type == "exllamav2":
# from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface
# elif config.backend_type == "ktransformers":
# from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
# else:
# raise NotImplementedError(f"{config.backend_type} not implemented")
# interface = BackendInterface(config)
# system = platform.system()
# if system == "Windows":
# os.system("cls")
# else:
# os.system("clear")
# # add a history chat content
# his_content = []
# while True:
# content = input("Chat: ")
# if content.startswith('"""'): # prefix """
# # multi lines input
# content = content[3:] + "\n"
# while True:
# line = input("")
# if line.endswith('"""'):
# # end multi lines input
# line = line[:-3] # suffix """
# if line:
# content += line + "\n"
# break
# else:
# content += line + "\n"
# if content == "":
# if not config.prompt_file:
# content = "hi"
# else:
# content = open(config.prompt_file, "r").read()
# print("User: ", content)
# elif os.path.isfile(content):
# content = open(content, "r").read()
# print("User: ", content)
# messages = his_content + [{"role": "user", "content": content}]
# async def async_inference(messages):
# generated = ""
# async for token in interface.inference(messages, "local_chat"):
# generated += token
# return generated
# generated = asyncio.run(async_inference(messages))
# his_content += [
# {"role": "user", "content": content},
# {"role": "assistant", "content": generated},
# ]
# if __name__ == "__main__":
# local_chat()
"""
Description :
Author : Boxin Zhang, Azure-Tang
@ -161,11 +55,12 @@ def local_chat(
model_path: str | None = None,
optimize_rule_path: str = None,
gguf_path: str | None = None,
max_new_tokens: int = 1000,
max_new_tokens: int = 300,
cpu_infer: int = Config().cpu_infer,
use_cuda_graph: bool = True,
prompt_file : str | None = None,
mode: str = "normal",
force_think: bool = False,
):
@ -259,10 +154,16 @@ def local_chat(
content = "Please write a piece of quicksort code in C++."
elif os.path.isfile(content):
content = open(content, "r").read()
messages = [{"role": "user", "content": content}]
input_tensor = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
)
if force_think:
token_thinks = torch.tensor([tokenizer.encode("<think>\\n",add_special_tokens=False)])
input_tensor = torch.cat(
[input_tensor, token_thinks], dim=1
)
if mode == 'long_context':
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
"please change max_seq_len in ~/.ktransformers/config.yaml"
@ -270,7 +171,7 @@ def local_chat(
torch.bfloat16
) # TODO: Remove this, replace dtype using config
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode, force_think
)

View file

@ -85,7 +85,7 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
module.load()
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
mode = 'normal'):
mode = 'normal', force_think: bool = False):
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch._dynamo.config.suppress_errors = True
@ -172,6 +172,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
prefill_count = seq_length
prefill_time = first_token_time
if force_think:
print("<think>\n")
print(stream.put(next_token.item()), end="", flush=True)
generated_ids[:, seq_length] = next_token
tokens.append(int(next_token))