optimize GPU

This commit is contained in:
Atream 2025-02-21 05:06:57 +00:00
parent cf4da5fd47
commit 7e1fe256c8
8 changed files with 677 additions and 156 deletions

View file

@ -25,10 +25,10 @@ class KTransformersThreadContext(TransformersThreadContext):
class KTransformersInterface(TransformersInterface):
def __init__(self, args: ConfigArgs = default_args):
self.args = args
torch.set_default_dtype(torch.bfloat16)
torch.set_grad_enabled(False)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
torch.set_default_dtype(config.torch_dtype)
if config.architectures[0] == "Qwen2MoeForCausalLM":
config._attn_implementation = "flash_attention_2"