diff --git a/ktransformers/server/backend/interfaces/balance_serve.py b/ktransformers/server/backend/interfaces/balance_serve.py index 74c680d..008431e 100644 --- a/ktransformers/server/backend/interfaces/balance_serve.py +++ b/ktransformers/server/backend/interfaces/balance_serve.py @@ -46,6 +46,8 @@ import pickle import subprocess import tempfile import atexit +import signal + ktransformer_rules_dir = ( os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/") @@ -55,6 +57,7 @@ default_optimize_rules = { "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct-serve.yaml", } + async def chat_stream(queue: asyncio.Queue, tokenizer: AutoTokenizer): streamer = TextStreamer(tokenizer) while True: @@ -293,10 +296,6 @@ class BalanceServeInterface(BackendInterfaceBase): kvcache_event.wait() - def cleanup(): - if sched_process.poll() is None: - sched_process.terminate() - with tempfile.NamedTemporaryFile(delete=False) as temp_file: pickle.dump(args, temp_file) temp_file_path = temp_file.name @@ -311,7 +310,27 @@ class BalanceServeInterface(BackendInterfaceBase): stderr=log ) print("sched_rpc started with PID:", sched_process.pid) - atexit.register(cleanup) + + def signal_handler(signum, frame): + print(f"Received signal {signum}, shutting down...") + cleanup() + os._exit(0) + + def cleanup(): + print("Cleaning up...") + + for p in processes: + if p.is_alive(): + print(f"Terminating subprocess {p.pid}") + p.terminate() + p.join() + + if sched_process and sched_process.poll() is None: + print(f"Terminating sched_process {sched_process.pid}") + sched_process.terminate() + sched_process.wait() + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) start_event.wait()