mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
fix load default max_new_tokens
This commit is contained in:
parent
67042d11e3
commit
7af83f9efb
4 changed files with 21 additions and 10 deletions
|
@ -138,12 +138,23 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
||||||
|
|
||||||
# Process messages with tool functionality if needed
|
# Process messages with tool functionality if needed
|
||||||
enhanced_messages = list(create.messages)
|
enhanced_messages = list(create.messages)
|
||||||
if create.max_tokens<0 or create.max_completion_tokens<0:
|
if create.max_tokens is not None and create.max_tokens<0:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
content={
|
content={
|
||||||
"object": "error",
|
"object": "error",
|
||||||
"message": f"max_new_tokens must be at least 0, got {create.max_tokens}.",
|
"message": f"max_tokens must be at least 0, got {create.max_tokens}.",
|
||||||
|
"type": "BadRequestError",
|
||||||
|
"param": None,
|
||||||
|
"code": 400
|
||||||
|
})
|
||||||
|
|
||||||
|
if create.max_completion_tokens is not None and create.max_completion_tokens<0:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={
|
||||||
|
"object": "error",
|
||||||
|
"message": f"max_completion_tokens must be at least 0, got {create.max_completion_tokens}.",
|
||||||
"type": "BadRequestError",
|
"type": "BadRequestError",
|
||||||
"param": None,
|
"param": None,
|
||||||
"code": 400
|
"code": 400
|
||||||
|
|
|
@ -14,22 +14,22 @@ router = APIRouter()
|
||||||
@router.post("/completions",tags=['openai'])
|
@router.post("/completions",tags=['openai'])
|
||||||
async def create_completion(request:Request, create:CompletionCreate):
|
async def create_completion(request:Request, create:CompletionCreate):
|
||||||
id = str(uuid4())
|
id = str(uuid4())
|
||||||
if create.max_tokens<0:
|
if create.max_tokens is not None and create.max_tokens<0:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
content={
|
content={
|
||||||
"object": "error",
|
"object": "error",
|
||||||
"message": f"max_new_tokens must be at least 0, got {create.max_tokens}.",
|
"message": f"max_tokens must be at least 0, got {create.max_tokens}.",
|
||||||
"type": "BadRequestError",
|
"type": "BadRequestError",
|
||||||
"param": None,
|
"param": None,
|
||||||
"code": 400
|
"code": 400
|
||||||
})
|
})
|
||||||
if create.max_completion_tokens<0:
|
if create.max_completion_tokens is not None and create.max_completion_tokens<0:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
content={
|
content={
|
||||||
"object": "error",
|
"object": "error",
|
||||||
"message": f"max_new_tokens must be at least 0, got {create.max_completion_tokens}.",
|
"message": f"max_completion_tokens must be at least 0, got {create.max_completion_tokens}.",
|
||||||
"type": "BadRequestError",
|
"type": "BadRequestError",
|
||||||
"param": None,
|
"param": None,
|
||||||
"code": 400
|
"code": 400
|
||||||
|
|
|
@ -73,8 +73,8 @@ class ChatCompletionCreate(BaseModel):
|
||||||
stream_options: Optional[Dict[str, Any]] = None
|
stream_options: Optional[Dict[str, Any]] = None
|
||||||
frequency_penalty: float = 0
|
frequency_penalty: float = 0
|
||||||
presence_penalty: float = 0
|
presence_penalty: float = 0
|
||||||
max_tokens: Optional[int] = Field(default=Config().max_new_tokens)
|
max_tokens: Optional[int] = Field(default=None)
|
||||||
max_completion_tokens: Optional[int] = Field(default=Config().max_new_tokens)
|
max_completion_tokens: Optional[int] = Field(default=None)
|
||||||
return_speed: Optional[bool] = Field(default=False)
|
return_speed: Optional[bool] = Field(default=False)
|
||||||
def get_tokenizer_messages(self):
|
def get_tokenizer_messages(self):
|
||||||
return [m.to_tokenizer_message() for m in self.messages]
|
return [m.to_tokenizer_message() for m in self.messages]
|
||||||
|
|
|
@ -10,8 +10,8 @@ class CompletionCreate(BaseModel):
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
temperature: Optional[float] = Field(default=Config().temperature)
|
temperature: Optional[float] = Field(default=Config().temperature)
|
||||||
top_p: Optional[float] = Field(default=Config().top_p)
|
top_p: Optional[float] = Field(default=Config().top_p)
|
||||||
max_tokens: Optional[int] = Field(default=Config().max_new_tokens)
|
max_tokens: Optional[int] = Field(default=None)
|
||||||
max_completion_tokens: Optional[int] = Field(default=Config().max_new_tokens)
|
max_completion_tokens: Optional[int] = Field(default=None)
|
||||||
|
|
||||||
def get_tokenizer_messages(self):
|
def get_tokenizer_messages(self):
|
||||||
if isinstance(self.prompt,List):
|
if isinstance(self.prompt,List):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue