update speed test

This commit is contained in:
qiyuxinlin 2025-04-22 07:38:05 +00:00
parent f5287e908a
commit b17ab8653c
4 changed files with 66 additions and 31 deletions

View file

@ -13,16 +13,10 @@ from ktransformers.server.schemas.endpoints.chat import RawUsage, Role
from ktransformers.server.backend.base import BackendInterfaceBase
from ktransformers.server.config.config import Config
from ktransformers.server.config.log import logger
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk, CompletionUsage
# Define own data structure instead of importing from OpenAI
class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_tokens_details: Optional[Dict[str, Any]] = None
completion_tokens_details: Optional[Dict[str, Any]] = None
class Choice(BaseModel):
index: int
@ -217,6 +211,12 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
completion_tokens=raw_usage.decode_count,
total_tokens=raw_usage.prefill_count + raw_usage.decode_count
)
if create.return_speed:
chunk.usage.prefill_time = res.prefill_time
chunk.usage.decode_time = res.decode_time
else:
chunk.usage.__dict__.pop('prefill_time', None)
chunk.usage.__dict__.pop('decode_time', None)
yield chunk
elif isinstance(res, tuple) and len(res) == 2:
token, finish_reason = res
@ -377,8 +377,15 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
usage = CompletionUsage(
prompt_tokens=raw_usage.prefill_count,
completion_tokens=raw_usage.decode_count,
total_tokens=raw_usage.prefill_count + raw_usage.decode_count
total_tokens=raw_usage.prefill_count + raw_usage.decode_count,
)
if create.return_speed:
usage.prefill_time = res.prefill_time
usage.decode_time = res.decode_time
else:
usage.__dict__.pop('prefill_time', None)
usage.__dict__.pop('decode_time', None)
elif isinstance(res, tuple) and len(res) == 2:
token, finish_reason = res
token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token)