diff --git a/koboldcpp.py b/koboldcpp.py index 42b495956..02dc45140 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -361,81 +361,27 @@ def init_library(): global handle, args, libname global lib_default,lib_failsafe,lib_noavx2,lib_clblast,lib_clblast_noavx2,lib_cublas,lib_hipblas,lib_vulkan,lib_vulkan_noavx2 - libname = "" - use_clblast = False #uses CLBlast instead - use_cublas = False #uses cublas instead - use_hipblas = False #uses hipblas instead - use_noavx2 = False #uses no avx2 instructions - use_failsafe = False #uses no intrinsics, failsafe mode - use_vulkan = False #uses vulkan (needs avx2) + libname = lib_default if args.noavx2: - use_noavx2 = True - if args.useclblast: - if not file_exists(lib_clblast_noavx2) or (os.name=='nt' and not file_exists("clblast.dll")): - print("Warning: NoAVX2 CLBlast library file not found. CPU library will be used.") - else: - print("Attempting to use NoAVX2 CLBlast library for faster prompt ingestion. A compatible clblast will be required.") - use_clblast = True - elif (args.usevulkan is not None): - if not file_exists(lib_vulkan_noavx2): - print("Warning: NoAVX2 Vulkan library file not found. CPU library will be used.") - else: - print("Attempting to use NoAVX2 Vulkan library for faster prompt ingestion. A compatible Vulkan will be required.") - use_vulkan = True - else: - if not file_exists(lib_noavx2): - print("Warning: NoAVX2 library file not found. Failsafe library will be used.") - elif (args.usecpu and args.nommap): - use_failsafe = True - print("!!! Attempting to use FAILSAFE MODE !!!") - else: - print("Attempting to use non-avx2 compatibility library.") - elif (args.usecublas is not None): - if not file_exists(lib_cublas) and not file_exists(lib_hipblas): - print("Warning: CuBLAS library file not found. CPU library will be used.") - else: - if file_exists(lib_cublas): - print("Attempting to use CuBLAS library for faster prompt ingestion. A compatible CuBLAS will be required.") - use_cublas = True - elif file_exists(lib_hipblas): - print("Attempting to use hipBLAS library for faster prompt ingestion. A compatible AMD GPU will be required.") - use_hipblas = True - elif (args.usevulkan is not None): - if not file_exists(lib_vulkan): - print("Warning: Vulkan library file not found. CPU library will be used.") - else: - print("Attempting to use Vulkan library for faster prompt ingestion. A compatible Vulkan will be required.") - use_vulkan = True - elif args.useclblast: - if not file_exists(lib_clblast) or (os.name=='nt' and not file_exists("clblast.dll")): - print("Warning: CLBlast library file not found. CPU library will be used.") - else: - print("Attempting to use CLBlast library for faster prompt ingestion. A compatible clblast will be required.") - use_clblast = True - else: - print("Attempting to use CPU library.") - - if use_noavx2: - if use_failsafe: - libname = lib_failsafe - elif use_clblast: + if args.useclblast and file_exists(lib_clblast_noavx2) and (os.name!='nt' or file_exists("clblast.dll")): libname = lib_clblast_noavx2 - elif use_vulkan: + elif (args.usevulkan is not None) and file_exists(lib_vulkan_noavx2): libname = lib_vulkan_noavx2 - else: + elif (args.usecpu and args.nommap) and file_exists(lib_failsafe): + print("!!! Attempting to use FAILSAFE MODE !!!") + libname = lib_failsafe + elif file_exists(lib_noavx2): libname = lib_noavx2 - else: - if use_clblast: - libname = lib_clblast - elif use_cublas: + elif (args.usecublas is not None): + if file_exists(lib_cublas): libname = lib_cublas - elif use_hipblas: + elif file_exists(lib_hipblas): libname = lib_hipblas - elif use_vulkan: - libname = lib_vulkan - else: - libname = lib_default + elif (args.usevulkan is not None) and file_exists(lib_vulkan): + libname = lib_vulkan + elif args.useclblast and file_exists(lib_clblast) and (os.name!='nt' or file_exists("clblast.dll")): + libname = lib_clblast print("Initializing dynamic library: " + libname) dir_path = getdirpath() @@ -947,7 +893,7 @@ def generate(genparams, is_quiet=False, stream_flag=False): images = genparams.get('images', []) max_context_length = genparams.get('max_context_length', maxctx) max_length = genparams.get('max_length', 200) - temperature = genparams.get('temperature', 0.7) + temperature = genparams.get('temperature', 0.75) top_k = genparams.get('top_k', 100) top_a = genparams.get('top_a', 0.0) top_p = genparams.get('top_p', 0.92) @@ -1384,7 +1330,7 @@ def parse_last_logprobs(lastlogprobs): return logprobsdict def transform_genparams(genparams, api_format): - global chatcompl_adapter + global chatcompl_adapter, maxctx #api format 1=basic,2=kai,3=oai,4=oai-chat,5=interrogate,6=ollama,7=ollamachat #alias all nonstandard alternative names for rep pen. rp1 = genparams.get('repeat_penalty', 1.0) @@ -1548,14 +1494,28 @@ ws ::= | " " | "\n" [ \t]{0,20} utfprint("Ollama Context Error: " + str(e)) ollamasysprompt = genparams.get('system', "") ollamabodyprompt = f"{detokstr}{user_message_start}{genparams.get('prompt', '')}{assistant_message_start}" + ollamaopts = genparams.get('options', {}) genparams["stop_sequence"] = genparams.get('stop', []) + if "num_predict" in ollamaopts: + genparams["max_length"] = ollamaopts.get('num_predict', 200) + if "num_ctx" in ollamaopts: + genparams["max_context_length"] = ollamaopts.get('num_ctx', maxctx) + if "temperature" in ollamaopts: + genparams["temperature"] = ollamaopts.get('temperature', 0.75) + if "top_k" in ollamaopts: + genparams["top_k"] = ollamaopts.get('top_k', 100) + if "top_p" in ollamaopts: + genparams["top_p"] = ollamaopts.get('top_p', 0.92) + if "seed" in ollamaopts: + genparams["sampler_seed"] = tryparseint(ollamaopts.get('seed', -1)) + if "stop" in ollamaopts: + genparams["stop_sequence"] = ollamaopts.get('stop', []) genparams["stop_sequence"].append(user_message_start.strip()) genparams["stop_sequence"].append(assistant_message_start.strip()) genparams["trim_stop"] = True genparams["ollamasysprompt"] = ollamasysprompt genparams["ollamabodyprompt"] = ollamabodyprompt genparams["prompt"] = ollamasysprompt + ollamabodyprompt - return genparams class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): @@ -1837,7 +1797,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): status = str(parsed_dict['status'][0]) if 'status' in parsed_dict else "Ready To Generate" prompt = str(parsed_dict['prompt'][0]) if 'prompt' in parsed_dict else "" max_length = int(parsed_dict['max_length'][0]) if 'max_length' in parsed_dict else 100 - temperature = float(parsed_dict['temperature'][0]) if 'temperature' in parsed_dict else 0.7 + temperature = float(parsed_dict['temperature'][0]) if 'temperature' in parsed_dict else 0.75 top_k = int(parsed_dict['top_k'][0]) if 'top_k' in parsed_dict else 100 top_p = float(parsed_dict['top_p'][0]) if 'top_p' in parsed_dict else 0.9 rep_pen = float(parsed_dict['rep_pen'][0]) if 'rep_pen' in parsed_dict else 1.0