add balance-serve, support concurrence

This commit is contained in:
Atream 2025-03-31 22:55:32 +08:00
parent 8d0292aa44
commit 25cee5810e
196 changed files with 22077 additions and 565 deletions

View file

@ -5,24 +5,20 @@ from fastapi.staticfiles import StaticFiles
import uvicorn.logging
import uvicorn
import sys
import atexit
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
sys.path.insert(0, project_dir)
from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.args import ArgumentParser
from ktransformers.server.config.config import Config
from ktransformers.server.utils.create_interface import create_interface
from ktransformers.server.backend.args import default_args
from ktransformers.server.utils.create_interface import create_interface, GlobalInterface
from fastapi.openapi.utils import get_openapi
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.api import router, post_db_creation_operations
from ktransformers.server.utils.sql_utils import Base, SQLUtil
from ktransformers.server.config.log import logger
import subprocess
import tempfile
def mount_app_routes(mount_app: FastAPI):
sql_util = SQLUtil()
@ -34,7 +30,10 @@ def mount_app_routes(mount_app: FastAPI):
def create_app():
cfg = Config()
app = FastAPI()
if(hasattr(GlobalInterface.interface, "lifespan")):
app = FastAPI(lifespan=GlobalInterface.interface.lifespan)
else:
app = FastAPI()
if Config().web_cross_domain:
app.add_middleware(
CORSMiddleware,
@ -108,11 +107,32 @@ def main():
arg_parser = ArgumentParser(cfg)
# 初始化消息
args = arg_parser.parse_args()
if args.backend_type == "balance_serve":
import pickle
def cleanup():
if sched_process.poll() is None:
sched_process.terminate()
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
pickle.dump(args, temp_file)
temp_file_path = temp_file.name
current_file = __file__
target_file = os.path.join(os.path.dirname(current_file), "balance_serve", "sched_rpc.py")
target_file = os.path.normpath(target_file)
log_path = os.path.join(args.log_dir, "rpc.log")
log = open(log_path, "a")
sched_process = subprocess.Popen(
["python3", target_file, "--config", temp_file_path],
stdout=log,
stderr=log
)
print("sched_rpc started with PID:", sched_process.pid)
atexit.register(cleanup)
create_interface(config=cfg, default_args=cfg)
app = create_app()
custom_openapi(app)
create_interface(config=cfg, default_args=cfg)
run_api(
app=app,
host=args.host,
@ -121,6 +141,5 @@ def main():
ssl_certfile=args.ssl_certfile,
)
if __name__ == "__main__":
main()