mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
Move KV cache creation to balance_serve
This commit is contained in:
parent
8770b6d573
commit
38e841900d
2 changed files with 45 additions and 38 deletions
|
@ -30,6 +30,7 @@ from ktransformers.server.balance_serve.sched_rpc import SchedulerClient
|
|||
from ktransformers.server.balance_serve.settings import sched_ext
|
||||
from torch.multiprocessing import Queue
|
||||
import torch.multiprocessing as mp
|
||||
from multiprocessing.synchronize import Event
|
||||
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||||
from ktransformers.server.utils.multi_timer import Profiler
|
||||
import zmq
|
||||
|
@ -41,8 +42,10 @@ import threading
|
|||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, Request
|
||||
import os
|
||||
|
||||
|
||||
import pickle
|
||||
import subprocess
|
||||
import tempfile
|
||||
import atexit
|
||||
|
||||
ktransformer_rules_dir = (
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/")
|
||||
|
@ -99,7 +102,7 @@ class Engine:
|
|||
sampler: Sampler
|
||||
query_manager: QueryManager
|
||||
cache: KDeepSeekV3Cache
|
||||
def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None):
|
||||
def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None, kvcache_event: Event = None):
|
||||
self.args = args
|
||||
|
||||
# 子进程和父进程无法共享 config 变量
|
||||
|
@ -115,14 +118,6 @@ class Engine:
|
|||
|
||||
self.gen_queue = generated_token_queue
|
||||
|
||||
print(f"Getting inference context from sched_client.")
|
||||
inference_context = self.sched_client.get_inference_context_raw()
|
||||
print(f"Got inference context, sending it to subscribers.")
|
||||
inference_context = self.sched_client.rebuild_inferece_context(inference_context)
|
||||
self.cache.load(inference_context)
|
||||
print(f"kv_cache loaded successfully.")
|
||||
|
||||
self.block_num = inference_context.k_cache[0].size(1)
|
||||
with torch.device("meta"):
|
||||
if config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||
self.model = KDeepseekV3ForCausalLM(config, self.cache)
|
||||
|
@ -165,6 +160,17 @@ class Engine:
|
|||
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
||||
|
||||
self.model.eval()
|
||||
kvcache_event.set()
|
||||
# load kvcache
|
||||
print(f"Getting inference context from sched_client.")
|
||||
inference_context = self.sched_client.get_inference_context_raw()
|
||||
print(f"Got inference context, sending it to subscribers.")
|
||||
inference_context = self.sched_client.rebuild_inferece_context(inference_context)
|
||||
self.cache.load(inference_context)
|
||||
print(f"kv_cache loaded successfully.")
|
||||
|
||||
|
||||
self.block_num = inference_context.k_cache[0].size(1)
|
||||
#@TODO add config
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
|
||||
|
||||
|
@ -240,8 +246,8 @@ class BalanceServeThreadContext(ThreadContext):
|
|||
return local_messages
|
||||
|
||||
|
||||
def run_engine(args, token_queue, broadcast_endpoint, event):
|
||||
engine = Engine(args, token_queue, broadcast_endpoint)
|
||||
def run_engine(args, token_queue, broadcast_endpoint, event, kvcache_event):
|
||||
engine = Engine(args, token_queue, broadcast_endpoint, kvcache_event)
|
||||
if args.use_cuda_graph:
|
||||
engine.model_runner.warmup()
|
||||
|
||||
|
@ -278,10 +284,34 @@ class BalanceServeInterface(BackendInterfaceBase):
|
|||
self.streamer = TextStreamer(self.tokenizer)
|
||||
|
||||
start_event = ctx.Event()
|
||||
kvcache_event = ctx.Event()
|
||||
|
||||
p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event))
|
||||
p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event, kvcache_event))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
kvcache_event.wait()
|
||||
|
||||
|
||||
def cleanup():
|
||||
if sched_process.poll() is None:
|
||||
sched_process.terminate()
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
pickle.dump(args, temp_file)
|
||||
temp_file_path = temp_file.name
|
||||
current_file = __file__
|
||||
target_file = os.path.join(os.path.dirname(current_file), "..", "..", "balance_serve", "sched_rpc.py")
|
||||
target_file = os.path.normpath(target_file)
|
||||
log_path = os.path.join(args.log_dir, "rpc.log")
|
||||
log = open(log_path, "a")
|
||||
sched_process = subprocess.Popen(
|
||||
["python3", target_file, "--config", temp_file_path],
|
||||
stdout=log,
|
||||
stderr=log
|
||||
)
|
||||
print("sched_rpc started with PID:", sched_process.pid)
|
||||
atexit.register(cleanup)
|
||||
|
||||
start_event.wait()
|
||||
|
||||
def get_sampling_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None) -> tuple[float, float]:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue