mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
add balance-serve, support concurrence
This commit is contained in:
parent
8d0292aa44
commit
25cee5810e
196 changed files with 22077 additions and 565 deletions
|
@ -12,18 +12,10 @@ class ConfigArgs(BaseModel):
|
|||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
paged: bool = Field(None, description="Whether to use paged attention kv cache")
|
||||
total_context: int = Field(
|
||||
None,
|
||||
description=(
|
||||
"Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the"
|
||||
" total to distribute dynamically over however many jobs are active at once"
|
||||
),
|
||||
)
|
||||
max_batch_size: int = Field(
|
||||
None, description="Max number of batches to run at once, assuming the sequences will fit within total_context"
|
||||
)
|
||||
chunk_prefill_size: int = Field(
|
||||
chunk_size: int = Field(
|
||||
None,
|
||||
description=(
|
||||
"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new"
|
||||
|
@ -70,7 +62,6 @@ class ConfigArgs(BaseModel):
|
|||
repetition_penalty: float = Field(None, description="Sampler repetition penalty, default = 1.01 (1 to disable)")
|
||||
frequency_penalty: float = Field(None, description="Sampler frequency penalty, default = 0.0 (0 to disable)")
|
||||
presence_penalty: float = Field(None, description="Sampler presence penalty, default = 0.0 (0 to disable)")
|
||||
max_response_tokens: int = Field(None, description="Max tokens per response, default = 1000")
|
||||
response_chunk: int = Field(None, description="Space to reserve in context for reply, default = 250")
|
||||
no_code_formatting: bool = Field(None, description="Disable code formatting/syntax highlighting")
|
||||
cache_8bit: bool = Field(None, description="Use 8-bit (FP8) cache")
|
||||
|
|
|
@ -9,9 +9,11 @@ from ktransformers.server.backend.interfaces.transformers import TransformersThr
|
|||
from ktransformers.server.backend.interfaces.ktransformers import KTransformersThreadContext
|
||||
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaThreadContext
|
||||
|
||||
|
||||
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface
|
||||
from ktransformers.server.backend.interfaces.transformers import TransformersInterface
|
||||
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface
|
||||
|
||||
class ThreadContextManager:
|
||||
lock: Lock
|
||||
threads_context: Dict[ObjectID, ThreadContext]
|
||||
|
@ -36,7 +38,16 @@ class ThreadContextManager:
|
|||
elif isinstance(self.interface, TransformersInterface):
|
||||
new_context = TransformersThreadContext(run, self.interface)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
from ktransformers.server.backend.interfaces.balance_serve import BalanceServeThreadContext
|
||||
from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface
|
||||
if isinstance(self.interface, BalanceServeInterface):
|
||||
new_context = BalanceServeThreadContext(run, self.interface)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
# elif isinstance(self.interface, BalanceServeInterface):
|
||||
# new_context = BalanceServeThreadContext(run, self.interface)
|
||||
# else:
|
||||
# raise NotImplementedError
|
||||
self.threads_context[run.thread_id] = new_context
|
||||
# self.threads_context[run.thread_id] = ExllamaInferenceContext(run)
|
||||
re = self.threads_context[run.thread_id]
|
||||
|
|
406
ktransformers/server/backend/interfaces/balance_serve.py
Normal file
406
ktransformers/server/backend/interfaces/balance_serve.py
Normal file
|
@ -0,0 +1,406 @@
|
|||
from typing import Any, AsyncIterator, List, Optional, Set
|
||||
from ktransformers.models.custom_cache import KDeepSeekV3Cache
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
GenerationConfig,
|
||||
StaticCache,
|
||||
AutoModelForCausalLM,
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
|
||||
from ktransformers.server.config.config import Config
|
||||
from ..base import ThreadContext, BackendInterfaceBase
|
||||
import torch
|
||||
from ktransformers.server.backend.interfaces.transformers import (
|
||||
ConfigArgs,
|
||||
default_args,
|
||||
TextStreamer,
|
||||
)
|
||||
from ktransformers.server.schemas.base import ObjectID
|
||||
from ktransformers.server.config.log import logger
|
||||
from ktransformers.optimize.optimize import optimize_and_load_gguf
|
||||
from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM
|
||||
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
|
||||
from ktransformers.server.balance_serve.inference.model_runner import ModelRunner
|
||||
from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions
|
||||
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
|
||||
from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput
|
||||
from ktransformers.server.balance_serve.sched_rpc import SchedulerClient
|
||||
from ktransformers.server.balance_serve.settings import sched_ext
|
||||
from torch.multiprocessing import Queue
|
||||
import torch.multiprocessing as mp
|
||||
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||||
from ktransformers.server.utils.multi_timer import Profiler
|
||||
import zmq
|
||||
import time
|
||||
import queue
|
||||
import tempfile
|
||||
import asyncio
|
||||
import threading
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, Request
|
||||
import os
|
||||
|
||||
|
||||
|
||||
ktransformer_rules_dir = (
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/")
|
||||
)
|
||||
default_optimize_rules = {
|
||||
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
|
||||
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct-serve.yaml",
|
||||
}
|
||||
|
||||
async def chat_stream(queue: asyncio.Queue, tokenizer: AutoTokenizer):
|
||||
streamer = TextStreamer(tokenizer)
|
||||
while True:
|
||||
token = await queue.get()
|
||||
#print(f"Got token: {token}")
|
||||
if token is None:
|
||||
# str = f'{token}\n\n'
|
||||
# str = model.tokenizer.decode(token)
|
||||
s = streamer.end()
|
||||
if s is not None:
|
||||
yield s
|
||||
break
|
||||
|
||||
# str = model.tokenizer.decode(token)
|
||||
yield streamer.put(token)
|
||||
|
||||
|
||||
|
||||
def fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_tokens: torch.Tensor, query_manager: QueryManager = None):
|
||||
#print(len(query_updates), generated_tokens.size(0), generated_tokens)
|
||||
for i in range(generated_tokens.size(0)):
|
||||
print(generated_tokens[i].item())
|
||||
query_updates[i].generated_token = generated_tokens[i].item()
|
||||
if not query_manager.query_map[query_updates[i].id].is_prefill:
|
||||
pos = query_updates[i].active_position
|
||||
query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i]
|
||||
|
||||
def report_last_time_performance(profiler: Profiler):
|
||||
try:
|
||||
tokenize_time = profiler.get_timer_sec('tokenize')
|
||||
prefill_time = profiler.get_timer_sec('prefill')
|
||||
decode_time = profiler.get_timer_sec('decode')
|
||||
prefill_count = profiler.get_counter('prefill')
|
||||
decode_count = profiler.get_counter('decode')
|
||||
|
||||
logger.info(f'Performance(T/s): prefill {prefill_count/prefill_time}, decode {decode_count/decode_time}. Time(s): tokenize {tokenize_time}, prefill {prefill_time}, decode {decode_time}')
|
||||
except:
|
||||
logger.info(f'Performance statistics not recorded')
|
||||
|
||||
class Engine:
|
||||
sched_client : SchedulerClient
|
||||
updates : list[sched_ext.QueryUpdate]
|
||||
batch : sched_ext.BatchQueryTodo
|
||||
model_runner: ModelRunner
|
||||
sampler: Sampler
|
||||
query_manager: QueryManager
|
||||
cache: KDeepSeekV3Cache
|
||||
def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None):
|
||||
self.args = args
|
||||
|
||||
# 子进程和父进程无法共享 config 变量
|
||||
for key, value in vars(args).items():
|
||||
if value is not None and hasattr(Config(), key):
|
||||
setattr(Config(), key, value)
|
||||
|
||||
self.device = self.args.device
|
||||
self.sched_client = SchedulerClient(args.sched_port)
|
||||
self.updates = []
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
|
||||
|
||||
self.gen_queue = generated_token_queue
|
||||
|
||||
print(f"Getting inference context from sched_client.")
|
||||
inference_context = self.sched_client.get_inference_context_raw()
|
||||
print(f"Got inference context, sending it to subscribers.")
|
||||
inference_context = self.sched_client.rebuild_inferece_context(inference_context)
|
||||
self.cache.load(inference_context)
|
||||
print(f"kv_cache loaded successfully.")
|
||||
|
||||
self.block_num = inference_context.k_cache[0].size(1)
|
||||
with torch.device("meta"):
|
||||
if config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||
self.model = KDeepseekV3ForCausalLM(config, self.cache)
|
||||
elif config.architectures[0] == "DeepseekV2ForCausalLM":
|
||||
self.model = KDeepseekV2ForCausalLM(config, self.cache)
|
||||
# print(self.block_num)
|
||||
|
||||
context = zmq.Context()
|
||||
|
||||
|
||||
self.pub_socket = context.socket(zmq.PUB)
|
||||
self.pub_socket.bind(f"ipc://{broadcast_endpoint}")
|
||||
# time.sleep(1) # make sure all subscribers are ready
|
||||
|
||||
|
||||
try:
|
||||
generation_config = GenerationConfig.from_pretrained(args.model_dir)
|
||||
except:
|
||||
generation_config = GenerationConfig(
|
||||
max_length=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
do_sample=True
|
||||
)
|
||||
|
||||
if args.optimize_config_path is None:
|
||||
optimize_config_path = default_optimize_rules[config.architectures[0]]
|
||||
|
||||
else:
|
||||
optimize_config_path = args.optimize_config_path
|
||||
gguf_path = args.gguf_path
|
||||
if gguf_path is None:
|
||||
gguf_path = input(
|
||||
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
|
||||
" belong to current model):"
|
||||
)
|
||||
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
|
||||
self.model.generation_config = generation_config
|
||||
if self.model.generation_config.pad_token_id is None:
|
||||
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
||||
|
||||
self.model.eval()
|
||||
#@TODO add config
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
|
||||
|
||||
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size)
|
||||
self.sampler = Sampler()
|
||||
self.query_manager = QueryManager(device = self.device, page_size = args.page_size)
|
||||
|
||||
|
||||
def sampling(self, forward_output: ForwardBatchOutput):
|
||||
generated_tokens = torch.empty(0, device=self.device, dtype=torch.int32)
|
||||
for i in range(forward_output.num_batchs):
|
||||
logit = forward_output.logits[i]
|
||||
if hasattr(forward_output, "temperatures"):
|
||||
temperatures = forward_output.temperatures[i]
|
||||
else:
|
||||
temperatures = None
|
||||
|
||||
if hasattr(forward_output, "top_ps"):
|
||||
top_ps = forward_output.top_ps[i]
|
||||
else:
|
||||
top_ps = None
|
||||
|
||||
sample_options = SamplingOptions(logit.size(0), self.device, pretrained_config=self.model.generation_config, temperatures=temperatures, top_ps=top_ps)
|
||||
generated_tokens, probs=self.sampler(logit, sample_options)
|
||||
return generated_tokens, probs
|
||||
|
||||
def loop(self):
|
||||
|
||||
next_batch = None
|
||||
|
||||
while True:
|
||||
self.batch = next_batch
|
||||
if self.batch is not None:
|
||||
self.model_runner.run(self.batch, self.query_manager)
|
||||
|
||||
if len(self.updates) > 0:
|
||||
for q in self.updates:
|
||||
if q.is_prefill == True:
|
||||
continue
|
||||
# print(f"Putting token {q.generated_token} into queue for query id: {q.id}")
|
||||
try:
|
||||
self.gen_queue.put((q.id, q.generated_token if q.decode_done == False else None), timeout=5)
|
||||
except queue.Full:
|
||||
pass#print("Queue is full after timeout; unable to put more items.")
|
||||
|
||||
next_batch = self.sched_client.update_last_batch(self.updates)
|
||||
if next_batch.query_ids == []:
|
||||
next_batch = None
|
||||
self.pub_socket.send_pyobj(next_batch)
|
||||
|
||||
if next_batch is not None:
|
||||
self.query_manager.add_query(next_batch)
|
||||
|
||||
|
||||
if self.batch is not None:
|
||||
self.model_runner.sync()
|
||||
print(f"Model execution time (GPU): {self.model_runner.model_time:.3f} ms")
|
||||
# if self.rank == 0:
|
||||
|
||||
generated_tokens, probs = self.sampling( self.model_runner.output)
|
||||
|
||||
self.updates = self.query_manager.update(self.batch)
|
||||
fill_generated_tokens(self.updates, generated_tokens, self.query_manager)
|
||||
else:
|
||||
self.updates = []
|
||||
|
||||
class BalanceServeThreadContext(ThreadContext):
|
||||
def get_local_messages(self):
|
||||
local_messages = []
|
||||
for m in self.messages:
|
||||
local_messages.append({"role": m.role.value, "content": m.get_text_content()})
|
||||
|
||||
return local_messages
|
||||
|
||||
|
||||
def run_engine(args, token_queue, broadcast_endpoint, event):
|
||||
engine = Engine(args, token_queue, broadcast_endpoint)
|
||||
if args.use_cuda_graph:
|
||||
engine.model_runner.warmup()
|
||||
|
||||
event.set()
|
||||
engine.loop()
|
||||
|
||||
|
||||
class BalanceServeInterface(BackendInterfaceBase):
|
||||
use_static_cache: bool = True
|
||||
|
||||
model: Any
|
||||
tokenizer: AutoTokenizer
|
||||
|
||||
cache: StaticCache
|
||||
generated_ids: torch.Tensor
|
||||
seq_length: int
|
||||
|
||||
streamer: TextStreamer
|
||||
|
||||
# thread_related
|
||||
last_request_id: Optional[str] = None
|
||||
ever_generated_ids: Set[int] = set()
|
||||
def __init__(self, args: ConfigArgs = default_args):
|
||||
self.args = args
|
||||
self.queue_map:dict[int,asyncio.Queue] = {}
|
||||
self.thread_map: dict[int, int] = {}
|
||||
processes = []
|
||||
self.broadcast_endpoint = tempfile.NamedTemporaryFile(delete=False).name # @TODO add to config
|
||||
ctx = mp.get_context("spawn")
|
||||
self.token_queue = ctx.Queue(maxsize=1000)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
self.sched_client = SchedulerClient(args.sched_port)
|
||||
self.streamer = TextStreamer(self.tokenizer)
|
||||
|
||||
start_event = ctx.Event()
|
||||
|
||||
p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
start_event.wait()
|
||||
|
||||
def run_queue_proxy(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self.queue_proxy())
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, app: FastAPI):
|
||||
asyncio.create_task(self.queue_proxy())
|
||||
yield
|
||||
|
||||
async def queue_proxy(self):
|
||||
print("Queue Proxy Started")
|
||||
while True:
|
||||
try:
|
||||
query_id, token = self.token_queue.get_nowait()
|
||||
try:
|
||||
# query id might not be allocated yet
|
||||
self.queue_map[query_id].put_nowait(token)
|
||||
#print(f"Proxy Put token: {token} to queue for query id: {query_id}")
|
||||
except asyncio.QueueFull:
|
||||
#print(f"Queue for query id: {query_id} is full, waiting to put: {token}")
|
||||
await self.queue_map[query_id].put(token)
|
||||
|
||||
except queue.Empty:
|
||||
# print("no new token")
|
||||
# await asyncio.sleep(1)
|
||||
await asyncio.sleep(0)
|
||||
def tokenize_prompt(self, prompt: str):
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.args.device)
|
||||
return input_ids
|
||||
|
||||
def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):
|
||||
for m in messages:
|
||||
if m["role"] == "system":
|
||||
logger.warning(f'change {m["role"]} to user')
|
||||
m["role"] = "user"
|
||||
|
||||
new_messages = [messages[0]]
|
||||
for m in messages[1:]:
|
||||
if m["role"] == "user" and new_messages[-1]["role"] == "user":
|
||||
logger.warning("merge two adjacent user messages")
|
||||
new_messages[-1]["content"] += '\n' + m["content"]
|
||||
else:
|
||||
new_messages.append(m)
|
||||
input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
|
||||
# drop <think> token in chat template
|
||||
if input_str.endswith('<think>\n'):
|
||||
input_str = input_str[:-len('<think>\n')]
|
||||
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
|
||||
logger.debug(f"get input ids of shape {input_ids.shape}")
|
||||
return input_ids
|
||||
|
||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||
profiler = Profiler()
|
||||
profiler.create_and_start_timer("tokenize")
|
||||
|
||||
if isinstance(local_messages, List):
|
||||
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
||||
elif isinstance(local_messages, str):
|
||||
#local_messages = local_messages[0]['content']
|
||||
input_ids = self.tokenize_prompt(local_messages)
|
||||
else:
|
||||
raise ValueError("local_messages should be List or str")
|
||||
if Config().user_force_think:
|
||||
token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
|
||||
input_ids = torch.cat(
|
||||
[input_ids, token_thinks], dim=1
|
||||
)
|
||||
|
||||
|
||||
profiler.pause_timer("tokenize")
|
||||
|
||||
profiler.create_and_start_timer("prefill")
|
||||
|
||||
|
||||
|
||||
query_add = sched_ext.QueryAdd()
|
||||
query_add.query_token = input_ids[0].tolist()
|
||||
query_length = input_ids[0].shape[0]
|
||||
query_add.query_length = query_length
|
||||
profiler.set_counter("prefill", query_length)
|
||||
#@TODO add server
|
||||
stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
|
||||
query_add.stop_criteria = stop_criteria
|
||||
query_add.sample_options.temperature = temperature
|
||||
query_add.sample_options.top_p = top_p
|
||||
query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens)
|
||||
query_id = self.sched_client.add_query(query_add)
|
||||
queue = asyncio.Queue(maxsize=self.args.max_new_tokens)
|
||||
self.queue_map[query_id] = queue
|
||||
self.thread_map[thread_id] = query_id
|
||||
is_first_token = True
|
||||
async for token in chat_stream(self.queue_map[query_id], self.tokenizer):
|
||||
if is_first_token:
|
||||
is_first_token=False
|
||||
profiler.pause_timer("prefill")
|
||||
profiler.create_and_start_timer("decode")
|
||||
profiler.set_counter("decode", 0)
|
||||
if Config().user_force_think:
|
||||
think = '<think>\n'
|
||||
print(think, end="",flush=True)
|
||||
yield think, None
|
||||
else:
|
||||
profiler.inc("decode")
|
||||
yield token, None
|
||||
profiler.pause_timer("decode")
|
||||
report_last_time_performance(profiler)
|
||||
yield self.streamer.end(), None
|
||||
if profiler.get_counter('decode') >= self.args.max_new_tokens - 1:
|
||||
yield "", "length"
|
||||
else:
|
||||
yield "", "stop"
|
||||
|
||||
|
||||
yield RawUsage(
|
||||
tokenize_time = profiler.get_timer_sec('tokenize'),
|
||||
prefill_time = profiler.get_timer_sec('prefill'),
|
||||
decode_time = profiler.get_timer_sec('decode'),
|
||||
prefill_count = profiler.get_counter('prefill'),
|
||||
decode_count = profiler.get_counter('decode'),
|
||||
)
|
|
@ -211,11 +211,11 @@ class KTransformersInterface(TransformersInterface):
|
|||
|
||||
chunk_start = 0
|
||||
while chunk_start < input_ids_length:
|
||||
chunk_end = min(chunk_start + self.args.chunk_prefill_size, input_ids_length)
|
||||
chunk_end = min(chunk_start + self.args.chunk_size, input_ids_length)
|
||||
if self.cache != None:
|
||||
self.cache.cur_idx=cache_position[chunk_start:chunk_end]
|
||||
logits = chunk_prefill(input_ids[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end])
|
||||
chunk_start += self.args.chunk_prefill_size
|
||||
chunk_start += self.args.chunk_size
|
||||
|
||||
if flashinfer_enabled:
|
||||
MLAWrapperSingleton.reset_buffer()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue