mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 17:44:38 +00:00
- Thread count set equal to cpu_count() if it's < 6, otherwise set to cpu_count()-2 instead. This can be forcibly overwritten by the --threads parameter. Setting all threads=cpu_count() chokes my own PC and slows it down badly, so I'd rather make it optional. - Added localmodehost as a URL parameter in Kobold Lite instead, to avoid monkeypatching the embedded kobold lite directly. It should be parsed via ?localmodehost=(host). Also your updated klite file has the wrong encoding, it should be UTF-8, some of the symbols are incorrect such as the palette icon in settings. Repackaged the new version of Kobold Lite correctly with changes. - Reverting the TK GUI filedialog if no model is provided, because I want to keep it noob friendly for those who don't know how to use command line args. The file dialog only loads if there are no command line args. If command line args are present, the GUI will not trigger. - Modified the argparser to also take positional arguments for backwards compatibility, in addition to the optional argparse flags specified. - Your code does not work if embedded kobold is removed. The embedded KAI variable was not declared in the correct scope, and also Python f-string formatted variables cannot work with raw byte strings. You also have incorrect indentation when returning the response body - have corrected all the above but please do test all codepaths if possible. - There is a good reason to bind to "" (0.0.0.0) instead of a specific IP. It allows receiving requests from all routable interfaces. I don't know why you need an explicitly defined --host flag, but I will leave it there as an optional parameter, though the default should still be to accept from all interfaces. In that way, even if the displayed url is localhost, connecting via 192.168.x.x will also work, for example.
335 lines
14 KiB
Python
335 lines
14 KiB
Python
# A hacky little script from Concedo that exposes llama.cpp function bindings
|
|
# allowing it to be used via a simulated kobold api endpoint
|
|
# it's not very usable as there is a fundamental flaw with llama.cpp
|
|
# which causes generation delay to scale linearly with original prompt length.
|
|
|
|
import ctypes
|
|
import os
|
|
import argparse
|
|
import json, http.server, threading, socket, sys, time
|
|
|
|
class load_model_inputs(ctypes.Structure):
|
|
_fields_ = [("threads", ctypes.c_int),
|
|
("max_context_length", ctypes.c_int),
|
|
("batch_size", ctypes.c_int),
|
|
("f16_kv", ctypes.c_bool),
|
|
("model_filename", ctypes.c_char_p),
|
|
("n_parts_overwrite", ctypes.c_int)]
|
|
|
|
class generation_inputs(ctypes.Structure):
|
|
_fields_ = [("seed", ctypes.c_int),
|
|
("prompt", ctypes.c_char_p),
|
|
("max_context_length", ctypes.c_int),
|
|
("max_length", ctypes.c_int),
|
|
("temperature", ctypes.c_float),
|
|
("top_k", ctypes.c_int),
|
|
("top_p", ctypes.c_float),
|
|
("rep_pen", ctypes.c_float),
|
|
("rep_pen_range", ctypes.c_int)]
|
|
|
|
class generation_outputs(ctypes.Structure):
|
|
_fields_ = [("status", ctypes.c_int),
|
|
("text", ctypes.c_char * 16384)]
|
|
|
|
dir_path = os.path.dirname(os.path.realpath(__file__))
|
|
handle = ctypes.CDLL(os.path.join(dir_path, "llamacpp.dll"))
|
|
|
|
handle.load_model.argtypes = [load_model_inputs]
|
|
handle.load_model.restype = ctypes.c_bool
|
|
handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever
|
|
handle.generate.restype = generation_outputs
|
|
|
|
def load_model(model_filename,batch_size=8,max_context_length=512,n_parts_overwrite=-1,threads=6):
|
|
inputs = load_model_inputs()
|
|
inputs.model_filename = model_filename.encode("UTF-8")
|
|
inputs.batch_size = batch_size
|
|
inputs.max_context_length = max_context_length #initial value to use for ctx, can be overwritten
|
|
inputs.threads = threads
|
|
inputs.n_parts_overwrite = n_parts_overwrite
|
|
inputs.f16_kv = True
|
|
ret = handle.load_model(inputs)
|
|
return ret
|
|
|
|
def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=100,top_p=0.85,rep_pen=1.1,rep_pen_range=128,seed=-1):
|
|
inputs = generation_inputs()
|
|
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
|
|
inputs.prompt = prompt.encode("UTF-8")
|
|
inputs.max_context_length = max_context_length # this will resize the context buffer if changed
|
|
inputs.max_length = max_length
|
|
inputs.temperature = temperature
|
|
inputs.top_k = top_k
|
|
inputs.top_p = top_p
|
|
inputs.rep_pen = rep_pen
|
|
inputs.rep_pen_range = rep_pen_range
|
|
inputs.seed = seed
|
|
ret = handle.generate(inputs,outputs)
|
|
if(ret.status==1):
|
|
return ret.text.decode("UTF-8")
|
|
return ""
|
|
|
|
#################################################################
|
|
### A hacky simple HTTP server simulating a kobold api by Concedo
|
|
### we are intentionally NOT using flask, because we want MINIMAL dependencies
|
|
#################################################################
|
|
friendlymodelname = "concedo/llamacpp" # local kobold api apparently needs a hardcoded known HF model name
|
|
maxctx = 2048
|
|
maxlen = 128
|
|
modelbusy = False
|
|
|
|
class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|
sys_version = ""
|
|
server_version = "ConcedoLlamaForKoboldServer"
|
|
|
|
def __init__(self, addr, port, embedded_kailite):
|
|
self.addr = addr
|
|
self.port = port
|
|
self.embedded_kailite = embedded_kailite
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def do_GET(self):
|
|
global maxctx, maxlen, friendlymodelname
|
|
if self.path in ["/", "/?"] or self.path.startswith(('/?','?')): #it's possible for the root url to have ?params without /
|
|
response_body = ""
|
|
if self.embedded_kailite is None:
|
|
response_body = (f"Embedded Kobold Lite is not found.<br>You will have to connect via the main KoboldAI client, or <a href='https://lite.koboldai.net?local=1&port={self.port}'>use this URL</a> to connect.").encode()
|
|
else:
|
|
response_body = self.embedded_kailite
|
|
|
|
self.send_response(200)
|
|
self.send_header('Content-Length', str(len(response_body)))
|
|
self.end_headers()
|
|
self.wfile.write(response_body)
|
|
return
|
|
|
|
self.path = self.path.rstrip('/')
|
|
if self.path.endswith(('/api/v1/model', '/api/latest/model')):
|
|
self.send_response(200)
|
|
self.end_headers()
|
|
result = {'result': friendlymodelname }
|
|
self.wfile.write(json.dumps(result).encode())
|
|
return
|
|
|
|
if self.path.endswith(('/api/v1/config/max_length', '/api/latest/config/max_length')):
|
|
self.send_response(200)
|
|
self.end_headers()
|
|
self.wfile.write(json.dumps({"value": maxlen}).encode())
|
|
return
|
|
|
|
if self.path.endswith(('/api/v1/config/max_context_length', '/api/latest/config/max_context_length')):
|
|
self.send_response(200)
|
|
self.end_headers()
|
|
self.wfile.write(json.dumps({"value": maxctx}).encode())
|
|
return
|
|
|
|
if self.path.endswith(('/api/v1/config/soft_prompt', '/api/latest/config/soft_prompt')):
|
|
self.send_response(200)
|
|
self.end_headers()
|
|
self.wfile.write(json.dumps({"value":""}).encode())
|
|
return
|
|
|
|
self.send_response(404)
|
|
self.end_headers()
|
|
rp = 'Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.'
|
|
self.wfile.write(rp.encode())
|
|
return
|
|
|
|
def do_POST(self):
|
|
global modelbusy
|
|
content_length = int(self.headers['Content-Length'])
|
|
body = self.rfile.read(content_length)
|
|
basic_api_flag = False
|
|
kai_api_flag = False
|
|
self.path = self.path.rstrip('/')
|
|
|
|
if modelbusy:
|
|
self.send_response(503)
|
|
self.end_headers()
|
|
self.wfile.write(json.dumps({"detail": {
|
|
"msg": "Server is busy; please try again later.",
|
|
"type": "service_unavailable",
|
|
}}).encode())
|
|
return
|
|
|
|
if self.path.endswith('/request'):
|
|
basic_api_flag = True
|
|
|
|
if self.path.endswith(('/api/v1/generate', '/api/latest/generate')):
|
|
kai_api_flag = True
|
|
|
|
if basic_api_flag or kai_api_flag:
|
|
genparams = None
|
|
try:
|
|
genparams = json.loads(body)
|
|
except ValueError as e:
|
|
self.send_response(503)
|
|
self.end_headers()
|
|
return
|
|
print("\nInput: " + json.dumps(genparams))
|
|
|
|
modelbusy = True
|
|
if kai_api_flag:
|
|
fullprompt = genparams.get('prompt', "")
|
|
else:
|
|
fullprompt = genparams.get('text', "")
|
|
newprompt = fullprompt
|
|
|
|
recvtxt = ""
|
|
if kai_api_flag:
|
|
recvtxt = generate(
|
|
prompt=newprompt,
|
|
max_context_length=genparams.get('max_context_length', maxctx),
|
|
max_length=genparams.get('max_length', 50),
|
|
temperature=genparams.get('temperature', 0.8),
|
|
top_k=genparams.get('top_k', 200),
|
|
top_p=genparams.get('top_p', 0.85),
|
|
rep_pen=genparams.get('rep_pen', 1.1),
|
|
rep_pen_range=genparams.get('rep_pen_range', 128),
|
|
seed=-1
|
|
)
|
|
print("\nOutput: " + recvtxt)
|
|
res = {"results": [{"text": recvtxt}]}
|
|
self.send_response(200)
|
|
self.end_headers()
|
|
self.wfile.write(json.dumps(res).encode())
|
|
else:
|
|
recvtxt = generate(
|
|
prompt=newprompt,
|
|
max_length=genparams.get('max', 50),
|
|
temperature=genparams.get('temperature', 0.8),
|
|
top_k=genparams.get('top_k', 200),
|
|
top_p=genparams.get('top_p', 0.85),
|
|
rep_pen=genparams.get('rep_pen', 1.1),
|
|
rep_pen_range=genparams.get('rep_pen_range', 128),
|
|
seed=-1
|
|
)
|
|
print("\nOutput: " + recvtxt)
|
|
res = {"data": {"seqs":[recvtxt]}}
|
|
self.send_response(200)
|
|
self.end_headers()
|
|
self.wfile.write(json.dumps(res).encode())
|
|
modelbusy = False
|
|
return
|
|
self.send_response(404)
|
|
self.end_headers()
|
|
|
|
def do_OPTIONS(self):
|
|
self.send_response(200)
|
|
self.end_headers()
|
|
|
|
def do_HEAD(self):
|
|
self.send_response(200)
|
|
self.end_headers()
|
|
|
|
def end_headers(self):
|
|
self.send_header('Access-Control-Allow-Origin', '*')
|
|
self.send_header('Access-Control-Allow-Methods', '*')
|
|
self.send_header('Access-Control-Allow-Headers', '*')
|
|
if "/api" in self.path:
|
|
self.send_header('Content-type', 'application/json')
|
|
else:
|
|
self.send_header('Content-type', 'text/html')
|
|
|
|
return super(ServerRequestHandler, self).end_headers()
|
|
|
|
|
|
def RunServerMultiThreaded(addr, port, embedded_kailite = None):
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
sock.bind((addr, port))
|
|
sock.listen(5)
|
|
|
|
class Thread(threading.Thread):
|
|
def __init__(self, i):
|
|
threading.Thread.__init__(self)
|
|
self.i = i
|
|
self.daemon = True
|
|
self.start()
|
|
|
|
def run(self):
|
|
handler = ServerRequestHandler(addr, port, embedded_kailite)
|
|
with http.server.HTTPServer((addr, port), handler, False) as self.httpd:
|
|
try:
|
|
self.httpd.socket = sock
|
|
self.httpd.server_bind = self.server_close = lambda self: None
|
|
self.httpd.serve_forever()
|
|
except (KeyboardInterrupt,SystemExit):
|
|
self.httpd.server_close()
|
|
sys.exit(0)
|
|
finally:
|
|
self.httpd.server_close()
|
|
sys.exit(0)
|
|
def stop(self):
|
|
self.httpd.server_close()
|
|
|
|
numThreads = 5
|
|
threadArr = []
|
|
for i in range(numThreads):
|
|
threadArr.append(Thread(i))
|
|
while 1:
|
|
try:
|
|
time.sleep(10)
|
|
except KeyboardInterrupt:
|
|
for i in range(numThreads):
|
|
threadArr[i].stop()
|
|
sys.exit(0)
|
|
|
|
def main(args):
|
|
ggml_selected_file = args.model_file
|
|
embedded_kailite = None
|
|
if not ggml_selected_file:
|
|
#give them a chance to pick a file
|
|
print("Please manually select ggml file:")
|
|
from tkinter.filedialog import askopenfilename
|
|
ggml_selected_file = askopenfilename (title="Select ggml model .bin files")
|
|
if not ggml_selected_file:
|
|
print("\nNo ggml model file was selected. Exiting.")
|
|
time.sleep(1)
|
|
sys.exit(2)
|
|
|
|
if not os.path.exists(ggml_selected_file):
|
|
print(f"Cannot find model file: {ggml_selected_file}")
|
|
time.sleep(1)
|
|
sys.exit(2)
|
|
|
|
mdl_nparts = sum(1 for n in range(1, 9) if os.path.exists(f"{ggml_selected_file}.{n}")) + 1
|
|
modelname = os.path.abspath(ggml_selected_file)
|
|
print(f"Loading model: {modelname}, Parts: {mdl_nparts}, Threads: {args.threads}")
|
|
loadok = load_model(modelname,8,maxctx,mdl_nparts,args.threads)
|
|
print("Load Model OK: " + str(loadok))
|
|
|
|
if not loadok:
|
|
print("Could not load model: " + modelname)
|
|
sys.exit(3)
|
|
try:
|
|
basepath = os.path.abspath(os.path.dirname(__file__))
|
|
with open(os.path.join(basepath, "klite.embd"), mode='rb') as f:
|
|
embedded_kailite = f.read()
|
|
print("Embedded Kobold Lite loaded.")
|
|
except:
|
|
print("Could not find Kobold Lite. Embedded Kobold Lite will not be available.")
|
|
|
|
print(f"Starting Kobold HTTP Server on port {args.port}")
|
|
epurl = ""
|
|
if args.host=="":
|
|
epurl = f"http://localhost:{args.port}" + ("?streaming=1" if not args.nostream else "")
|
|
else:
|
|
epurl = f"http://{args.host}:{args.port}?host={args.host}" + ("&streaming=1" if not args.nostream else "")
|
|
|
|
|
|
print(f"Please connect to custom endpoint at {epurl}")
|
|
RunServerMultiThreaded(args.host, args.port, embedded_kailite)
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='Kobold llama.cpp server')
|
|
parser.add_argument("model_file", help="Model file to load", nargs="?")
|
|
portgroup = parser.add_mutually_exclusive_group() #we want to be backwards compatible with the unnamed positional args
|
|
portgroup.add_argument("--port", help="Port to listen on", default=5001, type=int)
|
|
portgroup.add_argument("port", help="Port to listen on", default=5001, nargs="?", type=int)
|
|
parser.add_argument("--host", help="Host IP to listen on. If empty, all routable interfaces are accepted.", default="")
|
|
default_threads = (os.cpu_count() if os.cpu_count()<=6 else max(6,os.cpu_count()-2))
|
|
parser.add_argument("--threads", help="Use a custom number of threads if specified. Otherwise, uses an amount based on CPU cores", type=int, default=default_threads)
|
|
parser.add_argument("--nostream", help="Disables pseudo streaming", action='store_true')
|
|
args = parser.parse_args()
|
|
main(args)
|