added preloadstory

This commit is contained in:
Concedo 2023-11-10 13:05:22 +08:00
parent 6870c31933
commit be92cfa125
5 changed files with 65 additions and 7 deletions

View file

@ -214,6 +214,7 @@ def init_library():
handle.get_last_eval_time.restype = ctypes.c_float
handle.get_last_process_time.restype = ctypes.c_float
handle.get_last_token_count.restype = ctypes.c_int
handle.get_total_gens.restype = ctypes.c_int
handle.get_last_stop_reason.restype = ctypes.c_int
handle.abort_generate.restype = ctypes.c_bool
handle.token_count.restype = ctypes.c_int
@ -401,6 +402,7 @@ totalgens = 0
currentusergenkey = "" #store a special key so polled streaming works even in multiuser
args = None #global args
gui_layers_untouched = True
preloaded_story = None
class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
sys_version = ""
@ -618,7 +620,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
def do_GET(self):
global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens
global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens, preloaded_story
self.path = self.path.rstrip('/')
response_body = None
content_type = 'application/json'
@ -658,8 +660,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
lastp = handle.get_last_process_time()
laste = handle.get_last_eval_time()
lastc = handle.get_last_token_count()
totalgens = handle.get_total_gens()
stopreason = handle.get_last_stop_reason()
response_body = (json.dumps({"last_process":lastp,"last_eval":laste,"last_token_count":lastc, "stop_reason":stopreason, "queue":requestsinqueue, "idle":(0 if modelbusy.locked() else 1)}).encode())
response_body = (json.dumps({"last_process":lastp,"last_eval":laste,"last_token_count":lastc, "total_gens":totalgens, "stop_reason":stopreason, "queue":requestsinqueue, "idle":(0 if modelbusy.locked() else 1)}).encode())
elif self.path.endswith('/api/extra/generate/check'):
pendtxtStr = ""
@ -677,6 +680,12 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
response_body = (f"KoboldCpp partial API reference can be found at the wiki: https://github.com/LostRuins/koboldcpp/wiki").encode()
else:
response_body = self.embedded_kcpp_docs
elif self.path=="/api/extra/preloadstory":
if preloaded_story is None:
response_body = (json.dumps({}).encode())
else:
response_body = preloaded_story
elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')):
self.path = "/api"
self.send_response(302)
@ -1008,7 +1017,8 @@ def show_new_gui():
model_var = ctk.StringVar()
lora_var = ctk.StringVar()
lora_base_var = ctk.StringVar()
lora_base_var = ctk.StringVar()
preloadstory_var = ctk.StringVar()
port_var = ctk.StringVar(value=defaultport)
host_var = ctk.StringVar(value="")
@ -1404,6 +1414,7 @@ def show_new_gui():
makefileentry(model_tab, "Model:", "Select GGML Model File", model_var, 1, onchoosefile=autoset_gpu_layers)
makefileentry(model_tab, "Lora:", "Select Lora File",lora_var, 3)
makefileentry(model_tab, "Lora Base:", "Select Lora Base File", lora_base_var, 5)
makefileentry(model_tab, "Preloaded Story:", "Select Preloaded Story File", preloadstory_var, 7)
# Network Tab
network_tab = tabcontent["Network"]
@ -1505,6 +1516,7 @@ def show_new_gui():
args.model_param = None if model_var.get() == "" else model_var.get()
args.lora = None if lora_var.get() == "" else ([lora_var.get()] if lora_base_var.get()=="" else [lora_var.get(), lora_base_var.get()])
args.preloadstory = None if preloadstory_var.get() == "" else preloadstory_var.get()
args.port_param = defaultport if port_var.get()=="" else int(port_var.get())
args.host = host_var.get()
@ -1595,6 +1607,9 @@ def show_new_gui():
else:
lora_var.set(dict["lora"][0])
if "preloadstory" in dict and dict["preloadstory"]:
preloadstory_var.set(dict["preloadstory"])
if "port_param" in dict and dict["port_param"]:
port_var.set(dict["port_param"])
@ -1963,6 +1978,7 @@ def unload_libs():
del handle.get_last_eval_time
del handle.get_last_process_time
del handle.get_last_token_count
del handle.get_total_gens
del handle.get_last_stop_reason
del handle.abort_generate
del handle.token_count
@ -2018,6 +2034,17 @@ def main(launch_args,start_server=True):
time.sleep(3)
sys.exit(2)
#try to read story if provided
if args.preloadstory:
if isinstance(args.preloadstory, str) and os.path.exists(args.preloadstory):
print(f"Preloading saved story {args.preloadstory} into server...")
with open(args.preloadstory, mode='rb') as f:
global preloaded_story
preloaded_story = f.read()
print("Saved story preloaded.")
else:
print(f"Warning: Saved story file {args.preloadstory} invalid or not found. No story will be preloaded into server.")
# sanitize and replace the default vanity name. remember me....
if args.model_param!="":
newmdldisplayname = os.path.basename(args.model_param)
@ -2201,6 +2228,7 @@ if __name__ == '__main__':
parser.add_argument("--multiuser", help="Runs in multiuser mode, which queues incoming requests instead of blocking them.", action='store_true')
parser.add_argument("--remotetunnel", help="Uses Cloudflare to create a remote tunnel, allowing you to access koboldcpp remotely over the internet even behind a firewall.", action='store_true')
parser.add_argument("--foreground", help="Windows only. Sends the terminal to the foreground every time a new prompt is generated. This helps avoid some idle slowdown issues.", action='store_true')
parser.add_argument("--preloadstory", help="Configures a prepared story json save file to be hosted on the server, which frontends (such as Kobold Lite) can access over the API.", default="")
# #deprecated hidden args. they do nothing. do not use
# parser.add_argument("--psutil_set_threads", action='store_true', help=argparse.SUPPRESS)