Merge pull request #110 from KMSorSMS/main

refactor local_chat & config setting
This commit is contained in:
UnicornChan 2024-11-06 09:44:12 +08:00 committed by GitHub
commit dddc42038d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 652 additions and 400 deletions

4
.flake8 Normal file
View file

@ -0,0 +1,4 @@
[flake8]
max-line-length = 120
extend-select = B950
extend-ignore = E203,E501,E701, B001,B006,B007,B008,B009,B010,B011,B016,B028,B031,B950,E265,E266,E401,E402,E711,E712,E713,E721,E722,E731,F401,F403,F405,F541,F811,F821,F841,W391

1
.gitignore vendored
View file

@ -18,3 +18,4 @@ compile_commands.json
ktransformers/server/local_store/ ktransformers/server/local_store/
ktransformers/server_test1.db ktransformers/server_test1.db
*.patch *.patch
img/

21
Makefile Normal file
View file

@ -0,0 +1,21 @@
flake_find:
cd ktransformers && flake8 | grep -Eo '[A-Z][0-9]{3}' | sort | uniq| paste -sd ',' -
format:
@cd ktransformers && black .
@black setup.py
dev_install:
# clear build dirs
rm -rf build
rm -rf *.egg-info
rm -rf ktransformers/ktransformers_ext/build
rm -rf ktransformers/ktransformers_ext/cuda/build
rm -rf ktransformers/ktransformers_ext/cuda/dist
rm -rf ktransformers/ktransformers_ext/cuda/*.egg-info
# install ktransformers
echo "Installing python dependencies from requirements.txt"
pip install -r requirements-local_chat.txt
echo "Installing ktransformers"
KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . --no-build-isolation
echo "Installation completed successfully"

View file

@ -151,7 +151,7 @@ Some preparation:
``` ```
install.bat install.bat
``` ```
4. If you are developer, you can make use of the makefile to compile and format the code. <br> the detailed usage of makefile is [here](./doc/en/makefile_usage.md)
<h3>Local Chat</h3> <h3>Local Chat</h3>
We provide a simple command-line local chat Python script that you can run for testing. We provide a simple command-line local chat Python script that you can run for testing.

26
doc/en/makefile_usage.md Normal file
View file

@ -0,0 +1,26 @@
# Makefile
## Target
### flake_find:
```bash
make flake_find
```
find all the python files under ./ktransformers dir and find the Error, Warning, Fatal... (their codes) into a list that are not consistent with the pep8 standard. For now we have get all this list in the .flake8 file's extend-ignore section in order to let flakes8 ignore them temporarily.(we may improve them in the future)
### format:
```bash
make format
```
we use black to format all the python files under ./ktransformers dir. It obeys the pep8 standard
but we modify the line length to 120 by add
```toml
[tool.black]
line-length = 120
preview = true
unstable = true
```
in the pyproject.toml file.
### dev_install:
```bash
make dev_install
```
install the package in the development mode. It means that the package is installed in the editable mode. So if you modify the code, you don't need to reinstall the package. We recommend the developer to use this method to install the package.

View file

@ -7,7 +7,7 @@ log:
server: server:
ip: 0.0.0.0 ip: 0.0.0.0
port: 12456 port: 10002
db: db:
type: "sqllite" type: "sqllite"
@ -24,10 +24,11 @@ model:
type: ktransformers type: ktransformers
name: DeepSeek-Coder-V2-Instruct name: DeepSeek-Coder-V2-Instruct
path: /mnt/data/model/DeepSeek-Coder-V2-Instruct/ path: deepseek-ai/DeepSeek-V2-Lite-Chat
gguf_path: /mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH/ gguf_path: ./DeepSeek-V2-Lite-Chat-GGUF
device: cuda:0 device: cuda:0
cache_lens: 8192
web: web:
mount: False mount: False
@ -50,4 +51,7 @@ long_context:
head_select_mode: SHARED head_select_mode: SHARED
preselect_block_count: 32 preselect_block_count: 32
layer_step: 1 layer_step: 1
token_step: 100 token_step:
local_chat:
prompt_file: "./ktransformers/p.txt"

View file

@ -5,29 +5,19 @@ Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved. Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
""" """
import asyncio
import os import os
import platform import platform
import sys import sys
project_dir = os.path.dirname(os.path.dirname(__file__)) project_dir = os.path.dirname(os.path.dirname(__file__))
sys.path.insert(0, project_dir) sys.path.insert(0, project_dir)
import torch from ktransformers.server.args import ArgumentParser
import logging
from transformers import (
AutoTokenizer,
AutoConfig,
AutoModelForCausalLM,
GenerationConfig,
TextStreamer,
)
import json
import fire
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
custom_models = { custom_models = {
@ -37,9 +27,7 @@ custom_models = {
"MixtralForCausalLM": MixtralForCausalLM, "MixtralForCausalLM": MixtralForCausalLM,
} }
ktransformer_rules_dir = ( ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
)
default_optimize_rules = { default_optimize_rules = {
"DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml", "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
@ -48,75 +36,28 @@ default_optimize_rules = {
} }
def local_chat( def local_chat():
model_path: str | None = None, config = Config()
optimize_rule_path: str = None, arg_parser = ArgumentParser(config)
gguf_path: str | None = None, # 初始化消息
max_new_tokens: int = 1000, arg_parser.parse_args()
cpu_infer: int = Config().cpu_infer, if config.backend_type == "transformers":
use_cuda_graph: bool = True, from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
prompt_file : str | None = None, elif config.backend_type == "exllamav2":
mode: str = "normal", from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface
): elif config.backend_type == "ktransformers":
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
torch.set_grad_enabled(False)
Config().cpu_infer = cpu_infer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if mode == 'long_context':
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
torch.set_default_dtype(torch.float16)
else: else:
torch.set_default_dtype(config.torch_dtype) raise NotImplementedError(f"{config.backend_type} not implemented")
interface = BackendInterface(config)
with torch.device("meta"):
if config.architectures[0] in custom_models:
print("using custom modeling_xxx.py.")
if (
"Qwen2Moe" in config.architectures[0]
): # Qwen2Moe must use flash_attention_2 to avoid overflow.
config._attn_implementation = "flash_attention_2"
if "Llama" in config.architectures[0]:
config._attn_implementation = "eager"
if "Mixtral" in config.architectures[0]:
config._attn_implementation = "flash_attention_2"
model = custom_models[config.architectures[0]](config)
else:
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=True, attn_implementation="flash_attention_2"
)
if optimize_rule_path is None:
if config.architectures[0] in default_optimize_rules:
print("using default_optimize_rule for", config.architectures[0])
optimize_rule_path = default_optimize_rules[config.architectures[0]]
else:
optimize_rule_path = input(
"please input the path of your rule file(yaml file containing optimize rules):"
)
if gguf_path is None:
gguf_path = input(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config)
model.generation_config = GenerationConfig.from_pretrained(model_path)
if model.generation_config.pad_token_id is None:
model.generation_config.pad_token_id = model.generation_config.eos_token_id
model.eval()
logging.basicConfig(level=logging.INFO)
system = platform.system() system = platform.system()
if system == "Windows": if system == "Windows":
os.system("cls") os.system("cls")
else: else:
os.system("clear") os.system("clear")
# add a history chat content
his_content = []
while True: while True:
content = input("Chat: ") content = input("Chat: ")
if content.startswith('"""'): # prefix """ if content.startswith('"""'): # prefix """
@ -132,28 +73,27 @@ def local_chat(
break break
else: else:
content += line + "\n" content += line + "\n"
if content == "": if content == "":
if prompt_file != None: if config.prompt_file == None or config.prompt_file == "":
content = open(prompt_file, "r").read()
else:
content = "Please write a piece of quicksort code in C++." content = "Please write a piece of quicksort code in C++."
else:
content = open(config.prompt_file, "r").read()
elif os.path.isfile(content): elif os.path.isfile(content):
content = open(content, "r").read() content = open(content, "r").read()
messages = [{"role": "user", "content": content}] messages = his_content + [{"role": "user", "content": content}]
input_tensor = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt" async def async_inference(messages):
) generated = ""
if mode == 'long_context': async for token in interface.inference(messages, "local_chat"):
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ generated += token
"please change max_seq_len in ~/.ktransformers/config.yaml" return generated
torch.set_default_dtype(
torch.bfloat16 generated = asyncio.run(async_inference(messages))
) # TODO: Remove this, replace dtype using config his_content += [
generated = prefill_and_generate( {"role": "user", "content": content},
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode {"role": "assistant", "content": generated},
) ]
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(local_chat) local_chat()

View file

@ -0,0 +1,56 @@
- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View file

@ -135,6 +135,8 @@ class OllamaShowResponse(BaseModel):
details: OllamaShowDetial details: OllamaShowDetial
model_info: OllamaModelInfo model_info: OllamaModelInfo
class Config:
protected_namespaces = ()

View file

@ -0,0 +1,124 @@
import argparse
from ktransformers.server.backend.args import ConfigArgs, default_args
class ArgumentParser:
def __init__(self, cfg):
self.cfg = cfg
def parse_args(self):
parser = argparse.ArgumentParser(prog="kvcache.ai", description="Ktransformers")
parser.add_argument("--host", type=str, default=self.cfg.server_ip)
parser.add_argument("--port", type=int, default=self.cfg.server_port)
parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str)
parser.add_argument("--web", type=bool, default=self.cfg.mount_web)
parser.add_argument("--model_name", type=str, default=self.cfg.model_name)
parser.add_argument("--model_dir", type=str)
parser.add_argument("--model_path", type=str)
parser.add_argument(
"--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter"
)
parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path)
parser.add_argument("--optimize_config_path", default=self.cfg.optimize_config_path, type=str, required=False)
parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer)
parser.add_argument("--type", type=str, default=self.cfg.backend_type)
# model configs
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
parser.add_argument("--paged", type=bool, default=self.cfg.paged)
parser.add_argument("--total_context", type=int, default=self.cfg.total_context)
parser.add_argument("--max_batch_size", type=int, default=self.cfg.max_batch_size)
parser.add_argument("--max_chunk_size", type=int, default=self.cfg.max_chunk_size)
parser.add_argument("--max_new_tokens", type=int, default=self.cfg.max_new_tokens)
parser.add_argument("--json_mode", type=bool, default=self.cfg.json_mode)
parser.add_argument("--healing", type=bool, default=self.cfg.healing)
parser.add_argument("--ban_strings", type=list, default=self.cfg.ban_strings, required=False)
parser.add_argument("--gpu_split", type=str, default=self.cfg.gpu_split, required=False)
parser.add_argument("--length", type=int, default=self.cfg.length, required=False)
parser.add_argument("--rope_scale", type=float, default=self.cfg.rope_scale, required=False)
parser.add_argument("--rope_alpha", type=float, default=self.cfg.rope_alpha, required=False)
parser.add_argument("--no_flash_attn", type=bool, default=self.cfg.no_flash_attn)
parser.add_argument("--low_mem", type=bool, default=self.cfg.low_mem)
parser.add_argument("--experts_per_token", type=int, default=self.cfg.experts_per_token, required=False)
parser.add_argument("--load_q4", type=bool, default=self.cfg.load_q4)
parser.add_argument("--fast_safetensors", type=bool, default=self.cfg.fast_safetensors)
parser.add_argument("--draft_model_dir", type=str, default=self.cfg.draft_model_dir, required=False)
parser.add_argument("--no_draft_scale", type=bool, default=self.cfg.no_draft_scale)
parser.add_argument("--modes", type=bool, default=self.cfg.modes)
parser.add_argument("--mode", type=str, default=self.cfg.mode)
parser.add_argument("--username", type=str, default=self.cfg.username)
parser.add_argument("--botname", type=str, default=self.cfg.botname)
parser.add_argument("--system_prompt", type=str, default=self.cfg.system_prompt, required=False)
parser.add_argument("--temperature", type=float, default=self.cfg.temperature)
parser.add_argument("--smoothing_factor", type=float, default=self.cfg.smoothing_factor)
parser.add_argument("--dynamic_temperature", type=str, default=self.cfg.dynamic_temperature, required=False)
parser.add_argument("--top_k", type=int, default=self.cfg.top_k)
parser.add_argument("--top_p", type=float, default=self.cfg.top_p)
parser.add_argument("--top_a", type=float, default=self.cfg.top_a)
parser.add_argument("--skew", type=float, default=self.cfg.skew)
parser.add_argument("--typical", type=float, default=self.cfg.typical)
parser.add_argument("--repetition_penalty", type=float, default=self.cfg.repetition_penalty)
parser.add_argument("--frequency_penalty", type=float, default=self.cfg.frequency_penalty)
parser.add_argument("--presence_penalty", type=float, default=self.cfg.presence_penalty)
parser.add_argument("--max_response_tokens", type=int, default=self.cfg.max_response_tokens)
parser.add_argument("--response_chunk", type=int, default=self.cfg.response_chunk)
parser.add_argument("--no_code_formatting", type=bool, default=self.cfg.no_code_formatting)
parser.add_argument("--cache_8bit", type=bool, default=self.cfg.cache_8bit)
parser.add_argument("--cache_q4", type=bool, default=self.cfg.cache_q4)
parser.add_argument("--ngram_decoding", type=bool, default=self.cfg.ngram_decoding)
parser.add_argument("--print_timings", type=bool, default=self.cfg.print_timings)
parser.add_argument("--amnesia", type=bool, default=self.cfg.amnesia)
parser.add_argument("--batch_size", type=int, default=self.cfg.batch_size)
parser.add_argument("--cache_lens", type=int, default=self.cfg.cache_lens)
# log configs
# log level: debug, info, warn, error, crit
parser.add_argument("--log_dir", type=str, default=self.cfg.log_dir)
parser.add_argument("--log_file", type=str, default=self.cfg.log_file)
parser.add_argument("--log_level", type=str, default=self.cfg.log_level)
parser.add_argument("--backup_count", type=int, default=self.cfg.backup_count)
# db configs
parser.add_argument("--db_type", type=str, default=self.cfg.db_type)
parser.add_argument("--db_host", type=str, default=self.cfg.db_host)
parser.add_argument("--db_port", type=str, default=self.cfg.db_port)
parser.add_argument("--db_name", type=str, default=self.cfg.db_name)
parser.add_argument("--db_pool_size", type=int, default=self.cfg.db_pool_size)
parser.add_argument("--db_database", type=str, default=self.cfg.db_database)
# user config
parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key)
parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm)
# web config
parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain)
# file config
parser.add_argument("--file_upload_dir", type=str, default=self.cfg.file_upload_dir)
parser.add_argument("--assistant_store_dir", type=str, default=self.cfg.assistant_store_dir)
# local chat
parser.add_argument("--prompt_file", type=str, default=self.cfg.prompt_file)
args = parser.parse_args()
if (args.model_dir is not None):
if (args.model_path is not None):
# if pass model_dir and model_path, we use model_path
args.model_dir = args.model_path
else:
# if only pass model_dir, we use model_dir
args.model_path = args.model_dir
else:
args.model_dir = self.cfg.model_dir
args.model_path = self.cfg.model_path
# set config from args
for key, value in vars(args).items():
if value is not None and hasattr(self.cfg, key):
setattr(self.cfg, key, value)
# we add the name not match args individually
self.cfg.model_device = args.device
self.cfg.mount_web = args.web
self.cfg.server_ip = args.host
self.cfg.server_port = args.port
self.cfg.backend_type = args.type
return args

View file

@ -1,97 +1,89 @@
from pydantic import BaseModel,Field from pydantic import BaseModel, Field
from typing import Optional from typing import Optional
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
class ConfigArgs(BaseModel): class ConfigArgs(BaseModel):
model_name : Optional[str] = Field(..., description="Model name") model_name: Optional[str] = Field(..., description="Model name")
model_dir: Optional[str] = Field(..., description="Path to model directory") model_dir: Optional[str] = Field(..., description="Path to model directory")
optimize_config_path: Optional[str] = Field('./KTransformers/optimize_config/DeepSeek-V2-Chat.json', description="Path of your optimize config json file") optimize_config_path: Optional[str] = Field(None, description="Path of your optimize config yml file")
gguf_path: Optional[str] = Field('/models/DeepSeek-Coder-V2-Instruct-GGUF/DeepSeek-Coder-V2-Instruct-Q4_K_M.gguf', description="Path of your gguf file") gguf_path: Optional[str] = Field(None, description="Path of your gguf file")
class Config: class Config:
protected_namespaces = () protected_namespaces = ()
paged : bool = Field(True,description='Wether to use paged attention kv cache') paged: bool = Field(None, description="Whether to use paged attention kv cache")
total_context: int = Field(
# total_context: int = Field(16384, description="Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the total to distribute dynamically over however many jobs are active at once") None,
total_context: int = Field(2**18, description="Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the total to distribute dynamically over however many jobs are active at once") description=(
max_batch_size: int = Field(20 if paged else 1, description="Max number of batches to run at once, assuming the sequences will fit within total_context") "Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the"
max_chunk_size: int = Field(2048, description="Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new job is started, but at the expense of overall prompt ingestion speed") " total to distribute dynamically over however many jobs are active at once"
max_new_tokens: int = Field(500, description="Max new tokens per completion. For this example applies to all jobs") ),
json_mode: bool = Field(False, description="Use LMFE to constrain the output to JSON format. See schema and details below") )
healing: bool = Field(False, description="Demonstrate token healing") max_batch_size: int = Field(
None, description="Max number of batches to run at once, assuming the sequences will fit within total_context"
)
max_chunk_size: int = Field(
None,
description=(
"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new"
" job is started, but at the expense of overall prompt ingestion speed"
),
)
max_new_tokens: int = Field(None, description="Max new tokens per completion. For this example applies to all jobs")
json_mode: bool = Field(
None, description="Use LMFE to constrain the output to JSON format. See schema and details below"
)
healing: bool = Field(None, description="Demonstrate token healing")
ban_strings: Optional[list] = Field(None, description="Ban some phrases maybe") ban_strings: Optional[list] = Field(None, description="Ban some phrases maybe")
gpu_split: Optional[str] = Field(None, description='"auto", or VRAM allocation per GPU in GB') gpu_split: Optional[str] = Field(None, description='"auto", or VRAM allocation per GPU in GB')
length: Optional[int] = Field(None, description="Maximum sequence length") length: Optional[int] = Field(None, description="Maximum sequence length")
rope_scale: Optional[float] = Field(None, description="RoPE scaling factor") rope_scale: Optional[float] = Field(None, description="RoPE scaling factor")
rope_alpha: Optional[float] = Field(None, description="RoPE alpha value (NTK)") rope_alpha: Optional[float] = Field(None, description="RoPE alpha value (NTK)")
no_flash_attn: bool = Field(False, description="Disable Flash Attention") no_flash_attn: bool = Field(None, description="Disable Flash Attention")
low_mem: bool = Field( low_mem: bool = Field(None, description="Enable VRAM optimizations, potentially trading off speed")
False,
description="Enable VRAM optimizations, potentially trading off speed",
)
experts_per_token: Optional[int] = Field( experts_per_token: Optional[int] = Field(
None, None, description="Override MoE model's default number of experts per token"
description="Override MoE model's default number of experts per token",
)
load_q4: bool = Field(False, description="Load weights in Q4 mode")
fast_safetensors: bool = Field(
False,
description="Optimized safetensors loading with direct I/O (experimental!)",
) )
load_q4: bool = Field(None, description="Load weights in Q4 mode")
fast_safetensors: bool = Field(None, description="Optimized safetensors loading with direct I/O (experimental!)")
draft_model_dir: Optional[str] = Field(None, description="Path to draft model directory") draft_model_dir: Optional[str] = Field(None, description="Path to draft model directory")
no_draft_scale: bool = Field( no_draft_scale: bool = Field(
False, None,
description="If draft model has smaller context size than model, don't apply alpha (NTK) scaling to extend it", description="If draft model has smaller context size than model, don't apply alpha (NTK) scaling to extend it",
) )
modes: bool = Field(False, description="List available modes and exit.") modes: bool = Field(None, description="List available modes and exit.")
mode: str = Field( mode: str = Field(None, description="Chat mode. Use llama for Llama 1/2 chat finetunes.")
"llama", username: str = Field(None, description="Username when using raw chat mode")
description="Chat mode. Use llama for Llama 1/2 chat finetunes.", botname: str = Field(None, description="Bot name when using raw chat mode")
)
username: str = Field("User", description="Username when using raw chat mode")
botname: str = Field("Chatbort", description="Bot name when using raw chat mode")
system_prompt: Optional[str] = Field(None, description="Use custom system prompt") system_prompt: Optional[str] = Field(None, description="Use custom system prompt")
temperature: float = Field(0.95, description="Sampler temperature, default = 0.95 (1 to disable)") temperature: float = Field(None, description="Sampler temperature, default = 0.95 (1 to disable)")
smoothing_factor: float = Field(0.0, description="Smoothing Factor, default = 0.0 (0 to disable)") smoothing_factor: float = Field(None, description="Smoothing Factor, default = 0.0 (0 to disable)")
dynamic_temperature: Optional[str] = Field( dynamic_temperature: Optional[str] = Field(
None, None, description="Dynamic temperature min,max,exponent, e.g. -dyntemp 0.2,1.5,1"
description="Dynamic temperature min,max,exponent, e.g. -dyntemp 0.2,1.5,1",
) )
top_k: int = Field(50, description="Sampler top-K, default = 50 (0 to disable)") top_k: int = Field(None, description="Sampler top-K, default = 50 (0 to disable)")
top_p: float = Field(0.8, description="Sampler top-P, default = 0.8 (0 to disable)") top_p: float = Field(None, description="Sampler top-P, default = 0.8 (0 to disable)")
top_a: float = Field(0.0, description="Sampler top-A, default = 0.0 (0 to disable)") top_a: float = Field(None, description="Sampler top-A, default = 0.0 (0 to disable)")
skew: float = Field(0.0, description="Skew sampling, default = 0.0 (0 to disable)") skew: float = Field(None, description="Skew sampling, default = 0.0 (0 to disable)")
typical: float = Field( typical: float = Field(None, description="Sampler typical threshold, default = 0.0 (0 to disable)")
0.0, repetition_penalty: float = Field(None, description="Sampler repetition penalty, default = 1.01 (1 to disable)")
description="Sampler typical threshold, default = 0.0 (0 to disable)", frequency_penalty: float = Field(None, description="Sampler frequency penalty, default = 0.0 (0 to disable)")
) presence_penalty: float = Field(None, description="Sampler presence penalty, default = 0.0 (0 to disable)")
repetition_penalty: float = Field( max_response_tokens: int = Field(None, description="Max tokens per response, default = 1000")
1.01, response_chunk: int = Field(None, description="Space to reserve in context for reply, default = 250")
description="Sampler repetition penalty, default = 1.01 (1 to disable)", no_code_formatting: bool = Field(None, description="Disable code formatting/syntax highlighting")
) cache_8bit: bool = Field(None, description="Use 8-bit (FP8) cache")
frequency_penalty: float = Field( cache_q4: bool = Field(None, description="Use Q4 cache")
0.0, ngram_decoding: bool = Field(None, description="Use n-gram speculative decoding")
description="Sampler frequency penalty, default = 0.0 (0 to disable)", print_timings: bool = Field(None, description="Output timings after each prompt")
) amnesia: bool = Field(None, description="Forget context after every response")
presence_penalty: float = Field(
0.0,
description="Sampler presence penalty, default = 0.0 (0 to disable)",
)
max_response_tokens: int = Field(300, description="Max tokens per response, default = 1000")
response_chunk: int = Field(250, description="Space to reserve in context for reply, default = 250")
no_code_formatting: bool = Field(False, description="Disable code formatting/syntax highlighting")
cache_8bit: bool = Field(False, description="Use 8-bit (FP8) cache")
cache_q4: bool = Field(True, description="Use Q4 cache")
ngram_decoding: bool = Field(False, description="Use n-gram speculative decoding")
print_timings: bool = Field(False, description="Output timings after each prompt")
amnesia: bool = Field(False, description="Forget context after every response")
# for transformers # for transformers
batch_size :int = Field(1,description="Batch Size") batch_size: int = Field(None, description="Batch Size")
cache_lens:int = Field(4096, description="Cache lens for transformers static cache") cache_lens: int = Field(None, description="Cache lens for transformers static cache")
device:str = Field('cuda:2',description="device") device: str = Field(None, description="device")
cfg = Config() cfg = Config()
default_args = ConfigArgs(model_name=cfg.model_name,model_dir=cfg.model_path) default_args = cfg

View file

@ -1,6 +1,12 @@
import torch import torch
from transformers import AutoTokenizer, AutoConfig, GenerationConfig from transformers import AutoTokenizer, AutoConfig, GenerationConfig
from ktransformers.server.backend.interfaces.transformers import TransformersInterface,ConfigArgs, TransformersThreadContext,default_args,TextStreamer from ktransformers.server.backend.interfaces.transformers import (
TransformersInterface,
ConfigArgs,
TransformersThreadContext,
default_args,
TextStreamer,
)
from ktransformers.server.config.log import logger from ktransformers.server.config.log import logger
from ktransformers.optimize.optimize import optimize_and_load_gguf from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.custom_cache import StaticCache from ktransformers.models.custom_cache import StaticCache
@ -14,17 +20,17 @@ class KTransformersThreadContext(TransformersThreadContext):
class KTransformersInterface(TransformersInterface): class KTransformersInterface(TransformersInterface):
def __init__(self,args:ConfigArgs= default_args): def __init__(self, args: ConfigArgs = default_args):
self.args = args self.args = args
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir,device = args.device) self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device)
config=AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
if config.architectures[0] == "Qwen2MoeForCausalLM": if config.architectures[0] == "Qwen2MoeForCausalLM":
config._attn_implementation="flash_attention_2" config._attn_implementation = "flash_attention_2"
with torch.device("meta"): with torch.device("meta"):
self.model=custom_models[config.architectures[0]](config) self.model = custom_models[config.architectures[0]](config)
if default_args.optimize_config_path is None: if default_args.optimize_config_path is None:
optimize_rule_path = default_optimize_rules[config.architectures[0]] optimize_rule_path = default_optimize_rules[config.architectures[0]]
else: else:
@ -35,15 +41,21 @@ class KTransformersInterface(TransformersInterface):
gguf_path = args.gguf_path gguf_path = args.gguf_path
if gguf_path is None: if gguf_path is None:
gguf_path = input( gguf_path = input(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):" "please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
" belong to current model):"
) )
optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config) optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config)
device_map = self.model.gguf_loader.tensor_device_map device_map = self.model.gguf_loader.tensor_device_map
logger.info(f'{args.model_name} loaded from {args.model_dir} to {device_map}') logger.info(f"{args.model_name} loaded from {args.model_dir} to {device_map}")
self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=device_map, dtype=self.model.dtype) self.cache = StaticCache(
logger.info(f'StaticCache (length={args.cache_lens}) created at {device_map}, batch size:{args.batch_size}') config=self.model.config,
max_batch_size=args.batch_size,
max_cache_len=args.cache_lens,
device=device_map,
dtype=self.model.dtype,
)
logger.info(f"StaticCache (length={args.cache_lens}) created at {device_map}, batch size:{args.batch_size}")
self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir) self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
if self.model.generation_config.pad_token_id is None: if self.model.generation_config.pad_token_id is None:
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
@ -52,33 +64,41 @@ class KTransformersInterface(TransformersInterface):
def decode_one_tokens(self): def decode_one_tokens(self):
if not hasattr(self, "cuda_graph_runner"): if not hasattr(self, "cuda_graph_runner"):
device_map = self.model.gguf_loader.tensor_device_map device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device('blk.0.self_attn', device_map) torch_device = get_device("blk.0.self_attn", device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device torch_device = "cuda:0" if torch_device == "cuda" else torch_device
self.cuda_graph_runner = CUDAGraphRunner() self.cuda_graph_runner = CUDAGraphRunner()
self.cuda_graph_runner.capture(self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, main_device=torch_device, return_dict=False, use_cache=True) self.cuda_graph_runner.capture(
self.model,
self.current_ids,
self.active_cache_position.unsqueeze(0),
self.active_cache_position,
self.cache,
main_device=torch_device,
return_dict=False,
use_cache=True,
)
if hasattr(self, "cuda_graph_runner"): if hasattr(self, "cuda_graph_runner"):
logits = self.cuda_graph_runner(self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position) logits = self.cuda_graph_runner(
self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position
)
self.cache.change_seq_length(1) self.cache.change_seq_length(1)
torch.cuda.synchronize() torch.cuda.synchronize()
logits = logits[0,-1,:] logits = logits[0, -1, :]
return self.logits_to_token(logits) return self.logits_to_token(logits)
if self.use_static_cache: if self.use_static_cache:
mask = torch.ones((1,self.seq_length)).to(torch_device) mask = torch.ones((1, self.seq_length)).to(torch_device)
logits = self.model( logits = self.model(
self.current_ids, self.current_ids,
cache_position=self.active_cache_position, cache_position=self.active_cache_position,
past_key_values=self.cache, past_key_values=self.cache,
attention_mask=mask, attention_mask=mask,
return_dict=False, return_dict=False,
use_cache=True use_cache=True,
)[0] )[0]
else: else:
logits = self.model( logits = self.model(self.current_ids, return_dict=False)[0]
self.current_ids, logits = logits[0, -1, :]
return_dict=False
)[0]
logits = logits[0,-1,:]
return self.logits_to_token(logits) return self.logits_to_token(logits)

View file

@ -1,14 +1,22 @@
from typing import Any, List, Optional, Set from typing import Any, List, Optional, Set
from transformers import LlamaTokenizer,AutoTokenizer, AutoConfig, LlamaForCausalLM,GenerationConfig, StaticCache, AutoModelForCausalLM,BitsAndBytesConfig from transformers import (
LlamaTokenizer,
AutoTokenizer,
AutoConfig,
LlamaForCausalLM,
GenerationConfig,
StaticCache,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
from ktransformers.server.schemas.base import ObjectID from ktransformers.server.schemas.base import ObjectID
from ktransformers.server.utils.multi_timer import Profiler from ktransformers.server.utils.multi_timer import Profiler
import torch import torch
import sys, os import sys, os
from ..base import ThreadContext,BackendInterfaceBase from ..base import ThreadContext, BackendInterfaceBase
from ktransformers.server.config.log import logger from ktransformers.server.config.log import logger
from ..args import ConfigArgs,default_args from ..args import ConfigArgs, default_args
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py # This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
@ -28,21 +36,20 @@ class TextStreamer:
self.token_cache = [] self.token_cache = []
self.print_len = 0 self.print_len = 0
def put(self, value)->Optional[str]: def put(self, value) -> Optional[str]:
""" """
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
""" """
if not isinstance(value,int): if not isinstance(value, int):
raise ValueError("TextStreamer only supports batch size 1, and int type input") raise ValueError("TextStreamer only supports batch size 1, and int type input")
if self.skip_prompt and self.next_tokens_are_prompt: if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False self.next_tokens_are_prompt = False
return None return None
# Add the new token to the cache and decodes the entire thing. # Add the new token to the cache and decodes the entire thing.
self.token_cache.append(value) self.token_cache.append(value)
text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True,**self.decode_kwargs) text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs)
# After the symbol for a new line, we flush the cache. # After the symbol for a new line, we flush the cache.
if text.endswith("\n"): if text.endswith("\n"):
@ -59,7 +66,7 @@ class TextStreamer:
self.print_len += len(printable_text) self.print_len += len(printable_text)
return printable_text return printable_text
def end(self)->Optional[str]: def end(self) -> Optional[str]:
"""Flushes any remaining cache and prints a newline to stdout.""" """Flushes any remaining cache and prints a newline to stdout."""
# Flush the cache, if it exists # Flush the cache, if it exists
if len(self.token_cache) > 0: if len(self.token_cache) > 0:
@ -101,24 +108,20 @@ class TransformersThreadContext(ThreadContext):
def get_local_messages(self): def get_local_messages(self):
local_messages = [] local_messages = []
for m in self.messages: for m in self.messages:
local_messages.append( local_messages.append({"role": m.role.value, "content": m.get_text_content()})
{'role':m.role.value,
'content':m.get_text_content()}
)
return local_messages return local_messages
class TransformersInterface(BackendInterfaceBase): class TransformersInterface(BackendInterfaceBase):
use_static_cache : bool = True use_static_cache: bool = True
model: Any model: Any
tokenizer: AutoTokenizer tokenizer: AutoTokenizer
cache: StaticCache cache: StaticCache
generated_ids:torch.Tensor generated_ids: torch.Tensor
seq_length:int seq_length: int
streamer: TextStreamer streamer: TextStreamer
@ -126,52 +129,56 @@ class TransformersInterface(BackendInterfaceBase):
last_request_id: Optional[str] = None last_request_id: Optional[str] = None
ever_generated_ids: Set[int] = set() ever_generated_ids: Set[int] = set()
def __init__(self, args: ConfigArgs = default_args):
def __init__(self, args:ConfigArgs = default_args):
self.args = args self.args = args
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir) self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device,use_safetensors=True) self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device, use_safetensors=True)
logger.info(f'{args.model_name} loaded from {args.model_dir} to {args.device}') logger.info(f"{args.model_name} loaded from {args.model_dir} to {args.device}")
self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=args.device, dtype=self.model.dtype) self.cache = StaticCache(
logger.info(f'StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}') config=self.model.config,
max_batch_size=args.batch_size,
max_cache_len=args.cache_lens,
device=args.device,
dtype=self.model.dtype,
)
logger.info(f"StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}")
self.streamer = TextStreamer(self.tokenizer) self.streamer = TextStreamer(self.tokenizer)
@property @property
def current_ids(self): def current_ids(self):
return self.generated_ids[:,self.seq_length-1].unsqueeze(1) return self.generated_ids[:, self.seq_length - 1].unsqueeze(1)
@property @property
def active_cache_position(self): def active_cache_position(self):
return torch.tensor([self.seq_length-1], device=self.args.device) return torch.tensor([self.seq_length - 1], device=self.args.device)
def tokenize_prompt(self, prompt: str):
def tokenize_prompt(self,prompt:str): input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.args.device)
input_ids = self.tokenizer.encode(prompt,return_tensors='pt').to(self.args.device)
return input_ids return input_ids
def format_and_tokenize_input_ids(self,thread_id:ObjectID,messages:List): def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):
for m in messages: for m in messages:
if m['role']=='system': if m["role"] == "system":
logger.warn(f'change {m["role"]} to user') logger.warning(f'change {m["role"]} to user')
m['role'] = 'user' m["role"] = "user"
new_messages = [messages[0]] new_messages = [messages[0]]
for m in messages[1:]: for m in messages[1:]:
if m['role'] == 'user' and new_messages[-1]['role']=='user': if m["role"] == "user" and new_messages[-1]["role"] == "user":
logger.warn('merge two adjacent user messages') logger.warning("merge two adjacent user messages")
new_messages[-1]['content']+=m['content'] new_messages[-1]["content"] += m["content"]
else: else:
new_messages.append(m) new_messages.append(m)
# if (self.last_request_id is not None) and self.last_request_id == thread_id:
# input_ids = self.tokenizer.encode(self.tokenizer.eos_token+self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",tokenize=False, add_generation_prompt=True), add_special_tokens = False, return_tensors="pt").to(self.args.device)
# else:
# input_ids = self.tokenizer.apply_chat_template(
# new_messages, return_tensors="pt", add_generation_prompt=True
# ).to(self.args.device)
input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device) input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device)
if (self.last_request_id is not None) and self.last_request_id == thread_id: if (self.last_request_id is not None) and self.last_request_id == thread_id:
x = self.generated_ids[:,:self.seq_length] x = self.generated_ids[:,:self.seq_length]
y = input_ids[:,:self.seq_length] y = input_ids[:,:self.seq_length]
@ -179,19 +186,19 @@ class TransformersInterface(BackendInterfaceBase):
unequal_mask = torch.ne(x,y) unequal_mask = torch.ne(x,y)
unequal_positions = torch.nonzero(unequal_mask) unequal_positions = torch.nonzero(unequal_mask)
num_unequal_elements = unequal_mask.sum().item() num_unequal_elements = unequal_mask.sum().item()
logger.warn(f'num_unequal_elements: {num_unequal_elements}') logger.warning(f'num_unequal_elements: {num_unequal_elements}')
input_ids = input_ids[:,self.seq_length:] input_ids = input_ids[:,self.seq_length:]
logger.debug(f'get input ids of shape {input_ids.shape}') logger.debug(f"get input ids of shape {input_ids.shape}")
return input_ids return input_ids
def append_new_tokens(self,new_tokens:int)->Optional[str]: def append_new_tokens(self, new_tokens: int) -> Optional[str]:
self.generated_ids[0,self.seq_length] = new_tokens self.generated_ids[0, self.seq_length] = new_tokens
self.seq_length+=1 self.seq_length += 1
return self.streamer.put(new_tokens) return self.streamer.put(new_tokens)
def logits_to_token(self,logits:torch.Tensor): def logits_to_token(self, logits: torch.Tensor):
logits = logits/self.args.temperature logits = logits / self.args.temperature
for token_idx in self.ever_generated_ids: for token_idx in self.ever_generated_ids:
if logits[token_idx] < 0: if logits[token_idx] < 0:
@ -211,34 +218,28 @@ class TransformersInterface(BackendInterfaceBase):
self.ever_generated_ids.add(last) self.ever_generated_ids.add(last)
return last return last
def decode_one_tokens(self): def decode_one_tokens(self):
if self.use_static_cache: if self.use_static_cache:
mask = torch.ones((1,self.seq_length)).to(self.args.device) mask = torch.ones((1, self.seq_length)).to(self.args.device)
logits = self.model( logits = self.model(
self.current_ids, self.current_ids,
cache_position=self.active_cache_position, cache_position=self.active_cache_position,
past_key_values=self.cache, past_key_values=self.cache,
attention_mask=mask, attention_mask=mask,
return_dict=False, return_dict=False,
use_cache=True use_cache=True,
)[0] )[0]
else: else:
logits = self.model( logits = self.model(self.current_ids, return_dict=False)[0]
self.current_ids, logits = logits[0, -1, :]
return_dict=False
)[0]
logits = logits[0,-1,:]
return self.logits_to_token(logits) return self.logits_to_token(logits)
@torch.no_grad @torch.no_grad
def prefill(self,input_ids:torch.Tensor,is_new:bool): def prefill(self, input_ids: torch.Tensor, is_new: bool):
input_ids_length = input_ids.shape[-1] input_ids_length = input_ids.shape[-1]
self.profiler.set_counter('prefill',input_ids_length) self.profiler.set_counter("prefill", input_ids_length)
logger.debug(f'input_ids: {input_ids.shape}') logger.debug(f"input_ids: {input_ids.shape}")
if is_new: if is_new:
self.cache.reset() self.cache.reset()
@ -246,92 +247,95 @@ class TransformersInterface(BackendInterfaceBase):
former_seq_length = 0 former_seq_length = 0
self.seq_length = input_ids_length self.seq_length = input_ids_length
self.generated_ids = torch.zeros( self.generated_ids = torch.zeros(
self.args.batch_size, self.seq_length + self.args.max_new_tokens + 1, dtype=torch.int, device=self.args.device self.args.batch_size,
self.seq_length + self.args.max_new_tokens + 1,
dtype=torch.int,
device=self.args.device,
) )
else: else:
logger.debug(f'generate_ids: {self.generated_ids.shape}') logger.debug(f"generate_ids: {self.generated_ids.shape}")
former_seq_length = self.seq_length former_seq_length = self.seq_length
self.seq_length += input_ids_length self.seq_length += input_ids_length
expected_length = self.seq_length + self.args.max_new_tokens+1 expected_length = self.seq_length + self.args.max_new_tokens + 1
delta_length = expected_length - self.generated_ids.shape[-1] delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length>0: if delta_length > 0:
new_generate_ids = torch.zeros( new_generate_ids = torch.zeros(
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
) )
self.generated_ids = torch.cat([self.generated_ids,new_generate_ids],dim=-1) self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
logger.debug(f'cache position: {former_seq_length} to {self.seq_length}') logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
cache_position = torch.arange(former_seq_length,self.seq_length, device=self.args.device) cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)
self.generated_ids[:,cache_position] = input_ids.to(self.args.device).to(torch.int) self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
mask = torch.ones((1,self.seq_length)).to(self.args.device) mask = torch.ones((1, self.seq_length)).to(self.args.device)
device = input_ids.device device = input_ids.device
if not(type(self) is TransformersInterface): if not (type(self) is TransformersInterface):
input_ids = input_ids.to("cpu") input_ids = input_ids.to("cpu")
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device) inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
if self.use_static_cache: if self.use_static_cache:
logits = self.model( logits = self.model(
inputs_embeds=inputs_embeds, cache_position=cache_position, past_key_values=self.cache,return_dict=False, use_cache=True,attention_mask=mask inputs_embeds=inputs_embeds,
cache_position=cache_position,
past_key_values=self.cache,
return_dict=False,
use_cache=True,
attention_mask=mask,
)[0] )[0]
else: else:
logits = self.model( logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
inputs_embeds=inputs_embeds,return_dict=False
)[0]
next_token = self.logits_to_token(logits[0, -1, :])
next_token = self.logits_to_token(logits[0,-1,:])
yield self.append_new_tokens(next_token) yield self.append_new_tokens(next_token)
@torch.no_grad @torch.no_grad
def generate(self): def generate(self):
self.profiler.set_counter('decode',0) self.profiler.set_counter("decode", 0)
for _ in range(1, self.args.max_new_tokens): for _ in range(1, self.args.max_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
next_token = self.decode_one_tokens() next_token = self.decode_one_tokens()
self.profiler.inc('decode') self.profiler.inc("decode")
if next_token == self.tokenizer.eos_token_id: if next_token == self.tokenizer.eos_token_id:
assert self.args.batch_size == 1 assert self.args.batch_size == 1
break break
yield self.append_new_tokens(next_token) yield self.append_new_tokens(next_token)
yield self.streamer.end() yield self.streamer.end()
def check_is_new(self,thread_id:str): def check_is_new(self, thread_id: str):
if not self.use_static_cache: if not self.use_static_cache:
return True return True
if self.last_request_id is None: if self.last_request_id is None:
self.last_request_id = thread_id self.last_request_id = thread_id
return True return True
else: else:
if self.last_request_id==thread_id: if self.last_request_id == thread_id:
return False return False
else: else:
self.last_request_id = thread_id self.last_request_id = thread_id
return True return True
async def inference(self,local_messages,thread_id:str): async def inference(self, local_messages, thread_id: str):
self.profiler.create_and_start_timer('tokenize') self.profiler.create_and_start_timer("tokenize")
if isinstance(local_messages,List): if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id,local_messages) input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages,str): elif isinstance(local_messages, str):
input_ids = self.tokenize_prompt(local_messages) input_ids = self.tokenize_prompt(local_messages)
else: else:
raise ValueError('local_messages should be List or str') raise ValueError("local_messages should be List or str")
self.profiler.pause_timer('tokenize') self.profiler.pause_timer("tokenize")
self.profiler.create_and_start_timer('prefill') self.profiler.create_and_start_timer("prefill")
for t in self.prefill(input_ids,self.check_is_new(thread_id)): for t in self.prefill(input_ids, self.check_is_new(thread_id)):
if t is not None: if t is not None:
print(t,end='') print(t, end="")
yield t yield t
self.profiler.pause_timer('prefill') self.profiler.pause_timer("prefill")
self.profiler.create_and_start_timer('decode') self.profiler.create_and_start_timer("decode")
for t in self.generate(): for t in self.generate():
if t is not None: if t is not None:
print(t,end='') print(t, end="")
yield t yield t
print('') print("")
self.profiler.pause_timer('decode') self.profiler.pause_timer("decode")
self.report_last_time_performance() self.report_last_time_performance()

View file

@ -1,23 +1,24 @@
#!/usr/bin/env python #!/usr/bin/env python
# coding=utf-8 # coding=utf-8
''' """
Description : Description :
Author : unicornchan Author : unicornchan
Date : 2024-06-11 16:35:42 Date : 2024-06-11 16:35:42
Version : 1.0.0 Version : 1.0.0
LastEditors : WuHao LastEditors : WuHao
LastEditTime : 2024-08-12 06:31:14 LastEditTime : 2024-08-12 06:31:14
''' """
import os import os
import shutil import shutil
import yaml import yaml
from ktransformers.server.config.singleton import Singleton from ktransformers.server.config.singleton import Singleton
from typing import Optional
class Config(metaclass=Singleton): class Config(metaclass=Singleton):
"""Singleton pattern Config class, used to get all configurations. """Singleton pattern Config class, used to get all configurations."""
"""
CONFIG_FILE_NAME = "config.yaml" CONFIG_FILE_NAME = "config.yaml"
@staticmethod @staticmethod
@ -27,22 +28,20 @@ class Config(metaclass=Singleton):
Returns: Returns:
dict: all configs dict: all configs
""" """
base_path: str = os.path.dirname( base_path: str = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
os.path.dirname(os.path.dirname(__file__))) config_yaml: str = os.path.join(base_path, "configs", Config.CONFIG_FILE_NAME)
config_yaml: str = os.path.join(
base_path, "configs", Config.CONFIG_FILE_NAME)
user_path: str = os.path.expanduser('~') user_path: str = os.path.expanduser("~")
localstore_path: str = os.path.join(user_path,'.ktransformers') localstore_path: str = os.path.join(user_path, ".ktransformers")
config_path: str = os.path.join(localstore_path,Config.CONFIG_FILE_NAME) config_path: str = os.path.join(localstore_path, Config.CONFIG_FILE_NAME)
if not os.path.exists(config_yaml): if not os.path.exists(config_yaml):
print(f"Can't find config file, {config_yaml}") print(f"Can't find config file, {config_yaml}")
exit(-1) exit(-1)
if not os.path.exists(localstore_path): if not os.path.exists(localstore_path):
os.mkdir(localstore_path) os.mkdir(localstore_path)
if not os.path.exists(config_path): if not os.path.exists(config_path):
shutil.copyfile(config_yaml,config_path) shutil.copyfile(config_yaml, config_path)
with open(config_path, 'r', encoding="utf-8") as fp: with open(config_path, "r", encoding="utf-8") as fp:
config = yaml.safe_load(fp) config = yaml.safe_load(fp)
return config return config
@ -52,16 +51,14 @@ class Config(metaclass=Singleton):
process file path process file path
""" """
base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
real_path = path if os.path.isabs( real_path = path if os.path.isabs(path) else os.path.join(base_path, path)
path) else os.path.join(base_path, path)
return real_path return real_path
def __init__(self): def __init__(self):
cfg = Config.load() cfg = Config.load()
self.base_path = os.path.dirname( self.base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
os.path.dirname(os.path.dirname(__file__))) self.user_path: str = os.path.expanduser("~")
self.user_path: str = os.path.expanduser('~') self.localstore_path: str = os.path.join(self.user_path, ".ktransformers")
self.localstore_path: str = os.path.join(self.user_path,'.ktransformers')
# log configs # log configs
self.log_dir = os.path.join(self.base_path, Config.to_path(cfg["log"]["dir"])) self.log_dir = os.path.join(self.base_path, Config.to_path(cfg["log"]["dir"]))
self.log_file = cfg["log"]["file"] self.log_file = cfg["log"]["file"]
@ -69,7 +66,7 @@ class Config(metaclass=Singleton):
self.backup_count = cfg["log"]["backup_count"] self.backup_count = cfg["log"]["backup_count"]
# server configs # server configs
self.server: dict = cfg.get("server",{}) self.server: dict = cfg.get("server", {})
self.server_ip = self.server.get("ip", "0.0.0.0") self.server_ip = self.server.get("ip", "0.0.0.0")
self.server_port = self.server.get("port", 9016) self.server_port = self.server.get("port", 9016)
@ -88,13 +85,65 @@ class Config(metaclass=Singleton):
self.user_algorithm = self.user_config.get("algorithm", "") self.user_algorithm = self.user_config.get("algorithm", "")
# model config # model config
self.model:dict = cfg.get("model", {}) self.model: dict = cfg.get("model", {})
self.backend_type: str = self.model.get("type", "transformers") self.backend_type: str = self.model.get("type", "transformers")
self.model_path: str = self.model.get("path", "") self.model_dir: str = self.model.get("path", "")
# to make sure it consistent with previous version
self.model_path: str = self.model_dir
self.model_name: str = self.model.get("name", "") self.model_name: str = self.model.get("name", "")
self.model_device: str = self.model.get("device", "cuda:0") self.model_device: str = self.model.get("device", "cuda:0")
self.gguf_path: str = self.model.get("gguf_path", "") self.gguf_path: Optional[str] = self.model.get("gguf_path", None)
self.model_cache_lens = self.model.get("cache_lens") # self.model_cache_lens = self.model.get("cache_lens")
self.optimize_config_path: Optional[str] = self.model.get(
"optimize_config_path", "./ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml"
)
self.paged = self.model.get("paged", True)
self.total_context = self.model.get("total_context", 2**18)
self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1)
self.max_chunk_size = self.model.get("max_chunk_size", 2048)
self.max_new_tokens = self.model.get("max_new_tokens", 500)
self.json_mode = self.model.get("json_mode", False)
self.healing = self.model.get("healing", False)
self.ban_strings: Optional[list] = self.model.get("ban_strings", None)
self.gpu_split: Optional[str] = self.model.get("gpu_split", None)
self.length: Optional[int] = self.model.get("length", None)
self.rope_scale: Optional[float] = self.model.get("rope_scale", None)
self.rope_alpha: Optional[float] = self.model.get("rope_alpha", None)
self.no_flash_attn = self.model.get("no_flash_attn", False)
self.low_mem = self.model.get("low_mem", False)
self.experts_per_token: Optional[int] = self.model.get("experts_per_token", None)
self.load_q4 = self.model.get("load_q4", False)
self.fast_safetensors = self.model.get("fast_safetensors", False)
self.draft_model_dir: Optional[str] = self.model.get("draft_model_dir", None)
self.no_draft_scale = self.model.get("no_draft_scale", False)
self.modes = self.model.get("modes", False)
self.mode = self.model.get("mode", "llama")
self.username = self.model.get("username", "User")
self.botname = self.model.get("botname", "Chatbort")
self.system_prompt: Optional[str] = self.model.get("system_prompt", None)
self.temperature = self.model.get("temperature", 0.95)
self.smoothing_factor = self.model.get("smoothing_factor", 0.0)
self.dynamic_temperature: Optional[str] = self.model.get("dynamic_temperature", None)
self.top_k = self.model.get("top_k", 50)
self.top_p = self.model.get("top_p", 0.8)
self.top_a = self.model.get("top_a", 0.0)
self.skew = self.model.get("skew", 0.0)
self.typical = self.model.get("typical", 0.0)
self.repetition_penalty = self.model.get("repetition_penalty", 1.01)
self.frequency_penalty = self.model.get("frequency_penalty", 0.0)
self.presence_penalty = self.model.get("presence_penalty", 0.0)
self.max_response_tokens = self.model.get("max_response_tokens", 300)
self.response_chunk = self.model.get("response_chunk", 250)
self.no_code_formatting = self.model.get("no_code_formatting", False)
self.cache_8bit = self.model.get("cache_8bit", False)
self.cache_q4 = self.model.get("cache_q4", True)
self.ngram_decoding = self.model.get("ngram_decoding", False)
self.print_timings = self.model.get("print_timings", False)
self.amnesia = self.model.get("amnesia", False)
self.batch_size = self.model.get("batch_size", 1)
self.cache_lens = self.model.get("cache_lens", 4096)
self.device = self.model.get("device", "cuda:2")
# web config # web config
self.web: dict = cfg.get("web", {}) self.web: dict = cfg.get("web", {})
@ -104,10 +153,32 @@ class Config(metaclass=Singleton):
self.ext: dict = cfg.get("ext", {}) self.ext: dict = cfg.get("ext", {})
self.cpu_infer = self.ext.get("cpu_infer", 10) self.cpu_infer = self.ext.get("cpu_infer", 10)
#file config # file config
self.local_store_configs: dict = cfg.get("local_store",{}) self.local_store_configs: dict = cfg.get("local_store", {})
self.file_upload_dir: str = os.path.join(self.localstore_path,self.local_store_configs.get("file_upload_dir","")) self.file_upload_dir: str = os.path.join(
self.assistant_store_dir: str = os.path.join(self.localstore_path,self.local_store_configs.get("assistant_store_dir","")) self.localstore_path, self.local_store_configs.get("file_upload_dir", "")
)
self.assistant_store_dir: str = os.path.join(
self.localstore_path, self.local_store_configs.get("assistant_store_dir", "")
)
#long context config # long context config
self.long_context_config: dict = cfg.get("long_context",{}) self.long_context_config: dict = cfg.get("long_context", {})
self.chunk_size = self.long_context_config.get("chunk_size", 4096)
self.max_seq_len = self.long_context_config.get("max_seq_len", 32000)
self.block_size = self.long_context_config.get("block_size", 128)
self.local_windows_len = self.long_context_config.get("local_windows_len", 4096)
self.second_select_num = self.long_context_config.get("second_select_num", 32)
self.anchor_type = self.long_context_config.get("anchor_type", "DYNAMIC")
self.kv_type = self.long_context_config.get("kv_type", "FP16")
self.dense_layer_num = self.long_context_config.get("dense_layer_num", 2)
self.anchor_num = self.long_context_config.get("anchor_num", 1)
self.preselect_block = self.long_context_config.get("preselect_block", True)
self.head_select_mode = self.long_context_config.get("head_select_mode", "SHARED")
self.preselect_block_count = self.long_context_config.get("preselect_block_count", 32)
self.layer_step = self.long_context_config.get("layer_step", 1)
self.token_step = self.long_context_config.get("token_step", 100)
# local chat
self.local_chat_config: dict = cfg.get("local_chat", {})
self.prompt_file = self.local_chat_config.get("prompt_file", None)

View file

@ -3,9 +3,13 @@ import re
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
import uvicorn.logging import uvicorn.logging
import argparse
import uvicorn import uvicorn
import sys
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
sys.path.insert(0, project_dir)
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.args import ArgumentParser
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from ktransformers.server.utils.create_interface import create_interface from ktransformers.server.utils.create_interface import create_interface
from ktransformers.server.backend.args import default_args from ktransformers.server.backend.args import default_args
@ -44,8 +48,11 @@ def create_app():
mount_index_routes(app) mount_index_routes(app)
return app return app
def update_web_port(config_file: str): def update_web_port(config_file: str):
ip_port_pattern = r"(localhost|((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)):[0-9]{1,5}" ip_port_pattern = (
r"(localhost|((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)):[0-9]{1,5}"
)
with open(config_file, "r", encoding="utf-8") as f_cfg: with open(config_file, "r", encoding="utf-8") as f_cfg:
web_config = f_cfg.read() web_config = f_cfg.read()
ip_port = "localhost:" + str(Config().server_port) ip_port = "localhost:" + str(Config().server_port)
@ -70,14 +77,15 @@ def mount_index_routes(app: FastAPI):
def run_api(app, host, port, **kwargs): def run_api(app, host, port, **kwargs):
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"): if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
uvicorn.run(app, uvicorn.run(
app,
host=host, host=host,
port=port, port=port,
ssl_keyfile=kwargs.get("ssl_keyfile"), ssl_keyfile=kwargs.get("ssl_keyfile"),
ssl_certfile=kwargs.get("ssl_certfile"), ssl_certfile=kwargs.get("ssl_certfile"),
) )
else: else:
uvicorn.run(app, host=host, port=port, log_level='debug') uvicorn.run(app, host=host, port=port, log_level="debug")
def custom_openapi(app): def custom_openapi(app):
@ -90,53 +98,27 @@ def custom_openapi(app):
description="We provided chat completion and openai assistant interfaces.", description="We provided chat completion and openai assistant interfaces.",
routes=app.routes, routes=app.routes,
) )
openapi_schema["info"]["x-logo"] = { openapi_schema["info"]["x-logo"] = {"url": "https://kvcache.ai/media/icon_1.png"}
"url": "https://kvcache.ai/media/icon_1.png"
}
app.openapi_schema = openapi_schema app.openapi_schema = openapi_schema
return app.openapi_schema return app.openapi_schema
def main(): def main():
cfg = Config() cfg = Config()
parser = argparse.ArgumentParser(prog='kvcache.ai', arg_parser = ArgumentParser(cfg)
description='Ktransformers')
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=cfg.server_port)
parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str)
parser.add_argument("--web", type=bool, default=False)
parser.add_argument("--model_name", type=str, default=cfg.model_name)
parser.add_argument("--model_path", type=str, default=cfg.model_path)
parser.add_argument("--device", type=str, default=cfg.model_device, help="Warning: Abandoning this parameter")
parser.add_argument("--gguf_path", type=str, default=cfg.gguf_path)
parser.add_argument("--optimize_config_path", default=None, type=str, required=False)
parser.add_argument("--cpu_infer", type=int, default=cfg.cpu_infer)
parser.add_argument("--type", type=str, default=cfg.backend_type)
# 初始化消息 # 初始化消息
args = parser.parse_args() args = arg_parser.parse_args()
cfg.model_name = args.model_name
cfg.model_path = args.model_path
cfg.model_device = args.device
cfg.mount_web = args.web
cfg.server_ip = args.host
cfg.server_port = args.port
cfg.cpu_infer = args.cpu_infer
cfg.backend_type = args.type
default_args.model_dir = args.model_path
default_args.device = args.device
default_args.gguf_path = args.gguf_path
default_args.optimize_config_path = args.optimize_config_path
app = create_app() app = create_app()
custom_openapi(app) custom_openapi(app)
create_interface(config=cfg, default_args=default_args) create_interface(config=cfg, default_args=cfg)
run_api(app=app, run_api(
app=app,
host=args.host, host=args.host,
port=args.port, port=args.port,
ssl_keyfile=args.ssl_keyfile, ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,) ssl_certfile=args.ssl_certfile,
)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -4412,8 +4412,9 @@
}, },
"node_modules/@vue/cli": { "node_modules/@vue/cli": {
"version": "5.0.8", "version": "5.0.8",
"resolved": "https://registry.npmmirror.com/@vue/cli/-/cli-5.0.8.tgz", "resolved": "https://registry.npmjs.org/@vue/cli/-/cli-5.0.8.tgz",
"integrity": "sha512-c/QKPdC09bYkW22m/boXkLaiz10z0Z2WHZO7zEeNdfSduqyWINZhKc6hVQU3Vk0NXW7BJAd7zWmcUrC8L9TuAA==", "integrity": "sha512-c/QKPdC09bYkW22m/boXkLaiz10z0Z2WHZO7zEeNdfSduqyWINZhKc6hVQU3Vk0NXW7BJAd7zWmcUrC8L9TuAA==",
"license": "MIT",
"dependencies": { "dependencies": {
"@types/ejs": "^3.0.6", "@types/ejs": "^3.0.6",
"@types/inquirer": "^8.1.3", "@types/inquirer": "^8.1.3",

View file

@ -70,3 +70,7 @@ ktransformers = "ktransformers.server.main:main"
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
where = ["./", ] where = ["./", ]
include = ["ktransformers"] include = ["ktransformers"]
[tool.black]
line-length = 120
preview = true
unstable = true