Merge remote-tracking branch 'origin/main' into check-para

This commit is contained in:
Alisehen 2025-04-23 02:40:14 +00:00
commit f7d939313b
8 changed files with 219 additions and 145 deletions

View file

@ -14,15 +14,10 @@ from ktransformers.server.backend.base import BackendInterfaceBase
from ktransformers.server.config.config import Config
from ktransformers.server.config.log import logger
from fastapi.responses import JSONResponse
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk, CompletionUsage
# Define own data structure instead of importing from OpenAI
class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_tokens_details: Optional[Dict[str, Any]] = None
completion_tokens_details: Optional[Dict[str, Any]] = None
class Choice(BaseModel):
index: int
@ -267,6 +262,12 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
completion_tokens=raw_usage.decode_count,
total_tokens=raw_usage.prefill_count + raw_usage.decode_count
)
if create.return_speed:
chunk.usage.prefill_time = res.prefill_time
chunk.usage.decode_time = res.decode_time
else:
chunk.usage.__dict__.pop('prefill_time', None)
chunk.usage.__dict__.pop('decode_time', None)
yield chunk
elif isinstance(res, tuple) and len(res) == 2:
token, finish_reason = res
@ -427,8 +428,15 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
usage = CompletionUsage(
prompt_tokens=raw_usage.prefill_count,
completion_tokens=raw_usage.decode_count,
total_tokens=raw_usage.prefill_count + raw_usage.decode_count
total_tokens=raw_usage.prefill_count + raw_usage.decode_count,
)
if create.return_speed:
usage.prefill_time = res.prefill_time
usage.decode_time = res.decode_time
else:
usage.__dict__.pop('prefill_time', None)
usage.__dict__.pop('decode_time', None)
elif isinstance(res, tuple) and len(res) == 2:
token, finish_reason = res
token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token)