diff --git a/ktransformers/server/api/openai/endpoints/chat.py b/ktransformers/server/api/openai/endpoints/chat.py index 356637c..e5ea636 100644 --- a/ktransformers/server/api/openai/endpoints/chat.py +++ b/ktransformers/server/api/openai/endpoints/chat.py @@ -28,13 +28,13 @@ async def chat_completion(request:Request,create:ChatCompletionCreate): if create.stream: async def inner(): chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time())) - async for token in interface.inference(input_message,id,create.temperature,create.top_p,create.repetition_penalty): + async for token in interface.inference(input_message,id,create.temperature,create.top_p): chunk.set_token(token) yield chunk return chat_stream_response(request,inner()) else: comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time())) comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2) - async for token in interface.inference(input_message,id,create.temperature,create.top_p,create.repetition_penalty): + async for token in interface.inference(input_message,id,create.temperature,create.top_p): comp.append_token(token) return comp diff --git a/ktransformers/server/api/openai/legacy/completions.py b/ktransformers/server/api/openai/legacy/completions.py index 9808c3a..fe250f4 100644 --- a/ktransformers/server/api/openai/legacy/completions.py +++ b/ktransformers/server/api/openai/legacy/completions.py @@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate): if create.stream: async def inner(): - async for token in interface.inference(create.prompt,id,create.temperature,create.top_p,create.repetition_penalty): + async for token in interface.inference(create.prompt,id,create.temperature,create.top_p): d = {'choices':[{'delta':{'content':token}}]} yield f"data:{json.dumps(d)}\n\n" d = {'choices':[{'delta':{'content':''},'finish_reason':''}]} @@ -28,6 +28,6 @@ async def create_completion(request:Request,create:CompletionCreate): return stream_response(request,inner()) else: comp = CompletionObject(id=id,object='text_completion',created=int(time())) - async for token in interface.inference(create.prompt,id,create.temperature,create.top_p,create.repetition_penalty): + async for token in interface.inference(create.prompt,id,create.temperature,create.top_p): comp.append_token(token) return comp diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index 2674dd1..d2e48a4 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -202,20 +202,18 @@ class TransformersInterface(BackendInterfaceBase): self.seq_length += 1 return self.streamer.put(new_tokens) - def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None, repetition_penalty: Optional[float] = None): + def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None): if temperature is None: temperature = self.args.temperature if top_p is None: top_p = self.args.top_p - if repetition_penalty is None: - repetition_penalty = self.args.repetition_penalty generation_config, model_kwargs = self.model._prepare_generation_config( None, max_length=self.args.max_new_tokens, do_sample=True, top_k=self.args.top_k, top_p=top_p, temperature=temperature, - repetition_penalty=repetition_penalty # change this to modify generate config + repetition_penalty=self.args.repetition_penalty # change this to modify generate config ) self.inputs = inputs self.generation_config = generation_config @@ -261,7 +259,7 @@ class TransformersInterface(BackendInterfaceBase): return self.logits_to_token(logits) @torch.no_grad - def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, repetition_penalty: Optional[float] = None): + def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None): input_ids_length = input_ids.shape[-1] logger.debug(f"input_ids: {input_ids.shape}") @@ -329,7 +327,7 @@ class TransformersInterface(BackendInterfaceBase): else: logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] - self.prepare_logits_wrapper(input_ids, device, temperature, top_p, repetition_penalty) + self.prepare_logits_wrapper(input_ids, device, temperature, top_p) next_token = self.logits_to_token(logits[0, -1, :]) yield self.append_new_tokens(next_token) @@ -365,7 +363,7 @@ class TransformersInterface(BackendInterfaceBase): self.last_request_id = thread_id return True - async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, repetition_penalty: Optional[float] = None): + async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None): self.streamer.reset() self.profiler.create_and_start_timer("tokenize") if isinstance(local_messages, List): @@ -392,7 +390,7 @@ class TransformersInterface(BackendInterfaceBase): print(think, end="",flush=True) yield think - for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p, repetition_penalty): + for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p): # output think token after prefill done if t is not None: print(t, end="",flush=True) diff --git a/ktransformers/server/schemas/endpoints/chat.py b/ktransformers/server/schemas/endpoints/chat.py index b929c4b..5507266 100644 --- a/ktransformers/server/schemas/endpoints/chat.py +++ b/ktransformers/server/schemas/endpoints/chat.py @@ -27,7 +27,7 @@ class ChatCompletionCreate(BaseModel): stream : bool = False temperature: Optional[float] top_p: Optional[float] - repetition_penalty: Optional[float] + frequency_penalty: Optional[float] def get_tokenizer_messages(self): return [m.to_tokenizer_message() for m in self.messages] diff --git a/ktransformers/server/schemas/legacy/completions.py b/ktransformers/server/schemas/legacy/completions.py index c5876d4..ca4b89c 100644 --- a/ktransformers/server/schemas/legacy/completions.py +++ b/ktransformers/server/schemas/legacy/completions.py @@ -11,7 +11,7 @@ class CompletionCreate(BaseModel): stream: bool = False temperature: Optional[float] top_p: Optional[float] - repetition_penalty: Optional[float] + frequency_penalty: Optional[float] def get_tokenizer_messages(self): if isinstance(self.prompt,List):