diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 8d121d5..f896a90 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -25,8 +25,10 @@ class KTransformersInterface(TransformersInterface): with torch.device("meta"): self.model=custom_models[config.architectures[0]](config) - - optimize_rule_path = default_optimize_rules[config.architectures[0]] + if default_args.optimize_config_path is not None: + optimize_rule_path = default_optimize_rules[config.architectures[0]] + else: + optimize_rule_path = args.optimize_config_path # print(optimize_config) @@ -38,10 +40,10 @@ class KTransformersInterface(TransformersInterface): optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config) - - logger.info(f'{args.model_name} loaded from {args.model_dir} to {args.device}') - self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=args.device, dtype=self.model.dtype) - logger.info(f'StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}') + device_map = self.model.gguf_loader.tensor_device_map + logger.info(f'{args.model_name} loaded from {args.model_dir} to {device_map}') + self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=device_map, dtype=self.model.dtype) + logger.info(f'StaticCache (length={args.cache_lens}) created at {device_map}, batch size:{args.batch_size}') self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir) if self.model.generation_config.pad_token_id is None: self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id @@ -63,7 +65,7 @@ class KTransformersInterface(TransformersInterface): return self.logits_to_token(logits) if self.use_static_cache: - mask = torch.ones((1,self.seq_length)).to(self.args.device) + mask = torch.ones((1,self.seq_length)).to(torch_device) logits = self.model( self.current_ids, cache_position=self.active_cache_position, diff --git a/ktransformers/server/main.py b/ktransformers/server/main.py index 274a6bc..0bb52cc 100644 --- a/ktransformers/server/main.py +++ b/ktransformers/server/main.py @@ -107,9 +107,9 @@ def main(): parser.add_argument("--web", type=bool, default=False) parser.add_argument("--model_name", type=str, default=cfg.model_name) parser.add_argument("--model_path", type=str, default=cfg.model_path) - parser.add_argument("--device", type=str, default=cfg.model_device) + parser.add_argument("--device", type=str, default=cfg.model_device, help="Warning: Abandoning this parameter") parser.add_argument("--gguf_path", type=str, default=cfg.gguf_path) - parser.add_argument("--optimize_config_path", type=str, required=False) + parser.add_argument("--optimize_config_path", default=None, type=str, required=False) parser.add_argument("--cpu_infer", type=int, default=cfg.cpu_infer) parser.add_argument("--type", type=str, default=cfg.backend_type) diff --git a/requirements-local_chat.txt b/requirements-local_chat.txt index 917cc03..17cb0f1 100644 --- a/requirements-local_chat.txt +++ b/requirements-local_chat.txt @@ -3,3 +3,4 @@ transformers numpy torch>=2.3.0 packaging +cpufeature \ No newline at end of file