Move KV cache creation to balance_serve

This commit is contained in:
qiyuxinlin 2025-04-18 10:10:07 +00:00
parent 8770b6d573
commit 38e841900d
2 changed files with 45 additions and 38 deletions

View file

@ -30,6 +30,7 @@ from ktransformers.server.balance_serve.sched_rpc import SchedulerClient
from ktransformers.server.balance_serve.settings import sched_ext from ktransformers.server.balance_serve.settings import sched_ext
from torch.multiprocessing import Queue from torch.multiprocessing import Queue
import torch.multiprocessing as mp import torch.multiprocessing as mp
from multiprocessing.synchronize import Event
from ktransformers.server.schemas.endpoints.chat import RawUsage from ktransformers.server.schemas.endpoints.chat import RawUsage
from ktransformers.server.utils.multi_timer import Profiler from ktransformers.server.utils.multi_timer import Profiler
import zmq import zmq
@ -41,8 +42,10 @@ import threading
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
import os import os
import pickle
import subprocess
import tempfile
import atexit
ktransformer_rules_dir = ( ktransformer_rules_dir = (
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/") os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/")
@ -99,7 +102,7 @@ class Engine:
sampler: Sampler sampler: Sampler
query_manager: QueryManager query_manager: QueryManager
cache: KDeepSeekV3Cache cache: KDeepSeekV3Cache
def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None): def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None, kvcache_event: Event = None):
self.args = args self.args = args
# 子进程和父进程无法共享 config 变量 # 子进程和父进程无法共享 config 变量
@ -115,14 +118,6 @@ class Engine:
self.gen_queue = generated_token_queue self.gen_queue = generated_token_queue
print(f"Getting inference context from sched_client.")
inference_context = self.sched_client.get_inference_context_raw()
print(f"Got inference context, sending it to subscribers.")
inference_context = self.sched_client.rebuild_inferece_context(inference_context)
self.cache.load(inference_context)
print(f"kv_cache loaded successfully.")
self.block_num = inference_context.k_cache[0].size(1)
with torch.device("meta"): with torch.device("meta"):
if config.architectures[0] == "DeepseekV3ForCausalLM": if config.architectures[0] == "DeepseekV3ForCausalLM":
self.model = KDeepseekV3ForCausalLM(config, self.cache) self.model = KDeepseekV3ForCausalLM(config, self.cache)
@ -165,6 +160,17 @@ class Engine:
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
self.model.eval() self.model.eval()
kvcache_event.set()
# load kvcache
print(f"Getting inference context from sched_client.")
inference_context = self.sched_client.get_inference_context_raw()
print(f"Got inference context, sending it to subscribers.")
inference_context = self.sched_client.rebuild_inferece_context(inference_context)
self.cache.load(inference_context)
print(f"kv_cache loaded successfully.")
self.block_num = inference_context.k_cache[0].size(1)
#@TODO add config #@TODO add config
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num) self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
@ -240,8 +246,8 @@ class BalanceServeThreadContext(ThreadContext):
return local_messages return local_messages
def run_engine(args, token_queue, broadcast_endpoint, event): def run_engine(args, token_queue, broadcast_endpoint, event, kvcache_event):
engine = Engine(args, token_queue, broadcast_endpoint) engine = Engine(args, token_queue, broadcast_endpoint, kvcache_event)
if args.use_cuda_graph: if args.use_cuda_graph:
engine.model_runner.warmup() engine.model_runner.warmup()
@ -278,10 +284,34 @@ class BalanceServeInterface(BackendInterfaceBase):
self.streamer = TextStreamer(self.tokenizer) self.streamer = TextStreamer(self.tokenizer)
start_event = ctx.Event() start_event = ctx.Event()
kvcache_event = ctx.Event()
p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event)) p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event, kvcache_event))
p.start() p.start()
processes.append(p) processes.append(p)
kvcache_event.wait()
def cleanup():
if sched_process.poll() is None:
sched_process.terminate()
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
pickle.dump(args, temp_file)
temp_file_path = temp_file.name
current_file = __file__
target_file = os.path.join(os.path.dirname(current_file), "..", "..", "balance_serve", "sched_rpc.py")
target_file = os.path.normpath(target_file)
log_path = os.path.join(args.log_dir, "rpc.log")
log = open(log_path, "a")
sched_process = subprocess.Popen(
["python3", target_file, "--config", temp_file_path],
stdout=log,
stderr=log
)
print("sched_rpc started with PID:", sched_process.pid)
atexit.register(cleanup)
start_event.wait() start_event.wait()
def get_sampling_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None) -> tuple[float, float]: def get_sampling_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None) -> tuple[float, float]:

View file

@ -5,7 +5,6 @@ from fastapi.staticfiles import StaticFiles
import uvicorn.logging import uvicorn.logging
import uvicorn import uvicorn
import sys import sys
import atexit
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.args import ArgumentParser from ktransformers.server.args import ArgumentParser
@ -17,8 +16,7 @@ from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.api import router, post_db_creation_operations from ktransformers.server.api import router, post_db_creation_operations
from ktransformers.server.utils.sql_utils import Base, SQLUtil from ktransformers.server.utils.sql_utils import Base, SQLUtil
from ktransformers.server.config.log import logger from ktransformers.server.config.log import logger
import subprocess
import tempfile
def mount_app_routes(mount_app: FastAPI): def mount_app_routes(mount_app: FastAPI):
sql_util = SQLUtil() sql_util = SQLUtil()
@ -108,27 +106,6 @@ def main():
arg_parser = ArgumentParser(cfg) arg_parser = ArgumentParser(cfg)
args = arg_parser.parse_args() args = arg_parser.parse_args()
if args.backend_type == "balance_serve":
import pickle
def cleanup():
if sched_process.poll() is None:
sched_process.terminate()
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
pickle.dump(args, temp_file)
temp_file_path = temp_file.name
current_file = __file__
target_file = os.path.join(os.path.dirname(current_file), "balance_serve", "sched_rpc.py")
target_file = os.path.normpath(target_file)
log_path = os.path.join(args.log_dir, "rpc.log")
log = open(log_path, "a")
sched_process = subprocess.Popen(
["python3", target_file, "--config", temp_file_path],
stdout=log,
stderr=log
)
print("sched_rpc started with PID:", sched_process.pid)
atexit.register(cleanup)
create_interface(config=cfg, default_args=cfg) create_interface(config=cfg, default_args=cfg)
app = create_app() app = create_app()
custom_openapi(app) custom_openapi(app)