diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..fadb988 --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +max-line-length = 120 +extend-select = B950 +extend-ignore = E203,E501,E701, B001,B006,B007,B008,B009,B010,B011,B016,B028,B031,B950,E265,E266,E401,E402,E711,E712,E713,E721,E722,E731,F401,F403,F405,F541,F811,F821,F841,W391 \ No newline at end of file diff --git a/.gitignore b/.gitignore index 1bb8666..c33a95d 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,5 @@ compile_commands.json *dist/ ktransformers/server/local_store/ ktransformers/server_test1.db -*.patch \ No newline at end of file +*.patch +img/ \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..dbf771d --- /dev/null +++ b/Makefile @@ -0,0 +1,21 @@ +flake_find: + cd ktransformers && flake8 | grep -Eo '[A-Z][0-9]{3}' | sort | uniq| paste -sd ',' - +format: + @cd ktransformers && black . + @black setup.py +dev_install: +# clear build dirs + rm -rf build + rm -rf *.egg-info + rm -rf ktransformers/ktransformers_ext/build + rm -rf ktransformers/ktransformers_ext/cuda/build + rm -rf ktransformers/ktransformers_ext/cuda/dist + rm -rf ktransformers/ktransformers_ext/cuda/*.egg-info + +# install ktransformers + echo "Installing python dependencies from requirements.txt" + pip install -r requirements-local_chat.txt + + echo "Installing ktransformers" + KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . --no-build-isolation + echo "Installation completed successfully" \ No newline at end of file diff --git a/README.md b/README.md index 7d6f342..eb23bf8 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ Some preparation: ``` install.bat ``` - +4. If you are developer, you can make use of the makefile to compile and format the code.
the detailed usage of makefile is [here](./doc/en/makefile_usage.md)

Local Chat

We provide a simple command-line local chat Python script that you can run for testing. diff --git a/doc/en/makefile_usage.md b/doc/en/makefile_usage.md new file mode 100644 index 0000000..599173b --- /dev/null +++ b/doc/en/makefile_usage.md @@ -0,0 +1,26 @@ +# Makefile +## Target +### flake_find: +```bash +make flake_find +``` +find all the python files under ./ktransformers dir and find the Error, Warning, Fatal... (their codes) into a list that are not consistent with the pep8 standard. For now we have get all this list in the .flake8 file's extend-ignore section in order to let flakes8 ignore them temporarily.(we may improve them in the future) +### format: +```bash +make format +``` +we use black to format all the python files under ./ktransformers dir. It obeys the pep8 standard +but we modify the line length to 120 by add +```toml +[tool.black] +line-length = 120 +preview = true +unstable = true +``` +in the pyproject.toml file. + +### dev_install: +```bash +make dev_install +``` +install the package in the development mode. It means that the package is installed in the editable mode. So if you modify the code, you don't need to reinstall the package. We recommend the developer to use this method to install the package. \ No newline at end of file diff --git a/ktransformers/configs/config.yaml b/ktransformers/configs/config.yaml index 4078e24..7bde376 100644 --- a/ktransformers/configs/config.yaml +++ b/ktransformers/configs/config.yaml @@ -7,7 +7,7 @@ log: server: ip: 0.0.0.0 - port: 12456 + port: 10002 db: type: "sqllite" @@ -24,10 +24,11 @@ model: type: ktransformers name: DeepSeek-Coder-V2-Instruct - path: /mnt/data/model/DeepSeek-Coder-V2-Instruct/ - gguf_path: /mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH/ + path: deepseek-ai/DeepSeek-V2-Lite-Chat + gguf_path: ./DeepSeek-V2-Lite-Chat-GGUF device: cuda:0 + cache_lens: 8192 web: mount: False @@ -50,4 +51,7 @@ long_context: head_select_mode: SHARED preselect_block_count: 32 layer_step: 1 - token_step: 100 \ No newline at end of file + token_step: + +local_chat: + prompt_file: "./ktransformers/p.txt" \ No newline at end of file diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 1057e82..41f98a1 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -1,33 +1,23 @@ """ -Description : +Description : Author : Boxin Zhang, Azure-Tang Version : 0.1.0 -Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +Copyright (c) 2024 by KVCache.AI, All Rights Reserved. """ +import asyncio import os import platform import sys - project_dir = os.path.dirname(os.path.dirname(__file__)) sys.path.insert(0, project_dir) -import torch -import logging -from transformers import ( - AutoTokenizer, - AutoConfig, - AutoModelForCausalLM, - GenerationConfig, - TextStreamer, -) -import json -import fire -from ktransformers.optimize.optimize import optimize_and_load_gguf +from ktransformers.server.args import ArgumentParser + + from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM -from ktransformers.util.utils import prefill_and_generate from ktransformers.server.config.config import Config custom_models = { @@ -37,9 +27,7 @@ custom_models = { "MixtralForCausalLM": MixtralForCausalLM, } -ktransformer_rules_dir = ( - os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" -) +ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" default_optimize_rules = { "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml", @@ -48,75 +36,28 @@ default_optimize_rules = { } -def local_chat( - model_path: str | None = None, - optimize_rule_path: str = None, - gguf_path: str | None = None, - max_new_tokens: int = 1000, - cpu_infer: int = Config().cpu_infer, - use_cuda_graph: bool = True, - prompt_file : str | None = None, - mode: str = "normal", -): - - - torch.set_grad_enabled(False) - - Config().cpu_infer = cpu_infer - - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - if mode == 'long_context': - assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode" - torch.set_default_dtype(torch.float16) +def local_chat(): + config = Config() + arg_parser = ArgumentParser(config) + # 初始化消息 + arg_parser.parse_args() + if config.backend_type == "transformers": + from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface + elif config.backend_type == "exllamav2": + from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface + elif config.backend_type == "ktransformers": + from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface else: - torch.set_default_dtype(config.torch_dtype) - - with torch.device("meta"): - if config.architectures[0] in custom_models: - print("using custom modeling_xxx.py.") - if ( - "Qwen2Moe" in config.architectures[0] - ): # Qwen2Moe must use flash_attention_2 to avoid overflow. - config._attn_implementation = "flash_attention_2" - if "Llama" in config.architectures[0]: - config._attn_implementation = "eager" - if "Mixtral" in config.architectures[0]: - config._attn_implementation = "flash_attention_2" - - model = custom_models[config.architectures[0]](config) - else: - model = AutoModelForCausalLM.from_config( - config, trust_remote_code=True, attn_implementation="flash_attention_2" - ) - - if optimize_rule_path is None: - if config.architectures[0] in default_optimize_rules: - print("using default_optimize_rule for", config.architectures[0]) - optimize_rule_path = default_optimize_rules[config.architectures[0]] - else: - optimize_rule_path = input( - "please input the path of your rule file(yaml file containing optimize rules):" - ) - - if gguf_path is None: - gguf_path = input( - "please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):" - ) - optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config) - - model.generation_config = GenerationConfig.from_pretrained(model_path) - if model.generation_config.pad_token_id is None: - model.generation_config.pad_token_id = model.generation_config.eos_token_id - model.eval() - logging.basicConfig(level=logging.INFO) + raise NotImplementedError(f"{config.backend_type} not implemented") + interface = BackendInterface(config) system = platform.system() if system == "Windows": os.system("cls") else: os.system("clear") - + # add a history chat content + his_content = [] while True: content = input("Chat: ") if content.startswith('"""'): # prefix """ @@ -132,28 +73,27 @@ def local_chat( break else: content += line + "\n" - if content == "": - if prompt_file != None: - content = open(prompt_file, "r").read() - else: + if config.prompt_file == None or config.prompt_file == "": content = "Please write a piece of quicksort code in C++." + else: + content = open(config.prompt_file, "r").read() elif os.path.isfile(content): content = open(content, "r").read() - messages = [{"role": "user", "content": content}] - input_tensor = tokenizer.apply_chat_template( - messages, add_generation_prompt=True, return_tensors="pt" - ) - if mode == 'long_context': - assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ - "please change max_seq_len in ~/.ktransformers/config.yaml" - torch.set_default_dtype( - torch.bfloat16 - ) # TODO: Remove this, replace dtype using config - generated = prefill_and_generate( - model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode - ) + messages = his_content + [{"role": "user", "content": content}] + + async def async_inference(messages): + generated = "" + async for token in interface.inference(messages, "local_chat"): + generated += token + return generated + + generated = asyncio.run(async_inference(messages)) + his_content += [ + {"role": "user", "content": content}, + {"role": "assistant", "content": generated}, + ] if __name__ == "__main__": - fire.Fire(local_chat) + local_chat() diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml new file mode 100644 index 0000000..b115aba --- /dev/null +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml @@ -0,0 +1,56 @@ +- match: + class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbedding + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_deepseek.DeepseekV2MoE + replace: + class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" \ No newline at end of file diff --git a/ktransformers/server/api/ollama/completions.py b/ktransformers/server/api/ollama/completions.py index d0eb346..e3a1a51 100644 --- a/ktransformers/server/api/ollama/completions.py +++ b/ktransformers/server/api/ollama/completions.py @@ -135,6 +135,8 @@ class OllamaShowResponse(BaseModel): details: OllamaShowDetial model_info: OllamaModelInfo + class Config: + protected_namespaces = () diff --git a/ktransformers/server/args.py b/ktransformers/server/args.py new file mode 100644 index 0000000..5c0bb03 --- /dev/null +++ b/ktransformers/server/args.py @@ -0,0 +1,124 @@ +import argparse +from ktransformers.server.backend.args import ConfigArgs, default_args + + +class ArgumentParser: + def __init__(self, cfg): + self.cfg = cfg + + def parse_args(self): + parser = argparse.ArgumentParser(prog="kvcache.ai", description="Ktransformers") + parser.add_argument("--host", type=str, default=self.cfg.server_ip) + parser.add_argument("--port", type=int, default=self.cfg.server_port) + parser.add_argument("--ssl_keyfile", type=str) + parser.add_argument("--ssl_certfile", type=str) + parser.add_argument("--web", type=bool, default=self.cfg.mount_web) + parser.add_argument("--model_name", type=str, default=self.cfg.model_name) + parser.add_argument("--model_dir", type=str) + parser.add_argument("--model_path", type=str) + parser.add_argument( + "--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter" + ) + parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path) + parser.add_argument("--optimize_config_path", default=self.cfg.optimize_config_path, type=str, required=False) + parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer) + parser.add_argument("--type", type=str, default=self.cfg.backend_type) + + # model configs + # parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int? + parser.add_argument("--paged", type=bool, default=self.cfg.paged) + parser.add_argument("--total_context", type=int, default=self.cfg.total_context) + parser.add_argument("--max_batch_size", type=int, default=self.cfg.max_batch_size) + parser.add_argument("--max_chunk_size", type=int, default=self.cfg.max_chunk_size) + parser.add_argument("--max_new_tokens", type=int, default=self.cfg.max_new_tokens) + parser.add_argument("--json_mode", type=bool, default=self.cfg.json_mode) + parser.add_argument("--healing", type=bool, default=self.cfg.healing) + parser.add_argument("--ban_strings", type=list, default=self.cfg.ban_strings, required=False) + parser.add_argument("--gpu_split", type=str, default=self.cfg.gpu_split, required=False) + parser.add_argument("--length", type=int, default=self.cfg.length, required=False) + parser.add_argument("--rope_scale", type=float, default=self.cfg.rope_scale, required=False) + parser.add_argument("--rope_alpha", type=float, default=self.cfg.rope_alpha, required=False) + parser.add_argument("--no_flash_attn", type=bool, default=self.cfg.no_flash_attn) + parser.add_argument("--low_mem", type=bool, default=self.cfg.low_mem) + parser.add_argument("--experts_per_token", type=int, default=self.cfg.experts_per_token, required=False) + parser.add_argument("--load_q4", type=bool, default=self.cfg.load_q4) + parser.add_argument("--fast_safetensors", type=bool, default=self.cfg.fast_safetensors) + parser.add_argument("--draft_model_dir", type=str, default=self.cfg.draft_model_dir, required=False) + parser.add_argument("--no_draft_scale", type=bool, default=self.cfg.no_draft_scale) + parser.add_argument("--modes", type=bool, default=self.cfg.modes) + parser.add_argument("--mode", type=str, default=self.cfg.mode) + parser.add_argument("--username", type=str, default=self.cfg.username) + parser.add_argument("--botname", type=str, default=self.cfg.botname) + parser.add_argument("--system_prompt", type=str, default=self.cfg.system_prompt, required=False) + parser.add_argument("--temperature", type=float, default=self.cfg.temperature) + parser.add_argument("--smoothing_factor", type=float, default=self.cfg.smoothing_factor) + parser.add_argument("--dynamic_temperature", type=str, default=self.cfg.dynamic_temperature, required=False) + parser.add_argument("--top_k", type=int, default=self.cfg.top_k) + parser.add_argument("--top_p", type=float, default=self.cfg.top_p) + parser.add_argument("--top_a", type=float, default=self.cfg.top_a) + parser.add_argument("--skew", type=float, default=self.cfg.skew) + parser.add_argument("--typical", type=float, default=self.cfg.typical) + parser.add_argument("--repetition_penalty", type=float, default=self.cfg.repetition_penalty) + parser.add_argument("--frequency_penalty", type=float, default=self.cfg.frequency_penalty) + parser.add_argument("--presence_penalty", type=float, default=self.cfg.presence_penalty) + parser.add_argument("--max_response_tokens", type=int, default=self.cfg.max_response_tokens) + parser.add_argument("--response_chunk", type=int, default=self.cfg.response_chunk) + parser.add_argument("--no_code_formatting", type=bool, default=self.cfg.no_code_formatting) + parser.add_argument("--cache_8bit", type=bool, default=self.cfg.cache_8bit) + parser.add_argument("--cache_q4", type=bool, default=self.cfg.cache_q4) + parser.add_argument("--ngram_decoding", type=bool, default=self.cfg.ngram_decoding) + parser.add_argument("--print_timings", type=bool, default=self.cfg.print_timings) + parser.add_argument("--amnesia", type=bool, default=self.cfg.amnesia) + parser.add_argument("--batch_size", type=int, default=self.cfg.batch_size) + parser.add_argument("--cache_lens", type=int, default=self.cfg.cache_lens) + + # log configs + # log level: debug, info, warn, error, crit + parser.add_argument("--log_dir", type=str, default=self.cfg.log_dir) + parser.add_argument("--log_file", type=str, default=self.cfg.log_file) + parser.add_argument("--log_level", type=str, default=self.cfg.log_level) + parser.add_argument("--backup_count", type=int, default=self.cfg.backup_count) + + # db configs + parser.add_argument("--db_type", type=str, default=self.cfg.db_type) + parser.add_argument("--db_host", type=str, default=self.cfg.db_host) + parser.add_argument("--db_port", type=str, default=self.cfg.db_port) + parser.add_argument("--db_name", type=str, default=self.cfg.db_name) + parser.add_argument("--db_pool_size", type=int, default=self.cfg.db_pool_size) + parser.add_argument("--db_database", type=str, default=self.cfg.db_database) + + # user config + parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key) + parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm) + + # web config + parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain) + + # file config + parser.add_argument("--file_upload_dir", type=str, default=self.cfg.file_upload_dir) + parser.add_argument("--assistant_store_dir", type=str, default=self.cfg.assistant_store_dir) + # local chat + parser.add_argument("--prompt_file", type=str, default=self.cfg.prompt_file) + + args = parser.parse_args() + if (args.model_dir is not None): + if (args.model_path is not None): + # if pass model_dir and model_path, we use model_path + args.model_dir = args.model_path + else: + # if only pass model_dir, we use model_dir + args.model_path = args.model_dir + else: + args.model_dir = self.cfg.model_dir + args.model_path = self.cfg.model_path + # set config from args + for key, value in vars(args).items(): + if value is not None and hasattr(self.cfg, key): + setattr(self.cfg, key, value) + # we add the name not match args individually + self.cfg.model_device = args.device + self.cfg.mount_web = args.web + self.cfg.server_ip = args.host + self.cfg.server_port = args.port + self.cfg.backend_type = args.type + return args diff --git a/ktransformers/server/backend/args.py b/ktransformers/server/backend/args.py index e16e914..0b473e7 100644 --- a/ktransformers/server/backend/args.py +++ b/ktransformers/server/backend/args.py @@ -1,97 +1,89 @@ -from pydantic import BaseModel,Field +from pydantic import BaseModel, Field from typing import Optional from ktransformers.server.config.config import Config class ConfigArgs(BaseModel): - model_name : Optional[str] = Field(..., description="Model name") + model_name: Optional[str] = Field(..., description="Model name") model_dir: Optional[str] = Field(..., description="Path to model directory") - optimize_config_path: Optional[str] = Field('./KTransformers/optimize_config/DeepSeek-V2-Chat.json', description="Path of your optimize config json file") - gguf_path: Optional[str] = Field('/models/DeepSeek-Coder-V2-Instruct-GGUF/DeepSeek-Coder-V2-Instruct-Q4_K_M.gguf', description="Path of your gguf file") + optimize_config_path: Optional[str] = Field(None, description="Path of your optimize config yml file") + gguf_path: Optional[str] = Field(None, description="Path of your gguf file") + class Config: protected_namespaces = () - paged : bool = Field(True,description='Wether to use paged attention kv cache') - - # total_context: int = Field(16384, description="Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the total to distribute dynamically over however many jobs are active at once") - total_context: int = Field(2**18, description="Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the total to distribute dynamically over however many jobs are active at once") - max_batch_size: int = Field(20 if paged else 1, description="Max number of batches to run at once, assuming the sequences will fit within total_context") - max_chunk_size: int = Field(2048, description="Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new job is started, but at the expense of overall prompt ingestion speed") - max_new_tokens: int = Field(500, description="Max new tokens per completion. For this example applies to all jobs") - json_mode: bool = Field(False, description="Use LMFE to constrain the output to JSON format. See schema and details below") - healing: bool = Field(False, description="Demonstrate token healing") + paged: bool = Field(None, description="Whether to use paged attention kv cache") + total_context: int = Field( + None, + description=( + "Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the" + " total to distribute dynamically over however many jobs are active at once" + ), + ) + max_batch_size: int = Field( + None, description="Max number of batches to run at once, assuming the sequences will fit within total_context" + ) + max_chunk_size: int = Field( + None, + description=( + "Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new" + " job is started, but at the expense of overall prompt ingestion speed" + ), + ) + max_new_tokens: int = Field(None, description="Max new tokens per completion. For this example applies to all jobs") + json_mode: bool = Field( + None, description="Use LMFE to constrain the output to JSON format. See schema and details below" + ) + healing: bool = Field(None, description="Demonstrate token healing") ban_strings: Optional[list] = Field(None, description="Ban some phrases maybe") - gpu_split: Optional[str] = Field(None, description='"auto", or VRAM allocation per GPU in GB') length: Optional[int] = Field(None, description="Maximum sequence length") rope_scale: Optional[float] = Field(None, description="RoPE scaling factor") rope_alpha: Optional[float] = Field(None, description="RoPE alpha value (NTK)") - no_flash_attn: bool = Field(False, description="Disable Flash Attention") - low_mem: bool = Field( - False, - description="Enable VRAM optimizations, potentially trading off speed", - ) + no_flash_attn: bool = Field(None, description="Disable Flash Attention") + low_mem: bool = Field(None, description="Enable VRAM optimizations, potentially trading off speed") experts_per_token: Optional[int] = Field( - None, - description="Override MoE model's default number of experts per token", - ) - load_q4: bool = Field(False, description="Load weights in Q4 mode") - fast_safetensors: bool = Field( - False, - description="Optimized safetensors loading with direct I/O (experimental!)", + None, description="Override MoE model's default number of experts per token" ) + load_q4: bool = Field(None, description="Load weights in Q4 mode") + fast_safetensors: bool = Field(None, description="Optimized safetensors loading with direct I/O (experimental!)") draft_model_dir: Optional[str] = Field(None, description="Path to draft model directory") no_draft_scale: bool = Field( - False, + None, description="If draft model has smaller context size than model, don't apply alpha (NTK) scaling to extend it", ) - modes: bool = Field(False, description="List available modes and exit.") - mode: str = Field( - "llama", - description="Chat mode. Use llama for Llama 1/2 chat finetunes.", - ) - username: str = Field("User", description="Username when using raw chat mode") - botname: str = Field("Chatbort", description="Bot name when using raw chat mode") + modes: bool = Field(None, description="List available modes and exit.") + mode: str = Field(None, description="Chat mode. Use llama for Llama 1/2 chat finetunes.") + username: str = Field(None, description="Username when using raw chat mode") + botname: str = Field(None, description="Bot name when using raw chat mode") system_prompt: Optional[str] = Field(None, description="Use custom system prompt") - temperature: float = Field(0.95, description="Sampler temperature, default = 0.95 (1 to disable)") - smoothing_factor: float = Field(0.0, description="Smoothing Factor, default = 0.0 (0 to disable)") + temperature: float = Field(None, description="Sampler temperature, default = 0.95 (1 to disable)") + smoothing_factor: float = Field(None, description="Smoothing Factor, default = 0.0 (0 to disable)") dynamic_temperature: Optional[str] = Field( - None, - description="Dynamic temperature min,max,exponent, e.g. -dyntemp 0.2,1.5,1", + None, description="Dynamic temperature min,max,exponent, e.g. -dyntemp 0.2,1.5,1" ) - top_k: int = Field(50, description="Sampler top-K, default = 50 (0 to disable)") - top_p: float = Field(0.8, description="Sampler top-P, default = 0.8 (0 to disable)") - top_a: float = Field(0.0, description="Sampler top-A, default = 0.0 (0 to disable)") - skew: float = Field(0.0, description="Skew sampling, default = 0.0 (0 to disable)") - typical: float = Field( - 0.0, - description="Sampler typical threshold, default = 0.0 (0 to disable)", - ) - repetition_penalty: float = Field( - 1.01, - description="Sampler repetition penalty, default = 1.01 (1 to disable)", - ) - frequency_penalty: float = Field( - 0.0, - description="Sampler frequency penalty, default = 0.0 (0 to disable)", - ) - presence_penalty: float = Field( - 0.0, - description="Sampler presence penalty, default = 0.0 (0 to disable)", - ) - max_response_tokens: int = Field(300, description="Max tokens per response, default = 1000") - response_chunk: int = Field(250, description="Space to reserve in context for reply, default = 250") - no_code_formatting: bool = Field(False, description="Disable code formatting/syntax highlighting") - cache_8bit: bool = Field(False, description="Use 8-bit (FP8) cache") - cache_q4: bool = Field(True, description="Use Q4 cache") - ngram_decoding: bool = Field(False, description="Use n-gram speculative decoding") - print_timings: bool = Field(False, description="Output timings after each prompt") - amnesia: bool = Field(False, description="Forget context after every response") + top_k: int = Field(None, description="Sampler top-K, default = 50 (0 to disable)") + top_p: float = Field(None, description="Sampler top-P, default = 0.8 (0 to disable)") + top_a: float = Field(None, description="Sampler top-A, default = 0.0 (0 to disable)") + skew: float = Field(None, description="Skew sampling, default = 0.0 (0 to disable)") + typical: float = Field(None, description="Sampler typical threshold, default = 0.0 (0 to disable)") + repetition_penalty: float = Field(None, description="Sampler repetition penalty, default = 1.01 (1 to disable)") + frequency_penalty: float = Field(None, description="Sampler frequency penalty, default = 0.0 (0 to disable)") + presence_penalty: float = Field(None, description="Sampler presence penalty, default = 0.0 (0 to disable)") + max_response_tokens: int = Field(None, description="Max tokens per response, default = 1000") + response_chunk: int = Field(None, description="Space to reserve in context for reply, default = 250") + no_code_formatting: bool = Field(None, description="Disable code formatting/syntax highlighting") + cache_8bit: bool = Field(None, description="Use 8-bit (FP8) cache") + cache_q4: bool = Field(None, description="Use Q4 cache") + ngram_decoding: bool = Field(None, description="Use n-gram speculative decoding") + print_timings: bool = Field(None, description="Output timings after each prompt") + amnesia: bool = Field(None, description="Forget context after every response") # for transformers - batch_size :int = Field(1,description="Batch Size") - cache_lens:int = Field(4096, description="Cache lens for transformers static cache") - device:str = Field('cuda:2',description="device") + batch_size: int = Field(None, description="Batch Size") + cache_lens: int = Field(None, description="Cache lens for transformers static cache") + device: str = Field(None, description="device") + cfg = Config() -default_args = ConfigArgs(model_name=cfg.model_name,model_dir=cfg.model_path) +default_args = cfg diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 0e5a8c8..420f37e 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -1,6 +1,12 @@ import torch from transformers import AutoTokenizer, AutoConfig, GenerationConfig -from ktransformers.server.backend.interfaces.transformers import TransformersInterface,ConfigArgs, TransformersThreadContext,default_args,TextStreamer +from ktransformers.server.backend.interfaces.transformers import ( + TransformersInterface, + ConfigArgs, + TransformersThreadContext, + default_args, + TextStreamer, +) from ktransformers.server.config.log import logger from ktransformers.optimize.optimize import optimize_and_load_gguf from ktransformers.models.custom_cache import StaticCache @@ -14,71 +20,85 @@ class KTransformersThreadContext(TransformersThreadContext): class KTransformersInterface(TransformersInterface): - def __init__(self,args:ConfigArgs= default_args): + def __init__(self, args: ConfigArgs = default_args): self.args = args torch.set_default_dtype(torch.bfloat16) torch.set_grad_enabled(False) - self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir,device = args.device) - config=AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device) + config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) if config.architectures[0] == "Qwen2MoeForCausalLM": - config._attn_implementation="flash_attention_2" + config._attn_implementation = "flash_attention_2" with torch.device("meta"): - self.model=custom_models[config.architectures[0]](config) + self.model = custom_models[config.architectures[0]](config) if default_args.optimize_config_path is None: optimize_rule_path = default_optimize_rules[config.architectures[0]] else: optimize_rule_path = args.optimize_config_path - + # print(optimize_config) gguf_path = args.gguf_path if gguf_path is None: gguf_path = input( - "please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):" + "please input the path of your gguf file(gguf file in the dir containing input gguf file must all" + " belong to current model):" ) optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config) - device_map = self.model.gguf_loader.tensor_device_map - logger.info(f'{args.model_name} loaded from {args.model_dir} to {device_map}') - self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=device_map, dtype=self.model.dtype) - logger.info(f'StaticCache (length={args.cache_lens}) created at {device_map}, batch size:{args.batch_size}') + logger.info(f"{args.model_name} loaded from {args.model_dir} to {device_map}") + self.cache = StaticCache( + config=self.model.config, + max_batch_size=args.batch_size, + max_cache_len=args.cache_lens, + device=device_map, + dtype=self.model.dtype, + ) + logger.info(f"StaticCache (length={args.cache_lens}) created at {device_map}, batch size:{args.batch_size}") self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir) if self.model.generation_config.pad_token_id is None: self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id self.streamer = TextStreamer(self.tokenizer) - + def decode_one_tokens(self): if not hasattr(self, "cuda_graph_runner"): device_map = self.model.gguf_loader.tensor_device_map - torch_device = get_device('blk.0.self_attn', device_map) + torch_device = get_device("blk.0.self_attn", device_map) torch_device = "cuda:0" if torch_device == "cuda" else torch_device self.cuda_graph_runner = CUDAGraphRunner() - self.cuda_graph_runner.capture(self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, main_device=torch_device, return_dict=False, use_cache=True) - + self.cuda_graph_runner.capture( + self.model, + self.current_ids, + self.active_cache_position.unsqueeze(0), + self.active_cache_position, + self.cache, + main_device=torch_device, + return_dict=False, + use_cache=True, + ) + if hasattr(self, "cuda_graph_runner"): - logits = self.cuda_graph_runner(self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position) + logits = self.cuda_graph_runner( + self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position + ) self.cache.change_seq_length(1) torch.cuda.synchronize() - logits = logits[0,-1,:] + logits = logits[0, -1, :] return self.logits_to_token(logits) - + if self.use_static_cache: - mask = torch.ones((1,self.seq_length)).to(torch_device) + mask = torch.ones((1, self.seq_length)).to(torch_device) logits = self.model( self.current_ids, cache_position=self.active_cache_position, past_key_values=self.cache, attention_mask=mask, return_dict=False, - use_cache=True + use_cache=True, )[0] else: - logits = self.model( - self.current_ids, - return_dict=False - )[0] - logits = logits[0,-1,:] + logits = self.model(self.current_ids, return_dict=False)[0] + logits = logits[0, -1, :] return self.logits_to_token(logits) diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index 2c4779d..f205ac5 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -1,14 +1,22 @@ from typing import Any, List, Optional, Set -from transformers import LlamaTokenizer,AutoTokenizer, AutoConfig, LlamaForCausalLM,GenerationConfig, StaticCache, AutoModelForCausalLM,BitsAndBytesConfig +from transformers import ( + LlamaTokenizer, + AutoTokenizer, + AutoConfig, + LlamaForCausalLM, + GenerationConfig, + StaticCache, + AutoModelForCausalLM, + BitsAndBytesConfig, +) from ktransformers.server.schemas.base import ObjectID from ktransformers.server.utils.multi_timer import Profiler import torch import sys, os -from ..base import ThreadContext,BackendInterfaceBase +from ..base import ThreadContext, BackendInterfaceBase from ktransformers.server.config.log import logger -from ..args import ConfigArgs,default_args - +from ..args import ConfigArgs, default_args # This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py @@ -28,21 +36,20 @@ class TextStreamer: self.token_cache = [] self.print_len = 0 - def put(self, value)->Optional[str]: + def put(self, value) -> Optional[str]: """ Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. - """ - if not isinstance(value,int): + """ + if not isinstance(value, int): raise ValueError("TextStreamer only supports batch size 1, and int type input") - if self.skip_prompt and self.next_tokens_are_prompt: self.next_tokens_are_prompt = False return None # Add the new token to the cache and decodes the entire thing. self.token_cache.append(value) - text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True,**self.decode_kwargs) + text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs) # After the symbol for a new line, we flush the cache. if text.endswith("\n"): @@ -59,7 +66,7 @@ class TextStreamer: self.print_len += len(printable_text) return printable_text - def end(self)->Optional[str]: + def end(self) -> Optional[str]: """Flushes any remaining cache and prints a newline to stdout.""" # Flush the cache, if it exists if len(self.token_cache) > 0: @@ -71,7 +78,7 @@ class TextStreamer: self.next_tokens_are_prompt = True return printable_text - + def _is_chinese_char(self, cp): """Checks whether CP is the codepoint of a CJK character.""" # This defines a "chinese character" as anything in the CJK Unicode block: @@ -97,81 +104,81 @@ class TextStreamer: return False -class TransformersThreadContext(ThreadContext): +class TransformersThreadContext(ThreadContext): def get_local_messages(self): local_messages = [] for m in self.messages: - local_messages.append( - {'role':m.role.value, - 'content':m.get_text_content()} - ) - + local_messages.append({"role": m.role.value, "content": m.get_text_content()}) + return local_messages class TransformersInterface(BackendInterfaceBase): - use_static_cache : bool = True - + use_static_cache: bool = True model: Any tokenizer: AutoTokenizer - + cache: StaticCache - generated_ids:torch.Tensor - seq_length:int - + generated_ids: torch.Tensor + seq_length: int + streamer: TextStreamer # thread_related last_request_id: Optional[str] = None ever_generated_ids: Set[int] = set() - - - def __init__(self, args:ConfigArgs = default_args): + def __init__(self, args: ConfigArgs = default_args): self.args = args - - self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir) - self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device,use_safetensors=True) - logger.info(f'{args.model_name} loaded from {args.model_dir} to {args.device}') - - self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=args.device, dtype=self.model.dtype) - logger.info(f'StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}') - - self.streamer = TextStreamer(self.tokenizer) - + self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir) + self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device, use_safetensors=True) + logger.info(f"{args.model_name} loaded from {args.model_dir} to {args.device}") + + self.cache = StaticCache( + config=self.model.config, + max_batch_size=args.batch_size, + max_cache_len=args.cache_lens, + device=args.device, + dtype=self.model.dtype, + ) + logger.info(f"StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}") + + self.streamer = TextStreamer(self.tokenizer) @property def current_ids(self): - return self.generated_ids[:,self.seq_length-1].unsqueeze(1) - + return self.generated_ids[:, self.seq_length - 1].unsqueeze(1) + @property def active_cache_position(self): - return torch.tensor([self.seq_length-1], device=self.args.device) + return torch.tensor([self.seq_length - 1], device=self.args.device) - - def tokenize_prompt(self,prompt:str): - input_ids = self.tokenizer.encode(prompt,return_tensors='pt').to(self.args.device) + def tokenize_prompt(self, prompt: str): + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.args.device) return input_ids - def format_and_tokenize_input_ids(self,thread_id:ObjectID,messages:List): + def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List): for m in messages: - if m['role']=='system': - logger.warn(f'change {m["role"]} to user') - m['role'] = 'user' + if m["role"] == "system": + logger.warning(f'change {m["role"]} to user') + m["role"] = "user" new_messages = [messages[0]] - for m in messages[1:]: - if m['role'] == 'user' and new_messages[-1]['role']=='user': - logger.warn('merge two adjacent user messages') - new_messages[-1]['content']+=m['content'] + for m in messages[1:]: + if m["role"] == "user" and new_messages[-1]["role"] == "user": + logger.warning("merge two adjacent user messages") + new_messages[-1]["content"] += m["content"] else: - new_messages.append(m) - - + new_messages.append(m) + # if (self.last_request_id is not None) and self.last_request_id == thread_id: + # input_ids = self.tokenizer.encode(self.tokenizer.eos_token+self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",tokenize=False, add_generation_prompt=True), add_special_tokens = False, return_tensors="pt").to(self.args.device) + # else: + # input_ids = self.tokenizer.apply_chat_template( + # new_messages, return_tensors="pt", add_generation_prompt=True + # ).to(self.args.device) input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device) - if (self.last_request_id is not None) and self.last_request_id == thread_id: x = self.generated_ids[:,:self.seq_length] y = input_ids[:,:self.seq_length] @@ -179,19 +186,19 @@ class TransformersInterface(BackendInterfaceBase): unequal_mask = torch.ne(x,y) unequal_positions = torch.nonzero(unequal_mask) num_unequal_elements = unequal_mask.sum().item() - logger.warn(f'num_unequal_elements: {num_unequal_elements}') + logger.warning(f'num_unequal_elements: {num_unequal_elements}') input_ids = input_ids[:,self.seq_length:] - logger.debug(f'get input ids of shape {input_ids.shape}') + logger.debug(f"get input ids of shape {input_ids.shape}") return input_ids - - def append_new_tokens(self,new_tokens:int)->Optional[str]: - self.generated_ids[0,self.seq_length] = new_tokens - self.seq_length+=1 + + def append_new_tokens(self, new_tokens: int) -> Optional[str]: + self.generated_ids[0, self.seq_length] = new_tokens + self.seq_length += 1 return self.streamer.put(new_tokens) - def logits_to_token(self,logits:torch.Tensor): - logits = logits/self.args.temperature + def logits_to_token(self, logits: torch.Tensor): + logits = logits / self.args.temperature for token_idx in self.ever_generated_ids: if logits[token_idx] < 0: @@ -200,7 +207,7 @@ class TransformersInterface(BackendInterfaceBase): logits[token_idx] /= self.args.repetition_penalty probs = torch.nn.functional.softmax(logits, dim=-1) - + sample = True if sample: last = torch.multinomial(probs, num_samples=1) @@ -211,127 +218,124 @@ class TransformersInterface(BackendInterfaceBase): self.ever_generated_ids.add(last) return last - - def decode_one_tokens(self): if self.use_static_cache: - mask = torch.ones((1,self.seq_length)).to(self.args.device) + mask = torch.ones((1, self.seq_length)).to(self.args.device) logits = self.model( self.current_ids, cache_position=self.active_cache_position, past_key_values=self.cache, attention_mask=mask, return_dict=False, - use_cache=True + use_cache=True, )[0] else: - logits = self.model( - self.current_ids, - return_dict=False - )[0] - logits = logits[0,-1,:] + logits = self.model(self.current_ids, return_dict=False)[0] + logits = logits[0, -1, :] return self.logits_to_token(logits) @torch.no_grad - def prefill(self,input_ids:torch.Tensor,is_new:bool): + def prefill(self, input_ids: torch.Tensor, is_new: bool): input_ids_length = input_ids.shape[-1] - self.profiler.set_counter('prefill',input_ids_length) - logger.debug(f'input_ids: {input_ids.shape}') + self.profiler.set_counter("prefill", input_ids_length) + logger.debug(f"input_ids: {input_ids.shape}") - if is_new: self.cache.reset() self.ever_generated_ids.clear() former_seq_length = 0 self.seq_length = input_ids_length self.generated_ids = torch.zeros( - self.args.batch_size, self.seq_length + self.args.max_new_tokens + 1, dtype=torch.int, device=self.args.device - ) + self.args.batch_size, + self.seq_length + self.args.max_new_tokens + 1, + dtype=torch.int, + device=self.args.device, + ) else: - logger.debug(f'generate_ids: {self.generated_ids.shape}') + logger.debug(f"generate_ids: {self.generated_ids.shape}") former_seq_length = self.seq_length self.seq_length += input_ids_length - expected_length = self.seq_length + self.args.max_new_tokens+1 + expected_length = self.seq_length + self.args.max_new_tokens + 1 delta_length = expected_length - self.generated_ids.shape[-1] - if delta_length>0: + if delta_length > 0: new_generate_ids = torch.zeros( self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device ) - self.generated_ids = torch.cat([self.generated_ids,new_generate_ids],dim=-1) - logger.debug(f'cache position: {former_seq_length} to {self.seq_length}') - cache_position = torch.arange(former_seq_length,self.seq_length, device=self.args.device) - self.generated_ids[:,cache_position] = input_ids.to(self.args.device).to(torch.int) + self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1) + logger.debug(f"cache position: {former_seq_length} to {self.seq_length}") + cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device) + self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int) - mask = torch.ones((1,self.seq_length)).to(self.args.device) + mask = torch.ones((1, self.seq_length)).to(self.args.device) device = input_ids.device - if not(type(self) is TransformersInterface): + if not (type(self) is TransformersInterface): input_ids = input_ids.to("cpu") inputs_embeds = self.model.model.embed_tokens(input_ids).to(device) if self.use_static_cache: logits = self.model( - inputs_embeds=inputs_embeds, cache_position=cache_position, past_key_values=self.cache,return_dict=False, use_cache=True,attention_mask=mask + inputs_embeds=inputs_embeds, + cache_position=cache_position, + past_key_values=self.cache, + return_dict=False, + use_cache=True, + attention_mask=mask, )[0] else: - logits = self.model( - inputs_embeds=inputs_embeds,return_dict=False - )[0] + logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] - - - next_token = self.logits_to_token(logits[0,-1,:]) + next_token = self.logits_to_token(logits[0, -1, :]) yield self.append_new_tokens(next_token) @torch.no_grad def generate(self): - self.profiler.set_counter('decode',0) + self.profiler.set_counter("decode", 0) for _ in range(1, self.args.max_new_tokens): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): next_token = self.decode_one_tokens() - self.profiler.inc('decode') + self.profiler.inc("decode") if next_token == self.tokenizer.eos_token_id: assert self.args.batch_size == 1 break yield self.append_new_tokens(next_token) yield self.streamer.end() - def check_is_new(self,thread_id:str): + def check_is_new(self, thread_id: str): if not self.use_static_cache: return True if self.last_request_id is None: self.last_request_id = thread_id return True else: - if self.last_request_id==thread_id: + if self.last_request_id == thread_id: return False else: self.last_request_id = thread_id return True - async def inference(self,local_messages,thread_id:str): - self.profiler.create_and_start_timer('tokenize') - if isinstance(local_messages,List): - input_ids = self.format_and_tokenize_input_ids(thread_id,local_messages) - elif isinstance(local_messages,str): + async def inference(self, local_messages, thread_id: str): + self.profiler.create_and_start_timer("tokenize") + if isinstance(local_messages, List): + input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages) + elif isinstance(local_messages, str): input_ids = self.tokenize_prompt(local_messages) else: - raise ValueError('local_messages should be List or str') + raise ValueError("local_messages should be List or str") - self.profiler.pause_timer('tokenize') + self.profiler.pause_timer("tokenize") - self.profiler.create_and_start_timer('prefill') - for t in self.prefill(input_ids,self.check_is_new(thread_id)): + self.profiler.create_and_start_timer("prefill") + for t in self.prefill(input_ids, self.check_is_new(thread_id)): if t is not None: - print(t,end='') + print(t, end="") yield t - self.profiler.pause_timer('prefill') + self.profiler.pause_timer("prefill") - self.profiler.create_and_start_timer('decode') + self.profiler.create_and_start_timer("decode") for t in self.generate(): if t is not None: - print(t,end='') + print(t, end="") yield t - print('') - self.profiler.pause_timer('decode') + print("") + self.profiler.pause_timer("decode") self.report_last_time_performance() - diff --git a/ktransformers/server/config/config.py b/ktransformers/server/config/config.py index d391d66..850b6db 100644 --- a/ktransformers/server/config/config.py +++ b/ktransformers/server/config/config.py @@ -1,23 +1,24 @@ #!/usr/bin/env python # coding=utf-8 -''' -Description : +""" +Description : Author : unicornchan Date : 2024-06-11 16:35:42 Version : 1.0.0 -LastEditors : WuHao +LastEditors : WuHao LastEditTime : 2024-08-12 06:31:14 -''' +""" import os import shutil import yaml from ktransformers.server.config.singleton import Singleton +from typing import Optional class Config(metaclass=Singleton): - """Singleton pattern Config class, used to get all configurations. - """ + """Singleton pattern Config class, used to get all configurations.""" + CONFIG_FILE_NAME = "config.yaml" @staticmethod @@ -27,22 +28,20 @@ class Config(metaclass=Singleton): Returns: dict: all configs """ - base_path: str = os.path.dirname( - os.path.dirname(os.path.dirname(__file__))) - config_yaml: str = os.path.join( - base_path, "configs", Config.CONFIG_FILE_NAME) - - user_path: str = os.path.expanduser('~') - localstore_path: str = os.path.join(user_path,'.ktransformers') - config_path: str = os.path.join(localstore_path,Config.CONFIG_FILE_NAME) + base_path: str = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + config_yaml: str = os.path.join(base_path, "configs", Config.CONFIG_FILE_NAME) + + user_path: str = os.path.expanduser("~") + localstore_path: str = os.path.join(user_path, ".ktransformers") + config_path: str = os.path.join(localstore_path, Config.CONFIG_FILE_NAME) if not os.path.exists(config_yaml): print(f"Can't find config file, {config_yaml}") exit(-1) if not os.path.exists(localstore_path): os.mkdir(localstore_path) if not os.path.exists(config_path): - shutil.copyfile(config_yaml,config_path) - with open(config_path, 'r', encoding="utf-8") as fp: + shutil.copyfile(config_yaml, config_path) + with open(config_path, "r", encoding="utf-8") as fp: config = yaml.safe_load(fp) return config @@ -52,16 +51,14 @@ class Config(metaclass=Singleton): process file path """ base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) - real_path = path if os.path.isabs( - path) else os.path.join(base_path, path) + real_path = path if os.path.isabs(path) else os.path.join(base_path, path) return real_path def __init__(self): cfg = Config.load() - self.base_path = os.path.dirname( - os.path.dirname(os.path.dirname(__file__))) - self.user_path: str = os.path.expanduser('~') - self.localstore_path: str = os.path.join(self.user_path,'.ktransformers') + self.base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + self.user_path: str = os.path.expanduser("~") + self.localstore_path: str = os.path.join(self.user_path, ".ktransformers") # log configs self.log_dir = os.path.join(self.base_path, Config.to_path(cfg["log"]["dir"])) self.log_file = cfg["log"]["file"] @@ -69,7 +66,7 @@ class Config(metaclass=Singleton): self.backup_count = cfg["log"]["backup_count"] # server configs - self.server: dict = cfg.get("server",{}) + self.server: dict = cfg.get("server", {}) self.server_ip = self.server.get("ip", "0.0.0.0") self.server_port = self.server.get("port", 9016) @@ -86,16 +83,68 @@ class Config(metaclass=Singleton): self.user_config: dict = cfg.get("user", {}) self.user_secret_key = self.user_config.get("secret_key", "") self.user_algorithm = self.user_config.get("algorithm", "") - + # model config - self.model:dict = cfg.get("model", {}) + self.model: dict = cfg.get("model", {}) self.backend_type: str = self.model.get("type", "transformers") - self.model_path: str = self.model.get("path", "") + self.model_dir: str = self.model.get("path", "") + # to make sure it consistent with previous version + self.model_path: str = self.model_dir self.model_name: str = self.model.get("name", "") self.model_device: str = self.model.get("device", "cuda:0") - self.gguf_path: str = self.model.get("gguf_path", "") - self.model_cache_lens = self.model.get("cache_lens") - + self.gguf_path: Optional[str] = self.model.get("gguf_path", None) + # self.model_cache_lens = self.model.get("cache_lens") + self.optimize_config_path: Optional[str] = self.model.get( + "optimize_config_path", "./ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml" + ) + self.paged = self.model.get("paged", True) + + self.total_context = self.model.get("total_context", 2**18) + self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1) + self.max_chunk_size = self.model.get("max_chunk_size", 2048) + self.max_new_tokens = self.model.get("max_new_tokens", 500) + self.json_mode = self.model.get("json_mode", False) + self.healing = self.model.get("healing", False) + self.ban_strings: Optional[list] = self.model.get("ban_strings", None) + self.gpu_split: Optional[str] = self.model.get("gpu_split", None) + self.length: Optional[int] = self.model.get("length", None) + self.rope_scale: Optional[float] = self.model.get("rope_scale", None) + self.rope_alpha: Optional[float] = self.model.get("rope_alpha", None) + self.no_flash_attn = self.model.get("no_flash_attn", False) + self.low_mem = self.model.get("low_mem", False) + self.experts_per_token: Optional[int] = self.model.get("experts_per_token", None) + self.load_q4 = self.model.get("load_q4", False) + self.fast_safetensors = self.model.get("fast_safetensors", False) + self.draft_model_dir: Optional[str] = self.model.get("draft_model_dir", None) + self.no_draft_scale = self.model.get("no_draft_scale", False) + self.modes = self.model.get("modes", False) + self.mode = self.model.get("mode", "llama") + self.username = self.model.get("username", "User") + self.botname = self.model.get("botname", "Chatbort") + self.system_prompt: Optional[str] = self.model.get("system_prompt", None) + self.temperature = self.model.get("temperature", 0.95) + self.smoothing_factor = self.model.get("smoothing_factor", 0.0) + self.dynamic_temperature: Optional[str] = self.model.get("dynamic_temperature", None) + self.top_k = self.model.get("top_k", 50) + self.top_p = self.model.get("top_p", 0.8) + self.top_a = self.model.get("top_a", 0.0) + self.skew = self.model.get("skew", 0.0) + self.typical = self.model.get("typical", 0.0) + self.repetition_penalty = self.model.get("repetition_penalty", 1.01) + self.frequency_penalty = self.model.get("frequency_penalty", 0.0) + self.presence_penalty = self.model.get("presence_penalty", 0.0) + self.max_response_tokens = self.model.get("max_response_tokens", 300) + self.response_chunk = self.model.get("response_chunk", 250) + self.no_code_formatting = self.model.get("no_code_formatting", False) + self.cache_8bit = self.model.get("cache_8bit", False) + self.cache_q4 = self.model.get("cache_q4", True) + self.ngram_decoding = self.model.get("ngram_decoding", False) + self.print_timings = self.model.get("print_timings", False) + self.amnesia = self.model.get("amnesia", False) + self.batch_size = self.model.get("batch_size", 1) + self.cache_lens = self.model.get("cache_lens", 4096) + self.device = self.model.get("device", "cuda:2") + # web config self.web: dict = cfg.get("web", {}) self.web_cross_domain: bool = self.web.get("open_cross_domain", True) @@ -104,10 +153,32 @@ class Config(metaclass=Singleton): self.ext: dict = cfg.get("ext", {}) self.cpu_infer = self.ext.get("cpu_infer", 10) - #file config - self.local_store_configs: dict = cfg.get("local_store",{}) - self.file_upload_dir: str = os.path.join(self.localstore_path,self.local_store_configs.get("file_upload_dir","")) - self.assistant_store_dir: str = os.path.join(self.localstore_path,self.local_store_configs.get("assistant_store_dir","")) + # file config + self.local_store_configs: dict = cfg.get("local_store", {}) + self.file_upload_dir: str = os.path.join( + self.localstore_path, self.local_store_configs.get("file_upload_dir", "") + ) + self.assistant_store_dir: str = os.path.join( + self.localstore_path, self.local_store_configs.get("assistant_store_dir", "") + ) - #long context config - self.long_context_config: dict = cfg.get("long_context",{}) \ No newline at end of file + # long context config + self.long_context_config: dict = cfg.get("long_context", {}) + self.chunk_size = self.long_context_config.get("chunk_size", 4096) + self.max_seq_len = self.long_context_config.get("max_seq_len", 32000) + self.block_size = self.long_context_config.get("block_size", 128) + self.local_windows_len = self.long_context_config.get("local_windows_len", 4096) + self.second_select_num = self.long_context_config.get("second_select_num", 32) + self.anchor_type = self.long_context_config.get("anchor_type", "DYNAMIC") + self.kv_type = self.long_context_config.get("kv_type", "FP16") + self.dense_layer_num = self.long_context_config.get("dense_layer_num", 2) + self.anchor_num = self.long_context_config.get("anchor_num", 1) + self.preselect_block = self.long_context_config.get("preselect_block", True) + self.head_select_mode = self.long_context_config.get("head_select_mode", "SHARED") + self.preselect_block_count = self.long_context_config.get("preselect_block_count", 32) + self.layer_step = self.long_context_config.get("layer_step", 1) + self.token_step = self.long_context_config.get("token_step", 100) + + # local chat + self.local_chat_config: dict = cfg.get("local_chat", {}) + self.prompt_file = self.local_chat_config.get("prompt_file", None) diff --git a/ktransformers/server/main.py b/ktransformers/server/main.py index 0bb52cc..5e01a48 100644 --- a/ktransformers/server/main.py +++ b/ktransformers/server/main.py @@ -3,11 +3,15 @@ import re from fastapi import FastAPI from fastapi.staticfiles import StaticFiles import uvicorn.logging -import argparse import uvicorn +import sys + +project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) +sys.path.insert(0, project_dir) from fastapi.middleware.cors import CORSMiddleware +from ktransformers.server.args import ArgumentParser from ktransformers.server.config.config import Config -from ktransformers.server.utils.create_interface import create_interface +from ktransformers.server.utils.create_interface import create_interface from ktransformers.server.backend.args import default_args from fastapi.openapi.utils import get_openapi @@ -44,8 +48,11 @@ def create_app(): mount_index_routes(app) return app + def update_web_port(config_file: str): - ip_port_pattern = r"(localhost|((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)):[0-9]{1,5}" + ip_port_pattern = ( + r"(localhost|((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)):[0-9]{1,5}" + ) with open(config_file, "r", encoding="utf-8") as f_cfg: web_config = f_cfg.read() ip_port = "localhost:" + str(Config().server_port) @@ -70,14 +77,15 @@ def mount_index_routes(app: FastAPI): def run_api(app, host, port, **kwargs): if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"): - uvicorn.run(app, - host=host, - port=port, - ssl_keyfile=kwargs.get("ssl_keyfile"), - ssl_certfile=kwargs.get("ssl_certfile"), - ) + uvicorn.run( + app, + host=host, + port=port, + ssl_keyfile=kwargs.get("ssl_keyfile"), + ssl_certfile=kwargs.get("ssl_certfile"), + ) else: - uvicorn.run(app, host=host, port=port, log_level='debug') + uvicorn.run(app, host=host, port=port, log_level="debug") def custom_openapi(app): @@ -90,53 +98,27 @@ def custom_openapi(app): description="We provided chat completion and openai assistant interfaces.", routes=app.routes, ) - openapi_schema["info"]["x-logo"] = { - "url": "https://kvcache.ai/media/icon_1.png" - } + openapi_schema["info"]["x-logo"] = {"url": "https://kvcache.ai/media/icon_1.png"} app.openapi_schema = openapi_schema return app.openapi_schema + def main(): cfg = Config() - parser = argparse.ArgumentParser(prog='kvcache.ai', - description='Ktransformers') - parser.add_argument("--host", type=str, default="0.0.0.0") - parser.add_argument("--port", type=int, default=cfg.server_port) - parser.add_argument("--ssl_keyfile", type=str) - parser.add_argument("--ssl_certfile", type=str) - parser.add_argument("--web", type=bool, default=False) - parser.add_argument("--model_name", type=str, default=cfg.model_name) - parser.add_argument("--model_path", type=str, default=cfg.model_path) - parser.add_argument("--device", type=str, default=cfg.model_device, help="Warning: Abandoning this parameter") - parser.add_argument("--gguf_path", type=str, default=cfg.gguf_path) - parser.add_argument("--optimize_config_path", default=None, type=str, required=False) - parser.add_argument("--cpu_infer", type=int, default=cfg.cpu_infer) - parser.add_argument("--type", type=str, default=cfg.backend_type) + arg_parser = ArgumentParser(cfg) # 初始化消息 - args = parser.parse_args() - cfg.model_name = args.model_name - cfg.model_path = args.model_path - cfg.model_device = args.device - cfg.mount_web = args.web - cfg.server_ip = args.host - cfg.server_port = args.port - cfg.cpu_infer = args.cpu_infer - cfg.backend_type = args.type - - default_args.model_dir = args.model_path - default_args.device = args.device - default_args.gguf_path = args.gguf_path - default_args.optimize_config_path = args.optimize_config_path - + args = arg_parser.parse_args() app = create_app() custom_openapi(app) - create_interface(config=cfg, default_args=default_args) - run_api(app=app, - host=args.host, - port=args.port, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile,) + create_interface(config=cfg, default_args=cfg) + run_api( + app=app, + host=args.host, + port=args.port, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ) if __name__ == "__main__": diff --git a/ktransformers/website/package-lock.json b/ktransformers/website/package-lock.json index 7d77ec1..444d585 100644 --- a/ktransformers/website/package-lock.json +++ b/ktransformers/website/package-lock.json @@ -4412,8 +4412,9 @@ }, "node_modules/@vue/cli": { "version": "5.0.8", - "resolved": "https://registry.npmmirror.com/@vue/cli/-/cli-5.0.8.tgz", + "resolved": "https://registry.npmjs.org/@vue/cli/-/cli-5.0.8.tgz", "integrity": "sha512-c/QKPdC09bYkW22m/boXkLaiz10z0Z2WHZO7zEeNdfSduqyWINZhKc6hVQU3Vk0NXW7BJAd7zWmcUrC8L9TuAA==", + "license": "MIT", "dependencies": { "@types/ejs": "^3.0.6", "@types/inquirer": "^8.1.3", diff --git a/pyproject.toml b/pyproject.toml index 8070241..028c6a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,4 +69,8 @@ ktransformers = "ktransformers.server.main:main" [tool.setuptools.packages.find] where = ["./", ] -include = ["ktransformers"] \ No newline at end of file +include = ["ktransformers"] +[tool.black] +line-length = 120 +preview = true +unstable = true \ No newline at end of file