integrated optional (experimentl) CLBlast support

This commit is contained in:
Concedo 2023-04-11 23:33:44 +08:00
parent c9f18082fd
commit 23c675b2e6
53 changed files with 22095 additions and 151 deletions

View file

@ -15,7 +15,8 @@ class load_model_inputs(ctypes.Structure):
("f16_kv", ctypes.c_bool),
("model_filename", ctypes.c_char_p),
("n_parts_overwrite", ctypes.c_int),
("use_mmap", ctypes.c_bool)]
("use_mmap", ctypes.c_bool),
("clblast_info", ctypes.c_int)]
class generation_inputs(ctypes.Structure):
_fields_ = [("seed", ctypes.c_int),
@ -34,12 +35,15 @@ class generation_outputs(ctypes.Structure):
handle = None
use_blas = False # if true, uses OpenBLAS for acceleration. libopenblas.dll must exist in the same dir.
use_clblast = False #uses CLBlast instead
def init_library():
global handle, use_blas
global handle, use_blas, use_clblast
libname = ""
if use_blas:
libname = "koboldcpp_blas.dll"
libname = "koboldcpp_openblas.dll"
elif use_clblast:
libname = "koboldcpp_clblast.dll"
else:
libname = "koboldcpp.dll"
@ -63,6 +67,10 @@ def load_model(model_filename,batch_size=8,max_context_length=512,n_parts_overwr
inputs.n_parts_overwrite = n_parts_overwrite
inputs.f16_kv = True
inputs.use_mmap = use_mmap
clblastids = 0
if args.useclblast:
clblastids = int(args.useclblast[0])*10 + int(args.useclblast[1])
inputs.clblast_info = clblastids
ret = handle.load_model(inputs)
return ret
@ -301,13 +309,19 @@ def RunServerMultiThreaded(addr, port, embedded_kailite = None):
sys.exit(0)
def main(args):
global use_blas
if not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), "libopenblas.dll")) or not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), "koboldcpp_blas.dll")):
print("Warning: libopenblas.dll or koboldcpp_blas.dll not found. Non-BLAS library will be used. Ignore this if you have manually linked with OpenBLAS.")
global use_blas, use_clblast
if not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), "libopenblas.dll")) or not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), "koboldcpp_openblas.dll")):
print("Warning: libopenblas.dll or koboldcpp_openblas.dll not found. Non-BLAS library will be used. Ignore this if you have manually linked with OpenBLAS.")
use_blas = False
elif os.name != 'nt':
print("Prebuilt OpenBLAS binaries only available for windows. Please manually build/link libopenblas from makefile with LLAMA_OPENBLAS=1")
use_blas = False
elif args.useclblast:
if not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), "clblast.dll")) or not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), "koboldcpp_clblast.dll")):
print("Warning: clblast.dll or koboldcpp_clblast.dll not found. Non-BLAS library will be used. Ignore this if you have manually linked with CLBlast.")
else:
print("Attempting to use CLBlast library for faster prompt ingestion. A compatible clblast.dll will be required.")
use_clblast = True
elif not args.noblas:
print("Attempting to use OpenBLAS library for faster prompt ingestion. A compatible libopenblas.dll will be required.")
use_blas = True
@ -397,5 +411,6 @@ if __name__ == '__main__':
parser.add_argument("--stream", help="Uses pseudo streaming", action='store_true')
parser.add_argument("--noblas", help="Do not use OpenBLAS for accelerated prompt ingestion", action='store_true')
parser.add_argument("--nommap", help="If set, do not use mmap to load newer models", action='store_true')
parser.add_argument("--useclblast", help="Use CLBlast instead of OpenBLAS for prompt ingestion. Must specify exactly 2 arguments, platform ID and device ID (e.g. --useclblast 1 0).", type=int, choices=range(0,9), nargs=2)
args = parser.parse_args()
main(args)