mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 11:49:51 +00:00
* refactor: move legacy code to archive/ directory - Moved ktransformers, csrc, third_party, merge_tensors to archive/ - Moved build scripts and configurations to archive/ - Kept kt-kernel, KT-SFT, doc, and README files in root - Preserved complete git history for all moved files * refactor: restructure repository to focus on kt-kernel and KT-SFT modules * fix README * fix README * fix README * fix README * docs: add performance benchmarks to kt-kernel section Add comprehensive performance data for kt-kernel to match KT-SFT's presentation: - AMX kernel optimization: 21.3 TFLOPS (3.9× faster than PyTorch) - Prefill phase: up to 20× speedup vs baseline - Decode phase: up to 4× speedup - NUMA optimization: up to 63% throughput improvement - Multi-GPU (8×L20): 227.85 tokens/s total throughput with DeepSeek-R1 FP8 Source: https://lmsys.org/blog/2025-10-22-KTransformers/ This provides users with concrete performance metrics for both core modules, making it easier to understand the capabilities of each component. * refactor: improve kt-kernel performance data with specific hardware and models Replace generic performance descriptions with concrete benchmarks: - Specify exact hardware: 8×L20 GPU + Xeon Gold 6454S, Single/Dual-socket Xeon + AMX - Include specific models: DeepSeek-R1-0528 (FP8), DeepSeek-V3 (671B) - Show detailed metrics: total throughput, output throughput, concurrency details - Match KT-SFT presentation style for consistency This provides users with actionable performance data they can use to evaluate hardware requirements and expected performance for their use cases. * fix README * docs: clean up performance table and improve formatting * add pic for README * refactor: simplify .gitmodules and backup legacy submodules - Remove 7 legacy submodules from root .gitmodules (archive/third_party/*) - Keep only 2 active submodules for kt-kernel (llama.cpp, pybind11) - Backup complete .gitmodules to archive/.gitmodules - Add documentation in archive/README.md for researchers who need legacy submodules This reduces initial clone size by ~500MB and avoids downloading unused dependencies. * refactor: move doc/ back to root directory Keep documentation in root for easier access and maintenance. * refactor: consolidate all images to doc/assets/ - Move kt-kernel/assets/heterogeneous_computing.png to doc/assets/ - Remove KT-SFT/assets/ (images already in doc/assets/) - Update KT-SFT/README.md image references to ../doc/assets/ - Eliminates ~7.9MB image duplication - Centralizes all documentation assets in one location * fix pic path for README
167 lines
No EOL
6.6 KiB
Python
167 lines
No EOL
6.6 KiB
Python
from asyncio import Queue
|
|
from enum import Enum
|
|
import sys, os
|
|
from typing import AsyncIterator, Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
from ktransformers.server.config.log import logger
|
|
from ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager
|
|
from ktransformers.server.crud.assistants.messages import MessageDatabaseManager
|
|
from ktransformers.server.crud.assistants.runs import RunsDatabaseManager
|
|
from ktransformers.server.crud.assistants.threads import ThreadsDatabaseManager
|
|
from ktransformers.server.exceptions import request_error
|
|
from ktransformers.server.schemas.assistants.assistants import AssistantObject
|
|
from ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, Role
|
|
from ktransformers.server.schemas.assistants.runs import RunObject
|
|
from ktransformers.server.schemas.assistants.threads import ThreadObject
|
|
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
|
from ktransformers.server.schemas.base import ObjectID, Order
|
|
from ktransformers.server.utils.multi_timer import Profiler
|
|
|
|
|
|
from .args import ConfigArgs,default_args
|
|
|
|
|
|
|
|
class BackendInterfaceBase:
|
|
'''
|
|
Interface to inference frameworks. e.g. transformers, exllama.
|
|
Implement __init__ and work
|
|
'''
|
|
|
|
args: ConfigArgs
|
|
profiler:Profiler = Profiler()
|
|
|
|
def __init__(self, args:ConfigArgs = default_args):
|
|
raise NotImplementedError
|
|
|
|
|
|
async def inference(self,local_messages,request_unique_id:Optional[str])->AsyncIterator[str]:
|
|
'''
|
|
work can be called directly, or by ThreadContext
|
|
|
|
local_messages:
|
|
When called by ThreadContext, local_messages are generated by ThreadContext.get_local_messages().
|
|
Please deal with different local_messages
|
|
request_unique_id:
|
|
unique id of different requests, useful when using cache
|
|
|
|
return:
|
|
async str output for stream update
|
|
|
|
'''
|
|
raise NotImplementedError
|
|
|
|
|
|
def report_last_time_performance(self):
|
|
try:
|
|
tokenize_time = self.profiler.get_timer_sec('tokenize')
|
|
prefill_time = self.profiler.get_timer_sec('prefill')
|
|
decode_time = self.profiler.get_timer_sec('decode')
|
|
prefill_count = self.profiler.get_counter('prefill')
|
|
decode_count = self.profiler.get_counter('decode')
|
|
|
|
logger.info(f'Performance(T/s): prefill {prefill_count/prefill_time}, decode {decode_count/decode_time}. Time(s): tokenize {tokenize_time}, prefill {prefill_time}, decode {decode_time}')
|
|
except:
|
|
logger.info(f'Performance statistics not recorded')
|
|
|
|
|
|
class ThreadContext:
|
|
'''
|
|
A thread context holding assistant logics
|
|
|
|
'''
|
|
|
|
args: ConfigArgs
|
|
# Assistant Logic
|
|
assistant: Optional[AssistantObject] = None
|
|
related_threads : List[ThreadObject]
|
|
thread: ThreadObject
|
|
messages: List[MessageObject] = []
|
|
run: RunObject
|
|
|
|
interface: Optional[BackendInterfaceBase] = None
|
|
|
|
queue: Optional[Queue] = None
|
|
timer: Profiler = Profiler()
|
|
|
|
def __init__(self, run: RunObject,interface:BackendInterfaceBase, args: ConfigArgs = default_args) -> None:
|
|
self.args = args
|
|
self.thread_manager = ThreadsDatabaseManager()
|
|
self.message_manager = MessageDatabaseManager()
|
|
self.runs_manager = RunsDatabaseManager()
|
|
self.assistant_manager = AssistantDatabaseManager()
|
|
self.thread = self.thread_manager.db_get_thread_by_id(run.thread_id)
|
|
self.assistant = self.assistant_manager.db_get_assistant_by_id(run.assistant_id)
|
|
self.messages = self.message_manager.db_list_messages_of_thread(run.thread_id,order=Order.ASC)
|
|
logger.debug(f"{len(self.messages)} messages loaded from database")
|
|
self.interface = interface
|
|
self.update_by_run(run,args)
|
|
|
|
def get_local_messages(self):
|
|
'''
|
|
Get local messages, as the input to interface.work
|
|
This function is intended to message preprocess e.g. apply chat template
|
|
'''
|
|
raise NotImplementedError
|
|
|
|
def update_by_run(self,run:RunObject,args:ConfigArgs = default_args):
|
|
self.run = run
|
|
self.args = args
|
|
|
|
def put_user_message(self, message: MessageObject):
|
|
assert (
|
|
message.role.is_user() and message.thread_id == self.thread.id and message.status == MessageObject.Status.in_progress
|
|
)
|
|
self.messages.append(message)
|
|
|
|
def delete_user_message(self,message_id: ObjectID):
|
|
self.messages = [m for m in self.messages if m.id != message_id]
|
|
|
|
async def work(self)->AsyncIterator:
|
|
logger.debug('start working')
|
|
user_message = self.messages[-1]
|
|
if not user_message.role.is_user():
|
|
raise request_error('user must talk before LLM can talk')
|
|
user_message.status = MessageObject.Status.completed
|
|
user_message.sync_db()
|
|
|
|
local_messages = self.get_local_messages() # must get this before we interseted reply_message
|
|
|
|
|
|
response_str_count = 0
|
|
reply_message = self.message_manager.create_message_object(
|
|
self.thread.id,
|
|
self.run.id,
|
|
MessageCreate(role=Role.assistant, content=""),
|
|
)
|
|
reply_message.assistant_id = self.assistant.id
|
|
self.messages.append(reply_message)
|
|
|
|
yield reply_message.stream_response_with_event(MessageObject.Status.created)
|
|
yield reply_message.stream_response_with_event(MessageObject.Status.in_progress)
|
|
yield self.run.stream_response_with_event(RunObject.Status.in_progress)
|
|
|
|
async for res in self.interface.inference(local_messages,self.thread.id):
|
|
if isinstance(res, RawUsage):
|
|
raw_usage = res
|
|
else:
|
|
token, finish_reason = res
|
|
if self.run.status == RunObject.Status.cancelling:
|
|
logger.warn(f'Run {self.run.id} cancelling')
|
|
break
|
|
yield reply_message.append_message_delta(token)
|
|
response_str_count+=1
|
|
|
|
if self.run.status == RunObject.Status.cancelling:
|
|
yield self.run.stream_response_with_event(RunObject.Status.cancelled)
|
|
yield reply_message.stream_response_with_event(MessageObject.Status.incomplete)
|
|
elif self.run.status == RunObject.Status.in_progress:
|
|
yield self.run.stream_response_with_event(RunObject.Status.completed)
|
|
yield reply_message.stream_response_with_event(MessageObject.Status.completed)
|
|
else:
|
|
raise NotImplementedError(f'{self.run.status} should not appear here')
|
|
|
|
reply_message.sync_db()
|
|
self.run.sync_db() |