diff --git a/ktransformers/server/backend/interfaces/balance_serve.py b/ktransformers/server/backend/interfaces/balance_serve.py index be48c92..aa6e436 100644 --- a/ktransformers/server/backend/interfaces/balance_serve.py +++ b/ktransformers/server/backend/interfaces/balance_serve.py @@ -368,6 +368,8 @@ class BalanceServeInterface(BackendInterfaceBase): stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")] query_add.stop_criteria = stop_criteria query_add.sample_options.temperature = temperature + if top_p == 0: + top_p = 0.0001 query_add.sample_options.top_p = top_p query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens) query_id = self.sched_client.add_query(query_add) diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index 1460176..d002435 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -208,6 +208,8 @@ class TransformersInterface(BackendInterfaceBase): temperature = self.model.generation_config.temperature if top_p is None: top_p = self.model.generation_config.top_p + if top_p == 0: + top_p = 0.0001 generation_config, model_kwargs = self.model._prepare_generation_config( None, max_length=self.args.max_new_tokens, do_sample=True,