kvcache-ai-ktransformers/archive/ktransformers/server/backend/base.py
Jiaqi Liao 57d14d22bc
Refactor: restructure repository to focus on kt-kernel and KT-SFT modulesq recon (#1581)
* refactor: move legacy code to archive/ directory

  - Moved ktransformers, csrc, third_party, merge_tensors to archive/
  - Moved build scripts and configurations to archive/
  - Kept kt-kernel, KT-SFT, doc, and README files in root
  - Preserved complete git history for all moved files

* refactor: restructure repository to focus on kt-kernel and KT-SFT modules

* fix README

* fix README

* fix README

* fix README

* docs: add performance benchmarks to kt-kernel section

Add comprehensive performance data for kt-kernel to match KT-SFT's presentation:
- AMX kernel optimization: 21.3 TFLOPS (3.9× faster than PyTorch)
- Prefill phase: up to 20× speedup vs baseline
- Decode phase: up to 4× speedup
- NUMA optimization: up to 63% throughput improvement
- Multi-GPU (8×L20): 227.85 tokens/s total throughput with DeepSeek-R1 FP8

Source: https://lmsys.org/blog/2025-10-22-KTransformers/

This provides users with concrete performance metrics for both core modules,
making it easier to understand the capabilities of each component.

* refactor: improve kt-kernel performance data with specific hardware and models

Replace generic performance descriptions with concrete benchmarks:
- Specify exact hardware: 8×L20 GPU + Xeon Gold 6454S, Single/Dual-socket Xeon + AMX
- Include specific models: DeepSeek-R1-0528 (FP8), DeepSeek-V3 (671B)
- Show detailed metrics: total throughput, output throughput, concurrency details
- Match KT-SFT presentation style for consistency

This provides users with actionable performance data they can use to evaluate
hardware requirements and expected performance for their use cases.

* fix README

* docs: clean up performance table and improve formatting

* add pic for README

* refactor: simplify .gitmodules and backup legacy submodules

- Remove 7 legacy submodules from root .gitmodules (archive/third_party/*)
- Keep only 2 active submodules for kt-kernel (llama.cpp, pybind11)
- Backup complete .gitmodules to archive/.gitmodules
- Add documentation in archive/README.md for researchers who need legacy submodules

This reduces initial clone size by ~500MB and avoids downloading unused dependencies.

* refactor: move doc/ back to root directory

Keep documentation in root for easier access and maintenance.

* refactor: consolidate all images to doc/assets/

- Move kt-kernel/assets/heterogeneous_computing.png to doc/assets/
- Remove KT-SFT/assets/ (images already in doc/assets/)
- Update KT-SFT/README.md image references to ../doc/assets/
- Eliminates ~7.9MB image duplication
- Centralizes all documentation assets in one location

* fix pic path for README
2025-11-10 17:42:26 +08:00

167 lines
No EOL
6.6 KiB
Python

from asyncio import Queue
from enum import Enum
import sys, os
from typing import AsyncIterator, Dict, List, Optional, Tuple
import torch
from ktransformers.server.config.log import logger
from ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager
from ktransformers.server.crud.assistants.messages import MessageDatabaseManager
from ktransformers.server.crud.assistants.runs import RunsDatabaseManager
from ktransformers.server.crud.assistants.threads import ThreadsDatabaseManager
from ktransformers.server.exceptions import request_error
from ktransformers.server.schemas.assistants.assistants import AssistantObject
from ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, Role
from ktransformers.server.schemas.assistants.runs import RunObject
from ktransformers.server.schemas.assistants.threads import ThreadObject
from ktransformers.server.schemas.endpoints.chat import RawUsage
from ktransformers.server.schemas.base import ObjectID, Order
from ktransformers.server.utils.multi_timer import Profiler
from .args import ConfigArgs,default_args
class BackendInterfaceBase:
'''
Interface to inference frameworks. e.g. transformers, exllama.
Implement __init__ and work
'''
args: ConfigArgs
profiler:Profiler = Profiler()
def __init__(self, args:ConfigArgs = default_args):
raise NotImplementedError
async def inference(self,local_messages,request_unique_id:Optional[str])->AsyncIterator[str]:
'''
work can be called directly, or by ThreadContext
local_messages:
When called by ThreadContext, local_messages are generated by ThreadContext.get_local_messages().
Please deal with different local_messages
request_unique_id:
unique id of different requests, useful when using cache
return:
async str output for stream update
'''
raise NotImplementedError
def report_last_time_performance(self):
try:
tokenize_time = self.profiler.get_timer_sec('tokenize')
prefill_time = self.profiler.get_timer_sec('prefill')
decode_time = self.profiler.get_timer_sec('decode')
prefill_count = self.profiler.get_counter('prefill')
decode_count = self.profiler.get_counter('decode')
logger.info(f'Performance(T/s): prefill {prefill_count/prefill_time}, decode {decode_count/decode_time}. Time(s): tokenize {tokenize_time}, prefill {prefill_time}, decode {decode_time}')
except:
logger.info(f'Performance statistics not recorded')
class ThreadContext:
'''
A thread context holding assistant logics
'''
args: ConfigArgs
# Assistant Logic
assistant: Optional[AssistantObject] = None
related_threads : List[ThreadObject]
thread: ThreadObject
messages: List[MessageObject] = []
run: RunObject
interface: Optional[BackendInterfaceBase] = None
queue: Optional[Queue] = None
timer: Profiler = Profiler()
def __init__(self, run: RunObject,interface:BackendInterfaceBase, args: ConfigArgs = default_args) -> None:
self.args = args
self.thread_manager = ThreadsDatabaseManager()
self.message_manager = MessageDatabaseManager()
self.runs_manager = RunsDatabaseManager()
self.assistant_manager = AssistantDatabaseManager()
self.thread = self.thread_manager.db_get_thread_by_id(run.thread_id)
self.assistant = self.assistant_manager.db_get_assistant_by_id(run.assistant_id)
self.messages = self.message_manager.db_list_messages_of_thread(run.thread_id,order=Order.ASC)
logger.debug(f"{len(self.messages)} messages loaded from database")
self.interface = interface
self.update_by_run(run,args)
def get_local_messages(self):
'''
Get local messages, as the input to interface.work
This function is intended to message preprocess e.g. apply chat template
'''
raise NotImplementedError
def update_by_run(self,run:RunObject,args:ConfigArgs = default_args):
self.run = run
self.args = args
def put_user_message(self, message: MessageObject):
assert (
message.role.is_user() and message.thread_id == self.thread.id and message.status == MessageObject.Status.in_progress
)
self.messages.append(message)
def delete_user_message(self,message_id: ObjectID):
self.messages = [m for m in self.messages if m.id != message_id]
async def work(self)->AsyncIterator:
logger.debug('start working')
user_message = self.messages[-1]
if not user_message.role.is_user():
raise request_error('user must talk before LLM can talk')
user_message.status = MessageObject.Status.completed
user_message.sync_db()
local_messages = self.get_local_messages() # must get this before we interseted reply_message
response_str_count = 0
reply_message = self.message_manager.create_message_object(
self.thread.id,
self.run.id,
MessageCreate(role=Role.assistant, content=""),
)
reply_message.assistant_id = self.assistant.id
self.messages.append(reply_message)
yield reply_message.stream_response_with_event(MessageObject.Status.created)
yield reply_message.stream_response_with_event(MessageObject.Status.in_progress)
yield self.run.stream_response_with_event(RunObject.Status.in_progress)
async for res in self.interface.inference(local_messages,self.thread.id):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
if self.run.status == RunObject.Status.cancelling:
logger.warn(f'Run {self.run.id} cancelling')
break
yield reply_message.append_message_delta(token)
response_str_count+=1
if self.run.status == RunObject.Status.cancelling:
yield self.run.stream_response_with_event(RunObject.Status.cancelled)
yield reply_message.stream_response_with_event(MessageObject.Status.incomplete)
elif self.run.status == RunObject.Status.in_progress:
yield self.run.stream_response_with_event(RunObject.Status.completed)
yield reply_message.stream_response_with_event(MessageObject.Status.completed)
else:
raise NotImplementedError(f'{self.run.status} should not appear here')
reply_message.sync_db()
self.run.sync_db()