mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 12:09:48 +00:00
Merge pull request #110 from KMSorSMS/main
refactor local_chat & config setting
This commit is contained in:
commit
dddc42038d
17 changed files with 652 additions and 400 deletions
4
.flake8
Normal file
4
.flake8
Normal file
|
@ -0,0 +1,4 @@
|
|||
[flake8]
|
||||
max-line-length = 120
|
||||
extend-select = B950
|
||||
extend-ignore = E203,E501,E701, B001,B006,B007,B008,B009,B010,B011,B016,B028,B031,B950,E265,E266,E401,E402,E711,E712,E713,E721,E722,E731,F401,F403,F405,F541,F811,F821,F841,W391
|
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -17,4 +17,5 @@ compile_commands.json
|
|||
*dist/
|
||||
ktransformers/server/local_store/
|
||||
ktransformers/server_test1.db
|
||||
*.patch
|
||||
*.patch
|
||||
img/
|
21
Makefile
Normal file
21
Makefile
Normal file
|
@ -0,0 +1,21 @@
|
|||
flake_find:
|
||||
cd ktransformers && flake8 | grep -Eo '[A-Z][0-9]{3}' | sort | uniq| paste -sd ',' -
|
||||
format:
|
||||
@cd ktransformers && black .
|
||||
@black setup.py
|
||||
dev_install:
|
||||
# clear build dirs
|
||||
rm -rf build
|
||||
rm -rf *.egg-info
|
||||
rm -rf ktransformers/ktransformers_ext/build
|
||||
rm -rf ktransformers/ktransformers_ext/cuda/build
|
||||
rm -rf ktransformers/ktransformers_ext/cuda/dist
|
||||
rm -rf ktransformers/ktransformers_ext/cuda/*.egg-info
|
||||
|
||||
# install ktransformers
|
||||
echo "Installing python dependencies from requirements.txt"
|
||||
pip install -r requirements-local_chat.txt
|
||||
|
||||
echo "Installing ktransformers"
|
||||
KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . --no-build-isolation
|
||||
echo "Installation completed successfully"
|
|
@ -151,7 +151,7 @@ Some preparation:
|
|||
```
|
||||
install.bat
|
||||
```
|
||||
|
||||
4. If you are developer, you can make use of the makefile to compile and format the code. <br> the detailed usage of makefile is [here](./doc/en/makefile_usage.md)
|
||||
<h3>Local Chat</h3>
|
||||
We provide a simple command-line local chat Python script that you can run for testing.
|
||||
|
||||
|
|
26
doc/en/makefile_usage.md
Normal file
26
doc/en/makefile_usage.md
Normal file
|
@ -0,0 +1,26 @@
|
|||
# Makefile
|
||||
## Target
|
||||
### flake_find:
|
||||
```bash
|
||||
make flake_find
|
||||
```
|
||||
find all the python files under ./ktransformers dir and find the Error, Warning, Fatal... (their codes) into a list that are not consistent with the pep8 standard. For now we have get all this list in the .flake8 file's extend-ignore section in order to let flakes8 ignore them temporarily.(we may improve them in the future)
|
||||
### format:
|
||||
```bash
|
||||
make format
|
||||
```
|
||||
we use black to format all the python files under ./ktransformers dir. It obeys the pep8 standard
|
||||
but we modify the line length to 120 by add
|
||||
```toml
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
preview = true
|
||||
unstable = true
|
||||
```
|
||||
in the pyproject.toml file.
|
||||
|
||||
### dev_install:
|
||||
```bash
|
||||
make dev_install
|
||||
```
|
||||
install the package in the development mode. It means that the package is installed in the editable mode. So if you modify the code, you don't need to reinstall the package. We recommend the developer to use this method to install the package.
|
|
@ -7,7 +7,7 @@ log:
|
|||
|
||||
server:
|
||||
ip: 0.0.0.0
|
||||
port: 12456
|
||||
port: 10002
|
||||
|
||||
db:
|
||||
type: "sqllite"
|
||||
|
@ -24,10 +24,11 @@ model:
|
|||
type: ktransformers
|
||||
|
||||
name: DeepSeek-Coder-V2-Instruct
|
||||
path: /mnt/data/model/DeepSeek-Coder-V2-Instruct/
|
||||
gguf_path: /mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH/
|
||||
path: deepseek-ai/DeepSeek-V2-Lite-Chat
|
||||
gguf_path: ./DeepSeek-V2-Lite-Chat-GGUF
|
||||
|
||||
device: cuda:0
|
||||
cache_lens: 8192
|
||||
|
||||
web:
|
||||
mount: False
|
||||
|
@ -50,4 +51,7 @@ long_context:
|
|||
head_select_mode: SHARED
|
||||
preselect_block_count: 32
|
||||
layer_step: 1
|
||||
token_step: 100
|
||||
token_step:
|
||||
|
||||
local_chat:
|
||||
prompt_file: "./ktransformers/p.txt"
|
|
@ -1,33 +1,23 @@
|
|||
"""
|
||||
Description :
|
||||
Description :
|
||||
Author : Boxin Zhang, Azure-Tang
|
||||
Version : 0.1.0
|
||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
|
||||
project_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
sys.path.insert(0, project_dir)
|
||||
import torch
|
||||
import logging
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
GenerationConfig,
|
||||
TextStreamer,
|
||||
)
|
||||
import json
|
||||
import fire
|
||||
from ktransformers.optimize.optimize import optimize_and_load_gguf
|
||||
from ktransformers.server.args import ArgumentParser
|
||||
|
||||
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
|
||||
from ktransformers.models.modeling_llama import LlamaForCausalLM
|
||||
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
||||
from ktransformers.util.utils import prefill_and_generate
|
||||
from ktransformers.server.config.config import Config
|
||||
|
||||
custom_models = {
|
||||
|
@ -37,9 +27,7 @@ custom_models = {
|
|||
"MixtralForCausalLM": MixtralForCausalLM,
|
||||
}
|
||||
|
||||
ktransformer_rules_dir = (
|
||||
os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
|
||||
)
|
||||
ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
|
||||
default_optimize_rules = {
|
||||
"DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
|
||||
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
|
||||
|
@ -48,75 +36,28 @@ default_optimize_rules = {
|
|||
}
|
||||
|
||||
|
||||
def local_chat(
|
||||
model_path: str | None = None,
|
||||
optimize_rule_path: str = None,
|
||||
gguf_path: str | None = None,
|
||||
max_new_tokens: int = 1000,
|
||||
cpu_infer: int = Config().cpu_infer,
|
||||
use_cuda_graph: bool = True,
|
||||
prompt_file : str | None = None,
|
||||
mode: str = "normal",
|
||||
):
|
||||
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
Config().cpu_infer = cpu_infer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
if mode == 'long_context':
|
||||
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
|
||||
torch.set_default_dtype(torch.float16)
|
||||
def local_chat():
|
||||
config = Config()
|
||||
arg_parser = ArgumentParser(config)
|
||||
# 初始化消息
|
||||
arg_parser.parse_args()
|
||||
if config.backend_type == "transformers":
|
||||
from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
|
||||
elif config.backend_type == "exllamav2":
|
||||
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface
|
||||
elif config.backend_type == "ktransformers":
|
||||
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
|
||||
else:
|
||||
torch.set_default_dtype(config.torch_dtype)
|
||||
|
||||
with torch.device("meta"):
|
||||
if config.architectures[0] in custom_models:
|
||||
print("using custom modeling_xxx.py.")
|
||||
if (
|
||||
"Qwen2Moe" in config.architectures[0]
|
||||
): # Qwen2Moe must use flash_attention_2 to avoid overflow.
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
if "Llama" in config.architectures[0]:
|
||||
config._attn_implementation = "eager"
|
||||
if "Mixtral" in config.architectures[0]:
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
|
||||
model = custom_models[config.architectures[0]](config)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=True, attn_implementation="flash_attention_2"
|
||||
)
|
||||
|
||||
if optimize_rule_path is None:
|
||||
if config.architectures[0] in default_optimize_rules:
|
||||
print("using default_optimize_rule for", config.architectures[0])
|
||||
optimize_rule_path = default_optimize_rules[config.architectures[0]]
|
||||
else:
|
||||
optimize_rule_path = input(
|
||||
"please input the path of your rule file(yaml file containing optimize rules):"
|
||||
)
|
||||
|
||||
if gguf_path is None:
|
||||
gguf_path = input(
|
||||
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
|
||||
)
|
||||
optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config)
|
||||
|
||||
model.generation_config = GenerationConfig.from_pretrained(model_path)
|
||||
if model.generation_config.pad_token_id is None:
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
||||
model.eval()
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
raise NotImplementedError(f"{config.backend_type} not implemented")
|
||||
interface = BackendInterface(config)
|
||||
|
||||
system = platform.system()
|
||||
if system == "Windows":
|
||||
os.system("cls")
|
||||
else:
|
||||
os.system("clear")
|
||||
|
||||
# add a history chat content
|
||||
his_content = []
|
||||
while True:
|
||||
content = input("Chat: ")
|
||||
if content.startswith('"""'): # prefix """
|
||||
|
@ -132,28 +73,27 @@ def local_chat(
|
|||
break
|
||||
else:
|
||||
content += line + "\n"
|
||||
|
||||
if content == "":
|
||||
if prompt_file != None:
|
||||
content = open(prompt_file, "r").read()
|
||||
else:
|
||||
if config.prompt_file == None or config.prompt_file == "":
|
||||
content = "Please write a piece of quicksort code in C++."
|
||||
else:
|
||||
content = open(config.prompt_file, "r").read()
|
||||
elif os.path.isfile(content):
|
||||
content = open(content, "r").read()
|
||||
messages = [{"role": "user", "content": content}]
|
||||
input_tensor = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
if mode == 'long_context':
|
||||
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
|
||||
"please change max_seq_len in ~/.ktransformers/config.yaml"
|
||||
torch.set_default_dtype(
|
||||
torch.bfloat16
|
||||
) # TODO: Remove this, replace dtype using config
|
||||
generated = prefill_and_generate(
|
||||
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode
|
||||
)
|
||||
messages = his_content + [{"role": "user", "content": content}]
|
||||
|
||||
async def async_inference(messages):
|
||||
generated = ""
|
||||
async for token in interface.inference(messages, "local_chat"):
|
||||
generated += token
|
||||
return generated
|
||||
|
||||
generated = asyncio.run(async_inference(messages))
|
||||
his_content += [
|
||||
{"role": "user", "content": content},
|
||||
{"role": "assistant", "content": generated},
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(local_chat)
|
||||
local_chat()
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
- match:
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPU"
|
||||
out_device: "cuda"
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
|
@ -135,6 +135,8 @@ class OllamaShowResponse(BaseModel):
|
|||
details: OllamaShowDetial
|
||||
model_info: OllamaModelInfo
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
|
||||
|
|
124
ktransformers/server/args.py
Normal file
124
ktransformers/server/args.py
Normal file
|
@ -0,0 +1,124 @@
|
|||
import argparse
|
||||
from ktransformers.server.backend.args import ConfigArgs, default_args
|
||||
|
||||
|
||||
class ArgumentParser:
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
|
||||
def parse_args(self):
|
||||
parser = argparse.ArgumentParser(prog="kvcache.ai", description="Ktransformers")
|
||||
parser.add_argument("--host", type=str, default=self.cfg.server_ip)
|
||||
parser.add_argument("--port", type=int, default=self.cfg.server_port)
|
||||
parser.add_argument("--ssl_keyfile", type=str)
|
||||
parser.add_argument("--ssl_certfile", type=str)
|
||||
parser.add_argument("--web", type=bool, default=self.cfg.mount_web)
|
||||
parser.add_argument("--model_name", type=str, default=self.cfg.model_name)
|
||||
parser.add_argument("--model_dir", type=str)
|
||||
parser.add_argument("--model_path", type=str)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter"
|
||||
)
|
||||
parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path)
|
||||
parser.add_argument("--optimize_config_path", default=self.cfg.optimize_config_path, type=str, required=False)
|
||||
parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer)
|
||||
parser.add_argument("--type", type=str, default=self.cfg.backend_type)
|
||||
|
||||
# model configs
|
||||
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
|
||||
parser.add_argument("--paged", type=bool, default=self.cfg.paged)
|
||||
parser.add_argument("--total_context", type=int, default=self.cfg.total_context)
|
||||
parser.add_argument("--max_batch_size", type=int, default=self.cfg.max_batch_size)
|
||||
parser.add_argument("--max_chunk_size", type=int, default=self.cfg.max_chunk_size)
|
||||
parser.add_argument("--max_new_tokens", type=int, default=self.cfg.max_new_tokens)
|
||||
parser.add_argument("--json_mode", type=bool, default=self.cfg.json_mode)
|
||||
parser.add_argument("--healing", type=bool, default=self.cfg.healing)
|
||||
parser.add_argument("--ban_strings", type=list, default=self.cfg.ban_strings, required=False)
|
||||
parser.add_argument("--gpu_split", type=str, default=self.cfg.gpu_split, required=False)
|
||||
parser.add_argument("--length", type=int, default=self.cfg.length, required=False)
|
||||
parser.add_argument("--rope_scale", type=float, default=self.cfg.rope_scale, required=False)
|
||||
parser.add_argument("--rope_alpha", type=float, default=self.cfg.rope_alpha, required=False)
|
||||
parser.add_argument("--no_flash_attn", type=bool, default=self.cfg.no_flash_attn)
|
||||
parser.add_argument("--low_mem", type=bool, default=self.cfg.low_mem)
|
||||
parser.add_argument("--experts_per_token", type=int, default=self.cfg.experts_per_token, required=False)
|
||||
parser.add_argument("--load_q4", type=bool, default=self.cfg.load_q4)
|
||||
parser.add_argument("--fast_safetensors", type=bool, default=self.cfg.fast_safetensors)
|
||||
parser.add_argument("--draft_model_dir", type=str, default=self.cfg.draft_model_dir, required=False)
|
||||
parser.add_argument("--no_draft_scale", type=bool, default=self.cfg.no_draft_scale)
|
||||
parser.add_argument("--modes", type=bool, default=self.cfg.modes)
|
||||
parser.add_argument("--mode", type=str, default=self.cfg.mode)
|
||||
parser.add_argument("--username", type=str, default=self.cfg.username)
|
||||
parser.add_argument("--botname", type=str, default=self.cfg.botname)
|
||||
parser.add_argument("--system_prompt", type=str, default=self.cfg.system_prompt, required=False)
|
||||
parser.add_argument("--temperature", type=float, default=self.cfg.temperature)
|
||||
parser.add_argument("--smoothing_factor", type=float, default=self.cfg.smoothing_factor)
|
||||
parser.add_argument("--dynamic_temperature", type=str, default=self.cfg.dynamic_temperature, required=False)
|
||||
parser.add_argument("--top_k", type=int, default=self.cfg.top_k)
|
||||
parser.add_argument("--top_p", type=float, default=self.cfg.top_p)
|
||||
parser.add_argument("--top_a", type=float, default=self.cfg.top_a)
|
||||
parser.add_argument("--skew", type=float, default=self.cfg.skew)
|
||||
parser.add_argument("--typical", type=float, default=self.cfg.typical)
|
||||
parser.add_argument("--repetition_penalty", type=float, default=self.cfg.repetition_penalty)
|
||||
parser.add_argument("--frequency_penalty", type=float, default=self.cfg.frequency_penalty)
|
||||
parser.add_argument("--presence_penalty", type=float, default=self.cfg.presence_penalty)
|
||||
parser.add_argument("--max_response_tokens", type=int, default=self.cfg.max_response_tokens)
|
||||
parser.add_argument("--response_chunk", type=int, default=self.cfg.response_chunk)
|
||||
parser.add_argument("--no_code_formatting", type=bool, default=self.cfg.no_code_formatting)
|
||||
parser.add_argument("--cache_8bit", type=bool, default=self.cfg.cache_8bit)
|
||||
parser.add_argument("--cache_q4", type=bool, default=self.cfg.cache_q4)
|
||||
parser.add_argument("--ngram_decoding", type=bool, default=self.cfg.ngram_decoding)
|
||||
parser.add_argument("--print_timings", type=bool, default=self.cfg.print_timings)
|
||||
parser.add_argument("--amnesia", type=bool, default=self.cfg.amnesia)
|
||||
parser.add_argument("--batch_size", type=int, default=self.cfg.batch_size)
|
||||
parser.add_argument("--cache_lens", type=int, default=self.cfg.cache_lens)
|
||||
|
||||
# log configs
|
||||
# log level: debug, info, warn, error, crit
|
||||
parser.add_argument("--log_dir", type=str, default=self.cfg.log_dir)
|
||||
parser.add_argument("--log_file", type=str, default=self.cfg.log_file)
|
||||
parser.add_argument("--log_level", type=str, default=self.cfg.log_level)
|
||||
parser.add_argument("--backup_count", type=int, default=self.cfg.backup_count)
|
||||
|
||||
# db configs
|
||||
parser.add_argument("--db_type", type=str, default=self.cfg.db_type)
|
||||
parser.add_argument("--db_host", type=str, default=self.cfg.db_host)
|
||||
parser.add_argument("--db_port", type=str, default=self.cfg.db_port)
|
||||
parser.add_argument("--db_name", type=str, default=self.cfg.db_name)
|
||||
parser.add_argument("--db_pool_size", type=int, default=self.cfg.db_pool_size)
|
||||
parser.add_argument("--db_database", type=str, default=self.cfg.db_database)
|
||||
|
||||
# user config
|
||||
parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key)
|
||||
parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm)
|
||||
|
||||
# web config
|
||||
parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain)
|
||||
|
||||
# file config
|
||||
parser.add_argument("--file_upload_dir", type=str, default=self.cfg.file_upload_dir)
|
||||
parser.add_argument("--assistant_store_dir", type=str, default=self.cfg.assistant_store_dir)
|
||||
# local chat
|
||||
parser.add_argument("--prompt_file", type=str, default=self.cfg.prompt_file)
|
||||
|
||||
args = parser.parse_args()
|
||||
if (args.model_dir is not None):
|
||||
if (args.model_path is not None):
|
||||
# if pass model_dir and model_path, we use model_path
|
||||
args.model_dir = args.model_path
|
||||
else:
|
||||
# if only pass model_dir, we use model_dir
|
||||
args.model_path = args.model_dir
|
||||
else:
|
||||
args.model_dir = self.cfg.model_dir
|
||||
args.model_path = self.cfg.model_path
|
||||
# set config from args
|
||||
for key, value in vars(args).items():
|
||||
if value is not None and hasattr(self.cfg, key):
|
||||
setattr(self.cfg, key, value)
|
||||
# we add the name not match args individually
|
||||
self.cfg.model_device = args.device
|
||||
self.cfg.mount_web = args.web
|
||||
self.cfg.server_ip = args.host
|
||||
self.cfg.server_port = args.port
|
||||
self.cfg.backend_type = args.type
|
||||
return args
|
|
@ -1,97 +1,89 @@
|
|||
from pydantic import BaseModel,Field
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from ktransformers.server.config.config import Config
|
||||
|
||||
|
||||
class ConfigArgs(BaseModel):
|
||||
model_name : Optional[str] = Field(..., description="Model name")
|
||||
model_name: Optional[str] = Field(..., description="Model name")
|
||||
model_dir: Optional[str] = Field(..., description="Path to model directory")
|
||||
optimize_config_path: Optional[str] = Field('./KTransformers/optimize_config/DeepSeek-V2-Chat.json', description="Path of your optimize config json file")
|
||||
gguf_path: Optional[str] = Field('/models/DeepSeek-Coder-V2-Instruct-GGUF/DeepSeek-Coder-V2-Instruct-Q4_K_M.gguf', description="Path of your gguf file")
|
||||
optimize_config_path: Optional[str] = Field(None, description="Path of your optimize config yml file")
|
||||
gguf_path: Optional[str] = Field(None, description="Path of your gguf file")
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
paged : bool = Field(True,description='Wether to use paged attention kv cache')
|
||||
|
||||
# total_context: int = Field(16384, description="Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the total to distribute dynamically over however many jobs are active at once")
|
||||
total_context: int = Field(2**18, description="Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the total to distribute dynamically over however many jobs are active at once")
|
||||
max_batch_size: int = Field(20 if paged else 1, description="Max number of batches to run at once, assuming the sequences will fit within total_context")
|
||||
max_chunk_size: int = Field(2048, description="Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new job is started, but at the expense of overall prompt ingestion speed")
|
||||
max_new_tokens: int = Field(500, description="Max new tokens per completion. For this example applies to all jobs")
|
||||
json_mode: bool = Field(False, description="Use LMFE to constrain the output to JSON format. See schema and details below")
|
||||
healing: bool = Field(False, description="Demonstrate token healing")
|
||||
paged: bool = Field(None, description="Whether to use paged attention kv cache")
|
||||
total_context: int = Field(
|
||||
None,
|
||||
description=(
|
||||
"Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the"
|
||||
" total to distribute dynamically over however many jobs are active at once"
|
||||
),
|
||||
)
|
||||
max_batch_size: int = Field(
|
||||
None, description="Max number of batches to run at once, assuming the sequences will fit within total_context"
|
||||
)
|
||||
max_chunk_size: int = Field(
|
||||
None,
|
||||
description=(
|
||||
"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new"
|
||||
" job is started, but at the expense of overall prompt ingestion speed"
|
||||
),
|
||||
)
|
||||
max_new_tokens: int = Field(None, description="Max new tokens per completion. For this example applies to all jobs")
|
||||
json_mode: bool = Field(
|
||||
None, description="Use LMFE to constrain the output to JSON format. See schema and details below"
|
||||
)
|
||||
healing: bool = Field(None, description="Demonstrate token healing")
|
||||
ban_strings: Optional[list] = Field(None, description="Ban some phrases maybe")
|
||||
|
||||
gpu_split: Optional[str] = Field(None, description='"auto", or VRAM allocation per GPU in GB')
|
||||
length: Optional[int] = Field(None, description="Maximum sequence length")
|
||||
rope_scale: Optional[float] = Field(None, description="RoPE scaling factor")
|
||||
rope_alpha: Optional[float] = Field(None, description="RoPE alpha value (NTK)")
|
||||
no_flash_attn: bool = Field(False, description="Disable Flash Attention")
|
||||
low_mem: bool = Field(
|
||||
False,
|
||||
description="Enable VRAM optimizations, potentially trading off speed",
|
||||
)
|
||||
no_flash_attn: bool = Field(None, description="Disable Flash Attention")
|
||||
low_mem: bool = Field(None, description="Enable VRAM optimizations, potentially trading off speed")
|
||||
experts_per_token: Optional[int] = Field(
|
||||
None,
|
||||
description="Override MoE model's default number of experts per token",
|
||||
)
|
||||
load_q4: bool = Field(False, description="Load weights in Q4 mode")
|
||||
fast_safetensors: bool = Field(
|
||||
False,
|
||||
description="Optimized safetensors loading with direct I/O (experimental!)",
|
||||
None, description="Override MoE model's default number of experts per token"
|
||||
)
|
||||
load_q4: bool = Field(None, description="Load weights in Q4 mode")
|
||||
fast_safetensors: bool = Field(None, description="Optimized safetensors loading with direct I/O (experimental!)")
|
||||
draft_model_dir: Optional[str] = Field(None, description="Path to draft model directory")
|
||||
no_draft_scale: bool = Field(
|
||||
False,
|
||||
None,
|
||||
description="If draft model has smaller context size than model, don't apply alpha (NTK) scaling to extend it",
|
||||
)
|
||||
modes: bool = Field(False, description="List available modes and exit.")
|
||||
mode: str = Field(
|
||||
"llama",
|
||||
description="Chat mode. Use llama for Llama 1/2 chat finetunes.",
|
||||
)
|
||||
username: str = Field("User", description="Username when using raw chat mode")
|
||||
botname: str = Field("Chatbort", description="Bot name when using raw chat mode")
|
||||
modes: bool = Field(None, description="List available modes and exit.")
|
||||
mode: str = Field(None, description="Chat mode. Use llama for Llama 1/2 chat finetunes.")
|
||||
username: str = Field(None, description="Username when using raw chat mode")
|
||||
botname: str = Field(None, description="Bot name when using raw chat mode")
|
||||
system_prompt: Optional[str] = Field(None, description="Use custom system prompt")
|
||||
temperature: float = Field(0.95, description="Sampler temperature, default = 0.95 (1 to disable)")
|
||||
smoothing_factor: float = Field(0.0, description="Smoothing Factor, default = 0.0 (0 to disable)")
|
||||
temperature: float = Field(None, description="Sampler temperature, default = 0.95 (1 to disable)")
|
||||
smoothing_factor: float = Field(None, description="Smoothing Factor, default = 0.0 (0 to disable)")
|
||||
dynamic_temperature: Optional[str] = Field(
|
||||
None,
|
||||
description="Dynamic temperature min,max,exponent, e.g. -dyntemp 0.2,1.5,1",
|
||||
None, description="Dynamic temperature min,max,exponent, e.g. -dyntemp 0.2,1.5,1"
|
||||
)
|
||||
top_k: int = Field(50, description="Sampler top-K, default = 50 (0 to disable)")
|
||||
top_p: float = Field(0.8, description="Sampler top-P, default = 0.8 (0 to disable)")
|
||||
top_a: float = Field(0.0, description="Sampler top-A, default = 0.0 (0 to disable)")
|
||||
skew: float = Field(0.0, description="Skew sampling, default = 0.0 (0 to disable)")
|
||||
typical: float = Field(
|
||||
0.0,
|
||||
description="Sampler typical threshold, default = 0.0 (0 to disable)",
|
||||
)
|
||||
repetition_penalty: float = Field(
|
||||
1.01,
|
||||
description="Sampler repetition penalty, default = 1.01 (1 to disable)",
|
||||
)
|
||||
frequency_penalty: float = Field(
|
||||
0.0,
|
||||
description="Sampler frequency penalty, default = 0.0 (0 to disable)",
|
||||
)
|
||||
presence_penalty: float = Field(
|
||||
0.0,
|
||||
description="Sampler presence penalty, default = 0.0 (0 to disable)",
|
||||
)
|
||||
max_response_tokens: int = Field(300, description="Max tokens per response, default = 1000")
|
||||
response_chunk: int = Field(250, description="Space to reserve in context for reply, default = 250")
|
||||
no_code_formatting: bool = Field(False, description="Disable code formatting/syntax highlighting")
|
||||
cache_8bit: bool = Field(False, description="Use 8-bit (FP8) cache")
|
||||
cache_q4: bool = Field(True, description="Use Q4 cache")
|
||||
ngram_decoding: bool = Field(False, description="Use n-gram speculative decoding")
|
||||
print_timings: bool = Field(False, description="Output timings after each prompt")
|
||||
amnesia: bool = Field(False, description="Forget context after every response")
|
||||
top_k: int = Field(None, description="Sampler top-K, default = 50 (0 to disable)")
|
||||
top_p: float = Field(None, description="Sampler top-P, default = 0.8 (0 to disable)")
|
||||
top_a: float = Field(None, description="Sampler top-A, default = 0.0 (0 to disable)")
|
||||
skew: float = Field(None, description="Skew sampling, default = 0.0 (0 to disable)")
|
||||
typical: float = Field(None, description="Sampler typical threshold, default = 0.0 (0 to disable)")
|
||||
repetition_penalty: float = Field(None, description="Sampler repetition penalty, default = 1.01 (1 to disable)")
|
||||
frequency_penalty: float = Field(None, description="Sampler frequency penalty, default = 0.0 (0 to disable)")
|
||||
presence_penalty: float = Field(None, description="Sampler presence penalty, default = 0.0 (0 to disable)")
|
||||
max_response_tokens: int = Field(None, description="Max tokens per response, default = 1000")
|
||||
response_chunk: int = Field(None, description="Space to reserve in context for reply, default = 250")
|
||||
no_code_formatting: bool = Field(None, description="Disable code formatting/syntax highlighting")
|
||||
cache_8bit: bool = Field(None, description="Use 8-bit (FP8) cache")
|
||||
cache_q4: bool = Field(None, description="Use Q4 cache")
|
||||
ngram_decoding: bool = Field(None, description="Use n-gram speculative decoding")
|
||||
print_timings: bool = Field(None, description="Output timings after each prompt")
|
||||
amnesia: bool = Field(None, description="Forget context after every response")
|
||||
|
||||
# for transformers
|
||||
batch_size :int = Field(1,description="Batch Size")
|
||||
cache_lens:int = Field(4096, description="Cache lens for transformers static cache")
|
||||
device:str = Field('cuda:2',description="device")
|
||||
batch_size: int = Field(None, description="Batch Size")
|
||||
cache_lens: int = Field(None, description="Cache lens for transformers static cache")
|
||||
device: str = Field(None, description="device")
|
||||
|
||||
|
||||
cfg = Config()
|
||||
default_args = ConfigArgs(model_name=cfg.model_name,model_dir=cfg.model_path)
|
||||
default_args = cfg
|
||||
|
|
|
@ -1,6 +1,12 @@
|
|||
import torch
|
||||
from transformers import AutoTokenizer, AutoConfig, GenerationConfig
|
||||
from ktransformers.server.backend.interfaces.transformers import TransformersInterface,ConfigArgs, TransformersThreadContext,default_args,TextStreamer
|
||||
from ktransformers.server.backend.interfaces.transformers import (
|
||||
TransformersInterface,
|
||||
ConfigArgs,
|
||||
TransformersThreadContext,
|
||||
default_args,
|
||||
TextStreamer,
|
||||
)
|
||||
from ktransformers.server.config.log import logger
|
||||
from ktransformers.optimize.optimize import optimize_and_load_gguf
|
||||
from ktransformers.models.custom_cache import StaticCache
|
||||
|
@ -14,71 +20,85 @@ class KTransformersThreadContext(TransformersThreadContext):
|
|||
|
||||
|
||||
class KTransformersInterface(TransformersInterface):
|
||||
def __init__(self,args:ConfigArgs= default_args):
|
||||
def __init__(self, args: ConfigArgs = default_args):
|
||||
self.args = args
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_grad_enabled(False)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir,device = args.device)
|
||||
config=AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device)
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
config._attn_implementation="flash_attention_2"
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
|
||||
with torch.device("meta"):
|
||||
self.model=custom_models[config.architectures[0]](config)
|
||||
self.model = custom_models[config.architectures[0]](config)
|
||||
if default_args.optimize_config_path is None:
|
||||
optimize_rule_path = default_optimize_rules[config.architectures[0]]
|
||||
else:
|
||||
optimize_rule_path = args.optimize_config_path
|
||||
|
||||
|
||||
# print(optimize_config)
|
||||
|
||||
gguf_path = args.gguf_path
|
||||
if gguf_path is None:
|
||||
gguf_path = input(
|
||||
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
|
||||
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
|
||||
" belong to current model):"
|
||||
)
|
||||
optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config)
|
||||
|
||||
|
||||
device_map = self.model.gguf_loader.tensor_device_map
|
||||
logger.info(f'{args.model_name} loaded from {args.model_dir} to {device_map}')
|
||||
self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=device_map, dtype=self.model.dtype)
|
||||
logger.info(f'StaticCache (length={args.cache_lens}) created at {device_map}, batch size:{args.batch_size}')
|
||||
logger.info(f"{args.model_name} loaded from {args.model_dir} to {device_map}")
|
||||
self.cache = StaticCache(
|
||||
config=self.model.config,
|
||||
max_batch_size=args.batch_size,
|
||||
max_cache_len=args.cache_lens,
|
||||
device=device_map,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
logger.info(f"StaticCache (length={args.cache_lens}) created at {device_map}, batch size:{args.batch_size}")
|
||||
self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
|
||||
if self.model.generation_config.pad_token_id is None:
|
||||
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
||||
self.streamer = TextStreamer(self.tokenizer)
|
||||
|
||||
|
||||
def decode_one_tokens(self):
|
||||
if not hasattr(self, "cuda_graph_runner"):
|
||||
device_map = self.model.gguf_loader.tensor_device_map
|
||||
torch_device = get_device('blk.0.self_attn', device_map)
|
||||
torch_device = get_device("blk.0.self_attn", device_map)
|
||||
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
||||
self.cuda_graph_runner = CUDAGraphRunner()
|
||||
self.cuda_graph_runner.capture(self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, main_device=torch_device, return_dict=False, use_cache=True)
|
||||
|
||||
self.cuda_graph_runner.capture(
|
||||
self.model,
|
||||
self.current_ids,
|
||||
self.active_cache_position.unsqueeze(0),
|
||||
self.active_cache_position,
|
||||
self.cache,
|
||||
main_device=torch_device,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
if hasattr(self, "cuda_graph_runner"):
|
||||
logits = self.cuda_graph_runner(self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position)
|
||||
logits = self.cuda_graph_runner(
|
||||
self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position
|
||||
)
|
||||
self.cache.change_seq_length(1)
|
||||
torch.cuda.synchronize()
|
||||
logits = logits[0,-1,:]
|
||||
logits = logits[0, -1, :]
|
||||
return self.logits_to_token(logits)
|
||||
|
||||
|
||||
if self.use_static_cache:
|
||||
mask = torch.ones((1,self.seq_length)).to(torch_device)
|
||||
mask = torch.ones((1, self.seq_length)).to(torch_device)
|
||||
logits = self.model(
|
||||
self.current_ids,
|
||||
cache_position=self.active_cache_position,
|
||||
past_key_values=self.cache,
|
||||
attention_mask=mask,
|
||||
return_dict=False,
|
||||
use_cache=True
|
||||
use_cache=True,
|
||||
)[0]
|
||||
else:
|
||||
logits = self.model(
|
||||
self.current_ids,
|
||||
return_dict=False
|
||||
)[0]
|
||||
logits = logits[0,-1,:]
|
||||
logits = self.model(self.current_ids, return_dict=False)[0]
|
||||
logits = logits[0, -1, :]
|
||||
|
||||
return self.logits_to_token(logits)
|
||||
|
|
|
@ -1,14 +1,22 @@
|
|||
from typing import Any, List, Optional, Set
|
||||
from transformers import LlamaTokenizer,AutoTokenizer, AutoConfig, LlamaForCausalLM,GenerationConfig, StaticCache, AutoModelForCausalLM,BitsAndBytesConfig
|
||||
from transformers import (
|
||||
LlamaTokenizer,
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
LlamaForCausalLM,
|
||||
GenerationConfig,
|
||||
StaticCache,
|
||||
AutoModelForCausalLM,
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
|
||||
from ktransformers.server.schemas.base import ObjectID
|
||||
from ktransformers.server.utils.multi_timer import Profiler
|
||||
import torch
|
||||
import sys, os
|
||||
from ..base import ThreadContext,BackendInterfaceBase
|
||||
from ..base import ThreadContext, BackendInterfaceBase
|
||||
from ktransformers.server.config.log import logger
|
||||
from ..args import ConfigArgs,default_args
|
||||
|
||||
from ..args import ConfigArgs, default_args
|
||||
|
||||
|
||||
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
|
||||
|
@ -28,21 +36,20 @@ class TextStreamer:
|
|||
self.token_cache = []
|
||||
self.print_len = 0
|
||||
|
||||
def put(self, value)->Optional[str]:
|
||||
def put(self, value) -> Optional[str]:
|
||||
"""
|
||||
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
|
||||
"""
|
||||
if not isinstance(value,int):
|
||||
"""
|
||||
if not isinstance(value, int):
|
||||
raise ValueError("TextStreamer only supports batch size 1, and int type input")
|
||||
|
||||
|
||||
if self.skip_prompt and self.next_tokens_are_prompt:
|
||||
self.next_tokens_are_prompt = False
|
||||
return None
|
||||
|
||||
# Add the new token to the cache and decodes the entire thing.
|
||||
self.token_cache.append(value)
|
||||
text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True,**self.decode_kwargs)
|
||||
text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs)
|
||||
|
||||
# After the symbol for a new line, we flush the cache.
|
||||
if text.endswith("\n"):
|
||||
|
@ -59,7 +66,7 @@ class TextStreamer:
|
|||
self.print_len += len(printable_text)
|
||||
return printable_text
|
||||
|
||||
def end(self)->Optional[str]:
|
||||
def end(self) -> Optional[str]:
|
||||
"""Flushes any remaining cache and prints a newline to stdout."""
|
||||
# Flush the cache, if it exists
|
||||
if len(self.token_cache) > 0:
|
||||
|
@ -71,7 +78,7 @@ class TextStreamer:
|
|||
|
||||
self.next_tokens_are_prompt = True
|
||||
return printable_text
|
||||
|
||||
|
||||
def _is_chinese_char(self, cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
|
@ -97,81 +104,81 @@ class TextStreamer:
|
|||
return False
|
||||
|
||||
|
||||
class TransformersThreadContext(ThreadContext):
|
||||
class TransformersThreadContext(ThreadContext):
|
||||
def get_local_messages(self):
|
||||
local_messages = []
|
||||
for m in self.messages:
|
||||
local_messages.append(
|
||||
{'role':m.role.value,
|
||||
'content':m.get_text_content()}
|
||||
)
|
||||
|
||||
local_messages.append({"role": m.role.value, "content": m.get_text_content()})
|
||||
|
||||
return local_messages
|
||||
|
||||
|
||||
class TransformersInterface(BackendInterfaceBase):
|
||||
use_static_cache : bool = True
|
||||
|
||||
use_static_cache: bool = True
|
||||
|
||||
model: Any
|
||||
tokenizer: AutoTokenizer
|
||||
|
||||
|
||||
cache: StaticCache
|
||||
generated_ids:torch.Tensor
|
||||
seq_length:int
|
||||
|
||||
generated_ids: torch.Tensor
|
||||
seq_length: int
|
||||
|
||||
streamer: TextStreamer
|
||||
|
||||
# thread_related
|
||||
last_request_id: Optional[str] = None
|
||||
ever_generated_ids: Set[int] = set()
|
||||
|
||||
|
||||
|
||||
def __init__(self, args:ConfigArgs = default_args):
|
||||
def __init__(self, args: ConfigArgs = default_args):
|
||||
self.args = args
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device,use_safetensors=True)
|
||||
logger.info(f'{args.model_name} loaded from {args.model_dir} to {args.device}')
|
||||
|
||||
self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=args.device, dtype=self.model.dtype)
|
||||
logger.info(f'StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}')
|
||||
|
||||
self.streamer = TextStreamer(self.tokenizer)
|
||||
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device, use_safetensors=True)
|
||||
logger.info(f"{args.model_name} loaded from {args.model_dir} to {args.device}")
|
||||
|
||||
self.cache = StaticCache(
|
||||
config=self.model.config,
|
||||
max_batch_size=args.batch_size,
|
||||
max_cache_len=args.cache_lens,
|
||||
device=args.device,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
logger.info(f"StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}")
|
||||
|
||||
self.streamer = TextStreamer(self.tokenizer)
|
||||
|
||||
@property
|
||||
def current_ids(self):
|
||||
return self.generated_ids[:,self.seq_length-1].unsqueeze(1)
|
||||
|
||||
return self.generated_ids[:, self.seq_length - 1].unsqueeze(1)
|
||||
|
||||
@property
|
||||
def active_cache_position(self):
|
||||
return torch.tensor([self.seq_length-1], device=self.args.device)
|
||||
return torch.tensor([self.seq_length - 1], device=self.args.device)
|
||||
|
||||
|
||||
def tokenize_prompt(self,prompt:str):
|
||||
input_ids = self.tokenizer.encode(prompt,return_tensors='pt').to(self.args.device)
|
||||
def tokenize_prompt(self, prompt: str):
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.args.device)
|
||||
return input_ids
|
||||
|
||||
def format_and_tokenize_input_ids(self,thread_id:ObjectID,messages:List):
|
||||
def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):
|
||||
for m in messages:
|
||||
if m['role']=='system':
|
||||
logger.warn(f'change {m["role"]} to user')
|
||||
m['role'] = 'user'
|
||||
if m["role"] == "system":
|
||||
logger.warning(f'change {m["role"]} to user')
|
||||
m["role"] = "user"
|
||||
|
||||
new_messages = [messages[0]]
|
||||
for m in messages[1:]:
|
||||
if m['role'] == 'user' and new_messages[-1]['role']=='user':
|
||||
logger.warn('merge two adjacent user messages')
|
||||
new_messages[-1]['content']+=m['content']
|
||||
for m in messages[1:]:
|
||||
if m["role"] == "user" and new_messages[-1]["role"] == "user":
|
||||
logger.warning("merge two adjacent user messages")
|
||||
new_messages[-1]["content"] += m["content"]
|
||||
else:
|
||||
new_messages.append(m)
|
||||
|
||||
|
||||
new_messages.append(m)
|
||||
# if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
||||
# input_ids = self.tokenizer.encode(self.tokenizer.eos_token+self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",tokenize=False, add_generation_prompt=True), add_special_tokens = False, return_tensors="pt").to(self.args.device)
|
||||
# else:
|
||||
# input_ids = self.tokenizer.apply_chat_template(
|
||||
# new_messages, return_tensors="pt", add_generation_prompt=True
|
||||
# ).to(self.args.device)
|
||||
input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device)
|
||||
|
||||
if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
||||
x = self.generated_ids[:,:self.seq_length]
|
||||
y = input_ids[:,:self.seq_length]
|
||||
|
@ -179,19 +186,19 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
unequal_mask = torch.ne(x,y)
|
||||
unequal_positions = torch.nonzero(unequal_mask)
|
||||
num_unequal_elements = unequal_mask.sum().item()
|
||||
logger.warn(f'num_unequal_elements: {num_unequal_elements}')
|
||||
logger.warning(f'num_unequal_elements: {num_unequal_elements}')
|
||||
|
||||
input_ids = input_ids[:,self.seq_length:]
|
||||
logger.debug(f'get input ids of shape {input_ids.shape}')
|
||||
logger.debug(f"get input ids of shape {input_ids.shape}")
|
||||
return input_ids
|
||||
|
||||
def append_new_tokens(self,new_tokens:int)->Optional[str]:
|
||||
self.generated_ids[0,self.seq_length] = new_tokens
|
||||
self.seq_length+=1
|
||||
|
||||
def append_new_tokens(self, new_tokens: int) -> Optional[str]:
|
||||
self.generated_ids[0, self.seq_length] = new_tokens
|
||||
self.seq_length += 1
|
||||
return self.streamer.put(new_tokens)
|
||||
|
||||
def logits_to_token(self,logits:torch.Tensor):
|
||||
logits = logits/self.args.temperature
|
||||
def logits_to_token(self, logits: torch.Tensor):
|
||||
logits = logits / self.args.temperature
|
||||
|
||||
for token_idx in self.ever_generated_ids:
|
||||
if logits[token_idx] < 0:
|
||||
|
@ -200,7 +207,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
logits[token_idx] /= self.args.repetition_penalty
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
|
||||
|
||||
sample = True
|
||||
if sample:
|
||||
last = torch.multinomial(probs, num_samples=1)
|
||||
|
@ -211,127 +218,124 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
self.ever_generated_ids.add(last)
|
||||
return last
|
||||
|
||||
|
||||
|
||||
def decode_one_tokens(self):
|
||||
if self.use_static_cache:
|
||||
mask = torch.ones((1,self.seq_length)).to(self.args.device)
|
||||
mask = torch.ones((1, self.seq_length)).to(self.args.device)
|
||||
logits = self.model(
|
||||
self.current_ids,
|
||||
cache_position=self.active_cache_position,
|
||||
past_key_values=self.cache,
|
||||
attention_mask=mask,
|
||||
return_dict=False,
|
||||
use_cache=True
|
||||
use_cache=True,
|
||||
)[0]
|
||||
else:
|
||||
logits = self.model(
|
||||
self.current_ids,
|
||||
return_dict=False
|
||||
)[0]
|
||||
logits = logits[0,-1,:]
|
||||
logits = self.model(self.current_ids, return_dict=False)[0]
|
||||
logits = logits[0, -1, :]
|
||||
|
||||
return self.logits_to_token(logits)
|
||||
|
||||
@torch.no_grad
|
||||
def prefill(self,input_ids:torch.Tensor,is_new:bool):
|
||||
def prefill(self, input_ids: torch.Tensor, is_new: bool):
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
self.profiler.set_counter('prefill',input_ids_length)
|
||||
logger.debug(f'input_ids: {input_ids.shape}')
|
||||
self.profiler.set_counter("prefill", input_ids_length)
|
||||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
|
||||
|
||||
if is_new:
|
||||
self.cache.reset()
|
||||
self.ever_generated_ids.clear()
|
||||
former_seq_length = 0
|
||||
self.seq_length = input_ids_length
|
||||
self.generated_ids = torch.zeros(
|
||||
self.args.batch_size, self.seq_length + self.args.max_new_tokens + 1, dtype=torch.int, device=self.args.device
|
||||
)
|
||||
self.args.batch_size,
|
||||
self.seq_length + self.args.max_new_tokens + 1,
|
||||
dtype=torch.int,
|
||||
device=self.args.device,
|
||||
)
|
||||
else:
|
||||
logger.debug(f'generate_ids: {self.generated_ids.shape}')
|
||||
logger.debug(f"generate_ids: {self.generated_ids.shape}")
|
||||
former_seq_length = self.seq_length
|
||||
self.seq_length += input_ids_length
|
||||
expected_length = self.seq_length + self.args.max_new_tokens+1
|
||||
expected_length = self.seq_length + self.args.max_new_tokens + 1
|
||||
delta_length = expected_length - self.generated_ids.shape[-1]
|
||||
if delta_length>0:
|
||||
if delta_length > 0:
|
||||
new_generate_ids = torch.zeros(
|
||||
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
|
||||
)
|
||||
self.generated_ids = torch.cat([self.generated_ids,new_generate_ids],dim=-1)
|
||||
logger.debug(f'cache position: {former_seq_length} to {self.seq_length}')
|
||||
cache_position = torch.arange(former_seq_length,self.seq_length, device=self.args.device)
|
||||
self.generated_ids[:,cache_position] = input_ids.to(self.args.device).to(torch.int)
|
||||
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
|
||||
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
|
||||
cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)
|
||||
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
|
||||
|
||||
mask = torch.ones((1,self.seq_length)).to(self.args.device)
|
||||
mask = torch.ones((1, self.seq_length)).to(self.args.device)
|
||||
device = input_ids.device
|
||||
if not(type(self) is TransformersInterface):
|
||||
if not (type(self) is TransformersInterface):
|
||||
input_ids = input_ids.to("cpu")
|
||||
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
||||
if self.use_static_cache:
|
||||
logits = self.model(
|
||||
inputs_embeds=inputs_embeds, cache_position=cache_position, past_key_values=self.cache,return_dict=False, use_cache=True,attention_mask=mask
|
||||
inputs_embeds=inputs_embeds,
|
||||
cache_position=cache_position,
|
||||
past_key_values=self.cache,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
attention_mask=mask,
|
||||
)[0]
|
||||
else:
|
||||
logits = self.model(
|
||||
inputs_embeds=inputs_embeds,return_dict=False
|
||||
)[0]
|
||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||
|
||||
|
||||
|
||||
next_token = self.logits_to_token(logits[0,-1,:])
|
||||
next_token = self.logits_to_token(logits[0, -1, :])
|
||||
yield self.append_new_tokens(next_token)
|
||||
|
||||
@torch.no_grad
|
||||
def generate(self):
|
||||
self.profiler.set_counter('decode',0)
|
||||
self.profiler.set_counter("decode", 0)
|
||||
for _ in range(1, self.args.max_new_tokens):
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
||||
next_token = self.decode_one_tokens()
|
||||
self.profiler.inc('decode')
|
||||
self.profiler.inc("decode")
|
||||
if next_token == self.tokenizer.eos_token_id:
|
||||
assert self.args.batch_size == 1
|
||||
break
|
||||
yield self.append_new_tokens(next_token)
|
||||
yield self.streamer.end()
|
||||
|
||||
def check_is_new(self,thread_id:str):
|
||||
def check_is_new(self, thread_id: str):
|
||||
if not self.use_static_cache:
|
||||
return True
|
||||
if self.last_request_id is None:
|
||||
self.last_request_id = thread_id
|
||||
return True
|
||||
else:
|
||||
if self.last_request_id==thread_id:
|
||||
if self.last_request_id == thread_id:
|
||||
return False
|
||||
else:
|
||||
self.last_request_id = thread_id
|
||||
return True
|
||||
|
||||
async def inference(self,local_messages,thread_id:str):
|
||||
self.profiler.create_and_start_timer('tokenize')
|
||||
if isinstance(local_messages,List):
|
||||
input_ids = self.format_and_tokenize_input_ids(thread_id,local_messages)
|
||||
elif isinstance(local_messages,str):
|
||||
async def inference(self, local_messages, thread_id: str):
|
||||
self.profiler.create_and_start_timer("tokenize")
|
||||
if isinstance(local_messages, List):
|
||||
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
||||
elif isinstance(local_messages, str):
|
||||
input_ids = self.tokenize_prompt(local_messages)
|
||||
else:
|
||||
raise ValueError('local_messages should be List or str')
|
||||
raise ValueError("local_messages should be List or str")
|
||||
|
||||
self.profiler.pause_timer('tokenize')
|
||||
self.profiler.pause_timer("tokenize")
|
||||
|
||||
self.profiler.create_and_start_timer('prefill')
|
||||
for t in self.prefill(input_ids,self.check_is_new(thread_id)):
|
||||
self.profiler.create_and_start_timer("prefill")
|
||||
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
|
||||
if t is not None:
|
||||
print(t,end='')
|
||||
print(t, end="")
|
||||
yield t
|
||||
self.profiler.pause_timer('prefill')
|
||||
self.profiler.pause_timer("prefill")
|
||||
|
||||
self.profiler.create_and_start_timer('decode')
|
||||
self.profiler.create_and_start_timer("decode")
|
||||
for t in self.generate():
|
||||
if t is not None:
|
||||
print(t,end='')
|
||||
print(t, end="")
|
||||
yield t
|
||||
print('')
|
||||
self.profiler.pause_timer('decode')
|
||||
print("")
|
||||
self.profiler.pause_timer("decode")
|
||||
self.report_last_time_performance()
|
||||
|
||||
|
|
|
@ -1,23 +1,24 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
'''
|
||||
Description :
|
||||
"""
|
||||
Description :
|
||||
Author : unicornchan
|
||||
Date : 2024-06-11 16:35:42
|
||||
Version : 1.0.0
|
||||
LastEditors : WuHao
|
||||
LastEditors : WuHao
|
||||
LastEditTime : 2024-08-12 06:31:14
|
||||
'''
|
||||
"""
|
||||
import os
|
||||
import shutil
|
||||
import yaml
|
||||
|
||||
from ktransformers.server.config.singleton import Singleton
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Config(metaclass=Singleton):
|
||||
"""Singleton pattern Config class, used to get all configurations.
|
||||
"""
|
||||
"""Singleton pattern Config class, used to get all configurations."""
|
||||
|
||||
CONFIG_FILE_NAME = "config.yaml"
|
||||
|
||||
@staticmethod
|
||||
|
@ -27,22 +28,20 @@ class Config(metaclass=Singleton):
|
|||
Returns:
|
||||
dict: all configs
|
||||
"""
|
||||
base_path: str = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(__file__)))
|
||||
config_yaml: str = os.path.join(
|
||||
base_path, "configs", Config.CONFIG_FILE_NAME)
|
||||
|
||||
user_path: str = os.path.expanduser('~')
|
||||
localstore_path: str = os.path.join(user_path,'.ktransformers')
|
||||
config_path: str = os.path.join(localstore_path,Config.CONFIG_FILE_NAME)
|
||||
base_path: str = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
config_yaml: str = os.path.join(base_path, "configs", Config.CONFIG_FILE_NAME)
|
||||
|
||||
user_path: str = os.path.expanduser("~")
|
||||
localstore_path: str = os.path.join(user_path, ".ktransformers")
|
||||
config_path: str = os.path.join(localstore_path, Config.CONFIG_FILE_NAME)
|
||||
if not os.path.exists(config_yaml):
|
||||
print(f"Can't find config file, {config_yaml}")
|
||||
exit(-1)
|
||||
if not os.path.exists(localstore_path):
|
||||
os.mkdir(localstore_path)
|
||||
if not os.path.exists(config_path):
|
||||
shutil.copyfile(config_yaml,config_path)
|
||||
with open(config_path, 'r', encoding="utf-8") as fp:
|
||||
shutil.copyfile(config_yaml, config_path)
|
||||
with open(config_path, "r", encoding="utf-8") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
return config
|
||||
|
||||
|
@ -52,16 +51,14 @@ class Config(metaclass=Singleton):
|
|||
process file path
|
||||
"""
|
||||
base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
real_path = path if os.path.isabs(
|
||||
path) else os.path.join(base_path, path)
|
||||
real_path = path if os.path.isabs(path) else os.path.join(base_path, path)
|
||||
return real_path
|
||||
|
||||
def __init__(self):
|
||||
cfg = Config.load()
|
||||
self.base_path = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(__file__)))
|
||||
self.user_path: str = os.path.expanduser('~')
|
||||
self.localstore_path: str = os.path.join(self.user_path,'.ktransformers')
|
||||
self.base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
self.user_path: str = os.path.expanduser("~")
|
||||
self.localstore_path: str = os.path.join(self.user_path, ".ktransformers")
|
||||
# log configs
|
||||
self.log_dir = os.path.join(self.base_path, Config.to_path(cfg["log"]["dir"]))
|
||||
self.log_file = cfg["log"]["file"]
|
||||
|
@ -69,7 +66,7 @@ class Config(metaclass=Singleton):
|
|||
self.backup_count = cfg["log"]["backup_count"]
|
||||
|
||||
# server configs
|
||||
self.server: dict = cfg.get("server",{})
|
||||
self.server: dict = cfg.get("server", {})
|
||||
self.server_ip = self.server.get("ip", "0.0.0.0")
|
||||
self.server_port = self.server.get("port", 9016)
|
||||
|
||||
|
@ -86,16 +83,68 @@ class Config(metaclass=Singleton):
|
|||
self.user_config: dict = cfg.get("user", {})
|
||||
self.user_secret_key = self.user_config.get("secret_key", "")
|
||||
self.user_algorithm = self.user_config.get("algorithm", "")
|
||||
|
||||
|
||||
# model config
|
||||
self.model:dict = cfg.get("model", {})
|
||||
self.model: dict = cfg.get("model", {})
|
||||
self.backend_type: str = self.model.get("type", "transformers")
|
||||
self.model_path: str = self.model.get("path", "")
|
||||
self.model_dir: str = self.model.get("path", "")
|
||||
# to make sure it consistent with previous version
|
||||
self.model_path: str = self.model_dir
|
||||
self.model_name: str = self.model.get("name", "")
|
||||
self.model_device: str = self.model.get("device", "cuda:0")
|
||||
self.gguf_path: str = self.model.get("gguf_path", "")
|
||||
self.model_cache_lens = self.model.get("cache_lens")
|
||||
|
||||
self.gguf_path: Optional[str] = self.model.get("gguf_path", None)
|
||||
# self.model_cache_lens = self.model.get("cache_lens")
|
||||
self.optimize_config_path: Optional[str] = self.model.get(
|
||||
"optimize_config_path", "./ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml"
|
||||
)
|
||||
self.paged = self.model.get("paged", True)
|
||||
|
||||
self.total_context = self.model.get("total_context", 2**18)
|
||||
self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1)
|
||||
self.max_chunk_size = self.model.get("max_chunk_size", 2048)
|
||||
self.max_new_tokens = self.model.get("max_new_tokens", 500)
|
||||
self.json_mode = self.model.get("json_mode", False)
|
||||
self.healing = self.model.get("healing", False)
|
||||
self.ban_strings: Optional[list] = self.model.get("ban_strings", None)
|
||||
self.gpu_split: Optional[str] = self.model.get("gpu_split", None)
|
||||
self.length: Optional[int] = self.model.get("length", None)
|
||||
self.rope_scale: Optional[float] = self.model.get("rope_scale", None)
|
||||
self.rope_alpha: Optional[float] = self.model.get("rope_alpha", None)
|
||||
self.no_flash_attn = self.model.get("no_flash_attn", False)
|
||||
self.low_mem = self.model.get("low_mem", False)
|
||||
self.experts_per_token: Optional[int] = self.model.get("experts_per_token", None)
|
||||
self.load_q4 = self.model.get("load_q4", False)
|
||||
self.fast_safetensors = self.model.get("fast_safetensors", False)
|
||||
self.draft_model_dir: Optional[str] = self.model.get("draft_model_dir", None)
|
||||
self.no_draft_scale = self.model.get("no_draft_scale", False)
|
||||
self.modes = self.model.get("modes", False)
|
||||
self.mode = self.model.get("mode", "llama")
|
||||
self.username = self.model.get("username", "User")
|
||||
self.botname = self.model.get("botname", "Chatbort")
|
||||
self.system_prompt: Optional[str] = self.model.get("system_prompt", None)
|
||||
self.temperature = self.model.get("temperature", 0.95)
|
||||
self.smoothing_factor = self.model.get("smoothing_factor", 0.0)
|
||||
self.dynamic_temperature: Optional[str] = self.model.get("dynamic_temperature", None)
|
||||
self.top_k = self.model.get("top_k", 50)
|
||||
self.top_p = self.model.get("top_p", 0.8)
|
||||
self.top_a = self.model.get("top_a", 0.0)
|
||||
self.skew = self.model.get("skew", 0.0)
|
||||
self.typical = self.model.get("typical", 0.0)
|
||||
self.repetition_penalty = self.model.get("repetition_penalty", 1.01)
|
||||
self.frequency_penalty = self.model.get("frequency_penalty", 0.0)
|
||||
self.presence_penalty = self.model.get("presence_penalty", 0.0)
|
||||
self.max_response_tokens = self.model.get("max_response_tokens", 300)
|
||||
self.response_chunk = self.model.get("response_chunk", 250)
|
||||
self.no_code_formatting = self.model.get("no_code_formatting", False)
|
||||
self.cache_8bit = self.model.get("cache_8bit", False)
|
||||
self.cache_q4 = self.model.get("cache_q4", True)
|
||||
self.ngram_decoding = self.model.get("ngram_decoding", False)
|
||||
self.print_timings = self.model.get("print_timings", False)
|
||||
self.amnesia = self.model.get("amnesia", False)
|
||||
self.batch_size = self.model.get("batch_size", 1)
|
||||
self.cache_lens = self.model.get("cache_lens", 4096)
|
||||
self.device = self.model.get("device", "cuda:2")
|
||||
|
||||
# web config
|
||||
self.web: dict = cfg.get("web", {})
|
||||
self.web_cross_domain: bool = self.web.get("open_cross_domain", True)
|
||||
|
@ -104,10 +153,32 @@ class Config(metaclass=Singleton):
|
|||
self.ext: dict = cfg.get("ext", {})
|
||||
self.cpu_infer = self.ext.get("cpu_infer", 10)
|
||||
|
||||
#file config
|
||||
self.local_store_configs: dict = cfg.get("local_store",{})
|
||||
self.file_upload_dir: str = os.path.join(self.localstore_path,self.local_store_configs.get("file_upload_dir",""))
|
||||
self.assistant_store_dir: str = os.path.join(self.localstore_path,self.local_store_configs.get("assistant_store_dir",""))
|
||||
# file config
|
||||
self.local_store_configs: dict = cfg.get("local_store", {})
|
||||
self.file_upload_dir: str = os.path.join(
|
||||
self.localstore_path, self.local_store_configs.get("file_upload_dir", "")
|
||||
)
|
||||
self.assistant_store_dir: str = os.path.join(
|
||||
self.localstore_path, self.local_store_configs.get("assistant_store_dir", "")
|
||||
)
|
||||
|
||||
#long context config
|
||||
self.long_context_config: dict = cfg.get("long_context",{})
|
||||
# long context config
|
||||
self.long_context_config: dict = cfg.get("long_context", {})
|
||||
self.chunk_size = self.long_context_config.get("chunk_size", 4096)
|
||||
self.max_seq_len = self.long_context_config.get("max_seq_len", 32000)
|
||||
self.block_size = self.long_context_config.get("block_size", 128)
|
||||
self.local_windows_len = self.long_context_config.get("local_windows_len", 4096)
|
||||
self.second_select_num = self.long_context_config.get("second_select_num", 32)
|
||||
self.anchor_type = self.long_context_config.get("anchor_type", "DYNAMIC")
|
||||
self.kv_type = self.long_context_config.get("kv_type", "FP16")
|
||||
self.dense_layer_num = self.long_context_config.get("dense_layer_num", 2)
|
||||
self.anchor_num = self.long_context_config.get("anchor_num", 1)
|
||||
self.preselect_block = self.long_context_config.get("preselect_block", True)
|
||||
self.head_select_mode = self.long_context_config.get("head_select_mode", "SHARED")
|
||||
self.preselect_block_count = self.long_context_config.get("preselect_block_count", 32)
|
||||
self.layer_step = self.long_context_config.get("layer_step", 1)
|
||||
self.token_step = self.long_context_config.get("token_step", 100)
|
||||
|
||||
# local chat
|
||||
self.local_chat_config: dict = cfg.get("local_chat", {})
|
||||
self.prompt_file = self.local_chat_config.get("prompt_file", None)
|
||||
|
|
|
@ -3,11 +3,15 @@ import re
|
|||
from fastapi import FastAPI
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
import uvicorn.logging
|
||||
import argparse
|
||||
import uvicorn
|
||||
import sys
|
||||
|
||||
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
sys.path.insert(0, project_dir)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from ktransformers.server.args import ArgumentParser
|
||||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.server.utils.create_interface import create_interface
|
||||
from ktransformers.server.utils.create_interface import create_interface
|
||||
from ktransformers.server.backend.args import default_args
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
|
@ -44,8 +48,11 @@ def create_app():
|
|||
mount_index_routes(app)
|
||||
return app
|
||||
|
||||
|
||||
def update_web_port(config_file: str):
|
||||
ip_port_pattern = r"(localhost|((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)):[0-9]{1,5}"
|
||||
ip_port_pattern = (
|
||||
r"(localhost|((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)):[0-9]{1,5}"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f_cfg:
|
||||
web_config = f_cfg.read()
|
||||
ip_port = "localhost:" + str(Config().server_port)
|
||||
|
@ -70,14 +77,15 @@ def mount_index_routes(app: FastAPI):
|
|||
|
||||
def run_api(app, host, port, **kwargs):
|
||||
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
||||
uvicorn.run(app,
|
||||
host=host,
|
||||
port=port,
|
||||
ssl_keyfile=kwargs.get("ssl_keyfile"),
|
||||
ssl_certfile=kwargs.get("ssl_certfile"),
|
||||
)
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
ssl_keyfile=kwargs.get("ssl_keyfile"),
|
||||
ssl_certfile=kwargs.get("ssl_certfile"),
|
||||
)
|
||||
else:
|
||||
uvicorn.run(app, host=host, port=port, log_level='debug')
|
||||
uvicorn.run(app, host=host, port=port, log_level="debug")
|
||||
|
||||
|
||||
def custom_openapi(app):
|
||||
|
@ -90,53 +98,27 @@ def custom_openapi(app):
|
|||
description="We provided chat completion and openai assistant interfaces.",
|
||||
routes=app.routes,
|
||||
)
|
||||
openapi_schema["info"]["x-logo"] = {
|
||||
"url": "https://kvcache.ai/media/icon_1.png"
|
||||
}
|
||||
openapi_schema["info"]["x-logo"] = {"url": "https://kvcache.ai/media/icon_1.png"}
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
||||
def main():
|
||||
cfg = Config()
|
||||
parser = argparse.ArgumentParser(prog='kvcache.ai',
|
||||
description='Ktransformers')
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=cfg.server_port)
|
||||
parser.add_argument("--ssl_keyfile", type=str)
|
||||
parser.add_argument("--ssl_certfile", type=str)
|
||||
parser.add_argument("--web", type=bool, default=False)
|
||||
parser.add_argument("--model_name", type=str, default=cfg.model_name)
|
||||
parser.add_argument("--model_path", type=str, default=cfg.model_path)
|
||||
parser.add_argument("--device", type=str, default=cfg.model_device, help="Warning: Abandoning this parameter")
|
||||
parser.add_argument("--gguf_path", type=str, default=cfg.gguf_path)
|
||||
parser.add_argument("--optimize_config_path", default=None, type=str, required=False)
|
||||
parser.add_argument("--cpu_infer", type=int, default=cfg.cpu_infer)
|
||||
parser.add_argument("--type", type=str, default=cfg.backend_type)
|
||||
arg_parser = ArgumentParser(cfg)
|
||||
|
||||
# 初始化消息
|
||||
args = parser.parse_args()
|
||||
cfg.model_name = args.model_name
|
||||
cfg.model_path = args.model_path
|
||||
cfg.model_device = args.device
|
||||
cfg.mount_web = args.web
|
||||
cfg.server_ip = args.host
|
||||
cfg.server_port = args.port
|
||||
cfg.cpu_infer = args.cpu_infer
|
||||
cfg.backend_type = args.type
|
||||
|
||||
default_args.model_dir = args.model_path
|
||||
default_args.device = args.device
|
||||
default_args.gguf_path = args.gguf_path
|
||||
default_args.optimize_config_path = args.optimize_config_path
|
||||
|
||||
args = arg_parser.parse_args()
|
||||
app = create_app()
|
||||
custom_openapi(app)
|
||||
create_interface(config=cfg, default_args=default_args)
|
||||
run_api(app=app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,)
|
||||
create_interface(config=cfg, default_args=cfg)
|
||||
run_api(
|
||||
app=app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
3
ktransformers/website/package-lock.json
generated
3
ktransformers/website/package-lock.json
generated
|
@ -4412,8 +4412,9 @@
|
|||
},
|
||||
"node_modules/@vue/cli": {
|
||||
"version": "5.0.8",
|
||||
"resolved": "https://registry.npmmirror.com/@vue/cli/-/cli-5.0.8.tgz",
|
||||
"resolved": "https://registry.npmjs.org/@vue/cli/-/cli-5.0.8.tgz",
|
||||
"integrity": "sha512-c/QKPdC09bYkW22m/boXkLaiz10z0Z2WHZO7zEeNdfSduqyWINZhKc6hVQU3Vk0NXW7BJAd7zWmcUrC8L9TuAA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/ejs": "^3.0.6",
|
||||
"@types/inquirer": "^8.1.3",
|
||||
|
|
|
@ -69,4 +69,8 @@ ktransformers = "ktransformers.server.main:main"
|
|||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["./", ]
|
||||
include = ["ktransformers"]
|
||||
include = ["ktransformers"]
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
preview = true
|
||||
unstable = true
|
Loading…
Add table
Reference in a new issue