mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
optimize GPU
This commit is contained in:
parent
cf4da5fd47
commit
7e1fe256c8
8 changed files with 677 additions and 156 deletions
|
@ -25,10 +25,10 @@ class KTransformersThreadContext(TransformersThreadContext):
|
|||
class KTransformersInterface(TransformersInterface):
|
||||
def __init__(self, args: ConfigArgs = default_args):
|
||||
self.args = args
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_grad_enabled(False)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
|
||||
torch.set_default_dtype(config.torch_dtype)
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue