mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-14 17:19:42 +00:00
Merge branch 'kvcache-ai:main' into main
This commit is contained in:
commit
877aec858e
251 changed files with 47224 additions and 749 deletions
|
@ -1,6 +1,6 @@
|
|||
import argparse
|
||||
from ktransformers.server.backend.args import ConfigArgs, default_args
|
||||
|
||||
from ktransformers.util.utils import get_free_ports
|
||||
|
||||
class ArgumentParser:
|
||||
def __init__(self, cfg):
|
||||
|
@ -16,20 +16,18 @@ class ArgumentParser:
|
|||
parser.add_argument("--web", type=bool, default=self.cfg.mount_web)
|
||||
parser.add_argument("--model_name", type=str, default=self.cfg.model_name)
|
||||
parser.add_argument("--model_dir", type=str)
|
||||
parser.add_argument("--model_path", type=str)
|
||||
parser.add_argument("--model_path", type=str, default=self.cfg.model_path)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter"
|
||||
)
|
||||
parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path)
|
||||
parser.add_argument("--optimize_config_path", default=self.cfg.optimize_config_path, type=str, required=False)
|
||||
parser.add_argument("--optimize_config_path", default=None, type=str, required=False)
|
||||
parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer)
|
||||
parser.add_argument("--type", type=str, default=self.cfg.backend_type)
|
||||
parser.add_argument("--chunk_prefill_size", type=int, default=8192)
|
||||
parser.add_argument("--backend_type", type=str, default=self.cfg.backend_type)
|
||||
parser.add_argument("--chunk_size", type=int, default=self.cfg.chunk_size)
|
||||
|
||||
# model configs
|
||||
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
|
||||
parser.add_argument("--paged", type=bool, default=self.cfg.paged)
|
||||
parser.add_argument("--total_context", type=int, default=self.cfg.total_context)
|
||||
parser.add_argument("--max_batch_size", type=int, default=self.cfg.max_batch_size)
|
||||
parser.add_argument("--max_new_tokens", type=int, default=self.cfg.max_new_tokens)
|
||||
parser.add_argument("--json_mode", type=bool, default=self.cfg.json_mode)
|
||||
|
@ -62,7 +60,6 @@ class ArgumentParser:
|
|||
parser.add_argument("--repetition_penalty", type=float, default=self.cfg.repetition_penalty)
|
||||
parser.add_argument("--frequency_penalty", type=float, default=self.cfg.frequency_penalty)
|
||||
parser.add_argument("--presence_penalty", type=float, default=self.cfg.presence_penalty)
|
||||
parser.add_argument("--max_response_tokens", type=int, default=self.cfg.max_response_tokens)
|
||||
parser.add_argument("--response_chunk", type=int, default=self.cfg.response_chunk)
|
||||
parser.add_argument("--no_code_formatting", type=bool, default=self.cfg.no_code_formatting)
|
||||
parser.add_argument("--cache_8bit", type=bool, default=self.cfg.cache_8bit)
|
||||
|
@ -73,6 +70,9 @@ class ArgumentParser:
|
|||
parser.add_argument("--batch_size", type=int, default=self.cfg.batch_size)
|
||||
parser.add_argument("--cache_lens", type=int, default=self.cfg.cache_lens)
|
||||
|
||||
# kvc2 config
|
||||
parser.add_argument("--kvc2_config_dir", type=str, default=self.cfg.kvc2_config_dir)
|
||||
|
||||
# log configs
|
||||
# log level: debug, info, warn, error, crit
|
||||
parser.add_argument("--log_dir", type=str, default=self.cfg.log_dir)
|
||||
|
@ -103,6 +103,18 @@ class ArgumentParser:
|
|||
# local chat
|
||||
parser.add_argument("--prompt_file", type=str, default=self.cfg.prompt_file)
|
||||
|
||||
|
||||
# async server
|
||||
parser.add_argument("--sched_strategy", type=str, default=self.cfg.sched_strategy)
|
||||
# parser.add_argument("--sched_port", type=int, default=self.cfg.sched_port)
|
||||
# parser.add_argument("--sched_metrics_port", type=int, default=self.cfg.sched_metrics_port)
|
||||
# parser.add_argument("--kvc2_metrics_port", type=int, default=self.cfg.kvc2_metrics_port)
|
||||
parser.add_argument("--page_size", type=str, default=self.cfg.page_size)
|
||||
parser.add_argument("--memory_gpu_only", type=str, default=self.cfg.memory_gpu_only)
|
||||
parser.add_argument("--utilization_percentage", type=str, default=self.cfg.utilization_percentage)
|
||||
parser.add_argument("--cpu_memory_size_GB", type=str, default=self.cfg.cpu_memory_size_GB)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
if (args.model_dir is not None or args.model_path is not None):
|
||||
if (args.model_path is not None):
|
||||
|
@ -123,6 +135,15 @@ class ArgumentParser:
|
|||
self.cfg.mount_web = args.web
|
||||
self.cfg.server_ip = args.host
|
||||
self.cfg.server_port = args.port
|
||||
self.cfg.backend_type = args.type
|
||||
self.cfg.user_force_think = args.force_think
|
||||
|
||||
args.gpu_memory_size = args.cache_lens*2*576*61
|
||||
self.cfg.gpu_memory_size = args.gpu_memory_size
|
||||
free_ports = get_free_ports(3, [args.port])
|
||||
args.sched_port = free_ports[0]
|
||||
args.sched_metrics_port = free_ports[1]
|
||||
args.kvc2_metrics_port = free_ports[2]
|
||||
self.cfg.sched_port = free_ports[0]
|
||||
self.cfg.sched_metrics_port = free_ports[1]
|
||||
self.cfg.kvc2_metrics_port = free_ports[2]
|
||||
return args
|
||||
|
|
|
@ -12,18 +12,10 @@ class ConfigArgs(BaseModel):
|
|||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
paged: bool = Field(None, description="Whether to use paged attention kv cache")
|
||||
total_context: int = Field(
|
||||
None,
|
||||
description=(
|
||||
"Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the"
|
||||
" total to distribute dynamically over however many jobs are active at once"
|
||||
),
|
||||
)
|
||||
max_batch_size: int = Field(
|
||||
None, description="Max number of batches to run at once, assuming the sequences will fit within total_context"
|
||||
)
|
||||
chunk_prefill_size: int = Field(
|
||||
chunk_size: int = Field(
|
||||
None,
|
||||
description=(
|
||||
"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new"
|
||||
|
@ -70,7 +62,6 @@ class ConfigArgs(BaseModel):
|
|||
repetition_penalty: float = Field(None, description="Sampler repetition penalty, default = 1.01 (1 to disable)")
|
||||
frequency_penalty: float = Field(None, description="Sampler frequency penalty, default = 0.0 (0 to disable)")
|
||||
presence_penalty: float = Field(None, description="Sampler presence penalty, default = 0.0 (0 to disable)")
|
||||
max_response_tokens: int = Field(None, description="Max tokens per response, default = 1000")
|
||||
response_chunk: int = Field(None, description="Space to reserve in context for reply, default = 250")
|
||||
no_code_formatting: bool = Field(None, description="Disable code formatting/syntax highlighting")
|
||||
cache_8bit: bool = Field(None, description="Use 8-bit (FP8) cache")
|
||||
|
|
|
@ -9,9 +9,11 @@ from ktransformers.server.backend.interfaces.transformers import TransformersThr
|
|||
from ktransformers.server.backend.interfaces.ktransformers import KTransformersThreadContext
|
||||
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaThreadContext
|
||||
|
||||
|
||||
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface
|
||||
from ktransformers.server.backend.interfaces.transformers import TransformersInterface
|
||||
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface
|
||||
|
||||
class ThreadContextManager:
|
||||
lock: Lock
|
||||
threads_context: Dict[ObjectID, ThreadContext]
|
||||
|
@ -36,7 +38,16 @@ class ThreadContextManager:
|
|||
elif isinstance(self.interface, TransformersInterface):
|
||||
new_context = TransformersThreadContext(run, self.interface)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
from ktransformers.server.backend.interfaces.balance_serve import BalanceServeThreadContext
|
||||
from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface
|
||||
if isinstance(self.interface, BalanceServeInterface):
|
||||
new_context = BalanceServeThreadContext(run, self.interface)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
# elif isinstance(self.interface, BalanceServeInterface):
|
||||
# new_context = BalanceServeThreadContext(run, self.interface)
|
||||
# else:
|
||||
# raise NotImplementedError
|
||||
self.threads_context[run.thread_id] = new_context
|
||||
# self.threads_context[run.thread_id] = ExllamaInferenceContext(run)
|
||||
re = self.threads_context[run.thread_id]
|
||||
|
|
410
ktransformers/server/backend/interfaces/balance_serve.py
Normal file
410
ktransformers/server/backend/interfaces/balance_serve.py
Normal file
|
@ -0,0 +1,410 @@
|
|||
from typing import Any, AsyncIterator, List, Optional, Set
|
||||
from ktransformers.models.custom_cache import KDeepSeekV3Cache
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
GenerationConfig,
|
||||
StaticCache,
|
||||
AutoModelForCausalLM,
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
|
||||
from ktransformers.server.config.config import Config
|
||||
from ..base import ThreadContext, BackendInterfaceBase
|
||||
import torch
|
||||
from ktransformers.server.backend.interfaces.transformers import (
|
||||
ConfigArgs,
|
||||
default_args,
|
||||
TextStreamer,
|
||||
)
|
||||
from ktransformers.server.schemas.base import ObjectID
|
||||
from ktransformers.server.config.log import logger
|
||||
from ktransformers.optimize.optimize import optimize_and_load_gguf
|
||||
from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM
|
||||
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
|
||||
from ktransformers.server.balance_serve.inference.model_runner import ModelRunner
|
||||
from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions
|
||||
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
|
||||
from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput
|
||||
from ktransformers.server.balance_serve.sched_rpc import SchedulerClient
|
||||
from ktransformers.server.balance_serve.settings import sched_ext
|
||||
from torch.multiprocessing import Queue
|
||||
import torch.multiprocessing as mp
|
||||
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||||
from ktransformers.server.utils.multi_timer import Profiler
|
||||
import zmq
|
||||
import time
|
||||
import queue
|
||||
import tempfile
|
||||
import asyncio
|
||||
import threading
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, Request
|
||||
import os
|
||||
|
||||
|
||||
|
||||
ktransformer_rules_dir = (
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/")
|
||||
)
|
||||
default_optimize_rules = {
|
||||
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
|
||||
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct-serve.yaml",
|
||||
}
|
||||
|
||||
async def chat_stream(queue: asyncio.Queue, tokenizer: AutoTokenizer):
|
||||
streamer = TextStreamer(tokenizer)
|
||||
while True:
|
||||
token = await queue.get()
|
||||
#print(f"Got token: {token}")
|
||||
if token is None:
|
||||
# str = f'{token}\n\n'
|
||||
# str = model.tokenizer.decode(token)
|
||||
s = streamer.end()
|
||||
if s is not None:
|
||||
yield s
|
||||
break
|
||||
|
||||
# str = model.tokenizer.decode(token)
|
||||
yield streamer.put(token)
|
||||
|
||||
|
||||
|
||||
def fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_tokens: torch.Tensor, query_manager: QueryManager = None):
|
||||
#print(len(query_updates), generated_tokens.size(0), generated_tokens)
|
||||
for i in range(generated_tokens.size(0)):
|
||||
print(generated_tokens[i].item())
|
||||
query_updates[i].generated_token = generated_tokens[i].item()
|
||||
if not query_manager.query_map[query_updates[i].id].is_prefill:
|
||||
pos = query_updates[i].active_position
|
||||
query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i]
|
||||
|
||||
def report_last_time_performance(profiler: Profiler):
|
||||
try:
|
||||
tokenize_time = profiler.get_timer_sec('tokenize')
|
||||
prefill_time = profiler.get_timer_sec('prefill')
|
||||
decode_time = profiler.get_timer_sec('decode')
|
||||
prefill_count = profiler.get_counter('prefill')
|
||||
decode_count = profiler.get_counter('decode')
|
||||
|
||||
logger.info(f'Performance(T/s): prefill {prefill_count/prefill_time}, decode {decode_count/decode_time}. Time(s): tokenize {tokenize_time}, prefill {prefill_time}, decode {decode_time}')
|
||||
except:
|
||||
logger.info(f'Performance statistics not recorded')
|
||||
|
||||
class Engine:
|
||||
sched_client : SchedulerClient
|
||||
updates : list[sched_ext.QueryUpdate]
|
||||
batch : sched_ext.BatchQueryTodo
|
||||
model_runner: ModelRunner
|
||||
sampler: Sampler
|
||||
query_manager: QueryManager
|
||||
cache: KDeepSeekV3Cache
|
||||
def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None):
|
||||
self.args = args
|
||||
|
||||
# 子进程和父进程无法共享 config 变量
|
||||
for key, value in vars(args).items():
|
||||
if value is not None and hasattr(Config(), key):
|
||||
setattr(Config(), key, value)
|
||||
|
||||
self.device = self.args.device
|
||||
self.sched_client = SchedulerClient(args.sched_port)
|
||||
self.updates = []
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
|
||||
|
||||
self.gen_queue = generated_token_queue
|
||||
|
||||
print(f"Getting inference context from sched_client.")
|
||||
inference_context = self.sched_client.get_inference_context_raw()
|
||||
print(f"Got inference context, sending it to subscribers.")
|
||||
inference_context = self.sched_client.rebuild_inferece_context(inference_context)
|
||||
self.cache.load(inference_context)
|
||||
print(f"kv_cache loaded successfully.")
|
||||
|
||||
self.block_num = inference_context.k_cache[0].size(1)
|
||||
with torch.device("meta"):
|
||||
if config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||
self.model = KDeepseekV3ForCausalLM(config, self.cache)
|
||||
elif config.architectures[0] == "DeepseekV2ForCausalLM":
|
||||
self.model = KDeepseekV2ForCausalLM(config, self.cache)
|
||||
# print(self.block_num)
|
||||
|
||||
context = zmq.Context()
|
||||
|
||||
|
||||
self.pub_socket = context.socket(zmq.PUB)
|
||||
self.pub_socket.bind(f"ipc://{broadcast_endpoint}")
|
||||
# time.sleep(1) # make sure all subscribers are ready
|
||||
|
||||
|
||||
try:
|
||||
generation_config = GenerationConfig.from_pretrained(args.model_dir)
|
||||
except:
|
||||
generation_config = GenerationConfig(
|
||||
max_length=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
do_sample=True
|
||||
)
|
||||
|
||||
if args.optimize_config_path is None:
|
||||
optimize_config_path = default_optimize_rules[config.architectures[0]]
|
||||
|
||||
else:
|
||||
optimize_config_path = args.optimize_config_path
|
||||
gguf_path = args.gguf_path
|
||||
if gguf_path is None:
|
||||
gguf_path = input(
|
||||
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
|
||||
" belong to current model):"
|
||||
)
|
||||
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
|
||||
self.model.generation_config = generation_config
|
||||
if self.model.generation_config.pad_token_id is None:
|
||||
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
||||
|
||||
self.model.eval()
|
||||
#@TODO add config
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
|
||||
|
||||
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size)
|
||||
self.sampler = Sampler()
|
||||
self.query_manager = QueryManager(device = self.device, page_size = args.page_size)
|
||||
|
||||
|
||||
def sampling(self, forward_output: ForwardBatchOutput):
|
||||
generated_tokens = torch.empty(0, device=self.device, dtype=torch.int32)
|
||||
for i in range(forward_output.num_batchs):
|
||||
logit = forward_output.logits[i]
|
||||
if hasattr(forward_output, "temperatures"):
|
||||
temperatures = forward_output.temperatures[i]
|
||||
else:
|
||||
temperatures = None
|
||||
|
||||
if hasattr(forward_output, "top_ps"):
|
||||
top_ps = forward_output.top_ps[i]
|
||||
else:
|
||||
top_ps = None
|
||||
|
||||
sample_options = SamplingOptions(logit.size(0), self.device, pretrained_config=self.model.generation_config, temperatures=temperatures, top_ps=top_ps)
|
||||
generated_tokens, probs=self.sampler(logit, sample_options)
|
||||
return generated_tokens, probs
|
||||
|
||||
def loop(self):
|
||||
|
||||
next_batch = None
|
||||
|
||||
while True:
|
||||
self.batch = next_batch
|
||||
if self.batch is not None:
|
||||
self.model_runner.run(self.batch, self.query_manager)
|
||||
|
||||
if len(self.updates) > 0:
|
||||
for q in self.updates:
|
||||
if q.is_prefill == True:
|
||||
continue
|
||||
# print(f"Putting token {q.generated_token} into queue for query id: {q.id}")
|
||||
try:
|
||||
self.gen_queue.put((q.id, q.generated_token if q.decode_done == False else None), timeout=5)
|
||||
except queue.Full:
|
||||
pass#print("Queue is full after timeout; unable to put more items.")
|
||||
|
||||
next_batch = self.sched_client.update_last_batch(self.updates)
|
||||
if next_batch.query_ids == []:
|
||||
next_batch = None
|
||||
self.pub_socket.send_pyobj(next_batch)
|
||||
|
||||
if next_batch is not None:
|
||||
self.query_manager.add_query(next_batch)
|
||||
|
||||
|
||||
if self.batch is not None:
|
||||
self.model_runner.sync()
|
||||
print(f"Model execution time (GPU): {self.model_runner.model_time:.3f} ms")
|
||||
# if self.rank == 0:
|
||||
|
||||
generated_tokens, probs = self.sampling( self.model_runner.output)
|
||||
|
||||
self.updates = self.query_manager.update(self.batch)
|
||||
fill_generated_tokens(self.updates, generated_tokens, self.query_manager)
|
||||
else:
|
||||
self.updates = []
|
||||
|
||||
class BalanceServeThreadContext(ThreadContext):
|
||||
def get_local_messages(self):
|
||||
local_messages = []
|
||||
for m in self.messages:
|
||||
local_messages.append({"role": m.role.value, "content": m.get_text_content()})
|
||||
|
||||
return local_messages
|
||||
|
||||
|
||||
def run_engine(args, token_queue, broadcast_endpoint, event):
|
||||
engine = Engine(args, token_queue, broadcast_endpoint)
|
||||
if args.use_cuda_graph:
|
||||
engine.model_runner.warmup()
|
||||
|
||||
event.set()
|
||||
engine.loop()
|
||||
|
||||
|
||||
class BalanceServeInterface(BackendInterfaceBase):
|
||||
use_static_cache: bool = True
|
||||
|
||||
model: Any
|
||||
tokenizer: AutoTokenizer
|
||||
|
||||
cache: StaticCache
|
||||
generated_ids: torch.Tensor
|
||||
seq_length: int
|
||||
|
||||
streamer: TextStreamer
|
||||
|
||||
# thread_related
|
||||
last_request_id: Optional[str] = None
|
||||
ever_generated_ids: Set[int] = set()
|
||||
def __init__(self, args: ConfigArgs = default_args):
|
||||
self.args = args
|
||||
self.queue_map:dict[int,asyncio.Queue] = {}
|
||||
self.thread_map: dict[int, int] = {}
|
||||
processes = []
|
||||
self.broadcast_endpoint = tempfile.NamedTemporaryFile(delete=False).name # @TODO add to config
|
||||
ctx = mp.get_context("spawn")
|
||||
self.token_queue = ctx.Queue(maxsize=1000)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
self.sched_client = SchedulerClient(args.sched_port)
|
||||
self.streamer = TextStreamer(self.tokenizer)
|
||||
|
||||
start_event = ctx.Event()
|
||||
|
||||
p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
start_event.wait()
|
||||
|
||||
def run_queue_proxy(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self.queue_proxy())
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, app: FastAPI):
|
||||
asyncio.create_task(self.queue_proxy())
|
||||
yield
|
||||
|
||||
async def queue_proxy(self):
|
||||
print("Queue Proxy Started")
|
||||
while True:
|
||||
try:
|
||||
query_id, token = self.token_queue.get_nowait()
|
||||
try:
|
||||
# query id might not be allocated yet
|
||||
self.queue_map[query_id].put_nowait(token)
|
||||
#print(f"Proxy Put token: {token} to queue for query id: {query_id}")
|
||||
except asyncio.QueueFull:
|
||||
#print(f"Queue for query id: {query_id} is full, waiting to put: {token}")
|
||||
await self.queue_map[query_id].put(token)
|
||||
|
||||
except queue.Empty:
|
||||
# print("no new token")
|
||||
# await asyncio.sleep(1)
|
||||
await asyncio.sleep(0)
|
||||
def tokenize_prompt(self, prompt: str):
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.args.device)
|
||||
return input_ids
|
||||
|
||||
def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):
|
||||
for m in messages:
|
||||
if m["role"] == "system":
|
||||
logger.warning(f'change {m["role"]} to user')
|
||||
m["role"] = "user"
|
||||
|
||||
new_messages = [messages[0]]
|
||||
for m in messages[1:]:
|
||||
if m["role"] == "user" and new_messages[-1]["role"] == "user":
|
||||
logger.warning("merge two adjacent user messages")
|
||||
new_messages[-1]["content"] += '\n' + m["content"]
|
||||
else:
|
||||
new_messages.append(m)
|
||||
input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
|
||||
# drop <think> token in chat template
|
||||
if input_str.endswith('<think>\n'):
|
||||
input_str = input_str[:-len('<think>\n')]
|
||||
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
|
||||
logger.debug(f"get input ids of shape {input_ids.shape}")
|
||||
return input_ids
|
||||
|
||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||
profiler = Profiler()
|
||||
profiler.create_and_start_timer("tokenize")
|
||||
|
||||
if isinstance(local_messages, List):
|
||||
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
||||
elif isinstance(local_messages, str):
|
||||
#local_messages = local_messages[0]['content']
|
||||
input_ids = self.tokenize_prompt(local_messages)
|
||||
else:
|
||||
raise ValueError("local_messages should be List or str")
|
||||
if Config().user_force_think:
|
||||
token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
|
||||
input_ids = torch.cat(
|
||||
[input_ids, token_thinks], dim=1
|
||||
)
|
||||
|
||||
|
||||
profiler.pause_timer("tokenize")
|
||||
|
||||
profiler.create_and_start_timer("prefill")
|
||||
|
||||
|
||||
|
||||
query_add = sched_ext.QueryAdd()
|
||||
query_add.query_token = input_ids[0].tolist()
|
||||
query_length = input_ids[0].shape[0]
|
||||
query_add.query_length = query_length
|
||||
profiler.set_counter("prefill", query_length)
|
||||
#@TODO add server
|
||||
stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
|
||||
query_add.stop_criteria = stop_criteria
|
||||
if temperature == 0:
|
||||
temperature = 0.0001
|
||||
query_add.sample_options.temperature = temperature
|
||||
if top_p == 0:
|
||||
top_p = 0.0001
|
||||
query_add.sample_options.top_p = top_p
|
||||
query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens)
|
||||
query_id = self.sched_client.add_query(query_add)
|
||||
queue = asyncio.Queue(maxsize=self.args.max_new_tokens)
|
||||
self.queue_map[query_id] = queue
|
||||
self.thread_map[thread_id] = query_id
|
||||
is_first_token = True
|
||||
async for token in chat_stream(self.queue_map[query_id], self.tokenizer):
|
||||
if is_first_token:
|
||||
is_first_token=False
|
||||
profiler.pause_timer("prefill")
|
||||
profiler.create_and_start_timer("decode")
|
||||
profiler.set_counter("decode", 0)
|
||||
if Config().user_force_think:
|
||||
think = '<think>\n'
|
||||
print(think, end="",flush=True)
|
||||
yield think, None
|
||||
else:
|
||||
profiler.inc("decode")
|
||||
yield token, None
|
||||
profiler.pause_timer("decode")
|
||||
report_last_time_performance(profiler)
|
||||
yield self.streamer.end(), None
|
||||
if profiler.get_counter('decode') >= self.args.max_new_tokens - 1:
|
||||
yield "", "length"
|
||||
else:
|
||||
yield "", "stop"
|
||||
|
||||
|
||||
yield RawUsage(
|
||||
tokenize_time = profiler.get_timer_sec('tokenize'),
|
||||
prefill_time = profiler.get_timer_sec('prefill'),
|
||||
decode_time = profiler.get_timer_sec('decode'),
|
||||
prefill_count = profiler.get_counter('prefill'),
|
||||
decode_count = profiler.get_counter('decode'),
|
||||
)
|
|
@ -211,11 +211,11 @@ class KTransformersInterface(TransformersInterface):
|
|||
|
||||
chunk_start = 0
|
||||
while chunk_start < input_ids_length:
|
||||
chunk_end = min(chunk_start + self.args.chunk_prefill_size, input_ids_length)
|
||||
chunk_end = min(chunk_start + self.args.chunk_size, input_ids_length)
|
||||
if self.cache != None:
|
||||
self.cache.cur_idx=cache_position[chunk_start:chunk_end]
|
||||
logits = chunk_prefill(input_ids[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end])
|
||||
chunk_start += self.args.chunk_prefill_size
|
||||
chunk_start += self.args.chunk_size
|
||||
|
||||
if flashinfer_enabled:
|
||||
MLAWrapperSingleton.reset_buffer()
|
||||
|
|
|
@ -208,6 +208,8 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
temperature = self.model.generation_config.temperature
|
||||
if top_p is None:
|
||||
top_p = self.model.generation_config.top_p
|
||||
if top_p == 0:
|
||||
top_p = 0.0001
|
||||
generation_config, model_kwargs = self.model._prepare_generation_config(
|
||||
None, max_length=self.args.max_new_tokens,
|
||||
do_sample=True,
|
||||
|
@ -341,7 +343,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
for i in range(1, self.max_new_tokens):
|
||||
with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
|
||||
if flashinfer_enabled:
|
||||
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
|
||||
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1, None,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
|
||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
|
|
0
ktransformers/server/balance_serve/inference/__init__.py
Normal file
0
ktransformers/server/balance_serve/inference/__init__.py
Normal file
142
ktransformers/server/balance_serve/inference/config.py
Normal file
142
ktransformers/server/balance_serve/inference/config.py
Normal file
|
@ -0,0 +1,142 @@
|
|||
'''
|
||||
Date: 2024-11-07 07:30:16
|
||||
LastEditors: djw
|
||||
LastEditTime: 2024-11-15 14:23:26
|
||||
'''
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
import yaml
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
class ModelConfig:
|
||||
vocab_size: int = 32000
|
||||
n_layer: int = 1
|
||||
n_head: int = 32
|
||||
dim: int = 4096
|
||||
intermediate_size: int = 18944
|
||||
n_local_heads: int = 8
|
||||
head_dim: int = 128
|
||||
rope_base: float = 1000000.0
|
||||
norm_eps: float = 1e-06
|
||||
rope_scaling: Optional[dict] = None
|
||||
rms_norm_eps: float = 1e-6
|
||||
hidden_act: str = "silu"
|
||||
model_path: str
|
||||
gguf_path: str
|
||||
optimize_rule_path: str
|
||||
speculative_rule_path: str
|
||||
|
||||
|
||||
# quantize config
|
||||
quant_algorithm: Optional[str] = None
|
||||
quant_group_size: Optional[int] = None
|
||||
quant_num_bits: Optional[int] = None
|
||||
|
||||
json_key_map = {
|
||||
"vocab_size": "vocab_size",
|
||||
"n_layer": "num_hidden_layers",
|
||||
"n_head": "num_attention_heads",
|
||||
"dim": "hidden_size",
|
||||
"intermediate_size": "intermediate_size",
|
||||
"n_local_heads": "num_key_value_heads",
|
||||
"rope_base": "rope_theta",
|
||||
"norm_eps": "norm_eps",
|
||||
"rms_norm_eps": "rms_norm_eps",
|
||||
"hidden_act": "hidden_act",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
self.model_path = config["model"]["model_path"]
|
||||
self.gguf_path = config["model"]["gguf_path"]
|
||||
self.optimize_rule_path = config["model"]["optimize_rule_path"]
|
||||
if "speculative_rule_path" in config["model"]:
|
||||
self.speculative_rule_path = config["model"]["speculative_rule_path"]
|
||||
self.speculative_gguf_path = config["model"]["speculative_gguf_path"]
|
||||
self.speculative_model_path = config["model"]["speculative_model_path"]
|
||||
self.quant_algorithm = config["model"]["quant"]["algorithm"]
|
||||
self.quant_group_size = config["model"]["quant"]["group_size"]
|
||||
self.quant_num_bits = config["model"]["quant"]["num_bits"]
|
||||
self.load_config()
|
||||
self.n_layer = config["model"]["n_layers"]
|
||||
|
||||
def load_config(self):
|
||||
config_file = f"{self.model_path}/config.json"
|
||||
try:
|
||||
with open(config_file, "r") as f:
|
||||
config_data = json.load(f)
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Configuration file not found at {config_file}")
|
||||
|
||||
for attr, json_key in self.json_key_map.items():
|
||||
if json_key in config_data:
|
||||
setattr(self, attr, config_data[json_key])
|
||||
else:
|
||||
setattr(self, attr, getattr(self, attr))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class ParallelConfig:
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
) -> None:
|
||||
self.pipeline_parallel_size = config["parallel"]["pp"]
|
||||
self.tensor_parallel_size = config["parallel"]["tp"]
|
||||
self.disable_custom_all_reduce = config["parallel"]["disable_custom_all_reduce"]
|
||||
self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size
|
||||
|
||||
class AttnConfig:
|
||||
page_size: int = 256
|
||||
block_num: int = 32
|
||||
max_batch_token : int = 256
|
||||
max_batch_size: int = 32
|
||||
|
||||
def __init__(self, config):
|
||||
self.page_size = config["attn"]["page_size"]
|
||||
self.block_num = config["attn"]["block_num"]
|
||||
self.max_batch_token = config["attn"]["max_batch_token"]
|
||||
self.max_batch_size = config["attn"]["max_batch_size"]
|
||||
|
||||
|
||||
class SamplerConfig():
|
||||
# Batched sampling params
|
||||
temperatures: float
|
||||
is_all_greedy: bool
|
||||
|
||||
def __init__(self, config):
|
||||
self.temperatures = config["sample"]["temperature"]
|
||||
self.is_all_greedy = True
|
||||
|
||||
|
||||
def load_yaml_config(file_path):
|
||||
with open(file_path, "r") as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
|
||||
|
||||
class LLMConfig:
|
||||
model_config: ModelConfig
|
||||
parallel_config: ParallelConfig
|
||||
attn_config: AttnConfig
|
||||
sample_config: SamplerConfig
|
||||
config_file: str
|
||||
|
||||
def __init__(self, config_file):
|
||||
self.config_file = config_file
|
||||
config = load_yaml_config(config_file)
|
||||
self.model_config = ModelConfig(config)
|
||||
self.parallel_config = ParallelConfig(config)
|
||||
self.attn_config = AttnConfig(config)
|
||||
self.sample_config = SamplerConfig(config)
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
from .communication_op import *
|
||||
from .parallel_state import *
|
||||
from .utils import *
|
|
@ -0,0 +1,39 @@
|
|||
"""
|
||||
Date: 2024-12-11 06:02:42
|
||||
LastEditors: djw
|
||||
LastEditTime: 2024-12-12 09:52:06
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from .parallel_state import get_tp_group
|
||||
|
||||
|
||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> torch.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
return get_tp_group().all_reduce(input_, bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap)
|
||||
|
||||
|
||||
def tensor_model_parallel_all_gather(
|
||||
input_: torch.Tensor, dim: int = -1
|
||||
) -> torch.Tensor:
|
||||
"""All-gather the input tensor across model parallel group."""
|
||||
return get_tp_group().all_gather(input_, dim)
|
||||
|
||||
|
||||
def tensor_model_parallel_gather(
|
||||
input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Gather the input tensor across model parallel group."""
|
||||
return get_tp_group().gather(input_, dst, dim)
|
||||
|
||||
|
||||
def broadcast_tensor_dict(
|
||||
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0
|
||||
):
|
||||
if not torch.distributed.is_initialized():
|
||||
return tensor_dict
|
||||
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
|
|
@ -0,0 +1,168 @@
|
|||
"""This file is a pure Python wrapper for the cudart library.
|
||||
It avoids the need to compile a separate shared library, and is
|
||||
convenient for use when we just need to call a few functions.
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# this line makes it possible to directly load `libcudart.so` using `ctypes`
|
||||
import torch # noqa
|
||||
|
||||
# === export types and functions from cudart to Python ===
|
||||
# for the original cudart definition, please check
|
||||
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
|
||||
|
||||
cudaError_t = ctypes.c_int
|
||||
cudaMemcpyKind = ctypes.c_int
|
||||
|
||||
|
||||
class cudaIpcMemHandle_t(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: List[Any]
|
||||
|
||||
|
||||
def find_loaded_library(lib_name) -> Optional[str]:
|
||||
"""
|
||||
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
|
||||
the file `/proc/self/maps` contains the memory maps of the process, which includes the
|
||||
shared libraries loaded by the process. We can use this file to find the path of the
|
||||
a loaded library.
|
||||
""" # noqa
|
||||
found = False
|
||||
with open("/proc/self/maps") as f:
|
||||
for line in f:
|
||||
if lib_name in line:
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
# the library is not loaded in the current process
|
||||
return None
|
||||
# if lib_name is libcudart, we need to match a line with:
|
||||
# address /path/to/libcudart-hash.so.11.0
|
||||
start = line.index("/")
|
||||
path = line[start:].strip()
|
||||
filename = path.split("/")[-1]
|
||||
assert filename.rpartition(".so")[0].startswith(lib_name), \
|
||||
f"Unexpected filename: {filename} for library {lib_name}"
|
||||
return path
|
||||
|
||||
|
||||
class CudaRTLibrary:
|
||||
exported_functions = [
|
||||
# cudaError_t cudaSetDevice ( int device )
|
||||
Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
|
||||
# cudaError_t cudaDeviceSynchronize ( void )
|
||||
Function("cudaDeviceSynchronize", cudaError_t, []),
|
||||
# cudaError_t cudaDeviceReset ( void )
|
||||
Function("cudaDeviceReset", cudaError_t, []),
|
||||
|
||||
# const char* cudaGetErrorString ( cudaError_t error )
|
||||
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
|
||||
|
||||
# cudaError_t cudaMalloc ( void** devPtr, size_t size )
|
||||
Function("cudaMalloc", cudaError_t,
|
||||
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
|
||||
# cudaError_t cudaFree ( void* devPtr )
|
||||
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
|
||||
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
|
||||
Function("cudaMemset", cudaError_t,
|
||||
[ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
|
||||
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
|
||||
Function("cudaMemcpy", cudaError_t, [
|
||||
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind
|
||||
]),
|
||||
|
||||
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
|
||||
Function("cudaIpcGetMemHandle", cudaError_t,
|
||||
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
|
||||
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
|
||||
Function("cudaIpcOpenMemHandle", cudaError_t, [
|
||||
ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint
|
||||
]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: Dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
if so_file is None:
|
||||
so_file = find_loaded_library("libcudart")
|
||||
assert so_file is not None, \
|
||||
"libcudart is not loaded in the current process"
|
||||
if so_file not in CudaRTLibrary.path_to_library_cache:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
CudaRTLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = CudaRTLibrary.path_to_library_cache[so_file]
|
||||
|
||||
if so_file not in CudaRTLibrary.path_to_dict_mapping:
|
||||
_funcs = {}
|
||||
for func in CudaRTLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def CUDART_CHECK(self, result: cudaError_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.cudaGetErrorString(result)
|
||||
raise RuntimeError(f"CUDART error: {error_str}")
|
||||
|
||||
def cudaGetErrorString(self, error: cudaError_t) -> str:
|
||||
return self.funcs["cudaGetErrorString"](error).decode("utf-8")
|
||||
|
||||
def cudaSetDevice(self, device: int) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
|
||||
|
||||
def cudaDeviceSynchronize(self) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
|
||||
|
||||
def cudaDeviceReset(self) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
|
||||
|
||||
def cudaMalloc(self, size: int) -> ctypes.c_void_p:
|
||||
devPtr = ctypes.c_void_p()
|
||||
self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
|
||||
return devPtr
|
||||
|
||||
def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
|
||||
|
||||
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int,
|
||||
count: int) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
|
||||
|
||||
def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p,
|
||||
count: int) -> None:
|
||||
cudaMemcpyDefault = 4
|
||||
kind = cudaMemcpyDefault
|
||||
self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
|
||||
|
||||
def cudaIpcGetMemHandle(self,
|
||||
devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
|
||||
handle = cudaIpcMemHandle_t()
|
||||
self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"](
|
||||
ctypes.byref(handle), devPtr))
|
||||
return handle
|
||||
|
||||
def cudaIpcOpenMemHandle(self,
|
||||
handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
|
||||
cudaIpcMemLazyEnablePeerAccess = 1
|
||||
devPtr = ctypes.c_void_p()
|
||||
self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"](
|
||||
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess))
|
||||
return devPtr
|
|
@ -0,0 +1,310 @@
|
|||
import ctypes
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import server.envs as envs
|
||||
from server.inference.distributed.cuda_wrapper import CudaRTLibrary
|
||||
from server.inference.distributed.custom_all_reduce_utils import gpu_p2p_access_check
|
||||
from server.inference.distributed.parallel_state import in_the_same_node_as
|
||||
from server.inference.platforms import current_platform
|
||||
from server.utils import cuda_device_count_stateless
|
||||
import vLLMCustomAllreduce
|
||||
|
||||
try:
|
||||
vLLMCustomAllreduce.meta_size()
|
||||
custom_ar = True
|
||||
except Exception:
|
||||
# For AMD GPUs and CPUs
|
||||
custom_ar = False
|
||||
|
||||
|
||||
def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
continue
|
||||
if envs.VLLM_SKIP_P2P_CHECK:
|
||||
print("Skipping P2P check and trusting the driver's P2P report.")
|
||||
return torch.cuda.can_device_access_peer(rank, i)
|
||||
if not gpu_p2p_access_check(rank, i):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_weak_contiguous(inp: torch.Tensor):
|
||||
return inp.is_contiguous() or (
|
||||
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size()
|
||||
)
|
||||
|
||||
|
||||
class CustomAllreduce:
|
||||
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||
|
||||
# max_size: max supported allreduce size
|
||||
def __init__(
|
||||
self,
|
||||
group: ProcessGroup,
|
||||
device: Union[int, str, torch.device],
|
||||
max_size=8192 * 1024,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the CustomAllreduce to. If None,
|
||||
it will be bind to f"cuda:{local_rank}".
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device, and all communicators in this group
|
||||
are in the same node.
|
||||
"""
|
||||
self._IS_CAPTURING = False
|
||||
self.disabled = True
|
||||
|
||||
if not custom_ar:
|
||||
# disable because of missing custom allreduce library
|
||||
# e.g. in a non-cuda environment
|
||||
return
|
||||
|
||||
self.group = group
|
||||
|
||||
assert (
|
||||
dist.get_backend(group) != dist.Backend.NCCL
|
||||
), "CustomAllreduce should be attached to a non-NCCL group."
|
||||
|
||||
if not all(in_the_same_node_as(group, source_rank=0)):
|
||||
# No need to initialize custom allreduce for multi-node case.
|
||||
print(
|
||||
"Custom allreduce is disabled because this process group"
|
||||
" spans across nodes."
|
||||
)
|
||||
return
|
||||
|
||||
rank = dist.get_rank(group=self.group)
|
||||
world_size = dist.get_world_size(group=self.group)
|
||||
if world_size == 1:
|
||||
# No need to initialize custom allreduce for single GPU case.
|
||||
return
|
||||
|
||||
if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
|
||||
print(
|
||||
"Custom allreduce is disabled due to an unsupported world"
|
||||
" size: %d. Supported world sizes: %s. To silence this "
|
||||
"warning, specify disable_custom_all_reduce=True explicitly.",
|
||||
world_size,
|
||||
str(CustomAllreduce._SUPPORTED_WORLD_SIZES),
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
if cuda_visible_devices:
|
||||
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||
else:
|
||||
device_ids = list(range(cuda_device_count_stateless()))
|
||||
|
||||
physical_device_id = device_ids[device.index]
|
||||
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
|
||||
gather_list = [
|
||||
torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(world_size)
|
||||
]
|
||||
dist.all_gather(gather_list, tensor, group=self.group)
|
||||
physical_device_ids = [t.item() for t in gather_list]
|
||||
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom allreduce is not supported
|
||||
# this checks hardware and driver support for NVLink
|
||||
assert current_platform.is_cuda()
|
||||
from server.inference.platforms.cuda import CudaPlatform
|
||||
|
||||
cuda_platform: CudaPlatform = current_platform
|
||||
full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids)
|
||||
if world_size > 2 and not full_nvlink:
|
||||
print(
|
||||
"Custom allreduce is disabled because it's not supported on"
|
||||
" more than two PCIe-only GPUs. To silence this warning, "
|
||||
"specify disable_custom_all_reduce=True explicitly."
|
||||
)
|
||||
return
|
||||
# test P2P capability, this checks software/cudaruntime support
|
||||
# this is expensive to compute at the first time
|
||||
# then we cache the result
|
||||
if not _can_p2p(rank, world_size):
|
||||
print(
|
||||
"Custom allreduce is disabled because your platform lacks "
|
||||
"GPU P2P capability or P2P test failed. To silence this "
|
||||
"warning, specify disable_custom_all_reduce=True explicitly."
|
||||
)
|
||||
return
|
||||
|
||||
self.disabled = False
|
||||
# Buffers memory are owned by this Python class and passed to C++.
|
||||
# Meta data composes of two parts: meta data for synchronization and a
|
||||
# temporary buffer for storing intermediate allreduce results.
|
||||
self.meta_ptrs = self.create_shared_buffer(
|
||||
vLLMCustomAllreduce.meta_size() + max_size, group=group
|
||||
)
|
||||
# This is a pre-registered IPC buffer. In eager mode, input tensors
|
||||
# are first copied into this buffer before allreduce is performed
|
||||
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
|
||||
# This is a buffer for storing the tuples of pointers pointing to
|
||||
# IPC buffers from all ranks. Each registered tuple has size of
|
||||
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
|
||||
# is enough for 131072 such tuples. The largest model I've seen only
|
||||
# needs less than 10000 of registered tuples.
|
||||
self.rank_data = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self.max_size = max_size
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.full_nvlink = full_nvlink
|
||||
self._ptr = vLLMCustomAllreduce.init_custom_ar(
|
||||
self.meta_ptrs, self.rank_data, rank, self.full_nvlink
|
||||
)
|
||||
vLLMCustomAllreduce.register_buffer(self._ptr, self.buffer_ptrs)
|
||||
|
||||
@staticmethod
|
||||
def create_shared_buffer(
|
||||
size_in_bytes: int, group: Optional[ProcessGroup] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Creates a shared buffer and returns a list of pointers
|
||||
representing the buffer on all processes in the group.
|
||||
"""
|
||||
lib = CudaRTLibrary()
|
||||
pointer = lib.cudaMalloc(size_in_bytes)
|
||||
handle = lib.cudaIpcGetMemHandle(pointer)
|
||||
world_size = dist.get_world_size(group=group)
|
||||
rank = dist.get_rank(group=group)
|
||||
handles = [None] * world_size
|
||||
dist.all_gather_object(handles, handle, group=group)
|
||||
|
||||
pointers: List[int] = []
|
||||
for i, h in enumerate(handles):
|
||||
if i == rank:
|
||||
pointers.append(pointer.value) # type: ignore
|
||||
else:
|
||||
pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore
|
||||
|
||||
return pointers
|
||||
|
||||
@staticmethod
|
||||
def free_shared_buffer(
|
||||
pointers: List[int], group: Optional[ProcessGroup] = None
|
||||
) -> None:
|
||||
rank = dist.get_rank(group=group)
|
||||
lib = CudaRTLibrary()
|
||||
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
|
||||
|
||||
@contextmanager
|
||||
def capture(self):
|
||||
"""
|
||||
The main responsibility of this context manager is the
|
||||
`register_graph_buffers` call at the end of the context.
|
||||
It records all the buffer addresses used in the CUDA graph.
|
||||
"""
|
||||
try:
|
||||
self._IS_CAPTURING = True
|
||||
yield
|
||||
finally:
|
||||
self._IS_CAPTURING = False
|
||||
if not self.disabled:
|
||||
self.register_graph_buffers()
|
||||
|
||||
def register_graph_buffers(self):
|
||||
handle, offset = vLLMCustomAllreduce.get_graph_buffer_ipc_meta(self._ptr)
|
||||
print("Registering %d cuda graph addresses", len(offset))
|
||||
# We cannot directly use `dist.all_gather_object` here
|
||||
# because it is incompatible with `gloo` backend under inference mode.
|
||||
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
||||
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
|
||||
all_data[self.rank] = [handle, offset]
|
||||
ranks = sorted(dist.get_process_group_ranks(group=self.group))
|
||||
for i, rank in enumerate(ranks):
|
||||
dist.broadcast_object_list(
|
||||
all_data[i], src=rank, group=self.group, device="cpu"
|
||||
)
|
||||
# Unpack list of tuples to tuple of lists.
|
||||
handles = [d[0] for d in all_data] # type: ignore
|
||||
offsets = [d[1] for d in all_data] # type: ignore
|
||||
vLLMCustomAllreduce.register_graph_buffers(self._ptr, handles, offsets)
|
||||
|
||||
def should_custom_ar(self, inp: torch.Tensor):
|
||||
if self.disabled:
|
||||
return False
|
||||
inp_size = inp.numel() * inp.element_size()
|
||||
# custom allreduce requires input byte size to be multiples of 16
|
||||
if inp_size % 16 != 0:
|
||||
return False
|
||||
if not is_weak_contiguous(inp):
|
||||
return False
|
||||
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
|
||||
# little performance improvement over NCCL.
|
||||
if self.world_size == 2 or self.full_nvlink:
|
||||
return inp_size < self.max_size
|
||||
return False
|
||||
|
||||
def all_reduce(
|
||||
self, inp: torch.Tensor, *, out: torch.Tensor = None, bsz_tensor: torch.Tensor = None, registered: bool = False,
|
||||
is_compute_bound=False, overlap=False
|
||||
):
|
||||
"""Performs an out-of-place all reduce.
|
||||
|
||||
If registered is True, this assumes inp's pointer is already
|
||||
IPC-registered. Otherwise, inp is first copied into a pre-registered
|
||||
buffer.
|
||||
"""
|
||||
if is_compute_bound:
|
||||
sms = 2 if overlap else 36
|
||||
else:
|
||||
sms = 20 if overlap else 36
|
||||
#print("all reduce sms", sms)
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
if registered:
|
||||
vLLMCustomAllreduce.all_reduce(self._ptr, inp, out, 0, 0, bsz_tensor, block_limit=sms)
|
||||
else:
|
||||
vLLMCustomAllreduce.all_reduce(
|
||||
self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size, bsz_tensor, block_limit=sms
|
||||
)
|
||||
return out
|
||||
|
||||
def custom_all_reduce(self, input: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> Optional[torch.Tensor]:
|
||||
"""The main allreduce API that provides support for cuda graph."""
|
||||
# When custom allreduce is disabled, this will be None.
|
||||
if self.disabled or not self.should_custom_ar(input):
|
||||
return None
|
||||
if self._IS_CAPTURING:
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
return self.all_reduce(input, bsz_tensor=bsz_tensor, registered=True, is_compute_bound=is_compute_bound, overlap=overlap)
|
||||
else:
|
||||
# If warm up, mimic the allocation pattern since custom
|
||||
# allreduce is out-of-place.
|
||||
return torch.empty_like(input)
|
||||
else:
|
||||
# Note: outside of cuda graph context, custom allreduce incurs a
|
||||
# cost of cudaMemcpy, which should be small (<=1% of overall
|
||||
# latency) compared to the performance gain of using custom kernels
|
||||
return self.all_reduce(input, bsz_tensor=bsz_tensor, registered=False, is_compute_bound=is_compute_bound, overlap=overlap)
|
||||
|
||||
def close(self):
|
||||
if not self.disabled and self._ptr:
|
||||
vLLMCustomAllreduce.dispose(self._ptr)
|
||||
self._ptr = 0
|
||||
self.free_shared_buffer(self.meta_ptrs)
|
||||
self.free_shared_buffer(self.buffer_ptrs)
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
|
@ -0,0 +1,272 @@
|
|||
import ctypes
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from itertools import product
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import server.envs as envs
|
||||
from server.inference.distributed.cuda_wrapper import CudaRTLibrary
|
||||
from server.utils import cuda_device_count_stateless, update_environment_variables
|
||||
|
||||
|
||||
def producer(
|
||||
batch_src: Sequence[int],
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: Optional[str] = None,
|
||||
):
|
||||
if cuda_visible_devices is not None:
|
||||
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
|
||||
lib = CudaRTLibrary()
|
||||
for i in batch_src:
|
||||
lib.cudaSetDevice(i)
|
||||
pointer = lib.cudaMalloc(1024)
|
||||
lib.cudaMemset(pointer, 1, 1024)
|
||||
lib.cudaDeviceSynchronize()
|
||||
handle = lib.cudaIpcGetMemHandle(pointer)
|
||||
producer_queue.put(handle)
|
||||
open_success = consumer_queue.get()
|
||||
if open_success:
|
||||
# use two queues to simulate barrier
|
||||
producer_queue.put(0)
|
||||
consumer_queue.get()
|
||||
# check if the memory is modified
|
||||
host_data = (ctypes.c_char * 1024)()
|
||||
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
||||
for i in range(1024):
|
||||
if ord(host_data[i]) != 2:
|
||||
open_success = False
|
||||
break
|
||||
result_queue.put(open_success)
|
||||
lib.cudaDeviceReset()
|
||||
|
||||
|
||||
def consumer(
|
||||
batch_tgt: Sequence[int],
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: Optional[str] = None,
|
||||
):
|
||||
if cuda_visible_devices is not None:
|
||||
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
|
||||
lib = CudaRTLibrary()
|
||||
for j in batch_tgt:
|
||||
lib.cudaSetDevice(j)
|
||||
handle = producer_queue.get()
|
||||
open_success = False
|
||||
try:
|
||||
pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
|
||||
open_success = True
|
||||
except RuntimeError:
|
||||
# cannot error out here, because the producer process
|
||||
# is still waiting for the response.
|
||||
pass
|
||||
consumer_queue.put(open_success)
|
||||
if open_success:
|
||||
# modify the memory
|
||||
lib.cudaMemset(pointer, 2, 1024)
|
||||
lib.cudaDeviceSynchronize()
|
||||
# use two queues to simulate barrier
|
||||
producer_queue.get()
|
||||
consumer_queue.put(0)
|
||||
# check if the memory is modified
|
||||
host_data = (ctypes.c_char * 1024)()
|
||||
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
||||
for i in range(1024):
|
||||
if ord(host_data[i]) != 2:
|
||||
open_success = False
|
||||
break
|
||||
result_queue.put(open_success)
|
||||
lib.cudaDeviceReset()
|
||||
|
||||
|
||||
def can_actually_p2p(
|
||||
batch_src: Sequence[int],
|
||||
batch_tgt: Sequence[int],
|
||||
) -> Sequence[bool]:
|
||||
"""
|
||||
Usually, checking if P2P access is enabled can be done by
|
||||
`torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
|
||||
the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
|
||||
returns `True` even if P2P access is not actually possible.
|
||||
See https://github.com/vllm-project/vllm/issues/2728 and
|
||||
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
|
||||
Therefore, we have to perform a real P2P access to check if it is actually
|
||||
possible.
|
||||
|
||||
Note on p2p and cuda IPC:
|
||||
Usually, one process uses one GPU:
|
||||
GPU src --> cuda context src --> tensor src --> process src
|
||||
|
||||
We need to combine p2p and cuda IPC, so that:
|
||||
GPU src --> cuda context src --> tensor src --> process src
|
||||
|shared|
|
||||
GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
|
||||
That is to say, process src creates a tensor in GPU src, passes IPC handle to
|
||||
process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
|
||||
tensor in process tgt will be reflected in the tensor in process src, because
|
||||
they are the same memory segment.
|
||||
It is important to note that process tgt accesses the tensor in GPU tgt, not
|
||||
GPU src. That's why we need p2p access.
|
||||
|
||||
The most time-consuming part is the process creation. To avoid creating
|
||||
processes for every pair of GPUs, we use batched testing. We create two
|
||||
processes for testing all pairs of GPUs in batch. The trick is to reset
|
||||
the device after each test (which is not available in PyTorch).
|
||||
""" # noqa
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
# pass the CUDA_VISIBLE_DEVICES to the child process
|
||||
# to make sure they see the same set of GPUs
|
||||
|
||||
# make sure the processes are spawned
|
||||
smp = mp.get_context("spawn")
|
||||
producer_queue = smp.Queue()
|
||||
consumer_queue = smp.Queue()
|
||||
result_queue = smp.Queue()
|
||||
p_src = smp.Process(
|
||||
target=producer,
|
||||
args=(
|
||||
batch_src,
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices,
|
||||
),
|
||||
)
|
||||
p_tgt = smp.Process(
|
||||
target=consumer,
|
||||
args=(
|
||||
batch_tgt,
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices,
|
||||
),
|
||||
)
|
||||
p_src.start()
|
||||
p_tgt.start()
|
||||
p_src.join()
|
||||
p_tgt.join()
|
||||
assert p_src.exitcode == 0 and p_tgt.exitcode == 0
|
||||
result: List[bool] = []
|
||||
for src, tgt in zip(batch_src, batch_tgt):
|
||||
a = result_queue.get()
|
||||
b = result_queue.get()
|
||||
if a != b:
|
||||
print(
|
||||
"Two processes do not agree on the P2P access"
|
||||
" status on %d -> %d, treat as disabled.",
|
||||
src,
|
||||
tgt,
|
||||
)
|
||||
result.append(False)
|
||||
else:
|
||||
result.append(a)
|
||||
return result
|
||||
|
||||
|
||||
# why do we need this cache?
|
||||
# we are testing peer-to-peer (p2p) access between GPUs,across processes.
|
||||
# if we test it every time, it will be very slow, because we need to create
|
||||
# N * N * 2 processes, where N is the world size. This is very slow.
|
||||
# to reduce the time, we use a cache file to store the p2p access status.
|
||||
# the cache file is generated by the master process if it does not exist.
|
||||
# then all the processes can read the cache file to check the p2p access status.
|
||||
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
|
||||
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
|
||||
# e.g. used by different vllm engines. The device id in the cache file is a
|
||||
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
|
||||
# of visible devices in the vllm engine.
|
||||
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
|
||||
|
||||
|
||||
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||
"""Check if GPU src can access GPU tgt."""
|
||||
|
||||
# if the cache variable is already calculated,
|
||||
# read from the cache instead of checking it again
|
||||
global _gpu_p2p_access_cache
|
||||
if _gpu_p2p_access_cache is not None:
|
||||
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
||||
|
||||
is_distributed = dist.is_initialized()
|
||||
|
||||
num_dev = cuda_device_count_stateless()
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
if cuda_visible_devices is None:
|
||||
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
|
||||
|
||||
path = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT, f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
|
||||
)
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
from server.inference.distributed.parallel_state import get_world_group
|
||||
|
||||
if (not is_distributed or get_world_group().local_rank == 0) and (
|
||||
not os.path.exists(path)
|
||||
):
|
||||
# only the local master process (with local_rank == 0) can
|
||||
# enter this block to calculate the cache
|
||||
print("generating GPU P2P access cache in %s", path)
|
||||
cache: Dict[str, bool] = {}
|
||||
ids = list(range(num_dev))
|
||||
# batch of all pairs of GPUs
|
||||
batch_src, batch_tgt = zip(*list(product(ids, ids)))
|
||||
# NOTE: we use `subprocess` rather than `multiprocessing` here
|
||||
# because the caller might not have `if __name__ == "__main__":`,
|
||||
# in that case we cannot use spawn method in multiprocessing.
|
||||
# However, `can_actually_p2p` requires spawn method.
|
||||
# The fix is, we use `subprocess` to call the function,
|
||||
# where we have `if __name__ == "__main__":` in this file.
|
||||
|
||||
# use a temporary file to store the result
|
||||
# we don't use the output of the subprocess directly,
|
||||
# because the subprocess might produce logging output
|
||||
with tempfile.NamedTemporaryFile() as output_file:
|
||||
input_bytes = pickle.dumps((batch_src, batch_tgt, output_file.name))
|
||||
returned = subprocess.run(
|
||||
[sys.executable, __file__], input=input_bytes, capture_output=True
|
||||
)
|
||||
# check if the subprocess is successful
|
||||
try:
|
||||
returned.check_returncode()
|
||||
except Exception as e:
|
||||
# wrap raised exception to provide more information
|
||||
raise RuntimeError(
|
||||
f"Error happened when batch testing "
|
||||
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
|
||||
f"{returned.stderr.decode()}"
|
||||
) from e
|
||||
with open(output_file.name, "rb") as f:
|
||||
result = pickle.load(f)
|
||||
for _i, _j, r in zip(batch_src, batch_tgt, result):
|
||||
cache[f"{_i}->{_j}"] = r
|
||||
with open(path, "w") as f:
|
||||
json.dump(cache, f, indent=4)
|
||||
if is_distributed:
|
||||
get_world_group().barrier()
|
||||
print("reading GPU P2P access cache from %s", path)
|
||||
with open(path) as f:
|
||||
cache = json.load(f)
|
||||
_gpu_p2p_access_cache = cache
|
||||
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
||||
|
||||
|
||||
__all__ = ["gpu_p2p_access_check"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())
|
||||
result = can_actually_p2p(batch_src, batch_tgt)
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(pickle.dumps(result))
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,201 @@
|
|||
from contextlib import contextmanager
|
||||
from typing import Optional, Union
|
||||
|
||||
# ===================== import region =====================
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
|
||||
from server.inference.distributed.pynccl_wrapper import (
|
||||
NCCLLibrary,
|
||||
buffer_type,
|
||||
cudaStream_t,
|
||||
ncclComm_t,
|
||||
ncclDataTypeEnum,
|
||||
ncclRedOpTypeEnum,
|
||||
ncclUniqueId,
|
||||
)
|
||||
from server.inference.distributed.utils import StatelessProcessGroup
|
||||
|
||||
|
||||
class PyNcclCommunicator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group: Union[ProcessGroup, StatelessProcessGroup],
|
||||
device: Union[int, str, torch.device],
|
||||
library_path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the PyNcclCommunicator to. If None,
|
||||
it will be bind to f"cuda:{local_rank}".
|
||||
library_path: the path to the NCCL library. If None, it will
|
||||
use the default library path.
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device.
|
||||
"""
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
assert dist.is_initialized()
|
||||
assert (
|
||||
dist.get_backend(group) != dist.Backend.NCCL
|
||||
), "PyNcclCommunicator should be attached to a non-NCCL group."
|
||||
# note: this rank is the rank in the group
|
||||
self.rank = dist.get_rank(group)
|
||||
self.world_size = dist.get_world_size(group)
|
||||
else:
|
||||
self.rank = group.rank
|
||||
self.world_size = group.world_size
|
||||
|
||||
self.group = group
|
||||
|
||||
# if world_size == 1, no need to create communicator
|
||||
if self.world_size == 1:
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
self.stream = None
|
||||
return
|
||||
try:
|
||||
self.nccl = NCCLLibrary(library_path)
|
||||
except Exception:
|
||||
# disable because of missing NCCL library
|
||||
# e.g. in a non-GPU environment
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
self.stream = None
|
||||
return
|
||||
|
||||
self.available = True
|
||||
self.disabled = False
|
||||
|
||||
print("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
|
||||
|
||||
if self.rank == 0:
|
||||
# get the unique id from NCCL
|
||||
self.unique_id = self.nccl.ncclGetUniqueId()
|
||||
else:
|
||||
# construct an empty unique id
|
||||
self.unique_id = ncclUniqueId()
|
||||
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
||||
ranks = dist.get_process_group_ranks(group)
|
||||
# arg `src` in `broadcast` is the global rank
|
||||
dist.broadcast(tensor, src=ranks[0], group=group)
|
||||
byte_list = tensor.tolist()
|
||||
for i, byte in enumerate(byte_list):
|
||||
self.unique_id.internal[i] = byte
|
||||
else:
|
||||
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
# nccl communicator and stream will use this device
|
||||
# `torch.cuda.device` is a context manager that changes the
|
||||
# current cuda device to the specified one
|
||||
with torch.cuda.device(device):
|
||||
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
self.world_size, self.unique_id, self.rank
|
||||
)
|
||||
self.stream = torch.cuda.Stream()
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
data = torch.zeros(1, device=device)
|
||||
self.all_reduce(data)
|
||||
self.stream.synchronize()
|
||||
del data
|
||||
|
||||
# by default it is disabled, e.g. in profiling models and prefill phase.
|
||||
# to use it, use under `with obj.change_state(enable=True)`, usually
|
||||
# when we are using CUDA graph.
|
||||
self.disabled = True
|
||||
|
||||
def all_reduce(
|
||||
self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None
|
||||
):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
self.nccl.ncclAllReduce(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
self.nccl.ncclSend(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
dst,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def recv(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
self.nccl.ncclRecv(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
src,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def change_state(
|
||||
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
|
||||
):
|
||||
"""
|
||||
A context manager to change the state of the communicator.
|
||||
"""
|
||||
if enable is None:
|
||||
# guess a default value when not specified
|
||||
enable = self.available
|
||||
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
|
||||
old_disable = self.disabled
|
||||
old_stream = self.stream
|
||||
|
||||
self.stream = stream
|
||||
self.disabled = not enable
|
||||
yield
|
||||
|
||||
self.disabled = old_disable
|
||||
self.stream = old_stream
|
|
@ -0,0 +1,276 @@
|
|||
# This file is a pure Python wrapper for the NCCL library.
|
||||
# The main purpose is to use NCCL combined with CUDA graph.
|
||||
# Before writing this script, we tried the following approach:
|
||||
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
|
||||
# often gets stuck when initializing the NCCL communicator.
|
||||
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
|
||||
# contains many other potential cuda APIs, that are not allowed during
|
||||
# capturing the CUDA graph. For further details, please check
|
||||
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
|
||||
#
|
||||
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
|
||||
# doable, but we often encounter issues related with nccl versions, and need
|
||||
# to switch between different versions of NCCL. See
|
||||
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
|
||||
# A C/C++ binding is not flexible enough to handle this. It requires
|
||||
# recompilation of the code every time we want to switch between different
|
||||
# versions. This current implementation, with a **pure** Python wrapper, is
|
||||
# more flexible. We can easily switch between different versions of NCCL by
|
||||
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
|
||||
# variable in the code.
|
||||
|
||||
import ctypes
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from server.utils import find_nccl_library
|
||||
|
||||
|
||||
# === export types and functions from nccl to Python ===
|
||||
# for the original nccl definition, please check
|
||||
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
|
||||
|
||||
ncclResult_t = ctypes.c_int
|
||||
ncclComm_t = ctypes.c_void_p
|
||||
|
||||
|
||||
class ncclUniqueId(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||
|
||||
|
||||
cudaStream_t = ctypes.c_void_p
|
||||
buffer_type = ctypes.c_void_p
|
||||
|
||||
ncclDataType_t = ctypes.c_int
|
||||
|
||||
|
||||
class ncclDataTypeEnum:
|
||||
ncclInt8 = 0
|
||||
ncclChar = 0
|
||||
ncclUint8 = 1
|
||||
ncclInt32 = 2
|
||||
ncclInt = 2
|
||||
ncclUint32 = 3
|
||||
ncclInt64 = 4
|
||||
ncclUint64 = 5
|
||||
ncclFloat16 = 6
|
||||
ncclHalf = 6
|
||||
ncclFloat32 = 7
|
||||
ncclFloat = 7
|
||||
ncclFloat64 = 8
|
||||
ncclDouble = 8
|
||||
ncclBfloat16 = 9
|
||||
ncclNumTypes = 10
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||
if dtype == torch.int8:
|
||||
return cls.ncclInt8
|
||||
if dtype == torch.uint8:
|
||||
return cls.ncclUint8
|
||||
if dtype == torch.int32:
|
||||
return cls.ncclInt32
|
||||
if dtype == torch.int64:
|
||||
return cls.ncclInt64
|
||||
if dtype == torch.float16:
|
||||
return cls.ncclFloat16
|
||||
if dtype == torch.float32:
|
||||
return cls.ncclFloat32
|
||||
if dtype == torch.float64:
|
||||
return cls.ncclFloat64
|
||||
if dtype == torch.bfloat16:
|
||||
return cls.ncclBfloat16
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
ncclRedOp_t = ctypes.c_int
|
||||
|
||||
|
||||
class ncclRedOpTypeEnum:
|
||||
ncclSum = 0
|
||||
ncclProd = 1
|
||||
ncclMax = 2
|
||||
ncclMin = 3
|
||||
ncclAvg = 4
|
||||
ncclNumOps = 5
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, op: ReduceOp) -> int:
|
||||
if op == ReduceOp.SUM:
|
||||
return cls.ncclSum
|
||||
if op == ReduceOp.PRODUCT:
|
||||
return cls.ncclProd
|
||||
if op == ReduceOp.MAX:
|
||||
return cls.ncclMax
|
||||
if op == ReduceOp.MIN:
|
||||
return cls.ncclMin
|
||||
if op == ReduceOp.AVG:
|
||||
return cls.ncclAvg
|
||||
raise ValueError(f"Unsupported op: {op}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: List[Any]
|
||||
|
||||
|
||||
class NCCLLibrary:
|
||||
exported_functions = [
|
||||
# const char* ncclGetErrorString(ncclResult_t result)
|
||||
Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
|
||||
# ncclResult_t ncclGetVersion(int *version);
|
||||
Function("ncclGetVersion", ncclResult_t,
|
||||
[ctypes.POINTER(ctypes.c_int)]),
|
||||
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
|
||||
Function("ncclGetUniqueId", ncclResult_t,
|
||||
[ctypes.POINTER(ncclUniqueId)]),
|
||||
# ncclResult_t ncclCommInitRank(
|
||||
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
|
||||
# note that ncclComm_t is a pointer type, so the first argument
|
||||
# is a pointer to a pointer
|
||||
Function("ncclCommInitRank", ncclResult_t, [
|
||||
ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId,
|
||||
ctypes.c_int
|
||||
]),
|
||||
# ncclResult_t ncclAllReduce(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("ncclAllReduce", ncclResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||
ncclRedOp_t, ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# ncclResult_t ncclSend(
|
||||
# const void* sendbuff, size_t count, ncclDataType_t datatype,
|
||||
# int dest, ncclComm_t comm, cudaStream_t stream);
|
||||
Function("ncclSend", ncclResult_t, [
|
||||
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
|
||||
ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# ncclResult_t ncclRecv(
|
||||
# void* recvbuff, size_t count, ncclDataType_t datatype,
|
||||
# int src, ncclComm_t comm, cudaStream_t stream);
|
||||
Function("ncclRecv", ncclResult_t, [
|
||||
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
|
||||
ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# be cautious! this is a collective call, it will block until all
|
||||
# processes in the communicator have called this function.
|
||||
# because Python object destruction can happen in random order,
|
||||
# it is better not to call it at all.
|
||||
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
|
||||
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: Dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
|
||||
so_file = so_file or find_nccl_library()
|
||||
|
||||
try:
|
||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
NCCLLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = NCCLLibrary.path_to_library_cache[so_file]
|
||||
except Exception as e:
|
||||
print(
|
||||
"Failed to load NCCL library from %s ."
|
||||
"It is expected if you are not running on NVIDIA/AMD GPUs."
|
||||
"Otherwise, the nccl library might not exist, be corrupted "
|
||||
"or it does not support the current platform %s."
|
||||
"If you already have the library, please set the "
|
||||
"environment variable VLLM_NCCL_SO_PATH"
|
||||
" to point to the correct nccl library path.", so_file,
|
||||
platform.platform())
|
||||
raise e
|
||||
|
||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||
_funcs: Dict[str, Any] = {}
|
||||
for func in NCCLLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def ncclGetErrorString(self, result: ncclResult_t) -> str:
|
||||
return self._funcs["ncclGetErrorString"](result).decode("utf-8")
|
||||
|
||||
def NCCL_CHECK(self, result: ncclResult_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.ncclGetErrorString(result)
|
||||
raise RuntimeError(f"NCCL error: {error_str}")
|
||||
|
||||
def ncclGetVersion(self) -> str:
|
||||
version = ctypes.c_int()
|
||||
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
|
||||
version_str = str(version.value)
|
||||
# something like 21903 --> "2.19.3"
|
||||
major = version_str[0].lstrip("0")
|
||||
minor = version_str[1:3].lstrip("0")
|
||||
patch = version_str[3:].lstrip("0")
|
||||
return f"{major}.{minor}.{patch}"
|
||||
|
||||
def ncclGetUniqueId(self) -> ncclUniqueId:
|
||||
unique_id = ncclUniqueId()
|
||||
self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](
|
||||
ctypes.byref(unique_id)))
|
||||
return unique_id
|
||||
|
||||
def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
|
||||
rank: int) -> ncclComm_t:
|
||||
comm = ncclComm_t()
|
||||
self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm),
|
||||
world_size, unique_id,
|
||||
rank))
|
||||
return comm
|
||||
|
||||
def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, comm: ncclComm_t,
|
||||
stream: cudaStream_t) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# and `op` should be `ncclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
|
||||
datatype, op, comm,
|
||||
stream))
|
||||
|
||||
def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
|
||||
dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,
|
||||
dest, comm, stream))
|
||||
|
||||
def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,
|
||||
src: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src,
|
||||
comm, stream))
|
||||
|
||||
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
|
||||
"ncclComm_t", "cudaStream_t", "buffer_type"
|
||||
]
|
|
@ -0,0 +1,219 @@
|
|||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
import dataclasses
|
||||
import pickle
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch.distributed import TCPStore
|
||||
|
||||
import server.envs as envs
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(
|
||||
numerator, denominator
|
||||
)
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator and return
|
||||
the division value."""
|
||||
ensure_divisibility(numerator, denominator)
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
def split_tensor_along_last_dim(
|
||||
tensor: torch.Tensor,
|
||||
num_partitions: int,
|
||||
contiguous_split_chunks: bool = False,
|
||||
) -> Sequence[torch.Tensor]:
|
||||
"""Split a tensor along its last dimension.
|
||||
|
||||
Arguments:
|
||||
tensor: input tensor.
|
||||
num_partitions: number of partitions to split the tensor
|
||||
contiguous_split_chunks: If True, make each chunk contiguous
|
||||
in memory.
|
||||
|
||||
Returns:
|
||||
A list of Tensors
|
||||
"""
|
||||
# Get the size and dimension.
|
||||
last_dim = tensor.dim() - 1
|
||||
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
|
||||
# Split.
|
||||
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
||||
# NOTE: torch.split does not create contiguous tensors by default.
|
||||
if contiguous_split_chunks:
|
||||
return tuple(chunk.contiguous() for chunk in tensor_list)
|
||||
|
||||
return tensor_list
|
||||
|
||||
|
||||
def get_pp_indices(
|
||||
num_hidden_layers: int, pp_rank: int, pp_size: int
|
||||
) -> Tuple[int, int]:
|
||||
"""Try to evenly distribute layers across partitions.
|
||||
If the number of layers is not divisible by the number of partitions,
|
||||
the last partition will have the remaining layers.
|
||||
"""
|
||||
partition_list_str = envs.VLLM_PP_LAYER_PARTITION
|
||||
if partition_list_str is not None:
|
||||
try:
|
||||
partitions = [int(layer) for layer in partition_list_str.split(",")]
|
||||
except ValueError as err:
|
||||
raise ValueError(
|
||||
"Invalid partition string: {}".format(partition_list_str)
|
||||
) from err
|
||||
if len(partitions) != pp_size:
|
||||
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
|
||||
if sum(partitions) != num_hidden_layers:
|
||||
raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.")
|
||||
start_layer = sum(partitions[:pp_rank])
|
||||
end_layer = start_layer + partitions[pp_rank]
|
||||
else:
|
||||
layers_per_partition = num_hidden_layers // pp_size
|
||||
start_layer = pp_rank * layers_per_partition
|
||||
end_layer = start_layer + layers_per_partition
|
||||
|
||||
if pp_rank == pp_size - 1:
|
||||
end_layer = num_hidden_layers
|
||||
|
||||
return (start_layer, end_layer)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StatelessProcessGroup:
|
||||
"""A dataclass to hold a metadata store, and the rank, world_size of the
|
||||
group. Only use it to communicate metadata between processes.
|
||||
For data-plane communication, create NCCL-related objects.
|
||||
"""
|
||||
|
||||
rank: int
|
||||
world_size: int
|
||||
store: torch._C._distributed_c10d.Store
|
||||
data_expiration_seconds: int = 3600 # 1 hour
|
||||
|
||||
# dst rank -> counter
|
||||
send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
||||
# src rank -> counter
|
||||
recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
||||
broadcast_send_counter: int = 0
|
||||
broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
||||
|
||||
# A deque to store the data entries, with key and timestamp.
|
||||
entries: Deque[Tuple[str, float]] = dataclasses.field(default_factory=deque)
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.rank < self.world_size
|
||||
self.send_dst_counter = {i: 0 for i in range(self.world_size)}
|
||||
self.recv_src_counter = {i: 0 for i in range(self.world_size)}
|
||||
self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)}
|
||||
|
||||
def send_obj(self, obj: Any, dst: int):
|
||||
"""Send an object to a destination rank."""
|
||||
self.expire_data()
|
||||
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
|
||||
self.store.set(key, pickle.dumps(obj))
|
||||
self.send_dst_counter[dst] += 1
|
||||
self.entries.append((key, time.time()))
|
||||
|
||||
def expire_data(self):
|
||||
"""Expire data that is older than `data_expiration_seconds` seconds."""
|
||||
while self.entries:
|
||||
# check the oldest entry
|
||||
key, timestamp = self.entries[0]
|
||||
if time.time() - timestamp > self.data_expiration_seconds:
|
||||
self.store.delete_key(key)
|
||||
self.entries.popleft()
|
||||
else:
|
||||
break
|
||||
|
||||
def recv_obj(self, src: int) -> Any:
|
||||
"""Receive an object from a source rank."""
|
||||
obj = pickle.loads(
|
||||
self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}")
|
||||
)
|
||||
self.recv_src_counter[src] += 1
|
||||
return obj
|
||||
|
||||
def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
|
||||
"""Broadcast an object from a source rank to all other ranks.
|
||||
It does not clean up after all ranks have received the object.
|
||||
Use it for limited times, e.g., for initialization.
|
||||
"""
|
||||
if self.rank == src:
|
||||
self.expire_data()
|
||||
key = f"broadcast_from/{src}/" f"{self.broadcast_send_counter}"
|
||||
self.store.set(key, pickle.dumps(obj))
|
||||
self.broadcast_send_counter += 1
|
||||
self.entries.append((key, time.time()))
|
||||
return obj
|
||||
else:
|
||||
key = f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}"
|
||||
recv_obj = pickle.loads(self.store.get(key))
|
||||
self.broadcast_recv_src_counter[src] += 1
|
||||
return recv_obj
|
||||
|
||||
def all_gather_obj(self, obj: Any) -> list[Any]:
|
||||
"""All gather an object from all ranks."""
|
||||
gathered_objs = []
|
||||
for i in range(self.world_size):
|
||||
if i == self.rank:
|
||||
gathered_objs.append(obj)
|
||||
self.broadcast_obj(obj, src=self.rank)
|
||||
else:
|
||||
recv_obj = self.broadcast_obj(None, src=i)
|
||||
gathered_objs.append(recv_obj)
|
||||
return gathered_objs
|
||||
|
||||
def barrier(self):
|
||||
"""A barrier to synchronize all ranks."""
|
||||
for i in range(self.world_size):
|
||||
if i == self.rank:
|
||||
self.broadcast_obj(None, src=self.rank)
|
||||
else:
|
||||
self.broadcast_obj(None, src=i)
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
host: str,
|
||||
port: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
data_expiration_seconds: int = 3600,
|
||||
) -> "StatelessProcessGroup":
|
||||
"""A replacement for `torch.distributed.init_process_group` that does not
|
||||
pollute the global state.
|
||||
|
||||
If we have process A and process B called `torch.distributed.init_process_group`
|
||||
to form a group, and then we want to form another group with process A, B, C,
|
||||
D, it is not possible in PyTorch, because process A and process B have already
|
||||
formed a group, and process C and process D cannot join that group. This
|
||||
function is a workaround for this issue.
|
||||
|
||||
`torch.distributed.init_process_group` is a global call, while this function
|
||||
is a stateless call. It will return a `StatelessProcessGroup` object that can be
|
||||
used for exchanging metadata. With this function, process A and process B
|
||||
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
|
||||
C, and D can call `StatelessProcessGroup.create` to form another group.
|
||||
""" # noqa
|
||||
store = TCPStore(
|
||||
host_name=host,
|
||||
port=port,
|
||||
world_size=world_size,
|
||||
is_master=(rank == 0),
|
||||
)
|
||||
|
||||
return StatelessProcessGroup(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
store=store,
|
||||
data_expiration_seconds=data_expiration_seconds,
|
||||
)
|
284
ktransformers/server/balance_serve/inference/forward_batch.py
Normal file
284
ktransformers/server/balance_serve/inference/forward_batch.py
Normal file
|
@ -0,0 +1,284 @@
|
|||
'''
|
||||
Date: 2024-11-12 14:15:16
|
||||
LastEditors: Xie Weiyu ervinxie@qq.com
|
||||
LastEditTime: 2024-11-26 08:12:49
|
||||
'''
|
||||
import torch
|
||||
from ktransformers.server.balance_serve.settings import sched_ext
|
||||
from ktransformers.server.balance_serve.inference.query_manager import QueryManager, QueryInfo
|
||||
import time
|
||||
from ktransformers.server.config.config import Config
|
||||
class ForwardBatchInput:
|
||||
|
||||
class ForwardMiniBatch:
|
||||
q_indptr: torch.Tensor
|
||||
kv_indptr: torch.Tensor
|
||||
kv_indices: torch.Tensor
|
||||
kv_last_page_len: torch.Tensor
|
||||
kv_len: torch.Tensor
|
||||
position_ids: torch.Tensor
|
||||
tokens: torch.Tensor
|
||||
batch_indices: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
chunk_size: int
|
||||
decode_batch: int
|
||||
is_last_prefill_chunk: bool
|
||||
logits_start: list
|
||||
|
||||
temperatures: torch.Tensor
|
||||
top_ps: torch.Tensor
|
||||
|
||||
def __init__(self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo], prefill_s: list[int] = None, prefill_l: list[int] = None, device = torch.device('cuda'), page_size = 256):
|
||||
batch_decode = len(decode_querys_info)
|
||||
batch_prefill = len(prefill_querys_info)
|
||||
|
||||
self.q_indptr = torch.tensor([0], device=device, dtype=torch.int32)
|
||||
self.kv_indptr = torch.tensor([0], device=device, dtype=torch.int32)
|
||||
self.kv_indices = torch.tensor([], device=device, dtype=torch.int32)
|
||||
self.kv_len = torch.tensor([], device=device, dtype=torch.int32)
|
||||
self.kv_last_page_len = torch.tensor([], device=device, dtype=torch.int32)
|
||||
self.position_ids = torch.tensor([], device=device, dtype=torch.int32)
|
||||
self.tokens = torch.tensor([], device=device, dtype=torch.int32)
|
||||
|
||||
self.temperatures = torch.tensor([], device=device, dtype=torch.float32)
|
||||
self.top_ps = torch.tensor([], device=device, dtype=torch.float32)
|
||||
|
||||
self.logits_start = []
|
||||
self.decode_batch = batch_decode
|
||||
self.num_tokens = batch_decode + sum(prefill_l)
|
||||
self.batch_size = batch_decode + batch_prefill
|
||||
|
||||
for i, prefill_query_info in enumerate(prefill_querys_info):
|
||||
if prefill_query_info != None:
|
||||
prefill_kv_block_len = (prefill_query_info.active_position + prefill_l[i] + page_size - 1) // page_size if prefill_query_info is not None else 0
|
||||
# print(f"block_len: {prefill_kv_block_len}, page_size: {page_size}")
|
||||
self.q_indptr = torch.concat((self.q_indptr, torch.tensor([prefill_l[i] + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
|
||||
self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([prefill_kv_block_len + self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
|
||||
self.kv_indices = torch.concat((self.kv_indices, prefill_query_info.block_index[:prefill_kv_block_len]), dim=0)
|
||||
self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i]) % page_size if (prefill_query_info.active_position + prefill_l[i]) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)
|
||||
self.kv_len = torch.concat((self.kv_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i])], device=device, dtype=torch.int32)), dim=0)
|
||||
self.position_ids = torch.concat((self.position_ids, torch.arange(prefill_s[i], prefill_l[i] + prefill_s[i], device=device, dtype=torch.int32)), dim=0)
|
||||
self.tokens = torch.concat((self.tokens, prefill_query_info.query_tokens[prefill_s[i]:prefill_s[i] + prefill_l[i]]), dim=0)
|
||||
self.logits_start.append(prefill_l[i] - 1 if len(self.logits_start) == 0 else sum(prefill_l[:i+1])-1)
|
||||
|
||||
self.temperatures = torch.concat((self.temperatures, torch.tensor([prefill_query_info.temperature], device=device, dtype=torch.float32)), dim=0)
|
||||
self.top_ps = torch.concat((self.top_ps, torch.tensor([prefill_query_info.top_p], device=device, dtype=torch.float32)), dim=0)
|
||||
|
||||
for decode_query_info in decode_querys_info:
|
||||
decode_kv_block_len = (decode_query_info.active_position + 1 + page_size - 1) // page_size
|
||||
self.q_indptr = torch.concat((self.q_indptr, torch.tensor([1 + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
|
||||
self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([decode_kv_block_len+self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
|
||||
self.kv_indices = torch.concat((self.kv_indices, decode_query_info.block_index[:decode_kv_block_len]), dim=0)
|
||||
self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(decode_query_info.active_position + 1) % page_size if (decode_query_info.active_position + 1) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)
|
||||
self.kv_len = torch.concat((self.kv_len, torch.tensor([(decode_query_info.active_position + 1)], device=device, dtype=torch.int32)), dim=0)
|
||||
self.position_ids = torch.concat((self.position_ids, torch.arange(decode_query_info.active_position, decode_query_info.active_position + 1, device=device, dtype=torch.int32)), dim=0)
|
||||
if decode_query_info.active_position > 0:
|
||||
self.tokens = torch.concat((self.tokens, decode_query_info.query_tokens[decode_query_info.active_position:decode_query_info.active_position+1]), dim=0)
|
||||
else:
|
||||
self.tokens = torch.concat((self.tokens, torch.tensor([0], device=device, dtype=torch.int32)), dim=0)
|
||||
self.logits_start.append(0 if len(self.logits_start) == 0 else self.logits_start[-1]+1)
|
||||
|
||||
self.temperatures = torch.concat((self.temperatures, torch.tensor([decode_query_info.temperature], device=device, dtype=torch.float32)), dim=0)
|
||||
self.top_ps = torch.concat((self.top_ps, torch.tensor([decode_query_info.top_p], device=device, dtype=torch.float32)), dim=0)
|
||||
|
||||
self.q_indptr = self.q_indptr.contiguous()
|
||||
self.kv_indptr = self.kv_indptr.contiguous()
|
||||
self.kv_indices = self.kv_indices.contiguous()
|
||||
self.kv_len = self.kv_len.contiguous()
|
||||
self.kv_last_page_len = self.kv_last_page_len.contiguous()
|
||||
self.position_ids = self.position_ids.contiguous()
|
||||
self.tokens = self.tokens.contiguous()
|
||||
|
||||
self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32)
|
||||
|
||||
def fill(self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo], prefill_s: list[int] = None, prefill_l: list[int] = None, device = torch.device('cuda'), page_size = 256):
|
||||
batch_decode = len(decode_querys_info)
|
||||
batch_prefill = len(prefill_querys_info)
|
||||
|
||||
self.q_indptr = torch.tensor([0], device=device, dtype=torch.int32)
|
||||
self.kv_indptr = torch.tensor([0], device=device, dtype=torch.int32)
|
||||
self.kv_indices = torch.tensor([], device=device, dtype=torch.int32)
|
||||
self.kv_len = torch.tensor([], device=device, dtype=torch.int32)
|
||||
self.kv_last_page_len = torch.tensor([], device=device, dtype=torch.int32)
|
||||
new_position_ids = torch.tensor([], device=device, dtype=torch.int32)
|
||||
new_tokens = torch.tensor([], device=device, dtype=torch.int32)
|
||||
|
||||
self.temperatures = torch.tensor([], device=device, dtype=torch.float32)
|
||||
self.top_ps = torch.tensor([], device=device, dtype=torch.float32)
|
||||
|
||||
self.logits_start = []
|
||||
self.decode_batch = batch_decode
|
||||
self.num_tokens = batch_decode + sum(prefill_l)
|
||||
self.batch_size = batch_decode + batch_prefill
|
||||
|
||||
for i, prefill_query_info in enumerate(prefill_querys_info):
|
||||
prefill_kv_block_len = (prefill_query_info.active_position + prefill_l[i] + page_size - 1) // page_size if prefill_query_info is not None else 0
|
||||
# print(f"block_len: {prefill_kv_block_len}, page_size: {page_size}")
|
||||
self.q_indptr = torch.concat((self.q_indptr, torch.tensor([prefill_l[i] + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
|
||||
self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([prefill_kv_block_len + self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
|
||||
self.kv_indices = torch.concat((self.kv_indices, prefill_query_info.block_index[:prefill_kv_block_len]), dim=0)
|
||||
self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i]) % page_size if (prefill_query_info.active_position + prefill_l[i]) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)
|
||||
self.kv_len = torch.concat((self.kv_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i])], device=device, dtype=torch.int32)), dim=0)
|
||||
new_position_ids = torch.concat((new_position_ids, torch.arange(prefill_s[i], prefill_l[i] + prefill_s[i], device=device, dtype=torch.int32)), dim=0)
|
||||
new_tokens = torch.concat((new_tokens, prefill_query_info.query_tokens[prefill_s[i]:prefill_s[i] + prefill_l[i]]), dim=0)
|
||||
self.logits_start.append(prefill_l[i] - 1 if len(self.logits_start) == 0 else sum(prefill_l[:i+1])-1)
|
||||
|
||||
self.temperatures = torch.concat((self.temperatures, torch.tensor([prefill_query_info.temperature], device=device, dtype=torch.float32)), dim=0)
|
||||
self.top_ps = torch.concat((self.top_ps, torch.tensor([prefill_query_info.top_p], device=device, dtype=torch.float32)), dim=0)
|
||||
|
||||
|
||||
for decode_query_info in decode_querys_info:
|
||||
decode_kv_block_len = (decode_query_info.active_position + 1 + page_size - 1) // page_size
|
||||
self.q_indptr = torch.concat((self.q_indptr, torch.tensor([1 + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
|
||||
self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([decode_kv_block_len+self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
|
||||
self.kv_indices = torch.concat((self.kv_indices, decode_query_info.block_index[:decode_kv_block_len]), dim=0)
|
||||
self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(decode_query_info.active_position + 1) % page_size if (decode_query_info.active_position + 1) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)
|
||||
self.kv_len = torch.concat((self.kv_len, torch.tensor([(decode_query_info.active_position + 1)], device=device, dtype=torch.int32)), dim=0)
|
||||
new_position_ids = torch.concat((new_position_ids, torch.arange(decode_query_info.active_position, decode_query_info.active_position + 1, device=device, dtype=torch.int32)), dim=0)
|
||||
if decode_query_info.active_position > 0:
|
||||
new_tokens = torch.concat((new_tokens, decode_query_info.query_tokens[decode_query_info.active_position:decode_query_info.active_position+1]), dim=0)
|
||||
else:
|
||||
new_tokens = torch.concat((new_tokens, torch.tensor([0], device=device, dtype=torch.int32)), dim=0)
|
||||
self.logits_start.append(0 if len(self.logits_start) == 0 else self.logits_start[-1]+1)
|
||||
|
||||
self.temperatures = torch.concat((self.temperatures, torch.tensor([decode_query_info.temperature], device=device, dtype=torch.float32)), dim=0)
|
||||
self.top_ps = torch.concat((self.top_ps, torch.tensor([decode_query_info.top_p], device=device, dtype=torch.float32)), dim=0)
|
||||
|
||||
|
||||
self.q_indptr = self.q_indptr.contiguous()
|
||||
self.kv_indptr = self.kv_indptr.contiguous()
|
||||
self.kv_indices = self.kv_indices.contiguous()
|
||||
self.kv_len = self.kv_len.contiguous()
|
||||
self.kv_last_page_len = self.kv_last_page_len.contiguous()
|
||||
|
||||
self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32)
|
||||
|
||||
# copy new_position_ids and new_tokens to self.position_ids and self.tokens
|
||||
# print("new_position_ids: ", new_position_ids)
|
||||
# self.print()
|
||||
self.position_ids[:new_position_ids.size(0)].copy_(new_position_ids)
|
||||
self.position_ids[new_position_ids.size(0):].zero_()
|
||||
self.tokens[:new_tokens.size(0)].copy_(new_tokens)
|
||||
|
||||
|
||||
forward_minibatchs: list[ForwardMiniBatch]
|
||||
batch_size: int
|
||||
minibatch: ForwardMiniBatch
|
||||
|
||||
|
||||
|
||||
def __init__(self, batch : sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None, device=None, tokens: torch.Tensor = None):
|
||||
|
||||
if batch is None:
|
||||
return
|
||||
|
||||
|
||||
prefill_minibatches = batch.prefill_mini_batches
|
||||
decode_mini_batches = [item for sublist in batch.decode_mini_batches for item in sublist]
|
||||
prefill_querys_info = []
|
||||
prefill_s = []
|
||||
prefill_l = []
|
||||
decode_querys_info = []
|
||||
self.batch_size = 1
|
||||
for (id, s, l) in prefill_minibatches:
|
||||
prefill_querys_info.append(query_manager.query_map[id])
|
||||
prefill_s.append(s)
|
||||
prefill_l.append(l)
|
||||
for decode_batch_idx in decode_mini_batches:
|
||||
if query_manager.query_map[decode_batch_idx].decode_start_time is None:
|
||||
query_manager.query_map[decode_batch_idx].decode_start_time =time.time()
|
||||
decode_querys_info.append(query_manager.query_map[decode_batch_idx])
|
||||
|
||||
|
||||
minibatch = ForwardBatchInput.ForwardMiniBatch(prefill_querys_info, decode_querys_info, prefill_s, prefill_l, device = query_manager.device, page_size = query_manager.page_size)
|
||||
|
||||
self.minibatch = minibatch
|
||||
|
||||
@classmethod
|
||||
def gen_max_forward_batch(
|
||||
cls,
|
||||
device=None,
|
||||
tokens: torch.Tensor = None,
|
||||
num_mini_batches: int = 1,
|
||||
max_seq_length: int = 1024, # TODO: add to yaml
|
||||
prefill_query_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size, # TODO: use config
|
||||
prefill_active_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size,
|
||||
gen_prefill: bool = True,
|
||||
decode_batch_size: int = Config().max_decode_batch_size,
|
||||
decode_active_position: torch.Tensor = None,
|
||||
page_size = 256,
|
||||
cuda_lens = 1
|
||||
):
|
||||
instance = cls()
|
||||
|
||||
instance.batch_size = num_mini_batches
|
||||
page_size = page_size
|
||||
|
||||
prefill_query_info = []
|
||||
offset = 0
|
||||
if gen_prefill and prefill_query_length != 0:
|
||||
for i in range(Config().max_prefill_batch_size):
|
||||
prefill_query_info.append(QueryInfo(i, prefill_query_length, max_seq_length, page_size, device, offset=offset))
|
||||
offset += max_seq_length // page_size
|
||||
|
||||
decode_querys_info = []
|
||||
for i in range(min(decode_batch_size, cuda_lens)):
|
||||
query_info = QueryInfo(i+Config().max_prefill_batch_size, prefill_query_length, max_seq_length, page_size, device, is_prefill=False, offset=offset)
|
||||
offset += max_seq_length // page_size
|
||||
if tokens is not None:
|
||||
query_info.query_tokens[prefill_active_length:prefill_active_length + 1].copy_(tokens)
|
||||
if decode_active_position is None:
|
||||
query_info.active_position = prefill_active_length
|
||||
else:
|
||||
query_info.active_position = decode_active_position[i]
|
||||
|
||||
decode_querys_info.append(query_info)
|
||||
|
||||
if prefill_query_length*Config().max_prefill_batch_size + len(decode_querys_info) < cuda_lens:
|
||||
decode_querys_info.append(query_info)
|
||||
|
||||
instance.minibatch = ForwardBatchInput.ForwardMiniBatch(prefill_query_info, decode_querys_info, [0, 0], [prefill_active_length for _ in range(Config().max_prefill_batch_size)], device, page_size)
|
||||
|
||||
return instance
|
||||
|
||||
def fill(self, batch : sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None, page_size = 256):
|
||||
if batch is None:
|
||||
return
|
||||
prefill_minibatches = batch.prefill_mini_batches
|
||||
decode_mini_batches = [item for sublist in batch.decode_mini_batches for item in sublist]
|
||||
|
||||
prefill_querys_info = []
|
||||
prefill_s = []
|
||||
prefill_l = []
|
||||
decode_querys_info = []
|
||||
self.batch_size = 1
|
||||
for (id, s, l) in prefill_minibatches:
|
||||
prefill_querys_info.append(query_manager.query_map[id])
|
||||
prefill_s.append(s)
|
||||
prefill_l.append(l)
|
||||
for decode_batch_idx in decode_mini_batches:
|
||||
if query_manager.query_map[decode_batch_idx].decode_start_time is None:
|
||||
query_manager.query_map[decode_batch_idx].decode_start_time =time.time()
|
||||
decode_querys_info.append(query_manager.query_map[decode_batch_idx])
|
||||
|
||||
self.minibatch.fill(prefill_querys_info, decode_querys_info, prefill_s, prefill_l, device=query_manager.device, page_size=page_size)
|
||||
|
||||
|
||||
|
||||
class ForwardBatchOutput:
|
||||
logits: list[torch.Tensor]
|
||||
num_batchs: int
|
||||
batch_sizes: list[int]
|
||||
generated_tokens_num: list[int]
|
||||
lm_start: list[int]
|
||||
|
||||
temperatures: list[torch.Tensor]
|
||||
top_ps: list[torch.Tensor]
|
||||
|
||||
def __init__(self):
|
||||
self.logits = []
|
||||
self.batch_sizes = []
|
||||
self.generated_tokens_num = []
|
||||
self.top_ps = []
|
||||
self.temperatures = []
|
||||
pass
|
306
ktransformers/server/balance_serve/inference/model_runner.py
Normal file
306
ktransformers/server/balance_serve/inference/model_runner.py
Normal file
|
@ -0,0 +1,306 @@
|
|||
"""
|
||||
Date: 2024-11-07 07:02:20
|
||||
LastEditors: djw
|
||||
LastEditTime: 2024-12-10 08:48:32
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import queue
|
||||
import signal
|
||||
import queue
|
||||
from typing import AsyncIterable
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from contextlib import asynccontextmanager
|
||||
from pydantic import BaseModel, Field
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import time
|
||||
import torch.multiprocessing as mp
|
||||
import random
|
||||
import torch.distributed as dist
|
||||
import zmq
|
||||
import tempfile
|
||||
from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput
|
||||
|
||||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM
|
||||
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
|
||||
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
|
||||
from ktransformers.server.balance_serve.settings import sched_ext
|
||||
|
||||
|
||||
|
||||
def pad_num_tokens(num_tokens):
|
||||
return (num_tokens + 63) // 64 * 64
|
||||
|
||||
def deduplicate_and_sort(lst):
|
||||
return sorted(set(lst))
|
||||
class ModelRunner:
|
||||
"""A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile."""
|
||||
|
||||
model: KDeepseekV3ForCausalLM
|
||||
input: ForwardBatchInput | list[ForwardBatchInput]
|
||||
output: ForwardBatchOutput
|
||||
|
||||
def __init__(self, model = None, device = None, use_cuda_graph = False, max_decode_batch_size = 1, max_chunk_size = 4096, num_mini_batches: int = 1, page_size = 256):
|
||||
|
||||
self.stream = torch.cuda.Stream(device=device)
|
||||
# 先注释掉
|
||||
self.model = model # Compile and move model to the specified device
|
||||
self.device = device
|
||||
self.input = None
|
||||
self.features_buf = None
|
||||
self.output = None
|
||||
self.graph_memory_pool = None
|
||||
self.cuda_graphs = deduplicate_and_sort([1, 2, 3, Config().max_batch_size, 64, Config().chunk_size])
|
||||
self.use_cuda_graph = use_cuda_graph
|
||||
self.model_time = 0
|
||||
self.page_size = page_size
|
||||
# GPU timing for model execution
|
||||
self.start_model_event = torch.cuda.Event(enable_timing=True)
|
||||
self.end_model_event = torch.cuda.Event(enable_timing=True)
|
||||
if isinstance(self.cuda_graphs, list):
|
||||
self.graphs = [torch.cuda.CUDAGraph() for _ in range(len(self.cuda_graphs))]
|
||||
self.page_idx_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))]
|
||||
self.page_offset_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))]
|
||||
else:
|
||||
self.graphs = torch.cuda.CUDAGraph()
|
||||
self.page_idx_buf = torch.zeros([self.cuda_graphs], dtype=torch.int32, device = self.device)
|
||||
self.page_offset_buf = torch.zeros([self.cuda_graphs], dtype=torch.int32, device = self.device)
|
||||
self.num_mini_batches = num_mini_batches
|
||||
|
||||
self.max_chunk_size = max_chunk_size
|
||||
|
||||
self.bsz_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device)
|
||||
self.num_tokens_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device)
|
||||
def warmup(self):
|
||||
|
||||
def capture_graphs(cuda_graph_idx=-1):
|
||||
if cuda_graph_idx != -1:
|
||||
with torch.cuda.graph(self.graphs[cuda_graph_idx], pool=self.graph_memory_pool, stream=self.stream):
|
||||
self.outputs_buf[cuda_graph_idx] = self.model(self.input[cuda_graph_idx], self.features_buf[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[cuda_graph_idx], self.page_offset_buf[cuda_graph_idx], cuda_graph_idx=cuda_graph_idx)
|
||||
self.graph_memory_pool = self.graphs[cuda_graph_idx].pool()
|
||||
else:
|
||||
with torch.cuda.graph(self.graphs, pool=self.graph_memory_pool, stream=self.stream):
|
||||
self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf)
|
||||
self.graph_memory_pool = self.graphs.pool()
|
||||
|
||||
if isinstance(self.cuda_graphs, list):
|
||||
self.input = []
|
||||
self.features_buf = []
|
||||
self.outputs_buf = []
|
||||
self.bsz_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)
|
||||
self.num_tokens_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)
|
||||
for i in range(len(self.cuda_graphs)):
|
||||
prefill_query_length = (self.cuda_graphs[i] - Config().max_decode_batch_size) // Config().max_prefill_batch_size if self.cuda_graphs[i] > Config().max_decode_batch_size else 0 #@TODO only supprot 2 prefill batch
|
||||
self.input.append(ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches, prefill_query_length=prefill_query_length, prefill_active_length=prefill_query_length, page_size=self.page_size, cuda_lens = self.cuda_graphs[i]))
|
||||
|
||||
self.features_buf.append(self.model.batch_embeddings(self.input[i]))
|
||||
batch_size = self.input[i].minibatch.q_indptr.size(0)-1
|
||||
num_tokens = self.features_buf[i][0].size(0)
|
||||
print("capturing cuda graph", batch_size, num_tokens)
|
||||
self.bsz_tensor_buf[0] = batch_size
|
||||
self.num_tokens_tensor_buf[0] = num_tokens
|
||||
|
||||
self.model.flash_infer_attn_plan(self.input[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
|
||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
|
||||
page_idx, page_offset = self.model.cache.get_page_table(self.input[i].minibatch.position_ids, self.input[i].minibatch.q_indptr, self.input[i].minibatch.kv_indptr, self.input[i].minibatch.kv_indices, self.num_tokens_tensor_buf)
|
||||
|
||||
self.page_idx_buf[i][:num_tokens].copy_(page_idx[:num_tokens])
|
||||
self.page_offset_buf[i][:num_tokens].copy_(page_offset[:num_tokens])
|
||||
self.page_idx_buf[i][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size -1)
|
||||
|
||||
self.outputs_buf.append(None)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
for warm_up_iters in range(11):
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.outputs_buf[i] = self.model(self.input[i], self.features_buf[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[i], self.page_offset_buf[i])
|
||||
torch.cuda.synchronize()
|
||||
|
||||
capture_graphs(i)
|
||||
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.graphs[i].replay()
|
||||
|
||||
self.sync(calc_time=False)
|
||||
print(f"cuda_graph: {i+1}/{len(self.cuda_graphs)}, warmup finished.")
|
||||
else:
|
||||
self.input = ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches)
|
||||
|
||||
self.features_buf = self.model.batch_embeddings(self.input)
|
||||
batch_size = self.input.minibatch.q_indptr.size(0)-1
|
||||
num_tokens = self.features_buf[0].size(0)
|
||||
|
||||
|
||||
self.bsz_tensor_buf = torch.tensor([batch_size], dtype=torch.int32, device=self.device)
|
||||
self.num_tokens_tensor_buf = torch.tensor([num_tokens], dtype=torch.int32, device=self.device)
|
||||
|
||||
|
||||
self.model.flash_infer_attn_plan(self.input, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
|
||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
|
||||
page_idx, page_offset = self.model.cache.get_page_table(self.input.minibatch.position_ids, self.input.minibatch.q_indptr, self.input.minibatch.kv_indptr, self.input.minibatch.kv_indices, self.num_tokens_tensor_buf)
|
||||
self.page_idx_buf[:num_tokens].copy_(page_idx[:num_tokens])
|
||||
self.page_offset_buf[:num_tokens].copy_(page_offset[:num_tokens])
|
||||
self.page_idx_buf[num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1)
|
||||
|
||||
|
||||
torch.cuda.synchronize()
|
||||
for warm_up_iters in range(11):
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def capture_graphs():
|
||||
with torch.cuda.graph(self.graphs, stream=self.stream):
|
||||
self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf)
|
||||
# self.graph_memory_pool = self.graphs.pool()
|
||||
|
||||
|
||||
capture_graphs()
|
||||
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.graphs.replay()
|
||||
|
||||
self.sync(calc_time=False)
|
||||
print("warmup finished.")
|
||||
|
||||
def run(self, batch: sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None):
|
||||
with torch.cuda.stream(self.stream):
|
||||
|
||||
batch_size = len(batch.prefill_mini_batches) # TODO: calc this
|
||||
num_tokens = 0
|
||||
for i in range(len(batch.decode_mini_batches)):
|
||||
batch_size += len(batch.decode_mini_batches[i])
|
||||
num_tokens += len(batch.decode_mini_batches[i])
|
||||
print(f'decode_batch_i: {len(batch.decode_mini_batches[i])},')
|
||||
|
||||
for i in range(len(batch.prefill_mini_batches)):
|
||||
num_tokens += batch.prefill_mini_batches[i][2]
|
||||
print(f'prefill_batch_i: {batch.prefill_mini_batches[i][2]},')
|
||||
|
||||
|
||||
|
||||
if isinstance(self.cuda_graphs, list):
|
||||
# cuda graph idx equal to min idx i in self.cuda_graphs, that self.cuda_graphs[i] > num_tokens
|
||||
cuda_graph_idx = next((i for i, token in enumerate(self.cuda_graphs) if token >= num_tokens), len(self.cuda_graphs))
|
||||
if cuda_graph_idx == len(self.cuda_graphs):
|
||||
assert False, "num_tokens is too large"
|
||||
else:
|
||||
cuda_graph_idx = -1
|
||||
|
||||
if self.use_cuda_graph:
|
||||
if cuda_graph_idx != -1:
|
||||
self.input[cuda_graph_idx].fill(batch, query_manager, self.page_size)
|
||||
else:
|
||||
self.input.fill(batch, query_manager, self.page_size)
|
||||
else:
|
||||
self.input = ForwardBatchInput(batch=batch, query_manager=query_manager, device=self.device)
|
||||
|
||||
|
||||
if cuda_graph_idx != -1 and self.use_cuda_graph:
|
||||
self.features = self.model.batch_embeddings(self.input[cuda_graph_idx], device=self.device)
|
||||
else:
|
||||
self.features = self.model.batch_embeddings(self.input, device=self.device)
|
||||
|
||||
|
||||
self.bsz_tensor_buf.copy_(batch_size)
|
||||
self.num_tokens_tensor_buf.copy_(torch.tensor([num_tokens], dtype=torch.int32, device=self.device))
|
||||
|
||||
if self.use_cuda_graph:
|
||||
if cuda_graph_idx != -1:
|
||||
self.features_buf[cuda_graph_idx][0].copy_(self.features[0], non_blocking=True)
|
||||
else:
|
||||
self.features_buf[0].copy_(self.features[0], non_blocking=True)
|
||||
"""
|
||||
if num_tokens_0 > 64:
|
||||
padded_num_tokens_0 = pad_num_tokens(num_tokens_0)
|
||||
self.features_buf[0][num_tokens_0:padded_num_tokens_0] = 0
|
||||
"""
|
||||
#self.input.forward_minibatchs[0].print()
|
||||
# print([[hash(k[i].float().cpu().numpy().tobytes()) for i in self.input.forward_minibatchs[0].kv_indices] for k in self.model.cache.k_caches])
|
||||
# print(f"overlap: {overlap}, is_compute_bound: {is_compute_bound}")
|
||||
|
||||
# self.model.flash_infer_attn_plan(self.input, self.bsz_tensors, self.num_tokens_tensors)
|
||||
|
||||
"""
|
||||
if self.use_cuda_graph:
|
||||
print("before replay features_buf", self.features_buf[0])
|
||||
print("features_buf addr", self.features_buf[0].data_ptr())
|
||||
else:
|
||||
print("before run features", self.features[0])
|
||||
"""
|
||||
if cuda_graph_idx != -1 and self.use_cuda_graph:
|
||||
self.model.flash_infer_attn_plan(self.input[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
|
||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
self.start_model_event.record(self.stream)
|
||||
page_idx, page_offset = self.model.cache.get_page_table(self.input[cuda_graph_idx].minibatch.position_ids, self.input[cuda_graph_idx].minibatch.q_indptr, self.input[cuda_graph_idx].minibatch.kv_indptr, self.input[cuda_graph_idx].minibatch.kv_indices, self.num_tokens_tensor_buf)
|
||||
if self.use_cuda_graph:
|
||||
self.page_idx_buf[cuda_graph_idx][:num_tokens].copy_(page_idx[:num_tokens])
|
||||
self.page_offset_buf[cuda_graph_idx][:num_tokens].copy_(page_offset[:num_tokens])
|
||||
self.page_idx_buf[cuda_graph_idx][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1)
|
||||
self.replay(cuda_graph_idx)
|
||||
self.output = ForwardBatchOutput()
|
||||
|
||||
self.output.top_ps.append(self.input[cuda_graph_idx].minibatch.top_ps)
|
||||
self.output.temperatures.append(self.input[cuda_graph_idx].minibatch.temperatures)
|
||||
|
||||
self.output.logits.append(self.outputs_buf[cuda_graph_idx].logits[0][self.input[cuda_graph_idx].minibatch.logits_start].clone())
|
||||
else:
|
||||
self.output = self.model(self.input[cuda_graph_idx], self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset)
|
||||
self.output.logits[0] = self.output.logits[0][self.input[cuda_graph_idx].minibatch.logits_start]
|
||||
self.end_model_event.record(self.stream)
|
||||
else:
|
||||
self.model.flash_infer_attn_plan(self.input, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
|
||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
self.start_model_event.record(self.stream)
|
||||
page_idx, page_offset = self.model.cache.get_page_table(self.input.minibatch.position_ids, self.input.minibatch.q_indptr, self.input.minibatch.kv_indptr, self.input.minibatch.kv_indices, self.num_tokens_tensor_buf)
|
||||
if self.use_cuda_graph:
|
||||
self.page_idx_buf[:num_tokens].copy_(page_idx[:num_tokens])
|
||||
self.page_offset_buf[:num_tokens].copy_(page_offset[:num_tokens])
|
||||
self.page_idx_buf[num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1)
|
||||
self.replay(cuda_graph_idx)
|
||||
self.output = ForwardBatchOutput()
|
||||
|
||||
self.output.top_ps.append(self.input.minibatch.top_ps)
|
||||
self.output.temperatures.append(self.input.minibatch.temperatures)
|
||||
|
||||
self.output.logits.append(self.outputs_buf.logits[0][self.input.minibatch.logits_start].clone())
|
||||
else:
|
||||
self.output = self.model(self.input, self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset)
|
||||
self.output.logits[0] = self.output.logits[0][self.input.minibatch.logits_start]
|
||||
self.output.top_ps.append(self.input.minibatch.top_ps)
|
||||
self.output.temperatures.append(self.input.minibatch.temperatures)
|
||||
|
||||
self.end_model_event.record(self.stream)
|
||||
|
||||
if not self.use_cuda_graph:
|
||||
self.output.num_batchs = self.input.batch_size
|
||||
else:
|
||||
self.output.num_batchs = self.input[cuda_graph_idx].batch_size
|
||||
|
||||
|
||||
def replay(self, cuda_graph_idx=-1):
|
||||
with torch.cuda.stream(self.stream):
|
||||
if cuda_graph_idx != -1:
|
||||
self.graphs[cuda_graph_idx].replay()
|
||||
else:
|
||||
self.graphs.replay()
|
||||
|
||||
|
||||
def sync(self, calc_time = True):
|
||||
self.stream.synchronize()
|
||||
if calc_time:
|
||||
self.model_time = self.start_model_event.elapsed_time(self.end_model_event) # In ms
|
158
ktransformers/server/balance_serve/inference/query_manager.py
Normal file
158
ktransformers/server/balance_serve/inference/query_manager.py
Normal file
|
@ -0,0 +1,158 @@
|
|||
'''
|
||||
Date: 2024-11-14 12:23:45
|
||||
LastEditors: djw
|
||||
LastEditTime: 2024-11-20 04:06:23
|
||||
'''
|
||||
import torch
|
||||
from ktransformers.server.balance_serve.settings import sched_ext
|
||||
import random
|
||||
import time
|
||||
|
||||
class QueryInfo:
|
||||
id: int
|
||||
active_position: int
|
||||
query_length: int
|
||||
is_prefill: int
|
||||
block_index: torch.Tensor
|
||||
query_tokens: torch.Tensor
|
||||
stop_criteria: list[torch.Tensor]
|
||||
|
||||
temperature: float
|
||||
top_p: float
|
||||
|
||||
max_length: int
|
||||
|
||||
def __init__(self, id, query_length: int, max_length: int, page_size: int, device: torch.device, is_prefill: bool = True, offset: int = 0, active_position: int = 0, temperature: float = 0.01, top_p: float = 1.0):
|
||||
self.id = id
|
||||
self.is_prefill = is_prefill
|
||||
self.active_position = active_position
|
||||
self.max_length = max_length - 1
|
||||
self.query_tokens = torch.zeros((max_length,), dtype=torch.int, device = device)
|
||||
self.stop_criteria = []
|
||||
self.block_index = torch.arange(offset, offset + (max_length + active_position + page_size - 1) // page_size, dtype=torch.int, device = device)
|
||||
self.query_length = query_length
|
||||
self.enqueue_time = time.time()
|
||||
self.decode_start_time = None
|
||||
self.speculative_token = {} # {position: (accept, token)}
|
||||
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
|
||||
def check_stop(self):
|
||||
if self.active_position >= self.max_length - 2:
|
||||
return True
|
||||
|
||||
# 遍历每个停止条件
|
||||
for stop_tensor in self.stop_criteria:
|
||||
stop_len = len(stop_tensor)
|
||||
|
||||
# 如果停止条件比 query_tokens 长,跳过
|
||||
if stop_len >= self.active_position:
|
||||
continue
|
||||
|
||||
#print(f"stop_tensor: {stop_tensor}, stop_len: {stop_len}, active_position: {self.active_position}, query_token: {self.query_tokens[self.active_position - stop_len - 1:self.active_position - 1]}")
|
||||
|
||||
if (torch.equal(self.query_tokens[self.active_position - stop_len - 1:self.active_position - 1], stop_tensor) and self.active_position) or self.max_length <= self.active_position + 3:
|
||||
self.life_time = time.time() - self.enqueue_time
|
||||
self.decode_duration_time = time.time() - self.decode_start_time
|
||||
self.decode_tps = (self.active_position - self.query_length) / self.decode_duration_time
|
||||
print(f"prefill length: {self.query_length}, prefill time: {self.prefill_duration_time}, prefill tps {self.prefill_tps}, decode length: {self.active_position - self.query_length}, decode time: {self.decode_duration_time}, decode tps {self.decode_tps}")
|
||||
return True # 找到匹配的停止条件
|
||||
|
||||
|
||||
return False # 没有找到任何停止条件
|
||||
|
||||
|
||||
def print(self):
|
||||
print(f"active_position: {self.active_position}, query_length: {self.query_length}, is_prefill: {self.is_prefill}")
|
||||
print(f"block_index_shape: {self.block_index.shape}, query_tokens_shape: {self.query_tokens.shape}")
|
||||
|
||||
|
||||
class QueryManager:
|
||||
|
||||
max_length: int = 65536
|
||||
page_size: int = 256
|
||||
device: torch.device
|
||||
query_map : dict[int, QueryInfo]
|
||||
|
||||
def __init__(self, max_length = 65536, page_size = 256, device = torch.device('cuda')):
|
||||
self.max_length = max_length
|
||||
self.page_size = page_size
|
||||
self.device = device
|
||||
self.query_map = {}
|
||||
|
||||
def add_query(self, batch: sched_ext.BatchQueryTodo):
|
||||
|
||||
for i in range(len(batch.query_ids)):
|
||||
id = batch.query_ids[i]
|
||||
if id not in self.query_map:
|
||||
print(f"add query id: {id}, batch.query_lengths: {batch.query_lengths[i]}, batch_query_tokens: {batch.query_tokens[i].shape}, batch.block_indexes: {batch.block_indexes[i]}")
|
||||
assert batch.query_tokens[i].size(0) < self.max_length, "query max length in batchquerytodo exceeds internal max_length"
|
||||
query_info = QueryInfo(id=id, query_length=batch.query_lengths[i], max_length=batch.query_tokens[i].size(0) + 1, page_size=self.page_size, device=self.device, temperature=batch.sample_options[i].temperature, top_p=batch.sample_options[i].top_p)
|
||||
query_info.query_tokens[:query_info.query_length].copy_(batch.query_tokens[i][:query_info.query_length].to(self.device))
|
||||
|
||||
for stop_token_list in batch.stop_criteria[i]:
|
||||
query_info.stop_criteria.append(torch.tensor(stop_token_list, dtype=torch.int, device = self.device))
|
||||
|
||||
block_num = batch.block_indexes[i].size(0)
|
||||
query_info.block_index[:block_num].copy_(batch.block_indexes[i].to(self.device))
|
||||
|
||||
self.query_map[id] = query_info
|
||||
|
||||
prefill_mini_batches = batch.prefill_mini_batches
|
||||
for (prefill_id, s, l) in prefill_mini_batches:
|
||||
if prefill_id == id:
|
||||
self.query_map[prefill_id].active_position = s
|
||||
|
||||
|
||||
def update(self, batch: sched_ext.BatchQueryTodo) -> list[sched_ext.QueryUpdate]:
|
||||
query_updates = []
|
||||
|
||||
prefill_mini_batches = batch.prefill_mini_batches
|
||||
|
||||
for (id, s, l) in prefill_mini_batches:
|
||||
|
||||
if id not in self.query_map:
|
||||
assert False, f"query id {id} not found in query_map"
|
||||
|
||||
# update query_info
|
||||
query_info = self.query_map[id]
|
||||
query_info.active_position += l
|
||||
|
||||
if query_info.active_position >= query_info.query_length and query_info.is_prefill:
|
||||
query_info.is_prefill = False
|
||||
query_info.prefill_duration_time = time.time() - query_info.enqueue_time
|
||||
query_info.prefill_tps = query_info.query_length / query_info.prefill_duration_time
|
||||
|
||||
|
||||
# generate schedule query_update
|
||||
query_update = sched_ext.QueryUpdate()
|
||||
query_update.id = id
|
||||
query_update.ok = True
|
||||
query_update.is_prefill = query_info.is_prefill
|
||||
query_update.active_position = query_info.active_position
|
||||
# if(not query_info.is_prefill):
|
||||
query_updates.append(query_update)
|
||||
|
||||
|
||||
decode_mini_batches = batch.decode_mini_batches
|
||||
|
||||
for ids in decode_mini_batches:
|
||||
for id in ids:
|
||||
if id not in self.query_map:
|
||||
assert False, f"query id {id} not found in query_map"
|
||||
|
||||
query_info = self.query_map[id]
|
||||
query_info.active_position += 1
|
||||
|
||||
query_update = sched_ext.QueryUpdate()
|
||||
query_update.id = id
|
||||
query_update.ok = True
|
||||
query_update.is_prefill = query_info.is_prefill
|
||||
|
||||
query_update.decode_done = query_info.check_stop()
|
||||
|
||||
query_update.active_position = query_info.active_position
|
||||
query_updates.append(query_update)
|
||||
|
||||
return query_updates
|
|
@ -0,0 +1,13 @@
|
|||
from .orchestrator import BatchedPenalizerOrchestrator
|
||||
from .penalizers.frequency_penalty import BatchedFrequencyPenalizer
|
||||
from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer
|
||||
from .penalizers.presence_penalty import BatchedPresencePenalizer
|
||||
from .penalizers.repetition_penalty import BatchedRepetitionPenalizer
|
||||
|
||||
__all__ = [
|
||||
"BatchedFrequencyPenalizer",
|
||||
"BatchedMinNewTokensPenalizer",
|
||||
"BatchedPresencePenalizer",
|
||||
"BatchedRepetitionPenalizer",
|
||||
"BatchedPenalizerOrchestrator",
|
||||
]
|
|
@ -0,0 +1,376 @@
|
|||
import abc
|
||||
import dataclasses
|
||||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _ReqLike:
|
||||
origin_input_ids: typing.Union[torch.Tensor, typing.List[int]]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _BatchLike:
|
||||
reqs: typing.List[_ReqLike]
|
||||
|
||||
def batch_size(self):
|
||||
return len(self.reqs)
|
||||
|
||||
|
||||
class BatchedPenalizerOrchestrator:
|
||||
batch: _BatchLike
|
||||
device: str
|
||||
vocab_size: int
|
||||
penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
batch: _BatchLike,
|
||||
device: str,
|
||||
Penalizers: typing.Set[typing.Type["_BatchedPenalizer"]],
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.batch = batch
|
||||
self.device = device
|
||||
|
||||
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
|
||||
|
||||
is_required = False
|
||||
for penalizer in self.penalizers.values():
|
||||
pen_is_required = penalizer.prepare_if_required()
|
||||
is_required |= pen_is_required
|
||||
self.is_required = is_required
|
||||
|
||||
if self.is_required:
|
||||
self.cumulate_input_tokens(
|
||||
input_ids=[req.origin_input_ids for req in self.reqs()]
|
||||
)
|
||||
|
||||
def reqs(self):
|
||||
return self.batch.reqs
|
||||
|
||||
def batch_size(self):
|
||||
return self.batch.batch_size()
|
||||
|
||||
def cumulate_input_tokens(
|
||||
self,
|
||||
input_ids: typing.Union[
|
||||
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
||||
],
|
||||
):
|
||||
"""
|
||||
Feed the input tokens to the penalizers.
|
||||
|
||||
Args:
|
||||
input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens.
|
||||
"""
|
||||
token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
|
||||
|
||||
for penalizer in self.penalizers.values():
|
||||
penalizer.cumulate_input_tokens(input_ids=token_ids)
|
||||
|
||||
def cumulate_output_tokens(
|
||||
self,
|
||||
output_ids: typing.Union[
|
||||
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
||||
],
|
||||
):
|
||||
"""
|
||||
Feed the output tokens to the penalizers.
|
||||
|
||||
Args:
|
||||
output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
|
||||
"""
|
||||
if not self.is_required:
|
||||
return
|
||||
|
||||
token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
|
||||
|
||||
for penalizer in self.penalizers.values():
|
||||
penalizer.cumulate_output_tokens(output_ids=token_ids)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply the penalizers to the logits.
|
||||
Note that it may apply the penalizers in-place.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): The logits to apply the penalizers to.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The logits after applying the penalizers.
|
||||
"""
|
||||
if not self.is_required:
|
||||
return
|
||||
|
||||
for penalizer in self.penalizers.values():
|
||||
logits = penalizer.apply(logits)
|
||||
|
||||
return logits
|
||||
|
||||
def filter(
|
||||
self,
|
||||
indices_to_keep: typing.List[int],
|
||||
indices_tensor_to_keep: torch.Tensor = None,
|
||||
):
|
||||
"""
|
||||
Filter the penalizers based on the indices to keep in the batch.
|
||||
|
||||
Args:
|
||||
indices_to_keep (typing.List[int]): List of indices to keep in the batch.
|
||||
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
|
||||
"""
|
||||
if not self.is_required:
|
||||
return
|
||||
|
||||
empty_indices = len(indices_to_keep) == 0
|
||||
|
||||
is_required = False
|
||||
for penalizer in self.penalizers.values():
|
||||
tmp_is_required = penalizer.is_required()
|
||||
is_required = is_required or tmp_is_required
|
||||
if not tmp_is_required or empty_indices:
|
||||
penalizer.teardown()
|
||||
else:
|
||||
# create tensor index only when it's needed
|
||||
if indices_tensor_to_keep is None:
|
||||
indices_tensor_to_keep = torch.tensor(
|
||||
indices_to_keep, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
penalizer.filter(
|
||||
indices_to_keep=indices_to_keep,
|
||||
indices_tensor_to_keep=indices_tensor_to_keep,
|
||||
)
|
||||
self.is_required = is_required
|
||||
|
||||
def merge(self, their: "BatchedPenalizerOrchestrator"):
|
||||
"""
|
||||
Merge the penalizers of another orchestrator into this one.
|
||||
|
||||
Note that this function **must** be called _before_ self.batch.reqs is updated (filtered).
|
||||
Each unprepared penalizers would have to be prepared (creating tensors, etc.) first before merging.
|
||||
This step requires the original batch.reqs, before it gets merged with other batch.reqs.
|
||||
|
||||
Args:
|
||||
their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
|
||||
"""
|
||||
if not self.is_required and not their.is_required:
|
||||
return
|
||||
|
||||
self.is_required |= their.is_required
|
||||
for Penalizer, their_penalizer in their.penalizers.items():
|
||||
if Penalizer not in self.penalizers:
|
||||
raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
|
||||
|
||||
self.penalizers[Penalizer].merge(their_penalizer)
|
||||
|
||||
|
||||
class _TokenIDs:
|
||||
"""
|
||||
A class that wraps token IDs to provide additional utility functions to penalizers.
|
||||
|
||||
Attributes:
|
||||
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
|
||||
token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs.
|
||||
cached_counts (torch.Tensor): The cached occurrence count tensor.
|
||||
"""
|
||||
|
||||
orchestrator: BatchedPenalizerOrchestrator
|
||||
token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]]
|
||||
cached_counts: torch.Tensor = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
orchestrator: BatchedPenalizerOrchestrator,
|
||||
token_ids: typing.Union[
|
||||
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
||||
],
|
||||
):
|
||||
self.orchestrator = orchestrator
|
||||
|
||||
if not isinstance(token_ids[0], torch.Tensor):
|
||||
token_ids = [
|
||||
torch.tensor(
|
||||
data=ids, dtype=torch.int64, device=self.orchestrator.device
|
||||
)
|
||||
for ids in token_ids
|
||||
]
|
||||
|
||||
self.token_ids = token_ids
|
||||
|
||||
def occurrence_count(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The occurrence count tensor.
|
||||
"""
|
||||
if self.cached_counts is not None:
|
||||
return self.cached_counts
|
||||
|
||||
token_ids = self.token_ids
|
||||
|
||||
if isinstance(token_ids, torch.Tensor):
|
||||
token_ids = token_ids.unsqueeze(1)
|
||||
|
||||
# needs to be long to be used as index in scatter_add
|
||||
if token_ids.dtype != torch.int64:
|
||||
token_ids = token_ids.to(torch.int64)
|
||||
|
||||
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=token_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.orchestrator.vocab_size,
|
||||
)
|
||||
|
||||
self.cached_counts = torch.zeros(
|
||||
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
||||
dtype=torch.int64,
|
||||
device=self.orchestrator.device,
|
||||
).scatter_add_(
|
||||
dim=1,
|
||||
index=padded_token_ids,
|
||||
src=torch.ones_like(padded_token_ids),
|
||||
)[
|
||||
:, : self.orchestrator.vocab_size
|
||||
]
|
||||
|
||||
return self.cached_counts
|
||||
|
||||
|
||||
class _BatchedPenalizer(abc.ABC):
|
||||
"""
|
||||
An abstract class for a batched penalizer.
|
||||
"""
|
||||
|
||||
orchestrator: BatchedPenalizerOrchestrator
|
||||
_is_prepared: bool = False
|
||||
|
||||
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
||||
self.orchestrator = orchestrator
|
||||
|
||||
def is_prepared(self) -> bool:
|
||||
return self._is_prepared
|
||||
|
||||
def is_required(self) -> bool:
|
||||
return self._is_required()
|
||||
|
||||
def prepare(self):
|
||||
if not self.is_prepared():
|
||||
self._prepare()
|
||||
self._is_prepared = True
|
||||
|
||||
def prepare_if_required(self):
|
||||
if self.is_required():
|
||||
self.prepare()
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def teardown(self):
|
||||
if self.is_prepared():
|
||||
self._teardown()
|
||||
self._is_prepared = False
|
||||
|
||||
def cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
if not self.is_prepared():
|
||||
return
|
||||
|
||||
self._cumulate_input_tokens(input_ids=input_ids)
|
||||
|
||||
def cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
if not self.is_prepared():
|
||||
return
|
||||
|
||||
self._cumulate_output_tokens(output_ids=output_ids)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if not self.is_prepared():
|
||||
return logits
|
||||
|
||||
return self._apply(logits=logits)
|
||||
|
||||
def filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
if not self.is_prepared():
|
||||
return
|
||||
|
||||
self._filter(
|
||||
indices_to_keep=indices_to_keep,
|
||||
indices_tensor_to_keep=indices_tensor_to_keep,
|
||||
)
|
||||
|
||||
def merge(self, their: "_BatchedPenalizer"):
|
||||
if not self.is_prepared() and not their.is_prepared():
|
||||
return
|
||||
|
||||
self.prepare()
|
||||
their.prepare()
|
||||
self._merge(their)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _is_required(self) -> bool:
|
||||
"""
|
||||
Check if the penalizer is required to be prepared.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _prepare(self):
|
||||
"""
|
||||
Prepare the penalizer.
|
||||
Usually, this is where the penalizer initializes its tensors.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _teardown(self):
|
||||
"""
|
||||
Tear down the penalizer.
|
||||
Usually, this is where the penalizer frees its tensors.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
"""
|
||||
Cumulate the input tokens.
|
||||
Orchestrator will call this function to feed the input tokens to the penalizer.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
"""
|
||||
Cumulate the output tokens.
|
||||
Orchestrator will call this function to feed the output tokens to the penalizer.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply the penalizer to the logits.
|
||||
Penalizers can modify the logits in-place if needed.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
"""
|
||||
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _merge(self, their: "_BatchedPenalizer"):
|
||||
"""
|
||||
Merge the penalizer with another penalizer.
|
||||
"""
|
||||
pass
|
|
@ -0,0 +1,80 @@
|
|||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
|
||||
|
||||
class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Frequency penalizer penalizes tokens based on their frequency in the output.
|
||||
"""
|
||||
|
||||
frequency_penalties: torch.Tensor = None
|
||||
cumulated_frequency_penalties: torch.Tensor = None
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.frequency_penalty != 0.0
|
||||
for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.cumulated_frequency_penalties = (
|
||||
torch.tensor(
|
||||
data=[0.0 for _ in self.orchestrator.reqs()],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.repeat(1, self.orchestrator.vocab_size)
|
||||
)
|
||||
|
||||
self.frequency_penalties = (
|
||||
torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.frequency_penalty
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.expand_as(self.cumulated_frequency_penalties)
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
del self.frequency_penalties
|
||||
del self.cumulated_frequency_penalties
|
||||
|
||||
self.frequency_penalties = None
|
||||
self.cumulated_frequency_penalties = None
|
||||
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
pass
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
self.cumulated_frequency_penalties += (
|
||||
self.frequency_penalties * output_ids.occurrence_count()
|
||||
)
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
logits -= self.cumulated_frequency_penalties
|
||||
return logits
|
||||
|
||||
def _filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
|
||||
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
|
||||
indices_tensor_to_keep
|
||||
]
|
||||
|
||||
def _merge(self, their: "BatchedFrequencyPenalizer"):
|
||||
self.frequency_penalties = torch.cat(
|
||||
[self.frequency_penalties, their.frequency_penalties], dim=0
|
||||
)
|
||||
self.cumulated_frequency_penalties = torch.cat(
|
||||
[self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
|
||||
dim=0,
|
||||
)
|
|
@ -0,0 +1,108 @@
|
|||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
|
||||
|
||||
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Min new tokens penalizer penalizes tokens based on the length of the output.
|
||||
"""
|
||||
|
||||
min_new_tokens: torch.Tensor = None
|
||||
stop_token_penalties: torch.Tensor = None
|
||||
len_output_tokens: torch.Tensor = None
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.min_new_tokens > 0 for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.min_new_tokens = torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.min_new_tokens for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=self.orchestrator.device,
|
||||
).unsqueeze_(1)
|
||||
|
||||
padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=[
|
||||
torch.tensor(
|
||||
data=(
|
||||
list(
|
||||
(req.sampling_params.stop_token_ids or set())
|
||||
| (req.tokenizer.additional_stop_token_ids or set())
|
||||
| {req.tokenizer.eos_token_id}
|
||||
)
|
||||
),
|
||||
dtype=torch.int64,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
batch_first=True,
|
||||
padding_value=self.orchestrator.vocab_size,
|
||||
)
|
||||
self.stop_token_penalties = torch.zeros(
|
||||
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
).scatter_add_(
|
||||
dim=1,
|
||||
index=padded_stop_token_ids,
|
||||
src=torch.full_like(
|
||||
input=padded_stop_token_ids,
|
||||
dtype=torch.float32,
|
||||
fill_value=float("-inf"),
|
||||
device=self.orchestrator.device,
|
||||
),
|
||||
)[
|
||||
:, : self.orchestrator.vocab_size
|
||||
]
|
||||
|
||||
self.len_output_tokens = torch.zeros(
|
||||
size=(self.orchestrator.batch_size(), 1),
|
||||
dtype=torch.int32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
del self.min_new_tokens
|
||||
del self.stop_token_penalties
|
||||
del self.len_output_tokens
|
||||
|
||||
self.min_new_tokens = None
|
||||
self.stop_token_penalties = None
|
||||
self.len_output_tokens = None
|
||||
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
pass
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
self.len_output_tokens += 1
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits)
|
||||
logits[mask] += self.stop_token_penalties[mask]
|
||||
return logits
|
||||
|
||||
def _filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
|
||||
self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
|
||||
self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
|
||||
|
||||
def _merge(self, their: "BatchedMinNewTokensPenalizer"):
|
||||
self.min_new_tokens = torch.cat(
|
||||
[self.min_new_tokens, their.min_new_tokens], dim=0
|
||||
)
|
||||
self.stop_token_penalties = torch.cat(
|
||||
[self.stop_token_penalties, their.stop_token_penalties], dim=0
|
||||
)
|
||||
self.len_output_tokens = torch.cat(
|
||||
[self.len_output_tokens, their.len_output_tokens], dim=0
|
||||
)
|
|
@ -0,0 +1,79 @@
|
|||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
|
||||
|
||||
class BatchedPresencePenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Presence penalizer penalizes tokens based on their presence in the output.
|
||||
"""
|
||||
|
||||
presence_penalties: torch.Tensor = None
|
||||
cumulated_presence_penalties: torch.Tensor = None
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.presence_penalty != 0.0
|
||||
for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.cumulated_presence_penalties = (
|
||||
torch.tensor(
|
||||
data=[0.0 for _ in self.orchestrator.reqs()],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.repeat(1, self.orchestrator.vocab_size)
|
||||
)
|
||||
|
||||
self.presence_penalties = (
|
||||
torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.presence_penalty
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.expand_as(self.cumulated_presence_penalties)
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
del self.presence_penalties
|
||||
del self.cumulated_presence_penalties
|
||||
|
||||
self.presence_penalties = None
|
||||
self.cumulated_presence_penalties = None
|
||||
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
pass
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
mask = output_ids.occurrence_count() > 0
|
||||
self.cumulated_presence_penalties[mask] = self.presence_penalties[mask]
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
logits -= self.cumulated_presence_penalties
|
||||
return logits
|
||||
|
||||
def _filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
|
||||
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
|
||||
indices_tensor_to_keep
|
||||
]
|
||||
|
||||
def _merge(self, their: "BatchedPresencePenalizer"):
|
||||
self.presence_penalties = torch.cat(
|
||||
[self.presence_penalties, their.presence_penalties], dim=0
|
||||
)
|
||||
self.cumulated_presence_penalties = torch.cat(
|
||||
[self.cumulated_presence_penalties, their.cumulated_presence_penalties],
|
||||
dim=0,
|
||||
)
|
|
@ -0,0 +1,83 @@
|
|||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
|
||||
|
||||
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Repetition penalizer penalizes tokens based on their repetition in the input and output.
|
||||
"""
|
||||
|
||||
repetition_penalties: torch.Tensor = None
|
||||
cumulated_repetition_penalties: torch.Tensor = None
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.repetition_penalty != 1.0
|
||||
for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.cumulated_repetition_penalties = (
|
||||
torch.tensor(
|
||||
data=[1.0 for _ in self.orchestrator.reqs()],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.repeat(1, self.orchestrator.vocab_size)
|
||||
)
|
||||
|
||||
self.repetition_penalties = (
|
||||
torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.repetition_penalty
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.expand_as(self.cumulated_repetition_penalties)
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
del self.repetition_penalties
|
||||
del self.cumulated_repetition_penalties
|
||||
|
||||
self.repetition_penalties = None
|
||||
self.cumulated_repetition_penalties = None
|
||||
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
mask = input_ids.occurrence_count() > 0
|
||||
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
mask = output_ids.occurrence_count() > 0
|
||||
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return torch.where(
|
||||
logits > 0,
|
||||
logits / self.cumulated_repetition_penalties,
|
||||
logits * self.cumulated_repetition_penalties,
|
||||
)
|
||||
|
||||
def _filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
||||
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
|
||||
indices_tensor_to_keep
|
||||
]
|
||||
|
||||
def _merge(self, their: "BatchedRepetitionPenalizer"):
|
||||
self.repetition_penalties = torch.cat(
|
||||
[self.repetition_penalties, their.repetition_penalties], dim=0
|
||||
)
|
||||
self.cumulated_repetition_penalties = torch.cat(
|
||||
[self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],
|
||||
dim=0,
|
||||
)
|
100
ktransformers/server/balance_serve/inference/sampling/sampler.py
Normal file
100
ktransformers/server/balance_serve/inference/sampling/sampler.py
Normal file
|
@ -0,0 +1,100 @@
|
|||
'''
|
||||
Date: 2024-11-14 12:23:45
|
||||
LastEditors: Xie Weiyu ervinxie@qq.com
|
||||
LastEditTime: 2024-11-25 08:59:23
|
||||
'''
|
||||
import logging
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GenerationConfig
|
||||
|
||||
from flashinfer.sampling import (
|
||||
min_p_sampling_from_probs,
|
||||
top_k_renorm_probs,
|
||||
top_k_top_p_sampling_from_logits,
|
||||
top_p_renorm_probs,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SamplingOptions():
|
||||
# Batched sampling params
|
||||
temperatures: torch.Tensor
|
||||
top_ps: torch.Tensor
|
||||
top_ks: torch.Tensor
|
||||
min_ps: torch.Tensor
|
||||
|
||||
# All requests use greedy sampling
|
||||
is_all_greedy: bool
|
||||
|
||||
# Dispatch in CUDA graph
|
||||
need_min_p_sampling: bool
|
||||
|
||||
def __init__(self, bsz = 1, device = torch.device('cuda'), pretrained_config:GenerationConfig = None, temperatures: torch.Tensor = None, top_ps: torch.Tensor = None):
|
||||
if pretrained_config is None and temperatures is None:
|
||||
self.temperatures = torch.full((bsz, 1), 0, device=device, dtype=torch.float32)
|
||||
self.top_ps = torch.ones((bsz, 1), device=device, dtype=torch.float32)
|
||||
self.top_ks = torch.ones((bsz, 1), device=device, dtype=torch.float32)
|
||||
self.need_min_p_sampling = False
|
||||
self.is_all_greedy = True
|
||||
else:
|
||||
if temperatures is not None:
|
||||
self.temperatures = temperatures.unsqueeze(-1)
|
||||
else:
|
||||
self.temperatures = torch.full((bsz, 1), pretrained_config.temperature, device=device, dtype=torch.float32)
|
||||
|
||||
if top_ps is not None:
|
||||
self.top_ps = top_ps.unsqueeze(-1)
|
||||
else:
|
||||
self.top_ps = torch.full((bsz, 1), pretrained_config.top_p, device=device, dtype=torch.float32)
|
||||
self.top_ks = torch.full((bsz, 1), pretrained_config.top_k, device=device, dtype=torch.float32)
|
||||
self.need_min_p_sampling = False
|
||||
self.is_all_greedy = False
|
||||
|
||||
class Sampler(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_config: SamplingOptions = None,
|
||||
):
|
||||
if sampling_config == None:
|
||||
sampling_config = SamplingOptions()
|
||||
|
||||
logits = logits.contiguous()
|
||||
origin_logits = logits.clone()
|
||||
if sampling_config.is_all_greedy:
|
||||
# Use torch.argmax if all requests use greedy sampling
|
||||
probs = logits
|
||||
batch_next_token_ids = torch.argmax(logits, -1)
|
||||
else:
|
||||
# Post process logits
|
||||
logits.div_(sampling_config.temperatures)
|
||||
max_top_k_round, batch_size = 32, logits.shape[0]
|
||||
if sampling_config.need_min_p_sampling:
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
logits = None
|
||||
del logits
|
||||
probs = top_k_renorm_probs(probs, sampling_config.top_ks)
|
||||
probs = top_p_renorm_probs(probs, sampling_config.top_ps)
|
||||
batch_next_token_ids = min_p_sampling_from_probs(
|
||||
probs, sampling_config.min_ps
|
||||
)
|
||||
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
|
||||
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
|
||||
else:
|
||||
# TODO: use different kernel when don't need top_k or top_p
|
||||
# @TODO get probs
|
||||
probs = logits
|
||||
batch_next_token_ids = top_k_top_p_sampling_from_logits(
|
||||
logits,
|
||||
sampling_config.top_ks,
|
||||
sampling_config.top_ps,
|
||||
filter_apply_order="joint",
|
||||
)
|
||||
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
|
||||
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
|
||||
|
||||
return batch_next_token_ids.to(torch.int32), probs
|
213
ktransformers/server/balance_serve/sched_rpc.py
Normal file
213
ktransformers/server/balance_serve/sched_rpc.py
Normal file
|
@ -0,0 +1,213 @@
|
|||
from datetime import datetime
|
||||
import os
|
||||
from typing import Optional
|
||||
import zmq
|
||||
import pickle
|
||||
import threading
|
||||
import torch.multiprocessing as mp
|
||||
import sys
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
# sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
import pickle
|
||||
import argparse
|
||||
from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings
|
||||
|
||||
|
||||
|
||||
if mp.get_start_method(allow_none=True) is None:
|
||||
print('set start method')
|
||||
mp.set_start_method('spawn')
|
||||
else:
|
||||
print(f'start method already set to {mp.get_start_method(allow_none=True)}')
|
||||
|
||||
|
||||
class SchedulerServer:
|
||||
def __init__(self, settings, main_args):
|
||||
# 创建 Scheduler 实例并初始化
|
||||
self.sched = sched_ext.create_scheduler(settings)
|
||||
|
||||
# 初始化 ZeroMQ 上下文和套接字
|
||||
self.context = zmq.Context()
|
||||
self.frontend = self.context.socket(zmq.ROUTER)
|
||||
print(f"sched zmq rpc server on port {main_args.sched_port}")
|
||||
self.frontend.bind(f"tcp://*:{main_args.sched_port}")
|
||||
|
||||
# 创建内部的 DEALER 套接字,用于与工作线程通信
|
||||
self.backend = self.context.socket(zmq.DEALER)
|
||||
self.backend.bind("inproc://backend")
|
||||
|
||||
# 启动调度器
|
||||
def run_scheduler(self):
|
||||
self.sched.run()
|
||||
|
||||
# 停止调度器
|
||||
def stop_scheduler(self):
|
||||
self.sched.stop()
|
||||
|
||||
# 处理客户端请求
|
||||
def start_proxy(self):
|
||||
# 使用 ZMQ 的内置代理,将前端请求分发给后端工作线程
|
||||
zmq.proxy(self.frontend, self.backend)
|
||||
|
||||
# 工作线程处理请求
|
||||
def worker_routine(self):
|
||||
worker = self.context.socket(zmq.REP)
|
||||
worker.connect("inproc://backend")
|
||||
while True:
|
||||
try:
|
||||
# 接收客户端请求
|
||||
message = worker.recv()
|
||||
data = pickle.loads(message)
|
||||
|
||||
method = data.get('method')
|
||||
params = data.get('params', {})
|
||||
# print(f"Received request: {method}")
|
||||
|
||||
if method == 'add_query':
|
||||
query_add = params.get('query') # 直接是一个 QueryAdd 对象
|
||||
# 添加查询
|
||||
query_id = self.sched.add_query(query_add)
|
||||
# 发送响应
|
||||
response = {'status': 'ok', 'query_id': query_id}
|
||||
worker.send(pickle.dumps(response))
|
||||
|
||||
elif method == 'cancel_query':
|
||||
query_id = params.get('query_id')
|
||||
# 假设您的 Scheduler 类实现了 cancel 方法
|
||||
self.sched.cancel(query_id)
|
||||
response = {'status': 'ok'}
|
||||
worker.send(pickle.dumps(response))
|
||||
|
||||
elif method == 'update_last_batch':
|
||||
updates = params.get('updates') # 直接是一个列表,包含 QueryUpdate 对象
|
||||
|
||||
# 更新最后一个批次
|
||||
batch_todo = self.sched.update_last_batch(updates)
|
||||
|
||||
# 直接发送 batch_todo 对象
|
||||
response = {'status': 'ok', 'batch_todo': batch_todo}
|
||||
# print (batch_todo.query_lengths, batch_todo.query_ids)
|
||||
worker.send(pickle.dumps(response))
|
||||
|
||||
elif method == 'get_inference_context':
|
||||
inference_context = self.sched.get_inference_context()
|
||||
data = {
|
||||
"k_cache":inference_context.k_cache,
|
||||
"v_cache":inference_context.v_cache
|
||||
}
|
||||
print(f"Serializing KVCache")
|
||||
data["k_cache"] = [mp.reductions.reduce_tensor(t) for t in data['k_cache']]
|
||||
data["v_cache"] = [mp.reductions.reduce_tensor(t) for t in data['v_cache']]
|
||||
# print(data)
|
||||
response = {'status': 'ok', 'inference_context': data}
|
||||
|
||||
worker.send(pickle.dumps(response))
|
||||
# response['inference_context'].k_cache[0][0, 0, 0, 0, 0] = 1
|
||||
# print("k_cache update")
|
||||
|
||||
else:
|
||||
# 未知方法
|
||||
response = {'status': 'error', 'message': 'Unknown method'}
|
||||
worker.send(pickle.dumps(response))
|
||||
|
||||
except Exception as e:
|
||||
# 处理异常并发送错误响应
|
||||
response = {'status': 'error', 'message': str(e)}
|
||||
worker.send(pickle.dumps(response))
|
||||
|
||||
# 启动 RPC 服务
|
||||
def start_rpc_service(self):
|
||||
try:
|
||||
print("Scheduler RPC service is running...")
|
||||
|
||||
# 在单独的线程中运行调度器
|
||||
threading.Thread(target=self.run_scheduler, daemon=True).start()
|
||||
|
||||
# 启动工作线程
|
||||
for _ in range(10): # 根据需要调整线程数
|
||||
threading.Thread(target=self.worker_routine, daemon=True).start()
|
||||
|
||||
# 启动代理,开始监听请求
|
||||
self.start_proxy()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Shutting down scheduler RPC service...")
|
||||
self.stop_rpc_service()
|
||||
|
||||
# 停止 RPC 服务
|
||||
def stop_rpc_service(self):
|
||||
self.stop_scheduler()
|
||||
self.frontend.close()
|
||||
self.backend.close()
|
||||
self.context.term()
|
||||
|
||||
def start_server(settings, main_args):
|
||||
server = SchedulerServer(settings, main_args)
|
||||
server.start_rpc_service()
|
||||
|
||||
|
||||
# Add async client for webserver
|
||||
class SchedulerClient:
|
||||
def __init__(self, sched_port):
|
||||
address=f'tcp://localhost:{sched_port}'
|
||||
self.address = address
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.REQ)
|
||||
self.socket.connect(self.address)
|
||||
print(f"Connected to server at {self.address}")
|
||||
|
||||
def __del__(self):
|
||||
self.socket.close()
|
||||
self.context.term()
|
||||
|
||||
def send_request(self, method, params=None):
|
||||
if params is None:
|
||||
params = {}
|
||||
request = {
|
||||
'method': method,
|
||||
'params': params
|
||||
}
|
||||
# print(f'send request {request}')
|
||||
self.socket.send(pickle.dumps(request))
|
||||
response = self.socket.recv()
|
||||
# print(response)
|
||||
response = pickle.loads(response)
|
||||
if response.get('status') == 'ok':
|
||||
return response
|
||||
else:
|
||||
raise Exception(f"Error from server: {response.get('message')}")
|
||||
|
||||
def add_query(self, query):
|
||||
response = self.send_request('add_query', {'query': query})
|
||||
return response.get('query_id')
|
||||
|
||||
def cancel_query(self, query_id):
|
||||
self.send_request('cancel_query', {'query_id': query_id})
|
||||
|
||||
def update_last_batch(self, updates):
|
||||
response = self.send_request('update_last_batch', {'updates': updates})
|
||||
# print(f"update_last_batch response {response}")
|
||||
return response.get('batch_todo')
|
||||
|
||||
def rebuild_inferece_context(self,response):
|
||||
data = response.get('inference_context')
|
||||
inference_context = sched_ext.InferenceContext()
|
||||
print('Rebuilding kvcache')
|
||||
inference_context.k_cache = [fn(*args) for fn,args in data['k_cache']]
|
||||
inference_context.v_cache = [fn(*args) for fn,args in data['v_cache']]
|
||||
return inference_context
|
||||
|
||||
def get_inference_context_raw(self):
|
||||
response = self.send_request('get_inference_context')
|
||||
return response
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
with open(args.config, "rb") as f:
|
||||
main_args = pickle.load(f)
|
||||
settings = create_sched_settings(main_args)
|
||||
start_server(settings, main_args)
|
73
ktransformers/server/balance_serve/settings.py
Normal file
73
ktransformers/server/balance_serve/settings.py
Normal file
|
@ -0,0 +1,73 @@
|
|||
'''
|
||||
Date: 2024-11-13 09:43:39
|
||||
LastEditors: djw
|
||||
LastEditTime: 2024-11-18 16:41:03
|
||||
'''
|
||||
import sys, os
|
||||
import yaml, json
|
||||
from time import sleep
|
||||
|
||||
|
||||
import sched_ext
|
||||
from transformers import AutoConfig
|
||||
|
||||
def create_sched_settings(args):
|
||||
default_sample_options = sched_ext.SampleOptions()
|
||||
model_name = os.path.basename(os.path.normpath(args.model_dir))
|
||||
input_model_settings = sched_ext.ModelSettings()
|
||||
input_model_settings.model_path = args.model_dir
|
||||
input_model_settings.params_count = int(0)
|
||||
model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
input_model_settings.layer_count = model_config.num_hidden_layers
|
||||
input_model_settings.num_k_heads = 1 # model_config["num_key_value_heads"]
|
||||
input_model_settings.k_head_dim = 576
|
||||
input_model_settings.bytes_per_params = 2
|
||||
input_model_settings.bytes_per_kv_cache_element = 2
|
||||
settings = sched_ext.Settings()
|
||||
settings.model_name = model_name
|
||||
settings.quant_type = "BF16"
|
||||
settings.model_settings = input_model_settings
|
||||
settings.page_size = args.page_size
|
||||
settings.gpu_device_count = 1 # tp
|
||||
settings.gpu_device_id = [i for i in range(settings.gpu_device_count)]
|
||||
# settings.gpu_memory_size = args.cache_lens*576*2
|
||||
settings.gpu_memory_size = args.gpu_memory_size
|
||||
settings.memory_utilization_percentage = args.utilization_percentage
|
||||
max_batch_size = args.max_batch_size
|
||||
chunk_size = args.chunk_size
|
||||
|
||||
max_decode_batch_size = max_batch_size - 2
|
||||
|
||||
settings.max_batch_size = max_batch_size
|
||||
settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2
|
||||
settings.sample_options = default_sample_options
|
||||
settings.sched_metrics_port = args.sched_metrics_port
|
||||
settings.gpu_only = args.memory_gpu_only
|
||||
settings.use_self_defined_head_dim = True
|
||||
settings.self_defined_head_dim = 576
|
||||
settings.full_kv_cache_on_each_gpu = True
|
||||
settings.k_cache_on = True
|
||||
settings.v_cache_on = False
|
||||
|
||||
settings.kvc2_root_path = '/mnt/data/persist-kvc'
|
||||
settings.kvc2_config_path = args.kvc2_config_dir
|
||||
settings.memory_pool_size_GB = args.cpu_memory_size_GB
|
||||
settings.evict_count = 40
|
||||
settings.kvc2_metrics_port = args.kvc2_metrics_port
|
||||
settings.load_from_disk = False
|
||||
settings.save_to_disk = True
|
||||
|
||||
|
||||
settings.strategy_name = args.sched_strategy
|
||||
|
||||
settings.auto_derive()
|
||||
return settings
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -11,6 +11,7 @@ LastEditTime : 2024-08-12 06:31:14
|
|||
import os
|
||||
import shutil
|
||||
import yaml
|
||||
import psutil
|
||||
|
||||
from ktransformers.server.config.singleton import Singleton
|
||||
from typing import Optional
|
||||
|
@ -33,12 +34,15 @@ class Config(metaclass=Singleton):
|
|||
|
||||
user_path: str = os.path.expanduser("~")
|
||||
localstore_path: str = os.path.join(user_path, ".ktransformers")
|
||||
kvc2_config_dir = os.path.join(localstore_path, "kvc2")
|
||||
config_path: str = os.path.join(localstore_path, Config.CONFIG_FILE_NAME)
|
||||
if not os.path.exists(config_yaml):
|
||||
print(f"Can't find config file, {config_yaml}")
|
||||
exit(-1)
|
||||
if not os.path.exists(localstore_path):
|
||||
os.mkdir(localstore_path)
|
||||
if not os.path.exists(kvc2_config_dir):
|
||||
os.mkdir(kvc2_config_dir)
|
||||
if not os.path.exists(config_path):
|
||||
shutil.copyfile(config_yaml, config_path)
|
||||
with open(config_path, "r", encoding="utf-8") as fp:
|
||||
|
@ -60,11 +64,14 @@ class Config(metaclass=Singleton):
|
|||
self.user_path: str = os.path.expanduser("~")
|
||||
self.localstore_path: str = os.path.join(self.user_path, ".ktransformers")
|
||||
# log configs
|
||||
self.log_dir = os.path.join(self.base_path, Config.to_path(cfg["log"]["dir"]))
|
||||
self.log_dir = os.path.join(self.localstore_path, cfg["log"]["dir"])
|
||||
if not os.path.exists(self.log_dir):
|
||||
os.mkdir(self.log_dir)
|
||||
self.log_file = cfg["log"]["file"]
|
||||
self.log_level = cfg["log"]["level"]
|
||||
self.backup_count = cfg["log"]["backup_count"]
|
||||
|
||||
self.kvc2_config_dir = os.path.join(self.localstore_path, "kvc2")
|
||||
# server configs
|
||||
self.server: dict = cfg.get("server", {})
|
||||
self.server_ip = self.server.get("ip", "0.0.0.0")
|
||||
|
@ -74,7 +81,7 @@ class Config(metaclass=Singleton):
|
|||
# db configs
|
||||
self.db_configs: dict = cfg.get("db", {})
|
||||
self.db_type = self.db_configs.get("type", "")
|
||||
self.db_host = os.path.join(self.base_path, self.db_configs.get("host", ""))
|
||||
self.db_host = self.localstore_path
|
||||
self.db_port = self.db_configs.get("port", "")
|
||||
self.db_name = self.db_configs.get("database", "")
|
||||
self.db_pool_size = self.db_configs.get("pool_size")
|
||||
|
@ -101,11 +108,6 @@ class Config(metaclass=Singleton):
|
|||
self.optimize_config_path: Optional[str] = self.model.get(
|
||||
"optimize_config_path", None
|
||||
)
|
||||
self.paged = self.model.get("paged", True)
|
||||
|
||||
self.total_context = self.model.get("total_context", 2**18)
|
||||
self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1)
|
||||
self.chunk_prefill_size = self.model.get("chunk_prefill_size", 8192)
|
||||
|
||||
self.max_new_tokens = self.model.get("max_new_tokens", 2000)
|
||||
self.json_mode = self.model.get("json_mode", False)
|
||||
|
@ -138,7 +140,6 @@ class Config(metaclass=Singleton):
|
|||
self.repetition_penalty = self.model.get("repetition_penalty", 1.01)
|
||||
self.frequency_penalty = self.model.get("frequency_penalty", 0.0)
|
||||
self.presence_penalty = self.model.get("presence_penalty", 0.0)
|
||||
self.max_response_tokens = self.model.get("max_response_tokens", 300)
|
||||
self.response_chunk = self.model.get("response_chunk", 250)
|
||||
self.no_code_formatting = self.model.get("no_code_formatting", False)
|
||||
self.cache_8bit = self.model.get("cache_8bit", False)
|
||||
|
@ -155,8 +156,9 @@ class Config(metaclass=Singleton):
|
|||
self.web_cross_domain: bool = self.web.get("open_cross_domain", True)
|
||||
self.mount_web: bool = self.web.get("mount", False)
|
||||
|
||||
# ext
|
||||
self.ext: dict = cfg.get("ext", {})
|
||||
self.cpu_infer = self.ext.get("cpu_infer", 10)
|
||||
self.cpu_infer = psutil.cpu_count(logical=False) - 3
|
||||
|
||||
# file config
|
||||
self.local_store_configs: dict = cfg.get("local_store", {})
|
||||
|
@ -169,7 +171,6 @@ class Config(metaclass=Singleton):
|
|||
|
||||
# long context config
|
||||
self.long_context_config: dict = cfg.get("long_context", {})
|
||||
self.chunk_size = self.long_context_config.get("chunk_size", 4096)
|
||||
self.max_seq_len = self.long_context_config.get("max_seq_len", 32000)
|
||||
self.block_size = self.long_context_config.get("block_size", 128)
|
||||
self.local_windows_len = self.long_context_config.get("local_windows_len", 4096)
|
||||
|
@ -187,3 +188,21 @@ class Config(metaclass=Singleton):
|
|||
# local chat
|
||||
self.local_chat_config: dict = cfg.get("local_chat", {})
|
||||
self.prompt_file = self.local_chat_config.get("prompt_file", None)
|
||||
|
||||
# asyncserver
|
||||
self.sched_strategy = cfg["async_server"]["sched_strategy"]
|
||||
self.sched_port = cfg["async_server"]["sched_port"]
|
||||
self.sched_metrics_port = cfg["async_server"]["sched_metrics_port"]
|
||||
self.kvc2_metrics_port = cfg["async_server"]["kvc2_metrics_port"]
|
||||
self.max_batch_size = cfg["async_server"]["max_batch_size"]
|
||||
self.page_size = cfg["attn"]["page_size"]
|
||||
self.chunk_size = cfg["attn"]["chunk_size"]
|
||||
self.memory_gpu_only = cfg["kvc2"]["gpu_only"]
|
||||
self.cache_lens = ((self.cache_lens + self.page_size - 1) // self.page_size) * self.page_size
|
||||
self.gpu_memory_size = 2*576*61*self.cache_lens
|
||||
self.utilization_percentage = 1.0 #cfg["kvc2"]["utilization_percentage"]
|
||||
self.cpu_memory_size_GB = cfg["kvc2"]["cpu_memory_size_GB"]
|
||||
# only support 2 prefill task
|
||||
self.max_prefill_batch_size = 2
|
||||
self.max_decode_batch_size = self.max_batch_size - self.max_prefill_batch_size
|
||||
|
||||
|
|
|
@ -5,24 +5,20 @@ from fastapi.staticfiles import StaticFiles
|
|||
import uvicorn.logging
|
||||
import uvicorn
|
||||
import sys
|
||||
|
||||
import atexit
|
||||
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
sys.path.insert(0, project_dir)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from ktransformers.server.args import ArgumentParser
|
||||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.server.utils.create_interface import create_interface
|
||||
from ktransformers.server.backend.args import default_args
|
||||
from ktransformers.server.utils.create_interface import create_interface, GlobalInterface
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
|
||||
from ktransformers.server.api import router, post_db_creation_operations
|
||||
from ktransformers.server.utils.sql_utils import Base, SQLUtil
|
||||
from ktransformers.server.config.log import logger
|
||||
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
def mount_app_routes(mount_app: FastAPI):
|
||||
sql_util = SQLUtil()
|
||||
|
@ -34,7 +30,10 @@ def mount_app_routes(mount_app: FastAPI):
|
|||
|
||||
def create_app():
|
||||
cfg = Config()
|
||||
app = FastAPI()
|
||||
if(hasattr(GlobalInterface.interface, "lifespan")):
|
||||
app = FastAPI(lifespan=GlobalInterface.interface.lifespan)
|
||||
else:
|
||||
app = FastAPI()
|
||||
if Config().web_cross_domain:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
|
@ -108,11 +107,32 @@ def main():
|
|||
|
||||
arg_parser = ArgumentParser(cfg)
|
||||
|
||||
# 初始化消息
|
||||
args = arg_parser.parse_args()
|
||||
if args.backend_type == "balance_serve":
|
||||
import pickle
|
||||
def cleanup():
|
||||
if sched_process.poll() is None:
|
||||
sched_process.terminate()
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
pickle.dump(args, temp_file)
|
||||
temp_file_path = temp_file.name
|
||||
current_file = __file__
|
||||
target_file = os.path.join(os.path.dirname(current_file), "balance_serve", "sched_rpc.py")
|
||||
target_file = os.path.normpath(target_file)
|
||||
log_path = os.path.join(args.log_dir, "rpc.log")
|
||||
log = open(log_path, "a")
|
||||
sched_process = subprocess.Popen(
|
||||
["python3", target_file, "--config", temp_file_path],
|
||||
stdout=log,
|
||||
stderr=log
|
||||
)
|
||||
print("sched_rpc started with PID:", sched_process.pid)
|
||||
atexit.register(cleanup)
|
||||
create_interface(config=cfg, default_args=cfg)
|
||||
app = create_app()
|
||||
custom_openapi(app)
|
||||
create_interface(config=cfg, default_args=cfg)
|
||||
|
||||
run_api(
|
||||
app=app,
|
||||
host=args.host,
|
||||
|
@ -121,6 +141,5 @@ def main():
|
|||
ssl_certfile=args.ssl_certfile,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
torch >= 2.3.0,<=2.3.1
|
||||
torch >= 2.3.0
|
||||
transformers == 4.43.2
|
||||
fastapi >= 0.111.0
|
||||
langchain >= 0.2.0
|
||||
|
@ -11,4 +11,6 @@ build
|
|||
ninja
|
||||
wheel
|
||||
colorlog
|
||||
fire
|
||||
fire
|
||||
zmq
|
||||
psutil
|
|
@ -2,7 +2,7 @@ from typing import List, Optional
|
|||
from typing_extensions import Literal
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ktransformers.server.schemas.base import Object
|
||||
|
||||
|
@ -30,8 +30,8 @@ class ChatCompletionCreate(BaseModel):
|
|||
messages: List[Message]
|
||||
model : str
|
||||
stream : bool = False
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
temperature: Optional[float] = Field(default=1.0)
|
||||
top_p: Optional[float] = Field(default=1.0)
|
||||
|
||||
def get_tokenizer_messages(self):
|
||||
return [m.to_tokenizer_message() for m in self.messages]
|
||||
|
|
|
@ -15,6 +15,7 @@ from ktransformers.server.backend.context_manager import ThreadContextManager
|
|||
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface
|
||||
from ktransformers.server.backend.interfaces.transformers import TransformersInterface
|
||||
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface
|
||||
|
||||
def create_interface(config: Config, default_args: ConfigArgs):
|
||||
if config.backend_type=='transformers':
|
||||
from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
|
||||
|
@ -22,6 +23,8 @@ def create_interface(config: Config, default_args: ConfigArgs):
|
|||
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface
|
||||
elif config.backend_type == 'ktransformers':
|
||||
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
|
||||
elif config.backend_type == 'balance_serve':
|
||||
from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface as BackendInterface
|
||||
else:
|
||||
raise NotImplementedError(f'{config.backend_type} not implemented')
|
||||
GlobalInterface.interface = BackendInterface(default_args)
|
||||
|
@ -30,9 +33,9 @@ def create_interface(config: Config, default_args: ConfigArgs):
|
|||
class GlobalContextManager:
|
||||
context_manager: ThreadContextManager
|
||||
class GlobalInterface:
|
||||
interface: TransformersInterface | KTransformersInterface | ExllamaInterface
|
||||
interface: TransformersInterface | KTransformersInterface | ExllamaInterface
|
||||
|
||||
def get_thread_context_manager() -> ThreadContextManager:
|
||||
def get_thread_context_manager() -> GlobalContextManager:
|
||||
return GlobalContextManager.context_manager
|
||||
def get_interface() -> TransformersInterface | KTransformersInterface | ExllamaInterface:
|
||||
def get_interface() -> GlobalInterface:
|
||||
return GlobalInterface.interface
|
Loading…
Add table
Add a link
Reference in a new issue