fix load default max_new_tokens

This commit is contained in:
qiyuxinlin 2025-04-25 04:20:12 +00:00
parent 67042d11e3
commit 7af83f9efb
4 changed files with 21 additions and 10 deletions

View file

@ -138,12 +138,23 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
# Process messages with tool functionality if needed # Process messages with tool functionality if needed
enhanced_messages = list(create.messages) enhanced_messages = list(create.messages)
if create.max_tokens<0 or create.max_completion_tokens<0: if create.max_tokens is not None and create.max_tokens<0:
return JSONResponse( return JSONResponse(
status_code=400, status_code=400,
content={ content={
"object": "error", "object": "error",
"message": f"max_new_tokens must be at least 0, got {create.max_tokens}.", "message": f"max_tokens must be at least 0, got {create.max_tokens}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.max_completion_tokens is not None and create.max_completion_tokens<0:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"max_completion_tokens must be at least 0, got {create.max_completion_tokens}.",
"type": "BadRequestError", "type": "BadRequestError",
"param": None, "param": None,
"code": 400 "code": 400

View file

@ -14,22 +14,22 @@ router = APIRouter()
@router.post("/completions",tags=['openai']) @router.post("/completions",tags=['openai'])
async def create_completion(request:Request, create:CompletionCreate): async def create_completion(request:Request, create:CompletionCreate):
id = str(uuid4()) id = str(uuid4())
if create.max_tokens<0: if create.max_tokens is not None and create.max_tokens<0:
return JSONResponse( return JSONResponse(
status_code=400, status_code=400,
content={ content={
"object": "error", "object": "error",
"message": f"max_new_tokens must be at least 0, got {create.max_tokens}.", "message": f"max_tokens must be at least 0, got {create.max_tokens}.",
"type": "BadRequestError", "type": "BadRequestError",
"param": None, "param": None,
"code": 400 "code": 400
}) })
if create.max_completion_tokens<0: if create.max_completion_tokens is not None and create.max_completion_tokens<0:
return JSONResponse( return JSONResponse(
status_code=400, status_code=400,
content={ content={
"object": "error", "object": "error",
"message": f"max_new_tokens must be at least 0, got {create.max_completion_tokens}.", "message": f"max_completion_tokens must be at least 0, got {create.max_completion_tokens}.",
"type": "BadRequestError", "type": "BadRequestError",
"param": None, "param": None,
"code": 400 "code": 400

View file

@ -73,8 +73,8 @@ class ChatCompletionCreate(BaseModel):
stream_options: Optional[Dict[str, Any]] = None stream_options: Optional[Dict[str, Any]] = None
frequency_penalty: float = 0 frequency_penalty: float = 0
presence_penalty: float = 0 presence_penalty: float = 0
max_tokens: Optional[int] = Field(default=Config().max_new_tokens) max_tokens: Optional[int] = Field(default=None)
max_completion_tokens: Optional[int] = Field(default=Config().max_new_tokens) max_completion_tokens: Optional[int] = Field(default=None)
return_speed: Optional[bool] = Field(default=False) return_speed: Optional[bool] = Field(default=False)
def get_tokenizer_messages(self): def get_tokenizer_messages(self):
return [m.to_tokenizer_message() for m in self.messages] return [m.to_tokenizer_message() for m in self.messages]

View file

@ -10,8 +10,8 @@ class CompletionCreate(BaseModel):
stream: bool = False stream: bool = False
temperature: Optional[float] = Field(default=Config().temperature) temperature: Optional[float] = Field(default=Config().temperature)
top_p: Optional[float] = Field(default=Config().top_p) top_p: Optional[float] = Field(default=Config().top_p)
max_tokens: Optional[int] = Field(default=Config().max_new_tokens) max_tokens: Optional[int] = Field(default=None)
max_completion_tokens: Optional[int] = Field(default=Config().max_new_tokens) max_completion_tokens: Optional[int] = Field(default=None)
def get_tokenizer_messages(self): def get_tokenizer_messages(self):
if isinstance(self.prompt,List): if isinstance(self.prompt,List):