mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 04:30:03 +00:00
Merge pull request #48 from Azure-Tang/main
[fix] Fix bugs about static cache and server param;
This commit is contained in:
commit
cbc47d0b68
3 changed files with 12 additions and 9 deletions
|
@ -25,8 +25,10 @@ class KTransformersInterface(TransformersInterface):
|
|||
|
||||
with torch.device("meta"):
|
||||
self.model=custom_models[config.architectures[0]](config)
|
||||
|
||||
if default_args.optimize_config_path is not None:
|
||||
optimize_rule_path = default_optimize_rules[config.architectures[0]]
|
||||
else:
|
||||
optimize_rule_path = args.optimize_config_path
|
||||
|
||||
# print(optimize_config)
|
||||
|
||||
|
@ -38,10 +40,10 @@ class KTransformersInterface(TransformersInterface):
|
|||
optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config)
|
||||
|
||||
|
||||
|
||||
logger.info(f'{args.model_name} loaded from {args.model_dir} to {args.device}')
|
||||
self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=args.device, dtype=self.model.dtype)
|
||||
logger.info(f'StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}')
|
||||
device_map = self.model.gguf_loader.tensor_device_map
|
||||
logger.info(f'{args.model_name} loaded from {args.model_dir} to {device_map}')
|
||||
self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=device_map, dtype=self.model.dtype)
|
||||
logger.info(f'StaticCache (length={args.cache_lens}) created at {device_map}, batch size:{args.batch_size}')
|
||||
self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
|
||||
if self.model.generation_config.pad_token_id is None:
|
||||
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
||||
|
@ -63,7 +65,7 @@ class KTransformersInterface(TransformersInterface):
|
|||
return self.logits_to_token(logits)
|
||||
|
||||
if self.use_static_cache:
|
||||
mask = torch.ones((1,self.seq_length)).to(self.args.device)
|
||||
mask = torch.ones((1,self.seq_length)).to(torch_device)
|
||||
logits = self.model(
|
||||
self.current_ids,
|
||||
cache_position=self.active_cache_position,
|
||||
|
|
|
@ -107,9 +107,9 @@ def main():
|
|||
parser.add_argument("--web", type=bool, default=False)
|
||||
parser.add_argument("--model_name", type=str, default=cfg.model_name)
|
||||
parser.add_argument("--model_path", type=str, default=cfg.model_path)
|
||||
parser.add_argument("--device", type=str, default=cfg.model_device)
|
||||
parser.add_argument("--device", type=str, default=cfg.model_device, help="Warning: Abandoning this parameter")
|
||||
parser.add_argument("--gguf_path", type=str, default=cfg.gguf_path)
|
||||
parser.add_argument("--optimize_config_path", type=str, required=False)
|
||||
parser.add_argument("--optimize_config_path", default=None, type=str, required=False)
|
||||
parser.add_argument("--cpu_infer", type=int, default=cfg.cpu_infer)
|
||||
parser.add_argument("--type", type=str, default=cfg.backend_type)
|
||||
|
||||
|
|
|
@ -3,3 +3,4 @@ transformers
|
|||
numpy
|
||||
torch>=2.3.0
|
||||
packaging
|
||||
cpufeature
|
Loading…
Add table
Reference in a new issue