added new binding fields for quant k and quant v

This commit is contained in:
Concedo 2024-06-03 14:35:59 +08:00
parent 039cc392d1
commit 10a1d628ad
4 changed files with 67 additions and 38 deletions

View file

@ -59,7 +59,9 @@ class load_model_inputs(ctypes.Structure):
("rope_freq_scale", ctypes.c_float),
("rope_freq_base", ctypes.c_float),
("flash_attention", ctypes.c_bool),
("tensor_split", ctypes.c_float * tensor_split_max)]
("tensor_split", ctypes.c_float * tensor_split_max),
("quant_k", ctypes.c_int),
("quant_v", ctypes.c_int)]
class generation_inputs(ctypes.Structure):
_fields_ = [("seed", ctypes.c_int),
@ -294,11 +296,14 @@ def init_library():
os.add_dll_directory(abs_path)
os.add_dll_directory(os.getcwd())
if libname == lib_cublas and "CUDA_PATH" in os.environ:
os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "bin"))
newpath = os.path.join(os.environ["CUDA_PATH"], "bin")
if os.path.exists(newpath):
os.add_dll_directory(newpath)
if libname == lib_hipblas and "HIP_PATH" in os.environ:
os.add_dll_directory(os.path.join(os.environ["HIP_PATH"], "bin"))
if args.debugmode == 1:
print(f"HIP/ROCm SDK at {os.environ['HIP_PATH']} included in .DLL load path")
newpath = os.path.join(os.environ["HIP_PATH"], "bin")
if os.path.exists(newpath):
os.add_dll_directory(newpath)
handle = ctypes.CDLL(os.path.join(dir_path, libname))
handle.load_model.argtypes = [load_model_inputs]
@ -413,6 +418,8 @@ def load_model(model_filename):
inputs.use_smartcontext = args.smartcontext
inputs.use_contextshift = (0 if args.noshift else 1)
inputs.flash_attention = args.flashattention
inputs.quant_k = 0
inputs.quant_v = 0
inputs.blasbatchsize = args.blasbatchsize
inputs.forceversion = args.forceversion
inputs.gpulayers = args.gpulayers