kvcache-ai-ktransformers/ktransformers/server/config/config.py
sean.su 8699109129 Refactor the chat interface to support tool calling and parameter processing
Defined new data structures in chat.py to replace OpenAI's original implementation, adding support for tool calling.

Implemented logic for extracting and processing tool calls, enabling dynamic function invocation during conversations.

Added methods in balance_serve.py to retrieve sampling parameters, handling default values and edge cases.

Updated ktransformers.py and transformers.py to support the passing of tool parameters.

Modified the default value of top_p in config.py to 1.0 to increase generation diversity.

Extended the message model in chat.py to support the transmission of tool call information.

These changes enhance the system's flexibility and functionality, enabling more complex interaction patterns.
2025-04-14 15:23:37 +08:00

208 lines
9.8 KiB
Python

#!/usr/bin/env python
# coding=utf-8
"""
Description :
Author : unicornchan
Date : 2024-06-11 16:35:42
Version : 1.0.0
LastEditors : WuHao
LastEditTime : 2024-08-12 06:31:14
"""
import os
import shutil
import yaml
import psutil
from ktransformers.server.config.singleton import Singleton
from typing import Optional
class Config(metaclass=Singleton):
"""Singleton pattern Config class, used to get all configurations."""
CONFIG_FILE_NAME = "config.yaml"
@staticmethod
def load() -> dict:
"""load config file
Returns:
dict: all configs
"""
base_path: str = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
config_yaml: str = os.path.join(base_path, "configs", Config.CONFIG_FILE_NAME)
user_path: str = os.path.expanduser("~")
localstore_path: str = os.path.join(user_path, ".ktransformers")
kvc2_config_dir = os.path.join(localstore_path, "kvc2")
config_path: str = os.path.join(localstore_path, Config.CONFIG_FILE_NAME)
if not os.path.exists(config_yaml):
print(f"Can't find config file, {config_yaml}")
exit(-1)
if not os.path.exists(localstore_path):
os.mkdir(localstore_path)
if not os.path.exists(kvc2_config_dir):
os.mkdir(kvc2_config_dir)
if not os.path.exists(config_path):
shutil.copyfile(config_yaml, config_path)
with open(config_path, "r", encoding="utf-8") as fp:
config = yaml.safe_load(fp)
return config
@staticmethod
def to_path(path: str) -> str:
"""
process file path
"""
base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
real_path = path if os.path.isabs(path) else os.path.join(base_path, path)
return real_path
def __init__(self):
cfg = Config.load()
self.base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
self.user_path: str = os.path.expanduser("~")
self.localstore_path: str = os.path.join(self.user_path, ".ktransformers")
# log configs
self.log_dir = os.path.join(self.localstore_path, cfg["log"]["dir"])
if not os.path.exists(self.log_dir):
os.mkdir(self.log_dir)
self.log_file = cfg["log"]["file"]
self.log_level = cfg["log"]["level"]
self.backup_count = cfg["log"]["backup_count"]
self.kvc2_config_dir = os.path.join(self.localstore_path, "kvc2")
# server configs
self.server: dict = cfg.get("server", {})
self.server_ip = self.server.get("ip", "0.0.0.0")
self.server_port = self.server.get("port", 9016)
self.api_key = self.server.get("api_key", "")
# db configs
self.db_configs: dict = cfg.get("db", {})
self.db_type = self.db_configs.get("type", "")
self.db_host = self.localstore_path
self.db_port = self.db_configs.get("port", "")
self.db_name = self.db_configs.get("database", "")
self.db_pool_size = self.db_configs.get("pool_size")
self.db_database = self.db_configs.get("database", "")
# user config
self.user_config: dict = cfg.get("user", {})
self.user_secret_key = self.user_config.get("secret_key", "")
self.user_algorithm = self.user_config.get("algorithm", "")
self.user_force_think = self.user_config.get("force_think", False)
# model config
self.model: dict = cfg.get("model", {})
self.backend_type: str = self.model.get("type", "transformers")
self.model_dir: str = self.model.get("path", "")
# to make sure it consistent with previous version
self.model_path: str = self.model_dir
self.model_name: str = self.model.get("name", "")
self.model_device: str = self.model.get("device", "cuda:0")
self.gguf_path: Optional[str] = self.model.get("gguf_path", None)
self.use_cuda_graph = self.model.get("use_cuda_graph", True)
self.trust_remote_code = self.model.get("trust_remote_code", True)
# self.model_cache_lens = self.model.get("cache_lens")
self.optimize_config_path: Optional[str] = self.model.get(
"optimize_config_path", None
)
self.max_new_tokens = self.model.get("max_new_tokens", 2000)
self.json_mode = self.model.get("json_mode", False)
self.healing = self.model.get("healing", False)
self.ban_strings: Optional[list] = self.model.get("ban_strings", None)
self.gpu_split: Optional[str] = self.model.get("gpu_split", None)
self.length: Optional[int] = self.model.get("length", None)
self.rope_scale: Optional[float] = self.model.get("rope_scale", None)
self.rope_alpha: Optional[float] = self.model.get("rope_alpha", None)
self.no_flash_attn = self.model.get("no_flash_attn", False)
self.low_mem = self.model.get("low_mem", False)
self.experts_per_token: Optional[int] = self.model.get("experts_per_token", None)
self.load_q4 = self.model.get("load_q4", False)
self.fast_safetensors = self.model.get("fast_safetensors", False)
self.draft_model_dir: Optional[str] = self.model.get("draft_model_dir", None)
self.no_draft_scale = self.model.get("no_draft_scale", False)
self.modes = self.model.get("modes", False)
self.mode = self.model.get("mode", "llama")
self.username = self.model.get("username", "User")
self.botname = self.model.get("botname", "Chatbort")
self.system_prompt: Optional[str] = self.model.get("system_prompt", None)
self.temperature = self.model.get("temperature", 0.95)
self.smoothing_factor = self.model.get("smoothing_factor", 0.0)
self.dynamic_temperature: Optional[str] = self.model.get("dynamic_temperature", None)
self.top_k = self.model.get("top_k", 50)
self.top_p = self.model.get("top_p", 1.0)
self.top_a = self.model.get("top_a", 0.0)
self.skew = self.model.get("skew", 0.0)
self.typical = self.model.get("typical", 0.0)
self.repetition_penalty = self.model.get("repetition_penalty", 1.01)
self.frequency_penalty = self.model.get("frequency_penalty", 0.0)
self.presence_penalty = self.model.get("presence_penalty", 0.0)
self.response_chunk = self.model.get("response_chunk", 250)
self.no_code_formatting = self.model.get("no_code_formatting", False)
self.cache_8bit = self.model.get("cache_8bit", False)
self.cache_q4 = self.model.get("cache_q4", True)
self.ngram_decoding = self.model.get("ngram_decoding", False)
self.print_timings = self.model.get("print_timings", False)
self.amnesia = self.model.get("amnesia", False)
self.batch_size = self.model.get("batch_size", 1)
self.cache_lens = self.model.get("cache_lens", 4096)
self.device = self.model.get("device", "cuda:2")
# web config
self.web: dict = cfg.get("web", {})
self.web_cross_domain: bool = self.web.get("open_cross_domain", True)
self.mount_web: bool = self.web.get("mount", False)
# ext
self.ext: dict = cfg.get("ext", {})
self.cpu_infer = psutil.cpu_count(logical=False) - 3
# file config
self.local_store_configs: dict = cfg.get("local_store", {})
self.file_upload_dir: str = os.path.join(
self.localstore_path, self.local_store_configs.get("file_upload_dir", "")
)
self.assistant_store_dir: str = os.path.join(
self.localstore_path, self.local_store_configs.get("assistant_store_dir", "")
)
# long context config
self.long_context_config: dict = cfg.get("long_context", {})
self.max_seq_len = self.long_context_config.get("max_seq_len", 32000)
self.block_size = self.long_context_config.get("block_size", 128)
self.local_windows_len = self.long_context_config.get("local_windows_len", 4096)
self.second_select_num = self.long_context_config.get("second_select_num", 32)
self.anchor_type = self.long_context_config.get("anchor_type", "DYNAMIC")
self.kv_type = self.long_context_config.get("kv_type", "FP16")
self.dense_layer_num = self.long_context_config.get("dense_layer_num", 2)
self.anchor_num = self.long_context_config.get("anchor_num", 1)
self.preselect_block = self.long_context_config.get("preselect_block", True)
self.head_select_mode = self.long_context_config.get("head_select_mode", "SHARED")
self.preselect_block_count = self.long_context_config.get("preselect_block_count", 32)
self.layer_step = self.long_context_config.get("layer_step", 1)
self.token_step = self.long_context_config.get("token_step", 100)
# local chat
self.local_chat_config: dict = cfg.get("local_chat", {})
self.prompt_file = self.local_chat_config.get("prompt_file", None)
# asyncserver
self.sched_strategy = cfg["async_server"]["sched_strategy"]
self.sched_port = cfg["async_server"]["sched_port"]
self.sched_metrics_port = cfg["async_server"]["sched_metrics_port"]
self.kvc2_metrics_port = cfg["async_server"]["kvc2_metrics_port"]
self.max_batch_size = cfg["async_server"]["max_batch_size"]
self.page_size = cfg["attn"]["page_size"]
self.chunk_size = cfg["attn"]["chunk_size"]
self.memory_gpu_only = cfg["kvc2"]["gpu_only"]
self.cache_lens = ((self.cache_lens + self.page_size - 1) // self.page_size) * self.page_size
self.gpu_memory_size = 2*576*61*self.cache_lens
self.utilization_percentage = 1.0 #cfg["kvc2"]["utilization_percentage"]
self.cpu_memory_size_GB = cfg["kvc2"]["cpu_memory_size_GB"]
# only support 2 prefill task
self.max_prefill_batch_size = 2
self.max_decode_batch_size = self.max_batch_size - self.max_prefill_batch_size