mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
Move KV cache creation to balance_serve
This commit is contained in:
parent
8770b6d573
commit
38e841900d
2 changed files with 45 additions and 38 deletions
|
@ -30,6 +30,7 @@ from ktransformers.server.balance_serve.sched_rpc import SchedulerClient
|
||||||
from ktransformers.server.balance_serve.settings import sched_ext
|
from ktransformers.server.balance_serve.settings import sched_ext
|
||||||
from torch.multiprocessing import Queue
|
from torch.multiprocessing import Queue
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
from multiprocessing.synchronize import Event
|
||||||
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||||||
from ktransformers.server.utils.multi_timer import Profiler
|
from ktransformers.server.utils.multi_timer import Profiler
|
||||||
import zmq
|
import zmq
|
||||||
|
@ -41,8 +42,10 @@ import threading
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import atexit
|
||||||
|
|
||||||
ktransformer_rules_dir = (
|
ktransformer_rules_dir = (
|
||||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/")
|
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/")
|
||||||
|
@ -99,7 +102,7 @@ class Engine:
|
||||||
sampler: Sampler
|
sampler: Sampler
|
||||||
query_manager: QueryManager
|
query_manager: QueryManager
|
||||||
cache: KDeepSeekV3Cache
|
cache: KDeepSeekV3Cache
|
||||||
def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None):
|
def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None, kvcache_event: Event = None):
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
# 子进程和父进程无法共享 config 变量
|
# 子进程和父进程无法共享 config 变量
|
||||||
|
@ -115,14 +118,6 @@ class Engine:
|
||||||
|
|
||||||
self.gen_queue = generated_token_queue
|
self.gen_queue = generated_token_queue
|
||||||
|
|
||||||
print(f"Getting inference context from sched_client.")
|
|
||||||
inference_context = self.sched_client.get_inference_context_raw()
|
|
||||||
print(f"Got inference context, sending it to subscribers.")
|
|
||||||
inference_context = self.sched_client.rebuild_inferece_context(inference_context)
|
|
||||||
self.cache.load(inference_context)
|
|
||||||
print(f"kv_cache loaded successfully.")
|
|
||||||
|
|
||||||
self.block_num = inference_context.k_cache[0].size(1)
|
|
||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
if config.architectures[0] == "DeepseekV3ForCausalLM":
|
if config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||||
self.model = KDeepseekV3ForCausalLM(config, self.cache)
|
self.model = KDeepseekV3ForCausalLM(config, self.cache)
|
||||||
|
@ -165,6 +160,17 @@ class Engine:
|
||||||
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
||||||
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
kvcache_event.set()
|
||||||
|
# load kvcache
|
||||||
|
print(f"Getting inference context from sched_client.")
|
||||||
|
inference_context = self.sched_client.get_inference_context_raw()
|
||||||
|
print(f"Got inference context, sending it to subscribers.")
|
||||||
|
inference_context = self.sched_client.rebuild_inferece_context(inference_context)
|
||||||
|
self.cache.load(inference_context)
|
||||||
|
print(f"kv_cache loaded successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
self.block_num = inference_context.k_cache[0].size(1)
|
||||||
#@TODO add config
|
#@TODO add config
|
||||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
|
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
|
||||||
|
|
||||||
|
@ -240,8 +246,8 @@ class BalanceServeThreadContext(ThreadContext):
|
||||||
return local_messages
|
return local_messages
|
||||||
|
|
||||||
|
|
||||||
def run_engine(args, token_queue, broadcast_endpoint, event):
|
def run_engine(args, token_queue, broadcast_endpoint, event, kvcache_event):
|
||||||
engine = Engine(args, token_queue, broadcast_endpoint)
|
engine = Engine(args, token_queue, broadcast_endpoint, kvcache_event)
|
||||||
if args.use_cuda_graph:
|
if args.use_cuda_graph:
|
||||||
engine.model_runner.warmup()
|
engine.model_runner.warmup()
|
||||||
|
|
||||||
|
@ -278,10 +284,34 @@ class BalanceServeInterface(BackendInterfaceBase):
|
||||||
self.streamer = TextStreamer(self.tokenizer)
|
self.streamer = TextStreamer(self.tokenizer)
|
||||||
|
|
||||||
start_event = ctx.Event()
|
start_event = ctx.Event()
|
||||||
|
kvcache_event = ctx.Event()
|
||||||
|
|
||||||
p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event))
|
p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event, kvcache_event))
|
||||||
p.start()
|
p.start()
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
|
kvcache_event.wait()
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup():
|
||||||
|
if sched_process.poll() is None:
|
||||||
|
sched_process.terminate()
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||||
|
pickle.dump(args, temp_file)
|
||||||
|
temp_file_path = temp_file.name
|
||||||
|
current_file = __file__
|
||||||
|
target_file = os.path.join(os.path.dirname(current_file), "..", "..", "balance_serve", "sched_rpc.py")
|
||||||
|
target_file = os.path.normpath(target_file)
|
||||||
|
log_path = os.path.join(args.log_dir, "rpc.log")
|
||||||
|
log = open(log_path, "a")
|
||||||
|
sched_process = subprocess.Popen(
|
||||||
|
["python3", target_file, "--config", temp_file_path],
|
||||||
|
stdout=log,
|
||||||
|
stderr=log
|
||||||
|
)
|
||||||
|
print("sched_rpc started with PID:", sched_process.pid)
|
||||||
|
atexit.register(cleanup)
|
||||||
|
|
||||||
start_event.wait()
|
start_event.wait()
|
||||||
|
|
||||||
def get_sampling_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None) -> tuple[float, float]:
|
def get_sampling_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None) -> tuple[float, float]:
|
||||||
|
|
|
@ -5,7 +5,6 @@ from fastapi.staticfiles import StaticFiles
|
||||||
import uvicorn.logging
|
import uvicorn.logging
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import sys
|
import sys
|
||||||
import atexit
|
|
||||||
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from ktransformers.server.args import ArgumentParser
|
from ktransformers.server.args import ArgumentParser
|
||||||
|
@ -17,8 +16,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||||
from ktransformers.server.api import router, post_db_creation_operations
|
from ktransformers.server.api import router, post_db_creation_operations
|
||||||
from ktransformers.server.utils.sql_utils import Base, SQLUtil
|
from ktransformers.server.utils.sql_utils import Base, SQLUtil
|
||||||
from ktransformers.server.config.log import logger
|
from ktransformers.server.config.log import logger
|
||||||
import subprocess
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
def mount_app_routes(mount_app: FastAPI):
|
def mount_app_routes(mount_app: FastAPI):
|
||||||
sql_util = SQLUtil()
|
sql_util = SQLUtil()
|
||||||
|
@ -108,27 +106,6 @@ def main():
|
||||||
arg_parser = ArgumentParser(cfg)
|
arg_parser = ArgumentParser(cfg)
|
||||||
|
|
||||||
args = arg_parser.parse_args()
|
args = arg_parser.parse_args()
|
||||||
if args.backend_type == "balance_serve":
|
|
||||||
import pickle
|
|
||||||
def cleanup():
|
|
||||||
if sched_process.poll() is None:
|
|
||||||
sched_process.terminate()
|
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
|
||||||
pickle.dump(args, temp_file)
|
|
||||||
temp_file_path = temp_file.name
|
|
||||||
current_file = __file__
|
|
||||||
target_file = os.path.join(os.path.dirname(current_file), "balance_serve", "sched_rpc.py")
|
|
||||||
target_file = os.path.normpath(target_file)
|
|
||||||
log_path = os.path.join(args.log_dir, "rpc.log")
|
|
||||||
log = open(log_path, "a")
|
|
||||||
sched_process = subprocess.Popen(
|
|
||||||
["python3", target_file, "--config", temp_file_path],
|
|
||||||
stdout=log,
|
|
||||||
stderr=log
|
|
||||||
)
|
|
||||||
print("sched_rpc started with PID:", sched_process.pid)
|
|
||||||
atexit.register(cleanup)
|
|
||||||
create_interface(config=cfg, default_args=cfg)
|
create_interface(config=cfg, default_args=cfg)
|
||||||
app = create_app()
|
app = create_app()
|
||||||
custom_openapi(app)
|
custom_openapi(app)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue