kvcache-ai-ktransformers/ktransformers/server/balance_serve/sched_rpc.py

218 lines
7.9 KiB
Python

from datetime import datetime
import os
from typing import Optional
import zmq
import pickle
import threading
import torch.multiprocessing as mp
import sys
current_file_path = os.path.abspath(__file__)
# sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
import pickle
import argparse
from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings, create_sched_settings_qwen2moe, create_sched_settings_qwen3moe
if mp.get_start_method(allow_none=True) is None:
print('set start method')
mp.set_start_method('spawn')
else:
print(f'start method already set to {mp.get_start_method(allow_none=True)}')
class SchedulerServer:
def __init__(self, settings, main_args):
# 创建 Scheduler 实例并初始化
self.sched = sched_ext.create_scheduler(settings)
# 初始化 ZeroMQ 上下文和套接字
self.context = zmq.Context()
self.frontend = self.context.socket(zmq.ROUTER)
print(f"sched zmq rpc server on port {main_args.sched_port}")
self.frontend.bind(f"tcp://*:{main_args.sched_port}")
# 创建内部的 DEALER 套接字,用于与工作线程通信
self.backend = self.context.socket(zmq.DEALER)
self.backend.bind("inproc://backend")
# 启动调度器
def run_scheduler(self):
self.sched.run()
# 停止调度器
def stop_scheduler(self):
self.sched.stop()
# 处理客户端请求
def start_proxy(self):
# 使用 ZMQ 的内置代理,将前端请求分发给后端工作线程
zmq.proxy(self.frontend, self.backend)
# 工作线程处理请求
def worker_routine(self):
worker = self.context.socket(zmq.REP)
worker.connect("inproc://backend")
while True:
try:
# 接收客户端请求
message = worker.recv()
data = pickle.loads(message)
method = data.get('method')
params = data.get('params', {})
# print(f"Received request: {method}")
if method == 'add_query':
query_add = params.get('query') # 直接是一个 QueryAdd 对象
# 添加查询
query_id = self.sched.add_query(query_add)
# 发送响应
response = {'status': 'ok', 'query_id': query_id}
worker.send(pickle.dumps(response))
elif method == 'cancel_query':
query_id = params.get('query_id')
# 假设您的 Scheduler 类实现了 cancel 方法
self.sched.cancel(query_id)
response = {'status': 'ok'}
worker.send(pickle.dumps(response))
elif method == 'update_last_batch':
updates = params.get('updates') # 直接是一个列表,包含 QueryUpdate 对象
# 更新最后一个批次
batch_todo = self.sched.update_last_batch(updates)
# 直接发送 batch_todo 对象
response = {'status': 'ok', 'batch_todo': batch_todo}
# print (batch_todo.query_lengths, batch_todo.query_ids)
worker.send(pickle.dumps(response))
elif method == 'get_inference_context':
inference_context = self.sched.get_inference_context()
data = {
"k_cache":inference_context.k_cache,
"v_cache":inference_context.v_cache
}
print(f"Serializing KVCache")
data["k_cache"] = [mp.reductions.reduce_tensor(t) for t in data['k_cache']]
data["v_cache"] = [mp.reductions.reduce_tensor(t) for t in data['v_cache']]
# print(data)
response = {'status': 'ok', 'inference_context': data}
worker.send(pickle.dumps(response))
# response['inference_context'].k_cache[0][0, 0, 0, 0, 0] = 1
# print("k_cache update")
else:
# 未知方法
response = {'status': 'error', 'message': 'Unknown method'}
worker.send(pickle.dumps(response))
except Exception as e:
# 处理异常并发送错误响应
response = {'status': 'error', 'message': str(e)}
worker.send(pickle.dumps(response))
# 启动 RPC 服务
def start_rpc_service(self):
try:
print("Scheduler RPC service is running...")
# 在单独的线程中运行调度器
threading.Thread(target=self.run_scheduler, daemon=True).start()
# 启动工作线程
for _ in range(10): # 根据需要调整线程数
threading.Thread(target=self.worker_routine, daemon=True).start()
# 启动代理,开始监听请求
self.start_proxy()
except KeyboardInterrupt:
print("Shutting down scheduler RPC service...")
self.stop_rpc_service()
# 停止 RPC 服务
def stop_rpc_service(self):
self.stop_scheduler()
self.frontend.close()
self.backend.close()
self.context.term()
def start_server(settings, main_args):
server = SchedulerServer(settings, main_args)
server.start_rpc_service()
# Add async client for webserver
class SchedulerClient:
def __init__(self, sched_port):
address=f'tcp://localhost:{sched_port}'
self.address = address
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REQ)
self.socket.connect(self.address)
print(f"Connected to server at {self.address}")
def __del__(self):
self.socket.close()
self.context.term()
def send_request(self, method, params=None):
if params is None:
params = {}
request = {
'method': method,
'params': params
}
# print(f'send request {request}')
self.socket.send(pickle.dumps(request))
response = self.socket.recv()
# print(response)
response = pickle.loads(response)
if response.get('status') == 'ok':
return response
else:
raise Exception(f"Error from server: {response.get('message')}")
def add_query(self, query):
response = self.send_request('add_query', {'query': query})
return response.get('query_id')
def cancel_query(self, query_id):
self.send_request('cancel_query', {'query_id': query_id})
def update_last_batch(self, updates):
response = self.send_request('update_last_batch', {'updates': updates})
# print(f"update_last_batch response {response}")
return response.get('batch_todo')
def rebuild_inferece_context(self,response):
data = response.get('inference_context')
inference_context = sched_ext.InferenceContext()
print('Rebuilding kvcache')
inference_context.k_cache = [fn(*args) for fn,args in data['k_cache']]
inference_context.v_cache = [fn(*args) for fn,args in data['v_cache']]
return inference_context
def get_inference_context_raw(self):
response = self.send_request('get_inference_context')
return response
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
args = parser.parse_args()
with open(args.config, "rb") as f:
main_args = pickle.load(f)
if main_args.architectures == "Qwen2MoeForCausalLM":
settings = create_sched_settings_qwen2moe(main_args)
elif main_args.architectures == "Qwen3MoeForCausalLM":
settings = create_sched_settings_qwen3moe(main_args)
else:
settings = create_sched_settings(main_args)
start_server(settings, main_args)