add check-para

This commit is contained in:
Alisehen 2025-04-22 09:30:08 +00:00
parent 485588017b
commit c995bdbbfa
2 changed files with 114 additions and 4 deletions

View file

@ -13,7 +13,7 @@ from ktransformers.server.schemas.endpoints.chat import RawUsage, Role
from ktransformers.server.backend.base import BackendInterfaceBase from ktransformers.server.backend.base import BackendInterfaceBase
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from ktransformers.server.config.log import logger from ktransformers.server.config.log import logger
from fastapi.responses import JSONResponse
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
# Define own data structure instead of importing from OpenAI # Define own data structure instead of importing from OpenAI
@ -143,7 +143,67 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
# Process messages with tool functionality if needed # Process messages with tool functionality if needed
enhanced_messages = list(create.messages) enhanced_messages = list(create.messages)
if create.model != Config().model_name:
return JSONResponse(
status_code=400,
content={
"error": {
"message": "Model not found",
"code": 404,
"type": "NotFound"
}
})
if create.max_tokens<0 or create.max_completion_tokens<0:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"max_new_tokens must be at least 0, got {create.max_tokens}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.temperature<0 or create.temperature>2:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"temperature must be in [0, 2], got {create.temperature}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.top_p<=0 or create.top_p>1:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"top_p must be in (0, 1], got {create.top_p}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.frequency_penalty<-2 or create.frequency_penalty>2:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"frequency_penalty must be in [-2, 2], got {create.frequency_penalty}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.presence_penalty<-2 or create.presence_penalty>2:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"presence_penalty must be in [-2, 2], got {create.presence_penalty}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
# Check if tools are present # Check if tools are present
has_tools = create.tools and len(create.tools) > 0 has_tools = create.tools and len(create.tools) > 0

View file

@ -7,13 +7,63 @@ from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import stream_response from ktransformers.server.schemas.assistants.streaming import stream_response
from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject
from ktransformers.server.schemas.endpoints.chat import RawUsage from ktransformers.server.schemas.endpoints.chat import RawUsage
from fastapi.responses import JSONResponse
from ktransformers.server.config.config import Config
router = APIRouter() router = APIRouter()
@router.post("/completions",tags=['openai']) @router.post("/completions",tags=['openai'])
async def create_completion(request:Request, create:CompletionCreate): async def create_completion(request:Request, create:CompletionCreate):
id = str(uuid4()) id = str(uuid4())
if create.model != Config().model_name:
return JSONResponse(
status_code=400,
content={
"error": {
"message": "Model not found",
"code": 404,
"type": "NotFound"
}
})
if create.max_tokens<0:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"max_new_tokens must be at least 0, got {create.max_tokens}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.max_completion_tokens<0:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"max_new_tokens must be at least 0, got {create.max_completion_tokens}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.temperature<0 or create.temperature>2:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"temperature must be in [0, 2], got {create.temperature}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.top_p<=0 or create.top_p>1:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"top_p must be in (0, 1], got {create.top_p}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
interface = get_interface() interface = get_interface()
print(f'COMPLETION INPUT:----\n{create.prompt}\n----') print(f'COMPLETION INPUT:----\n{create.prompt}\n----')