fix temperature

This commit is contained in:
qiyuxinlin 2025-02-27 21:00:44 +08:00
parent 5e3c6b4f97
commit 22df52e94e
2 changed files with 16 additions and 16 deletions

View file

@ -29,6 +29,16 @@ class KTransformersInterface(TransformersInterface):
torch.set_grad_enabled(False)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
try:
generation_config = GenerationConfig.from_pretrained(args.model_dir)
except:
generation_config = GenerationConfig(
max_length=args.max_new_tokens,
temperature=args.temperature,
top_p=args.temperature,
do_sample=True
)
torch.set_default_dtype(config.torch_dtype)
if config.architectures[0] == "Qwen2MoeForCausalLM":
config._attn_implementation = "flash_attention_2"
@ -49,7 +59,7 @@ class KTransformersInterface(TransformersInterface):
" belong to current model):"
)
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
self.model.generation_config = generation_config
self.device_map = self.model.gguf_loader.tensor_device_map
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
self.cache = StaticCache(
@ -60,16 +70,7 @@ class KTransformersInterface(TransformersInterface):
dtype=self.model.dtype,
)
# logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
try:
self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
except:
gen_config = GenerationConfig(
max_length=128,
temperature=0.7,
top_p=0.9,
do_sample=True
)
self.model.generation_config = gen_config
if self.model.generation_config.pad_token_id is None:
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
self.streamer = TextStreamer(self.tokenizer)