fix server don't accept yaml path as param; fix server static cache device problem

This commit is contained in:
TangJingqi 2024-08-21 14:19:43 +08:00
parent 4358722891
commit 170b7a6001
3 changed files with 12 additions and 9 deletions

View file

@ -25,8 +25,10 @@ class KTransformersInterface(TransformersInterface):
with torch.device("meta"):
self.model=custom_models[config.architectures[0]](config)
optimize_rule_path = default_optimize_rules[config.architectures[0]]
if default_args.optimize_config_path is not None:
optimize_rule_path = default_optimize_rules[config.architectures[0]]
else:
optimize_rule_path = args.optimize_config_path
# print(optimize_config)
@ -38,10 +40,10 @@ class KTransformersInterface(TransformersInterface):
optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config)
logger.info(f'{args.model_name} loaded from {args.model_dir} to {args.device}')
self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=args.device, dtype=self.model.dtype)
logger.info(f'StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}')
device_map = self.model.gguf_loader.tensor_device_map
logger.info(f'{args.model_name} loaded from {args.model_dir} to {device_map}')
self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=device_map, dtype=self.model.dtype)
logger.info(f'StaticCache (length={args.cache_lens}) created at {device_map}, batch size:{args.batch_size}')
self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
if self.model.generation_config.pad_token_id is None:
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
@ -63,7 +65,7 @@ class KTransformersInterface(TransformersInterface):
return self.logits_to_token(logits)
if self.use_static_cache:
mask = torch.ones((1,self.seq_length)).to(self.args.device)
mask = torch.ones((1,self.seq_length)).to(torch_device)
logits = self.model(
self.current_ids,
cache_position=self.active_cache_position,