kvcache-ai-ktransformers/kt-sft/ktransformers/util/grad_wrapper.py
Peilin Li 171578a7ec
Some checks failed
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
[refactor]: Change named 'KT-SFT' to 'kt-sft' (#1626)
* Change named 'KT-SFT' to 'kt-sft'

* [docs]: update kt-sft name

---------

Co-authored-by: ZiWei Yuan <yzwliam@126.com>
2025-11-17 11:48:42 +08:00

29 lines
932 B
Python

from functools import wraps
import torch, yaml, pathlib
import os, sys
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
sys.path.insert(0, project_dir)
from ktransformers.util.globals import GLOBAL_CONFIG
# print(f"start_sit: {GLOBAL_CONFIG._config['mod']}")
def maybe_no_grad(_func=None):
# print(f"maybe_sit: {GLOBAL_CONFIG._config['mod']}")
def decorator(func):
# print(f"decorate_sit: {GLOBAL_CONFIG._config['mod']}")
def wrapper(*args, **kwargs):
# print(f"wrap_sit: {GLOBAL_CONFIG._config['mod']}")
if GLOBAL_CONFIG._config["mod"] == "sft":
return func(*args, **kwargs)
elif GLOBAL_CONFIG._config["mod"] == "infer":
with torch.no_grad():
return func(*args, **kwargs)
return wrapper
if _func is None:
return decorator
else:
return decorator(_func)