diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 6de0998..3ce1dda 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -29,6 +29,16 @@ class KTransformersInterface(TransformersInterface): torch.set_grad_enabled(False) self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code) config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code) + try: + generation_config = GenerationConfig.from_pretrained(args.model_dir) + except: + generation_config = GenerationConfig( + max_length=args.max_new_tokens, + temperature=args.temperature, + top_p=args.temperature, + do_sample=True + ) + torch.set_default_dtype(config.torch_dtype) if config.architectures[0] == "Qwen2MoeForCausalLM": config._attn_implementation = "flash_attention_2" @@ -49,7 +59,7 @@ class KTransformersInterface(TransformersInterface): " belong to current model):" ) optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config) - + self.model.generation_config = generation_config self.device_map = self.model.gguf_loader.tensor_device_map # logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}") self.cache = StaticCache( @@ -60,16 +70,7 @@ class KTransformersInterface(TransformersInterface): dtype=self.model.dtype, ) # logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}") - try: - self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir) - except: - gen_config = GenerationConfig( - max_length=128, - temperature=0.7, - top_p=0.9, - do_sample=True - ) - self.model.generation_config = gen_config + if self.model.generation_config.pad_token_id is None: self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id self.streamer = TextStreamer(self.tokenizer) diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index a0821d5..94ad27c 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -203,10 +203,10 @@ class TransformersInterface(BackendInterfaceBase): return self.streamer.put(new_tokens) def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None): - if temperature is None: - temperature = self.args.temperature + if temperature is None or temperature == 0: + temperature = self.model.generation_config.temperature if top_p is None: - top_p = self.args.top_p + top_p = self.model.generation_config.top_p generation_config, model_kwargs = self.model._prepare_generation_config( None, max_length=self.args.max_new_tokens, do_sample=True, @@ -216,10 +216,9 @@ class TransformersInterface(BackendInterfaceBase): repetition_penalty=self.args.repetition_penalty # change this to modify generate config ) self.inputs = inputs - self.generation_config = generation_config try: # transformers==4.43 self.logits_warper = ( - self.model._get_logits_warper(generation_config,device=device) + self.model._get_logits_warper(generation_config, device=device) ) except: self.logits_warper = (