Merge pull request #110 from KMSorSMS/main

refactor local_chat & config setting
This commit is contained in:
UnicornChan 2024-11-06 09:44:12 +08:00 committed by GitHub
commit dddc42038d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 652 additions and 400 deletions

4
.flake8 Normal file
View file

@ -0,0 +1,4 @@
[flake8]
max-line-length = 120
extend-select = B950
extend-ignore = E203,E501,E701, B001,B006,B007,B008,B009,B010,B011,B016,B028,B031,B950,E265,E266,E401,E402,E711,E712,E713,E721,E722,E731,F401,F403,F405,F541,F811,F821,F841,W391

3
.gitignore vendored
View file

@ -17,4 +17,5 @@ compile_commands.json
*dist/
ktransformers/server/local_store/
ktransformers/server_test1.db
*.patch
*.patch
img/

21
Makefile Normal file
View file

@ -0,0 +1,21 @@
flake_find:
cd ktransformers && flake8 | grep -Eo '[A-Z][0-9]{3}' | sort | uniq| paste -sd ',' -
format:
@cd ktransformers && black .
@black setup.py
dev_install:
# clear build dirs
rm -rf build
rm -rf *.egg-info
rm -rf ktransformers/ktransformers_ext/build
rm -rf ktransformers/ktransformers_ext/cuda/build
rm -rf ktransformers/ktransformers_ext/cuda/dist
rm -rf ktransformers/ktransformers_ext/cuda/*.egg-info
# install ktransformers
echo "Installing python dependencies from requirements.txt"
pip install -r requirements-local_chat.txt
echo "Installing ktransformers"
KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . --no-build-isolation
echo "Installation completed successfully"

View file

@ -151,7 +151,7 @@ Some preparation:
```
install.bat
```
4. If you are developer, you can make use of the makefile to compile and format the code. <br> the detailed usage of makefile is [here](./doc/en/makefile_usage.md)
<h3>Local Chat</h3>
We provide a simple command-line local chat Python script that you can run for testing.

26
doc/en/makefile_usage.md Normal file
View file

@ -0,0 +1,26 @@
# Makefile
## Target
### flake_find:
```bash
make flake_find
```
find all the python files under ./ktransformers dir and find the Error, Warning, Fatal... (their codes) into a list that are not consistent with the pep8 standard. For now we have get all this list in the .flake8 file's extend-ignore section in order to let flakes8 ignore them temporarily.(we may improve them in the future)
### format:
```bash
make format
```
we use black to format all the python files under ./ktransformers dir. It obeys the pep8 standard
but we modify the line length to 120 by add
```toml
[tool.black]
line-length = 120
preview = true
unstable = true
```
in the pyproject.toml file.
### dev_install:
```bash
make dev_install
```
install the package in the development mode. It means that the package is installed in the editable mode. So if you modify the code, you don't need to reinstall the package. We recommend the developer to use this method to install the package.

View file

@ -7,7 +7,7 @@ log:
server:
ip: 0.0.0.0
port: 12456
port: 10002
db:
type: "sqllite"
@ -24,10 +24,11 @@ model:
type: ktransformers
name: DeepSeek-Coder-V2-Instruct
path: /mnt/data/model/DeepSeek-Coder-V2-Instruct/
gguf_path: /mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH/
path: deepseek-ai/DeepSeek-V2-Lite-Chat
gguf_path: ./DeepSeek-V2-Lite-Chat-GGUF
device: cuda:0
cache_lens: 8192
web:
mount: False
@ -50,4 +51,7 @@ long_context:
head_select_mode: SHARED
preselect_block_count: 32
layer_step: 1
token_step: 100
token_step:
local_chat:
prompt_file: "./ktransformers/p.txt"

View file

@ -1,33 +1,23 @@
"""
Description :
Description :
Author : Boxin Zhang, Azure-Tang
Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import asyncio
import os
import platform
import sys
project_dir = os.path.dirname(os.path.dirname(__file__))
sys.path.insert(0, project_dir)
import torch
import logging
from transformers import (
AutoTokenizer,
AutoConfig,
AutoModelForCausalLM,
GenerationConfig,
TextStreamer,
)
import json
import fire
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.server.args import ArgumentParser
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate
from ktransformers.server.config.config import Config
custom_models = {
@ -37,9 +27,7 @@ custom_models = {
"MixtralForCausalLM": MixtralForCausalLM,
}
ktransformer_rules_dir = (
os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
)
ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
default_optimize_rules = {
"DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
@ -48,75 +36,28 @@ default_optimize_rules = {
}
def local_chat(
model_path: str | None = None,
optimize_rule_path: str = None,
gguf_path: str | None = None,
max_new_tokens: int = 1000,
cpu_infer: int = Config().cpu_infer,
use_cuda_graph: bool = True,
prompt_file : str | None = None,
mode: str = "normal",
):
torch.set_grad_enabled(False)
Config().cpu_infer = cpu_infer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if mode == 'long_context':
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
torch.set_default_dtype(torch.float16)
def local_chat():
config = Config()
arg_parser = ArgumentParser(config)
# 初始化消息
arg_parser.parse_args()
if config.backend_type == "transformers":
from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
elif config.backend_type == "exllamav2":
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface
elif config.backend_type == "ktransformers":
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
else:
torch.set_default_dtype(config.torch_dtype)
with torch.device("meta"):
if config.architectures[0] in custom_models:
print("using custom modeling_xxx.py.")
if (
"Qwen2Moe" in config.architectures[0]
): # Qwen2Moe must use flash_attention_2 to avoid overflow.
config._attn_implementation = "flash_attention_2"
if "Llama" in config.architectures[0]:
config._attn_implementation = "eager"
if "Mixtral" in config.architectures[0]:
config._attn_implementation = "flash_attention_2"
model = custom_models[config.architectures[0]](config)
else:
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=True, attn_implementation="flash_attention_2"
)
if optimize_rule_path is None:
if config.architectures[0] in default_optimize_rules:
print("using default_optimize_rule for", config.architectures[0])
optimize_rule_path = default_optimize_rules[config.architectures[0]]
else:
optimize_rule_path = input(
"please input the path of your rule file(yaml file containing optimize rules):"
)
if gguf_path is None:
gguf_path = input(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config)
model.generation_config = GenerationConfig.from_pretrained(model_path)
if model.generation_config.pad_token_id is None:
model.generation_config.pad_token_id = model.generation_config.eos_token_id
model.eval()
logging.basicConfig(level=logging.INFO)
raise NotImplementedError(f"{config.backend_type} not implemented")
interface = BackendInterface(config)
system = platform.system()
if system == "Windows":
os.system("cls")
else:
os.system("clear")
# add a history chat content
his_content = []
while True:
content = input("Chat: ")
if content.startswith('"""'): # prefix """
@ -132,28 +73,27 @@ def local_chat(
break
else:
content += line + "\n"
if content == "":
if prompt_file != None:
content = open(prompt_file, "r").read()
else:
if config.prompt_file == None or config.prompt_file == "":
content = "Please write a piece of quicksort code in C++."
else:
content = open(config.prompt_file, "r").read()
elif os.path.isfile(content):
content = open(content, "r").read()
messages = [{"role": "user", "content": content}]
input_tensor = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
)
if mode == 'long_context':
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
"please change max_seq_len in ~/.ktransformers/config.yaml"
torch.set_default_dtype(
torch.bfloat16
) # TODO: Remove this, replace dtype using config
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode
)
messages = his_content + [{"role": "user", "content": content}]
async def async_inference(messages):
generated = ""
async for token in interface.inference(messages, "local_chat"):
generated += token
return generated
generated = asyncio.run(async_inference(messages))
his_content += [
{"role": "user", "content": content},
{"role": "assistant", "content": generated},
]
if __name__ == "__main__":
fire.Fire(local_chat)
local_chat()

View file

@ -0,0 +1,56 @@
- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View file

@ -135,6 +135,8 @@ class OllamaShowResponse(BaseModel):
details: OllamaShowDetial
model_info: OllamaModelInfo
class Config:
protected_namespaces = ()

View file

@ -0,0 +1,124 @@
import argparse
from ktransformers.server.backend.args import ConfigArgs, default_args
class ArgumentParser:
def __init__(self, cfg):
self.cfg = cfg
def parse_args(self):
parser = argparse.ArgumentParser(prog="kvcache.ai", description="Ktransformers")
parser.add_argument("--host", type=str, default=self.cfg.server_ip)
parser.add_argument("--port", type=int, default=self.cfg.server_port)
parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str)
parser.add_argument("--web", type=bool, default=self.cfg.mount_web)
parser.add_argument("--model_name", type=str, default=self.cfg.model_name)
parser.add_argument("--model_dir", type=str)
parser.add_argument("--model_path", type=str)
parser.add_argument(
"--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter"
)
parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path)
parser.add_argument("--optimize_config_path", default=self.cfg.optimize_config_path, type=str, required=False)
parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer)
parser.add_argument("--type", type=str, default=self.cfg.backend_type)
# model configs
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
parser.add_argument("--paged", type=bool, default=self.cfg.paged)
parser.add_argument("--total_context", type=int, default=self.cfg.total_context)
parser.add_argument("--max_batch_size", type=int, default=self.cfg.max_batch_size)
parser.add_argument("--max_chunk_size", type=int, default=self.cfg.max_chunk_size)
parser.add_argument("--max_new_tokens", type=int, default=self.cfg.max_new_tokens)
parser.add_argument("--json_mode", type=bool, default=self.cfg.json_mode)
parser.add_argument("--healing", type=bool, default=self.cfg.healing)
parser.add_argument("--ban_strings", type=list, default=self.cfg.ban_strings, required=False)
parser.add_argument("--gpu_split", type=str, default=self.cfg.gpu_split, required=False)
parser.add_argument("--length", type=int, default=self.cfg.length, required=False)
parser.add_argument("--rope_scale", type=float, default=self.cfg.rope_scale, required=False)
parser.add_argument("--rope_alpha", type=float, default=self.cfg.rope_alpha, required=False)
parser.add_argument("--no_flash_attn", type=bool, default=self.cfg.no_flash_attn)
parser.add_argument("--low_mem", type=bool, default=self.cfg.low_mem)
parser.add_argument("--experts_per_token", type=int, default=self.cfg.experts_per_token, required=False)
parser.add_argument("--load_q4", type=bool, default=self.cfg.load_q4)
parser.add_argument("--fast_safetensors", type=bool, default=self.cfg.fast_safetensors)
parser.add_argument("--draft_model_dir", type=str, default=self.cfg.draft_model_dir, required=False)
parser.add_argument("--no_draft_scale", type=bool, default=self.cfg.no_draft_scale)
parser.add_argument("--modes", type=bool, default=self.cfg.modes)
parser.add_argument("--mode", type=str, default=self.cfg.mode)
parser.add_argument("--username", type=str, default=self.cfg.username)
parser.add_argument("--botname", type=str, default=self.cfg.botname)
parser.add_argument("--system_prompt", type=str, default=self.cfg.system_prompt, required=False)
parser.add_argument("--temperature", type=float, default=self.cfg.temperature)
parser.add_argument("--smoothing_factor", type=float, default=self.cfg.smoothing_factor)
parser.add_argument("--dynamic_temperature", type=str, default=self.cfg.dynamic_temperature, required=False)
parser.add_argument("--top_k", type=int, default=self.cfg.top_k)
parser.add_argument("--top_p", type=float, default=self.cfg.top_p)
parser.add_argument("--top_a", type=float, default=self.cfg.top_a)
parser.add_argument("--skew", type=float, default=self.cfg.skew)
parser.add_argument("--typical", type=float, default=self.cfg.typical)
parser.add_argument("--repetition_penalty", type=float, default=self.cfg.repetition_penalty)
parser.add_argument("--frequency_penalty", type=float, default=self.cfg.frequency_penalty)
parser.add_argument("--presence_penalty", type=float, default=self.cfg.presence_penalty)
parser.add_argument("--max_response_tokens", type=int, default=self.cfg.max_response_tokens)
parser.add_argument("--response_chunk", type=int, default=self.cfg.response_chunk)
parser.add_argument("--no_code_formatting", type=bool, default=self.cfg.no_code_formatting)
parser.add_argument("--cache_8bit", type=bool, default=self.cfg.cache_8bit)
parser.add_argument("--cache_q4", type=bool, default=self.cfg.cache_q4)
parser.add_argument("--ngram_decoding", type=bool, default=self.cfg.ngram_decoding)
parser.add_argument("--print_timings", type=bool, default=self.cfg.print_timings)
parser.add_argument("--amnesia", type=bool, default=self.cfg.amnesia)
parser.add_argument("--batch_size", type=int, default=self.cfg.batch_size)
parser.add_argument("--cache_lens", type=int, default=self.cfg.cache_lens)
# log configs
# log level: debug, info, warn, error, crit
parser.add_argument("--log_dir", type=str, default=self.cfg.log_dir)
parser.add_argument("--log_file", type=str, default=self.cfg.log_file)
parser.add_argument("--log_level", type=str, default=self.cfg.log_level)
parser.add_argument("--backup_count", type=int, default=self.cfg.backup_count)
# db configs
parser.add_argument("--db_type", type=str, default=self.cfg.db_type)
parser.add_argument("--db_host", type=str, default=self.cfg.db_host)
parser.add_argument("--db_port", type=str, default=self.cfg.db_port)
parser.add_argument("--db_name", type=str, default=self.cfg.db_name)
parser.add_argument("--db_pool_size", type=int, default=self.cfg.db_pool_size)
parser.add_argument("--db_database", type=str, default=self.cfg.db_database)
# user config
parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key)
parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm)
# web config
parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain)
# file config
parser.add_argument("--file_upload_dir", type=str, default=self.cfg.file_upload_dir)
parser.add_argument("--assistant_store_dir", type=str, default=self.cfg.assistant_store_dir)
# local chat
parser.add_argument("--prompt_file", type=str, default=self.cfg.prompt_file)
args = parser.parse_args()
if (args.model_dir is not None):
if (args.model_path is not None):
# if pass model_dir and model_path, we use model_path
args.model_dir = args.model_path
else:
# if only pass model_dir, we use model_dir
args.model_path = args.model_dir
else:
args.model_dir = self.cfg.model_dir
args.model_path = self.cfg.model_path
# set config from args
for key, value in vars(args).items():
if value is not None and hasattr(self.cfg, key):
setattr(self.cfg, key, value)
# we add the name not match args individually
self.cfg.model_device = args.device
self.cfg.mount_web = args.web
self.cfg.server_ip = args.host
self.cfg.server_port = args.port
self.cfg.backend_type = args.type
return args

View file

@ -1,97 +1,89 @@
from pydantic import BaseModel,Field
from pydantic import BaseModel, Field
from typing import Optional
from ktransformers.server.config.config import Config
class ConfigArgs(BaseModel):
model_name : Optional[str] = Field(..., description="Model name")
model_name: Optional[str] = Field(..., description="Model name")
model_dir: Optional[str] = Field(..., description="Path to model directory")
optimize_config_path: Optional[str] = Field('./KTransformers/optimize_config/DeepSeek-V2-Chat.json', description="Path of your optimize config json file")
gguf_path: Optional[str] = Field('/models/DeepSeek-Coder-V2-Instruct-GGUF/DeepSeek-Coder-V2-Instruct-Q4_K_M.gguf', description="Path of your gguf file")
optimize_config_path: Optional[str] = Field(None, description="Path of your optimize config yml file")
gguf_path: Optional[str] = Field(None, description="Path of your gguf file")
class Config:
protected_namespaces = ()
paged : bool = Field(True,description='Wether to use paged attention kv cache')
# total_context: int = Field(16384, description="Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the total to distribute dynamically over however many jobs are active at once")
total_context: int = Field(2**18, description="Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the total to distribute dynamically over however many jobs are active at once")
max_batch_size: int = Field(20 if paged else 1, description="Max number of batches to run at once, assuming the sequences will fit within total_context")
max_chunk_size: int = Field(2048, description="Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new job is started, but at the expense of overall prompt ingestion speed")
max_new_tokens: int = Field(500, description="Max new tokens per completion. For this example applies to all jobs")
json_mode: bool = Field(False, description="Use LMFE to constrain the output to JSON format. See schema and details below")
healing: bool = Field(False, description="Demonstrate token healing")
paged: bool = Field(None, description="Whether to use paged attention kv cache")
total_context: int = Field(
None,
description=(
"Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the"
" total to distribute dynamically over however many jobs are active at once"
),
)
max_batch_size: int = Field(
None, description="Max number of batches to run at once, assuming the sequences will fit within total_context"
)
max_chunk_size: int = Field(
None,
description=(
"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new"
" job is started, but at the expense of overall prompt ingestion speed"
),
)
max_new_tokens: int = Field(None, description="Max new tokens per completion. For this example applies to all jobs")
json_mode: bool = Field(
None, description="Use LMFE to constrain the output to JSON format. See schema and details below"
)
healing: bool = Field(None, description="Demonstrate token healing")
ban_strings: Optional[list] = Field(None, description="Ban some phrases maybe")
gpu_split: Optional[str] = Field(None, description='"auto", or VRAM allocation per GPU in GB')
length: Optional[int] = Field(None, description="Maximum sequence length")
rope_scale: Optional[float] = Field(None, description="RoPE scaling factor")
rope_alpha: Optional[float] = Field(None, description="RoPE alpha value (NTK)")
no_flash_attn: bool = Field(False, description="Disable Flash Attention")
low_mem: bool = Field(
False,
description="Enable VRAM optimizations, potentially trading off speed",
)
no_flash_attn: bool = Field(None, description="Disable Flash Attention")
low_mem: bool = Field(None, description="Enable VRAM optimizations, potentially trading off speed")
experts_per_token: Optional[int] = Field(
None,
description="Override MoE model's default number of experts per token",
)
load_q4: bool = Field(False, description="Load weights in Q4 mode")
fast_safetensors: bool = Field(
False,
description="Optimized safetensors loading with direct I/O (experimental!)",
None, description="Override MoE model's default number of experts per token"
)
load_q4: bool = Field(None, description="Load weights in Q4 mode")
fast_safetensors: bool = Field(None, description="Optimized safetensors loading with direct I/O (experimental!)")
draft_model_dir: Optional[str] = Field(None, description="Path to draft model directory")
no_draft_scale: bool = Field(
False,
None,
description="If draft model has smaller context size than model, don't apply alpha (NTK) scaling to extend it",
)
modes: bool = Field(False, description="List available modes and exit.")
mode: str = Field(
"llama",
description="Chat mode. Use llama for Llama 1/2 chat finetunes.",
)
username: str = Field("User", description="Username when using raw chat mode")
botname: str = Field("Chatbort", description="Bot name when using raw chat mode")
modes: bool = Field(None, description="List available modes and exit.")
mode: str = Field(None, description="Chat mode. Use llama for Llama 1/2 chat finetunes.")
username: str = Field(None, description="Username when using raw chat mode")
botname: str = Field(None, description="Bot name when using raw chat mode")
system_prompt: Optional[str] = Field(None, description="Use custom system prompt")
temperature: float = Field(0.95, description="Sampler temperature, default = 0.95 (1 to disable)")
smoothing_factor: float = Field(0.0, description="Smoothing Factor, default = 0.0 (0 to disable)")
temperature: float = Field(None, description="Sampler temperature, default = 0.95 (1 to disable)")
smoothing_factor: float = Field(None, description="Smoothing Factor, default = 0.0 (0 to disable)")
dynamic_temperature: Optional[str] = Field(
None,
description="Dynamic temperature min,max,exponent, e.g. -dyntemp 0.2,1.5,1",
None, description="Dynamic temperature min,max,exponent, e.g. -dyntemp 0.2,1.5,1"
)
top_k: int = Field(50, description="Sampler top-K, default = 50 (0 to disable)")
top_p: float = Field(0.8, description="Sampler top-P, default = 0.8 (0 to disable)")
top_a: float = Field(0.0, description="Sampler top-A, default = 0.0 (0 to disable)")
skew: float = Field(0.0, description="Skew sampling, default = 0.0 (0 to disable)")
typical: float = Field(
0.0,
description="Sampler typical threshold, default = 0.0 (0 to disable)",
)
repetition_penalty: float = Field(
1.01,
description="Sampler repetition penalty, default = 1.01 (1 to disable)",
)
frequency_penalty: float = Field(
0.0,
description="Sampler frequency penalty, default = 0.0 (0 to disable)",
)
presence_penalty: float = Field(
0.0,
description="Sampler presence penalty, default = 0.0 (0 to disable)",
)
max_response_tokens: int = Field(300, description="Max tokens per response, default = 1000")
response_chunk: int = Field(250, description="Space to reserve in context for reply, default = 250")
no_code_formatting: bool = Field(False, description="Disable code formatting/syntax highlighting")
cache_8bit: bool = Field(False, description="Use 8-bit (FP8) cache")
cache_q4: bool = Field(True, description="Use Q4 cache")
ngram_decoding: bool = Field(False, description="Use n-gram speculative decoding")
print_timings: bool = Field(False, description="Output timings after each prompt")
amnesia: bool = Field(False, description="Forget context after every response")
top_k: int = Field(None, description="Sampler top-K, default = 50 (0 to disable)")
top_p: float = Field(None, description="Sampler top-P, default = 0.8 (0 to disable)")
top_a: float = Field(None, description="Sampler top-A, default = 0.0 (0 to disable)")
skew: float = Field(None, description="Skew sampling, default = 0.0 (0 to disable)")
typical: float = Field(None, description="Sampler typical threshold, default = 0.0 (0 to disable)")
repetition_penalty: float = Field(None, description="Sampler repetition penalty, default = 1.01 (1 to disable)")
frequency_penalty: float = Field(None, description="Sampler frequency penalty, default = 0.0 (0 to disable)")
presence_penalty: float = Field(None, description="Sampler presence penalty, default = 0.0 (0 to disable)")
max_response_tokens: int = Field(None, description="Max tokens per response, default = 1000")
response_chunk: int = Field(None, description="Space to reserve in context for reply, default = 250")
no_code_formatting: bool = Field(None, description="Disable code formatting/syntax highlighting")
cache_8bit: bool = Field(None, description="Use 8-bit (FP8) cache")
cache_q4: bool = Field(None, description="Use Q4 cache")
ngram_decoding: bool = Field(None, description="Use n-gram speculative decoding")
print_timings: bool = Field(None, description="Output timings after each prompt")
amnesia: bool = Field(None, description="Forget context after every response")
# for transformers
batch_size :int = Field(1,description="Batch Size")
cache_lens:int = Field(4096, description="Cache lens for transformers static cache")
device:str = Field('cuda:2',description="device")
batch_size: int = Field(None, description="Batch Size")
cache_lens: int = Field(None, description="Cache lens for transformers static cache")
device: str = Field(None, description="device")
cfg = Config()
default_args = ConfigArgs(model_name=cfg.model_name,model_dir=cfg.model_path)
default_args = cfg

View file

@ -1,6 +1,12 @@
import torch
from transformers import AutoTokenizer, AutoConfig, GenerationConfig
from ktransformers.server.backend.interfaces.transformers import TransformersInterface,ConfigArgs, TransformersThreadContext,default_args,TextStreamer
from ktransformers.server.backend.interfaces.transformers import (
TransformersInterface,
ConfigArgs,
TransformersThreadContext,
default_args,
TextStreamer,
)
from ktransformers.server.config.log import logger
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.custom_cache import StaticCache
@ -14,71 +20,85 @@ class KTransformersThreadContext(TransformersThreadContext):
class KTransformersInterface(TransformersInterface):
def __init__(self,args:ConfigArgs= default_args):
def __init__(self, args: ConfigArgs = default_args):
self.args = args
torch.set_default_dtype(torch.bfloat16)
torch.set_grad_enabled(False)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir,device = args.device)
config=AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device)
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
if config.architectures[0] == "Qwen2MoeForCausalLM":
config._attn_implementation="flash_attention_2"
config._attn_implementation = "flash_attention_2"
with torch.device("meta"):
self.model=custom_models[config.architectures[0]](config)
self.model = custom_models[config.architectures[0]](config)
if default_args.optimize_config_path is None:
optimize_rule_path = default_optimize_rules[config.architectures[0]]
else:
optimize_rule_path = args.optimize_config_path
# print(optimize_config)
gguf_path = args.gguf_path
if gguf_path is None:
gguf_path = input(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
" belong to current model):"
)
optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config)
device_map = self.model.gguf_loader.tensor_device_map
logger.info(f'{args.model_name} loaded from {args.model_dir} to {device_map}')
self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=device_map, dtype=self.model.dtype)
logger.info(f'StaticCache (length={args.cache_lens}) created at {device_map}, batch size:{args.batch_size}')
logger.info(f"{args.model_name} loaded from {args.model_dir} to {device_map}")
self.cache = StaticCache(
config=self.model.config,
max_batch_size=args.batch_size,
max_cache_len=args.cache_lens,
device=device_map,
dtype=self.model.dtype,
)
logger.info(f"StaticCache (length={args.cache_lens}) created at {device_map}, batch size:{args.batch_size}")
self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
if self.model.generation_config.pad_token_id is None:
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
self.streamer = TextStreamer(self.tokenizer)
def decode_one_tokens(self):
if not hasattr(self, "cuda_graph_runner"):
device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device('blk.0.self_attn', device_map)
torch_device = get_device("blk.0.self_attn", device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
self.cuda_graph_runner = CUDAGraphRunner()
self.cuda_graph_runner.capture(self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, main_device=torch_device, return_dict=False, use_cache=True)
self.cuda_graph_runner.capture(
self.model,
self.current_ids,
self.active_cache_position.unsqueeze(0),
self.active_cache_position,
self.cache,
main_device=torch_device,
return_dict=False,
use_cache=True,
)
if hasattr(self, "cuda_graph_runner"):
logits = self.cuda_graph_runner(self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position)
logits = self.cuda_graph_runner(
self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position
)
self.cache.change_seq_length(1)
torch.cuda.synchronize()
logits = logits[0,-1,:]
logits = logits[0, -1, :]
return self.logits_to_token(logits)
if self.use_static_cache:
mask = torch.ones((1,self.seq_length)).to(torch_device)
mask = torch.ones((1, self.seq_length)).to(torch_device)
logits = self.model(
self.current_ids,
cache_position=self.active_cache_position,
past_key_values=self.cache,
attention_mask=mask,
return_dict=False,
use_cache=True
use_cache=True,
)[0]
else:
logits = self.model(
self.current_ids,
return_dict=False
)[0]
logits = logits[0,-1,:]
logits = self.model(self.current_ids, return_dict=False)[0]
logits = logits[0, -1, :]
return self.logits_to_token(logits)

View file

@ -1,14 +1,22 @@
from typing import Any, List, Optional, Set
from transformers import LlamaTokenizer,AutoTokenizer, AutoConfig, LlamaForCausalLM,GenerationConfig, StaticCache, AutoModelForCausalLM,BitsAndBytesConfig
from transformers import (
LlamaTokenizer,
AutoTokenizer,
AutoConfig,
LlamaForCausalLM,
GenerationConfig,
StaticCache,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
from ktransformers.server.schemas.base import ObjectID
from ktransformers.server.utils.multi_timer import Profiler
import torch
import sys, os
from ..base import ThreadContext,BackendInterfaceBase
from ..base import ThreadContext, BackendInterfaceBase
from ktransformers.server.config.log import logger
from ..args import ConfigArgs,default_args
from ..args import ConfigArgs, default_args
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
@ -28,21 +36,20 @@ class TextStreamer:
self.token_cache = []
self.print_len = 0
def put(self, value)->Optional[str]:
def put(self, value) -> Optional[str]:
"""
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
"""
if not isinstance(value,int):
"""
if not isinstance(value, int):
raise ValueError("TextStreamer only supports batch size 1, and int type input")
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return None
# Add the new token to the cache and decodes the entire thing.
self.token_cache.append(value)
text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True,**self.decode_kwargs)
text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs)
# After the symbol for a new line, we flush the cache.
if text.endswith("\n"):
@ -59,7 +66,7 @@ class TextStreamer:
self.print_len += len(printable_text)
return printable_text
def end(self)->Optional[str]:
def end(self) -> Optional[str]:
"""Flushes any remaining cache and prints a newline to stdout."""
# Flush the cache, if it exists
if len(self.token_cache) > 0:
@ -71,7 +78,7 @@ class TextStreamer:
self.next_tokens_are_prompt = True
return printable_text
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
@ -97,81 +104,81 @@ class TextStreamer:
return False
class TransformersThreadContext(ThreadContext):
class TransformersThreadContext(ThreadContext):
def get_local_messages(self):
local_messages = []
for m in self.messages:
local_messages.append(
{'role':m.role.value,
'content':m.get_text_content()}
)
local_messages.append({"role": m.role.value, "content": m.get_text_content()})
return local_messages
class TransformersInterface(BackendInterfaceBase):
use_static_cache : bool = True
use_static_cache: bool = True
model: Any
tokenizer: AutoTokenizer
cache: StaticCache
generated_ids:torch.Tensor
seq_length:int
generated_ids: torch.Tensor
seq_length: int
streamer: TextStreamer
# thread_related
last_request_id: Optional[str] = None
ever_generated_ids: Set[int] = set()
def __init__(self, args:ConfigArgs = default_args):
def __init__(self, args: ConfigArgs = default_args):
self.args = args
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device,use_safetensors=True)
logger.info(f'{args.model_name} loaded from {args.model_dir} to {args.device}')
self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=args.device, dtype=self.model.dtype)
logger.info(f'StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}')
self.streamer = TextStreamer(self.tokenizer)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device, use_safetensors=True)
logger.info(f"{args.model_name} loaded from {args.model_dir} to {args.device}")
self.cache = StaticCache(
config=self.model.config,
max_batch_size=args.batch_size,
max_cache_len=args.cache_lens,
device=args.device,
dtype=self.model.dtype,
)
logger.info(f"StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}")
self.streamer = TextStreamer(self.tokenizer)
@property
def current_ids(self):
return self.generated_ids[:,self.seq_length-1].unsqueeze(1)
return self.generated_ids[:, self.seq_length - 1].unsqueeze(1)
@property
def active_cache_position(self):
return torch.tensor([self.seq_length-1], device=self.args.device)
return torch.tensor([self.seq_length - 1], device=self.args.device)
def tokenize_prompt(self,prompt:str):
input_ids = self.tokenizer.encode(prompt,return_tensors='pt').to(self.args.device)
def tokenize_prompt(self, prompt: str):
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.args.device)
return input_ids
def format_and_tokenize_input_ids(self,thread_id:ObjectID,messages:List):
def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):
for m in messages:
if m['role']=='system':
logger.warn(f'change {m["role"]} to user')
m['role'] = 'user'
if m["role"] == "system":
logger.warning(f'change {m["role"]} to user')
m["role"] = "user"
new_messages = [messages[0]]
for m in messages[1:]:
if m['role'] == 'user' and new_messages[-1]['role']=='user':
logger.warn('merge two adjacent user messages')
new_messages[-1]['content']+=m['content']
for m in messages[1:]:
if m["role"] == "user" and new_messages[-1]["role"] == "user":
logger.warning("merge two adjacent user messages")
new_messages[-1]["content"] += m["content"]
else:
new_messages.append(m)
new_messages.append(m)
# if (self.last_request_id is not None) and self.last_request_id == thread_id:
# input_ids = self.tokenizer.encode(self.tokenizer.eos_token+self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",tokenize=False, add_generation_prompt=True), add_special_tokens = False, return_tensors="pt").to(self.args.device)
# else:
# input_ids = self.tokenizer.apply_chat_template(
# new_messages, return_tensors="pt", add_generation_prompt=True
# ).to(self.args.device)
input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device)
if (self.last_request_id is not None) and self.last_request_id == thread_id:
x = self.generated_ids[:,:self.seq_length]
y = input_ids[:,:self.seq_length]
@ -179,19 +186,19 @@ class TransformersInterface(BackendInterfaceBase):
unequal_mask = torch.ne(x,y)
unequal_positions = torch.nonzero(unequal_mask)
num_unequal_elements = unequal_mask.sum().item()
logger.warn(f'num_unequal_elements: {num_unequal_elements}')
logger.warning(f'num_unequal_elements: {num_unequal_elements}')
input_ids = input_ids[:,self.seq_length:]
logger.debug(f'get input ids of shape {input_ids.shape}')
logger.debug(f"get input ids of shape {input_ids.shape}")
return input_ids
def append_new_tokens(self,new_tokens:int)->Optional[str]:
self.generated_ids[0,self.seq_length] = new_tokens
self.seq_length+=1
def append_new_tokens(self, new_tokens: int) -> Optional[str]:
self.generated_ids[0, self.seq_length] = new_tokens
self.seq_length += 1
return self.streamer.put(new_tokens)
def logits_to_token(self,logits:torch.Tensor):
logits = logits/self.args.temperature
def logits_to_token(self, logits: torch.Tensor):
logits = logits / self.args.temperature
for token_idx in self.ever_generated_ids:
if logits[token_idx] < 0:
@ -200,7 +207,7 @@ class TransformersInterface(BackendInterfaceBase):
logits[token_idx] /= self.args.repetition_penalty
probs = torch.nn.functional.softmax(logits, dim=-1)
sample = True
if sample:
last = torch.multinomial(probs, num_samples=1)
@ -211,127 +218,124 @@ class TransformersInterface(BackendInterfaceBase):
self.ever_generated_ids.add(last)
return last
def decode_one_tokens(self):
if self.use_static_cache:
mask = torch.ones((1,self.seq_length)).to(self.args.device)
mask = torch.ones((1, self.seq_length)).to(self.args.device)
logits = self.model(
self.current_ids,
cache_position=self.active_cache_position,
past_key_values=self.cache,
attention_mask=mask,
return_dict=False,
use_cache=True
use_cache=True,
)[0]
else:
logits = self.model(
self.current_ids,
return_dict=False
)[0]
logits = logits[0,-1,:]
logits = self.model(self.current_ids, return_dict=False)[0]
logits = logits[0, -1, :]
return self.logits_to_token(logits)
@torch.no_grad
def prefill(self,input_ids:torch.Tensor,is_new:bool):
def prefill(self, input_ids: torch.Tensor, is_new: bool):
input_ids_length = input_ids.shape[-1]
self.profiler.set_counter('prefill',input_ids_length)
logger.debug(f'input_ids: {input_ids.shape}')
self.profiler.set_counter("prefill", input_ids_length)
logger.debug(f"input_ids: {input_ids.shape}")
if is_new:
self.cache.reset()
self.ever_generated_ids.clear()
former_seq_length = 0
self.seq_length = input_ids_length
self.generated_ids = torch.zeros(
self.args.batch_size, self.seq_length + self.args.max_new_tokens + 1, dtype=torch.int, device=self.args.device
)
self.args.batch_size,
self.seq_length + self.args.max_new_tokens + 1,
dtype=torch.int,
device=self.args.device,
)
else:
logger.debug(f'generate_ids: {self.generated_ids.shape}')
logger.debug(f"generate_ids: {self.generated_ids.shape}")
former_seq_length = self.seq_length
self.seq_length += input_ids_length
expected_length = self.seq_length + self.args.max_new_tokens+1
expected_length = self.seq_length + self.args.max_new_tokens + 1
delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length>0:
if delta_length > 0:
new_generate_ids = torch.zeros(
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
)
self.generated_ids = torch.cat([self.generated_ids,new_generate_ids],dim=-1)
logger.debug(f'cache position: {former_seq_length} to {self.seq_length}')
cache_position = torch.arange(former_seq_length,self.seq_length, device=self.args.device)
self.generated_ids[:,cache_position] = input_ids.to(self.args.device).to(torch.int)
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
mask = torch.ones((1,self.seq_length)).to(self.args.device)
mask = torch.ones((1, self.seq_length)).to(self.args.device)
device = input_ids.device
if not(type(self) is TransformersInterface):
if not (type(self) is TransformersInterface):
input_ids = input_ids.to("cpu")
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
if self.use_static_cache:
logits = self.model(
inputs_embeds=inputs_embeds, cache_position=cache_position, past_key_values=self.cache,return_dict=False, use_cache=True,attention_mask=mask
inputs_embeds=inputs_embeds,
cache_position=cache_position,
past_key_values=self.cache,
return_dict=False,
use_cache=True,
attention_mask=mask,
)[0]
else:
logits = self.model(
inputs_embeds=inputs_embeds,return_dict=False
)[0]
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
next_token = self.logits_to_token(logits[0,-1,:])
next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token)
@torch.no_grad
def generate(self):
self.profiler.set_counter('decode',0)
self.profiler.set_counter("decode", 0)
for _ in range(1, self.args.max_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
next_token = self.decode_one_tokens()
self.profiler.inc('decode')
self.profiler.inc("decode")
if next_token == self.tokenizer.eos_token_id:
assert self.args.batch_size == 1
break
yield self.append_new_tokens(next_token)
yield self.streamer.end()
def check_is_new(self,thread_id:str):
def check_is_new(self, thread_id: str):
if not self.use_static_cache:
return True
if self.last_request_id is None:
self.last_request_id = thread_id
return True
else:
if self.last_request_id==thread_id:
if self.last_request_id == thread_id:
return False
else:
self.last_request_id = thread_id
return True
async def inference(self,local_messages,thread_id:str):
self.profiler.create_and_start_timer('tokenize')
if isinstance(local_messages,List):
input_ids = self.format_and_tokenize_input_ids(thread_id,local_messages)
elif isinstance(local_messages,str):
async def inference(self, local_messages, thread_id: str):
self.profiler.create_and_start_timer("tokenize")
if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str):
input_ids = self.tokenize_prompt(local_messages)
else:
raise ValueError('local_messages should be List or str')
raise ValueError("local_messages should be List or str")
self.profiler.pause_timer('tokenize')
self.profiler.pause_timer("tokenize")
self.profiler.create_and_start_timer('prefill')
for t in self.prefill(input_ids,self.check_is_new(thread_id)):
self.profiler.create_and_start_timer("prefill")
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
if t is not None:
print(t,end='')
print(t, end="")
yield t
self.profiler.pause_timer('prefill')
self.profiler.pause_timer("prefill")
self.profiler.create_and_start_timer('decode')
self.profiler.create_and_start_timer("decode")
for t in self.generate():
if t is not None:
print(t,end='')
print(t, end="")
yield t
print('')
self.profiler.pause_timer('decode')
print("")
self.profiler.pause_timer("decode")
self.report_last_time_performance()

View file

@ -1,23 +1,24 @@
#!/usr/bin/env python
# coding=utf-8
'''
Description :
"""
Description :
Author : unicornchan
Date : 2024-06-11 16:35:42
Version : 1.0.0
LastEditors : WuHao
LastEditors : WuHao
LastEditTime : 2024-08-12 06:31:14
'''
"""
import os
import shutil
import yaml
from ktransformers.server.config.singleton import Singleton
from typing import Optional
class Config(metaclass=Singleton):
"""Singleton pattern Config class, used to get all configurations.
"""
"""Singleton pattern Config class, used to get all configurations."""
CONFIG_FILE_NAME = "config.yaml"
@staticmethod
@ -27,22 +28,20 @@ class Config(metaclass=Singleton):
Returns:
dict: all configs
"""
base_path: str = os.path.dirname(
os.path.dirname(os.path.dirname(__file__)))
config_yaml: str = os.path.join(
base_path, "configs", Config.CONFIG_FILE_NAME)
user_path: str = os.path.expanduser('~')
localstore_path: str = os.path.join(user_path,'.ktransformers')
config_path: str = os.path.join(localstore_path,Config.CONFIG_FILE_NAME)
base_path: str = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
config_yaml: str = os.path.join(base_path, "configs", Config.CONFIG_FILE_NAME)
user_path: str = os.path.expanduser("~")
localstore_path: str = os.path.join(user_path, ".ktransformers")
config_path: str = os.path.join(localstore_path, Config.CONFIG_FILE_NAME)
if not os.path.exists(config_yaml):
print(f"Can't find config file, {config_yaml}")
exit(-1)
if not os.path.exists(localstore_path):
os.mkdir(localstore_path)
if not os.path.exists(config_path):
shutil.copyfile(config_yaml,config_path)
with open(config_path, 'r', encoding="utf-8") as fp:
shutil.copyfile(config_yaml, config_path)
with open(config_path, "r", encoding="utf-8") as fp:
config = yaml.safe_load(fp)
return config
@ -52,16 +51,14 @@ class Config(metaclass=Singleton):
process file path
"""
base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
real_path = path if os.path.isabs(
path) else os.path.join(base_path, path)
real_path = path if os.path.isabs(path) else os.path.join(base_path, path)
return real_path
def __init__(self):
cfg = Config.load()
self.base_path = os.path.dirname(
os.path.dirname(os.path.dirname(__file__)))
self.user_path: str = os.path.expanduser('~')
self.localstore_path: str = os.path.join(self.user_path,'.ktransformers')
self.base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
self.user_path: str = os.path.expanduser("~")
self.localstore_path: str = os.path.join(self.user_path, ".ktransformers")
# log configs
self.log_dir = os.path.join(self.base_path, Config.to_path(cfg["log"]["dir"]))
self.log_file = cfg["log"]["file"]
@ -69,7 +66,7 @@ class Config(metaclass=Singleton):
self.backup_count = cfg["log"]["backup_count"]
# server configs
self.server: dict = cfg.get("server",{})
self.server: dict = cfg.get("server", {})
self.server_ip = self.server.get("ip", "0.0.0.0")
self.server_port = self.server.get("port", 9016)
@ -86,16 +83,68 @@ class Config(metaclass=Singleton):
self.user_config: dict = cfg.get("user", {})
self.user_secret_key = self.user_config.get("secret_key", "")
self.user_algorithm = self.user_config.get("algorithm", "")
# model config
self.model:dict = cfg.get("model", {})
self.model: dict = cfg.get("model", {})
self.backend_type: str = self.model.get("type", "transformers")
self.model_path: str = self.model.get("path", "")
self.model_dir: str = self.model.get("path", "")
# to make sure it consistent with previous version
self.model_path: str = self.model_dir
self.model_name: str = self.model.get("name", "")
self.model_device: str = self.model.get("device", "cuda:0")
self.gguf_path: str = self.model.get("gguf_path", "")
self.model_cache_lens = self.model.get("cache_lens")
self.gguf_path: Optional[str] = self.model.get("gguf_path", None)
# self.model_cache_lens = self.model.get("cache_lens")
self.optimize_config_path: Optional[str] = self.model.get(
"optimize_config_path", "./ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml"
)
self.paged = self.model.get("paged", True)
self.total_context = self.model.get("total_context", 2**18)
self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1)
self.max_chunk_size = self.model.get("max_chunk_size", 2048)
self.max_new_tokens = self.model.get("max_new_tokens", 500)
self.json_mode = self.model.get("json_mode", False)
self.healing = self.model.get("healing", False)
self.ban_strings: Optional[list] = self.model.get("ban_strings", None)
self.gpu_split: Optional[str] = self.model.get("gpu_split", None)
self.length: Optional[int] = self.model.get("length", None)
self.rope_scale: Optional[float] = self.model.get("rope_scale", None)
self.rope_alpha: Optional[float] = self.model.get("rope_alpha", None)
self.no_flash_attn = self.model.get("no_flash_attn", False)
self.low_mem = self.model.get("low_mem", False)
self.experts_per_token: Optional[int] = self.model.get("experts_per_token", None)
self.load_q4 = self.model.get("load_q4", False)
self.fast_safetensors = self.model.get("fast_safetensors", False)
self.draft_model_dir: Optional[str] = self.model.get("draft_model_dir", None)
self.no_draft_scale = self.model.get("no_draft_scale", False)
self.modes = self.model.get("modes", False)
self.mode = self.model.get("mode", "llama")
self.username = self.model.get("username", "User")
self.botname = self.model.get("botname", "Chatbort")
self.system_prompt: Optional[str] = self.model.get("system_prompt", None)
self.temperature = self.model.get("temperature", 0.95)
self.smoothing_factor = self.model.get("smoothing_factor", 0.0)
self.dynamic_temperature: Optional[str] = self.model.get("dynamic_temperature", None)
self.top_k = self.model.get("top_k", 50)
self.top_p = self.model.get("top_p", 0.8)
self.top_a = self.model.get("top_a", 0.0)
self.skew = self.model.get("skew", 0.0)
self.typical = self.model.get("typical", 0.0)
self.repetition_penalty = self.model.get("repetition_penalty", 1.01)
self.frequency_penalty = self.model.get("frequency_penalty", 0.0)
self.presence_penalty = self.model.get("presence_penalty", 0.0)
self.max_response_tokens = self.model.get("max_response_tokens", 300)
self.response_chunk = self.model.get("response_chunk", 250)
self.no_code_formatting = self.model.get("no_code_formatting", False)
self.cache_8bit = self.model.get("cache_8bit", False)
self.cache_q4 = self.model.get("cache_q4", True)
self.ngram_decoding = self.model.get("ngram_decoding", False)
self.print_timings = self.model.get("print_timings", False)
self.amnesia = self.model.get("amnesia", False)
self.batch_size = self.model.get("batch_size", 1)
self.cache_lens = self.model.get("cache_lens", 4096)
self.device = self.model.get("device", "cuda:2")
# web config
self.web: dict = cfg.get("web", {})
self.web_cross_domain: bool = self.web.get("open_cross_domain", True)
@ -104,10 +153,32 @@ class Config(metaclass=Singleton):
self.ext: dict = cfg.get("ext", {})
self.cpu_infer = self.ext.get("cpu_infer", 10)
#file config
self.local_store_configs: dict = cfg.get("local_store",{})
self.file_upload_dir: str = os.path.join(self.localstore_path,self.local_store_configs.get("file_upload_dir",""))
self.assistant_store_dir: str = os.path.join(self.localstore_path,self.local_store_configs.get("assistant_store_dir",""))
# file config
self.local_store_configs: dict = cfg.get("local_store", {})
self.file_upload_dir: str = os.path.join(
self.localstore_path, self.local_store_configs.get("file_upload_dir", "")
)
self.assistant_store_dir: str = os.path.join(
self.localstore_path, self.local_store_configs.get("assistant_store_dir", "")
)
#long context config
self.long_context_config: dict = cfg.get("long_context",{})
# long context config
self.long_context_config: dict = cfg.get("long_context", {})
self.chunk_size = self.long_context_config.get("chunk_size", 4096)
self.max_seq_len = self.long_context_config.get("max_seq_len", 32000)
self.block_size = self.long_context_config.get("block_size", 128)
self.local_windows_len = self.long_context_config.get("local_windows_len", 4096)
self.second_select_num = self.long_context_config.get("second_select_num", 32)
self.anchor_type = self.long_context_config.get("anchor_type", "DYNAMIC")
self.kv_type = self.long_context_config.get("kv_type", "FP16")
self.dense_layer_num = self.long_context_config.get("dense_layer_num", 2)
self.anchor_num = self.long_context_config.get("anchor_num", 1)
self.preselect_block = self.long_context_config.get("preselect_block", True)
self.head_select_mode = self.long_context_config.get("head_select_mode", "SHARED")
self.preselect_block_count = self.long_context_config.get("preselect_block_count", 32)
self.layer_step = self.long_context_config.get("layer_step", 1)
self.token_step = self.long_context_config.get("token_step", 100)
# local chat
self.local_chat_config: dict = cfg.get("local_chat", {})
self.prompt_file = self.local_chat_config.get("prompt_file", None)

View file

@ -3,11 +3,15 @@ import re
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
import uvicorn.logging
import argparse
import uvicorn
import sys
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
sys.path.insert(0, project_dir)
from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.args import ArgumentParser
from ktransformers.server.config.config import Config
from ktransformers.server.utils.create_interface import create_interface
from ktransformers.server.utils.create_interface import create_interface
from ktransformers.server.backend.args import default_args
from fastapi.openapi.utils import get_openapi
@ -44,8 +48,11 @@ def create_app():
mount_index_routes(app)
return app
def update_web_port(config_file: str):
ip_port_pattern = r"(localhost|((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)):[0-9]{1,5}"
ip_port_pattern = (
r"(localhost|((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)):[0-9]{1,5}"
)
with open(config_file, "r", encoding="utf-8") as f_cfg:
web_config = f_cfg.read()
ip_port = "localhost:" + str(Config().server_port)
@ -70,14 +77,15 @@ def mount_index_routes(app: FastAPI):
def run_api(app, host, port, **kwargs):
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
uvicorn.run(app,
host=host,
port=port,
ssl_keyfile=kwargs.get("ssl_keyfile"),
ssl_certfile=kwargs.get("ssl_certfile"),
)
uvicorn.run(
app,
host=host,
port=port,
ssl_keyfile=kwargs.get("ssl_keyfile"),
ssl_certfile=kwargs.get("ssl_certfile"),
)
else:
uvicorn.run(app, host=host, port=port, log_level='debug')
uvicorn.run(app, host=host, port=port, log_level="debug")
def custom_openapi(app):
@ -90,53 +98,27 @@ def custom_openapi(app):
description="We provided chat completion and openai assistant interfaces.",
routes=app.routes,
)
openapi_schema["info"]["x-logo"] = {
"url": "https://kvcache.ai/media/icon_1.png"
}
openapi_schema["info"]["x-logo"] = {"url": "https://kvcache.ai/media/icon_1.png"}
app.openapi_schema = openapi_schema
return app.openapi_schema
def main():
cfg = Config()
parser = argparse.ArgumentParser(prog='kvcache.ai',
description='Ktransformers')
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=cfg.server_port)
parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str)
parser.add_argument("--web", type=bool, default=False)
parser.add_argument("--model_name", type=str, default=cfg.model_name)
parser.add_argument("--model_path", type=str, default=cfg.model_path)
parser.add_argument("--device", type=str, default=cfg.model_device, help="Warning: Abandoning this parameter")
parser.add_argument("--gguf_path", type=str, default=cfg.gguf_path)
parser.add_argument("--optimize_config_path", default=None, type=str, required=False)
parser.add_argument("--cpu_infer", type=int, default=cfg.cpu_infer)
parser.add_argument("--type", type=str, default=cfg.backend_type)
arg_parser = ArgumentParser(cfg)
# 初始化消息
args = parser.parse_args()
cfg.model_name = args.model_name
cfg.model_path = args.model_path
cfg.model_device = args.device
cfg.mount_web = args.web
cfg.server_ip = args.host
cfg.server_port = args.port
cfg.cpu_infer = args.cpu_infer
cfg.backend_type = args.type
default_args.model_dir = args.model_path
default_args.device = args.device
default_args.gguf_path = args.gguf_path
default_args.optimize_config_path = args.optimize_config_path
args = arg_parser.parse_args()
app = create_app()
custom_openapi(app)
create_interface(config=cfg, default_args=default_args)
run_api(app=app,
host=args.host,
port=args.port,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,)
create_interface(config=cfg, default_args=cfg)
run_api(
app=app,
host=args.host,
port=args.port,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
)
if __name__ == "__main__":

View file

@ -4412,8 +4412,9 @@
},
"node_modules/@vue/cli": {
"version": "5.0.8",
"resolved": "https://registry.npmmirror.com/@vue/cli/-/cli-5.0.8.tgz",
"resolved": "https://registry.npmjs.org/@vue/cli/-/cli-5.0.8.tgz",
"integrity": "sha512-c/QKPdC09bYkW22m/boXkLaiz10z0Z2WHZO7zEeNdfSduqyWINZhKc6hVQU3Vk0NXW7BJAd7zWmcUrC8L9TuAA==",
"license": "MIT",
"dependencies": {
"@types/ejs": "^3.0.6",
"@types/inquirer": "^8.1.3",

View file

@ -69,4 +69,8 @@ ktransformers = "ktransformers.server.main:main"
[tool.setuptools.packages.find]
where = ["./", ]
include = ["ktransformers"]
include = ["ktransformers"]
[tool.black]
line-length = 120
preview = true
unstable = true