Refactor: restructure repository to focus on kt-kernel and KT-SFT modulesq recon (#1581)

* refactor: move legacy code to archive/ directory

  - Moved ktransformers, csrc, third_party, merge_tensors to archive/
  - Moved build scripts and configurations to archive/
  - Kept kt-kernel, KT-SFT, doc, and README files in root
  - Preserved complete git history for all moved files

* refactor: restructure repository to focus on kt-kernel and KT-SFT modules

* fix README

* fix README

* fix README

* fix README

* docs: add performance benchmarks to kt-kernel section

Add comprehensive performance data for kt-kernel to match KT-SFT's presentation:
- AMX kernel optimization: 21.3 TFLOPS (3.9× faster than PyTorch)
- Prefill phase: up to 20× speedup vs baseline
- Decode phase: up to 4× speedup
- NUMA optimization: up to 63% throughput improvement
- Multi-GPU (8×L20): 227.85 tokens/s total throughput with DeepSeek-R1 FP8

Source: https://lmsys.org/blog/2025-10-22-KTransformers/

This provides users with concrete performance metrics for both core modules,
making it easier to understand the capabilities of each component.

* refactor: improve kt-kernel performance data with specific hardware and models

Replace generic performance descriptions with concrete benchmarks:
- Specify exact hardware: 8×L20 GPU + Xeon Gold 6454S, Single/Dual-socket Xeon + AMX
- Include specific models: DeepSeek-R1-0528 (FP8), DeepSeek-V3 (671B)
- Show detailed metrics: total throughput, output throughput, concurrency details
- Match KT-SFT presentation style for consistency

This provides users with actionable performance data they can use to evaluate
hardware requirements and expected performance for their use cases.

* fix README

* docs: clean up performance table and improve formatting

* add pic for README

* refactor: simplify .gitmodules and backup legacy submodules

- Remove 7 legacy submodules from root .gitmodules (archive/third_party/*)
- Keep only 2 active submodules for kt-kernel (llama.cpp, pybind11)
- Backup complete .gitmodules to archive/.gitmodules
- Add documentation in archive/README.md for researchers who need legacy submodules

This reduces initial clone size by ~500MB and avoids downloading unused dependencies.

* refactor: move doc/ back to root directory

Keep documentation in root for easier access and maintenance.

* refactor: consolidate all images to doc/assets/

- Move kt-kernel/assets/heterogeneous_computing.png to doc/assets/
- Remove KT-SFT/assets/ (images already in doc/assets/)
- Update KT-SFT/README.md image references to ../doc/assets/
- Eliminates ~7.9MB image duplication
- Centralizes all documentation assets in one location

* fix pic path for README
This commit is contained in:
Jiaqi Liao 2025-11-10 17:42:26 +08:00 committed by GitHub
parent 8729435d85
commit 57d14d22bc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
510 changed files with 711 additions and 334 deletions

View file

@ -0,0 +1,665 @@
from typing import Any, AsyncIterator, List, Optional, Set
from ktransformers.models.custom_cache import KVC2StaticCache, KDeepSeekV3Cache, KGQACache
from transformers import (
AutoTokenizer,
AutoConfig,
GenerationConfig,
StaticCache,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
import torch.distributed as dist
from ktransformers.server.config.config import Config
from ..base import ThreadContext, BackendInterfaceBase
import torch
from ktransformers.server.backend.interfaces.transformers import (
ConfigArgs,
default_args,
TextStreamer,
)
from ktransformers.server.schemas.base import ObjectID
from ktransformers.server.config.log import logger
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM
from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM
from ktransformers.models.custom_modeling_smallthinker import KSmallThinkerForCausalLM
from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM
from ktransformers.models.custom_modeling_qwen3_next import KQwen3NextForCausalLM
from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig
from ktransformers.models.configuration_smallthinker import SmallthinkerConfig
from ktransformers.models.configuration_glm4_moe import Glm4MoeConfig
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
try:
import torch_npu
use_torch_npu = torch.npu.is_available()
except:
use_torch_npu = False
if use_torch_npu:
from ktransformers.models.ascend.custom_ascend_modeling_deepseek_v3 import KNPUDeepseekV3ForCausalLM
from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel, get_tensor_parallel_group, get_tensor_parallel_size
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util import utils
custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"MixtralForCausalLM": MixtralForCausalLM,
} #TODO 独有?
from ktransformers.server.balance_serve.inference.model_runner import ModelRunner, get_or_create_model_runner #TODO get_or_create_model_runner npu独有
from ktransformers.models.configuration_qwen3_next import Qwen3NextConfig
from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput
from ktransformers.server.balance_serve.sched_rpc import SchedulerClient
from ktransformers.server.balance_serve.settings import sched_ext
from torch.multiprocessing import Queue
import torch.multiprocessing as mp
from multiprocessing.synchronize import Event
import datetime
from ktransformers.server.schemas.endpoints.chat import RawUsage
from ktransformers.server.utils.multi_timer import Profiler
import zmq
import time
import queue
import tempfile
import asyncio
import cProfile
import threading
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
import os
import pickle
import subprocess
import tempfile
import atexit
import signal
ktransformer_rules_dir = (
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/")
)
default_optimize_rules = {
# "DeepseekV3ForCausalLM": ktransformer_rules_dir + "Moonlight-16B-A3B-serve.yaml",
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-serve.yaml",
"Qwen3MoeForCausalLM": ktransformer_rules_dir + "Qwen3Moe-serve.yaml",
"SmallThinkerForCausalLM": ktransformer_rules_dir + "Smallthinker-serve.yaml",
"Glm4MoeForCausalLM": ktransformer_rules_dir + "Glm4Moe-serve.yaml",
"Qwen3NextForCausalLM": ktransformer_rules_dir + "Qwen3Next-serve.yaml",
}
if use_torch_npu:
default_optimize_rules["Qwen2MoeForCausalLM"] = ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct-serve.yaml"
async def chat_stream(queue: asyncio.Queue, tokenizer: AutoTokenizer):
streamer = TextStreamer(tokenizer)
while True:
token = await queue.get()
#print(f"Got token: {token}")
if token is None:
# str = f'{token}\n\n'
# str = model.tokenizer.decode(token)
s = streamer.end()
if s is not None:
yield s
break
else:
# text output
text = tokenizer.decode(token)
print(text, end="", flush=True)
# str = model.tokenizer.decode(token)
yield streamer.put(token)
def fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_tokens: torch.Tensor, query_manager: QueryManager = None):
#print(len(query_updates), generated_tokens.size(0), generated_tokens)
for i in range(generated_tokens.size(0)):
# print(generated_tokens[i].item())
query_updates[i].generated_token = generated_tokens[i].item()
if not query_manager.query_map[query_updates[i].id].is_prefill:
pos = query_updates[i].active_position
if pos < query_manager.query_map[query_updates[i].id].max_length:
query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i]
def report_last_time_performance(profiler: Profiler):
try:
tokenize_time = profiler.get_timer_sec('tokenize')
prefill_time = profiler.get_timer_sec('prefill')
decode_time = profiler.get_timer_sec('decode')
prefill_count = profiler.get_counter('prefill')
decode_count = profiler.get_counter('decode')
logger.info(f'Performance(T/s): prefill {prefill_count/prefill_time}, decode {decode_count/decode_time}. Time(s): tokenize {tokenize_time}, prefill {prefill_time}, decode {decode_time}')
except:
logger.info(f'Performance statistics not recorded')
class Engine:
sched_client : SchedulerClient
updates : list[sched_ext.QueryUpdate]
batch : sched_ext.BatchQueryTodo
model_runner: ModelRunner
sampler: Sampler
query_manager: QueryManager
cache: KDeepSeekV3Cache | KGQACache | KVC2StaticCache
def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None, kvcache_event: Event = None):
self.args = args
# 子进程和父进程无法共享 config 变量
for key, value in vars(args).items():
if value is not None and hasattr(Config(), key):
setattr(Config(), key, value)
if use_torch_npu:
utils.CUR_DEVICE = f"npu:{torch.npu.current_device()}"
self.device = f"npu:{torch.npu.current_device()}"
else:
self.device = self.args.device
self.sched_client = SchedulerClient(args.sched_port)
self.updates = []
print(f"args.architectures: {args.architectures}")
if args.architectures == "Qwen3MoeForCausalLM":
config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
elif args.architectures == "Glm4MoeForCausalLM":
config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
elif args.architectures == "SmallThinkerForCausalLM":
config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
config._attn_implementation = "eager"
config.moe_intermediate_size = config.moe_ffn_hidden_size
elif args.architectures == "Qwen3NextForCausalLM":
config = Qwen3NextConfig.from_pretrained(args.model_dir, trust_remote_code=True)
else:
try:
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
except:
raise ValueError(f"Model {args.architectures} not supported. Please check your model directory or model name.")
self.gen_queue = generated_token_queue
self.debug = False
self.profiler_cprofile = cProfile.Profile()
self.cprof_prof_cnt, self.max_cprof_prof_cnt = 0, 8
with torch.device("meta"):
if config.architectures[0] == "DeepseekV3ForCausalLM":
if use_torch_npu:
self.cache = KVC2StaticCache(config, args.max_batch_size, self.args.page_size)
self.model = KNPUDeepseekV3ForCausalLM(config)
else:
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
self.model = KDeepseekV3ForCausalLM(config, self.cache)
elif config.architectures[0] == "DeepseekV2ForCausalLM":
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
self.model = KDeepseekV2ForCausalLM(config, self.cache)
elif config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM":
self.cache = KGQACache(config, self.args.page_size)
if config.architectures[0] == "Qwen2MoeForCausalLM":
self.model = KQwen2MoeForCausalLM(config, self.cache)
else:
self.model = KQwen3MoeForCausalLM(config, self.cache)
elif config.architectures[0] == "SmallThinkerForCausalLM":
self.cache = KGQACache(config, self.args.page_size)
self.model = KSmallThinkerForCausalLM(config, self.cache)
elif config.architectures[0] == "Glm4MoeForCausalLM":
self.cache = KGQACache(config, self.args.page_size)
self.model = KGlm4MoeForCausalLM(config, self.cache)
elif config.architectures[0] == "Qwen3NextForCausalLM":
self.cache = KGQACache(config, self.args.page_size)
self.model = KQwen3NextForCausalLM(config, self.cache)
context = zmq.Context()
if use_torch_npu:
if torch.distributed.get_rank() == 0:
self.pub_socket = context.socket(zmq.PUB)
self.pub_socket.bind(f"ipc://{broadcast_endpoint}")
self.sub_socket = None
else:
self.sub_socket = context.socket(zmq.SUB)
self.sub_socket.connect(f"ipc://{broadcast_endpoint}")
self.sub_socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.pub_socket = None
# time.sleep(1) # make sure all subscribers are ready
else:
self.pub_socket = context.socket(zmq.PUB)
self.pub_socket.bind(f"ipc://{broadcast_endpoint}")
try:
generation_config = GenerationConfig.from_pretrained(args.model_dir)
except:
generation_config = GenerationConfig(
max_length=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
do_sample=True
)
if args.optimize_config_path is None:
optimize_config_path = default_optimize_rules[config.architectures[0]]
else:
optimize_config_path = args.optimize_config_path
gguf_path = args.gguf_path
if gguf_path is None:
gguf_path = input(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
" belong to current model):"
)
if use_torch_npu:
tp_group = get_tensor_parallel_group()
torch.distributed.barrier(group=tp_group)
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
if use_torch_npu:
get_absort_weight(self.model, config) #TODO
torch.distributed.barrier(group=tp_group)
self.model.generation_config = generation_config
if self.model.generation_config.pad_token_id is None:
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
self.model.eval()
kvcache_event.set()
# load kvcache
print(f"Getting inference context from sched_client.")
inference_context = self.sched_client.get_inference_context_raw()
print(f"Got inference context, sending it to subscribers.")
inference_context = self.sched_client.rebuild_inferece_context(inference_context)
self.cache.load(inference_context)
print(f"kv_cache loaded successfully.")
self.block_num = inference_context.k_cache[0].size(1)
#TODO ModelRunner 区别
# self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
#@TODO add config
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallThinkerForCausalLM" or config.architectures[0] == "Qwen3NextForCausalLM":
self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num)
else:
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
# self.args.use_cuda_graph代表是否使用图下沉
self.model_runner = get_or_create_model_runner(self.model, self.cache, self.device, self.args.use_cuda_graph, page_size = args.page_size)
self.sampler = Sampler()
self.query_manager = QueryManager(device = self.device, page_size = args.page_size)
def sampling(self, forward_output: ForwardBatchOutput):
generated_tokens = []
probs = []
for i in range(forward_output.num_batchs):
logit = forward_output.logits[i]
if hasattr(forward_output, "temperatures"):
temperatures = forward_output.temperatures[i]
else:
temperatures = None
if hasattr(forward_output, "top_ps"):
top_ps = forward_output.top_ps[i]
else:
top_ps = None
sample_options = SamplingOptions(logit.size(0), self.device, pretrained_config=self.model.generation_config, temperatures=temperatures, top_ps=top_ps)
generated_token, prob=self.sampler(logit, sample_options)
generated_tokens.append(generated_token.clone())
probs.append(prob.clone())
generated_tokens, probs = torch.cat(generated_tokens), torch.cat(probs, dim=0)
return generated_tokens, probs
def loop(self):
next_batch = None
while True:
self.batch = next_batch
if self.batch is not None:
if use_torch_npu:
batch_size = 0
for i in range(len(self.batch.decode_mini_batches)):
batch_size += len(self.batch.decode_mini_batches[i])
logger.debug(f"prefill batch: {len(self.batch.prefill_mini_batches)} decode batch: {len(self.batch.decode_mini_batches)} {batch_size} \n")
self.model_runner.run_split(self.batch, self.query_manager)
else:
self.model_runner.run(self.batch, self.query_manager)
if len(self.updates) > 0:
for q in self.updates:
if q.is_prefill == True:
continue
# print(f"Putting token {q.generated_token} into queue for query id: {q.id}")
try:
if use_torch_npu:
if torch.distributed.get_rank() == 0:
self.gen_queue.put((q.id, q.generated_token if q.decode_done == False else None), timeout=5)
else:
self.gen_queue.put((q.id, q.generated_token if q.decode_done == False else None), timeout=5)
except queue.Full:
pass#print("Queue is full after timeout; unable to put more items.")
if use_torch_npu:
if torch.distributed.get_rank() == 0:
next_batch = self.sched_client.update_last_batch(self.updates)
if next_batch.query_ids == []:
next_batch = None
self.pub_socket.send_pyobj(next_batch)
else:
next_batch = self.sub_socket.recv_pyobj()
else:
next_batch = self.sched_client.update_last_batch(self.updates)
if next_batch.query_ids == []:
next_batch = None
self.pub_socket.send_pyobj(next_batch)
if next_batch is not None:
self.query_manager.add_query(next_batch)
if self.batch is not None:
self.model_runner.sync()
# print(f"Model execution time (GPU): {self.model_runner.model_time:.3f} ms")
# if self.rank == 0:
generated_tokens, probs = self.sampling( self.model_runner.output)
self.updates = self.query_manager.update(self.batch)
fill_generated_tokens(self.updates, generated_tokens, self.query_manager)
else:
self.updates = []
class BalanceServeThreadContext(ThreadContext):
def get_local_messages(self):
local_messages = []
for m in self.messages:
local_messages.append({"role": m.role.value, "content": m.get_text_content()})
return local_messages
def init_distributed(rank: int,
world_size: int,
tp_size: int,
master_addr: str = os.getenv("MASTER_ADDR", "127.0.0.1"),
master_port: int = os.getenv("MASTER_PORT", "29500"),
backend: str = "hccl"): #TODO csx: 是否distribute 都只与NPU有关
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = str(master_port)
local_rank, world_size = setup_model_parallel(tp=tp_size)
return local_rank, world_size
def run_engine(args, token_queue, broadcast_endpoint, event, kvcache_event, rank=None, world_size=None):
if use_torch_npu:
init_distributed(rank, world_size, args.tp, backend="hccl") #TODO 同上
import torch.distributed as dist
engine = Engine(args, token_queue, broadcast_endpoint, kvcache_event)
if args.use_cuda_graph:
if 'npu' in engine.device:
engine.model_runner.warmup_npu()
else:
engine.model_runner.warmup()
if use_torch_npu:
args.port += torch.distributed.get_rank()
event.set()
engine.loop()
class BalanceServeInterface(BackendInterfaceBase):
use_static_cache: bool = True
model: Any
tokenizer: AutoTokenizer
cache: StaticCache
generated_ids: torch.Tensor
seq_length: int
streamer: TextStreamer
# thread_related
last_request_id: Optional[str] = None
ever_generated_ids: Set[int] = set()
def __init__(self, args: ConfigArgs = default_args, input_args=None):
self.args = input_args
self.queue_map:dict[int,asyncio.Queue] = {}
self.thread_map: dict[int, int] = {}
processes = []
self.broadcast_endpoint = tempfile.NamedTemporaryFile(delete=False).name # @TODO add to config
ctx = mp.get_context("spawn")
self.token_queue = ctx.Queue(maxsize=1000)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
self.sched_client = SchedulerClient(args.sched_port)
self.streamer = TextStreamer(self.tokenizer)
if use_torch_npu:
world_size = str(os.getenv("WORLD_SIZE", self.args.tp))
if not isinstance(world_size, str):
raise ValueError(f"world_size ({world_size}) must be str")
start_events = []
kvcache_events = []
for rank in range(self.args.tp):
if int(self.args.device[-1]) > 0:
break
start_event = ctx.Event()
kvcache_event = ctx.Event()
p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event,
kvcache_event, rank, world_size))
p.start()
processes.append(p)
start_events.append(start_event)
kvcache_events.append(kvcache_event)
for evt in kvcache_events:
evt.wait()
self._engines = processes
else:
start_event = ctx.Event()
kvcache_event = ctx.Event()
p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event,
kvcache_event))
p.start()
processes.append(p)
kvcache_event.wait()
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
args.tp = input_args.tp
pickle.dump(args, temp_file)
temp_file_path = temp_file.name
current_file = __file__
target_file = os.path.join(os.path.dirname(current_file), "..", "..", "balance_serve", "sched_rpc.py")
target_file = os.path.normpath(target_file)
log_path = os.path.join(args.log_dir, "rpc.log")
log = open(log_path, "a")
sched_process = subprocess.Popen(
["python3", target_file, "--config", temp_file_path],
stdout=log,
stderr=log
)
print("sched_rpc started with PID:", sched_process.pid)
def signal_handler(signum, frame):
print(f"Received signal {signum}, shutting down...")
cleanup()
os._exit(0)
def cleanup():
print("Cleaning up...")
for p in processes:
if p.is_alive():
print(f"Terminating subprocess {p.pid}")
p.terminate()
p.join()
if sched_process and sched_process.poll() is None:
print(f"Terminating sched_process {sched_process.pid}")
sched_process.terminate()
sched_process.wait()
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
if use_torch_npu:
for evt in start_events:
evt.wait()
else:
start_event.wait()
def get_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None,
max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None) -> tuple[float, float]:
"""Get sampling parameters and handle default values and edge cases"""
if max_tokens is not None:
max_completion_tokens = max_tokens
if max_completion_tokens is None:
max_completion_tokens = self.args.max_new_tokens
else:
max_completion_tokens = min(self.args.max_new_tokens, max_completion_tokens)
if temperature is None:
temperature = self.args.temperature
if top_p is None:
top_p = self.args.top_p
if temperature == 0:
temperature = 0.0001
if top_p == 0:
top_p = 0.0001
return temperature, top_p, max_completion_tokens
def run_queue_proxy(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self.queue_proxy())
@asynccontextmanager
async def lifespan(self, app: FastAPI):
asyncio.create_task(self.queue_proxy())
yield
async def queue_proxy(self):
print("Queue Proxy Started")
while True:
try:
query_id, token = self.token_queue.get_nowait()
try:
# query id might not be allocated yet
self.queue_map[query_id].put_nowait(token)
#print(f"Proxy Put token: {token} to queue for query id: {query_id}")
except asyncio.QueueFull:
#print(f"Queue for query id: {query_id} is full, waiting to put: {token}")
await self.queue_map[query_id].put(token)
except queue.Empty:
# print("no new token")
# await asyncio.sleep(1)
await asyncio.sleep(0)
def tokenize_prompt(self, prompt: str):
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.args.device)
return input_ids
def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):
for m in messages:
if m["role"] == "system":
logger.warning(f'change {m["role"]} to user')
m["role"] = "user"
new_messages = [messages[0]]
for m in messages[1:]:
if m["role"] == "user" and new_messages[-1]["role"] == "user":
logger.warning("merge two adjacent user messages")
new_messages[-1]["content"] += '\n' + m["content"]
else:
new_messages.append(m)
# input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
# # drop <think> token in chat template
# if input_str.endswith('<think>\n'):
# input_str = input_str[:-len('<think>\n')]
input_ids = self.tokenizer.apply_chat_template(new_messages, add_generation_prompt=True, return_tensors="pt").to(self.args.device)
return input_ids
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = 0, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
profiler = Profiler()
profiler.create_and_start_timer("tokenize")
if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str):
#local_messages = local_messages[0]['content']
input_ids = self.tokenize_prompt(local_messages)
else:
raise ValueError("local_messages should be List or str")
if Config().user_force_think:
token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
if not torch.equal(input_ids[0, -token_thinks.shape[-1]:], token_thinks[-1]): #TODO 此行新加的考虑是否影响GPU
input_ids = torch.cat(
[input_ids, token_thinks], dim=1
)
logger.debug(f"get input ids of shape {input_ids.shape}")
profiler.pause_timer("tokenize")
profiler.create_and_start_timer("prefill")
query_add = sched_ext.QueryAdd()
query_add.query_token = input_ids[0].tolist()
query_length = input_ids[0].shape[0]
query_add.query_length = query_length
profiler.set_counter("prefill", query_length)
#@TODO add server
stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
query_add.stop_criteria = stop_criteria
temperature, top_p, max_new_tokens = self.get_params(temperature, top_p, max_tokens, max_completion_tokens)
query_add.sample_options.temperature = temperature
if top_p == 0 or top_p is None:
top_p = 0.0001
query_add.sample_options.top_p = top_p
query_add.estimated_length = min(self.args.cache_lens, query_length+max_new_tokens)
query_id = self.sched_client.add_query(query_add)
queue = asyncio.Queue(maxsize=max_new_tokens)
self.queue_map[query_id] = queue
self.thread_map[thread_id] = query_id
is_first_token = True
async for token in chat_stream(self.queue_map[query_id], self.tokenizer):
if is_first_token:
is_first_token=False
profiler.pause_timer("prefill")
profiler.create_and_start_timer("decode")
profiler.set_counter("decode", 0)
if Config().user_force_think:
think = '<think>\n'
print(think, end="",flush=True)
yield think, None
else:
profiler.inc("decode")
# TODO: 传入rank避免打印重复
yield token, None
profiler.pause_timer("decode")
report_last_time_performance(profiler)
yield self.streamer.end(), None
if profiler.get_counter('decode') >= max_new_tokens - 1:
yield "", "length"
else:
yield "", "stop"
yield RawUsage(
tokenize_time = profiler.get_timer_sec('tokenize'),
prefill_time = profiler.get_timer_sec('prefill'),
decode_time = profiler.get_timer_sec('decode'),
prefill_count = profiler.get_counter('prefill'),
decode_count = profiler.get_counter('decode'),
)