mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 12:40:02 +00:00
fix temperature
This commit is contained in:
parent
5e3c6b4f97
commit
22df52e94e
2 changed files with 16 additions and 16 deletions
|
@ -29,6 +29,16 @@ class KTransformersInterface(TransformersInterface):
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
|
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
|
||||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
|
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
|
||||||
|
try:
|
||||||
|
generation_config = GenerationConfig.from_pretrained(args.model_dir)
|
||||||
|
except:
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
max_length=args.max_new_tokens,
|
||||||
|
temperature=args.temperature,
|
||||||
|
top_p=args.temperature,
|
||||||
|
do_sample=True
|
||||||
|
)
|
||||||
|
|
||||||
torch.set_default_dtype(config.torch_dtype)
|
torch.set_default_dtype(config.torch_dtype)
|
||||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||||
config._attn_implementation = "flash_attention_2"
|
config._attn_implementation = "flash_attention_2"
|
||||||
|
@ -49,7 +59,7 @@ class KTransformersInterface(TransformersInterface):
|
||||||
" belong to current model):"
|
" belong to current model):"
|
||||||
)
|
)
|
||||||
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
|
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
|
||||||
|
self.model.generation_config = generation_config
|
||||||
self.device_map = self.model.gguf_loader.tensor_device_map
|
self.device_map = self.model.gguf_loader.tensor_device_map
|
||||||
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
|
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
|
||||||
self.cache = StaticCache(
|
self.cache = StaticCache(
|
||||||
|
@ -60,16 +70,7 @@ class KTransformersInterface(TransformersInterface):
|
||||||
dtype=self.model.dtype,
|
dtype=self.model.dtype,
|
||||||
)
|
)
|
||||||
# logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
|
# logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
|
||||||
try:
|
|
||||||
self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
|
|
||||||
except:
|
|
||||||
gen_config = GenerationConfig(
|
|
||||||
max_length=128,
|
|
||||||
temperature=0.7,
|
|
||||||
top_p=0.9,
|
|
||||||
do_sample=True
|
|
||||||
)
|
|
||||||
self.model.generation_config = gen_config
|
|
||||||
if self.model.generation_config.pad_token_id is None:
|
if self.model.generation_config.pad_token_id is None:
|
||||||
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
||||||
self.streamer = TextStreamer(self.tokenizer)
|
self.streamer = TextStreamer(self.tokenizer)
|
||||||
|
|
|
@ -203,10 +203,10 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
return self.streamer.put(new_tokens)
|
return self.streamer.put(new_tokens)
|
||||||
|
|
||||||
def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||||
if temperature is None:
|
if temperature is None or temperature == 0:
|
||||||
temperature = self.args.temperature
|
temperature = self.model.generation_config.temperature
|
||||||
if top_p is None:
|
if top_p is None:
|
||||||
top_p = self.args.top_p
|
top_p = self.model.generation_config.top_p
|
||||||
generation_config, model_kwargs = self.model._prepare_generation_config(
|
generation_config, model_kwargs = self.model._prepare_generation_config(
|
||||||
None, max_length=self.args.max_new_tokens,
|
None, max_length=self.args.max_new_tokens,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
|
@ -216,10 +216,9 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
|
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
|
||||||
)
|
)
|
||||||
self.inputs = inputs
|
self.inputs = inputs
|
||||||
self.generation_config = generation_config
|
|
||||||
try: # transformers==4.43
|
try: # transformers==4.43
|
||||||
self.logits_warper = (
|
self.logits_warper = (
|
||||||
self.model._get_logits_warper(generation_config,device=device)
|
self.model._get_logits_warper(generation_config, device=device)
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
self.logits_warper = (
|
self.logits_warper = (
|
||||||
|
|
Loading…
Add table
Reference in a new issue