diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..fadb988 --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +max-line-length = 120 +extend-select = B950 +extend-ignore = E203,E501,E701, B001,B006,B007,B008,B009,B010,B011,B016,B028,B031,B950,E265,E266,E401,E402,E711,E712,E713,E721,E722,E731,F401,F403,F405,F541,F811,F821,F841,W391 \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..dbf771d --- /dev/null +++ b/Makefile @@ -0,0 +1,21 @@ +flake_find: + cd ktransformers && flake8 | grep -Eo '[A-Z][0-9]{3}' | sort | uniq| paste -sd ',' - +format: + @cd ktransformers && black . + @black setup.py +dev_install: +# clear build dirs + rm -rf build + rm -rf *.egg-info + rm -rf ktransformers/ktransformers_ext/build + rm -rf ktransformers/ktransformers_ext/cuda/build + rm -rf ktransformers/ktransformers_ext/cuda/dist + rm -rf ktransformers/ktransformers_ext/cuda/*.egg-info + +# install ktransformers + echo "Installing python dependencies from requirements.txt" + pip install -r requirements-local_chat.txt + + echo "Installing ktransformers" + KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . --no-build-isolation + echo "Installation completed successfully" \ No newline at end of file diff --git a/ktransformers/configs/config.yaml b/ktransformers/configs/config.yaml index 4078e24..f6f6776 100644 --- a/ktransformers/configs/config.yaml +++ b/ktransformers/configs/config.yaml @@ -7,7 +7,7 @@ log: server: ip: 0.0.0.0 - port: 12456 + port: 10002 db: type: "sqllite" @@ -24,10 +24,13 @@ model: type: ktransformers name: DeepSeek-Coder-V2-Instruct - path: /mnt/data/model/DeepSeek-Coder-V2-Instruct/ - gguf_path: /mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH/ + # path: /mnt/data/model/DeepSeek-Coder-V2-Instruct/ + path: deepseek-ai/DeepSeek-V2-Lite-Chat + # gguf_path: /mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH/ + gguf_path: ./DeepSeek-V2-Lite-Chat-GGUF device: cuda:0 + cache_lens: 8192 web: mount: False @@ -50,4 +53,7 @@ long_context: head_select_mode: SHARED preselect_block_count: 32 layer_step: 1 - token_step: 100 \ No newline at end of file + token_step: + +local_chat: + prompt_file: "./ktransformers/p.txt" \ No newline at end of file diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 1057e82..80ada29 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -1,33 +1,23 @@ """ -Description : +Description : Author : Boxin Zhang, Azure-Tang Version : 0.1.0 -Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +Copyright (c) 2024 by KVCache.AI, All Rights Reserved. """ +import asyncio import os import platform import sys +from ktransformers.server.args import ArgumentParser + project_dir = os.path.dirname(os.path.dirname(__file__)) sys.path.insert(0, project_dir) -import torch -import logging -from transformers import ( - AutoTokenizer, - AutoConfig, - AutoModelForCausalLM, - GenerationConfig, - TextStreamer, -) -import json -import fire -from ktransformers.optimize.optimize import optimize_and_load_gguf from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM -from ktransformers.util.utils import prefill_and_generate from ktransformers.server.config.config import Config custom_models = { @@ -37,9 +27,7 @@ custom_models = { "MixtralForCausalLM": MixtralForCausalLM, } -ktransformer_rules_dir = ( - os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" -) +ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" default_optimize_rules = { "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml", @@ -48,75 +36,28 @@ default_optimize_rules = { } -def local_chat( - model_path: str | None = None, - optimize_rule_path: str = None, - gguf_path: str | None = None, - max_new_tokens: int = 1000, - cpu_infer: int = Config().cpu_infer, - use_cuda_graph: bool = True, - prompt_file : str | None = None, - mode: str = "normal", -): - - - torch.set_grad_enabled(False) - - Config().cpu_infer = cpu_infer - - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - if mode == 'long_context': - assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode" - torch.set_default_dtype(torch.float16) +def local_chat(): + config = Config() + arg_parser = ArgumentParser(config) + # 初始化消息 + arg_parser.parse_args() + if config.backend_type == "transformers": + from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface + elif config.backend_type == "exllamav2": + from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface + elif config.backend_type == "ktransformers": + from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface else: - torch.set_default_dtype(config.torch_dtype) - - with torch.device("meta"): - if config.architectures[0] in custom_models: - print("using custom modeling_xxx.py.") - if ( - "Qwen2Moe" in config.architectures[0] - ): # Qwen2Moe must use flash_attention_2 to avoid overflow. - config._attn_implementation = "flash_attention_2" - if "Llama" in config.architectures[0]: - config._attn_implementation = "eager" - if "Mixtral" in config.architectures[0]: - config._attn_implementation = "flash_attention_2" - - model = custom_models[config.architectures[0]](config) - else: - model = AutoModelForCausalLM.from_config( - config, trust_remote_code=True, attn_implementation="flash_attention_2" - ) - - if optimize_rule_path is None: - if config.architectures[0] in default_optimize_rules: - print("using default_optimize_rule for", config.architectures[0]) - optimize_rule_path = default_optimize_rules[config.architectures[0]] - else: - optimize_rule_path = input( - "please input the path of your rule file(yaml file containing optimize rules):" - ) - - if gguf_path is None: - gguf_path = input( - "please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):" - ) - optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config) - - model.generation_config = GenerationConfig.from_pretrained(model_path) - if model.generation_config.pad_token_id is None: - model.generation_config.pad_token_id = model.generation_config.eos_token_id - model.eval() - logging.basicConfig(level=logging.INFO) + raise NotImplementedError(f"{config.backend_type} not implemented") + interface = BackendInterface(config) system = platform.system() if system == "Windows": os.system("cls") else: os.system("clear") - + # add a history chat content + his_content = [] while True: content = input("Chat: ") if content.startswith('"""'): # prefix """ @@ -132,28 +73,27 @@ def local_chat( break else: content += line + "\n" - if content == "": - if prompt_file != None: - content = open(prompt_file, "r").read() - else: + if config.prompt_file == None or config.prompt_file == "": content = "Please write a piece of quicksort code in C++." + else: + content = open(config.prompt_file, "r").read() elif os.path.isfile(content): content = open(content, "r").read() - messages = [{"role": "user", "content": content}] - input_tensor = tokenizer.apply_chat_template( - messages, add_generation_prompt=True, return_tensors="pt" - ) - if mode == 'long_context': - assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ - "please change max_seq_len in ~/.ktransformers/config.yaml" - torch.set_default_dtype( - torch.bfloat16 - ) # TODO: Remove this, replace dtype using config - generated = prefill_and_generate( - model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode - ) + messages = his_content + [{"role": "user", "content": content}] + + async def async_inference(messages): + generated = "" + async for token in interface.inference(messages, "local_chat"): + generated += token + return generated + + generated = asyncio.run(async_inference(messages)) + his_content += [ + {"role": "user", "content": content}, + {"role": "assitant", "content": generated}, + ] if __name__ == "__main__": - fire.Fire(local_chat) + local_chat() diff --git a/ktransformers/server/api/ollama/completions.py b/ktransformers/server/api/ollama/completions.py index d0eb346..e3a1a51 100644 --- a/ktransformers/server/api/ollama/completions.py +++ b/ktransformers/server/api/ollama/completions.py @@ -135,6 +135,8 @@ class OllamaShowResponse(BaseModel): details: OllamaShowDetial model_info: OllamaModelInfo + class Config: + protected_namespaces = () diff --git a/ktransformers/server/args.py b/ktransformers/server/args.py new file mode 100644 index 0000000..61b88e7 --- /dev/null +++ b/ktransformers/server/args.py @@ -0,0 +1,113 @@ +import argparse +from ktransformers.server.backend.args import ConfigArgs, default_args + + +class ArgumentParser: + def __init__(self, cfg): + self.cfg = cfg + + def parse_args(self): + parser = argparse.ArgumentParser(prog="kvcache.ai", description="Ktransformers") + parser.add_argument("--host", type=str, default=self.cfg.server_ip) + parser.add_argument("--port", type=int, default=self.cfg.server_port) + parser.add_argument("--ssl_keyfile", type=str) + parser.add_argument("--ssl_certfile", type=str) + parser.add_argument("--web", type=bool, default=self.cfg.mount_web) + parser.add_argument("--model_name", type=str, default=self.cfg.model_name) + parser.add_argument("--model_dir", type=str, default=self.cfg.model_dir) + parser.add_argument( + "--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter" + ) + parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path) + parser.add_argument("--optimize_config_path", default=self.cfg.optimize_config_path, type=str, required=False) + parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer) + parser.add_argument("--type", type=str, default=self.cfg.backend_type) + + # model configs + # parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int? + parser.add_argument("--paged", type=bool, default=self.cfg.paged) + parser.add_argument("--total_context", type=int, default=self.cfg.total_context) + parser.add_argument("--max_batch_size", type=int, default=self.cfg.max_batch_size) + parser.add_argument("--max_chunk_size", type=int, default=self.cfg.max_chunk_size) + parser.add_argument("--max_new_tokens", type=int, default=self.cfg.max_new_tokens) + parser.add_argument("--json_mode", type=bool, default=self.cfg.json_mode) + parser.add_argument("--healing", type=bool, default=self.cfg.healing) + parser.add_argument("--ban_strings", type=list, default=self.cfg.ban_strings, required=False) + parser.add_argument("--gpu_split", type=str, default=self.cfg.gpu_split, required=False) + parser.add_argument("--length", type=int, default=self.cfg.length, required=False) + parser.add_argument("--rope_scale", type=float, default=self.cfg.rope_scale, required=False) + parser.add_argument("--rope_alpha", type=float, default=self.cfg.rope_alpha, required=False) + parser.add_argument("--no_flash_attn", type=bool, default=self.cfg.no_flash_attn) + parser.add_argument("--low_mem", type=bool, default=self.cfg.low_mem) + parser.add_argument("--experts_per_token", type=int, default=self.cfg.experts_per_token, required=False) + parser.add_argument("--load_q4", type=bool, default=self.cfg.load_q4) + parser.add_argument("--fast_safetensors", type=bool, default=self.cfg.fast_safetensors) + parser.add_argument("--draft_model_dir", type=str, default=self.cfg.draft_model_dir, required=False) + parser.add_argument("--no_draft_scale", type=bool, default=self.cfg.no_draft_scale) + parser.add_argument("--modes", type=bool, default=self.cfg.modes) + parser.add_argument("--mode", type=str, default=self.cfg.mode) + parser.add_argument("--username", type=str, default=self.cfg.username) + parser.add_argument("--botname", type=str, default=self.cfg.botname) + parser.add_argument("--system_prompt", type=str, default=self.cfg.system_prompt, required=False) + parser.add_argument("--temperature", type=float, default=self.cfg.temperature) + parser.add_argument("--smoothing_factor", type=float, default=self.cfg.smoothing_factor) + parser.add_argument("--dynamic_temperature", type=str, default=self.cfg.dynamic_temperature, required=False) + parser.add_argument("--top_k", type=int, default=self.cfg.top_k) + parser.add_argument("--top_p", type=float, default=self.cfg.top_p) + parser.add_argument("--top_a", type=float, default=self.cfg.top_a) + parser.add_argument("--skew", type=float, default=self.cfg.skew) + parser.add_argument("--typical", type=float, default=self.cfg.typical) + parser.add_argument("--repetition_penalty", type=float, default=self.cfg.repetition_penalty) + parser.add_argument("--frequency_penalty", type=float, default=self.cfg.frequency_penalty) + parser.add_argument("--presence_penalty", type=float, default=self.cfg.presence_penalty) + parser.add_argument("--max_response_tokens", type=int, default=self.cfg.max_response_tokens) + parser.add_argument("--response_chunk", type=int, default=self.cfg.response_chunk) + parser.add_argument("--no_code_formatting", type=bool, default=self.cfg.no_code_formatting) + parser.add_argument("--cache_8bit", type=bool, default=self.cfg.cache_8bit) + parser.add_argument("--cache_q4", type=bool, default=self.cfg.cache_q4) + parser.add_argument("--ngram_decoding", type=bool, default=self.cfg.ngram_decoding) + parser.add_argument("--print_timings", type=bool, default=self.cfg.print_timings) + parser.add_argument("--amnesia", type=bool, default=self.cfg.amnesia) + parser.add_argument("--batch_size", type=int, default=self.cfg.batch_size) + parser.add_argument("--cache_lens", type=int, default=self.cfg.cache_lens) + + # log configs + # log level: debug, info, warn, error, crit + parser.add_argument("--log_dir", type=str, default=self.cfg.log_dir) + parser.add_argument("--log_file", type=str, default=self.cfg.log_file) + parser.add_argument("--log_level", type=str, default=self.cfg.log_level) + parser.add_argument("--backup_count", type=int, default=self.cfg.backup_count) + + # db configs + parser.add_argument("--db_type", type=str, default=self.cfg.db_type) + parser.add_argument("--db_host", type=str, default=self.cfg.db_host) + parser.add_argument("--db_port", type=str, default=self.cfg.db_port) + parser.add_argument("--db_name", type=str, default=self.cfg.db_name) + parser.add_argument("--db_pool_size", type=int, default=self.cfg.db_pool_size) + parser.add_argument("--db_database", type=str, default=self.cfg.db_database) + + # user config + parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key) + parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm) + + # web config + parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain) + + # file config + parser.add_argument("--file_upload_dir", type=str, default=self.cfg.file_upload_dir) + parser.add_argument("--assistant_store_dir", type=str, default=self.cfg.assistant_store_dir) + # local chat + parser.add_argument("--prompt_file", type=str, default=self.cfg.prompt_file) + + args = parser.parse_args() + # set config from args + for key, value in vars(args).items(): + if value is not None and hasattr(self.cfg, key): + setattr(self.cfg, key, value) + # we add the name not match args individually + self.cfg.model_device = args.device + self.cfg.mount_web = args.web + self.cfg.server_ip = args.host + self.cfg.server_port = args.port + self.cfg.backend_type = args.type + return args diff --git a/ktransformers/server/backend/args.py b/ktransformers/server/backend/args.py index e16e914..0b473e7 100644 --- a/ktransformers/server/backend/args.py +++ b/ktransformers/server/backend/args.py @@ -1,97 +1,89 @@ -from pydantic import BaseModel,Field +from pydantic import BaseModel, Field from typing import Optional from ktransformers.server.config.config import Config class ConfigArgs(BaseModel): - model_name : Optional[str] = Field(..., description="Model name") + model_name: Optional[str] = Field(..., description="Model name") model_dir: Optional[str] = Field(..., description="Path to model directory") - optimize_config_path: Optional[str] = Field('./KTransformers/optimize_config/DeepSeek-V2-Chat.json', description="Path of your optimize config json file") - gguf_path: Optional[str] = Field('/models/DeepSeek-Coder-V2-Instruct-GGUF/DeepSeek-Coder-V2-Instruct-Q4_K_M.gguf', description="Path of your gguf file") + optimize_config_path: Optional[str] = Field(None, description="Path of your optimize config yml file") + gguf_path: Optional[str] = Field(None, description="Path of your gguf file") + class Config: protected_namespaces = () - paged : bool = Field(True,description='Wether to use paged attention kv cache') - - # total_context: int = Field(16384, description="Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the total to distribute dynamically over however many jobs are active at once") - total_context: int = Field(2**18, description="Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the total to distribute dynamically over however many jobs are active at once") - max_batch_size: int = Field(20 if paged else 1, description="Max number of batches to run at once, assuming the sequences will fit within total_context") - max_chunk_size: int = Field(2048, description="Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new job is started, but at the expense of overall prompt ingestion speed") - max_new_tokens: int = Field(500, description="Max new tokens per completion. For this example applies to all jobs") - json_mode: bool = Field(False, description="Use LMFE to constrain the output to JSON format. See schema and details below") - healing: bool = Field(False, description="Demonstrate token healing") + paged: bool = Field(None, description="Whether to use paged attention kv cache") + total_context: int = Field( + None, + description=( + "Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the" + " total to distribute dynamically over however many jobs are active at once" + ), + ) + max_batch_size: int = Field( + None, description="Max number of batches to run at once, assuming the sequences will fit within total_context" + ) + max_chunk_size: int = Field( + None, + description=( + "Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new" + " job is started, but at the expense of overall prompt ingestion speed" + ), + ) + max_new_tokens: int = Field(None, description="Max new tokens per completion. For this example applies to all jobs") + json_mode: bool = Field( + None, description="Use LMFE to constrain the output to JSON format. See schema and details below" + ) + healing: bool = Field(None, description="Demonstrate token healing") ban_strings: Optional[list] = Field(None, description="Ban some phrases maybe") - gpu_split: Optional[str] = Field(None, description='"auto", or VRAM allocation per GPU in GB') length: Optional[int] = Field(None, description="Maximum sequence length") rope_scale: Optional[float] = Field(None, description="RoPE scaling factor") rope_alpha: Optional[float] = Field(None, description="RoPE alpha value (NTK)") - no_flash_attn: bool = Field(False, description="Disable Flash Attention") - low_mem: bool = Field( - False, - description="Enable VRAM optimizations, potentially trading off speed", - ) + no_flash_attn: bool = Field(None, description="Disable Flash Attention") + low_mem: bool = Field(None, description="Enable VRAM optimizations, potentially trading off speed") experts_per_token: Optional[int] = Field( - None, - description="Override MoE model's default number of experts per token", - ) - load_q4: bool = Field(False, description="Load weights in Q4 mode") - fast_safetensors: bool = Field( - False, - description="Optimized safetensors loading with direct I/O (experimental!)", + None, description="Override MoE model's default number of experts per token" ) + load_q4: bool = Field(None, description="Load weights in Q4 mode") + fast_safetensors: bool = Field(None, description="Optimized safetensors loading with direct I/O (experimental!)") draft_model_dir: Optional[str] = Field(None, description="Path to draft model directory") no_draft_scale: bool = Field( - False, + None, description="If draft model has smaller context size than model, don't apply alpha (NTK) scaling to extend it", ) - modes: bool = Field(False, description="List available modes and exit.") - mode: str = Field( - "llama", - description="Chat mode. Use llama for Llama 1/2 chat finetunes.", - ) - username: str = Field("User", description="Username when using raw chat mode") - botname: str = Field("Chatbort", description="Bot name when using raw chat mode") + modes: bool = Field(None, description="List available modes and exit.") + mode: str = Field(None, description="Chat mode. Use llama for Llama 1/2 chat finetunes.") + username: str = Field(None, description="Username when using raw chat mode") + botname: str = Field(None, description="Bot name when using raw chat mode") system_prompt: Optional[str] = Field(None, description="Use custom system prompt") - temperature: float = Field(0.95, description="Sampler temperature, default = 0.95 (1 to disable)") - smoothing_factor: float = Field(0.0, description="Smoothing Factor, default = 0.0 (0 to disable)") + temperature: float = Field(None, description="Sampler temperature, default = 0.95 (1 to disable)") + smoothing_factor: float = Field(None, description="Smoothing Factor, default = 0.0 (0 to disable)") dynamic_temperature: Optional[str] = Field( - None, - description="Dynamic temperature min,max,exponent, e.g. -dyntemp 0.2,1.5,1", + None, description="Dynamic temperature min,max,exponent, e.g. -dyntemp 0.2,1.5,1" ) - top_k: int = Field(50, description="Sampler top-K, default = 50 (0 to disable)") - top_p: float = Field(0.8, description="Sampler top-P, default = 0.8 (0 to disable)") - top_a: float = Field(0.0, description="Sampler top-A, default = 0.0 (0 to disable)") - skew: float = Field(0.0, description="Skew sampling, default = 0.0 (0 to disable)") - typical: float = Field( - 0.0, - description="Sampler typical threshold, default = 0.0 (0 to disable)", - ) - repetition_penalty: float = Field( - 1.01, - description="Sampler repetition penalty, default = 1.01 (1 to disable)", - ) - frequency_penalty: float = Field( - 0.0, - description="Sampler frequency penalty, default = 0.0 (0 to disable)", - ) - presence_penalty: float = Field( - 0.0, - description="Sampler presence penalty, default = 0.0 (0 to disable)", - ) - max_response_tokens: int = Field(300, description="Max tokens per response, default = 1000") - response_chunk: int = Field(250, description="Space to reserve in context for reply, default = 250") - no_code_formatting: bool = Field(False, description="Disable code formatting/syntax highlighting") - cache_8bit: bool = Field(False, description="Use 8-bit (FP8) cache") - cache_q4: bool = Field(True, description="Use Q4 cache") - ngram_decoding: bool = Field(False, description="Use n-gram speculative decoding") - print_timings: bool = Field(False, description="Output timings after each prompt") - amnesia: bool = Field(False, description="Forget context after every response") + top_k: int = Field(None, description="Sampler top-K, default = 50 (0 to disable)") + top_p: float = Field(None, description="Sampler top-P, default = 0.8 (0 to disable)") + top_a: float = Field(None, description="Sampler top-A, default = 0.0 (0 to disable)") + skew: float = Field(None, description="Skew sampling, default = 0.0 (0 to disable)") + typical: float = Field(None, description="Sampler typical threshold, default = 0.0 (0 to disable)") + repetition_penalty: float = Field(None, description="Sampler repetition penalty, default = 1.01 (1 to disable)") + frequency_penalty: float = Field(None, description="Sampler frequency penalty, default = 0.0 (0 to disable)") + presence_penalty: float = Field(None, description="Sampler presence penalty, default = 0.0 (0 to disable)") + max_response_tokens: int = Field(None, description="Max tokens per response, default = 1000") + response_chunk: int = Field(None, description="Space to reserve in context for reply, default = 250") + no_code_formatting: bool = Field(None, description="Disable code formatting/syntax highlighting") + cache_8bit: bool = Field(None, description="Use 8-bit (FP8) cache") + cache_q4: bool = Field(None, description="Use Q4 cache") + ngram_decoding: bool = Field(None, description="Use n-gram speculative decoding") + print_timings: bool = Field(None, description="Output timings after each prompt") + amnesia: bool = Field(None, description="Forget context after every response") # for transformers - batch_size :int = Field(1,description="Batch Size") - cache_lens:int = Field(4096, description="Cache lens for transformers static cache") - device:str = Field('cuda:2',description="device") + batch_size: int = Field(None, description="Batch Size") + cache_lens: int = Field(None, description="Cache lens for transformers static cache") + device: str = Field(None, description="device") + cfg = Config() -default_args = ConfigArgs(model_name=cfg.model_name,model_dir=cfg.model_path) +default_args = cfg diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 0e5a8c8..420f37e 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -1,6 +1,12 @@ import torch from transformers import AutoTokenizer, AutoConfig, GenerationConfig -from ktransformers.server.backend.interfaces.transformers import TransformersInterface,ConfigArgs, TransformersThreadContext,default_args,TextStreamer +from ktransformers.server.backend.interfaces.transformers import ( + TransformersInterface, + ConfigArgs, + TransformersThreadContext, + default_args, + TextStreamer, +) from ktransformers.server.config.log import logger from ktransformers.optimize.optimize import optimize_and_load_gguf from ktransformers.models.custom_cache import StaticCache @@ -14,71 +20,85 @@ class KTransformersThreadContext(TransformersThreadContext): class KTransformersInterface(TransformersInterface): - def __init__(self,args:ConfigArgs= default_args): + def __init__(self, args: ConfigArgs = default_args): self.args = args torch.set_default_dtype(torch.bfloat16) torch.set_grad_enabled(False) - self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir,device = args.device) - config=AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device) + config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) if config.architectures[0] == "Qwen2MoeForCausalLM": - config._attn_implementation="flash_attention_2" + config._attn_implementation = "flash_attention_2" with torch.device("meta"): - self.model=custom_models[config.architectures[0]](config) + self.model = custom_models[config.architectures[0]](config) if default_args.optimize_config_path is None: optimize_rule_path = default_optimize_rules[config.architectures[0]] else: optimize_rule_path = args.optimize_config_path - + # print(optimize_config) gguf_path = args.gguf_path if gguf_path is None: gguf_path = input( - "please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):" + "please input the path of your gguf file(gguf file in the dir containing input gguf file must all" + " belong to current model):" ) optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config) - device_map = self.model.gguf_loader.tensor_device_map - logger.info(f'{args.model_name} loaded from {args.model_dir} to {device_map}') - self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=device_map, dtype=self.model.dtype) - logger.info(f'StaticCache (length={args.cache_lens}) created at {device_map}, batch size:{args.batch_size}') + logger.info(f"{args.model_name} loaded from {args.model_dir} to {device_map}") + self.cache = StaticCache( + config=self.model.config, + max_batch_size=args.batch_size, + max_cache_len=args.cache_lens, + device=device_map, + dtype=self.model.dtype, + ) + logger.info(f"StaticCache (length={args.cache_lens}) created at {device_map}, batch size:{args.batch_size}") self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir) if self.model.generation_config.pad_token_id is None: self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id self.streamer = TextStreamer(self.tokenizer) - + def decode_one_tokens(self): if not hasattr(self, "cuda_graph_runner"): device_map = self.model.gguf_loader.tensor_device_map - torch_device = get_device('blk.0.self_attn', device_map) + torch_device = get_device("blk.0.self_attn", device_map) torch_device = "cuda:0" if torch_device == "cuda" else torch_device self.cuda_graph_runner = CUDAGraphRunner() - self.cuda_graph_runner.capture(self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, main_device=torch_device, return_dict=False, use_cache=True) - + self.cuda_graph_runner.capture( + self.model, + self.current_ids, + self.active_cache_position.unsqueeze(0), + self.active_cache_position, + self.cache, + main_device=torch_device, + return_dict=False, + use_cache=True, + ) + if hasattr(self, "cuda_graph_runner"): - logits = self.cuda_graph_runner(self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position) + logits = self.cuda_graph_runner( + self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position + ) self.cache.change_seq_length(1) torch.cuda.synchronize() - logits = logits[0,-1,:] + logits = logits[0, -1, :] return self.logits_to_token(logits) - + if self.use_static_cache: - mask = torch.ones((1,self.seq_length)).to(torch_device) + mask = torch.ones((1, self.seq_length)).to(torch_device) logits = self.model( self.current_ids, cache_position=self.active_cache_position, past_key_values=self.cache, attention_mask=mask, return_dict=False, - use_cache=True + use_cache=True, )[0] else: - logits = self.model( - self.current_ids, - return_dict=False - )[0] - logits = logits[0,-1,:] + logits = self.model(self.current_ids, return_dict=False)[0] + logits = logits[0, -1, :] return self.logits_to_token(logits) diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index 2c4779d..7f569c4 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -1,14 +1,22 @@ from typing import Any, List, Optional, Set -from transformers import LlamaTokenizer,AutoTokenizer, AutoConfig, LlamaForCausalLM,GenerationConfig, StaticCache, AutoModelForCausalLM,BitsAndBytesConfig +from transformers import ( + LlamaTokenizer, + AutoTokenizer, + AutoConfig, + LlamaForCausalLM, + GenerationConfig, + StaticCache, + AutoModelForCausalLM, + BitsAndBytesConfig, +) from ktransformers.server.schemas.base import ObjectID from ktransformers.server.utils.multi_timer import Profiler import torch import sys, os -from ..base import ThreadContext,BackendInterfaceBase +from ..base import ThreadContext, BackendInterfaceBase from ktransformers.server.config.log import logger -from ..args import ConfigArgs,default_args - +from ..args import ConfigArgs, default_args # This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py @@ -28,21 +36,20 @@ class TextStreamer: self.token_cache = [] self.print_len = 0 - def put(self, value)->Optional[str]: + def put(self, value) -> Optional[str]: """ Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. - """ - if not isinstance(value,int): + """ + if not isinstance(value, int): raise ValueError("TextStreamer only supports batch size 1, and int type input") - if self.skip_prompt and self.next_tokens_are_prompt: self.next_tokens_are_prompt = False return None # Add the new token to the cache and decodes the entire thing. self.token_cache.append(value) - text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True,**self.decode_kwargs) + text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs) # After the symbol for a new line, we flush the cache. if text.endswith("\n"): @@ -59,7 +66,7 @@ class TextStreamer: self.print_len += len(printable_text) return printable_text - def end(self)->Optional[str]: + def end(self) -> Optional[str]: """Flushes any remaining cache and prints a newline to stdout.""" # Flush the cache, if it exists if len(self.token_cache) > 0: @@ -71,7 +78,7 @@ class TextStreamer: self.next_tokens_are_prompt = True return printable_text - + def _is_chinese_char(self, cp): """Checks whether CP is the codepoint of a CJK character.""" # This defines a "chinese character" as anything in the CJK Unicode block: @@ -97,101 +104,91 @@ class TextStreamer: return False -class TransformersThreadContext(ThreadContext): +class TransformersThreadContext(ThreadContext): def get_local_messages(self): local_messages = [] for m in self.messages: - local_messages.append( - {'role':m.role.value, - 'content':m.get_text_content()} - ) - + local_messages.append({"role": m.role.value, "content": m.get_text_content()}) + return local_messages class TransformersInterface(BackendInterfaceBase): - use_static_cache : bool = True - + use_static_cache: bool = True model: Any tokenizer: AutoTokenizer - + cache: StaticCache - generated_ids:torch.Tensor - seq_length:int - + generated_ids: torch.Tensor + seq_length: int + streamer: TextStreamer # thread_related last_request_id: Optional[str] = None ever_generated_ids: Set[int] = set() - - - def __init__(self, args:ConfigArgs = default_args): + def __init__(self, args: ConfigArgs = default_args): self.args = args - - self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir) - self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device,use_safetensors=True) - logger.info(f'{args.model_name} loaded from {args.model_dir} to {args.device}') - - self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=args.device, dtype=self.model.dtype) - logger.info(f'StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}') - - self.streamer = TextStreamer(self.tokenizer) - + self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir) + self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device, use_safetensors=True) + logger.info(f"{args.model_name} loaded from {args.model_dir} to {args.device}") + + self.cache = StaticCache( + config=self.model.config, + max_batch_size=args.batch_size, + max_cache_len=args.cache_lens, + device=args.device, + dtype=self.model.dtype, + ) + logger.info(f"StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}") + + self.streamer = TextStreamer(self.tokenizer) @property def current_ids(self): - return self.generated_ids[:,self.seq_length-1].unsqueeze(1) - + return self.generated_ids[:, self.seq_length - 1].unsqueeze(1) + @property def active_cache_position(self): - return torch.tensor([self.seq_length-1], device=self.args.device) + return torch.tensor([self.seq_length - 1], device=self.args.device) - - def tokenize_prompt(self,prompt:str): - input_ids = self.tokenizer.encode(prompt,return_tensors='pt').to(self.args.device) + def tokenize_prompt(self, prompt: str): + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.args.device) return input_ids - def format_and_tokenize_input_ids(self,thread_id:ObjectID,messages:List): + def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List): for m in messages: - if m['role']=='system': - logger.warn(f'change {m["role"]} to user') - m['role'] = 'user' + if m["role"] == "system": + logger.warning(f'change {m["role"]} to user') + m["role"] = "user" new_messages = [messages[0]] - for m in messages[1:]: - if m['role'] == 'user' and new_messages[-1]['role']=='user': - logger.warn('merge two adjacent user messages') - new_messages[-1]['content']+=m['content'] + for m in messages[1:]: + if m["role"] == "user" and new_messages[-1]["role"] == "user": + logger.warning("merge two adjacent user messages") + new_messages[-1]["content"] += m["content"] else: - new_messages.append(m) + new_messages.append(m) - - input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device) - if (self.last_request_id is not None) and self.last_request_id == thread_id: - x = self.generated_ids[:,:self.seq_length] - y = input_ids[:,:self.seq_length] - # We can only hope that the input_ids are the same - unequal_mask = torch.ne(x,y) - unequal_positions = torch.nonzero(unequal_mask) - num_unequal_elements = unequal_mask.sum().item() - logger.warn(f'num_unequal_elements: {num_unequal_elements}') - - input_ids = input_ids[:,self.seq_length:] - logger.debug(f'get input ids of shape {input_ids.shape}') + input_ids = self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt").to(self.args.device) + else: + input_ids = self.tokenizer.apply_chat_template( + new_messages, return_tensors="pt", add_generation_prompt=True + ).to(self.args.device) + logger.debug(f"get input ids of shape {input_ids.shape}") return input_ids - - def append_new_tokens(self,new_tokens:int)->Optional[str]: - self.generated_ids[0,self.seq_length] = new_tokens - self.seq_length+=1 + + def append_new_tokens(self, new_tokens: int) -> Optional[str]: + self.generated_ids[0, self.seq_length] = new_tokens + self.seq_length += 1 return self.streamer.put(new_tokens) - def logits_to_token(self,logits:torch.Tensor): - logits = logits/self.args.temperature + def logits_to_token(self, logits: torch.Tensor): + logits = logits / self.args.temperature for token_idx in self.ever_generated_ids: if logits[token_idx] < 0: @@ -200,7 +197,7 @@ class TransformersInterface(BackendInterfaceBase): logits[token_idx] /= self.args.repetition_penalty probs = torch.nn.functional.softmax(logits, dim=-1) - + sample = True if sample: last = torch.multinomial(probs, num_samples=1) @@ -211,127 +208,124 @@ class TransformersInterface(BackendInterfaceBase): self.ever_generated_ids.add(last) return last - - def decode_one_tokens(self): if self.use_static_cache: - mask = torch.ones((1,self.seq_length)).to(self.args.device) + mask = torch.ones((1, self.seq_length)).to(self.args.device) logits = self.model( self.current_ids, cache_position=self.active_cache_position, past_key_values=self.cache, attention_mask=mask, return_dict=False, - use_cache=True + use_cache=True, )[0] else: - logits = self.model( - self.current_ids, - return_dict=False - )[0] - logits = logits[0,-1,:] + logits = self.model(self.current_ids, return_dict=False)[0] + logits = logits[0, -1, :] return self.logits_to_token(logits) @torch.no_grad - def prefill(self,input_ids:torch.Tensor,is_new:bool): + def prefill(self, input_ids: torch.Tensor, is_new: bool): input_ids_length = input_ids.shape[-1] - self.profiler.set_counter('prefill',input_ids_length) - logger.debug(f'input_ids: {input_ids.shape}') + self.profiler.set_counter("prefill", input_ids_length) + logger.debug(f"input_ids: {input_ids.shape}") - if is_new: self.cache.reset() self.ever_generated_ids.clear() former_seq_length = 0 self.seq_length = input_ids_length self.generated_ids = torch.zeros( - self.args.batch_size, self.seq_length + self.args.max_new_tokens + 1, dtype=torch.int, device=self.args.device - ) + self.args.batch_size, + self.seq_length + self.args.max_new_tokens + 1, + dtype=torch.int, + device=self.args.device, + ) else: - logger.debug(f'generate_ids: {self.generated_ids.shape}') + logger.debug(f"generate_ids: {self.generated_ids.shape}") former_seq_length = self.seq_length self.seq_length += input_ids_length - expected_length = self.seq_length + self.args.max_new_tokens+1 + expected_length = self.seq_length + self.args.max_new_tokens + 1 delta_length = expected_length - self.generated_ids.shape[-1] - if delta_length>0: + if delta_length > 0: new_generate_ids = torch.zeros( self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device ) - self.generated_ids = torch.cat([self.generated_ids,new_generate_ids],dim=-1) - logger.debug(f'cache position: {former_seq_length} to {self.seq_length}') - cache_position = torch.arange(former_seq_length,self.seq_length, device=self.args.device) - self.generated_ids[:,cache_position] = input_ids.to(self.args.device).to(torch.int) + self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1) + logger.debug(f"cache position: {former_seq_length} to {self.seq_length}") + cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device) + self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int) - mask = torch.ones((1,self.seq_length)).to(self.args.device) + mask = torch.ones((1, self.seq_length)).to(self.args.device) device = input_ids.device - if not(type(self) is TransformersInterface): + if not (type(self) is TransformersInterface): input_ids = input_ids.to("cpu") inputs_embeds = self.model.model.embed_tokens(input_ids).to(device) if self.use_static_cache: logits = self.model( - inputs_embeds=inputs_embeds, cache_position=cache_position, past_key_values=self.cache,return_dict=False, use_cache=True,attention_mask=mask + inputs_embeds=inputs_embeds, + cache_position=cache_position, + past_key_values=self.cache, + return_dict=False, + use_cache=True, + attention_mask=mask, )[0] else: - logits = self.model( - inputs_embeds=inputs_embeds,return_dict=False - )[0] + logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] - - - next_token = self.logits_to_token(logits[0,-1,:]) + next_token = self.logits_to_token(logits[0, -1, :]) yield self.append_new_tokens(next_token) @torch.no_grad def generate(self): - self.profiler.set_counter('decode',0) + self.profiler.set_counter("decode", 0) for _ in range(1, self.args.max_new_tokens): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): next_token = self.decode_one_tokens() - self.profiler.inc('decode') + self.profiler.inc("decode") if next_token == self.tokenizer.eos_token_id: assert self.args.batch_size == 1 break yield self.append_new_tokens(next_token) yield self.streamer.end() - def check_is_new(self,thread_id:str): + def check_is_new(self, thread_id: str): if not self.use_static_cache: return True if self.last_request_id is None: self.last_request_id = thread_id return True else: - if self.last_request_id==thread_id: + if self.last_request_id == thread_id: return False else: self.last_request_id = thread_id return True - async def inference(self,local_messages,thread_id:str): - self.profiler.create_and_start_timer('tokenize') - if isinstance(local_messages,List): - input_ids = self.format_and_tokenize_input_ids(thread_id,local_messages) - elif isinstance(local_messages,str): + async def inference(self, local_messages, thread_id: str): + self.profiler.create_and_start_timer("tokenize") + if isinstance(local_messages, List): + input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages) + elif isinstance(local_messages, str): input_ids = self.tokenize_prompt(local_messages) else: - raise ValueError('local_messages should be List or str') + raise ValueError("local_messages should be List or str") - self.profiler.pause_timer('tokenize') + self.profiler.pause_timer("tokenize") - self.profiler.create_and_start_timer('prefill') - for t in self.prefill(input_ids,self.check_is_new(thread_id)): + self.profiler.create_and_start_timer("prefill") + for t in self.prefill(input_ids, self.check_is_new(thread_id)): if t is not None: - print(t,end='') + print(t, end="") yield t - self.profiler.pause_timer('prefill') + self.profiler.pause_timer("prefill") - self.profiler.create_and_start_timer('decode') + self.profiler.create_and_start_timer("decode") for t in self.generate(): if t is not None: - print(t,end='') + print(t, end="") yield t - print('') - self.profiler.pause_timer('decode') + print("") + self.profiler.pause_timer("decode") self.report_last_time_performance() - diff --git a/ktransformers/server/config/config.py b/ktransformers/server/config/config.py index d391d66..3b38b1e 100644 --- a/ktransformers/server/config/config.py +++ b/ktransformers/server/config/config.py @@ -1,23 +1,24 @@ #!/usr/bin/env python # coding=utf-8 -''' -Description : +""" +Description : Author : unicornchan Date : 2024-06-11 16:35:42 Version : 1.0.0 -LastEditors : WuHao +LastEditors : WuHao LastEditTime : 2024-08-12 06:31:14 -''' +""" import os import shutil import yaml from ktransformers.server.config.singleton import Singleton +from typing import Optional class Config(metaclass=Singleton): - """Singleton pattern Config class, used to get all configurations. - """ + """Singleton pattern Config class, used to get all configurations.""" + CONFIG_FILE_NAME = "config.yaml" @staticmethod @@ -27,22 +28,20 @@ class Config(metaclass=Singleton): Returns: dict: all configs """ - base_path: str = os.path.dirname( - os.path.dirname(os.path.dirname(__file__))) - config_yaml: str = os.path.join( - base_path, "configs", Config.CONFIG_FILE_NAME) - - user_path: str = os.path.expanduser('~') - localstore_path: str = os.path.join(user_path,'.ktransformers') - config_path: str = os.path.join(localstore_path,Config.CONFIG_FILE_NAME) + base_path: str = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + config_yaml: str = os.path.join(base_path, "configs", Config.CONFIG_FILE_NAME) + + user_path: str = os.path.expanduser("~") + localstore_path: str = os.path.join(user_path, ".ktransformers") + config_path: str = os.path.join(localstore_path, Config.CONFIG_FILE_NAME) if not os.path.exists(config_yaml): print(f"Can't find config file, {config_yaml}") exit(-1) if not os.path.exists(localstore_path): os.mkdir(localstore_path) if not os.path.exists(config_path): - shutil.copyfile(config_yaml,config_path) - with open(config_path, 'r', encoding="utf-8") as fp: + shutil.copyfile(config_yaml, config_path) + with open(config_path, "r", encoding="utf-8") as fp: config = yaml.safe_load(fp) return config @@ -52,16 +51,14 @@ class Config(metaclass=Singleton): process file path """ base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) - real_path = path if os.path.isabs( - path) else os.path.join(base_path, path) + real_path = path if os.path.isabs(path) else os.path.join(base_path, path) return real_path def __init__(self): cfg = Config.load() - self.base_path = os.path.dirname( - os.path.dirname(os.path.dirname(__file__))) - self.user_path: str = os.path.expanduser('~') - self.localstore_path: str = os.path.join(self.user_path,'.ktransformers') + self.base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + self.user_path: str = os.path.expanduser("~") + self.localstore_path: str = os.path.join(self.user_path, ".ktransformers") # log configs self.log_dir = os.path.join(self.base_path, Config.to_path(cfg["log"]["dir"])) self.log_file = cfg["log"]["file"] @@ -69,7 +66,7 @@ class Config(metaclass=Singleton): self.backup_count = cfg["log"]["backup_count"] # server configs - self.server: dict = cfg.get("server",{}) + self.server: dict = cfg.get("server", {}) self.server_ip = self.server.get("ip", "0.0.0.0") self.server_port = self.server.get("port", 9016) @@ -86,16 +83,66 @@ class Config(metaclass=Singleton): self.user_config: dict = cfg.get("user", {}) self.user_secret_key = self.user_config.get("secret_key", "") self.user_algorithm = self.user_config.get("algorithm", "") - + # model config - self.model:dict = cfg.get("model", {}) + self.model: dict = cfg.get("model", {}) self.backend_type: str = self.model.get("type", "transformers") - self.model_path: str = self.model.get("path", "") + self.model_dir: str = self.model.get("path", "") self.model_name: str = self.model.get("name", "") self.model_device: str = self.model.get("device", "cuda:0") - self.gguf_path: str = self.model.get("gguf_path", "") - self.model_cache_lens = self.model.get("cache_lens") - + self.gguf_path: Optional[str] = self.model.get("gguf_path", None) + # self.model_cache_lens = self.model.get("cache_lens") + self.optimize_config_path: Optional[str] = self.model.get( + "optimize_config_path", "./ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml" + ) + self.paged = self.model.get("paged", True) + + self.total_context = self.model.get("total_context", 2**18) + self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1) + self.max_chunk_size = self.model.get("max_chunk_size", 2048) + self.max_new_tokens = self.model.get("max_new_tokens", 500) + self.json_mode = self.model.get("json_mode", False) + self.healing = self.model.get("healing", False) + self.ban_strings: Optional[list] = self.model.get("ban_strings", None) + self.gpu_split: Optional[str] = self.model.get("gpu_split", None) + self.length: Optional[int] = self.model.get("length", None) + self.rope_scale: Optional[float] = self.model.get("rope_scale", None) + self.rope_alpha: Optional[float] = self.model.get("rope_alpha", None) + self.no_flash_attn = self.model.get("no_flash_attn", False) + self.low_mem = self.model.get("low_mem", False) + self.experts_per_token: Optional[int] = self.model.get("experts_per_token", None) + self.load_q4 = self.model.get("load_q4", False) + self.fast_safetensors = self.model.get("fast_safetensors", False) + self.draft_model_dir: Optional[str] = self.model.get("draft_model_dir", None) + self.no_draft_scale = self.model.get("no_draft_scale", False) + self.modes = self.model.get("modes", False) + self.mode = self.model.get("mode", "llama") + self.username = self.model.get("username", "User") + self.botname = self.model.get("botname", "Chatbort") + self.system_prompt: Optional[str] = self.model.get("system_prompt", None) + self.temperature = self.model.get("temperature", 0.95) + self.smoothing_factor = self.model.get("smoothing_factor", 0.0) + self.dynamic_temperature: Optional[str] = self.model.get("dynamic_temperature", None) + self.top_k = self.model.get("top_k", 50) + self.top_p = self.model.get("top_p", 0.8) + self.top_a = self.model.get("top_a", 0.0) + self.skew = self.model.get("skew", 0.0) + self.typical = self.model.get("typical", 0.0) + self.repetition_penalty = self.model.get("repetition_penalty", 1.01) + self.frequency_penalty = self.model.get("frequency_penalty", 0.0) + self.presence_penalty = self.model.get("presence_penalty", 0.0) + self.max_response_tokens = self.model.get("max_response_tokens", 300) + self.response_chunk = self.model.get("response_chunk", 250) + self.no_code_formatting = self.model.get("no_code_formatting", False) + self.cache_8bit = self.model.get("cache_8bit", False) + self.cache_q4 = self.model.get("cache_q4", True) + self.ngram_decoding = self.model.get("ngram_decoding", False) + self.print_timings = self.model.get("print_timings", False) + self.amnesia = self.model.get("amnesia", False) + self.batch_size = self.model.get("batch_size", 1) + self.cache_lens = self.model.get("cache_lens", 4096) + self.device = self.model.get("device", "cuda:2") + # web config self.web: dict = cfg.get("web", {}) self.web_cross_domain: bool = self.web.get("open_cross_domain", True) @@ -104,10 +151,32 @@ class Config(metaclass=Singleton): self.ext: dict = cfg.get("ext", {}) self.cpu_infer = self.ext.get("cpu_infer", 10) - #file config - self.local_store_configs: dict = cfg.get("local_store",{}) - self.file_upload_dir: str = os.path.join(self.localstore_path,self.local_store_configs.get("file_upload_dir","")) - self.assistant_store_dir: str = os.path.join(self.localstore_path,self.local_store_configs.get("assistant_store_dir","")) + # file config + self.local_store_configs: dict = cfg.get("local_store", {}) + self.file_upload_dir: str = os.path.join( + self.localstore_path, self.local_store_configs.get("file_upload_dir", "") + ) + self.assistant_store_dir: str = os.path.join( + self.localstore_path, self.local_store_configs.get("assistant_store_dir", "") + ) - #long context config - self.long_context_config: dict = cfg.get("long_context",{}) \ No newline at end of file + # long context config + self.long_context_config: dict = cfg.get("long_context", {}) + self.chunk_size = self.long_context_config.get("chunk_size", 4096) + self.max_seq_len = self.long_context_config.get("max_seq_len", 32000) + self.block_size = self.long_context_config.get("block_size", 128) + self.local_windows_len = self.long_context_config.get("local_windows_len", 4096) + self.second_select_num = self.long_context_config.get("second_select_num", 32) + self.anchor_type = self.long_context_config.get("anchor_type", "DYNAMIC") + self.kv_type = self.long_context_config.get("kv_type", "FP16") + self.dense_layer_num = self.long_context_config.get("dense_layer_num", 2) + self.anchor_num = self.long_context_config.get("anchor_num", 1) + self.preselect_block = self.long_context_config.get("preselect_block", True) + self.head_select_mode = self.long_context_config.get("head_select_mode", "SHARED") + self.preselect_block_count = self.long_context_config.get("preselect_block_count", 32) + self.layer_step = self.long_context_config.get("layer_step", 1) + self.token_step = self.long_context_config.get("token_step", 100) + + # local chat + self.local_chat_config: dict = cfg.get("local_chat", {}) + self.prompt_file = self.local_chat_config.get("prompt_file", None) diff --git a/ktransformers/server/main.py b/ktransformers/server/main.py index 0bb52cc..90adef5 100644 --- a/ktransformers/server/main.py +++ b/ktransformers/server/main.py @@ -3,11 +3,11 @@ import re from fastapi import FastAPI from fastapi.staticfiles import StaticFiles import uvicorn.logging -import argparse import uvicorn from fastapi.middleware.cors import CORSMiddleware +from ktransformers.server.args import ArgumentParser from ktransformers.server.config.config import Config -from ktransformers.server.utils.create_interface import create_interface +from ktransformers.server.utils.create_interface import create_interface from ktransformers.server.backend.args import default_args from fastapi.openapi.utils import get_openapi @@ -44,8 +44,11 @@ def create_app(): mount_index_routes(app) return app + def update_web_port(config_file: str): - ip_port_pattern = r"(localhost|((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)):[0-9]{1,5}" + ip_port_pattern = ( + r"(localhost|((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)):[0-9]{1,5}" + ) with open(config_file, "r", encoding="utf-8") as f_cfg: web_config = f_cfg.read() ip_port = "localhost:" + str(Config().server_port) @@ -70,14 +73,15 @@ def mount_index_routes(app: FastAPI): def run_api(app, host, port, **kwargs): if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"): - uvicorn.run(app, - host=host, - port=port, - ssl_keyfile=kwargs.get("ssl_keyfile"), - ssl_certfile=kwargs.get("ssl_certfile"), - ) + uvicorn.run( + app, + host=host, + port=port, + ssl_keyfile=kwargs.get("ssl_keyfile"), + ssl_certfile=kwargs.get("ssl_certfile"), + ) else: - uvicorn.run(app, host=host, port=port, log_level='debug') + uvicorn.run(app, host=host, port=port, log_level="debug") def custom_openapi(app): @@ -90,53 +94,27 @@ def custom_openapi(app): description="We provided chat completion and openai assistant interfaces.", routes=app.routes, ) - openapi_schema["info"]["x-logo"] = { - "url": "https://kvcache.ai/media/icon_1.png" - } + openapi_schema["info"]["x-logo"] = {"url": "https://kvcache.ai/media/icon_1.png"} app.openapi_schema = openapi_schema return app.openapi_schema + def main(): cfg = Config() - parser = argparse.ArgumentParser(prog='kvcache.ai', - description='Ktransformers') - parser.add_argument("--host", type=str, default="0.0.0.0") - parser.add_argument("--port", type=int, default=cfg.server_port) - parser.add_argument("--ssl_keyfile", type=str) - parser.add_argument("--ssl_certfile", type=str) - parser.add_argument("--web", type=bool, default=False) - parser.add_argument("--model_name", type=str, default=cfg.model_name) - parser.add_argument("--model_path", type=str, default=cfg.model_path) - parser.add_argument("--device", type=str, default=cfg.model_device, help="Warning: Abandoning this parameter") - parser.add_argument("--gguf_path", type=str, default=cfg.gguf_path) - parser.add_argument("--optimize_config_path", default=None, type=str, required=False) - parser.add_argument("--cpu_infer", type=int, default=cfg.cpu_infer) - parser.add_argument("--type", type=str, default=cfg.backend_type) + arg_parser = ArgumentParser(cfg) # 初始化消息 - args = parser.parse_args() - cfg.model_name = args.model_name - cfg.model_path = args.model_path - cfg.model_device = args.device - cfg.mount_web = args.web - cfg.server_ip = args.host - cfg.server_port = args.port - cfg.cpu_infer = args.cpu_infer - cfg.backend_type = args.type - - default_args.model_dir = args.model_path - default_args.device = args.device - default_args.gguf_path = args.gguf_path - default_args.optimize_config_path = args.optimize_config_path - + args = arg_parser.parse_args() app = create_app() custom_openapi(app) - create_interface(config=cfg, default_args=default_args) - run_api(app=app, - host=args.host, - port=args.port, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile,) + create_interface(config=cfg, default_args=cfg) + run_api( + app=app, + host=args.host, + port=args.port, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ) if __name__ == "__main__": diff --git a/ktransformers/website/package-lock.json b/ktransformers/website/package-lock.json index 7d77ec1..444d585 100644 --- a/ktransformers/website/package-lock.json +++ b/ktransformers/website/package-lock.json @@ -4412,8 +4412,9 @@ }, "node_modules/@vue/cli": { "version": "5.0.8", - "resolved": "https://registry.npmmirror.com/@vue/cli/-/cli-5.0.8.tgz", + "resolved": "https://registry.npmjs.org/@vue/cli/-/cli-5.0.8.tgz", "integrity": "sha512-c/QKPdC09bYkW22m/boXkLaiz10z0Z2WHZO7zEeNdfSduqyWINZhKc6hVQU3Vk0NXW7BJAd7zWmcUrC8L9TuAA==", + "license": "MIT", "dependencies": { "@types/ejs": "^3.0.6", "@types/inquirer": "^8.1.3", diff --git a/pyproject.toml b/pyproject.toml index 8070241..028c6a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,4 +69,8 @@ ktransformers = "ktransformers.server.main:main" [tool.setuptools.packages.find] where = ["./", ] -include = ["ktransformers"] \ No newline at end of file +include = ["ktransformers"] +[tool.black] +line-length = 120 +preview = true +unstable = true \ No newline at end of file