diff --git a/ktransformers/server/backend/interfaces/balance_serve.py b/ktransformers/server/backend/interfaces/balance_serve.py index 50341a3..582fabb 100644 --- a/ktransformers/server/backend/interfaces/balance_serve.py +++ b/ktransformers/server/backend/interfaces/balance_serve.py @@ -30,6 +30,7 @@ from ktransformers.server.balance_serve.sched_rpc import SchedulerClient from ktransformers.server.balance_serve.settings import sched_ext from torch.multiprocessing import Queue import torch.multiprocessing as mp +from multiprocessing.synchronize import Event from ktransformers.server.schemas.endpoints.chat import RawUsage from ktransformers.server.utils.multi_timer import Profiler import zmq @@ -41,8 +42,10 @@ import threading from contextlib import asynccontextmanager from fastapi import FastAPI, Request import os - - +import pickle +import subprocess +import tempfile +import atexit ktransformer_rules_dir = ( os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/") @@ -99,7 +102,7 @@ class Engine: sampler: Sampler query_manager: QueryManager cache: KDeepSeekV3Cache - def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None): + def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None, kvcache_event: Event = None): self.args = args # 子进程和父进程无法共享 config 变量 @@ -115,14 +118,6 @@ class Engine: self.gen_queue = generated_token_queue - print(f"Getting inference context from sched_client.") - inference_context = self.sched_client.get_inference_context_raw() - print(f"Got inference context, sending it to subscribers.") - inference_context = self.sched_client.rebuild_inferece_context(inference_context) - self.cache.load(inference_context) - print(f"kv_cache loaded successfully.") - - self.block_num = inference_context.k_cache[0].size(1) with torch.device("meta"): if config.architectures[0] == "DeepseekV3ForCausalLM": self.model = KDeepseekV3ForCausalLM(config, self.cache) @@ -165,6 +160,17 @@ class Engine: self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id self.model.eval() + kvcache_event.set() + # load kvcache + print(f"Getting inference context from sched_client.") + inference_context = self.sched_client.get_inference_context_raw() + print(f"Got inference context, sending it to subscribers.") + inference_context = self.sched_client.rebuild_inferece_context(inference_context) + self.cache.load(inference_context) + print(f"kv_cache loaded successfully.") + + + self.block_num = inference_context.k_cache[0].size(1) #@TODO add config self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num) @@ -240,8 +246,8 @@ class BalanceServeThreadContext(ThreadContext): return local_messages -def run_engine(args, token_queue, broadcast_endpoint, event): - engine = Engine(args, token_queue, broadcast_endpoint) +def run_engine(args, token_queue, broadcast_endpoint, event, kvcache_event): + engine = Engine(args, token_queue, broadcast_endpoint, kvcache_event) if args.use_cuda_graph: engine.model_runner.warmup() @@ -278,10 +284,34 @@ class BalanceServeInterface(BackendInterfaceBase): self.streamer = TextStreamer(self.tokenizer) start_event = ctx.Event() + kvcache_event = ctx.Event() - p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event)) + p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event, kvcache_event)) p.start() processes.append(p) + kvcache_event.wait() + + + def cleanup(): + if sched_process.poll() is None: + sched_process.terminate() + + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + pickle.dump(args, temp_file) + temp_file_path = temp_file.name + current_file = __file__ + target_file = os.path.join(os.path.dirname(current_file), "..", "..", "balance_serve", "sched_rpc.py") + target_file = os.path.normpath(target_file) + log_path = os.path.join(args.log_dir, "rpc.log") + log = open(log_path, "a") + sched_process = subprocess.Popen( + ["python3", target_file, "--config", temp_file_path], + stdout=log, + stderr=log + ) + print("sched_rpc started with PID:", sched_process.pid) + atexit.register(cleanup) + start_event.wait() def get_sampling_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None) -> tuple[float, float]: diff --git a/ktransformers/server/main.py b/ktransformers/server/main.py index 8108a3c..3341ee9 100644 --- a/ktransformers/server/main.py +++ b/ktransformers/server/main.py @@ -5,7 +5,6 @@ from fastapi.staticfiles import StaticFiles import uvicorn.logging import uvicorn import sys -import atexit project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) from fastapi.middleware.cors import CORSMiddleware from ktransformers.server.args import ArgumentParser @@ -17,8 +16,7 @@ from fastapi.middleware.cors import CORSMiddleware from ktransformers.server.api import router, post_db_creation_operations from ktransformers.server.utils.sql_utils import Base, SQLUtil from ktransformers.server.config.log import logger -import subprocess -import tempfile + def mount_app_routes(mount_app: FastAPI): sql_util = SQLUtil() @@ -108,27 +106,6 @@ def main(): arg_parser = ArgumentParser(cfg) args = arg_parser.parse_args() - if args.backend_type == "balance_serve": - import pickle - def cleanup(): - if sched_process.poll() is None: - sched_process.terminate() - - with tempfile.NamedTemporaryFile(delete=False) as temp_file: - pickle.dump(args, temp_file) - temp_file_path = temp_file.name - current_file = __file__ - target_file = os.path.join(os.path.dirname(current_file), "balance_serve", "sched_rpc.py") - target_file = os.path.normpath(target_file) - log_path = os.path.join(args.log_dir, "rpc.log") - log = open(log_path, "a") - sched_process = subprocess.Popen( - ["python3", target_file, "--config", temp_file_path], - stdout=log, - stderr=log - ) - print("sched_rpc started with PID:", sched_process.pid) - atexit.register(cleanup) create_interface(config=cfg, default_args=cfg) app = create_app() custom_openapi(app)