mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 03:39:48 +00:00
677 lines
30 KiB
Python
677 lines
30 KiB
Python
from typing import Any, AsyncIterator, List, Optional, Set
|
||
from ktransformers.models.custom_cache import KVC2StaticCache, KDeepSeekV3Cache, KGQACache, KVC2Qwen3Cache
|
||
from transformers import (
|
||
AutoTokenizer,
|
||
AutoConfig,
|
||
GenerationConfig,
|
||
StaticCache,
|
||
AutoModelForCausalLM,
|
||
BitsAndBytesConfig,
|
||
)
|
||
|
||
import torch.distributed as dist
|
||
from ktransformers.server.config.config import Config
|
||
from ..base import ThreadContext, BackendInterfaceBase
|
||
import torch
|
||
from ktransformers.server.backend.interfaces.transformers import (
|
||
ConfigArgs,
|
||
default_args,
|
||
TextStreamer,
|
||
)
|
||
from ktransformers.server.schemas.base import ObjectID
|
||
from ktransformers.server.config.log import logger
|
||
from ktransformers.optimize.optimize import optimize_and_load_gguf
|
||
from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM
|
||
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
|
||
from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM
|
||
from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM
|
||
from ktransformers.models.custom_modeling_smallthinker import KSmallThinkerForCausalLM
|
||
from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM
|
||
from ktransformers.models.custom_modeling_qwen3_next import KQwen3NextForCausalLM
|
||
from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig
|
||
from ktransformers.models.configuration_smallthinker import SmallthinkerConfig
|
||
from ktransformers.models.configuration_glm4_moe import Glm4MoeConfig
|
||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
|
||
try:
|
||
import torch_npu
|
||
use_torch_npu = torch.npu.is_available()
|
||
except:
|
||
use_torch_npu = False
|
||
if use_torch_npu:
|
||
from ktransformers.models.ascend.custom_ascend_modeling_deepseek_v3 import KNPUDeepseekV3ForCausalLM
|
||
from ktransformers.models.ascend.custom_ascend_modeling_qwen3 import KNPUQwen3MoeForCausalLM
|
||
from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel, get_tensor_parallel_group, get_tensor_parallel_size
|
||
|
||
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
|
||
from ktransformers.models.modeling_llama import LlamaForCausalLM
|
||
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
||
from ktransformers.util import utils
|
||
custom_models = {
|
||
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
|
||
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
|
||
"LlamaForCausalLM": LlamaForCausalLM,
|
||
"MixtralForCausalLM": MixtralForCausalLM,
|
||
}
|
||
from ktransformers.server.balance_serve.inference.model_runner import ModelRunner, get_or_create_model_runner #TODO get_or_create_model_runner npu独有?
|
||
from ktransformers.models.configuration_qwen3_next import Qwen3NextConfig
|
||
from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions
|
||
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
|
||
from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput
|
||
from ktransformers.server.balance_serve.sched_rpc import SchedulerClient
|
||
from ktransformers.server.balance_serve.settings import sched_ext
|
||
|
||
from torch.multiprocessing import Queue
|
||
import torch.multiprocessing as mp
|
||
from multiprocessing.synchronize import Event
|
||
import datetime
|
||
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||
from ktransformers.server.utils.multi_timer import Profiler
|
||
import zmq
|
||
import time
|
||
import queue
|
||
import tempfile
|
||
import asyncio
|
||
import cProfile
|
||
import threading
|
||
from contextlib import asynccontextmanager
|
||
from fastapi import FastAPI, Request
|
||
import os
|
||
import pickle
|
||
import subprocess
|
||
import tempfile
|
||
import atexit
|
||
import signal
|
||
|
||
|
||
ktransformer_rules_dir = (
|
||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/")
|
||
)
|
||
|
||
default_optimize_rules = {
|
||
# "DeepseekV3ForCausalLM": ktransformer_rules_dir + "Moonlight-16B-A3B-serve.yaml",
|
||
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
|
||
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-serve.yaml",
|
||
"Qwen3MoeForCausalLM": ktransformer_rules_dir + "Qwen3Moe-serve.yaml",
|
||
"SmallThinkerForCausalLM": ktransformer_rules_dir + "Smallthinker-serve.yaml",
|
||
"Glm4MoeForCausalLM": ktransformer_rules_dir + "Glm4Moe-serve.yaml",
|
||
"Qwen3NextForCausalLM": ktransformer_rules_dir + "Qwen3Next-serve.yaml",
|
||
}
|
||
if use_torch_npu:
|
||
default_optimize_rules["Qwen2MoeForCausalLM"] = ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct-serve.yaml"
|
||
|
||
async def chat_stream(queue: asyncio.Queue, tokenizer: AutoTokenizer):
|
||
streamer = TextStreamer(tokenizer)
|
||
while True:
|
||
token = await queue.get()
|
||
#print(f"Got token: {token}")
|
||
if token is None:
|
||
# str = f'{token}\n\n'
|
||
# str = model.tokenizer.decode(token)
|
||
s = streamer.end()
|
||
if s is not None:
|
||
yield s
|
||
break
|
||
else:
|
||
# text output
|
||
text = tokenizer.decode(token)
|
||
print(text, end="", flush=True)
|
||
|
||
# str = model.tokenizer.decode(token)
|
||
yield streamer.put(token)
|
||
|
||
def fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_tokens: torch.Tensor, query_manager: QueryManager = None):
|
||
#print(len(query_updates), generated_tokens.size(0), generated_tokens)
|
||
for i in range(generated_tokens.size(0)):
|
||
# print(generated_tokens[i].item())
|
||
query_updates[i].generated_token = generated_tokens[i].item()
|
||
if not query_manager.query_map[query_updates[i].id].is_prefill:
|
||
pos = query_updates[i].active_position
|
||
if pos < query_manager.query_map[query_updates[i].id].max_length:
|
||
query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i]
|
||
|
||
def report_last_time_performance(profiler: Profiler):
|
||
try:
|
||
tokenize_time = profiler.get_timer_sec('tokenize')
|
||
prefill_time = profiler.get_timer_sec('prefill')
|
||
decode_time = profiler.get_timer_sec('decode')
|
||
prefill_count = profiler.get_counter('prefill')
|
||
decode_count = profiler.get_counter('decode')
|
||
|
||
logger.info(f'Performance(T/s): prefill {prefill_count/prefill_time}, decode {decode_count/decode_time}. Time(s): tokenize {tokenize_time}, prefill {prefill_time}, decode {decode_time}')
|
||
except:
|
||
logger.info(f'Performance statistics not recorded')
|
||
|
||
class Engine:
|
||
sched_client : SchedulerClient
|
||
updates : list[sched_ext.QueryUpdate]
|
||
batch : sched_ext.BatchQueryTodo
|
||
model_runner: ModelRunner
|
||
sampler: Sampler
|
||
query_manager: QueryManager
|
||
cache: KDeepSeekV3Cache | KGQACache | KVC2StaticCache
|
||
def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None, kvcache_event: Event = None):
|
||
self.args = args
|
||
|
||
# 子进程和父进程无法共享 config 变量
|
||
for key, value in vars(args).items():
|
||
if value is not None and hasattr(Config(), key):
|
||
setattr(Config(), key, value)
|
||
if use_torch_npu:
|
||
utils.CUR_DEVICE = f"npu:{torch.npu.current_device()}"
|
||
self.device = f"npu:{torch.npu.current_device()}"
|
||
else:
|
||
self.device = self.args.device
|
||
self.sched_client = SchedulerClient(args.sched_port)
|
||
self.updates = []
|
||
|
||
print(f"args.architectures: {args.architectures}")
|
||
|
||
if args.architectures == "Qwen3MoeForCausalLM":
|
||
config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||
elif args.architectures == "Glm4MoeForCausalLM":
|
||
config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||
elif args.architectures == "SmallThinkerForCausalLM":
|
||
config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||
config._attn_implementation = "eager"
|
||
config.moe_intermediate_size = config.moe_ffn_hidden_size
|
||
elif args.architectures == "Qwen3NextForCausalLM":
|
||
config = Qwen3NextConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||
else:
|
||
try:
|
||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||
except:
|
||
raise ValueError(f"Model {args.architectures} not supported. Please check your model directory or model name.")
|
||
|
||
self.gen_queue = generated_token_queue
|
||
self.debug = False
|
||
|
||
self.profiler_cprofile = cProfile.Profile()
|
||
self.cprof_prof_cnt, self.max_cprof_prof_cnt = 0, 8
|
||
with torch.device("meta"):
|
||
if config.architectures[0] == "DeepseekV3ForCausalLM":
|
||
if use_torch_npu:
|
||
self.cache = KVC2StaticCache(config, args.max_batch_size, self.args.page_size)
|
||
self.model = KNPUDeepseekV3ForCausalLM(config)
|
||
else:
|
||
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
|
||
self.model = KDeepseekV3ForCausalLM(config, self.cache)
|
||
elif config.architectures[0] == "DeepseekV2ForCausalLM":
|
||
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
|
||
self.model = KDeepseekV2ForCausalLM(config, self.cache)
|
||
elif config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM":
|
||
if not use_torch_npu:
|
||
self.cache = KGQACache(config, self.args.page_size)
|
||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||
self.model = KQwen2MoeForCausalLM(config, self.cache)
|
||
else:
|
||
self.model = KQwen3MoeForCausalLM(config, self.cache)
|
||
else:
|
||
self.cache = KVC2Qwen3Cache(config, args.max_batch_size, self.args.page_size)
|
||
self.model = KNPUQwen3MoeForCausalLM(config, self.cache)
|
||
elif config.architectures[0] == "SmallThinkerForCausalLM":
|
||
self.cache = KGQACache(config, self.args.page_size)
|
||
self.model = KSmallThinkerForCausalLM(config, self.cache)
|
||
elif config.architectures[0] == "Glm4MoeForCausalLM":
|
||
self.cache = KGQACache(config, self.args.page_size)
|
||
self.model = KGlm4MoeForCausalLM(config, self.cache)
|
||
elif config.architectures[0] == "Qwen3NextForCausalLM":
|
||
self.cache = KGQACache(config, self.args.page_size)
|
||
self.model = KQwen3NextForCausalLM(config, self.cache)
|
||
|
||
context = zmq.Context()
|
||
if use_torch_npu:
|
||
if torch.distributed.get_rank() == 0:
|
||
self.pub_socket = context.socket(zmq.PUB)
|
||
self.pub_socket.bind(f"ipc://{broadcast_endpoint}")
|
||
self.sub_socket = None
|
||
else:
|
||
self.sub_socket = context.socket(zmq.SUB)
|
||
self.sub_socket.connect(f"ipc://{broadcast_endpoint}")
|
||
self.sub_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||
self.pub_socket = None
|
||
# time.sleep(1) # make sure all subscribers are ready
|
||
else:
|
||
self.pub_socket = context.socket(zmq.PUB)
|
||
self.pub_socket.bind(f"ipc://{broadcast_endpoint}")
|
||
|
||
try:
|
||
generation_config = GenerationConfig.from_pretrained(args.model_dir)
|
||
except:
|
||
generation_config = GenerationConfig(
|
||
max_length=args.max_new_tokens,
|
||
temperature=args.temperature,
|
||
top_p=args.top_p,
|
||
do_sample=True
|
||
)
|
||
|
||
if args.optimize_config_path is None:
|
||
optimize_config_path = default_optimize_rules[config.architectures[0]]
|
||
|
||
else:
|
||
optimize_config_path = args.optimize_config_path
|
||
gguf_path = args.gguf_path
|
||
if gguf_path is None:
|
||
gguf_path = input(
|
||
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
|
||
" belong to current model):"
|
||
)
|
||
if use_torch_npu:
|
||
tp_group = get_tensor_parallel_group()
|
||
torch.distributed.barrier(group=tp_group)
|
||
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
|
||
if use_torch_npu:
|
||
get_absort_weight(self.model, config) #TODO
|
||
torch.distributed.barrier(group=tp_group)
|
||
self.model.generation_config = generation_config
|
||
if self.model.generation_config.pad_token_id is None:
|
||
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
||
|
||
self.model.eval()
|
||
kvcache_event.set()
|
||
# load kvcache
|
||
print(f"Getting inference context from sched_client.")
|
||
inference_context = self.sched_client.get_inference_context_raw()
|
||
print(f"Got inference context, sending it to subscribers.")
|
||
inference_context = self.sched_client.rebuild_inferece_context(inference_context)
|
||
self.cache.load(inference_context)
|
||
print(f"kv_cache loaded successfully.")
|
||
|
||
|
||
self.block_num = inference_context.k_cache[0].size(1)
|
||
#TODO ModelRunner 区别
|
||
# self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
|
||
#@TODO add config
|
||
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallThinkerForCausalLM" or config.architectures[0] == "Qwen3NextForCausalLM":
|
||
if not use_torch_npu:
|
||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num)
|
||
else:
|
||
# npu donnot support flash attn
|
||
self.model.init_wrapper()
|
||
else:
|
||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
|
||
|
||
|
||
# self.args.use_cuda_graph代表是否使用图下沉
|
||
self.model_runner = get_or_create_model_runner(self.model, self.cache, self.device, self.args.use_cuda_graph, page_size = args.page_size)
|
||
self.sampler = Sampler()
|
||
self.query_manager = QueryManager(device = self.device, page_size = args.page_size)
|
||
|
||
|
||
def sampling(self, forward_output: ForwardBatchOutput):
|
||
generated_tokens = []
|
||
probs = []
|
||
|
||
for i in range(forward_output.num_batchs):
|
||
logit = forward_output.logits[i]
|
||
if hasattr(forward_output, "temperatures"):
|
||
temperatures = forward_output.temperatures[i]
|
||
else:
|
||
temperatures = None
|
||
|
||
if hasattr(forward_output, "top_ps"):
|
||
top_ps = forward_output.top_ps[i]
|
||
else:
|
||
top_ps = None
|
||
|
||
sample_options = SamplingOptions(logit.size(0), self.device, pretrained_config=self.model.generation_config, temperatures=temperatures, top_ps=top_ps)
|
||
generated_token, prob=self.sampler(logit, sample_options)
|
||
generated_tokens.append(generated_token.clone())
|
||
probs.append(prob.clone())
|
||
generated_tokens, probs = torch.cat(generated_tokens), torch.cat(probs, dim=0)
|
||
return generated_tokens, probs
|
||
|
||
def loop(self):
|
||
|
||
next_batch = None
|
||
|
||
while True:
|
||
self.batch = next_batch
|
||
if self.batch is not None:
|
||
if use_torch_npu:
|
||
batch_size = 0
|
||
for i in range(len(self.batch.decode_mini_batches)):
|
||
batch_size += len(self.batch.decode_mini_batches[i])
|
||
# logger.debug(f"prefill batch: {len(self.batch.prefill_mini_batches)} decode batch: {len(self.batch.decode_mini_batches)} {batch_size} \n")
|
||
self.model_runner.run_split(self.batch, self.query_manager)
|
||
else:
|
||
self.model_runner.run(self.batch, self.query_manager)
|
||
|
||
if len(self.updates) > 0:
|
||
for q in self.updates:
|
||
if q.is_prefill == True:
|
||
continue
|
||
# print(f"Putting token {q.generated_token} into queue for query id: {q.id}")
|
||
try:
|
||
if use_torch_npu:
|
||
if torch.distributed.get_rank() == 0:
|
||
self.gen_queue.put((q.id, q.generated_token if q.decode_done == False else None), timeout=5)
|
||
else:
|
||
self.gen_queue.put((q.id, q.generated_token if q.decode_done == False else None), timeout=5)
|
||
except queue.Full:
|
||
pass#print("Queue is full after timeout; unable to put more items.")
|
||
if use_torch_npu:
|
||
if torch.distributed.get_rank() == 0:
|
||
next_batch = self.sched_client.update_last_batch(self.updates)
|
||
if next_batch.query_ids == []:
|
||
next_batch = None
|
||
self.pub_socket.send_pyobj(next_batch)
|
||
else:
|
||
next_batch = self.sub_socket.recv_pyobj()
|
||
else:
|
||
next_batch = self.sched_client.update_last_batch(self.updates)
|
||
if next_batch.query_ids == []:
|
||
next_batch = None
|
||
self.pub_socket.send_pyobj(next_batch)
|
||
|
||
if next_batch is not None:
|
||
self.query_manager.add_query(next_batch)
|
||
|
||
|
||
if self.batch is not None:
|
||
self.model_runner.sync()
|
||
# print(f"Model execution time (GPU): {self.model_runner.model_time:.3f} ms")
|
||
# if self.rank == 0:
|
||
|
||
generated_tokens, probs = self.sampling( self.model_runner.output)
|
||
|
||
self.updates = self.query_manager.update(self.batch)
|
||
fill_generated_tokens(self.updates, generated_tokens, self.query_manager)
|
||
|
||
else:
|
||
self.updates = []
|
||
|
||
class BalanceServeThreadContext(ThreadContext):
|
||
def get_local_messages(self):
|
||
local_messages = []
|
||
for m in self.messages:
|
||
local_messages.append({"role": m.role.value, "content": m.get_text_content()})
|
||
|
||
return local_messages
|
||
|
||
|
||
def init_distributed(rank: int,
|
||
world_size: int,
|
||
tp_size: int,
|
||
master_addr: str = os.getenv("MASTER_ADDR", "127.0.0.1"),
|
||
master_port: int = os.getenv("MASTER_PORT", "29500"),
|
||
backend: str = "hccl"): #TODO csx: 是否distribute 都只与NPU有关
|
||
os.environ["RANK"] = str(rank)
|
||
os.environ["LOCAL_RANK"] = str(rank)
|
||
os.environ["WORLD_SIZE"] = str(world_size)
|
||
os.environ["MASTER_ADDR"] = master_addr
|
||
os.environ["MASTER_PORT"] = str(master_port)
|
||
|
||
local_rank, world_size = setup_model_parallel(tp=tp_size)
|
||
return local_rank, world_size
|
||
|
||
|
||
def run_engine(args, token_queue, broadcast_endpoint, event, kvcache_event, rank=None, world_size=None):
|
||
if use_torch_npu:
|
||
init_distributed(rank, world_size, args.tp, backend="hccl") #TODO 同上
|
||
import torch.distributed as dist
|
||
engine = Engine(args, token_queue, broadcast_endpoint, kvcache_event)
|
||
if args.use_cuda_graph:
|
||
if 'npu' in engine.device:
|
||
print(f"[WARMUP-NPU] start", flush=True)
|
||
engine.model_runner.warmup_npu()
|
||
else:
|
||
engine.model_runner.warmup()
|
||
else:
|
||
print(f"[WARMUP-NPU] skip warmup, eager mode!", flush=True)
|
||
if use_torch_npu:
|
||
args.port += torch.distributed.get_rank()
|
||
event.set()
|
||
engine.loop()
|
||
|
||
|
||
class BalanceServeInterface(BackendInterfaceBase):
|
||
use_static_cache: bool = True
|
||
|
||
model: Any
|
||
tokenizer: AutoTokenizer
|
||
|
||
cache: StaticCache
|
||
generated_ids: torch.Tensor
|
||
seq_length: int
|
||
|
||
streamer: TextStreamer
|
||
|
||
# thread_related
|
||
last_request_id: Optional[str] = None
|
||
ever_generated_ids: Set[int] = set()
|
||
|
||
def __init__(self, args: ConfigArgs = default_args, input_args=None):
|
||
self.args = input_args
|
||
self.queue_map:dict[int,asyncio.Queue] = {}
|
||
self.thread_map: dict[int, int] = {}
|
||
processes = []
|
||
self.broadcast_endpoint = tempfile.NamedTemporaryFile(delete=False).name # @TODO add to config
|
||
ctx = mp.get_context("spawn")
|
||
self.token_queue = ctx.Queue(maxsize=1000)
|
||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
|
||
self.sched_client = SchedulerClient(args.sched_port)
|
||
self.streamer = TextStreamer(self.tokenizer)
|
||
if use_torch_npu:
|
||
world_size = str(os.getenv("WORLD_SIZE", self.args.tp))
|
||
if not isinstance(world_size, str):
|
||
raise ValueError(f"world_size ({world_size}) must be str")
|
||
start_events = []
|
||
kvcache_events = []
|
||
for rank in range(self.args.tp):
|
||
if int(self.args.device[-1]) > 0:
|
||
break
|
||
|
||
start_event = ctx.Event()
|
||
kvcache_event = ctx.Event()
|
||
|
||
p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event,
|
||
kvcache_event, rank, world_size))
|
||
p.start()
|
||
processes.append(p)
|
||
start_events.append(start_event)
|
||
kvcache_events.append(kvcache_event)
|
||
|
||
for evt in kvcache_events:
|
||
evt.wait()
|
||
self._engines = processes
|
||
else:
|
||
start_event = ctx.Event()
|
||
kvcache_event = ctx.Event()
|
||
|
||
p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event,
|
||
kvcache_event))
|
||
p.start()
|
||
processes.append(p)
|
||
|
||
kvcache_event.wait()
|
||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||
args.tp = input_args.tp
|
||
pickle.dump(args, temp_file)
|
||
temp_file_path = temp_file.name
|
||
current_file = __file__
|
||
target_file = os.path.join(os.path.dirname(current_file), "..", "..", "balance_serve", "sched_rpc.py")
|
||
target_file = os.path.normpath(target_file)
|
||
log_path = os.path.join(args.log_dir, "rpc.log")
|
||
log = open(log_path, "a")
|
||
sched_process = subprocess.Popen(
|
||
["python3", target_file, "--config", temp_file_path],
|
||
stdout=log,
|
||
stderr=log
|
||
)
|
||
print("sched_rpc started with PID:", sched_process.pid)
|
||
|
||
def signal_handler(signum, frame):
|
||
print(f"Received signal {signum}, shutting down...")
|
||
cleanup()
|
||
os._exit(0)
|
||
|
||
def cleanup():
|
||
print("Cleaning up...")
|
||
|
||
for p in processes:
|
||
if p.is_alive():
|
||
print(f"Terminating subprocess {p.pid}")
|
||
p.terminate()
|
||
p.join()
|
||
|
||
if sched_process and sched_process.poll() is None:
|
||
print(f"Terminating sched_process {sched_process.pid}")
|
||
sched_process.terminate()
|
||
sched_process.wait()
|
||
signal.signal(signal.SIGINT, signal_handler)
|
||
signal.signal(signal.SIGTERM, signal_handler)
|
||
if use_torch_npu:
|
||
for evt in start_events:
|
||
evt.wait()
|
||
else:
|
||
start_event.wait()
|
||
|
||
def get_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None,
|
||
max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None) -> tuple[float, float]:
|
||
"""Get sampling parameters and handle default values and edge cases"""
|
||
if max_tokens is not None:
|
||
max_completion_tokens = max_tokens
|
||
if max_completion_tokens is None:
|
||
max_completion_tokens = self.args.max_new_tokens
|
||
else:
|
||
max_completion_tokens = min(self.args.max_new_tokens, max_completion_tokens)
|
||
if temperature is None:
|
||
temperature = self.args.temperature
|
||
if top_p is None:
|
||
top_p = self.args.top_p
|
||
|
||
if temperature == 0:
|
||
temperature = 0.0001
|
||
if top_p == 0:
|
||
top_p = 0.0001
|
||
|
||
return temperature, top_p, max_completion_tokens
|
||
|
||
def run_queue_proxy(self):
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
loop.run_until_complete(self.queue_proxy())
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(self, app: FastAPI):
|
||
asyncio.create_task(self.queue_proxy())
|
||
yield
|
||
|
||
async def queue_proxy(self):
|
||
print("Queue Proxy Started")
|
||
while True:
|
||
try:
|
||
query_id, token = self.token_queue.get_nowait()
|
||
try:
|
||
# query id might not be allocated yet
|
||
self.queue_map[query_id].put_nowait(token)
|
||
#print(f"Proxy Put token: {token} to queue for query id: {query_id}")
|
||
except asyncio.QueueFull:
|
||
#print(f"Queue for query id: {query_id} is full, waiting to put: {token}")
|
||
await self.queue_map[query_id].put(token)
|
||
|
||
except queue.Empty:
|
||
# print("no new token")
|
||
# await asyncio.sleep(1)
|
||
await asyncio.sleep(0)
|
||
def tokenize_prompt(self, prompt: str):
|
||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.args.device)
|
||
return input_ids
|
||
|
||
def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):
|
||
for m in messages:
|
||
if m["role"] == "system":
|
||
logger.warning(f'change {m["role"]} to user')
|
||
m["role"] = "user"
|
||
|
||
new_messages = [messages[0]]
|
||
for m in messages[1:]:
|
||
if m["role"] == "user" and new_messages[-1]["role"] == "user":
|
||
logger.warning("merge two adjacent user messages")
|
||
new_messages[-1]["content"] += '\n' + m["content"]
|
||
else:
|
||
new_messages.append(m)
|
||
# input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
|
||
# # drop <think> token in chat template
|
||
# if input_str.endswith('<think>\n'):
|
||
# input_str = input_str[:-len('<think>\n')]
|
||
input_ids = self.tokenizer.apply_chat_template(new_messages, add_generation_prompt=True, return_tensors="pt").to(self.args.device)
|
||
return input_ids
|
||
|
||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = 0, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
|
||
profiler = Profiler()
|
||
profiler.create_and_start_timer("tokenize")
|
||
|
||
if isinstance(local_messages, List):
|
||
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
||
elif isinstance(local_messages, str):
|
||
#local_messages = local_messages[0]['content']
|
||
input_ids = self.tokenize_prompt(local_messages)
|
||
else:
|
||
raise ValueError("local_messages should be List or str")
|
||
if Config().user_force_think:
|
||
token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
|
||
if not torch.equal(input_ids[0, -token_thinks.shape[-1]:], token_thinks[-1]): #TODO 此行新加的,考虑是否影响GPU
|
||
input_ids = torch.cat(
|
||
[input_ids, token_thinks], dim=1
|
||
)
|
||
logger.debug(f"get input ids of shape {input_ids.shape}")
|
||
|
||
|
||
profiler.pause_timer("tokenize")
|
||
|
||
profiler.create_and_start_timer("prefill")
|
||
|
||
|
||
|
||
query_add = sched_ext.QueryAdd()
|
||
query_add.query_token = input_ids[0].tolist()
|
||
query_length = input_ids[0].shape[0]
|
||
query_add.query_length = query_length
|
||
profiler.set_counter("prefill", query_length)
|
||
#@TODO add server
|
||
stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
|
||
query_add.stop_criteria = stop_criteria
|
||
|
||
temperature, top_p, max_new_tokens = self.get_params(temperature, top_p, max_tokens, max_completion_tokens)
|
||
|
||
query_add.sample_options.temperature = temperature
|
||
if top_p == 0 or top_p is None:
|
||
top_p = 0.0001
|
||
query_add.sample_options.top_p = top_p
|
||
query_add.estimated_length = min(self.args.cache_lens, query_length+max_new_tokens)
|
||
query_id = self.sched_client.add_query(query_add)
|
||
queue = asyncio.Queue(maxsize=max_new_tokens)
|
||
self.queue_map[query_id] = queue
|
||
self.thread_map[thread_id] = query_id
|
||
is_first_token = True
|
||
async for token in chat_stream(self.queue_map[query_id], self.tokenizer):
|
||
if is_first_token:
|
||
is_first_token=False
|
||
profiler.pause_timer("prefill")
|
||
profiler.create_and_start_timer("decode")
|
||
profiler.set_counter("decode", 0)
|
||
if Config().user_force_think:
|
||
think = '<think>\n'
|
||
print(think, end="",flush=True)
|
||
yield think, None
|
||
else:
|
||
profiler.inc("decode")
|
||
# TODO: 传入rank避免打印重复
|
||
yield token, None
|
||
profiler.pause_timer("decode")
|
||
report_last_time_performance(profiler)
|
||
yield self.streamer.end(), None
|
||
if profiler.get_counter('decode') >= max_new_tokens - 1:
|
||
yield "", "length"
|
||
else:
|
||
yield "", "stop"
|
||
|
||
|
||
yield RawUsage(
|
||
tokenize_time = profiler.get_timer_sec('tokenize'),
|
||
prefill_time = profiler.get_timer_sec('prefill'),
|
||
decode_time = profiler.get_timer_sec('decode'),
|
||
prefill_count = profiler.get_counter('prefill'),
|
||
decode_count = profiler.get_counter('decode'),
|
||
)
|