mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
✨: refactor local_chat and fix message slice bug in server
This commit is contained in:
parent
43fc7f44a6
commit
dd1d8667f3
13 changed files with 549 additions and 405 deletions
4
.flake8
Normal file
4
.flake8
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
[flake8]
|
||||||
|
max-line-length = 120
|
||||||
|
extend-select = B950
|
||||||
|
extend-ignore = E203,E501,E701, B001,B006,B007,B008,B009,B010,B011,B016,B028,B031,B950,E265,E266,E401,E402,E711,E712,E713,E721,E722,E731,F401,F403,F405,F541,F811,F821,F841,W391
|
21
Makefile
Normal file
21
Makefile
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
flake_find:
|
||||||
|
cd ktransformers && flake8 | grep -Eo '[A-Z][0-9]{3}' | sort | uniq| paste -sd ',' -
|
||||||
|
format:
|
||||||
|
@cd ktransformers && black .
|
||||||
|
@black setup.py
|
||||||
|
dev_install:
|
||||||
|
# clear build dirs
|
||||||
|
rm -rf build
|
||||||
|
rm -rf *.egg-info
|
||||||
|
rm -rf ktransformers/ktransformers_ext/build
|
||||||
|
rm -rf ktransformers/ktransformers_ext/cuda/build
|
||||||
|
rm -rf ktransformers/ktransformers_ext/cuda/dist
|
||||||
|
rm -rf ktransformers/ktransformers_ext/cuda/*.egg-info
|
||||||
|
|
||||||
|
# install ktransformers
|
||||||
|
echo "Installing python dependencies from requirements.txt"
|
||||||
|
pip install -r requirements-local_chat.txt
|
||||||
|
|
||||||
|
echo "Installing ktransformers"
|
||||||
|
KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . --no-build-isolation
|
||||||
|
echo "Installation completed successfully"
|
|
@ -7,7 +7,7 @@ log:
|
||||||
|
|
||||||
server:
|
server:
|
||||||
ip: 0.0.0.0
|
ip: 0.0.0.0
|
||||||
port: 12456
|
port: 10002
|
||||||
|
|
||||||
db:
|
db:
|
||||||
type: "sqllite"
|
type: "sqllite"
|
||||||
|
@ -24,10 +24,13 @@ model:
|
||||||
type: ktransformers
|
type: ktransformers
|
||||||
|
|
||||||
name: DeepSeek-Coder-V2-Instruct
|
name: DeepSeek-Coder-V2-Instruct
|
||||||
path: /mnt/data/model/DeepSeek-Coder-V2-Instruct/
|
# path: /mnt/data/model/DeepSeek-Coder-V2-Instruct/
|
||||||
gguf_path: /mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH/
|
path: deepseek-ai/DeepSeek-V2-Lite-Chat
|
||||||
|
# gguf_path: /mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH/
|
||||||
|
gguf_path: ./DeepSeek-V2-Lite-Chat-GGUF
|
||||||
|
|
||||||
device: cuda:0
|
device: cuda:0
|
||||||
|
cache_lens: 8192
|
||||||
|
|
||||||
web:
|
web:
|
||||||
mount: False
|
mount: False
|
||||||
|
@ -50,4 +53,7 @@ long_context:
|
||||||
head_select_mode: SHARED
|
head_select_mode: SHARED
|
||||||
preselect_block_count: 32
|
preselect_block_count: 32
|
||||||
layer_step: 1
|
layer_step: 1
|
||||||
token_step: 100
|
token_step:
|
||||||
|
|
||||||
|
local_chat:
|
||||||
|
prompt_file: "./ktransformers/p.txt"
|
|
@ -1,33 +1,23 @@
|
||||||
"""
|
"""
|
||||||
Description :
|
Description :
|
||||||
Author : Boxin Zhang, Azure-Tang
|
Author : Boxin Zhang, Azure-Tang
|
||||||
Version : 0.1.0
|
Version : 0.1.0
|
||||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from ktransformers.server.args import ArgumentParser
|
||||||
|
|
||||||
project_dir = os.path.dirname(os.path.dirname(__file__))
|
project_dir = os.path.dirname(os.path.dirname(__file__))
|
||||||
sys.path.insert(0, project_dir)
|
sys.path.insert(0, project_dir)
|
||||||
import torch
|
|
||||||
import logging
|
|
||||||
from transformers import (
|
|
||||||
AutoTokenizer,
|
|
||||||
AutoConfig,
|
|
||||||
AutoModelForCausalLM,
|
|
||||||
GenerationConfig,
|
|
||||||
TextStreamer,
|
|
||||||
)
|
|
||||||
import json
|
|
||||||
import fire
|
|
||||||
from ktransformers.optimize.optimize import optimize_and_load_gguf
|
|
||||||
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
|
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
|
||||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
|
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
|
||||||
from ktransformers.models.modeling_llama import LlamaForCausalLM
|
from ktransformers.models.modeling_llama import LlamaForCausalLM
|
||||||
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
||||||
from ktransformers.util.utils import prefill_and_generate
|
|
||||||
from ktransformers.server.config.config import Config
|
from ktransformers.server.config.config import Config
|
||||||
|
|
||||||
custom_models = {
|
custom_models = {
|
||||||
|
@ -37,9 +27,7 @@ custom_models = {
|
||||||
"MixtralForCausalLM": MixtralForCausalLM,
|
"MixtralForCausalLM": MixtralForCausalLM,
|
||||||
}
|
}
|
||||||
|
|
||||||
ktransformer_rules_dir = (
|
ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
|
||||||
os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
|
|
||||||
)
|
|
||||||
default_optimize_rules = {
|
default_optimize_rules = {
|
||||||
"DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
|
"DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
|
||||||
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
|
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
|
||||||
|
@ -48,75 +36,28 @@ default_optimize_rules = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def local_chat(
|
def local_chat():
|
||||||
model_path: str | None = None,
|
config = Config()
|
||||||
optimize_rule_path: str = None,
|
arg_parser = ArgumentParser(config)
|
||||||
gguf_path: str | None = None,
|
# 初始化消息
|
||||||
max_new_tokens: int = 1000,
|
arg_parser.parse_args()
|
||||||
cpu_infer: int = Config().cpu_infer,
|
if config.backend_type == "transformers":
|
||||||
use_cuda_graph: bool = True,
|
from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
|
||||||
prompt_file : str | None = None,
|
elif config.backend_type == "exllamav2":
|
||||||
mode: str = "normal",
|
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface
|
||||||
):
|
elif config.backend_type == "ktransformers":
|
||||||
|
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
|
||||||
|
|
||||||
torch.set_grad_enabled(False)
|
|
||||||
|
|
||||||
Config().cpu_infer = cpu_infer
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
||||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
|
||||||
if mode == 'long_context':
|
|
||||||
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
|
|
||||||
torch.set_default_dtype(torch.float16)
|
|
||||||
else:
|
else:
|
||||||
torch.set_default_dtype(config.torch_dtype)
|
raise NotImplementedError(f"{config.backend_type} not implemented")
|
||||||
|
interface = BackendInterface(config)
|
||||||
with torch.device("meta"):
|
|
||||||
if config.architectures[0] in custom_models:
|
|
||||||
print("using custom modeling_xxx.py.")
|
|
||||||
if (
|
|
||||||
"Qwen2Moe" in config.architectures[0]
|
|
||||||
): # Qwen2Moe must use flash_attention_2 to avoid overflow.
|
|
||||||
config._attn_implementation = "flash_attention_2"
|
|
||||||
if "Llama" in config.architectures[0]:
|
|
||||||
config._attn_implementation = "eager"
|
|
||||||
if "Mixtral" in config.architectures[0]:
|
|
||||||
config._attn_implementation = "flash_attention_2"
|
|
||||||
|
|
||||||
model = custom_models[config.architectures[0]](config)
|
|
||||||
else:
|
|
||||||
model = AutoModelForCausalLM.from_config(
|
|
||||||
config, trust_remote_code=True, attn_implementation="flash_attention_2"
|
|
||||||
)
|
|
||||||
|
|
||||||
if optimize_rule_path is None:
|
|
||||||
if config.architectures[0] in default_optimize_rules:
|
|
||||||
print("using default_optimize_rule for", config.architectures[0])
|
|
||||||
optimize_rule_path = default_optimize_rules[config.architectures[0]]
|
|
||||||
else:
|
|
||||||
optimize_rule_path = input(
|
|
||||||
"please input the path of your rule file(yaml file containing optimize rules):"
|
|
||||||
)
|
|
||||||
|
|
||||||
if gguf_path is None:
|
|
||||||
gguf_path = input(
|
|
||||||
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
|
|
||||||
)
|
|
||||||
optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config)
|
|
||||||
|
|
||||||
model.generation_config = GenerationConfig.from_pretrained(model_path)
|
|
||||||
if model.generation_config.pad_token_id is None:
|
|
||||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
|
||||||
model.eval()
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
|
||||||
system = platform.system()
|
system = platform.system()
|
||||||
if system == "Windows":
|
if system == "Windows":
|
||||||
os.system("cls")
|
os.system("cls")
|
||||||
else:
|
else:
|
||||||
os.system("clear")
|
os.system("clear")
|
||||||
|
# add a history chat content
|
||||||
|
his_content = []
|
||||||
while True:
|
while True:
|
||||||
content = input("Chat: ")
|
content = input("Chat: ")
|
||||||
if content.startswith('"""'): # prefix """
|
if content.startswith('"""'): # prefix """
|
||||||
|
@ -132,28 +73,27 @@ def local_chat(
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
content += line + "\n"
|
content += line + "\n"
|
||||||
|
|
||||||
if content == "":
|
if content == "":
|
||||||
if prompt_file != None:
|
if config.prompt_file == None or config.prompt_file == "":
|
||||||
content = open(prompt_file, "r").read()
|
|
||||||
else:
|
|
||||||
content = "Please write a piece of quicksort code in C++."
|
content = "Please write a piece of quicksort code in C++."
|
||||||
|
else:
|
||||||
|
content = open(config.prompt_file, "r").read()
|
||||||
elif os.path.isfile(content):
|
elif os.path.isfile(content):
|
||||||
content = open(content, "r").read()
|
content = open(content, "r").read()
|
||||||
messages = [{"role": "user", "content": content}]
|
messages = his_content + [{"role": "user", "content": content}]
|
||||||
input_tensor = tokenizer.apply_chat_template(
|
|
||||||
messages, add_generation_prompt=True, return_tensors="pt"
|
async def async_inference(messages):
|
||||||
)
|
generated = ""
|
||||||
if mode == 'long_context':
|
async for token in interface.inference(messages, "local_chat"):
|
||||||
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
|
generated += token
|
||||||
"please change max_seq_len in ~/.ktransformers/config.yaml"
|
return generated
|
||||||
torch.set_default_dtype(
|
|
||||||
torch.bfloat16
|
generated = asyncio.run(async_inference(messages))
|
||||||
) # TODO: Remove this, replace dtype using config
|
his_content += [
|
||||||
generated = prefill_and_generate(
|
{"role": "user", "content": content},
|
||||||
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode
|
{"role": "assitant", "content": generated},
|
||||||
)
|
]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fire.Fire(local_chat)
|
local_chat()
|
||||||
|
|
|
@ -135,6 +135,8 @@ class OllamaShowResponse(BaseModel):
|
||||||
details: OllamaShowDetial
|
details: OllamaShowDetial
|
||||||
model_info: OllamaModelInfo
|
model_info: OllamaModelInfo
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
113
ktransformers/server/args.py
Normal file
113
ktransformers/server/args.py
Normal file
|
@ -0,0 +1,113 @@
|
||||||
|
import argparse
|
||||||
|
from ktransformers.server.backend.args import ConfigArgs, default_args
|
||||||
|
|
||||||
|
|
||||||
|
class ArgumentParser:
|
||||||
|
def __init__(self, cfg):
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
def parse_args(self):
|
||||||
|
parser = argparse.ArgumentParser(prog="kvcache.ai", description="Ktransformers")
|
||||||
|
parser.add_argument("--host", type=str, default=self.cfg.server_ip)
|
||||||
|
parser.add_argument("--port", type=int, default=self.cfg.server_port)
|
||||||
|
parser.add_argument("--ssl_keyfile", type=str)
|
||||||
|
parser.add_argument("--ssl_certfile", type=str)
|
||||||
|
parser.add_argument("--web", type=bool, default=self.cfg.mount_web)
|
||||||
|
parser.add_argument("--model_name", type=str, default=self.cfg.model_name)
|
||||||
|
parser.add_argument("--model_dir", type=str, default=self.cfg.model_dir)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter"
|
||||||
|
)
|
||||||
|
parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path)
|
||||||
|
parser.add_argument("--optimize_config_path", default=self.cfg.optimize_config_path, type=str, required=False)
|
||||||
|
parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer)
|
||||||
|
parser.add_argument("--type", type=str, default=self.cfg.backend_type)
|
||||||
|
|
||||||
|
# model configs
|
||||||
|
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
|
||||||
|
parser.add_argument("--paged", type=bool, default=self.cfg.paged)
|
||||||
|
parser.add_argument("--total_context", type=int, default=self.cfg.total_context)
|
||||||
|
parser.add_argument("--max_batch_size", type=int, default=self.cfg.max_batch_size)
|
||||||
|
parser.add_argument("--max_chunk_size", type=int, default=self.cfg.max_chunk_size)
|
||||||
|
parser.add_argument("--max_new_tokens", type=int, default=self.cfg.max_new_tokens)
|
||||||
|
parser.add_argument("--json_mode", type=bool, default=self.cfg.json_mode)
|
||||||
|
parser.add_argument("--healing", type=bool, default=self.cfg.healing)
|
||||||
|
parser.add_argument("--ban_strings", type=list, default=self.cfg.ban_strings, required=False)
|
||||||
|
parser.add_argument("--gpu_split", type=str, default=self.cfg.gpu_split, required=False)
|
||||||
|
parser.add_argument("--length", type=int, default=self.cfg.length, required=False)
|
||||||
|
parser.add_argument("--rope_scale", type=float, default=self.cfg.rope_scale, required=False)
|
||||||
|
parser.add_argument("--rope_alpha", type=float, default=self.cfg.rope_alpha, required=False)
|
||||||
|
parser.add_argument("--no_flash_attn", type=bool, default=self.cfg.no_flash_attn)
|
||||||
|
parser.add_argument("--low_mem", type=bool, default=self.cfg.low_mem)
|
||||||
|
parser.add_argument("--experts_per_token", type=int, default=self.cfg.experts_per_token, required=False)
|
||||||
|
parser.add_argument("--load_q4", type=bool, default=self.cfg.load_q4)
|
||||||
|
parser.add_argument("--fast_safetensors", type=bool, default=self.cfg.fast_safetensors)
|
||||||
|
parser.add_argument("--draft_model_dir", type=str, default=self.cfg.draft_model_dir, required=False)
|
||||||
|
parser.add_argument("--no_draft_scale", type=bool, default=self.cfg.no_draft_scale)
|
||||||
|
parser.add_argument("--modes", type=bool, default=self.cfg.modes)
|
||||||
|
parser.add_argument("--mode", type=str, default=self.cfg.mode)
|
||||||
|
parser.add_argument("--username", type=str, default=self.cfg.username)
|
||||||
|
parser.add_argument("--botname", type=str, default=self.cfg.botname)
|
||||||
|
parser.add_argument("--system_prompt", type=str, default=self.cfg.system_prompt, required=False)
|
||||||
|
parser.add_argument("--temperature", type=float, default=self.cfg.temperature)
|
||||||
|
parser.add_argument("--smoothing_factor", type=float, default=self.cfg.smoothing_factor)
|
||||||
|
parser.add_argument("--dynamic_temperature", type=str, default=self.cfg.dynamic_temperature, required=False)
|
||||||
|
parser.add_argument("--top_k", type=int, default=self.cfg.top_k)
|
||||||
|
parser.add_argument("--top_p", type=float, default=self.cfg.top_p)
|
||||||
|
parser.add_argument("--top_a", type=float, default=self.cfg.top_a)
|
||||||
|
parser.add_argument("--skew", type=float, default=self.cfg.skew)
|
||||||
|
parser.add_argument("--typical", type=float, default=self.cfg.typical)
|
||||||
|
parser.add_argument("--repetition_penalty", type=float, default=self.cfg.repetition_penalty)
|
||||||
|
parser.add_argument("--frequency_penalty", type=float, default=self.cfg.frequency_penalty)
|
||||||
|
parser.add_argument("--presence_penalty", type=float, default=self.cfg.presence_penalty)
|
||||||
|
parser.add_argument("--max_response_tokens", type=int, default=self.cfg.max_response_tokens)
|
||||||
|
parser.add_argument("--response_chunk", type=int, default=self.cfg.response_chunk)
|
||||||
|
parser.add_argument("--no_code_formatting", type=bool, default=self.cfg.no_code_formatting)
|
||||||
|
parser.add_argument("--cache_8bit", type=bool, default=self.cfg.cache_8bit)
|
||||||
|
parser.add_argument("--cache_q4", type=bool, default=self.cfg.cache_q4)
|
||||||
|
parser.add_argument("--ngram_decoding", type=bool, default=self.cfg.ngram_decoding)
|
||||||
|
parser.add_argument("--print_timings", type=bool, default=self.cfg.print_timings)
|
||||||
|
parser.add_argument("--amnesia", type=bool, default=self.cfg.amnesia)
|
||||||
|
parser.add_argument("--batch_size", type=int, default=self.cfg.batch_size)
|
||||||
|
parser.add_argument("--cache_lens", type=int, default=self.cfg.cache_lens)
|
||||||
|
|
||||||
|
# log configs
|
||||||
|
# log level: debug, info, warn, error, crit
|
||||||
|
parser.add_argument("--log_dir", type=str, default=self.cfg.log_dir)
|
||||||
|
parser.add_argument("--log_file", type=str, default=self.cfg.log_file)
|
||||||
|
parser.add_argument("--log_level", type=str, default=self.cfg.log_level)
|
||||||
|
parser.add_argument("--backup_count", type=int, default=self.cfg.backup_count)
|
||||||
|
|
||||||
|
# db configs
|
||||||
|
parser.add_argument("--db_type", type=str, default=self.cfg.db_type)
|
||||||
|
parser.add_argument("--db_host", type=str, default=self.cfg.db_host)
|
||||||
|
parser.add_argument("--db_port", type=str, default=self.cfg.db_port)
|
||||||
|
parser.add_argument("--db_name", type=str, default=self.cfg.db_name)
|
||||||
|
parser.add_argument("--db_pool_size", type=int, default=self.cfg.db_pool_size)
|
||||||
|
parser.add_argument("--db_database", type=str, default=self.cfg.db_database)
|
||||||
|
|
||||||
|
# user config
|
||||||
|
parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key)
|
||||||
|
parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm)
|
||||||
|
|
||||||
|
# web config
|
||||||
|
parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain)
|
||||||
|
|
||||||
|
# file config
|
||||||
|
parser.add_argument("--file_upload_dir", type=str, default=self.cfg.file_upload_dir)
|
||||||
|
parser.add_argument("--assistant_store_dir", type=str, default=self.cfg.assistant_store_dir)
|
||||||
|
# local chat
|
||||||
|
parser.add_argument("--prompt_file", type=str, default=self.cfg.prompt_file)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
# set config from args
|
||||||
|
for key, value in vars(args).items():
|
||||||
|
if value is not None and hasattr(self.cfg, key):
|
||||||
|
setattr(self.cfg, key, value)
|
||||||
|
# we add the name not match args individually
|
||||||
|
self.cfg.model_device = args.device
|
||||||
|
self.cfg.mount_web = args.web
|
||||||
|
self.cfg.server_ip = args.host
|
||||||
|
self.cfg.server_port = args.port
|
||||||
|
self.cfg.backend_type = args.type
|
||||||
|
return args
|
|
@ -1,97 +1,89 @@
|
||||||
from pydantic import BaseModel,Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from ktransformers.server.config.config import Config
|
from ktransformers.server.config.config import Config
|
||||||
|
|
||||||
|
|
||||||
class ConfigArgs(BaseModel):
|
class ConfigArgs(BaseModel):
|
||||||
model_name : Optional[str] = Field(..., description="Model name")
|
model_name: Optional[str] = Field(..., description="Model name")
|
||||||
model_dir: Optional[str] = Field(..., description="Path to model directory")
|
model_dir: Optional[str] = Field(..., description="Path to model directory")
|
||||||
optimize_config_path: Optional[str] = Field('./KTransformers/optimize_config/DeepSeek-V2-Chat.json', description="Path of your optimize config json file")
|
optimize_config_path: Optional[str] = Field(None, description="Path of your optimize config yml file")
|
||||||
gguf_path: Optional[str] = Field('/models/DeepSeek-Coder-V2-Instruct-GGUF/DeepSeek-Coder-V2-Instruct-Q4_K_M.gguf', description="Path of your gguf file")
|
gguf_path: Optional[str] = Field(None, description="Path of your gguf file")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
paged : bool = Field(True,description='Wether to use paged attention kv cache')
|
paged: bool = Field(None, description="Whether to use paged attention kv cache")
|
||||||
|
total_context: int = Field(
|
||||||
# total_context: int = Field(16384, description="Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the total to distribute dynamically over however many jobs are active at once")
|
None,
|
||||||
total_context: int = Field(2**18, description="Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the total to distribute dynamically over however many jobs are active at once")
|
description=(
|
||||||
max_batch_size: int = Field(20 if paged else 1, description="Max number of batches to run at once, assuming the sequences will fit within total_context")
|
"Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the"
|
||||||
max_chunk_size: int = Field(2048, description="Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new job is started, but at the expense of overall prompt ingestion speed")
|
" total to distribute dynamically over however many jobs are active at once"
|
||||||
max_new_tokens: int = Field(500, description="Max new tokens per completion. For this example applies to all jobs")
|
),
|
||||||
json_mode: bool = Field(False, description="Use LMFE to constrain the output to JSON format. See schema and details below")
|
)
|
||||||
healing: bool = Field(False, description="Demonstrate token healing")
|
max_batch_size: int = Field(
|
||||||
|
None, description="Max number of batches to run at once, assuming the sequences will fit within total_context"
|
||||||
|
)
|
||||||
|
max_chunk_size: int = Field(
|
||||||
|
None,
|
||||||
|
description=(
|
||||||
|
"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new"
|
||||||
|
" job is started, but at the expense of overall prompt ingestion speed"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
max_new_tokens: int = Field(None, description="Max new tokens per completion. For this example applies to all jobs")
|
||||||
|
json_mode: bool = Field(
|
||||||
|
None, description="Use LMFE to constrain the output to JSON format. See schema and details below"
|
||||||
|
)
|
||||||
|
healing: bool = Field(None, description="Demonstrate token healing")
|
||||||
ban_strings: Optional[list] = Field(None, description="Ban some phrases maybe")
|
ban_strings: Optional[list] = Field(None, description="Ban some phrases maybe")
|
||||||
|
|
||||||
gpu_split: Optional[str] = Field(None, description='"auto", or VRAM allocation per GPU in GB')
|
gpu_split: Optional[str] = Field(None, description='"auto", or VRAM allocation per GPU in GB')
|
||||||
length: Optional[int] = Field(None, description="Maximum sequence length")
|
length: Optional[int] = Field(None, description="Maximum sequence length")
|
||||||
rope_scale: Optional[float] = Field(None, description="RoPE scaling factor")
|
rope_scale: Optional[float] = Field(None, description="RoPE scaling factor")
|
||||||
rope_alpha: Optional[float] = Field(None, description="RoPE alpha value (NTK)")
|
rope_alpha: Optional[float] = Field(None, description="RoPE alpha value (NTK)")
|
||||||
no_flash_attn: bool = Field(False, description="Disable Flash Attention")
|
no_flash_attn: bool = Field(None, description="Disable Flash Attention")
|
||||||
low_mem: bool = Field(
|
low_mem: bool = Field(None, description="Enable VRAM optimizations, potentially trading off speed")
|
||||||
False,
|
|
||||||
description="Enable VRAM optimizations, potentially trading off speed",
|
|
||||||
)
|
|
||||||
experts_per_token: Optional[int] = Field(
|
experts_per_token: Optional[int] = Field(
|
||||||
None,
|
None, description="Override MoE model's default number of experts per token"
|
||||||
description="Override MoE model's default number of experts per token",
|
|
||||||
)
|
|
||||||
load_q4: bool = Field(False, description="Load weights in Q4 mode")
|
|
||||||
fast_safetensors: bool = Field(
|
|
||||||
False,
|
|
||||||
description="Optimized safetensors loading with direct I/O (experimental!)",
|
|
||||||
)
|
)
|
||||||
|
load_q4: bool = Field(None, description="Load weights in Q4 mode")
|
||||||
|
fast_safetensors: bool = Field(None, description="Optimized safetensors loading with direct I/O (experimental!)")
|
||||||
draft_model_dir: Optional[str] = Field(None, description="Path to draft model directory")
|
draft_model_dir: Optional[str] = Field(None, description="Path to draft model directory")
|
||||||
no_draft_scale: bool = Field(
|
no_draft_scale: bool = Field(
|
||||||
False,
|
None,
|
||||||
description="If draft model has smaller context size than model, don't apply alpha (NTK) scaling to extend it",
|
description="If draft model has smaller context size than model, don't apply alpha (NTK) scaling to extend it",
|
||||||
)
|
)
|
||||||
modes: bool = Field(False, description="List available modes and exit.")
|
modes: bool = Field(None, description="List available modes and exit.")
|
||||||
mode: str = Field(
|
mode: str = Field(None, description="Chat mode. Use llama for Llama 1/2 chat finetunes.")
|
||||||
"llama",
|
username: str = Field(None, description="Username when using raw chat mode")
|
||||||
description="Chat mode. Use llama for Llama 1/2 chat finetunes.",
|
botname: str = Field(None, description="Bot name when using raw chat mode")
|
||||||
)
|
|
||||||
username: str = Field("User", description="Username when using raw chat mode")
|
|
||||||
botname: str = Field("Chatbort", description="Bot name when using raw chat mode")
|
|
||||||
system_prompt: Optional[str] = Field(None, description="Use custom system prompt")
|
system_prompt: Optional[str] = Field(None, description="Use custom system prompt")
|
||||||
temperature: float = Field(0.95, description="Sampler temperature, default = 0.95 (1 to disable)")
|
temperature: float = Field(None, description="Sampler temperature, default = 0.95 (1 to disable)")
|
||||||
smoothing_factor: float = Field(0.0, description="Smoothing Factor, default = 0.0 (0 to disable)")
|
smoothing_factor: float = Field(None, description="Smoothing Factor, default = 0.0 (0 to disable)")
|
||||||
dynamic_temperature: Optional[str] = Field(
|
dynamic_temperature: Optional[str] = Field(
|
||||||
None,
|
None, description="Dynamic temperature min,max,exponent, e.g. -dyntemp 0.2,1.5,1"
|
||||||
description="Dynamic temperature min,max,exponent, e.g. -dyntemp 0.2,1.5,1",
|
|
||||||
)
|
)
|
||||||
top_k: int = Field(50, description="Sampler top-K, default = 50 (0 to disable)")
|
top_k: int = Field(None, description="Sampler top-K, default = 50 (0 to disable)")
|
||||||
top_p: float = Field(0.8, description="Sampler top-P, default = 0.8 (0 to disable)")
|
top_p: float = Field(None, description="Sampler top-P, default = 0.8 (0 to disable)")
|
||||||
top_a: float = Field(0.0, description="Sampler top-A, default = 0.0 (0 to disable)")
|
top_a: float = Field(None, description="Sampler top-A, default = 0.0 (0 to disable)")
|
||||||
skew: float = Field(0.0, description="Skew sampling, default = 0.0 (0 to disable)")
|
skew: float = Field(None, description="Skew sampling, default = 0.0 (0 to disable)")
|
||||||
typical: float = Field(
|
typical: float = Field(None, description="Sampler typical threshold, default = 0.0 (0 to disable)")
|
||||||
0.0,
|
repetition_penalty: float = Field(None, description="Sampler repetition penalty, default = 1.01 (1 to disable)")
|
||||||
description="Sampler typical threshold, default = 0.0 (0 to disable)",
|
frequency_penalty: float = Field(None, description="Sampler frequency penalty, default = 0.0 (0 to disable)")
|
||||||
)
|
presence_penalty: float = Field(None, description="Sampler presence penalty, default = 0.0 (0 to disable)")
|
||||||
repetition_penalty: float = Field(
|
max_response_tokens: int = Field(None, description="Max tokens per response, default = 1000")
|
||||||
1.01,
|
response_chunk: int = Field(None, description="Space to reserve in context for reply, default = 250")
|
||||||
description="Sampler repetition penalty, default = 1.01 (1 to disable)",
|
no_code_formatting: bool = Field(None, description="Disable code formatting/syntax highlighting")
|
||||||
)
|
cache_8bit: bool = Field(None, description="Use 8-bit (FP8) cache")
|
||||||
frequency_penalty: float = Field(
|
cache_q4: bool = Field(None, description="Use Q4 cache")
|
||||||
0.0,
|
ngram_decoding: bool = Field(None, description="Use n-gram speculative decoding")
|
||||||
description="Sampler frequency penalty, default = 0.0 (0 to disable)",
|
print_timings: bool = Field(None, description="Output timings after each prompt")
|
||||||
)
|
amnesia: bool = Field(None, description="Forget context after every response")
|
||||||
presence_penalty: float = Field(
|
|
||||||
0.0,
|
|
||||||
description="Sampler presence penalty, default = 0.0 (0 to disable)",
|
|
||||||
)
|
|
||||||
max_response_tokens: int = Field(300, description="Max tokens per response, default = 1000")
|
|
||||||
response_chunk: int = Field(250, description="Space to reserve in context for reply, default = 250")
|
|
||||||
no_code_formatting: bool = Field(False, description="Disable code formatting/syntax highlighting")
|
|
||||||
cache_8bit: bool = Field(False, description="Use 8-bit (FP8) cache")
|
|
||||||
cache_q4: bool = Field(True, description="Use Q4 cache")
|
|
||||||
ngram_decoding: bool = Field(False, description="Use n-gram speculative decoding")
|
|
||||||
print_timings: bool = Field(False, description="Output timings after each prompt")
|
|
||||||
amnesia: bool = Field(False, description="Forget context after every response")
|
|
||||||
|
|
||||||
# for transformers
|
# for transformers
|
||||||
batch_size :int = Field(1,description="Batch Size")
|
batch_size: int = Field(None, description="Batch Size")
|
||||||
cache_lens:int = Field(4096, description="Cache lens for transformers static cache")
|
cache_lens: int = Field(None, description="Cache lens for transformers static cache")
|
||||||
device:str = Field('cuda:2',description="device")
|
device: str = Field(None, description="device")
|
||||||
|
|
||||||
|
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
default_args = ConfigArgs(model_name=cfg.model_name,model_dir=cfg.model_path)
|
default_args = cfg
|
||||||
|
|
|
@ -1,6 +1,12 @@
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoTokenizer, AutoConfig, GenerationConfig
|
from transformers import AutoTokenizer, AutoConfig, GenerationConfig
|
||||||
from ktransformers.server.backend.interfaces.transformers import TransformersInterface,ConfigArgs, TransformersThreadContext,default_args,TextStreamer
|
from ktransformers.server.backend.interfaces.transformers import (
|
||||||
|
TransformersInterface,
|
||||||
|
ConfigArgs,
|
||||||
|
TransformersThreadContext,
|
||||||
|
default_args,
|
||||||
|
TextStreamer,
|
||||||
|
)
|
||||||
from ktransformers.server.config.log import logger
|
from ktransformers.server.config.log import logger
|
||||||
from ktransformers.optimize.optimize import optimize_and_load_gguf
|
from ktransformers.optimize.optimize import optimize_and_load_gguf
|
||||||
from ktransformers.models.custom_cache import StaticCache
|
from ktransformers.models.custom_cache import StaticCache
|
||||||
|
@ -14,71 +20,85 @@ class KTransformersThreadContext(TransformersThreadContext):
|
||||||
|
|
||||||
|
|
||||||
class KTransformersInterface(TransformersInterface):
|
class KTransformersInterface(TransformersInterface):
|
||||||
def __init__(self,args:ConfigArgs= default_args):
|
def __init__(self, args: ConfigArgs = default_args):
|
||||||
self.args = args
|
self.args = args
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir,device = args.device)
|
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device)
|
||||||
config=AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||||
config._attn_implementation="flash_attention_2"
|
config._attn_implementation = "flash_attention_2"
|
||||||
|
|
||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
self.model=custom_models[config.architectures[0]](config)
|
self.model = custom_models[config.architectures[0]](config)
|
||||||
if default_args.optimize_config_path is None:
|
if default_args.optimize_config_path is None:
|
||||||
optimize_rule_path = default_optimize_rules[config.architectures[0]]
|
optimize_rule_path = default_optimize_rules[config.architectures[0]]
|
||||||
else:
|
else:
|
||||||
optimize_rule_path = args.optimize_config_path
|
optimize_rule_path = args.optimize_config_path
|
||||||
|
|
||||||
# print(optimize_config)
|
# print(optimize_config)
|
||||||
|
|
||||||
gguf_path = args.gguf_path
|
gguf_path = args.gguf_path
|
||||||
if gguf_path is None:
|
if gguf_path is None:
|
||||||
gguf_path = input(
|
gguf_path = input(
|
||||||
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
|
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
|
||||||
|
" belong to current model):"
|
||||||
)
|
)
|
||||||
optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config)
|
optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config)
|
||||||
|
|
||||||
|
|
||||||
device_map = self.model.gguf_loader.tensor_device_map
|
device_map = self.model.gguf_loader.tensor_device_map
|
||||||
logger.info(f'{args.model_name} loaded from {args.model_dir} to {device_map}')
|
logger.info(f"{args.model_name} loaded from {args.model_dir} to {device_map}")
|
||||||
self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=device_map, dtype=self.model.dtype)
|
self.cache = StaticCache(
|
||||||
logger.info(f'StaticCache (length={args.cache_lens}) created at {device_map}, batch size:{args.batch_size}')
|
config=self.model.config,
|
||||||
|
max_batch_size=args.batch_size,
|
||||||
|
max_cache_len=args.cache_lens,
|
||||||
|
device=device_map,
|
||||||
|
dtype=self.model.dtype,
|
||||||
|
)
|
||||||
|
logger.info(f"StaticCache (length={args.cache_lens}) created at {device_map}, batch size:{args.batch_size}")
|
||||||
self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
|
self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
|
||||||
if self.model.generation_config.pad_token_id is None:
|
if self.model.generation_config.pad_token_id is None:
|
||||||
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
||||||
self.streamer = TextStreamer(self.tokenizer)
|
self.streamer = TextStreamer(self.tokenizer)
|
||||||
|
|
||||||
def decode_one_tokens(self):
|
def decode_one_tokens(self):
|
||||||
if not hasattr(self, "cuda_graph_runner"):
|
if not hasattr(self, "cuda_graph_runner"):
|
||||||
device_map = self.model.gguf_loader.tensor_device_map
|
device_map = self.model.gguf_loader.tensor_device_map
|
||||||
torch_device = get_device('blk.0.self_attn', device_map)
|
torch_device = get_device("blk.0.self_attn", device_map)
|
||||||
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
||||||
self.cuda_graph_runner = CUDAGraphRunner()
|
self.cuda_graph_runner = CUDAGraphRunner()
|
||||||
self.cuda_graph_runner.capture(self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, main_device=torch_device, return_dict=False, use_cache=True)
|
self.cuda_graph_runner.capture(
|
||||||
|
self.model,
|
||||||
|
self.current_ids,
|
||||||
|
self.active_cache_position.unsqueeze(0),
|
||||||
|
self.active_cache_position,
|
||||||
|
self.cache,
|
||||||
|
main_device=torch_device,
|
||||||
|
return_dict=False,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
|
||||||
if hasattr(self, "cuda_graph_runner"):
|
if hasattr(self, "cuda_graph_runner"):
|
||||||
logits = self.cuda_graph_runner(self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position)
|
logits = self.cuda_graph_runner(
|
||||||
|
self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position
|
||||||
|
)
|
||||||
self.cache.change_seq_length(1)
|
self.cache.change_seq_length(1)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
logits = logits[0,-1,:]
|
logits = logits[0, -1, :]
|
||||||
return self.logits_to_token(logits)
|
return self.logits_to_token(logits)
|
||||||
|
|
||||||
if self.use_static_cache:
|
if self.use_static_cache:
|
||||||
mask = torch.ones((1,self.seq_length)).to(torch_device)
|
mask = torch.ones((1, self.seq_length)).to(torch_device)
|
||||||
logits = self.model(
|
logits = self.model(
|
||||||
self.current_ids,
|
self.current_ids,
|
||||||
cache_position=self.active_cache_position,
|
cache_position=self.active_cache_position,
|
||||||
past_key_values=self.cache,
|
past_key_values=self.cache,
|
||||||
attention_mask=mask,
|
attention_mask=mask,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
use_cache=True
|
use_cache=True,
|
||||||
)[0]
|
)[0]
|
||||||
else:
|
else:
|
||||||
logits = self.model(
|
logits = self.model(self.current_ids, return_dict=False)[0]
|
||||||
self.current_ids,
|
logits = logits[0, -1, :]
|
||||||
return_dict=False
|
|
||||||
)[0]
|
|
||||||
logits = logits[0,-1,:]
|
|
||||||
|
|
||||||
return self.logits_to_token(logits)
|
return self.logits_to_token(logits)
|
||||||
|
|
|
@ -1,14 +1,22 @@
|
||||||
from typing import Any, List, Optional, Set
|
from typing import Any, List, Optional, Set
|
||||||
from transformers import LlamaTokenizer,AutoTokenizer, AutoConfig, LlamaForCausalLM,GenerationConfig, StaticCache, AutoModelForCausalLM,BitsAndBytesConfig
|
from transformers import (
|
||||||
|
LlamaTokenizer,
|
||||||
|
AutoTokenizer,
|
||||||
|
AutoConfig,
|
||||||
|
LlamaForCausalLM,
|
||||||
|
GenerationConfig,
|
||||||
|
StaticCache,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
BitsAndBytesConfig,
|
||||||
|
)
|
||||||
|
|
||||||
from ktransformers.server.schemas.base import ObjectID
|
from ktransformers.server.schemas.base import ObjectID
|
||||||
from ktransformers.server.utils.multi_timer import Profiler
|
from ktransformers.server.utils.multi_timer import Profiler
|
||||||
import torch
|
import torch
|
||||||
import sys, os
|
import sys, os
|
||||||
from ..base import ThreadContext,BackendInterfaceBase
|
from ..base import ThreadContext, BackendInterfaceBase
|
||||||
from ktransformers.server.config.log import logger
|
from ktransformers.server.config.log import logger
|
||||||
from ..args import ConfigArgs,default_args
|
from ..args import ConfigArgs, default_args
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
|
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
|
||||||
|
@ -28,21 +36,20 @@ class TextStreamer:
|
||||||
self.token_cache = []
|
self.token_cache = []
|
||||||
self.print_len = 0
|
self.print_len = 0
|
||||||
|
|
||||||
def put(self, value)->Optional[str]:
|
def put(self, value) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
|
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
|
||||||
"""
|
"""
|
||||||
if not isinstance(value,int):
|
if not isinstance(value, int):
|
||||||
raise ValueError("TextStreamer only supports batch size 1, and int type input")
|
raise ValueError("TextStreamer only supports batch size 1, and int type input")
|
||||||
|
|
||||||
|
|
||||||
if self.skip_prompt and self.next_tokens_are_prompt:
|
if self.skip_prompt and self.next_tokens_are_prompt:
|
||||||
self.next_tokens_are_prompt = False
|
self.next_tokens_are_prompt = False
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Add the new token to the cache and decodes the entire thing.
|
# Add the new token to the cache and decodes the entire thing.
|
||||||
self.token_cache.append(value)
|
self.token_cache.append(value)
|
||||||
text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True,**self.decode_kwargs)
|
text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs)
|
||||||
|
|
||||||
# After the symbol for a new line, we flush the cache.
|
# After the symbol for a new line, we flush the cache.
|
||||||
if text.endswith("\n"):
|
if text.endswith("\n"):
|
||||||
|
@ -59,7 +66,7 @@ class TextStreamer:
|
||||||
self.print_len += len(printable_text)
|
self.print_len += len(printable_text)
|
||||||
return printable_text
|
return printable_text
|
||||||
|
|
||||||
def end(self)->Optional[str]:
|
def end(self) -> Optional[str]:
|
||||||
"""Flushes any remaining cache and prints a newline to stdout."""
|
"""Flushes any remaining cache and prints a newline to stdout."""
|
||||||
# Flush the cache, if it exists
|
# Flush the cache, if it exists
|
||||||
if len(self.token_cache) > 0:
|
if len(self.token_cache) > 0:
|
||||||
|
@ -71,7 +78,7 @@ class TextStreamer:
|
||||||
|
|
||||||
self.next_tokens_are_prompt = True
|
self.next_tokens_are_prompt = True
|
||||||
return printable_text
|
return printable_text
|
||||||
|
|
||||||
def _is_chinese_char(self, cp):
|
def _is_chinese_char(self, cp):
|
||||||
"""Checks whether CP is the codepoint of a CJK character."""
|
"""Checks whether CP is the codepoint of a CJK character."""
|
||||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||||
|
@ -97,101 +104,91 @@ class TextStreamer:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class TransformersThreadContext(ThreadContext):
|
class TransformersThreadContext(ThreadContext):
|
||||||
def get_local_messages(self):
|
def get_local_messages(self):
|
||||||
local_messages = []
|
local_messages = []
|
||||||
for m in self.messages:
|
for m in self.messages:
|
||||||
local_messages.append(
|
local_messages.append({"role": m.role.value, "content": m.get_text_content()})
|
||||||
{'role':m.role.value,
|
|
||||||
'content':m.get_text_content()}
|
|
||||||
)
|
|
||||||
|
|
||||||
return local_messages
|
return local_messages
|
||||||
|
|
||||||
|
|
||||||
class TransformersInterface(BackendInterfaceBase):
|
class TransformersInterface(BackendInterfaceBase):
|
||||||
use_static_cache : bool = True
|
use_static_cache: bool = True
|
||||||
|
|
||||||
|
|
||||||
model: Any
|
model: Any
|
||||||
tokenizer: AutoTokenizer
|
tokenizer: AutoTokenizer
|
||||||
|
|
||||||
cache: StaticCache
|
cache: StaticCache
|
||||||
generated_ids:torch.Tensor
|
generated_ids: torch.Tensor
|
||||||
seq_length:int
|
seq_length: int
|
||||||
|
|
||||||
streamer: TextStreamer
|
streamer: TextStreamer
|
||||||
|
|
||||||
# thread_related
|
# thread_related
|
||||||
last_request_id: Optional[str] = None
|
last_request_id: Optional[str] = None
|
||||||
ever_generated_ids: Set[int] = set()
|
ever_generated_ids: Set[int] = set()
|
||||||
|
|
||||||
|
def __init__(self, args: ConfigArgs = default_args):
|
||||||
|
|
||||||
def __init__(self, args:ConfigArgs = default_args):
|
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device,use_safetensors=True)
|
|
||||||
logger.info(f'{args.model_name} loaded from {args.model_dir} to {args.device}')
|
|
||||||
|
|
||||||
self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=args.device, dtype=self.model.dtype)
|
|
||||||
logger.info(f'StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}')
|
|
||||||
|
|
||||||
self.streamer = TextStreamer(self.tokenizer)
|
|
||||||
|
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device, use_safetensors=True)
|
||||||
|
logger.info(f"{args.model_name} loaded from {args.model_dir} to {args.device}")
|
||||||
|
|
||||||
|
self.cache = StaticCache(
|
||||||
|
config=self.model.config,
|
||||||
|
max_batch_size=args.batch_size,
|
||||||
|
max_cache_len=args.cache_lens,
|
||||||
|
device=args.device,
|
||||||
|
dtype=self.model.dtype,
|
||||||
|
)
|
||||||
|
logger.info(f"StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}")
|
||||||
|
|
||||||
|
self.streamer = TextStreamer(self.tokenizer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_ids(self):
|
def current_ids(self):
|
||||||
return self.generated_ids[:,self.seq_length-1].unsqueeze(1)
|
return self.generated_ids[:, self.seq_length - 1].unsqueeze(1)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def active_cache_position(self):
|
def active_cache_position(self):
|
||||||
return torch.tensor([self.seq_length-1], device=self.args.device)
|
return torch.tensor([self.seq_length - 1], device=self.args.device)
|
||||||
|
|
||||||
|
def tokenize_prompt(self, prompt: str):
|
||||||
def tokenize_prompt(self,prompt:str):
|
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.args.device)
|
||||||
input_ids = self.tokenizer.encode(prompt,return_tensors='pt').to(self.args.device)
|
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
def format_and_tokenize_input_ids(self,thread_id:ObjectID,messages:List):
|
def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):
|
||||||
for m in messages:
|
for m in messages:
|
||||||
if m['role']=='system':
|
if m["role"] == "system":
|
||||||
logger.warn(f'change {m["role"]} to user')
|
logger.warning(f'change {m["role"]} to user')
|
||||||
m['role'] = 'user'
|
m["role"] = "user"
|
||||||
|
|
||||||
new_messages = [messages[0]]
|
new_messages = [messages[0]]
|
||||||
for m in messages[1:]:
|
for m in messages[1:]:
|
||||||
if m['role'] == 'user' and new_messages[-1]['role']=='user':
|
if m["role"] == "user" and new_messages[-1]["role"] == "user":
|
||||||
logger.warn('merge two adjacent user messages')
|
logger.warning("merge two adjacent user messages")
|
||||||
new_messages[-1]['content']+=m['content']
|
new_messages[-1]["content"] += m["content"]
|
||||||
else:
|
else:
|
||||||
new_messages.append(m)
|
new_messages.append(m)
|
||||||
|
|
||||||
|
|
||||||
input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device)
|
|
||||||
|
|
||||||
if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
||||||
x = self.generated_ids[:,:self.seq_length]
|
input_ids = self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt").to(self.args.device)
|
||||||
y = input_ids[:,:self.seq_length]
|
else:
|
||||||
# We can only hope that the input_ids are the same
|
input_ids = self.tokenizer.apply_chat_template(
|
||||||
unequal_mask = torch.ne(x,y)
|
new_messages, return_tensors="pt", add_generation_prompt=True
|
||||||
unequal_positions = torch.nonzero(unequal_mask)
|
).to(self.args.device)
|
||||||
num_unequal_elements = unequal_mask.sum().item()
|
logger.debug(f"get input ids of shape {input_ids.shape}")
|
||||||
logger.warn(f'num_unequal_elements: {num_unequal_elements}')
|
|
||||||
|
|
||||||
input_ids = input_ids[:,self.seq_length:]
|
|
||||||
logger.debug(f'get input ids of shape {input_ids.shape}')
|
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
def append_new_tokens(self,new_tokens:int)->Optional[str]:
|
def append_new_tokens(self, new_tokens: int) -> Optional[str]:
|
||||||
self.generated_ids[0,self.seq_length] = new_tokens
|
self.generated_ids[0, self.seq_length] = new_tokens
|
||||||
self.seq_length+=1
|
self.seq_length += 1
|
||||||
return self.streamer.put(new_tokens)
|
return self.streamer.put(new_tokens)
|
||||||
|
|
||||||
def logits_to_token(self,logits:torch.Tensor):
|
def logits_to_token(self, logits: torch.Tensor):
|
||||||
logits = logits/self.args.temperature
|
logits = logits / self.args.temperature
|
||||||
|
|
||||||
for token_idx in self.ever_generated_ids:
|
for token_idx in self.ever_generated_ids:
|
||||||
if logits[token_idx] < 0:
|
if logits[token_idx] < 0:
|
||||||
|
@ -200,7 +197,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
logits[token_idx] /= self.args.repetition_penalty
|
logits[token_idx] /= self.args.repetition_penalty
|
||||||
|
|
||||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
|
|
||||||
sample = True
|
sample = True
|
||||||
if sample:
|
if sample:
|
||||||
last = torch.multinomial(probs, num_samples=1)
|
last = torch.multinomial(probs, num_samples=1)
|
||||||
|
@ -211,127 +208,124 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
self.ever_generated_ids.add(last)
|
self.ever_generated_ids.add(last)
|
||||||
return last
|
return last
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def decode_one_tokens(self):
|
def decode_one_tokens(self):
|
||||||
if self.use_static_cache:
|
if self.use_static_cache:
|
||||||
mask = torch.ones((1,self.seq_length)).to(self.args.device)
|
mask = torch.ones((1, self.seq_length)).to(self.args.device)
|
||||||
logits = self.model(
|
logits = self.model(
|
||||||
self.current_ids,
|
self.current_ids,
|
||||||
cache_position=self.active_cache_position,
|
cache_position=self.active_cache_position,
|
||||||
past_key_values=self.cache,
|
past_key_values=self.cache,
|
||||||
attention_mask=mask,
|
attention_mask=mask,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
use_cache=True
|
use_cache=True,
|
||||||
)[0]
|
)[0]
|
||||||
else:
|
else:
|
||||||
logits = self.model(
|
logits = self.model(self.current_ids, return_dict=False)[0]
|
||||||
self.current_ids,
|
logits = logits[0, -1, :]
|
||||||
return_dict=False
|
|
||||||
)[0]
|
|
||||||
logits = logits[0,-1,:]
|
|
||||||
|
|
||||||
return self.logits_to_token(logits)
|
return self.logits_to_token(logits)
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def prefill(self,input_ids:torch.Tensor,is_new:bool):
|
def prefill(self, input_ids: torch.Tensor, is_new: bool):
|
||||||
input_ids_length = input_ids.shape[-1]
|
input_ids_length = input_ids.shape[-1]
|
||||||
self.profiler.set_counter('prefill',input_ids_length)
|
self.profiler.set_counter("prefill", input_ids_length)
|
||||||
logger.debug(f'input_ids: {input_ids.shape}')
|
logger.debug(f"input_ids: {input_ids.shape}")
|
||||||
|
|
||||||
|
|
||||||
if is_new:
|
if is_new:
|
||||||
self.cache.reset()
|
self.cache.reset()
|
||||||
self.ever_generated_ids.clear()
|
self.ever_generated_ids.clear()
|
||||||
former_seq_length = 0
|
former_seq_length = 0
|
||||||
self.seq_length = input_ids_length
|
self.seq_length = input_ids_length
|
||||||
self.generated_ids = torch.zeros(
|
self.generated_ids = torch.zeros(
|
||||||
self.args.batch_size, self.seq_length + self.args.max_new_tokens + 1, dtype=torch.int, device=self.args.device
|
self.args.batch_size,
|
||||||
)
|
self.seq_length + self.args.max_new_tokens + 1,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=self.args.device,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(f'generate_ids: {self.generated_ids.shape}')
|
logger.debug(f"generate_ids: {self.generated_ids.shape}")
|
||||||
former_seq_length = self.seq_length
|
former_seq_length = self.seq_length
|
||||||
self.seq_length += input_ids_length
|
self.seq_length += input_ids_length
|
||||||
expected_length = self.seq_length + self.args.max_new_tokens+1
|
expected_length = self.seq_length + self.args.max_new_tokens + 1
|
||||||
delta_length = expected_length - self.generated_ids.shape[-1]
|
delta_length = expected_length - self.generated_ids.shape[-1]
|
||||||
if delta_length>0:
|
if delta_length > 0:
|
||||||
new_generate_ids = torch.zeros(
|
new_generate_ids = torch.zeros(
|
||||||
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
|
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
|
||||||
)
|
)
|
||||||
self.generated_ids = torch.cat([self.generated_ids,new_generate_ids],dim=-1)
|
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
|
||||||
logger.debug(f'cache position: {former_seq_length} to {self.seq_length}')
|
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
|
||||||
cache_position = torch.arange(former_seq_length,self.seq_length, device=self.args.device)
|
cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)
|
||||||
self.generated_ids[:,cache_position] = input_ids.to(self.args.device).to(torch.int)
|
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
|
||||||
|
|
||||||
mask = torch.ones((1,self.seq_length)).to(self.args.device)
|
mask = torch.ones((1, self.seq_length)).to(self.args.device)
|
||||||
device = input_ids.device
|
device = input_ids.device
|
||||||
if not(type(self) is TransformersInterface):
|
if not (type(self) is TransformersInterface):
|
||||||
input_ids = input_ids.to("cpu")
|
input_ids = input_ids.to("cpu")
|
||||||
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
||||||
if self.use_static_cache:
|
if self.use_static_cache:
|
||||||
logits = self.model(
|
logits = self.model(
|
||||||
inputs_embeds=inputs_embeds, cache_position=cache_position, past_key_values=self.cache,return_dict=False, use_cache=True,attention_mask=mask
|
inputs_embeds=inputs_embeds,
|
||||||
|
cache_position=cache_position,
|
||||||
|
past_key_values=self.cache,
|
||||||
|
return_dict=False,
|
||||||
|
use_cache=True,
|
||||||
|
attention_mask=mask,
|
||||||
)[0]
|
)[0]
|
||||||
else:
|
else:
|
||||||
logits = self.model(
|
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||||
inputs_embeds=inputs_embeds,return_dict=False
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
|
next_token = self.logits_to_token(logits[0, -1, :])
|
||||||
|
|
||||||
next_token = self.logits_to_token(logits[0,-1,:])
|
|
||||||
yield self.append_new_tokens(next_token)
|
yield self.append_new_tokens(next_token)
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def generate(self):
|
def generate(self):
|
||||||
self.profiler.set_counter('decode',0)
|
self.profiler.set_counter("decode", 0)
|
||||||
for _ in range(1, self.args.max_new_tokens):
|
for _ in range(1, self.args.max_new_tokens):
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
||||||
next_token = self.decode_one_tokens()
|
next_token = self.decode_one_tokens()
|
||||||
self.profiler.inc('decode')
|
self.profiler.inc("decode")
|
||||||
if next_token == self.tokenizer.eos_token_id:
|
if next_token == self.tokenizer.eos_token_id:
|
||||||
assert self.args.batch_size == 1
|
assert self.args.batch_size == 1
|
||||||
break
|
break
|
||||||
yield self.append_new_tokens(next_token)
|
yield self.append_new_tokens(next_token)
|
||||||
yield self.streamer.end()
|
yield self.streamer.end()
|
||||||
|
|
||||||
def check_is_new(self,thread_id:str):
|
def check_is_new(self, thread_id: str):
|
||||||
if not self.use_static_cache:
|
if not self.use_static_cache:
|
||||||
return True
|
return True
|
||||||
if self.last_request_id is None:
|
if self.last_request_id is None:
|
||||||
self.last_request_id = thread_id
|
self.last_request_id = thread_id
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
if self.last_request_id==thread_id:
|
if self.last_request_id == thread_id:
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
self.last_request_id = thread_id
|
self.last_request_id = thread_id
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def inference(self,local_messages,thread_id:str):
|
async def inference(self, local_messages, thread_id: str):
|
||||||
self.profiler.create_and_start_timer('tokenize')
|
self.profiler.create_and_start_timer("tokenize")
|
||||||
if isinstance(local_messages,List):
|
if isinstance(local_messages, List):
|
||||||
input_ids = self.format_and_tokenize_input_ids(thread_id,local_messages)
|
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
||||||
elif isinstance(local_messages,str):
|
elif isinstance(local_messages, str):
|
||||||
input_ids = self.tokenize_prompt(local_messages)
|
input_ids = self.tokenize_prompt(local_messages)
|
||||||
else:
|
else:
|
||||||
raise ValueError('local_messages should be List or str')
|
raise ValueError("local_messages should be List or str")
|
||||||
|
|
||||||
self.profiler.pause_timer('tokenize')
|
self.profiler.pause_timer("tokenize")
|
||||||
|
|
||||||
self.profiler.create_and_start_timer('prefill')
|
self.profiler.create_and_start_timer("prefill")
|
||||||
for t in self.prefill(input_ids,self.check_is_new(thread_id)):
|
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
|
||||||
if t is not None:
|
if t is not None:
|
||||||
print(t,end='')
|
print(t, end="")
|
||||||
yield t
|
yield t
|
||||||
self.profiler.pause_timer('prefill')
|
self.profiler.pause_timer("prefill")
|
||||||
|
|
||||||
self.profiler.create_and_start_timer('decode')
|
self.profiler.create_and_start_timer("decode")
|
||||||
for t in self.generate():
|
for t in self.generate():
|
||||||
if t is not None:
|
if t is not None:
|
||||||
print(t,end='')
|
print(t, end="")
|
||||||
yield t
|
yield t
|
||||||
print('')
|
print("")
|
||||||
self.profiler.pause_timer('decode')
|
self.profiler.pause_timer("decode")
|
||||||
self.report_last_time_performance()
|
self.report_last_time_performance()
|
||||||
|
|
||||||
|
|
|
@ -1,23 +1,24 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
'''
|
"""
|
||||||
Description :
|
Description :
|
||||||
Author : unicornchan
|
Author : unicornchan
|
||||||
Date : 2024-06-11 16:35:42
|
Date : 2024-06-11 16:35:42
|
||||||
Version : 1.0.0
|
Version : 1.0.0
|
||||||
LastEditors : WuHao
|
LastEditors : WuHao
|
||||||
LastEditTime : 2024-08-12 06:31:14
|
LastEditTime : 2024-08-12 06:31:14
|
||||||
'''
|
"""
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from ktransformers.server.config.singleton import Singleton
|
from ktransformers.server.config.singleton import Singleton
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class Config(metaclass=Singleton):
|
class Config(metaclass=Singleton):
|
||||||
"""Singleton pattern Config class, used to get all configurations.
|
"""Singleton pattern Config class, used to get all configurations."""
|
||||||
"""
|
|
||||||
CONFIG_FILE_NAME = "config.yaml"
|
CONFIG_FILE_NAME = "config.yaml"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -27,22 +28,20 @@ class Config(metaclass=Singleton):
|
||||||
Returns:
|
Returns:
|
||||||
dict: all configs
|
dict: all configs
|
||||||
"""
|
"""
|
||||||
base_path: str = os.path.dirname(
|
base_path: str = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
os.path.dirname(os.path.dirname(__file__)))
|
config_yaml: str = os.path.join(base_path, "configs", Config.CONFIG_FILE_NAME)
|
||||||
config_yaml: str = os.path.join(
|
|
||||||
base_path, "configs", Config.CONFIG_FILE_NAME)
|
user_path: str = os.path.expanduser("~")
|
||||||
|
localstore_path: str = os.path.join(user_path, ".ktransformers")
|
||||||
user_path: str = os.path.expanduser('~')
|
config_path: str = os.path.join(localstore_path, Config.CONFIG_FILE_NAME)
|
||||||
localstore_path: str = os.path.join(user_path,'.ktransformers')
|
|
||||||
config_path: str = os.path.join(localstore_path,Config.CONFIG_FILE_NAME)
|
|
||||||
if not os.path.exists(config_yaml):
|
if not os.path.exists(config_yaml):
|
||||||
print(f"Can't find config file, {config_yaml}")
|
print(f"Can't find config file, {config_yaml}")
|
||||||
exit(-1)
|
exit(-1)
|
||||||
if not os.path.exists(localstore_path):
|
if not os.path.exists(localstore_path):
|
||||||
os.mkdir(localstore_path)
|
os.mkdir(localstore_path)
|
||||||
if not os.path.exists(config_path):
|
if not os.path.exists(config_path):
|
||||||
shutil.copyfile(config_yaml,config_path)
|
shutil.copyfile(config_yaml, config_path)
|
||||||
with open(config_path, 'r', encoding="utf-8") as fp:
|
with open(config_path, "r", encoding="utf-8") as fp:
|
||||||
config = yaml.safe_load(fp)
|
config = yaml.safe_load(fp)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
@ -52,16 +51,14 @@ class Config(metaclass=Singleton):
|
||||||
process file path
|
process file path
|
||||||
"""
|
"""
|
||||||
base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
real_path = path if os.path.isabs(
|
real_path = path if os.path.isabs(path) else os.path.join(base_path, path)
|
||||||
path) else os.path.join(base_path, path)
|
|
||||||
return real_path
|
return real_path
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
cfg = Config.load()
|
cfg = Config.load()
|
||||||
self.base_path = os.path.dirname(
|
self.base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
os.path.dirname(os.path.dirname(__file__)))
|
self.user_path: str = os.path.expanduser("~")
|
||||||
self.user_path: str = os.path.expanduser('~')
|
self.localstore_path: str = os.path.join(self.user_path, ".ktransformers")
|
||||||
self.localstore_path: str = os.path.join(self.user_path,'.ktransformers')
|
|
||||||
# log configs
|
# log configs
|
||||||
self.log_dir = os.path.join(self.base_path, Config.to_path(cfg["log"]["dir"]))
|
self.log_dir = os.path.join(self.base_path, Config.to_path(cfg["log"]["dir"]))
|
||||||
self.log_file = cfg["log"]["file"]
|
self.log_file = cfg["log"]["file"]
|
||||||
|
@ -69,7 +66,7 @@ class Config(metaclass=Singleton):
|
||||||
self.backup_count = cfg["log"]["backup_count"]
|
self.backup_count = cfg["log"]["backup_count"]
|
||||||
|
|
||||||
# server configs
|
# server configs
|
||||||
self.server: dict = cfg.get("server",{})
|
self.server: dict = cfg.get("server", {})
|
||||||
self.server_ip = self.server.get("ip", "0.0.0.0")
|
self.server_ip = self.server.get("ip", "0.0.0.0")
|
||||||
self.server_port = self.server.get("port", 9016)
|
self.server_port = self.server.get("port", 9016)
|
||||||
|
|
||||||
|
@ -86,16 +83,66 @@ class Config(metaclass=Singleton):
|
||||||
self.user_config: dict = cfg.get("user", {})
|
self.user_config: dict = cfg.get("user", {})
|
||||||
self.user_secret_key = self.user_config.get("secret_key", "")
|
self.user_secret_key = self.user_config.get("secret_key", "")
|
||||||
self.user_algorithm = self.user_config.get("algorithm", "")
|
self.user_algorithm = self.user_config.get("algorithm", "")
|
||||||
|
|
||||||
# model config
|
# model config
|
||||||
self.model:dict = cfg.get("model", {})
|
self.model: dict = cfg.get("model", {})
|
||||||
self.backend_type: str = self.model.get("type", "transformers")
|
self.backend_type: str = self.model.get("type", "transformers")
|
||||||
self.model_path: str = self.model.get("path", "")
|
self.model_dir: str = self.model.get("path", "")
|
||||||
self.model_name: str = self.model.get("name", "")
|
self.model_name: str = self.model.get("name", "")
|
||||||
self.model_device: str = self.model.get("device", "cuda:0")
|
self.model_device: str = self.model.get("device", "cuda:0")
|
||||||
self.gguf_path: str = self.model.get("gguf_path", "")
|
self.gguf_path: Optional[str] = self.model.get("gguf_path", None)
|
||||||
self.model_cache_lens = self.model.get("cache_lens")
|
# self.model_cache_lens = self.model.get("cache_lens")
|
||||||
|
self.optimize_config_path: Optional[str] = self.model.get(
|
||||||
|
"optimize_config_path", "./ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml"
|
||||||
|
)
|
||||||
|
self.paged = self.model.get("paged", True)
|
||||||
|
|
||||||
|
self.total_context = self.model.get("total_context", 2**18)
|
||||||
|
self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1)
|
||||||
|
self.max_chunk_size = self.model.get("max_chunk_size", 2048)
|
||||||
|
self.max_new_tokens = self.model.get("max_new_tokens", 500)
|
||||||
|
self.json_mode = self.model.get("json_mode", False)
|
||||||
|
self.healing = self.model.get("healing", False)
|
||||||
|
self.ban_strings: Optional[list] = self.model.get("ban_strings", None)
|
||||||
|
self.gpu_split: Optional[str] = self.model.get("gpu_split", None)
|
||||||
|
self.length: Optional[int] = self.model.get("length", None)
|
||||||
|
self.rope_scale: Optional[float] = self.model.get("rope_scale", None)
|
||||||
|
self.rope_alpha: Optional[float] = self.model.get("rope_alpha", None)
|
||||||
|
self.no_flash_attn = self.model.get("no_flash_attn", False)
|
||||||
|
self.low_mem = self.model.get("low_mem", False)
|
||||||
|
self.experts_per_token: Optional[int] = self.model.get("experts_per_token", None)
|
||||||
|
self.load_q4 = self.model.get("load_q4", False)
|
||||||
|
self.fast_safetensors = self.model.get("fast_safetensors", False)
|
||||||
|
self.draft_model_dir: Optional[str] = self.model.get("draft_model_dir", None)
|
||||||
|
self.no_draft_scale = self.model.get("no_draft_scale", False)
|
||||||
|
self.modes = self.model.get("modes", False)
|
||||||
|
self.mode = self.model.get("mode", "llama")
|
||||||
|
self.username = self.model.get("username", "User")
|
||||||
|
self.botname = self.model.get("botname", "Chatbort")
|
||||||
|
self.system_prompt: Optional[str] = self.model.get("system_prompt", None)
|
||||||
|
self.temperature = self.model.get("temperature", 0.95)
|
||||||
|
self.smoothing_factor = self.model.get("smoothing_factor", 0.0)
|
||||||
|
self.dynamic_temperature: Optional[str] = self.model.get("dynamic_temperature", None)
|
||||||
|
self.top_k = self.model.get("top_k", 50)
|
||||||
|
self.top_p = self.model.get("top_p", 0.8)
|
||||||
|
self.top_a = self.model.get("top_a", 0.0)
|
||||||
|
self.skew = self.model.get("skew", 0.0)
|
||||||
|
self.typical = self.model.get("typical", 0.0)
|
||||||
|
self.repetition_penalty = self.model.get("repetition_penalty", 1.01)
|
||||||
|
self.frequency_penalty = self.model.get("frequency_penalty", 0.0)
|
||||||
|
self.presence_penalty = self.model.get("presence_penalty", 0.0)
|
||||||
|
self.max_response_tokens = self.model.get("max_response_tokens", 300)
|
||||||
|
self.response_chunk = self.model.get("response_chunk", 250)
|
||||||
|
self.no_code_formatting = self.model.get("no_code_formatting", False)
|
||||||
|
self.cache_8bit = self.model.get("cache_8bit", False)
|
||||||
|
self.cache_q4 = self.model.get("cache_q4", True)
|
||||||
|
self.ngram_decoding = self.model.get("ngram_decoding", False)
|
||||||
|
self.print_timings = self.model.get("print_timings", False)
|
||||||
|
self.amnesia = self.model.get("amnesia", False)
|
||||||
|
self.batch_size = self.model.get("batch_size", 1)
|
||||||
|
self.cache_lens = self.model.get("cache_lens", 4096)
|
||||||
|
self.device = self.model.get("device", "cuda:2")
|
||||||
|
|
||||||
# web config
|
# web config
|
||||||
self.web: dict = cfg.get("web", {})
|
self.web: dict = cfg.get("web", {})
|
||||||
self.web_cross_domain: bool = self.web.get("open_cross_domain", True)
|
self.web_cross_domain: bool = self.web.get("open_cross_domain", True)
|
||||||
|
@ -104,10 +151,32 @@ class Config(metaclass=Singleton):
|
||||||
self.ext: dict = cfg.get("ext", {})
|
self.ext: dict = cfg.get("ext", {})
|
||||||
self.cpu_infer = self.ext.get("cpu_infer", 10)
|
self.cpu_infer = self.ext.get("cpu_infer", 10)
|
||||||
|
|
||||||
#file config
|
# file config
|
||||||
self.local_store_configs: dict = cfg.get("local_store",{})
|
self.local_store_configs: dict = cfg.get("local_store", {})
|
||||||
self.file_upload_dir: str = os.path.join(self.localstore_path,self.local_store_configs.get("file_upload_dir",""))
|
self.file_upload_dir: str = os.path.join(
|
||||||
self.assistant_store_dir: str = os.path.join(self.localstore_path,self.local_store_configs.get("assistant_store_dir",""))
|
self.localstore_path, self.local_store_configs.get("file_upload_dir", "")
|
||||||
|
)
|
||||||
|
self.assistant_store_dir: str = os.path.join(
|
||||||
|
self.localstore_path, self.local_store_configs.get("assistant_store_dir", "")
|
||||||
|
)
|
||||||
|
|
||||||
#long context config
|
# long context config
|
||||||
self.long_context_config: dict = cfg.get("long_context",{})
|
self.long_context_config: dict = cfg.get("long_context", {})
|
||||||
|
self.chunk_size = self.long_context_config.get("chunk_size", 4096)
|
||||||
|
self.max_seq_len = self.long_context_config.get("max_seq_len", 32000)
|
||||||
|
self.block_size = self.long_context_config.get("block_size", 128)
|
||||||
|
self.local_windows_len = self.long_context_config.get("local_windows_len", 4096)
|
||||||
|
self.second_select_num = self.long_context_config.get("second_select_num", 32)
|
||||||
|
self.anchor_type = self.long_context_config.get("anchor_type", "DYNAMIC")
|
||||||
|
self.kv_type = self.long_context_config.get("kv_type", "FP16")
|
||||||
|
self.dense_layer_num = self.long_context_config.get("dense_layer_num", 2)
|
||||||
|
self.anchor_num = self.long_context_config.get("anchor_num", 1)
|
||||||
|
self.preselect_block = self.long_context_config.get("preselect_block", True)
|
||||||
|
self.head_select_mode = self.long_context_config.get("head_select_mode", "SHARED")
|
||||||
|
self.preselect_block_count = self.long_context_config.get("preselect_block_count", 32)
|
||||||
|
self.layer_step = self.long_context_config.get("layer_step", 1)
|
||||||
|
self.token_step = self.long_context_config.get("token_step", 100)
|
||||||
|
|
||||||
|
# local chat
|
||||||
|
self.local_chat_config: dict = cfg.get("local_chat", {})
|
||||||
|
self.prompt_file = self.local_chat_config.get("prompt_file", None)
|
||||||
|
|
|
@ -3,11 +3,11 @@ import re
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
import uvicorn.logging
|
import uvicorn.logging
|
||||||
import argparse
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from ktransformers.server.args import ArgumentParser
|
||||||
from ktransformers.server.config.config import Config
|
from ktransformers.server.config.config import Config
|
||||||
from ktransformers.server.utils.create_interface import create_interface
|
from ktransformers.server.utils.create_interface import create_interface
|
||||||
from ktransformers.server.backend.args import default_args
|
from ktransformers.server.backend.args import default_args
|
||||||
from fastapi.openapi.utils import get_openapi
|
from fastapi.openapi.utils import get_openapi
|
||||||
|
|
||||||
|
@ -44,8 +44,11 @@ def create_app():
|
||||||
mount_index_routes(app)
|
mount_index_routes(app)
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
def update_web_port(config_file: str):
|
def update_web_port(config_file: str):
|
||||||
ip_port_pattern = r"(localhost|((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)):[0-9]{1,5}"
|
ip_port_pattern = (
|
||||||
|
r"(localhost|((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)):[0-9]{1,5}"
|
||||||
|
)
|
||||||
with open(config_file, "r", encoding="utf-8") as f_cfg:
|
with open(config_file, "r", encoding="utf-8") as f_cfg:
|
||||||
web_config = f_cfg.read()
|
web_config = f_cfg.read()
|
||||||
ip_port = "localhost:" + str(Config().server_port)
|
ip_port = "localhost:" + str(Config().server_port)
|
||||||
|
@ -70,14 +73,15 @@ def mount_index_routes(app: FastAPI):
|
||||||
|
|
||||||
def run_api(app, host, port, **kwargs):
|
def run_api(app, host, port, **kwargs):
|
||||||
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
||||||
uvicorn.run(app,
|
uvicorn.run(
|
||||||
host=host,
|
app,
|
||||||
port=port,
|
host=host,
|
||||||
ssl_keyfile=kwargs.get("ssl_keyfile"),
|
port=port,
|
||||||
ssl_certfile=kwargs.get("ssl_certfile"),
|
ssl_keyfile=kwargs.get("ssl_keyfile"),
|
||||||
)
|
ssl_certfile=kwargs.get("ssl_certfile"),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
uvicorn.run(app, host=host, port=port, log_level='debug')
|
uvicorn.run(app, host=host, port=port, log_level="debug")
|
||||||
|
|
||||||
|
|
||||||
def custom_openapi(app):
|
def custom_openapi(app):
|
||||||
|
@ -90,53 +94,27 @@ def custom_openapi(app):
|
||||||
description="We provided chat completion and openai assistant interfaces.",
|
description="We provided chat completion and openai assistant interfaces.",
|
||||||
routes=app.routes,
|
routes=app.routes,
|
||||||
)
|
)
|
||||||
openapi_schema["info"]["x-logo"] = {
|
openapi_schema["info"]["x-logo"] = {"url": "https://kvcache.ai/media/icon_1.png"}
|
||||||
"url": "https://kvcache.ai/media/icon_1.png"
|
|
||||||
}
|
|
||||||
app.openapi_schema = openapi_schema
|
app.openapi_schema = openapi_schema
|
||||||
return app.openapi_schema
|
return app.openapi_schema
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
parser = argparse.ArgumentParser(prog='kvcache.ai',
|
arg_parser = ArgumentParser(cfg)
|
||||||
description='Ktransformers')
|
|
||||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
|
||||||
parser.add_argument("--port", type=int, default=cfg.server_port)
|
|
||||||
parser.add_argument("--ssl_keyfile", type=str)
|
|
||||||
parser.add_argument("--ssl_certfile", type=str)
|
|
||||||
parser.add_argument("--web", type=bool, default=False)
|
|
||||||
parser.add_argument("--model_name", type=str, default=cfg.model_name)
|
|
||||||
parser.add_argument("--model_path", type=str, default=cfg.model_path)
|
|
||||||
parser.add_argument("--device", type=str, default=cfg.model_device, help="Warning: Abandoning this parameter")
|
|
||||||
parser.add_argument("--gguf_path", type=str, default=cfg.gguf_path)
|
|
||||||
parser.add_argument("--optimize_config_path", default=None, type=str, required=False)
|
|
||||||
parser.add_argument("--cpu_infer", type=int, default=cfg.cpu_infer)
|
|
||||||
parser.add_argument("--type", type=str, default=cfg.backend_type)
|
|
||||||
|
|
||||||
# 初始化消息
|
# 初始化消息
|
||||||
args = parser.parse_args()
|
args = arg_parser.parse_args()
|
||||||
cfg.model_name = args.model_name
|
|
||||||
cfg.model_path = args.model_path
|
|
||||||
cfg.model_device = args.device
|
|
||||||
cfg.mount_web = args.web
|
|
||||||
cfg.server_ip = args.host
|
|
||||||
cfg.server_port = args.port
|
|
||||||
cfg.cpu_infer = args.cpu_infer
|
|
||||||
cfg.backend_type = args.type
|
|
||||||
|
|
||||||
default_args.model_dir = args.model_path
|
|
||||||
default_args.device = args.device
|
|
||||||
default_args.gguf_path = args.gguf_path
|
|
||||||
default_args.optimize_config_path = args.optimize_config_path
|
|
||||||
|
|
||||||
app = create_app()
|
app = create_app()
|
||||||
custom_openapi(app)
|
custom_openapi(app)
|
||||||
create_interface(config=cfg, default_args=default_args)
|
create_interface(config=cfg, default_args=cfg)
|
||||||
run_api(app=app,
|
run_api(
|
||||||
host=args.host,
|
app=app,
|
||||||
port=args.port,
|
host=args.host,
|
||||||
ssl_keyfile=args.ssl_keyfile,
|
port=args.port,
|
||||||
ssl_certfile=args.ssl_certfile,)
|
ssl_keyfile=args.ssl_keyfile,
|
||||||
|
ssl_certfile=args.ssl_certfile,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
3
ktransformers/website/package-lock.json
generated
3
ktransformers/website/package-lock.json
generated
|
@ -4412,8 +4412,9 @@
|
||||||
},
|
},
|
||||||
"node_modules/@vue/cli": {
|
"node_modules/@vue/cli": {
|
||||||
"version": "5.0.8",
|
"version": "5.0.8",
|
||||||
"resolved": "https://registry.npmmirror.com/@vue/cli/-/cli-5.0.8.tgz",
|
"resolved": "https://registry.npmjs.org/@vue/cli/-/cli-5.0.8.tgz",
|
||||||
"integrity": "sha512-c/QKPdC09bYkW22m/boXkLaiz10z0Z2WHZO7zEeNdfSduqyWINZhKc6hVQU3Vk0NXW7BJAd7zWmcUrC8L9TuAA==",
|
"integrity": "sha512-c/QKPdC09bYkW22m/boXkLaiz10z0Z2WHZO7zEeNdfSduqyWINZhKc6hVQU3Vk0NXW7BJAd7zWmcUrC8L9TuAA==",
|
||||||
|
"license": "MIT",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@types/ejs": "^3.0.6",
|
"@types/ejs": "^3.0.6",
|
||||||
"@types/inquirer": "^8.1.3",
|
"@types/inquirer": "^8.1.3",
|
||||||
|
|
|
@ -69,4 +69,8 @@ ktransformers = "ktransformers.server.main:main"
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["./", ]
|
where = ["./", ]
|
||||||
include = ["ktransformers"]
|
include = ["ktransformers"]
|
||||||
|
[tool.black]
|
||||||
|
line-length = 120
|
||||||
|
preview = true
|
||||||
|
unstable = true
|
Loading…
Add table
Add a link
Reference in a new issue