Merge branch 'kvcache-ai:main' into main

This commit is contained in:
Yuhao Tsui 2025-04-09 11:46:39 +08:00 committed by GitHub
commit 877aec858e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
251 changed files with 47224 additions and 749 deletions

View file

@ -1,6 +1,6 @@
import argparse
from ktransformers.server.backend.args import ConfigArgs, default_args
from ktransformers.util.utils import get_free_ports
class ArgumentParser:
def __init__(self, cfg):
@ -16,20 +16,18 @@ class ArgumentParser:
parser.add_argument("--web", type=bool, default=self.cfg.mount_web)
parser.add_argument("--model_name", type=str, default=self.cfg.model_name)
parser.add_argument("--model_dir", type=str)
parser.add_argument("--model_path", type=str)
parser.add_argument("--model_path", type=str, default=self.cfg.model_path)
parser.add_argument(
"--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter"
)
parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path)
parser.add_argument("--optimize_config_path", default=self.cfg.optimize_config_path, type=str, required=False)
parser.add_argument("--optimize_config_path", default=None, type=str, required=False)
parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer)
parser.add_argument("--type", type=str, default=self.cfg.backend_type)
parser.add_argument("--chunk_prefill_size", type=int, default=8192)
parser.add_argument("--backend_type", type=str, default=self.cfg.backend_type)
parser.add_argument("--chunk_size", type=int, default=self.cfg.chunk_size)
# model configs
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
parser.add_argument("--paged", type=bool, default=self.cfg.paged)
parser.add_argument("--total_context", type=int, default=self.cfg.total_context)
parser.add_argument("--max_batch_size", type=int, default=self.cfg.max_batch_size)
parser.add_argument("--max_new_tokens", type=int, default=self.cfg.max_new_tokens)
parser.add_argument("--json_mode", type=bool, default=self.cfg.json_mode)
@ -62,7 +60,6 @@ class ArgumentParser:
parser.add_argument("--repetition_penalty", type=float, default=self.cfg.repetition_penalty)
parser.add_argument("--frequency_penalty", type=float, default=self.cfg.frequency_penalty)
parser.add_argument("--presence_penalty", type=float, default=self.cfg.presence_penalty)
parser.add_argument("--max_response_tokens", type=int, default=self.cfg.max_response_tokens)
parser.add_argument("--response_chunk", type=int, default=self.cfg.response_chunk)
parser.add_argument("--no_code_formatting", type=bool, default=self.cfg.no_code_formatting)
parser.add_argument("--cache_8bit", type=bool, default=self.cfg.cache_8bit)
@ -73,6 +70,9 @@ class ArgumentParser:
parser.add_argument("--batch_size", type=int, default=self.cfg.batch_size)
parser.add_argument("--cache_lens", type=int, default=self.cfg.cache_lens)
# kvc2 config
parser.add_argument("--kvc2_config_dir", type=str, default=self.cfg.kvc2_config_dir)
# log configs
# log level: debug, info, warn, error, crit
parser.add_argument("--log_dir", type=str, default=self.cfg.log_dir)
@ -103,6 +103,18 @@ class ArgumentParser:
# local chat
parser.add_argument("--prompt_file", type=str, default=self.cfg.prompt_file)
# async server
parser.add_argument("--sched_strategy", type=str, default=self.cfg.sched_strategy)
# parser.add_argument("--sched_port", type=int, default=self.cfg.sched_port)
# parser.add_argument("--sched_metrics_port", type=int, default=self.cfg.sched_metrics_port)
# parser.add_argument("--kvc2_metrics_port", type=int, default=self.cfg.kvc2_metrics_port)
parser.add_argument("--page_size", type=str, default=self.cfg.page_size)
parser.add_argument("--memory_gpu_only", type=str, default=self.cfg.memory_gpu_only)
parser.add_argument("--utilization_percentage", type=str, default=self.cfg.utilization_percentage)
parser.add_argument("--cpu_memory_size_GB", type=str, default=self.cfg.cpu_memory_size_GB)
args = parser.parse_args()
if (args.model_dir is not None or args.model_path is not None):
if (args.model_path is not None):
@ -123,6 +135,15 @@ class ArgumentParser:
self.cfg.mount_web = args.web
self.cfg.server_ip = args.host
self.cfg.server_port = args.port
self.cfg.backend_type = args.type
self.cfg.user_force_think = args.force_think
args.gpu_memory_size = args.cache_lens*2*576*61
self.cfg.gpu_memory_size = args.gpu_memory_size
free_ports = get_free_ports(3, [args.port])
args.sched_port = free_ports[0]
args.sched_metrics_port = free_ports[1]
args.kvc2_metrics_port = free_ports[2]
self.cfg.sched_port = free_ports[0]
self.cfg.sched_metrics_port = free_ports[1]
self.cfg.kvc2_metrics_port = free_ports[2]
return args

View file

@ -12,18 +12,10 @@ class ConfigArgs(BaseModel):
class Config:
protected_namespaces = ()
paged: bool = Field(None, description="Whether to use paged attention kv cache")
total_context: int = Field(
None,
description=(
"Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the"
" total to distribute dynamically over however many jobs are active at once"
),
)
max_batch_size: int = Field(
None, description="Max number of batches to run at once, assuming the sequences will fit within total_context"
)
chunk_prefill_size: int = Field(
chunk_size: int = Field(
None,
description=(
"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new"
@ -70,7 +62,6 @@ class ConfigArgs(BaseModel):
repetition_penalty: float = Field(None, description="Sampler repetition penalty, default = 1.01 (1 to disable)")
frequency_penalty: float = Field(None, description="Sampler frequency penalty, default = 0.0 (0 to disable)")
presence_penalty: float = Field(None, description="Sampler presence penalty, default = 0.0 (0 to disable)")
max_response_tokens: int = Field(None, description="Max tokens per response, default = 1000")
response_chunk: int = Field(None, description="Space to reserve in context for reply, default = 250")
no_code_formatting: bool = Field(None, description="Disable code formatting/syntax highlighting")
cache_8bit: bool = Field(None, description="Use 8-bit (FP8) cache")

View file

@ -9,9 +9,11 @@ from ktransformers.server.backend.interfaces.transformers import TransformersThr
from ktransformers.server.backend.interfaces.ktransformers import KTransformersThreadContext
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaThreadContext
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface
from ktransformers.server.backend.interfaces.transformers import TransformersInterface
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface
class ThreadContextManager:
lock: Lock
threads_context: Dict[ObjectID, ThreadContext]
@ -36,7 +38,16 @@ class ThreadContextManager:
elif isinstance(self.interface, TransformersInterface):
new_context = TransformersThreadContext(run, self.interface)
else:
raise NotImplementedError
from ktransformers.server.backend.interfaces.balance_serve import BalanceServeThreadContext
from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface
if isinstance(self.interface, BalanceServeInterface):
new_context = BalanceServeThreadContext(run, self.interface)
else:
raise NotImplementedError
# elif isinstance(self.interface, BalanceServeInterface):
# new_context = BalanceServeThreadContext(run, self.interface)
# else:
# raise NotImplementedError
self.threads_context[run.thread_id] = new_context
# self.threads_context[run.thread_id] = ExllamaInferenceContext(run)
re = self.threads_context[run.thread_id]

View file

@ -0,0 +1,410 @@
from typing import Any, AsyncIterator, List, Optional, Set
from ktransformers.models.custom_cache import KDeepSeekV3Cache
from transformers import (
AutoTokenizer,
AutoConfig,
GenerationConfig,
StaticCache,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
from ktransformers.server.config.config import Config
from ..base import ThreadContext, BackendInterfaceBase
import torch
from ktransformers.server.backend.interfaces.transformers import (
ConfigArgs,
default_args,
TextStreamer,
)
from ktransformers.server.schemas.base import ObjectID
from ktransformers.server.config.log import logger
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
from ktransformers.server.balance_serve.inference.model_runner import ModelRunner
from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput
from ktransformers.server.balance_serve.sched_rpc import SchedulerClient
from ktransformers.server.balance_serve.settings import sched_ext
from torch.multiprocessing import Queue
import torch.multiprocessing as mp
from ktransformers.server.schemas.endpoints.chat import RawUsage
from ktransformers.server.utils.multi_timer import Profiler
import zmq
import time
import queue
import tempfile
import asyncio
import threading
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
import os
ktransformer_rules_dir = (
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/")
)
default_optimize_rules = {
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct-serve.yaml",
}
async def chat_stream(queue: asyncio.Queue, tokenizer: AutoTokenizer):
streamer = TextStreamer(tokenizer)
while True:
token = await queue.get()
#print(f"Got token: {token}")
if token is None:
# str = f'{token}\n\n'
# str = model.tokenizer.decode(token)
s = streamer.end()
if s is not None:
yield s
break
# str = model.tokenizer.decode(token)
yield streamer.put(token)
def fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_tokens: torch.Tensor, query_manager: QueryManager = None):
#print(len(query_updates), generated_tokens.size(0), generated_tokens)
for i in range(generated_tokens.size(0)):
print(generated_tokens[i].item())
query_updates[i].generated_token = generated_tokens[i].item()
if not query_manager.query_map[query_updates[i].id].is_prefill:
pos = query_updates[i].active_position
query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i]
def report_last_time_performance(profiler: Profiler):
try:
tokenize_time = profiler.get_timer_sec('tokenize')
prefill_time = profiler.get_timer_sec('prefill')
decode_time = profiler.get_timer_sec('decode')
prefill_count = profiler.get_counter('prefill')
decode_count = profiler.get_counter('decode')
logger.info(f'Performance(T/s): prefill {prefill_count/prefill_time}, decode {decode_count/decode_time}. Time(s): tokenize {tokenize_time}, prefill {prefill_time}, decode {decode_time}')
except:
logger.info(f'Performance statistics not recorded')
class Engine:
sched_client : SchedulerClient
updates : list[sched_ext.QueryUpdate]
batch : sched_ext.BatchQueryTodo
model_runner: ModelRunner
sampler: Sampler
query_manager: QueryManager
cache: KDeepSeekV3Cache
def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None):
self.args = args
# 子进程和父进程无法共享 config 变量
for key, value in vars(args).items():
if value is not None and hasattr(Config(), key):
setattr(Config(), key, value)
self.device = self.args.device
self.sched_client = SchedulerClient(args.sched_port)
self.updates = []
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
self.gen_queue = generated_token_queue
print(f"Getting inference context from sched_client.")
inference_context = self.sched_client.get_inference_context_raw()
print(f"Got inference context, sending it to subscribers.")
inference_context = self.sched_client.rebuild_inferece_context(inference_context)
self.cache.load(inference_context)
print(f"kv_cache loaded successfully.")
self.block_num = inference_context.k_cache[0].size(1)
with torch.device("meta"):
if config.architectures[0] == "DeepseekV3ForCausalLM":
self.model = KDeepseekV3ForCausalLM(config, self.cache)
elif config.architectures[0] == "DeepseekV2ForCausalLM":
self.model = KDeepseekV2ForCausalLM(config, self.cache)
# print(self.block_num)
context = zmq.Context()
self.pub_socket = context.socket(zmq.PUB)
self.pub_socket.bind(f"ipc://{broadcast_endpoint}")
# time.sleep(1) # make sure all subscribers are ready
try:
generation_config = GenerationConfig.from_pretrained(args.model_dir)
except:
generation_config = GenerationConfig(
max_length=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
do_sample=True
)
if args.optimize_config_path is None:
optimize_config_path = default_optimize_rules[config.architectures[0]]
else:
optimize_config_path = args.optimize_config_path
gguf_path = args.gguf_path
if gguf_path is None:
gguf_path = input(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
" belong to current model):"
)
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
self.model.generation_config = generation_config
if self.model.generation_config.pad_token_id is None:
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
self.model.eval()
#@TODO add config
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size)
self.sampler = Sampler()
self.query_manager = QueryManager(device = self.device, page_size = args.page_size)
def sampling(self, forward_output: ForwardBatchOutput):
generated_tokens = torch.empty(0, device=self.device, dtype=torch.int32)
for i in range(forward_output.num_batchs):
logit = forward_output.logits[i]
if hasattr(forward_output, "temperatures"):
temperatures = forward_output.temperatures[i]
else:
temperatures = None
if hasattr(forward_output, "top_ps"):
top_ps = forward_output.top_ps[i]
else:
top_ps = None
sample_options = SamplingOptions(logit.size(0), self.device, pretrained_config=self.model.generation_config, temperatures=temperatures, top_ps=top_ps)
generated_tokens, probs=self.sampler(logit, sample_options)
return generated_tokens, probs
def loop(self):
next_batch = None
while True:
self.batch = next_batch
if self.batch is not None:
self.model_runner.run(self.batch, self.query_manager)
if len(self.updates) > 0:
for q in self.updates:
if q.is_prefill == True:
continue
# print(f"Putting token {q.generated_token} into queue for query id: {q.id}")
try:
self.gen_queue.put((q.id, q.generated_token if q.decode_done == False else None), timeout=5)
except queue.Full:
pass#print("Queue is full after timeout; unable to put more items.")
next_batch = self.sched_client.update_last_batch(self.updates)
if next_batch.query_ids == []:
next_batch = None
self.pub_socket.send_pyobj(next_batch)
if next_batch is not None:
self.query_manager.add_query(next_batch)
if self.batch is not None:
self.model_runner.sync()
print(f"Model execution time (GPU): {self.model_runner.model_time:.3f} ms")
# if self.rank == 0:
generated_tokens, probs = self.sampling( self.model_runner.output)
self.updates = self.query_manager.update(self.batch)
fill_generated_tokens(self.updates, generated_tokens, self.query_manager)
else:
self.updates = []
class BalanceServeThreadContext(ThreadContext):
def get_local_messages(self):
local_messages = []
for m in self.messages:
local_messages.append({"role": m.role.value, "content": m.get_text_content()})
return local_messages
def run_engine(args, token_queue, broadcast_endpoint, event):
engine = Engine(args, token_queue, broadcast_endpoint)
if args.use_cuda_graph:
engine.model_runner.warmup()
event.set()
engine.loop()
class BalanceServeInterface(BackendInterfaceBase):
use_static_cache: bool = True
model: Any
tokenizer: AutoTokenizer
cache: StaticCache
generated_ids: torch.Tensor
seq_length: int
streamer: TextStreamer
# thread_related
last_request_id: Optional[str] = None
ever_generated_ids: Set[int] = set()
def __init__(self, args: ConfigArgs = default_args):
self.args = args
self.queue_map:dict[int,asyncio.Queue] = {}
self.thread_map: dict[int, int] = {}
processes = []
self.broadcast_endpoint = tempfile.NamedTemporaryFile(delete=False).name # @TODO add to config
ctx = mp.get_context("spawn")
self.token_queue = ctx.Queue(maxsize=1000)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
self.sched_client = SchedulerClient(args.sched_port)
self.streamer = TextStreamer(self.tokenizer)
start_event = ctx.Event()
p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event))
p.start()
processes.append(p)
start_event.wait()
def run_queue_proxy(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self.queue_proxy())
@asynccontextmanager
async def lifespan(self, app: FastAPI):
asyncio.create_task(self.queue_proxy())
yield
async def queue_proxy(self):
print("Queue Proxy Started")
while True:
try:
query_id, token = self.token_queue.get_nowait()
try:
# query id might not be allocated yet
self.queue_map[query_id].put_nowait(token)
#print(f"Proxy Put token: {token} to queue for query id: {query_id}")
except asyncio.QueueFull:
#print(f"Queue for query id: {query_id} is full, waiting to put: {token}")
await self.queue_map[query_id].put(token)
except queue.Empty:
# print("no new token")
# await asyncio.sleep(1)
await asyncio.sleep(0)
def tokenize_prompt(self, prompt: str):
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.args.device)
return input_ids
def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):
for m in messages:
if m["role"] == "system":
logger.warning(f'change {m["role"]} to user')
m["role"] = "user"
new_messages = [messages[0]]
for m in messages[1:]:
if m["role"] == "user" and new_messages[-1]["role"] == "user":
logger.warning("merge two adjacent user messages")
new_messages[-1]["content"] += '\n' + m["content"]
else:
new_messages.append(m)
input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
# drop <think> token in chat template
if input_str.endswith('<think>\n'):
input_str = input_str[:-len('<think>\n')]
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
logger.debug(f"get input ids of shape {input_ids.shape}")
return input_ids
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
profiler = Profiler()
profiler.create_and_start_timer("tokenize")
if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str):
#local_messages = local_messages[0]['content']
input_ids = self.tokenize_prompt(local_messages)
else:
raise ValueError("local_messages should be List or str")
if Config().user_force_think:
token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
input_ids = torch.cat(
[input_ids, token_thinks], dim=1
)
profiler.pause_timer("tokenize")
profiler.create_and_start_timer("prefill")
query_add = sched_ext.QueryAdd()
query_add.query_token = input_ids[0].tolist()
query_length = input_ids[0].shape[0]
query_add.query_length = query_length
profiler.set_counter("prefill", query_length)
#@TODO add server
stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
query_add.stop_criteria = stop_criteria
if temperature == 0:
temperature = 0.0001
query_add.sample_options.temperature = temperature
if top_p == 0:
top_p = 0.0001
query_add.sample_options.top_p = top_p
query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens)
query_id = self.sched_client.add_query(query_add)
queue = asyncio.Queue(maxsize=self.args.max_new_tokens)
self.queue_map[query_id] = queue
self.thread_map[thread_id] = query_id
is_first_token = True
async for token in chat_stream(self.queue_map[query_id], self.tokenizer):
if is_first_token:
is_first_token=False
profiler.pause_timer("prefill")
profiler.create_and_start_timer("decode")
profiler.set_counter("decode", 0)
if Config().user_force_think:
think = '<think>\n'
print(think, end="",flush=True)
yield think, None
else:
profiler.inc("decode")
yield token, None
profiler.pause_timer("decode")
report_last_time_performance(profiler)
yield self.streamer.end(), None
if profiler.get_counter('decode') >= self.args.max_new_tokens - 1:
yield "", "length"
else:
yield "", "stop"
yield RawUsage(
tokenize_time = profiler.get_timer_sec('tokenize'),
prefill_time = profiler.get_timer_sec('prefill'),
decode_time = profiler.get_timer_sec('decode'),
prefill_count = profiler.get_counter('prefill'),
decode_count = profiler.get_counter('decode'),
)

View file

@ -211,11 +211,11 @@ class KTransformersInterface(TransformersInterface):
chunk_start = 0
while chunk_start < input_ids_length:
chunk_end = min(chunk_start + self.args.chunk_prefill_size, input_ids_length)
chunk_end = min(chunk_start + self.args.chunk_size, input_ids_length)
if self.cache != None:
self.cache.cur_idx=cache_position[chunk_start:chunk_end]
logits = chunk_prefill(input_ids[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end])
chunk_start += self.args.chunk_prefill_size
chunk_start += self.args.chunk_size
if flashinfer_enabled:
MLAWrapperSingleton.reset_buffer()

View file

@ -208,6 +208,8 @@ class TransformersInterface(BackendInterfaceBase):
temperature = self.model.generation_config.temperature
if top_p is None:
top_p = self.model.generation_config.top_p
if top_p == 0:
top_p = 0.0001
generation_config, model_kwargs = self.model._prepare_generation_config(
None, max_length=self.args.max_new_tokens,
do_sample=True,
@ -341,7 +343,7 @@ class TransformersInterface(BackendInterfaceBase):
for i in range(1, self.max_new_tokens):
with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
if flashinfer_enabled:
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1, None,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)

View file

@ -0,0 +1,142 @@
'''
Date: 2024-11-07 07:30:16
LastEditors: djw
LastEditTime: 2024-11-15 14:23:26
'''
import math
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
import yaml
import json
from typing import Optional
class ModelConfig:
vocab_size: int = 32000
n_layer: int = 1
n_head: int = 32
dim: int = 4096
intermediate_size: int = 18944
n_local_heads: int = 8
head_dim: int = 128
rope_base: float = 1000000.0
norm_eps: float = 1e-06
rope_scaling: Optional[dict] = None
rms_norm_eps: float = 1e-6
hidden_act: str = "silu"
model_path: str
gguf_path: str
optimize_rule_path: str
speculative_rule_path: str
# quantize config
quant_algorithm: Optional[str] = None
quant_group_size: Optional[int] = None
quant_num_bits: Optional[int] = None
json_key_map = {
"vocab_size": "vocab_size",
"n_layer": "num_hidden_layers",
"n_head": "num_attention_heads",
"dim": "hidden_size",
"intermediate_size": "intermediate_size",
"n_local_heads": "num_key_value_heads",
"rope_base": "rope_theta",
"norm_eps": "norm_eps",
"rms_norm_eps": "rms_norm_eps",
"hidden_act": "hidden_act",
}
def __init__(self, config):
self.model_path = config["model"]["model_path"]
self.gguf_path = config["model"]["gguf_path"]
self.optimize_rule_path = config["model"]["optimize_rule_path"]
if "speculative_rule_path" in config["model"]:
self.speculative_rule_path = config["model"]["speculative_rule_path"]
self.speculative_gguf_path = config["model"]["speculative_gguf_path"]
self.speculative_model_path = config["model"]["speculative_model_path"]
self.quant_algorithm = config["model"]["quant"]["algorithm"]
self.quant_group_size = config["model"]["quant"]["group_size"]
self.quant_num_bits = config["model"]["quant"]["num_bits"]
self.load_config()
self.n_layer = config["model"]["n_layers"]
def load_config(self):
config_file = f"{self.model_path}/config.json"
try:
with open(config_file, "r") as f:
config_data = json.load(f)
except FileNotFoundError:
raise FileNotFoundError(f"Configuration file not found at {config_file}")
for attr, json_key in self.json_key_map.items():
if json_key in config_data:
setattr(self, attr, config_data[json_key])
else:
setattr(self, attr, getattr(self, attr))
class ParallelConfig:
def __init__(
self,
config,
) -> None:
self.pipeline_parallel_size = config["parallel"]["pp"]
self.tensor_parallel_size = config["parallel"]["tp"]
self.disable_custom_all_reduce = config["parallel"]["disable_custom_all_reduce"]
self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size
class AttnConfig:
page_size: int = 256
block_num: int = 32
max_batch_token : int = 256
max_batch_size: int = 32
def __init__(self, config):
self.page_size = config["attn"]["page_size"]
self.block_num = config["attn"]["block_num"]
self.max_batch_token = config["attn"]["max_batch_token"]
self.max_batch_size = config["attn"]["max_batch_size"]
class SamplerConfig():
# Batched sampling params
temperatures: float
is_all_greedy: bool
def __init__(self, config):
self.temperatures = config["sample"]["temperature"]
self.is_all_greedy = True
def load_yaml_config(file_path):
with open(file_path, "r") as f:
return yaml.safe_load(f)
class LLMConfig:
model_config: ModelConfig
parallel_config: ParallelConfig
attn_config: AttnConfig
sample_config: SamplerConfig
config_file: str
def __init__(self, config_file):
self.config_file = config_file
config = load_yaml_config(config_file)
self.model_config = ModelConfig(config)
self.parallel_config = ParallelConfig(config)
self.attn_config = AttnConfig(config)
self.sample_config = SamplerConfig(config)

View file

@ -0,0 +1,3 @@
from .communication_op import *
from .parallel_state import *
from .utils import *

View file

@ -0,0 +1,39 @@
"""
Date: 2024-12-11 06:02:42
LastEditors: djw
LastEditTime: 2024-12-12 09:52:06
"""
from typing import Any, Dict, Optional, Union
import torch
import torch.distributed
from .parallel_state import get_tp_group
def tensor_model_parallel_all_reduce(input_: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_, bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap)
def tensor_model_parallel_all_gather(
input_: torch.Tensor, dim: int = -1
) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_tp_group().all_gather(input_, dim)
def tensor_model_parallel_gather(
input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]:
"""Gather the input tensor across model parallel group."""
return get_tp_group().gather(input_, dst, dim)
def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0
):
if not torch.distributed.is_initialized():
return tensor_dict
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)

View file

@ -0,0 +1,168 @@
"""This file is a pure Python wrapper for the cudart library.
It avoids the need to compile a separate shared library, and is
convenient for use when we just need to call a few functions.
"""
import ctypes
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
# this line makes it possible to directly load `libcudart.so` using `ctypes`
import torch # noqa
# === export types and functions from cudart to Python ===
# for the original cudart definition, please check
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
cudaError_t = ctypes.c_int
cudaMemcpyKind = ctypes.c_int
class cudaIpcMemHandle_t(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
@dataclass
class Function:
name: str
restype: Any
argtypes: List[Any]
def find_loaded_library(lib_name) -> Optional[str]:
"""
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
the file `/proc/self/maps` contains the memory maps of the process, which includes the
shared libraries loaded by the process. We can use this file to find the path of the
a loaded library.
""" # noqa
found = False
with open("/proc/self/maps") as f:
for line in f:
if lib_name in line:
found = True
break
if not found:
# the library is not loaded in the current process
return None
# if lib_name is libcudart, we need to match a line with:
# address /path/to/libcudart-hash.so.11.0
start = line.index("/")
path = line[start:].strip()
filename = path.split("/")[-1]
assert filename.rpartition(".so")[0].startswith(lib_name), \
f"Unexpected filename: {filename} for library {lib_name}"
return path
class CudaRTLibrary:
exported_functions = [
# cudaError_t cudaSetDevice ( int device )
Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
# cudaError_t cudaDeviceSynchronize ( void )
Function("cudaDeviceSynchronize", cudaError_t, []),
# cudaError_t cudaDeviceReset ( void )
Function("cudaDeviceReset", cudaError_t, []),
# const char* cudaGetErrorString ( cudaError_t error )
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
# cudaError_t cudaMalloc ( void** devPtr, size_t size )
Function("cudaMalloc", cudaError_t,
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
# cudaError_t cudaFree ( void* devPtr )
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
Function("cudaMemset", cudaError_t,
[ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
Function("cudaMemcpy", cudaError_t, [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind
]),
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
Function("cudaIpcGetMemHandle", cudaError_t,
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
Function("cudaIpcOpenMemHandle", cudaError_t, [
ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint
]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
if so_file is None:
so_file = find_loaded_library("libcudart")
assert so_file is not None, \
"libcudart is not loaded in the current process"
if so_file not in CudaRTLibrary.path_to_library_cache:
lib = ctypes.CDLL(so_file)
CudaRTLibrary.path_to_library_cache[so_file] = lib
self.lib = CudaRTLibrary.path_to_library_cache[so_file]
if so_file not in CudaRTLibrary.path_to_dict_mapping:
_funcs = {}
for func in CudaRTLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
def CUDART_CHECK(self, result: cudaError_t) -> None:
if result != 0:
error_str = self.cudaGetErrorString(result)
raise RuntimeError(f"CUDART error: {error_str}")
def cudaGetErrorString(self, error: cudaError_t) -> str:
return self.funcs["cudaGetErrorString"](error).decode("utf-8")
def cudaSetDevice(self, device: int) -> None:
self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
def cudaDeviceSynchronize(self) -> None:
self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
def cudaDeviceReset(self) -> None:
self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
def cudaMalloc(self, size: int) -> ctypes.c_void_p:
devPtr = ctypes.c_void_p()
self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
return devPtr
def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int,
count: int) -> None:
self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p,
count: int) -> None:
cudaMemcpyDefault = 4
kind = cudaMemcpyDefault
self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
def cudaIpcGetMemHandle(self,
devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
handle = cudaIpcMemHandle_t()
self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"](
ctypes.byref(handle), devPtr))
return handle
def cudaIpcOpenMemHandle(self,
handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
cudaIpcMemLazyEnablePeerAccess = 1
devPtr = ctypes.c_void_p()
self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"](
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess))
return devPtr

View file

@ -0,0 +1,310 @@
import ctypes
from contextlib import contextmanager
from typing import List, Optional, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
import server.envs as envs
from server.inference.distributed.cuda_wrapper import CudaRTLibrary
from server.inference.distributed.custom_all_reduce_utils import gpu_p2p_access_check
from server.inference.distributed.parallel_state import in_the_same_node_as
from server.inference.platforms import current_platform
from server.utils import cuda_device_count_stateless
import vLLMCustomAllreduce
try:
vLLMCustomAllreduce.meta_size()
custom_ar = True
except Exception:
# For AMD GPUs and CPUs
custom_ar = False
def _can_p2p(rank: int, world_size: int) -> bool:
for i in range(world_size):
if i == rank:
continue
if envs.VLLM_SKIP_P2P_CHECK:
print("Skipping P2P check and trusting the driver's P2P report.")
return torch.cuda.can_device_access_peer(rank, i)
if not gpu_p2p_access_check(rank, i):
return False
return True
def is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
class CustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
# max_size: max supported allreduce size
def __init__(
self,
group: ProcessGroup,
device: Union[int, str, torch.device],
max_size=8192 * 1024,
) -> None:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self._IS_CAPTURING = False
self.disabled = True
if not custom_ar:
# disable because of missing custom allreduce library
# e.g. in a non-cuda environment
return
self.group = group
assert (
dist.get_backend(group) != dist.Backend.NCCL
), "CustomAllreduce should be attached to a non-NCCL group."
if not all(in_the_same_node_as(group, source_rank=0)):
# No need to initialize custom allreduce for multi-node case.
print(
"Custom allreduce is disabled because this process group"
" spans across nodes."
)
return
rank = dist.get_rank(group=self.group)
world_size = dist.get_world_size(group=self.group)
if world_size == 1:
# No need to initialize custom allreduce for single GPU case.
return
if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
print(
"Custom allreduce is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly.",
world_size,
str(CustomAllreduce._SUPPORTED_WORLD_SIZES),
)
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(cuda_device_count_stateless()))
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
gather_list = [
torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(world_size)
]
dist.all_gather(gather_list, tensor, group=self.group)
physical_device_ids = [t.item() for t in gather_list]
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
assert current_platform.is_cuda()
from server.inference.platforms.cuda import CudaPlatform
cuda_platform: CudaPlatform = current_platform
full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids)
if world_size > 2 and not full_nvlink:
print(
"Custom allreduce is disabled because it's not supported on"
" more than two PCIe-only GPUs. To silence this warning, "
"specify disable_custom_all_reduce=True explicitly."
)
return
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if not _can_p2p(rank, world_size):
print(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly."
)
return
self.disabled = False
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self.meta_ptrs = self.create_shared_buffer(
vLLMCustomAllreduce.meta_size() + max_size, group=group
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
)
self.max_size = max_size
self.rank = rank
self.world_size = world_size
self.full_nvlink = full_nvlink
self._ptr = vLLMCustomAllreduce.init_custom_ar(
self.meta_ptrs, self.rank_data, rank, self.full_nvlink
)
vLLMCustomAllreduce.register_buffer(self._ptr, self.buffer_ptrs)
@staticmethod
def create_shared_buffer(
size_in_bytes: int, group: Optional[ProcessGroup] = None
) -> List[int]:
"""
Creates a shared buffer and returns a list of pointers
representing the buffer on all processes in the group.
"""
lib = CudaRTLibrary()
pointer = lib.cudaMalloc(size_in_bytes)
handle = lib.cudaIpcGetMemHandle(pointer)
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
pointers: List[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer.value) # type: ignore
else:
pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore
return pointers
@staticmethod
def free_shared_buffer(
pointers: List[int], group: Optional[ProcessGroup] = None
) -> None:
rank = dist.get_rank(group=group)
lib = CudaRTLibrary()
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
@contextmanager
def capture(self):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try:
self._IS_CAPTURING = True
yield
finally:
self._IS_CAPTURING = False
if not self.disabled:
self.register_graph_buffers()
def register_graph_buffers(self):
handle, offset = vLLMCustomAllreduce.get_graph_buffer_ipc_meta(self._ptr)
print("Registering %d cuda graph addresses", len(offset))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
for i, rank in enumerate(ranks):
dist.broadcast_object_list(
all_data[i], src=rank, group=self.group, device="cpu"
)
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
vLLMCustomAllreduce.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor):
if self.disabled:
return False
inp_size = inp.numel() * inp.element_size()
# custom allreduce requires input byte size to be multiples of 16
if inp_size % 16 != 0:
return False
if not is_weak_contiguous(inp):
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if self.world_size == 2 or self.full_nvlink:
return inp_size < self.max_size
return False
def all_reduce(
self, inp: torch.Tensor, *, out: torch.Tensor = None, bsz_tensor: torch.Tensor = None, registered: bool = False,
is_compute_bound=False, overlap=False
):
"""Performs an out-of-place all reduce.
If registered is True, this assumes inp's pointer is already
IPC-registered. Otherwise, inp is first copied into a pre-registered
buffer.
"""
if is_compute_bound:
sms = 2 if overlap else 36
else:
sms = 20 if overlap else 36
#print("all reduce sms", sms)
if out is None:
out = torch.empty_like(inp)
if registered:
vLLMCustomAllreduce.all_reduce(self._ptr, inp, out, 0, 0, bsz_tensor, block_limit=sms)
else:
vLLMCustomAllreduce.all_reduce(
self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size, bsz_tensor, block_limit=sms
)
return out
def custom_all_reduce(self, input: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> Optional[torch.Tensor]:
"""The main allreduce API that provides support for cuda graph."""
# When custom allreduce is disabled, this will be None.
if self.disabled or not self.should_custom_ar(input):
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.all_reduce(input, bsz_tensor=bsz_tensor, registered=True, is_compute_bound=is_compute_bound, overlap=overlap)
else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return torch.empty_like(input)
else:
# Note: outside of cuda graph context, custom allreduce incurs a
# cost of cudaMemcpy, which should be small (<=1% of overall
# latency) compared to the performance gain of using custom kernels
return self.all_reduce(input, bsz_tensor=bsz_tensor, registered=False, is_compute_bound=is_compute_bound, overlap=overlap)
def close(self):
if not self.disabled and self._ptr:
vLLMCustomAllreduce.dispose(self._ptr)
self._ptr = 0
self.free_shared_buffer(self.meta_ptrs)
self.free_shared_buffer(self.buffer_ptrs)
def __del__(self):
self.close()

View file

@ -0,0 +1,272 @@
import ctypes
import json
import os
import pickle
import subprocess
import sys
import tempfile
from itertools import product
from typing import Dict, List, Optional, Sequence
import torch.distributed as dist
import torch.multiprocessing as mp
import server.envs as envs
from server.inference.distributed.cuda_wrapper import CudaRTLibrary
from server.utils import cuda_device_count_stateless, update_environment_variables
def producer(
batch_src: Sequence[int],
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices: Optional[str] = None,
):
if cuda_visible_devices is not None:
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
lib = CudaRTLibrary()
for i in batch_src:
lib.cudaSetDevice(i)
pointer = lib.cudaMalloc(1024)
lib.cudaMemset(pointer, 1, 1024)
lib.cudaDeviceSynchronize()
handle = lib.cudaIpcGetMemHandle(pointer)
producer_queue.put(handle)
open_success = consumer_queue.get()
if open_success:
# use two queues to simulate barrier
producer_queue.put(0)
consumer_queue.get()
# check if the memory is modified
host_data = (ctypes.c_char * 1024)()
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
for i in range(1024):
if ord(host_data[i]) != 2:
open_success = False
break
result_queue.put(open_success)
lib.cudaDeviceReset()
def consumer(
batch_tgt: Sequence[int],
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices: Optional[str] = None,
):
if cuda_visible_devices is not None:
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
lib = CudaRTLibrary()
for j in batch_tgt:
lib.cudaSetDevice(j)
handle = producer_queue.get()
open_success = False
try:
pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
open_success = True
except RuntimeError:
# cannot error out here, because the producer process
# is still waiting for the response.
pass
consumer_queue.put(open_success)
if open_success:
# modify the memory
lib.cudaMemset(pointer, 2, 1024)
lib.cudaDeviceSynchronize()
# use two queues to simulate barrier
producer_queue.get()
consumer_queue.put(0)
# check if the memory is modified
host_data = (ctypes.c_char * 1024)()
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
for i in range(1024):
if ord(host_data[i]) != 2:
open_success = False
break
result_queue.put(open_success)
lib.cudaDeviceReset()
def can_actually_p2p(
batch_src: Sequence[int],
batch_tgt: Sequence[int],
) -> Sequence[bool]:
"""
Usually, checking if P2P access is enabled can be done by
`torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
returns `True` even if P2P access is not actually possible.
See https://github.com/vllm-project/vllm/issues/2728 and
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
Therefore, we have to perform a real P2P access to check if it is actually
possible.
Note on p2p and cuda IPC:
Usually, one process uses one GPU:
GPU src --> cuda context src --> tensor src --> process src
We need to combine p2p and cuda IPC, so that:
GPU src --> cuda context src --> tensor src --> process src
|shared|
GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
That is to say, process src creates a tensor in GPU src, passes IPC handle to
process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
tensor in process tgt will be reflected in the tensor in process src, because
they are the same memory segment.
It is important to note that process tgt accesses the tensor in GPU tgt, not
GPU src. That's why we need p2p access.
The most time-consuming part is the process creation. To avoid creating
processes for every pair of GPUs, we use batched testing. We create two
processes for testing all pairs of GPUs in batch. The trick is to reset
the device after each test (which is not available in PyTorch).
""" # noqa
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
# pass the CUDA_VISIBLE_DEVICES to the child process
# to make sure they see the same set of GPUs
# make sure the processes are spawned
smp = mp.get_context("spawn")
producer_queue = smp.Queue()
consumer_queue = smp.Queue()
result_queue = smp.Queue()
p_src = smp.Process(
target=producer,
args=(
batch_src,
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices,
),
)
p_tgt = smp.Process(
target=consumer,
args=(
batch_tgt,
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices,
),
)
p_src.start()
p_tgt.start()
p_src.join()
p_tgt.join()
assert p_src.exitcode == 0 and p_tgt.exitcode == 0
result: List[bool] = []
for src, tgt in zip(batch_src, batch_tgt):
a = result_queue.get()
b = result_queue.get()
if a != b:
print(
"Two processes do not agree on the P2P access"
" status on %d -> %d, treat as disabled.",
src,
tgt,
)
result.append(False)
else:
result.append(a)
return result
# why do we need this cache?
# we are testing peer-to-peer (p2p) access between GPUs,across processes.
# if we test it every time, it will be very slow, because we need to create
# N * N * 2 processes, where N is the world size. This is very slow.
# to reduce the time, we use a cache file to store the p2p access status.
# the cache file is generated by the master process if it does not exist.
# then all the processes can read the cache file to check the p2p access status.
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
# e.g. used by different vllm engines. The device id in the cache file is a
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
# of visible devices in the vllm engine.
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
"""Check if GPU src can access GPU tgt."""
# if the cache variable is already calculated,
# read from the cache instead of checking it again
global _gpu_p2p_access_cache
if _gpu_p2p_access_cache is not None:
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
is_distributed = dist.is_initialized()
num_dev = cuda_device_count_stateless()
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
path = os.path.join(
envs.VLLM_CACHE_ROOT, f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
)
os.makedirs(os.path.dirname(path), exist_ok=True)
from server.inference.distributed.parallel_state import get_world_group
if (not is_distributed or get_world_group().local_rank == 0) and (
not os.path.exists(path)
):
# only the local master process (with local_rank == 0) can
# enter this block to calculate the cache
print("generating GPU P2P access cache in %s", path)
cache: Dict[str, bool] = {}
ids = list(range(num_dev))
# batch of all pairs of GPUs
batch_src, batch_tgt = zip(*list(product(ids, ids)))
# NOTE: we use `subprocess` rather than `multiprocessing` here
# because the caller might not have `if __name__ == "__main__":`,
# in that case we cannot use spawn method in multiprocessing.
# However, `can_actually_p2p` requires spawn method.
# The fix is, we use `subprocess` to call the function,
# where we have `if __name__ == "__main__":` in this file.
# use a temporary file to store the result
# we don't use the output of the subprocess directly,
# because the subprocess might produce logging output
with tempfile.NamedTemporaryFile() as output_file:
input_bytes = pickle.dumps((batch_src, batch_tgt, output_file.name))
returned = subprocess.run(
[sys.executable, __file__], input=input_bytes, capture_output=True
)
# check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
# wrap raised exception to provide more information
raise RuntimeError(
f"Error happened when batch testing "
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
f"{returned.stderr.decode()}"
) from e
with open(output_file.name, "rb") as f:
result = pickle.load(f)
for _i, _j, r in zip(batch_src, batch_tgt, result):
cache[f"{_i}->{_j}"] = r
with open(path, "w") as f:
json.dump(cache, f, indent=4)
if is_distributed:
get_world_group().barrier()
print("reading GPU P2P access cache from %s", path)
with open(path) as f:
cache = json.load(f)
_gpu_p2p_access_cache = cache
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
__all__ = ["gpu_p2p_access_check"]
if __name__ == "__main__":
batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())
result = can_actually_p2p(batch_src, batch_tgt)
with open(output_file, "wb") as f:
f.write(pickle.dumps(result))

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,201 @@
from contextlib import contextmanager
from typing import Optional, Union
# ===================== import region =====================
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp
from server.inference.distributed.pynccl_wrapper import (
NCCLLibrary,
buffer_type,
cudaStream_t,
ncclComm_t,
ncclDataTypeEnum,
ncclRedOpTypeEnum,
ncclUniqueId,
)
from server.inference.distributed.utils import StatelessProcessGroup
class PyNcclCommunicator:
def __init__(
self,
group: Union[ProcessGroup, StatelessProcessGroup],
device: Union[int, str, torch.device],
library_path: Optional[str] = None,
):
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the PyNcclCommunicator to. If None,
it will be bind to f"cuda:{local_rank}".
library_path: the path to the NCCL library. If None, it will
use the default library path.
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
"""
if not isinstance(group, StatelessProcessGroup):
assert dist.is_initialized()
assert (
dist.get_backend(group) != dist.Backend.NCCL
), "PyNcclCommunicator should be attached to a non-NCCL group."
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
else:
self.rank = group.rank
self.world_size = group.world_size
self.group = group
# if world_size == 1, no need to create communicator
if self.world_size == 1:
self.available = False
self.disabled = True
self.stream = None
return
try:
self.nccl = NCCLLibrary(library_path)
except Exception:
# disable because of missing NCCL library
# e.g. in a non-GPU environment
self.available = False
self.disabled = True
self.stream = None
return
self.available = True
self.disabled = False
print("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
if self.rank == 0:
# get the unique id from NCCL
self.unique_id = self.nccl.ncclGetUniqueId()
else:
# construct an empty unique id
self.unique_id = ncclUniqueId()
if not isinstance(group, StatelessProcessGroup):
tensor = torch.ByteTensor(list(self.unique_id.internal))
ranks = dist.get_process_group_ranks(group)
# arg `src` in `broadcast` is the global rank
dist.broadcast(tensor, src=ranks[0], group=group)
byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
else:
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
# nccl communicator and stream will use this device
# `torch.cuda.device` is a context manager that changes the
# current cuda device to the specified one
with torch.cuda.device(device):
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
self.world_size, self.unique_id, self.rank
)
self.stream = torch.cuda.Stream()
# A small all_reduce for warmup.
data = torch.zeros(1, device=device)
self.all_reduce(data)
self.stream.synchronize()
del data
# by default it is disabled, e.g. in profiling models and prefill phase.
# to use it, use under `with obj.change_state(enable=True)`, usually
# when we are using CUDA graph.
self.disabled = True
def all_reduce(
self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None
):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
if stream is None:
stream = self.stream
self.nccl.ncclAllReduce(
buffer_type(tensor.data_ptr()),
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
self.comm,
cudaStream_t(stream.cuda_stream),
)
def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
if stream is None:
stream = self.stream
self.nccl.ncclSend(
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
dst,
self.comm,
cudaStream_t(stream.cuda_stream),
)
def recv(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
if stream is None:
stream = self.stream
self.nccl.ncclRecv(
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
src,
self.comm,
cudaStream_t(stream.cuda_stream),
)
@contextmanager
def change_state(
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
):
"""
A context manager to change the state of the communicator.
"""
if enable is None:
# guess a default value when not specified
enable = self.available
if stream is None:
stream = self.stream
old_disable = self.disabled
old_stream = self.stream
self.stream = stream
self.disabled = not enable
yield
self.disabled = old_disable
self.stream = old_stream

View file

@ -0,0 +1,276 @@
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
# often gets stuck when initializing the NCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
# contains many other potential cuda APIs, that are not allowed during
# capturing the CUDA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
# doable, but we often encounter issues related with nccl versions, and need
# to switch between different versions of NCCL. See
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of NCCL by
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
# variable in the code.
import ctypes
import platform
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import torch
from torch.distributed import ReduceOp
from server.utils import find_nccl_library
# === export types and functions from nccl to Python ===
# for the original nccl definition, please check
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
ncclResult_t = ctypes.c_int
ncclComm_t = ctypes.c_void_p
class ncclUniqueId(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
cudaStream_t = ctypes.c_void_p
buffer_type = ctypes.c_void_p
ncclDataType_t = ctypes.c_int
class ncclDataTypeEnum:
ncclInt8 = 0
ncclChar = 0
ncclUint8 = 1
ncclInt32 = 2
ncclInt = 2
ncclUint32 = 3
ncclInt64 = 4
ncclUint64 = 5
ncclFloat16 = 6
ncclHalf = 6
ncclFloat32 = 7
ncclFloat = 7
ncclFloat64 = 8
ncclDouble = 8
ncclBfloat16 = 9
ncclNumTypes = 10
@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8:
return cls.ncclInt8
if dtype == torch.uint8:
return cls.ncclUint8
if dtype == torch.int32:
return cls.ncclInt32
if dtype == torch.int64:
return cls.ncclInt64
if dtype == torch.float16:
return cls.ncclFloat16
if dtype == torch.float32:
return cls.ncclFloat32
if dtype == torch.float64:
return cls.ncclFloat64
if dtype == torch.bfloat16:
return cls.ncclBfloat16
raise ValueError(f"Unsupported dtype: {dtype}")
ncclRedOp_t = ctypes.c_int
class ncclRedOpTypeEnum:
ncclSum = 0
ncclProd = 1
ncclMax = 2
ncclMin = 3
ncclAvg = 4
ncclNumOps = 5
@classmethod
def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.ncclSum
if op == ReduceOp.PRODUCT:
return cls.ncclProd
if op == ReduceOp.MAX:
return cls.ncclMax
if op == ReduceOp.MIN:
return cls.ncclMin
if op == ReduceOp.AVG:
return cls.ncclAvg
raise ValueError(f"Unsupported op: {op}")
@dataclass
class Function:
name: str
restype: Any
argtypes: List[Any]
class NCCLLibrary:
exported_functions = [
# const char* ncclGetErrorString(ncclResult_t result)
Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
# ncclResult_t ncclGetVersion(int *version);
Function("ncclGetVersion", ncclResult_t,
[ctypes.POINTER(ctypes.c_int)]),
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
Function("ncclGetUniqueId", ncclResult_t,
[ctypes.POINTER(ncclUniqueId)]),
# ncclResult_t ncclCommInitRank(
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
# note that ncclComm_t is a pointer type, so the first argument
# is a pointer to a pointer
Function("ncclCommInitRank", ncclResult_t, [
ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId,
ctypes.c_int
]),
# ncclResult_t ncclAllReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclAllReduce", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclSend(
# const void* sendbuff, size_t count, ncclDataType_t datatype,
# int dest, ncclComm_t comm, cudaStream_t stream);
Function("ncclSend", ncclResult_t, [
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclRecv(
# void* recvbuff, size_t count, ncclDataType_t datatype,
# int src, ncclComm_t comm, cudaStream_t stream);
Function("ncclRecv", ncclResult_t, [
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
ncclComm_t, cudaStream_t
]),
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
# it is better not to call it at all.
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
so_file = so_file or find_nccl_library()
try:
if so_file not in NCCLLibrary.path_to_dict_mapping:
lib = ctypes.CDLL(so_file)
NCCLLibrary.path_to_library_cache[so_file] = lib
self.lib = NCCLLibrary.path_to_library_cache[so_file]
except Exception as e:
print(
"Failed to load NCCL library from %s ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the nccl library might not exist, be corrupted "
"or it does not support the current platform %s."
"If you already have the library, please set the "
"environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path.", so_file,
platform.platform())
raise e
if so_file not in NCCLLibrary.path_to_dict_mapping:
_funcs: Dict[str, Any] = {}
for func in NCCLLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]
def ncclGetErrorString(self, result: ncclResult_t) -> str:
return self._funcs["ncclGetErrorString"](result).decode("utf-8")
def NCCL_CHECK(self, result: ncclResult_t) -> None:
if result != 0:
error_str = self.ncclGetErrorString(result)
raise RuntimeError(f"NCCL error: {error_str}")
def ncclGetVersion(self) -> str:
version = ctypes.c_int()
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
version_str = str(version.value)
# something like 21903 --> "2.19.3"
major = version_str[0].lstrip("0")
minor = version_str[1:3].lstrip("0")
patch = version_str[3:].lstrip("0")
return f"{major}.{minor}.{patch}"
def ncclGetUniqueId(self) -> ncclUniqueId:
unique_id = ncclUniqueId()
self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](
ctypes.byref(unique_id)))
return unique_id
def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
rank: int) -> ncclComm_t:
comm = ncclComm_t()
self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm),
world_size, unique_id,
rank))
return comm
def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
datatype, op, comm,
stream))
def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,
dest, comm, stream))
def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,
src: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src,
comm, stream))
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
__all__ = [
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
"ncclComm_t", "cudaStream_t", "buffer_type"
]

View file

@ -0,0 +1,219 @@
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import pickle
import time
from collections import deque
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
import torch
from torch.distributed import TCPStore
import server.envs as envs
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator
)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> Sequence[torch.Tensor]:
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# NOTE: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
def get_pp_indices(
num_hidden_layers: int, pp_rank: int, pp_size: int
) -> Tuple[int, int]:
"""Try to evenly distribute layers across partitions.
If the number of layers is not divisible by the number of partitions,
the last partition will have the remaining layers.
"""
partition_list_str = envs.VLLM_PP_LAYER_PARTITION
if partition_list_str is not None:
try:
partitions = [int(layer) for layer in partition_list_str.split(",")]
except ValueError as err:
raise ValueError(
"Invalid partition string: {}".format(partition_list_str)
) from err
if len(partitions) != pp_size:
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
if sum(partitions) != num_hidden_layers:
raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.")
start_layer = sum(partitions[:pp_rank])
end_layer = start_layer + partitions[pp_rank]
else:
layers_per_partition = num_hidden_layers // pp_size
start_layer = pp_rank * layers_per_partition
end_layer = start_layer + layers_per_partition
if pp_rank == pp_size - 1:
end_layer = num_hidden_layers
return (start_layer, end_layer)
@dataclasses.dataclass
class StatelessProcessGroup:
"""A dataclass to hold a metadata store, and the rank, world_size of the
group. Only use it to communicate metadata between processes.
For data-plane communication, create NCCL-related objects.
"""
rank: int
world_size: int
store: torch._C._distributed_c10d.Store
data_expiration_seconds: int = 3600 # 1 hour
# dst rank -> counter
send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
# src rank -> counter
recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
broadcast_send_counter: int = 0
broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
# A deque to store the data entries, with key and timestamp.
entries: Deque[Tuple[str, float]] = dataclasses.field(default_factory=deque)
def __post_init__(self):
assert self.rank < self.world_size
self.send_dst_counter = {i: 0 for i in range(self.world_size)}
self.recv_src_counter = {i: 0 for i in range(self.world_size)}
self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)}
def send_obj(self, obj: Any, dst: int):
"""Send an object to a destination rank."""
self.expire_data()
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
self.store.set(key, pickle.dumps(obj))
self.send_dst_counter[dst] += 1
self.entries.append((key, time.time()))
def expire_data(self):
"""Expire data that is older than `data_expiration_seconds` seconds."""
while self.entries:
# check the oldest entry
key, timestamp = self.entries[0]
if time.time() - timestamp > self.data_expiration_seconds:
self.store.delete_key(key)
self.entries.popleft()
else:
break
def recv_obj(self, src: int) -> Any:
"""Receive an object from a source rank."""
obj = pickle.loads(
self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}")
)
self.recv_src_counter[src] += 1
return obj
def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
"""Broadcast an object from a source rank to all other ranks.
It does not clean up after all ranks have received the object.
Use it for limited times, e.g., for initialization.
"""
if self.rank == src:
self.expire_data()
key = f"broadcast_from/{src}/" f"{self.broadcast_send_counter}"
self.store.set(key, pickle.dumps(obj))
self.broadcast_send_counter += 1
self.entries.append((key, time.time()))
return obj
else:
key = f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}"
recv_obj = pickle.loads(self.store.get(key))
self.broadcast_recv_src_counter[src] += 1
return recv_obj
def all_gather_obj(self, obj: Any) -> list[Any]:
"""All gather an object from all ranks."""
gathered_objs = []
for i in range(self.world_size):
if i == self.rank:
gathered_objs.append(obj)
self.broadcast_obj(obj, src=self.rank)
else:
recv_obj = self.broadcast_obj(None, src=i)
gathered_objs.append(recv_obj)
return gathered_objs
def barrier(self):
"""A barrier to synchronize all ranks."""
for i in range(self.world_size):
if i == self.rank:
self.broadcast_obj(None, src=self.rank)
else:
self.broadcast_obj(None, src=i)
@staticmethod
def create(
host: str,
port: int,
rank: int,
world_size: int,
data_expiration_seconds: int = 3600,
) -> "StatelessProcessGroup":
"""A replacement for `torch.distributed.init_process_group` that does not
pollute the global state.
If we have process A and process B called `torch.distributed.init_process_group`
to form a group, and then we want to form another group with process A, B, C,
D, it is not possible in PyTorch, because process A and process B have already
formed a group, and process C and process D cannot join that group. This
function is a workaround for this issue.
`torch.distributed.init_process_group` is a global call, while this function
is a stateless call. It will return a `StatelessProcessGroup` object that can be
used for exchanging metadata. With this function, process A and process B
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
C, and D can call `StatelessProcessGroup.create` to form another group.
""" # noqa
store = TCPStore(
host_name=host,
port=port,
world_size=world_size,
is_master=(rank == 0),
)
return StatelessProcessGroup(
rank=rank,
world_size=world_size,
store=store,
data_expiration_seconds=data_expiration_seconds,
)

View file

@ -0,0 +1,284 @@
'''
Date: 2024-11-12 14:15:16
LastEditors: Xie Weiyu ervinxie@qq.com
LastEditTime: 2024-11-26 08:12:49
'''
import torch
from ktransformers.server.balance_serve.settings import sched_ext
from ktransformers.server.balance_serve.inference.query_manager import QueryManager, QueryInfo
import time
from ktransformers.server.config.config import Config
class ForwardBatchInput:
class ForwardMiniBatch:
q_indptr: torch.Tensor
kv_indptr: torch.Tensor
kv_indices: torch.Tensor
kv_last_page_len: torch.Tensor
kv_len: torch.Tensor
position_ids: torch.Tensor
tokens: torch.Tensor
batch_indices: torch.Tensor
positions: torch.Tensor
chunk_size: int
decode_batch: int
is_last_prefill_chunk: bool
logits_start: list
temperatures: torch.Tensor
top_ps: torch.Tensor
def __init__(self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo], prefill_s: list[int] = None, prefill_l: list[int] = None, device = torch.device('cuda'), page_size = 256):
batch_decode = len(decode_querys_info)
batch_prefill = len(prefill_querys_info)
self.q_indptr = torch.tensor([0], device=device, dtype=torch.int32)
self.kv_indptr = torch.tensor([0], device=device, dtype=torch.int32)
self.kv_indices = torch.tensor([], device=device, dtype=torch.int32)
self.kv_len = torch.tensor([], device=device, dtype=torch.int32)
self.kv_last_page_len = torch.tensor([], device=device, dtype=torch.int32)
self.position_ids = torch.tensor([], device=device, dtype=torch.int32)
self.tokens = torch.tensor([], device=device, dtype=torch.int32)
self.temperatures = torch.tensor([], device=device, dtype=torch.float32)
self.top_ps = torch.tensor([], device=device, dtype=torch.float32)
self.logits_start = []
self.decode_batch = batch_decode
self.num_tokens = batch_decode + sum(prefill_l)
self.batch_size = batch_decode + batch_prefill
for i, prefill_query_info in enumerate(prefill_querys_info):
if prefill_query_info != None:
prefill_kv_block_len = (prefill_query_info.active_position + prefill_l[i] + page_size - 1) // page_size if prefill_query_info is not None else 0
# print(f"block_len: {prefill_kv_block_len}, page_size: {page_size}")
self.q_indptr = torch.concat((self.q_indptr, torch.tensor([prefill_l[i] + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([prefill_kv_block_len + self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
self.kv_indices = torch.concat((self.kv_indices, prefill_query_info.block_index[:prefill_kv_block_len]), dim=0)
self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i]) % page_size if (prefill_query_info.active_position + prefill_l[i]) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)
self.kv_len = torch.concat((self.kv_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i])], device=device, dtype=torch.int32)), dim=0)
self.position_ids = torch.concat((self.position_ids, torch.arange(prefill_s[i], prefill_l[i] + prefill_s[i], device=device, dtype=torch.int32)), dim=0)
self.tokens = torch.concat((self.tokens, prefill_query_info.query_tokens[prefill_s[i]:prefill_s[i] + prefill_l[i]]), dim=0)
self.logits_start.append(prefill_l[i] - 1 if len(self.logits_start) == 0 else sum(prefill_l[:i+1])-1)
self.temperatures = torch.concat((self.temperatures, torch.tensor([prefill_query_info.temperature], device=device, dtype=torch.float32)), dim=0)
self.top_ps = torch.concat((self.top_ps, torch.tensor([prefill_query_info.top_p], device=device, dtype=torch.float32)), dim=0)
for decode_query_info in decode_querys_info:
decode_kv_block_len = (decode_query_info.active_position + 1 + page_size - 1) // page_size
self.q_indptr = torch.concat((self.q_indptr, torch.tensor([1 + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([decode_kv_block_len+self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
self.kv_indices = torch.concat((self.kv_indices, decode_query_info.block_index[:decode_kv_block_len]), dim=0)
self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(decode_query_info.active_position + 1) % page_size if (decode_query_info.active_position + 1) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)
self.kv_len = torch.concat((self.kv_len, torch.tensor([(decode_query_info.active_position + 1)], device=device, dtype=torch.int32)), dim=0)
self.position_ids = torch.concat((self.position_ids, torch.arange(decode_query_info.active_position, decode_query_info.active_position + 1, device=device, dtype=torch.int32)), dim=0)
if decode_query_info.active_position > 0:
self.tokens = torch.concat((self.tokens, decode_query_info.query_tokens[decode_query_info.active_position:decode_query_info.active_position+1]), dim=0)
else:
self.tokens = torch.concat((self.tokens, torch.tensor([0], device=device, dtype=torch.int32)), dim=0)
self.logits_start.append(0 if len(self.logits_start) == 0 else self.logits_start[-1]+1)
self.temperatures = torch.concat((self.temperatures, torch.tensor([decode_query_info.temperature], device=device, dtype=torch.float32)), dim=0)
self.top_ps = torch.concat((self.top_ps, torch.tensor([decode_query_info.top_p], device=device, dtype=torch.float32)), dim=0)
self.q_indptr = self.q_indptr.contiguous()
self.kv_indptr = self.kv_indptr.contiguous()
self.kv_indices = self.kv_indices.contiguous()
self.kv_len = self.kv_len.contiguous()
self.kv_last_page_len = self.kv_last_page_len.contiguous()
self.position_ids = self.position_ids.contiguous()
self.tokens = self.tokens.contiguous()
self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32)
def fill(self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo], prefill_s: list[int] = None, prefill_l: list[int] = None, device = torch.device('cuda'), page_size = 256):
batch_decode = len(decode_querys_info)
batch_prefill = len(prefill_querys_info)
self.q_indptr = torch.tensor([0], device=device, dtype=torch.int32)
self.kv_indptr = torch.tensor([0], device=device, dtype=torch.int32)
self.kv_indices = torch.tensor([], device=device, dtype=torch.int32)
self.kv_len = torch.tensor([], device=device, dtype=torch.int32)
self.kv_last_page_len = torch.tensor([], device=device, dtype=torch.int32)
new_position_ids = torch.tensor([], device=device, dtype=torch.int32)
new_tokens = torch.tensor([], device=device, dtype=torch.int32)
self.temperatures = torch.tensor([], device=device, dtype=torch.float32)
self.top_ps = torch.tensor([], device=device, dtype=torch.float32)
self.logits_start = []
self.decode_batch = batch_decode
self.num_tokens = batch_decode + sum(prefill_l)
self.batch_size = batch_decode + batch_prefill
for i, prefill_query_info in enumerate(prefill_querys_info):
prefill_kv_block_len = (prefill_query_info.active_position + prefill_l[i] + page_size - 1) // page_size if prefill_query_info is not None else 0
# print(f"block_len: {prefill_kv_block_len}, page_size: {page_size}")
self.q_indptr = torch.concat((self.q_indptr, torch.tensor([prefill_l[i] + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([prefill_kv_block_len + self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
self.kv_indices = torch.concat((self.kv_indices, prefill_query_info.block_index[:prefill_kv_block_len]), dim=0)
self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i]) % page_size if (prefill_query_info.active_position + prefill_l[i]) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)
self.kv_len = torch.concat((self.kv_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i])], device=device, dtype=torch.int32)), dim=0)
new_position_ids = torch.concat((new_position_ids, torch.arange(prefill_s[i], prefill_l[i] + prefill_s[i], device=device, dtype=torch.int32)), dim=0)
new_tokens = torch.concat((new_tokens, prefill_query_info.query_tokens[prefill_s[i]:prefill_s[i] + prefill_l[i]]), dim=0)
self.logits_start.append(prefill_l[i] - 1 if len(self.logits_start) == 0 else sum(prefill_l[:i+1])-1)
self.temperatures = torch.concat((self.temperatures, torch.tensor([prefill_query_info.temperature], device=device, dtype=torch.float32)), dim=0)
self.top_ps = torch.concat((self.top_ps, torch.tensor([prefill_query_info.top_p], device=device, dtype=torch.float32)), dim=0)
for decode_query_info in decode_querys_info:
decode_kv_block_len = (decode_query_info.active_position + 1 + page_size - 1) // page_size
self.q_indptr = torch.concat((self.q_indptr, torch.tensor([1 + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([decode_kv_block_len+self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
self.kv_indices = torch.concat((self.kv_indices, decode_query_info.block_index[:decode_kv_block_len]), dim=0)
self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(decode_query_info.active_position + 1) % page_size if (decode_query_info.active_position + 1) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)
self.kv_len = torch.concat((self.kv_len, torch.tensor([(decode_query_info.active_position + 1)], device=device, dtype=torch.int32)), dim=0)
new_position_ids = torch.concat((new_position_ids, torch.arange(decode_query_info.active_position, decode_query_info.active_position + 1, device=device, dtype=torch.int32)), dim=0)
if decode_query_info.active_position > 0:
new_tokens = torch.concat((new_tokens, decode_query_info.query_tokens[decode_query_info.active_position:decode_query_info.active_position+1]), dim=0)
else:
new_tokens = torch.concat((new_tokens, torch.tensor([0], device=device, dtype=torch.int32)), dim=0)
self.logits_start.append(0 if len(self.logits_start) == 0 else self.logits_start[-1]+1)
self.temperatures = torch.concat((self.temperatures, torch.tensor([decode_query_info.temperature], device=device, dtype=torch.float32)), dim=0)
self.top_ps = torch.concat((self.top_ps, torch.tensor([decode_query_info.top_p], device=device, dtype=torch.float32)), dim=0)
self.q_indptr = self.q_indptr.contiguous()
self.kv_indptr = self.kv_indptr.contiguous()
self.kv_indices = self.kv_indices.contiguous()
self.kv_len = self.kv_len.contiguous()
self.kv_last_page_len = self.kv_last_page_len.contiguous()
self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32)
# copy new_position_ids and new_tokens to self.position_ids and self.tokens
# print("new_position_ids: ", new_position_ids)
# self.print()
self.position_ids[:new_position_ids.size(0)].copy_(new_position_ids)
self.position_ids[new_position_ids.size(0):].zero_()
self.tokens[:new_tokens.size(0)].copy_(new_tokens)
forward_minibatchs: list[ForwardMiniBatch]
batch_size: int
minibatch: ForwardMiniBatch
def __init__(self, batch : sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None, device=None, tokens: torch.Tensor = None):
if batch is None:
return
prefill_minibatches = batch.prefill_mini_batches
decode_mini_batches = [item for sublist in batch.decode_mini_batches for item in sublist]
prefill_querys_info = []
prefill_s = []
prefill_l = []
decode_querys_info = []
self.batch_size = 1
for (id, s, l) in prefill_minibatches:
prefill_querys_info.append(query_manager.query_map[id])
prefill_s.append(s)
prefill_l.append(l)
for decode_batch_idx in decode_mini_batches:
if query_manager.query_map[decode_batch_idx].decode_start_time is None:
query_manager.query_map[decode_batch_idx].decode_start_time =time.time()
decode_querys_info.append(query_manager.query_map[decode_batch_idx])
minibatch = ForwardBatchInput.ForwardMiniBatch(prefill_querys_info, decode_querys_info, prefill_s, prefill_l, device = query_manager.device, page_size = query_manager.page_size)
self.minibatch = minibatch
@classmethod
def gen_max_forward_batch(
cls,
device=None,
tokens: torch.Tensor = None,
num_mini_batches: int = 1,
max_seq_length: int = 1024, # TODO: add to yaml
prefill_query_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size, # TODO: use config
prefill_active_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size,
gen_prefill: bool = True,
decode_batch_size: int = Config().max_decode_batch_size,
decode_active_position: torch.Tensor = None,
page_size = 256,
cuda_lens = 1
):
instance = cls()
instance.batch_size = num_mini_batches
page_size = page_size
prefill_query_info = []
offset = 0
if gen_prefill and prefill_query_length != 0:
for i in range(Config().max_prefill_batch_size):
prefill_query_info.append(QueryInfo(i, prefill_query_length, max_seq_length, page_size, device, offset=offset))
offset += max_seq_length // page_size
decode_querys_info = []
for i in range(min(decode_batch_size, cuda_lens)):
query_info = QueryInfo(i+Config().max_prefill_batch_size, prefill_query_length, max_seq_length, page_size, device, is_prefill=False, offset=offset)
offset += max_seq_length // page_size
if tokens is not None:
query_info.query_tokens[prefill_active_length:prefill_active_length + 1].copy_(tokens)
if decode_active_position is None:
query_info.active_position = prefill_active_length
else:
query_info.active_position = decode_active_position[i]
decode_querys_info.append(query_info)
if prefill_query_length*Config().max_prefill_batch_size + len(decode_querys_info) < cuda_lens:
decode_querys_info.append(query_info)
instance.minibatch = ForwardBatchInput.ForwardMiniBatch(prefill_query_info, decode_querys_info, [0, 0], [prefill_active_length for _ in range(Config().max_prefill_batch_size)], device, page_size)
return instance
def fill(self, batch : sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None, page_size = 256):
if batch is None:
return
prefill_minibatches = batch.prefill_mini_batches
decode_mini_batches = [item for sublist in batch.decode_mini_batches for item in sublist]
prefill_querys_info = []
prefill_s = []
prefill_l = []
decode_querys_info = []
self.batch_size = 1
for (id, s, l) in prefill_minibatches:
prefill_querys_info.append(query_manager.query_map[id])
prefill_s.append(s)
prefill_l.append(l)
for decode_batch_idx in decode_mini_batches:
if query_manager.query_map[decode_batch_idx].decode_start_time is None:
query_manager.query_map[decode_batch_idx].decode_start_time =time.time()
decode_querys_info.append(query_manager.query_map[decode_batch_idx])
self.minibatch.fill(prefill_querys_info, decode_querys_info, prefill_s, prefill_l, device=query_manager.device, page_size=page_size)
class ForwardBatchOutput:
logits: list[torch.Tensor]
num_batchs: int
batch_sizes: list[int]
generated_tokens_num: list[int]
lm_start: list[int]
temperatures: list[torch.Tensor]
top_ps: list[torch.Tensor]
def __init__(self):
self.logits = []
self.batch_sizes = []
self.generated_tokens_num = []
self.top_ps = []
self.temperatures = []
pass

View file

@ -0,0 +1,306 @@
"""
Date: 2024-11-07 07:02:20
LastEditors: djw
LastEditTime: 2024-12-10 08:48:32
"""
import torch
from torch import nn
import queue
import signal
import queue
from typing import AsyncIterable
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from contextlib import asynccontextmanager
from pydantic import BaseModel, Field
import asyncio
import multiprocessing
import time
import torch.multiprocessing as mp
import random
import torch.distributed as dist
import zmq
import tempfile
from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput
from ktransformers.server.config.config import Config
from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
from ktransformers.server.balance_serve.settings import sched_ext
def pad_num_tokens(num_tokens):
return (num_tokens + 63) // 64 * 64
def deduplicate_and_sort(lst):
return sorted(set(lst))
class ModelRunner:
"""A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile."""
model: KDeepseekV3ForCausalLM
input: ForwardBatchInput | list[ForwardBatchInput]
output: ForwardBatchOutput
def __init__(self, model = None, device = None, use_cuda_graph = False, max_decode_batch_size = 1, max_chunk_size = 4096, num_mini_batches: int = 1, page_size = 256):
self.stream = torch.cuda.Stream(device=device)
# 先注释掉
self.model = model # Compile and move model to the specified device
self.device = device
self.input = None
self.features_buf = None
self.output = None
self.graph_memory_pool = None
self.cuda_graphs = deduplicate_and_sort([1, 2, 3, Config().max_batch_size, 64, Config().chunk_size])
self.use_cuda_graph = use_cuda_graph
self.model_time = 0
self.page_size = page_size
# GPU timing for model execution
self.start_model_event = torch.cuda.Event(enable_timing=True)
self.end_model_event = torch.cuda.Event(enable_timing=True)
if isinstance(self.cuda_graphs, list):
self.graphs = [torch.cuda.CUDAGraph() for _ in range(len(self.cuda_graphs))]
self.page_idx_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))]
self.page_offset_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))]
else:
self.graphs = torch.cuda.CUDAGraph()
self.page_idx_buf = torch.zeros([self.cuda_graphs], dtype=torch.int32, device = self.device)
self.page_offset_buf = torch.zeros([self.cuda_graphs], dtype=torch.int32, device = self.device)
self.num_mini_batches = num_mini_batches
self.max_chunk_size = max_chunk_size
self.bsz_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device)
self.num_tokens_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device)
def warmup(self):
def capture_graphs(cuda_graph_idx=-1):
if cuda_graph_idx != -1:
with torch.cuda.graph(self.graphs[cuda_graph_idx], pool=self.graph_memory_pool, stream=self.stream):
self.outputs_buf[cuda_graph_idx] = self.model(self.input[cuda_graph_idx], self.features_buf[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[cuda_graph_idx], self.page_offset_buf[cuda_graph_idx], cuda_graph_idx=cuda_graph_idx)
self.graph_memory_pool = self.graphs[cuda_graph_idx].pool()
else:
with torch.cuda.graph(self.graphs, pool=self.graph_memory_pool, stream=self.stream):
self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf)
self.graph_memory_pool = self.graphs.pool()
if isinstance(self.cuda_graphs, list):
self.input = []
self.features_buf = []
self.outputs_buf = []
self.bsz_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)
self.num_tokens_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)
for i in range(len(self.cuda_graphs)):
prefill_query_length = (self.cuda_graphs[i] - Config().max_decode_batch_size) // Config().max_prefill_batch_size if self.cuda_graphs[i] > Config().max_decode_batch_size else 0 #@TODO only supprot 2 prefill batch
self.input.append(ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches, prefill_query_length=prefill_query_length, prefill_active_length=prefill_query_length, page_size=self.page_size, cuda_lens = self.cuda_graphs[i]))
self.features_buf.append(self.model.batch_embeddings(self.input[i]))
batch_size = self.input[i].minibatch.q_indptr.size(0)-1
num_tokens = self.features_buf[i][0].size(0)
print("capturing cuda graph", batch_size, num_tokens)
self.bsz_tensor_buf[0] = batch_size
self.num_tokens_tensor_buf[0] = num_tokens
self.model.flash_infer_attn_plan(self.input[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
page_idx, page_offset = self.model.cache.get_page_table(self.input[i].minibatch.position_ids, self.input[i].minibatch.q_indptr, self.input[i].minibatch.kv_indptr, self.input[i].minibatch.kv_indices, self.num_tokens_tensor_buf)
self.page_idx_buf[i][:num_tokens].copy_(page_idx[:num_tokens])
self.page_offset_buf[i][:num_tokens].copy_(page_offset[:num_tokens])
self.page_idx_buf[i][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size -1)
self.outputs_buf.append(None)
torch.cuda.synchronize()
for warm_up_iters in range(11):
with torch.cuda.stream(self.stream):
self.outputs_buf[i] = self.model(self.input[i], self.features_buf[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[i], self.page_offset_buf[i])
torch.cuda.synchronize()
capture_graphs(i)
with torch.cuda.stream(self.stream):
self.graphs[i].replay()
self.sync(calc_time=False)
print(f"cuda_graph: {i+1}/{len(self.cuda_graphs)}, warmup finished.")
else:
self.input = ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches)
self.features_buf = self.model.batch_embeddings(self.input)
batch_size = self.input.minibatch.q_indptr.size(0)-1
num_tokens = self.features_buf[0].size(0)
self.bsz_tensor_buf = torch.tensor([batch_size], dtype=torch.int32, device=self.device)
self.num_tokens_tensor_buf = torch.tensor([num_tokens], dtype=torch.int32, device=self.device)
self.model.flash_infer_attn_plan(self.input, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
page_idx, page_offset = self.model.cache.get_page_table(self.input.minibatch.position_ids, self.input.minibatch.q_indptr, self.input.minibatch.kv_indptr, self.input.minibatch.kv_indices, self.num_tokens_tensor_buf)
self.page_idx_buf[:num_tokens].copy_(page_idx[:num_tokens])
self.page_offset_buf[:num_tokens].copy_(page_offset[:num_tokens])
self.page_idx_buf[num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1)
torch.cuda.synchronize()
for warm_up_iters in range(11):
with torch.cuda.stream(self.stream):
self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf)
torch.cuda.synchronize()
def capture_graphs():
with torch.cuda.graph(self.graphs, stream=self.stream):
self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf)
# self.graph_memory_pool = self.graphs.pool()
capture_graphs()
with torch.cuda.stream(self.stream):
self.graphs.replay()
self.sync(calc_time=False)
print("warmup finished.")
def run(self, batch: sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None):
with torch.cuda.stream(self.stream):
batch_size = len(batch.prefill_mini_batches) # TODO: calc this
num_tokens = 0
for i in range(len(batch.decode_mini_batches)):
batch_size += len(batch.decode_mini_batches[i])
num_tokens += len(batch.decode_mini_batches[i])
print(f'decode_batch_i: {len(batch.decode_mini_batches[i])},')
for i in range(len(batch.prefill_mini_batches)):
num_tokens += batch.prefill_mini_batches[i][2]
print(f'prefill_batch_i: {batch.prefill_mini_batches[i][2]},')
if isinstance(self.cuda_graphs, list):
# cuda graph idx equal to min idx i in self.cuda_graphs, that self.cuda_graphs[i] > num_tokens
cuda_graph_idx = next((i for i, token in enumerate(self.cuda_graphs) if token >= num_tokens), len(self.cuda_graphs))
if cuda_graph_idx == len(self.cuda_graphs):
assert False, "num_tokens is too large"
else:
cuda_graph_idx = -1
if self.use_cuda_graph:
if cuda_graph_idx != -1:
self.input[cuda_graph_idx].fill(batch, query_manager, self.page_size)
else:
self.input.fill(batch, query_manager, self.page_size)
else:
self.input = ForwardBatchInput(batch=batch, query_manager=query_manager, device=self.device)
if cuda_graph_idx != -1 and self.use_cuda_graph:
self.features = self.model.batch_embeddings(self.input[cuda_graph_idx], device=self.device)
else:
self.features = self.model.batch_embeddings(self.input, device=self.device)
self.bsz_tensor_buf.copy_(batch_size)
self.num_tokens_tensor_buf.copy_(torch.tensor([num_tokens], dtype=torch.int32, device=self.device))
if self.use_cuda_graph:
if cuda_graph_idx != -1:
self.features_buf[cuda_graph_idx][0].copy_(self.features[0], non_blocking=True)
else:
self.features_buf[0].copy_(self.features[0], non_blocking=True)
"""
if num_tokens_0 > 64:
padded_num_tokens_0 = pad_num_tokens(num_tokens_0)
self.features_buf[0][num_tokens_0:padded_num_tokens_0] = 0
"""
#self.input.forward_minibatchs[0].print()
# print([[hash(k[i].float().cpu().numpy().tobytes()) for i in self.input.forward_minibatchs[0].kv_indices] for k in self.model.cache.k_caches])
# print(f"overlap: {overlap}, is_compute_bound: {is_compute_bound}")
# self.model.flash_infer_attn_plan(self.input, self.bsz_tensors, self.num_tokens_tensors)
"""
if self.use_cuda_graph:
print("before replay features_buf", self.features_buf[0])
print("features_buf addr", self.features_buf[0].data_ptr())
else:
print("before run features", self.features[0])
"""
if cuda_graph_idx != -1 and self.use_cuda_graph:
self.model.flash_infer_attn_plan(self.input[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
self.start_model_event.record(self.stream)
page_idx, page_offset = self.model.cache.get_page_table(self.input[cuda_graph_idx].minibatch.position_ids, self.input[cuda_graph_idx].minibatch.q_indptr, self.input[cuda_graph_idx].minibatch.kv_indptr, self.input[cuda_graph_idx].minibatch.kv_indices, self.num_tokens_tensor_buf)
if self.use_cuda_graph:
self.page_idx_buf[cuda_graph_idx][:num_tokens].copy_(page_idx[:num_tokens])
self.page_offset_buf[cuda_graph_idx][:num_tokens].copy_(page_offset[:num_tokens])
self.page_idx_buf[cuda_graph_idx][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1)
self.replay(cuda_graph_idx)
self.output = ForwardBatchOutput()
self.output.top_ps.append(self.input[cuda_graph_idx].minibatch.top_ps)
self.output.temperatures.append(self.input[cuda_graph_idx].minibatch.temperatures)
self.output.logits.append(self.outputs_buf[cuda_graph_idx].logits[0][self.input[cuda_graph_idx].minibatch.logits_start].clone())
else:
self.output = self.model(self.input[cuda_graph_idx], self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset)
self.output.logits[0] = self.output.logits[0][self.input[cuda_graph_idx].minibatch.logits_start]
self.end_model_event.record(self.stream)
else:
self.model.flash_infer_attn_plan(self.input, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
self.start_model_event.record(self.stream)
page_idx, page_offset = self.model.cache.get_page_table(self.input.minibatch.position_ids, self.input.minibatch.q_indptr, self.input.minibatch.kv_indptr, self.input.minibatch.kv_indices, self.num_tokens_tensor_buf)
if self.use_cuda_graph:
self.page_idx_buf[:num_tokens].copy_(page_idx[:num_tokens])
self.page_offset_buf[:num_tokens].copy_(page_offset[:num_tokens])
self.page_idx_buf[num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1)
self.replay(cuda_graph_idx)
self.output = ForwardBatchOutput()
self.output.top_ps.append(self.input.minibatch.top_ps)
self.output.temperatures.append(self.input.minibatch.temperatures)
self.output.logits.append(self.outputs_buf.logits[0][self.input.minibatch.logits_start].clone())
else:
self.output = self.model(self.input, self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset)
self.output.logits[0] = self.output.logits[0][self.input.minibatch.logits_start]
self.output.top_ps.append(self.input.minibatch.top_ps)
self.output.temperatures.append(self.input.minibatch.temperatures)
self.end_model_event.record(self.stream)
if not self.use_cuda_graph:
self.output.num_batchs = self.input.batch_size
else:
self.output.num_batchs = self.input[cuda_graph_idx].batch_size
def replay(self, cuda_graph_idx=-1):
with torch.cuda.stream(self.stream):
if cuda_graph_idx != -1:
self.graphs[cuda_graph_idx].replay()
else:
self.graphs.replay()
def sync(self, calc_time = True):
self.stream.synchronize()
if calc_time:
self.model_time = self.start_model_event.elapsed_time(self.end_model_event) # In ms

View file

@ -0,0 +1,158 @@
'''
Date: 2024-11-14 12:23:45
LastEditors: djw
LastEditTime: 2024-11-20 04:06:23
'''
import torch
from ktransformers.server.balance_serve.settings import sched_ext
import random
import time
class QueryInfo:
id: int
active_position: int
query_length: int
is_prefill: int
block_index: torch.Tensor
query_tokens: torch.Tensor
stop_criteria: list[torch.Tensor]
temperature: float
top_p: float
max_length: int
def __init__(self, id, query_length: int, max_length: int, page_size: int, device: torch.device, is_prefill: bool = True, offset: int = 0, active_position: int = 0, temperature: float = 0.01, top_p: float = 1.0):
self.id = id
self.is_prefill = is_prefill
self.active_position = active_position
self.max_length = max_length - 1
self.query_tokens = torch.zeros((max_length,), dtype=torch.int, device = device)
self.stop_criteria = []
self.block_index = torch.arange(offset, offset + (max_length + active_position + page_size - 1) // page_size, dtype=torch.int, device = device)
self.query_length = query_length
self.enqueue_time = time.time()
self.decode_start_time = None
self.speculative_token = {} # {position: (accept, token)}
self.temperature = temperature
self.top_p = top_p
def check_stop(self):
if self.active_position >= self.max_length - 2:
return True
# 遍历每个停止条件
for stop_tensor in self.stop_criteria:
stop_len = len(stop_tensor)
# 如果停止条件比 query_tokens 长,跳过
if stop_len >= self.active_position:
continue
#print(f"stop_tensor: {stop_tensor}, stop_len: {stop_len}, active_position: {self.active_position}, query_token: {self.query_tokens[self.active_position - stop_len - 1:self.active_position - 1]}")
if (torch.equal(self.query_tokens[self.active_position - stop_len - 1:self.active_position - 1], stop_tensor) and self.active_position) or self.max_length <= self.active_position + 3:
self.life_time = time.time() - self.enqueue_time
self.decode_duration_time = time.time() - self.decode_start_time
self.decode_tps = (self.active_position - self.query_length) / self.decode_duration_time
print(f"prefill length: {self.query_length}, prefill time: {self.prefill_duration_time}, prefill tps {self.prefill_tps}, decode length: {self.active_position - self.query_length}, decode time: {self.decode_duration_time}, decode tps {self.decode_tps}")
return True # 找到匹配的停止条件
return False # 没有找到任何停止条件
def print(self):
print(f"active_position: {self.active_position}, query_length: {self.query_length}, is_prefill: {self.is_prefill}")
print(f"block_index_shape: {self.block_index.shape}, query_tokens_shape: {self.query_tokens.shape}")
class QueryManager:
max_length: int = 65536
page_size: int = 256
device: torch.device
query_map : dict[int, QueryInfo]
def __init__(self, max_length = 65536, page_size = 256, device = torch.device('cuda')):
self.max_length = max_length
self.page_size = page_size
self.device = device
self.query_map = {}
def add_query(self, batch: sched_ext.BatchQueryTodo):
for i in range(len(batch.query_ids)):
id = batch.query_ids[i]
if id not in self.query_map:
print(f"add query id: {id}, batch.query_lengths: {batch.query_lengths[i]}, batch_query_tokens: {batch.query_tokens[i].shape}, batch.block_indexes: {batch.block_indexes[i]}")
assert batch.query_tokens[i].size(0) < self.max_length, "query max length in batchquerytodo exceeds internal max_length"
query_info = QueryInfo(id=id, query_length=batch.query_lengths[i], max_length=batch.query_tokens[i].size(0) + 1, page_size=self.page_size, device=self.device, temperature=batch.sample_options[i].temperature, top_p=batch.sample_options[i].top_p)
query_info.query_tokens[:query_info.query_length].copy_(batch.query_tokens[i][:query_info.query_length].to(self.device))
for stop_token_list in batch.stop_criteria[i]:
query_info.stop_criteria.append(torch.tensor(stop_token_list, dtype=torch.int, device = self.device))
block_num = batch.block_indexes[i].size(0)
query_info.block_index[:block_num].copy_(batch.block_indexes[i].to(self.device))
self.query_map[id] = query_info
prefill_mini_batches = batch.prefill_mini_batches
for (prefill_id, s, l) in prefill_mini_batches:
if prefill_id == id:
self.query_map[prefill_id].active_position = s
def update(self, batch: sched_ext.BatchQueryTodo) -> list[sched_ext.QueryUpdate]:
query_updates = []
prefill_mini_batches = batch.prefill_mini_batches
for (id, s, l) in prefill_mini_batches:
if id not in self.query_map:
assert False, f"query id {id} not found in query_map"
# update query_info
query_info = self.query_map[id]
query_info.active_position += l
if query_info.active_position >= query_info.query_length and query_info.is_prefill:
query_info.is_prefill = False
query_info.prefill_duration_time = time.time() - query_info.enqueue_time
query_info.prefill_tps = query_info.query_length / query_info.prefill_duration_time
# generate schedule query_update
query_update = sched_ext.QueryUpdate()
query_update.id = id
query_update.ok = True
query_update.is_prefill = query_info.is_prefill
query_update.active_position = query_info.active_position
# if(not query_info.is_prefill):
query_updates.append(query_update)
decode_mini_batches = batch.decode_mini_batches
for ids in decode_mini_batches:
for id in ids:
if id not in self.query_map:
assert False, f"query id {id} not found in query_map"
query_info = self.query_map[id]
query_info.active_position += 1
query_update = sched_ext.QueryUpdate()
query_update.id = id
query_update.ok = True
query_update.is_prefill = query_info.is_prefill
query_update.decode_done = query_info.check_stop()
query_update.active_position = query_info.active_position
query_updates.append(query_update)
return query_updates

View file

@ -0,0 +1,13 @@
from .orchestrator import BatchedPenalizerOrchestrator
from .penalizers.frequency_penalty import BatchedFrequencyPenalizer
from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer
from .penalizers.presence_penalty import BatchedPresencePenalizer
from .penalizers.repetition_penalty import BatchedRepetitionPenalizer
__all__ = [
"BatchedFrequencyPenalizer",
"BatchedMinNewTokensPenalizer",
"BatchedPresencePenalizer",
"BatchedRepetitionPenalizer",
"BatchedPenalizerOrchestrator",
]

View file

@ -0,0 +1,376 @@
import abc
import dataclasses
import typing
import torch
@dataclasses.dataclass
class _ReqLike:
origin_input_ids: typing.Union[torch.Tensor, typing.List[int]]
@dataclasses.dataclass
class _BatchLike:
reqs: typing.List[_ReqLike]
def batch_size(self):
return len(self.reqs)
class BatchedPenalizerOrchestrator:
batch: _BatchLike
device: str
vocab_size: int
penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"]
def __init__(
self,
vocab_size: int,
batch: _BatchLike,
device: str,
Penalizers: typing.Set[typing.Type["_BatchedPenalizer"]],
):
self.vocab_size = vocab_size
self.batch = batch
self.device = device
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
is_required = False
for penalizer in self.penalizers.values():
pen_is_required = penalizer.prepare_if_required()
is_required |= pen_is_required
self.is_required = is_required
if self.is_required:
self.cumulate_input_tokens(
input_ids=[req.origin_input_ids for req in self.reqs()]
)
def reqs(self):
return self.batch.reqs
def batch_size(self):
return self.batch.batch_size()
def cumulate_input_tokens(
self,
input_ids: typing.Union[
typing.List[torch.Tensor], typing.List[typing.List[int]]
],
):
"""
Feed the input tokens to the penalizers.
Args:
input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens.
"""
token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
for penalizer in self.penalizers.values():
penalizer.cumulate_input_tokens(input_ids=token_ids)
def cumulate_output_tokens(
self,
output_ids: typing.Union[
typing.List[torch.Tensor], typing.List[typing.List[int]]
],
):
"""
Feed the output tokens to the penalizers.
Args:
output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
"""
if not self.is_required:
return
token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
for penalizer in self.penalizers.values():
penalizer.cumulate_output_tokens(output_ids=token_ids)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
"""
Apply the penalizers to the logits.
Note that it may apply the penalizers in-place.
Args:
logits (torch.Tensor): The logits to apply the penalizers to.
Returns:
torch.Tensor: The logits after applying the penalizers.
"""
if not self.is_required:
return
for penalizer in self.penalizers.values():
logits = penalizer.apply(logits)
return logits
def filter(
self,
indices_to_keep: typing.List[int],
indices_tensor_to_keep: torch.Tensor = None,
):
"""
Filter the penalizers based on the indices to keep in the batch.
Args:
indices_to_keep (typing.List[int]): List of indices to keep in the batch.
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
"""
if not self.is_required:
return
empty_indices = len(indices_to_keep) == 0
is_required = False
for penalizer in self.penalizers.values():
tmp_is_required = penalizer.is_required()
is_required = is_required or tmp_is_required
if not tmp_is_required or empty_indices:
penalizer.teardown()
else:
# create tensor index only when it's needed
if indices_tensor_to_keep is None:
indices_tensor_to_keep = torch.tensor(
indices_to_keep, dtype=torch.int32, device=self.device
)
penalizer.filter(
indices_to_keep=indices_to_keep,
indices_tensor_to_keep=indices_tensor_to_keep,
)
self.is_required = is_required
def merge(self, their: "BatchedPenalizerOrchestrator"):
"""
Merge the penalizers of another orchestrator into this one.
Note that this function **must** be called _before_ self.batch.reqs is updated (filtered).
Each unprepared penalizers would have to be prepared (creating tensors, etc.) first before merging.
This step requires the original batch.reqs, before it gets merged with other batch.reqs.
Args:
their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
"""
if not self.is_required and not their.is_required:
return
self.is_required |= their.is_required
for Penalizer, their_penalizer in their.penalizers.items():
if Penalizer not in self.penalizers:
raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
self.penalizers[Penalizer].merge(their_penalizer)
class _TokenIDs:
"""
A class that wraps token IDs to provide additional utility functions to penalizers.
Attributes:
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs.
cached_counts (torch.Tensor): The cached occurrence count tensor.
"""
orchestrator: BatchedPenalizerOrchestrator
token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]]
cached_counts: torch.Tensor = None
def __init__(
self,
orchestrator: BatchedPenalizerOrchestrator,
token_ids: typing.Union[
typing.List[torch.Tensor], typing.List[typing.List[int]]
],
):
self.orchestrator = orchestrator
if not isinstance(token_ids[0], torch.Tensor):
token_ids = [
torch.tensor(
data=ids, dtype=torch.int64, device=self.orchestrator.device
)
for ids in token_ids
]
self.token_ids = token_ids
def occurrence_count(self) -> torch.Tensor:
"""
Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch.
Returns:
torch.Tensor: The occurrence count tensor.
"""
if self.cached_counts is not None:
return self.cached_counts
token_ids = self.token_ids
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.unsqueeze(1)
# needs to be long to be used as index in scatter_add
if token_ids.dtype != torch.int64:
token_ids = token_ids.to(torch.int64)
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
sequences=token_ids,
batch_first=True,
padding_value=self.orchestrator.vocab_size,
)
self.cached_counts = torch.zeros(
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
dtype=torch.int64,
device=self.orchestrator.device,
).scatter_add_(
dim=1,
index=padded_token_ids,
src=torch.ones_like(padded_token_ids),
)[
:, : self.orchestrator.vocab_size
]
return self.cached_counts
class _BatchedPenalizer(abc.ABC):
"""
An abstract class for a batched penalizer.
"""
orchestrator: BatchedPenalizerOrchestrator
_is_prepared: bool = False
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
self.orchestrator = orchestrator
def is_prepared(self) -> bool:
return self._is_prepared
def is_required(self) -> bool:
return self._is_required()
def prepare(self):
if not self.is_prepared():
self._prepare()
self._is_prepared = True
def prepare_if_required(self):
if self.is_required():
self.prepare()
return True
else:
return False
def teardown(self):
if self.is_prepared():
self._teardown()
self._is_prepared = False
def cumulate_input_tokens(self, input_ids: _TokenIDs):
if not self.is_prepared():
return
self._cumulate_input_tokens(input_ids=input_ids)
def cumulate_output_tokens(self, output_ids: _TokenIDs):
if not self.is_prepared():
return
self._cumulate_output_tokens(output_ids=output_ids)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.is_prepared():
return logits
return self._apply(logits=logits)
def filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
if not self.is_prepared():
return
self._filter(
indices_to_keep=indices_to_keep,
indices_tensor_to_keep=indices_tensor_to_keep,
)
def merge(self, their: "_BatchedPenalizer"):
if not self.is_prepared() and not their.is_prepared():
return
self.prepare()
their.prepare()
self._merge(their)
@abc.abstractmethod
def _is_required(self) -> bool:
"""
Check if the penalizer is required to be prepared.
"""
pass
@abc.abstractmethod
def _prepare(self):
"""
Prepare the penalizer.
Usually, this is where the penalizer initializes its tensors.
"""
pass
@abc.abstractmethod
def _teardown(self):
"""
Tear down the penalizer.
Usually, this is where the penalizer frees its tensors.
"""
pass
@abc.abstractmethod
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
"""
Cumulate the input tokens.
Orchestrator will call this function to feed the input tokens to the penalizer.
"""
pass
@abc.abstractmethod
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
"""
Cumulate the output tokens.
Orchestrator will call this function to feed the output tokens to the penalizer.
"""
pass
@abc.abstractmethod
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
"""
Apply the penalizer to the logits.
Penalizers can modify the logits in-place if needed.
"""
pass
@abc.abstractmethod
def _filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
"""
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
"""
pass
@abc.abstractmethod
def _merge(self, their: "_BatchedPenalizer"):
"""
Merge the penalizer with another penalizer.
"""
pass

View file

@ -0,0 +1,80 @@
import typing
import torch
from ..orchestrator import _BatchedPenalizer, _TokenIDs
class BatchedFrequencyPenalizer(_BatchedPenalizer):
"""
Frequency penalizer penalizes tokens based on their frequency in the output.
"""
frequency_penalties: torch.Tensor = None
cumulated_frequency_penalties: torch.Tensor = None
def _is_required(self) -> bool:
return any(
req.sampling_params.frequency_penalty != 0.0
for req in self.orchestrator.reqs()
)
def _prepare(self):
self.cumulated_frequency_penalties = (
torch.tensor(
data=[0.0 for _ in self.orchestrator.reqs()],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.repeat(1, self.orchestrator.vocab_size)
)
self.frequency_penalties = (
torch.tensor(
data=[
req.sampling_params.frequency_penalty
for req in self.orchestrator.reqs()
],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.expand_as(self.cumulated_frequency_penalties)
)
def _teardown(self):
del self.frequency_penalties
del self.cumulated_frequency_penalties
self.frequency_penalties = None
self.cumulated_frequency_penalties = None
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
pass
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
self.cumulated_frequency_penalties += (
self.frequency_penalties * output_ids.occurrence_count()
)
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
logits -= self.cumulated_frequency_penalties
return logits
def _filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
indices_tensor_to_keep
]
def _merge(self, their: "BatchedFrequencyPenalizer"):
self.frequency_penalties = torch.cat(
[self.frequency_penalties, their.frequency_penalties], dim=0
)
self.cumulated_frequency_penalties = torch.cat(
[self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
dim=0,
)

View file

@ -0,0 +1,108 @@
import typing
import torch
from ..orchestrator import _BatchedPenalizer, _TokenIDs
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
"""
Min new tokens penalizer penalizes tokens based on the length of the output.
"""
min_new_tokens: torch.Tensor = None
stop_token_penalties: torch.Tensor = None
len_output_tokens: torch.Tensor = None
def _is_required(self) -> bool:
return any(
req.sampling_params.min_new_tokens > 0 for req in self.orchestrator.reqs()
)
def _prepare(self):
self.min_new_tokens = torch.tensor(
data=[
req.sampling_params.min_new_tokens for req in self.orchestrator.reqs()
],
dtype=torch.int32,
device=self.orchestrator.device,
).unsqueeze_(1)
padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence(
sequences=[
torch.tensor(
data=(
list(
(req.sampling_params.stop_token_ids or set())
| (req.tokenizer.additional_stop_token_ids or set())
| {req.tokenizer.eos_token_id}
)
),
dtype=torch.int64,
device=self.orchestrator.device,
)
for req in self.orchestrator.reqs()
],
batch_first=True,
padding_value=self.orchestrator.vocab_size,
)
self.stop_token_penalties = torch.zeros(
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
dtype=torch.float32,
device=self.orchestrator.device,
).scatter_add_(
dim=1,
index=padded_stop_token_ids,
src=torch.full_like(
input=padded_stop_token_ids,
dtype=torch.float32,
fill_value=float("-inf"),
device=self.orchestrator.device,
),
)[
:, : self.orchestrator.vocab_size
]
self.len_output_tokens = torch.zeros(
size=(self.orchestrator.batch_size(), 1),
dtype=torch.int32,
device=self.orchestrator.device,
)
def _teardown(self):
del self.min_new_tokens
del self.stop_token_penalties
del self.len_output_tokens
self.min_new_tokens = None
self.stop_token_penalties = None
self.len_output_tokens = None
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
pass
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
self.len_output_tokens += 1
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits)
logits[mask] += self.stop_token_penalties[mask]
return logits
def _filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
def _merge(self, their: "BatchedMinNewTokensPenalizer"):
self.min_new_tokens = torch.cat(
[self.min_new_tokens, their.min_new_tokens], dim=0
)
self.stop_token_penalties = torch.cat(
[self.stop_token_penalties, their.stop_token_penalties], dim=0
)
self.len_output_tokens = torch.cat(
[self.len_output_tokens, their.len_output_tokens], dim=0
)

View file

@ -0,0 +1,79 @@
import typing
import torch
from ..orchestrator import _BatchedPenalizer, _TokenIDs
class BatchedPresencePenalizer(_BatchedPenalizer):
"""
Presence penalizer penalizes tokens based on their presence in the output.
"""
presence_penalties: torch.Tensor = None
cumulated_presence_penalties: torch.Tensor = None
def _is_required(self) -> bool:
return any(
req.sampling_params.presence_penalty != 0.0
for req in self.orchestrator.reqs()
)
def _prepare(self):
self.cumulated_presence_penalties = (
torch.tensor(
data=[0.0 for _ in self.orchestrator.reqs()],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.repeat(1, self.orchestrator.vocab_size)
)
self.presence_penalties = (
torch.tensor(
data=[
req.sampling_params.presence_penalty
for req in self.orchestrator.reqs()
],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.expand_as(self.cumulated_presence_penalties)
)
def _teardown(self):
del self.presence_penalties
del self.cumulated_presence_penalties
self.presence_penalties = None
self.cumulated_presence_penalties = None
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
pass
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
mask = output_ids.occurrence_count() > 0
self.cumulated_presence_penalties[mask] = self.presence_penalties[mask]
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
logits -= self.cumulated_presence_penalties
return logits
def _filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
indices_tensor_to_keep
]
def _merge(self, their: "BatchedPresencePenalizer"):
self.presence_penalties = torch.cat(
[self.presence_penalties, their.presence_penalties], dim=0
)
self.cumulated_presence_penalties = torch.cat(
[self.cumulated_presence_penalties, their.cumulated_presence_penalties],
dim=0,
)

View file

@ -0,0 +1,83 @@
import typing
import torch
from ..orchestrator import _BatchedPenalizer, _TokenIDs
class BatchedRepetitionPenalizer(_BatchedPenalizer):
"""
Repetition penalizer penalizes tokens based on their repetition in the input and output.
"""
repetition_penalties: torch.Tensor = None
cumulated_repetition_penalties: torch.Tensor = None
def _is_required(self) -> bool:
return any(
req.sampling_params.repetition_penalty != 1.0
for req in self.orchestrator.reqs()
)
def _prepare(self):
self.cumulated_repetition_penalties = (
torch.tensor(
data=[1.0 for _ in self.orchestrator.reqs()],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.repeat(1, self.orchestrator.vocab_size)
)
self.repetition_penalties = (
torch.tensor(
data=[
req.sampling_params.repetition_penalty
for req in self.orchestrator.reqs()
],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.expand_as(self.cumulated_repetition_penalties)
)
def _teardown(self):
del self.repetition_penalties
del self.cumulated_repetition_penalties
self.repetition_penalties = None
self.cumulated_repetition_penalties = None
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
mask = input_ids.occurrence_count() > 0
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
mask = output_ids.occurrence_count() > 0
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
return torch.where(
logits > 0,
logits / self.cumulated_repetition_penalties,
logits * self.cumulated_repetition_penalties,
)
def _filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
indices_tensor_to_keep
]
def _merge(self, their: "BatchedRepetitionPenalizer"):
self.repetition_penalties = torch.cat(
[self.repetition_penalties, their.repetition_penalties], dim=0
)
self.cumulated_repetition_penalties = torch.cat(
[self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],
dim=0,
)

View file

@ -0,0 +1,100 @@
'''
Date: 2024-11-14 12:23:45
LastEditors: Xie Weiyu ervinxie@qq.com
LastEditTime: 2024-11-25 08:59:23
'''
import logging
import torch
from torch import nn
from transformers import GenerationConfig
from flashinfer.sampling import (
min_p_sampling_from_probs,
top_k_renorm_probs,
top_k_top_p_sampling_from_logits,
top_p_renorm_probs,
)
logger = logging.getLogger(__name__)
class SamplingOptions():
# Batched sampling params
temperatures: torch.Tensor
top_ps: torch.Tensor
top_ks: torch.Tensor
min_ps: torch.Tensor
# All requests use greedy sampling
is_all_greedy: bool
# Dispatch in CUDA graph
need_min_p_sampling: bool
def __init__(self, bsz = 1, device = torch.device('cuda'), pretrained_config:GenerationConfig = None, temperatures: torch.Tensor = None, top_ps: torch.Tensor = None):
if pretrained_config is None and temperatures is None:
self.temperatures = torch.full((bsz, 1), 0, device=device, dtype=torch.float32)
self.top_ps = torch.ones((bsz, 1), device=device, dtype=torch.float32)
self.top_ks = torch.ones((bsz, 1), device=device, dtype=torch.float32)
self.need_min_p_sampling = False
self.is_all_greedy = True
else:
if temperatures is not None:
self.temperatures = temperatures.unsqueeze(-1)
else:
self.temperatures = torch.full((bsz, 1), pretrained_config.temperature, device=device, dtype=torch.float32)
if top_ps is not None:
self.top_ps = top_ps.unsqueeze(-1)
else:
self.top_ps = torch.full((bsz, 1), pretrained_config.top_p, device=device, dtype=torch.float32)
self.top_ks = torch.full((bsz, 1), pretrained_config.top_k, device=device, dtype=torch.float32)
self.need_min_p_sampling = False
self.is_all_greedy = False
class Sampler(nn.Module):
def __init__(self):
super().__init__()
def forward(
self,
logits: torch.Tensor,
sampling_config: SamplingOptions = None,
):
if sampling_config == None:
sampling_config = SamplingOptions()
logits = logits.contiguous()
origin_logits = logits.clone()
if sampling_config.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling
probs = logits
batch_next_token_ids = torch.argmax(logits, -1)
else:
# Post process logits
logits.div_(sampling_config.temperatures)
max_top_k_round, batch_size = 32, logits.shape[0]
if sampling_config.need_min_p_sampling:
probs = torch.softmax(logits, dim=-1)
logits = None
del logits
probs = top_k_renorm_probs(probs, sampling_config.top_ks)
probs = top_p_renorm_probs(probs, sampling_config.top_ps)
batch_next_token_ids = min_p_sampling_from_probs(
probs, sampling_config.min_ps
)
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
else:
# TODO: use different kernel when don't need top_k or top_p
# @TODO get probs
probs = logits
batch_next_token_ids = top_k_top_p_sampling_from_logits(
logits,
sampling_config.top_ks,
sampling_config.top_ps,
filter_apply_order="joint",
)
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
return batch_next_token_ids.to(torch.int32), probs

View file

@ -0,0 +1,213 @@
from datetime import datetime
import os
from typing import Optional
import zmq
import pickle
import threading
import torch.multiprocessing as mp
import sys
current_file_path = os.path.abspath(__file__)
# sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
import pickle
import argparse
from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings
if mp.get_start_method(allow_none=True) is None:
print('set start method')
mp.set_start_method('spawn')
else:
print(f'start method already set to {mp.get_start_method(allow_none=True)}')
class SchedulerServer:
def __init__(self, settings, main_args):
# 创建 Scheduler 实例并初始化
self.sched = sched_ext.create_scheduler(settings)
# 初始化 ZeroMQ 上下文和套接字
self.context = zmq.Context()
self.frontend = self.context.socket(zmq.ROUTER)
print(f"sched zmq rpc server on port {main_args.sched_port}")
self.frontend.bind(f"tcp://*:{main_args.sched_port}")
# 创建内部的 DEALER 套接字,用于与工作线程通信
self.backend = self.context.socket(zmq.DEALER)
self.backend.bind("inproc://backend")
# 启动调度器
def run_scheduler(self):
self.sched.run()
# 停止调度器
def stop_scheduler(self):
self.sched.stop()
# 处理客户端请求
def start_proxy(self):
# 使用 ZMQ 的内置代理,将前端请求分发给后端工作线程
zmq.proxy(self.frontend, self.backend)
# 工作线程处理请求
def worker_routine(self):
worker = self.context.socket(zmq.REP)
worker.connect("inproc://backend")
while True:
try:
# 接收客户端请求
message = worker.recv()
data = pickle.loads(message)
method = data.get('method')
params = data.get('params', {})
# print(f"Received request: {method}")
if method == 'add_query':
query_add = params.get('query') # 直接是一个 QueryAdd 对象
# 添加查询
query_id = self.sched.add_query(query_add)
# 发送响应
response = {'status': 'ok', 'query_id': query_id}
worker.send(pickle.dumps(response))
elif method == 'cancel_query':
query_id = params.get('query_id')
# 假设您的 Scheduler 类实现了 cancel 方法
self.sched.cancel(query_id)
response = {'status': 'ok'}
worker.send(pickle.dumps(response))
elif method == 'update_last_batch':
updates = params.get('updates') # 直接是一个列表,包含 QueryUpdate 对象
# 更新最后一个批次
batch_todo = self.sched.update_last_batch(updates)
# 直接发送 batch_todo 对象
response = {'status': 'ok', 'batch_todo': batch_todo}
# print (batch_todo.query_lengths, batch_todo.query_ids)
worker.send(pickle.dumps(response))
elif method == 'get_inference_context':
inference_context = self.sched.get_inference_context()
data = {
"k_cache":inference_context.k_cache,
"v_cache":inference_context.v_cache
}
print(f"Serializing KVCache")
data["k_cache"] = [mp.reductions.reduce_tensor(t) for t in data['k_cache']]
data["v_cache"] = [mp.reductions.reduce_tensor(t) for t in data['v_cache']]
# print(data)
response = {'status': 'ok', 'inference_context': data}
worker.send(pickle.dumps(response))
# response['inference_context'].k_cache[0][0, 0, 0, 0, 0] = 1
# print("k_cache update")
else:
# 未知方法
response = {'status': 'error', 'message': 'Unknown method'}
worker.send(pickle.dumps(response))
except Exception as e:
# 处理异常并发送错误响应
response = {'status': 'error', 'message': str(e)}
worker.send(pickle.dumps(response))
# 启动 RPC 服务
def start_rpc_service(self):
try:
print("Scheduler RPC service is running...")
# 在单独的线程中运行调度器
threading.Thread(target=self.run_scheduler, daemon=True).start()
# 启动工作线程
for _ in range(10): # 根据需要调整线程数
threading.Thread(target=self.worker_routine, daemon=True).start()
# 启动代理,开始监听请求
self.start_proxy()
except KeyboardInterrupt:
print("Shutting down scheduler RPC service...")
self.stop_rpc_service()
# 停止 RPC 服务
def stop_rpc_service(self):
self.stop_scheduler()
self.frontend.close()
self.backend.close()
self.context.term()
def start_server(settings, main_args):
server = SchedulerServer(settings, main_args)
server.start_rpc_service()
# Add async client for webserver
class SchedulerClient:
def __init__(self, sched_port):
address=f'tcp://localhost:{sched_port}'
self.address = address
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REQ)
self.socket.connect(self.address)
print(f"Connected to server at {self.address}")
def __del__(self):
self.socket.close()
self.context.term()
def send_request(self, method, params=None):
if params is None:
params = {}
request = {
'method': method,
'params': params
}
# print(f'send request {request}')
self.socket.send(pickle.dumps(request))
response = self.socket.recv()
# print(response)
response = pickle.loads(response)
if response.get('status') == 'ok':
return response
else:
raise Exception(f"Error from server: {response.get('message')}")
def add_query(self, query):
response = self.send_request('add_query', {'query': query})
return response.get('query_id')
def cancel_query(self, query_id):
self.send_request('cancel_query', {'query_id': query_id})
def update_last_batch(self, updates):
response = self.send_request('update_last_batch', {'updates': updates})
# print(f"update_last_batch response {response}")
return response.get('batch_todo')
def rebuild_inferece_context(self,response):
data = response.get('inference_context')
inference_context = sched_ext.InferenceContext()
print('Rebuilding kvcache')
inference_context.k_cache = [fn(*args) for fn,args in data['k_cache']]
inference_context.v_cache = [fn(*args) for fn,args in data['v_cache']]
return inference_context
def get_inference_context_raw(self):
response = self.send_request('get_inference_context')
return response
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
args = parser.parse_args()
with open(args.config, "rb") as f:
main_args = pickle.load(f)
settings = create_sched_settings(main_args)
start_server(settings, main_args)

View file

@ -0,0 +1,73 @@
'''
Date: 2024-11-13 09:43:39
LastEditors: djw
LastEditTime: 2024-11-18 16:41:03
'''
import sys, os
import yaml, json
from time import sleep
import sched_ext
from transformers import AutoConfig
def create_sched_settings(args):
default_sample_options = sched_ext.SampleOptions()
model_name = os.path.basename(os.path.normpath(args.model_dir))
input_model_settings = sched_ext.ModelSettings()
input_model_settings.model_path = args.model_dir
input_model_settings.params_count = int(0)
model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
input_model_settings.layer_count = model_config.num_hidden_layers
input_model_settings.num_k_heads = 1 # model_config["num_key_value_heads"]
input_model_settings.k_head_dim = 576
input_model_settings.bytes_per_params = 2
input_model_settings.bytes_per_kv_cache_element = 2
settings = sched_ext.Settings()
settings.model_name = model_name
settings.quant_type = "BF16"
settings.model_settings = input_model_settings
settings.page_size = args.page_size
settings.gpu_device_count = 1 # tp
settings.gpu_device_id = [i for i in range(settings.gpu_device_count)]
# settings.gpu_memory_size = args.cache_lens*576*2
settings.gpu_memory_size = args.gpu_memory_size
settings.memory_utilization_percentage = args.utilization_percentage
max_batch_size = args.max_batch_size
chunk_size = args.chunk_size
max_decode_batch_size = max_batch_size - 2
settings.max_batch_size = max_batch_size
settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2
settings.sample_options = default_sample_options
settings.sched_metrics_port = args.sched_metrics_port
settings.gpu_only = args.memory_gpu_only
settings.use_self_defined_head_dim = True
settings.self_defined_head_dim = 576
settings.full_kv_cache_on_each_gpu = True
settings.k_cache_on = True
settings.v_cache_on = False
settings.kvc2_root_path = '/mnt/data/persist-kvc'
settings.kvc2_config_path = args.kvc2_config_dir
settings.memory_pool_size_GB = args.cpu_memory_size_GB
settings.evict_count = 40
settings.kvc2_metrics_port = args.kvc2_metrics_port
settings.load_from_disk = False
settings.save_to_disk = True
settings.strategy_name = args.sched_strategy
settings.auto_derive()
return settings

View file

@ -11,6 +11,7 @@ LastEditTime : 2024-08-12 06:31:14
import os
import shutil
import yaml
import psutil
from ktransformers.server.config.singleton import Singleton
from typing import Optional
@ -33,12 +34,15 @@ class Config(metaclass=Singleton):
user_path: str = os.path.expanduser("~")
localstore_path: str = os.path.join(user_path, ".ktransformers")
kvc2_config_dir = os.path.join(localstore_path, "kvc2")
config_path: str = os.path.join(localstore_path, Config.CONFIG_FILE_NAME)
if not os.path.exists(config_yaml):
print(f"Can't find config file, {config_yaml}")
exit(-1)
if not os.path.exists(localstore_path):
os.mkdir(localstore_path)
if not os.path.exists(kvc2_config_dir):
os.mkdir(kvc2_config_dir)
if not os.path.exists(config_path):
shutil.copyfile(config_yaml, config_path)
with open(config_path, "r", encoding="utf-8") as fp:
@ -60,11 +64,14 @@ class Config(metaclass=Singleton):
self.user_path: str = os.path.expanduser("~")
self.localstore_path: str = os.path.join(self.user_path, ".ktransformers")
# log configs
self.log_dir = os.path.join(self.base_path, Config.to_path(cfg["log"]["dir"]))
self.log_dir = os.path.join(self.localstore_path, cfg["log"]["dir"])
if not os.path.exists(self.log_dir):
os.mkdir(self.log_dir)
self.log_file = cfg["log"]["file"]
self.log_level = cfg["log"]["level"]
self.backup_count = cfg["log"]["backup_count"]
self.kvc2_config_dir = os.path.join(self.localstore_path, "kvc2")
# server configs
self.server: dict = cfg.get("server", {})
self.server_ip = self.server.get("ip", "0.0.0.0")
@ -74,7 +81,7 @@ class Config(metaclass=Singleton):
# db configs
self.db_configs: dict = cfg.get("db", {})
self.db_type = self.db_configs.get("type", "")
self.db_host = os.path.join(self.base_path, self.db_configs.get("host", ""))
self.db_host = self.localstore_path
self.db_port = self.db_configs.get("port", "")
self.db_name = self.db_configs.get("database", "")
self.db_pool_size = self.db_configs.get("pool_size")
@ -101,11 +108,6 @@ class Config(metaclass=Singleton):
self.optimize_config_path: Optional[str] = self.model.get(
"optimize_config_path", None
)
self.paged = self.model.get("paged", True)
self.total_context = self.model.get("total_context", 2**18)
self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1)
self.chunk_prefill_size = self.model.get("chunk_prefill_size", 8192)
self.max_new_tokens = self.model.get("max_new_tokens", 2000)
self.json_mode = self.model.get("json_mode", False)
@ -138,7 +140,6 @@ class Config(metaclass=Singleton):
self.repetition_penalty = self.model.get("repetition_penalty", 1.01)
self.frequency_penalty = self.model.get("frequency_penalty", 0.0)
self.presence_penalty = self.model.get("presence_penalty", 0.0)
self.max_response_tokens = self.model.get("max_response_tokens", 300)
self.response_chunk = self.model.get("response_chunk", 250)
self.no_code_formatting = self.model.get("no_code_formatting", False)
self.cache_8bit = self.model.get("cache_8bit", False)
@ -155,8 +156,9 @@ class Config(metaclass=Singleton):
self.web_cross_domain: bool = self.web.get("open_cross_domain", True)
self.mount_web: bool = self.web.get("mount", False)
# ext
self.ext: dict = cfg.get("ext", {})
self.cpu_infer = self.ext.get("cpu_infer", 10)
self.cpu_infer = psutil.cpu_count(logical=False) - 3
# file config
self.local_store_configs: dict = cfg.get("local_store", {})
@ -169,7 +171,6 @@ class Config(metaclass=Singleton):
# long context config
self.long_context_config: dict = cfg.get("long_context", {})
self.chunk_size = self.long_context_config.get("chunk_size", 4096)
self.max_seq_len = self.long_context_config.get("max_seq_len", 32000)
self.block_size = self.long_context_config.get("block_size", 128)
self.local_windows_len = self.long_context_config.get("local_windows_len", 4096)
@ -187,3 +188,21 @@ class Config(metaclass=Singleton):
# local chat
self.local_chat_config: dict = cfg.get("local_chat", {})
self.prompt_file = self.local_chat_config.get("prompt_file", None)
# asyncserver
self.sched_strategy = cfg["async_server"]["sched_strategy"]
self.sched_port = cfg["async_server"]["sched_port"]
self.sched_metrics_port = cfg["async_server"]["sched_metrics_port"]
self.kvc2_metrics_port = cfg["async_server"]["kvc2_metrics_port"]
self.max_batch_size = cfg["async_server"]["max_batch_size"]
self.page_size = cfg["attn"]["page_size"]
self.chunk_size = cfg["attn"]["chunk_size"]
self.memory_gpu_only = cfg["kvc2"]["gpu_only"]
self.cache_lens = ((self.cache_lens + self.page_size - 1) // self.page_size) * self.page_size
self.gpu_memory_size = 2*576*61*self.cache_lens
self.utilization_percentage = 1.0 #cfg["kvc2"]["utilization_percentage"]
self.cpu_memory_size_GB = cfg["kvc2"]["cpu_memory_size_GB"]
# only support 2 prefill task
self.max_prefill_batch_size = 2
self.max_decode_batch_size = self.max_batch_size - self.max_prefill_batch_size

View file

@ -5,24 +5,20 @@ from fastapi.staticfiles import StaticFiles
import uvicorn.logging
import uvicorn
import sys
import atexit
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
sys.path.insert(0, project_dir)
from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.args import ArgumentParser
from ktransformers.server.config.config import Config
from ktransformers.server.utils.create_interface import create_interface
from ktransformers.server.backend.args import default_args
from ktransformers.server.utils.create_interface import create_interface, GlobalInterface
from fastapi.openapi.utils import get_openapi
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.api import router, post_db_creation_operations
from ktransformers.server.utils.sql_utils import Base, SQLUtil
from ktransformers.server.config.log import logger
import subprocess
import tempfile
def mount_app_routes(mount_app: FastAPI):
sql_util = SQLUtil()
@ -34,7 +30,10 @@ def mount_app_routes(mount_app: FastAPI):
def create_app():
cfg = Config()
app = FastAPI()
if(hasattr(GlobalInterface.interface, "lifespan")):
app = FastAPI(lifespan=GlobalInterface.interface.lifespan)
else:
app = FastAPI()
if Config().web_cross_domain:
app.add_middleware(
CORSMiddleware,
@ -108,11 +107,32 @@ def main():
arg_parser = ArgumentParser(cfg)
# 初始化消息
args = arg_parser.parse_args()
if args.backend_type == "balance_serve":
import pickle
def cleanup():
if sched_process.poll() is None:
sched_process.terminate()
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
pickle.dump(args, temp_file)
temp_file_path = temp_file.name
current_file = __file__
target_file = os.path.join(os.path.dirname(current_file), "balance_serve", "sched_rpc.py")
target_file = os.path.normpath(target_file)
log_path = os.path.join(args.log_dir, "rpc.log")
log = open(log_path, "a")
sched_process = subprocess.Popen(
["python3", target_file, "--config", temp_file_path],
stdout=log,
stderr=log
)
print("sched_rpc started with PID:", sched_process.pid)
atexit.register(cleanup)
create_interface(config=cfg, default_args=cfg)
app = create_app()
custom_openapi(app)
create_interface(config=cfg, default_args=cfg)
run_api(
app=app,
host=args.host,
@ -121,6 +141,5 @@ def main():
ssl_certfile=args.ssl_certfile,
)
if __name__ == "__main__":
main()

View file

@ -1,4 +1,4 @@
torch >= 2.3.0,<=2.3.1
torch >= 2.3.0
transformers == 4.43.2
fastapi >= 0.111.0
langchain >= 0.2.0
@ -11,4 +11,6 @@ build
ninja
wheel
colorlog
fire
fire
zmq
psutil

View file

@ -2,7 +2,7 @@ from typing import List, Optional
from typing_extensions import Literal
from enum import Enum
from pydantic import BaseModel
from pydantic import BaseModel, Field
from ktransformers.server.schemas.base import Object
@ -30,8 +30,8 @@ class ChatCompletionCreate(BaseModel):
messages: List[Message]
model : str
stream : bool = False
temperature: Optional[float] = None
top_p: Optional[float] = None
temperature: Optional[float] = Field(default=1.0)
top_p: Optional[float] = Field(default=1.0)
def get_tokenizer_messages(self):
return [m.to_tokenizer_message() for m in self.messages]

View file

@ -15,6 +15,7 @@ from ktransformers.server.backend.context_manager import ThreadContextManager
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface
from ktransformers.server.backend.interfaces.transformers import TransformersInterface
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface
def create_interface(config: Config, default_args: ConfigArgs):
if config.backend_type=='transformers':
from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
@ -22,6 +23,8 @@ def create_interface(config: Config, default_args: ConfigArgs):
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface
elif config.backend_type == 'ktransformers':
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
elif config.backend_type == 'balance_serve':
from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface as BackendInterface
else:
raise NotImplementedError(f'{config.backend_type} not implemented')
GlobalInterface.interface = BackendInterface(default_args)
@ -30,9 +33,9 @@ def create_interface(config: Config, default_args: ConfigArgs):
class GlobalContextManager:
context_manager: ThreadContextManager
class GlobalInterface:
interface: TransformersInterface | KTransformersInterface | ExllamaInterface
interface: TransformersInterface | KTransformersInterface | ExllamaInterface
def get_thread_context_manager() -> ThreadContextManager:
def get_thread_context_manager() -> GlobalContextManager:
return GlobalContextManager.context_manager
def get_interface() -> TransformersInterface | KTransformersInterface | ExllamaInterface:
def get_interface() -> GlobalInterface:
return GlobalInterface.interface