mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
✨: refactor local_chat and fix message slice bug in server
This commit is contained in:
parent
43fc7f44a6
commit
dd1d8667f3
13 changed files with 549 additions and 405 deletions
|
@ -3,11 +3,11 @@ import re
|
|||
from fastapi import FastAPI
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
import uvicorn.logging
|
||||
import argparse
|
||||
import uvicorn
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from ktransformers.server.args import ArgumentParser
|
||||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.server.utils.create_interface import create_interface
|
||||
from ktransformers.server.utils.create_interface import create_interface
|
||||
from ktransformers.server.backend.args import default_args
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
|
@ -44,8 +44,11 @@ def create_app():
|
|||
mount_index_routes(app)
|
||||
return app
|
||||
|
||||
|
||||
def update_web_port(config_file: str):
|
||||
ip_port_pattern = r"(localhost|((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)):[0-9]{1,5}"
|
||||
ip_port_pattern = (
|
||||
r"(localhost|((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)):[0-9]{1,5}"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f_cfg:
|
||||
web_config = f_cfg.read()
|
||||
ip_port = "localhost:" + str(Config().server_port)
|
||||
|
@ -70,14 +73,15 @@ def mount_index_routes(app: FastAPI):
|
|||
|
||||
def run_api(app, host, port, **kwargs):
|
||||
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
||||
uvicorn.run(app,
|
||||
host=host,
|
||||
port=port,
|
||||
ssl_keyfile=kwargs.get("ssl_keyfile"),
|
||||
ssl_certfile=kwargs.get("ssl_certfile"),
|
||||
)
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
ssl_keyfile=kwargs.get("ssl_keyfile"),
|
||||
ssl_certfile=kwargs.get("ssl_certfile"),
|
||||
)
|
||||
else:
|
||||
uvicorn.run(app, host=host, port=port, log_level='debug')
|
||||
uvicorn.run(app, host=host, port=port, log_level="debug")
|
||||
|
||||
|
||||
def custom_openapi(app):
|
||||
|
@ -90,53 +94,27 @@ def custom_openapi(app):
|
|||
description="We provided chat completion and openai assistant interfaces.",
|
||||
routes=app.routes,
|
||||
)
|
||||
openapi_schema["info"]["x-logo"] = {
|
||||
"url": "https://kvcache.ai/media/icon_1.png"
|
||||
}
|
||||
openapi_schema["info"]["x-logo"] = {"url": "https://kvcache.ai/media/icon_1.png"}
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
||||
def main():
|
||||
cfg = Config()
|
||||
parser = argparse.ArgumentParser(prog='kvcache.ai',
|
||||
description='Ktransformers')
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=cfg.server_port)
|
||||
parser.add_argument("--ssl_keyfile", type=str)
|
||||
parser.add_argument("--ssl_certfile", type=str)
|
||||
parser.add_argument("--web", type=bool, default=False)
|
||||
parser.add_argument("--model_name", type=str, default=cfg.model_name)
|
||||
parser.add_argument("--model_path", type=str, default=cfg.model_path)
|
||||
parser.add_argument("--device", type=str, default=cfg.model_device, help="Warning: Abandoning this parameter")
|
||||
parser.add_argument("--gguf_path", type=str, default=cfg.gguf_path)
|
||||
parser.add_argument("--optimize_config_path", default=None, type=str, required=False)
|
||||
parser.add_argument("--cpu_infer", type=int, default=cfg.cpu_infer)
|
||||
parser.add_argument("--type", type=str, default=cfg.backend_type)
|
||||
arg_parser = ArgumentParser(cfg)
|
||||
|
||||
# 初始化消息
|
||||
args = parser.parse_args()
|
||||
cfg.model_name = args.model_name
|
||||
cfg.model_path = args.model_path
|
||||
cfg.model_device = args.device
|
||||
cfg.mount_web = args.web
|
||||
cfg.server_ip = args.host
|
||||
cfg.server_port = args.port
|
||||
cfg.cpu_infer = args.cpu_infer
|
||||
cfg.backend_type = args.type
|
||||
|
||||
default_args.model_dir = args.model_path
|
||||
default_args.device = args.device
|
||||
default_args.gguf_path = args.gguf_path
|
||||
default_args.optimize_config_path = args.optimize_config_path
|
||||
|
||||
args = arg_parser.parse_args()
|
||||
app = create_app()
|
||||
custom_openapi(app)
|
||||
create_interface(config=cfg, default_args=default_args)
|
||||
run_api(app=app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,)
|
||||
create_interface(config=cfg, default_args=cfg)
|
||||
run_api(
|
||||
app=app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue