Revert repetition_penalty as it is not in API spec

This commit is contained in:
lazymio 2025-02-24 21:30:03 +08:00
parent 05ad288453
commit 76487c4dcb
No known key found for this signature in database
GPG key ID: DFF27E34A47CB873
5 changed files with 12 additions and 14 deletions

View file

@ -28,13 +28,13 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
if create.stream: if create.stream:
async def inner(): async def inner():
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time())) chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time()))
async for token in interface.inference(input_message,id,create.temperature,create.top_p,create.repetition_penalty): async for token in interface.inference(input_message,id,create.temperature,create.top_p):
chunk.set_token(token) chunk.set_token(token)
yield chunk yield chunk
return chat_stream_response(request,inner()) return chat_stream_response(request,inner())
else: else:
comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time())) comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time()))
comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2) comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2)
async for token in interface.inference(input_message,id,create.temperature,create.top_p,create.repetition_penalty): async for token in interface.inference(input_message,id,create.temperature,create.top_p):
comp.append_token(token) comp.append_token(token)
return comp return comp

View file

@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate):
if create.stream: if create.stream:
async def inner(): async def inner():
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p,create.repetition_penalty): async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
d = {'choices':[{'delta':{'content':token}}]} d = {'choices':[{'delta':{'content':token}}]}
yield f"data:{json.dumps(d)}\n\n" yield f"data:{json.dumps(d)}\n\n"
d = {'choices':[{'delta':{'content':''},'finish_reason':''}]} d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}
@ -28,6 +28,6 @@ async def create_completion(request:Request,create:CompletionCreate):
return stream_response(request,inner()) return stream_response(request,inner())
else: else:
comp = CompletionObject(id=id,object='text_completion',created=int(time())) comp = CompletionObject(id=id,object='text_completion',created=int(time()))
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p,create.repetition_penalty): async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
comp.append_token(token) comp.append_token(token)
return comp return comp

View file

@ -202,20 +202,18 @@ class TransformersInterface(BackendInterfaceBase):
self.seq_length += 1 self.seq_length += 1
return self.streamer.put(new_tokens) return self.streamer.put(new_tokens)
def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None, repetition_penalty: Optional[float] = None): def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None):
if temperature is None: if temperature is None:
temperature = self.args.temperature temperature = self.args.temperature
if top_p is None: if top_p is None:
top_p = self.args.top_p top_p = self.args.top_p
if repetition_penalty is None:
repetition_penalty = self.args.repetition_penalty
generation_config, model_kwargs = self.model._prepare_generation_config( generation_config, model_kwargs = self.model._prepare_generation_config(
None, max_length=self.args.max_new_tokens, None, max_length=self.args.max_new_tokens,
do_sample=True, do_sample=True,
top_k=self.args.top_k, top_k=self.args.top_k,
top_p=top_p, top_p=top_p,
temperature=temperature, temperature=temperature,
repetition_penalty=repetition_penalty # change this to modify generate config repetition_penalty=self.args.repetition_penalty # change this to modify generate config
) )
self.inputs = inputs self.inputs = inputs
self.generation_config = generation_config self.generation_config = generation_config
@ -261,7 +259,7 @@ class TransformersInterface(BackendInterfaceBase):
return self.logits_to_token(logits) return self.logits_to_token(logits)
@torch.no_grad @torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, repetition_penalty: Optional[float] = None): def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None):
input_ids_length = input_ids.shape[-1] input_ids_length = input_ids.shape[-1]
logger.debug(f"input_ids: {input_ids.shape}") logger.debug(f"input_ids: {input_ids.shape}")
@ -329,7 +327,7 @@ class TransformersInterface(BackendInterfaceBase):
else: else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
self.prepare_logits_wrapper(input_ids, device, temperature, top_p, repetition_penalty) self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
next_token = self.logits_to_token(logits[0, -1, :]) next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token) yield self.append_new_tokens(next_token)
@ -365,7 +363,7 @@ class TransformersInterface(BackendInterfaceBase):
self.last_request_id = thread_id self.last_request_id = thread_id
return True return True
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, repetition_penalty: Optional[float] = None): async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
self.streamer.reset() self.streamer.reset()
self.profiler.create_and_start_timer("tokenize") self.profiler.create_and_start_timer("tokenize")
if isinstance(local_messages, List): if isinstance(local_messages, List):
@ -392,7 +390,7 @@ class TransformersInterface(BackendInterfaceBase):
print(think, end="",flush=True) print(think, end="",flush=True)
yield think yield think
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p, repetition_penalty): for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
# output think token after prefill done # output think token after prefill done
if t is not None: if t is not None:
print(t, end="",flush=True) print(t, end="",flush=True)

View file

@ -27,7 +27,7 @@ class ChatCompletionCreate(BaseModel):
stream : bool = False stream : bool = False
temperature: Optional[float] temperature: Optional[float]
top_p: Optional[float] top_p: Optional[float]
repetition_penalty: Optional[float] frequency_penalty: Optional[float]
def get_tokenizer_messages(self): def get_tokenizer_messages(self):
return [m.to_tokenizer_message() for m in self.messages] return [m.to_tokenizer_message() for m in self.messages]

View file

@ -11,7 +11,7 @@ class CompletionCreate(BaseModel):
stream: bool = False stream: bool = False
temperature: Optional[float] temperature: Optional[float]
top_p: Optional[float] top_p: Optional[float]
repetition_penalty: Optional[float] frequency_penalty: Optional[float]
def get_tokenizer_messages(self): def get_tokenizer_messages(self):
if isinstance(self.prompt,List): if isinstance(self.prompt,List):