added basic support for password protection (+2 squashed commit)

Squashed commit:

[ff91ca72] added basic support for password protection

[91b0b208] updated docs
This commit is contained in:
Concedo 2024-03-12 16:05:51 +08:00
parent a69bc44e7a
commit 6c6ad93f01
4 changed files with 192 additions and 49 deletions

View file

@ -609,6 +609,7 @@ friendlymodelname = "inactive"
friendlysdmodelname = "inactive"
fullsdmodelpath = "" #if empty, it's not initialized
mmprojpath = "" #if empty, it's not initialized
password = "" #if empty, no auth key required
maxctx = 2048
maxhordectx = 2048
maxhordelen = 256
@ -898,6 +899,30 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
except Exception as e:
print(e)
def secure_endpoint(self): #returns false if auth fails. caller should exit
#handle password stuff
if password and password !="":
auth_header = None
auth_ok = False
if 'Authorization' in self.headers:
auth_header = self.headers['Authorization']
elif 'authorization' in self.headers:
auth_header = self.headers['authorization']
if auth_header != None and auth_header.startswith('Bearer '):
token = auth_header[len('Bearer '):].strip()
if token==password:
auth_ok = True
if auth_ok==False:
self.send_response(401)
self.end_headers(content_type='application/json')
self.wfile.write(json.dumps({"detail": {
"error": "Unauthorized",
"msg": "Authentication key is missing or invalid.",
"type": "unauthorized",
}}).encode())
return False
return True
def noscript_webui(self):
global modelbusy
import html
@ -978,7 +1003,7 @@ Enter Prompt:<br>
self.wfile.write(finalhtml)
def do_GET(self):
global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens, preloaded_story, exitcounter, currentusergenkey, friendlysdmodelname, fullsdmodelpath, mmprojpath
global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens, preloaded_story, exitcounter, currentusergenkey, friendlysdmodelname, fullsdmodelpath, mmprojpath, password
self.path = self.path.rstrip('/')
response_body = None
content_type = 'application/json'
@ -1018,7 +1043,8 @@ Enter Prompt:<br>
elif self.path.endswith(('/api/extra/version')):
has_txt2img = not (friendlysdmodelname=="inactive" or fullsdmodelpath=="")
has_vision = (mmprojpath!="")
response_body = (json.dumps({"result":"KoboldCpp","version":KcppVersion,"txt2img":has_txt2img,"vision":has_vision}).encode())
has_password = (password!="")
response_body = (json.dumps({"result":"KoboldCpp","version":KcppVersion, "protected":has_password ,"txt2img":has_txt2img,"vision":has_vision}).encode())
elif self.path.endswith(('/api/extra/perf')):
lastp = handle.get_last_process_time()
@ -1031,6 +1057,8 @@ Enter Prompt:<br>
response_body = (json.dumps({"last_process":lastp,"last_eval":laste,"last_token_count":lastc, "last_seed":lastseed, "total_gens":totalgens, "stop_reason":stopreason, "queue":requestsinqueue, "idle":(0 if modelbusy.locked() else 1), "hordeexitcounter":exitcounter, "uptime":uptime}).encode())
elif self.path.endswith('/api/extra/generate/check'):
if not self.secure_endpoint():
return
pendtxtStr = ""
if requestsinqueue==0 and totalgens>0 and currentusergenkey=="":
pendtxt = handle.get_pending_output()
@ -1102,6 +1130,8 @@ Enter Prompt:<br>
response_code = 200
if self.path.endswith(('/api/extra/tokencount')):
if not self.secure_endpoint():
return
try:
genparams = json.loads(body)
countprompt = genparams.get('prompt', "")
@ -1117,6 +1147,8 @@ Enter Prompt:<br>
response_body = (json.dumps({"value": -1}).encode())
elif self.path.endswith('/api/extra/abort'):
if not self.secure_endpoint():
return
multiuserkey = ""
try:
tempbody = json.loads(body)
@ -1125,7 +1157,6 @@ Enter Prompt:<br>
except Exception as e:
multiuserkey = ""
pass
if (multiuserkey=="" and requestsinqueue==0) or (multiuserkey!="" and multiuserkey==currentusergenkey):
ag = handle.abort_generate()
time.sleep(0.1) #short delay before replying
@ -1138,6 +1169,8 @@ Enter Prompt:<br>
response_body = (json.dumps({"success": "false", "done":"false"}).encode())
elif self.path.endswith('/api/extra/generate/check'):
if not self.secure_endpoint():
return
pendtxtStr = ""
multiuserkey = ""
try:
@ -1216,6 +1249,11 @@ Enter Prompt:<br>
is_txt2img = True
if is_txt2img or api_format > 0:
if not is_txt2img and api_format<5:
if not self.secure_endpoint():
return
genparams = None
try:
genparams = json.loads(body)
@ -1410,7 +1448,7 @@ def show_new_gui():
tabs = ctk.CTkFrame(root, corner_radius = 0, width=windowwidth, height=windowheight-50)
tabs.grid(row=0, stick="nsew")
tabnames= ["Quick Launch", "Hardware", "Tokens", "Model", "Network","Image Gen"]
tabnames= ["Quick Launch", "Hardware", "Tokens", "Model Files", "Network", "Horde Worker","Image Gen"]
navbuttons = {}
navbuttonframe = ctk.CTkFrame(tabs, width=100, height=int(tabs.cget("height")))
navbuttonframe.grid(row=0, column=0, padx=2,pady=2)
@ -1501,6 +1539,7 @@ def show_new_gui():
usehorde_var = ctk.IntVar()
ssl_cert_var = ctk.StringVar()
ssl_key_var = ctk.StringVar()
password_var = ctk.StringVar()
sd_model_var = ctk.StringVar()
sd_quick_var = ctk.IntVar(value=0)
@ -1934,7 +1973,7 @@ def show_new_gui():
togglerope(1,1,1)
# Model Tab
model_tab = tabcontent["Model"]
model_tab = tabcontent["Model Files"]
makefileentry(model_tab, "Model:", "Select GGML Model File", model_var, 1, onchoosefile=on_picked_model_file,tooltiptxt="Select a GGUF or GGML model file on disk to be loaded.")
makefileentry(model_tab, "Lora:", "Select Lora File",lora_var, 3,tooltiptxt="Select an optional GGML LoRA adapter to use.\nLeave blank to skip.")
@ -1956,15 +1995,17 @@ def show_new_gui():
makefileentry(network_tab, "SSL Cert:", "Select SSL cert.pem file",ssl_cert_var, 5, width=130 ,filetypes=[("Unencrypted Certificate PEM", "*.pem")], singlerow=True,tooltiptxt="Select your unencrypted .pem SSL certificate file for https.\nCan be generated with OpenSSL.")
makefileentry(network_tab, "SSL Key:", "Select SSL key.pem file", ssl_key_var, 7, width=130, filetypes=[("Unencrypted Key PEM", "*.pem")], singlerow=True,tooltiptxt="Select your unencrypted .pem SSL key file for https.\nCan be generated with OpenSSL.")
makelabelentry(network_tab, "Password: ", password_var, 8, 150,tooltip="Enter a password required to use this instance.\nThis key will be required for all text endpoints.\nImage endpoints are not secured.")
# horde
makelabel(network_tab, "Horde:", 18,0,"Settings for embedded AI Horde worker").grid(pady=10)
# Horde Tab
horde_tab = tabcontent["Horde Worker"]
makelabel(horde_tab, "Horde:", 18,0,"Settings for embedded AI Horde worker").grid(pady=10)
horde_name_entry, horde_name_label = makelabelentry(network_tab, "Horde Model Name:", horde_name_var, 20, 180,"The model name to be displayed on the AI Horde.")
horde_gen_entry, horde_gen_label = makelabelentry(network_tab, "Gen. Length:", horde_gen_var, 21, 50,"The maximum amount to generate per request \nthat this worker will accept jobs for.")
horde_context_entry, horde_context_label = makelabelentry(network_tab, "Max Context:",horde_context_var, 22, 50,"The maximum context length \nthat this worker will accept jobs for.")
horde_apikey_entry, horde_apikey_label = makelabelentry(network_tab, "API Key (If Embedded Worker):",horde_apikey_var, 23, 180,"Your AI Horde API Key that you have registered.")
horde_workername_entry, horde_workername_label = makelabelentry(network_tab, "Horde Worker Name:",horde_workername_var, 24, 180,"Your worker's name to be displayed.")
horde_name_entry, horde_name_label = makelabelentry(horde_tab, "Horde Model Name:", horde_name_var, 20, 180,"The model name to be displayed on the AI Horde.")
horde_gen_entry, horde_gen_label = makelabelentry(horde_tab, "Gen. Length:", horde_gen_var, 21, 50,"The maximum amount to generate per request \nthat this worker will accept jobs for.")
horde_context_entry, horde_context_label = makelabelentry(horde_tab, "Max Context:",horde_context_var, 22, 50,"The maximum context length \nthat this worker will accept jobs for.")
horde_apikey_entry, horde_apikey_label = makelabelentry(horde_tab, "API Key (If Embedded Worker):",horde_apikey_var, 23, 180,"Your AI Horde API Key that you have registered.")
horde_workername_entry, horde_workername_label = makelabelentry(horde_tab, "Horde Worker Name:",horde_workername_var, 24, 180,"Your worker's name to be displayed.")
def togglehorde(a,b,c):
labels = [horde_name_label, horde_gen_label, horde_context_label, horde_apikey_label, horde_workername_label]
@ -1979,7 +2020,7 @@ def show_new_gui():
basefile = os.path.basename(model_var.get())
horde_name_var.set(sanitize_string(os.path.splitext(basefile)[0]))
makecheckbox(network_tab, "Configure for Horde", usehorde_var, 19, command=togglehorde,tooltiptxt="Enable the embedded AI Horde worker.")
makecheckbox(horde_tab, "Configure for Horde", usehorde_var, 19, command=togglehorde,tooltiptxt="Enable the embedded AI Horde worker.")
togglehorde(1,1,1)
# Image Gen Tab
@ -2067,6 +2108,7 @@ def show_new_gui():
args.mmproj = None if mmproj_var.get() == "" else mmproj_var.get()
args.ssl = None if (ssl_cert_var.get() == "" or ssl_key_var.get() == "") else ([ssl_cert_var.get(), ssl_key_var.get()])
args.password = None if (password_var.get() == "") else (password_var.get())
args.port_param = defaultport if port_var.get()=="" else int(port_var.get())
args.host = host_var.get()
@ -2188,6 +2230,9 @@ def show_new_gui():
ssl_cert_var.set(dict["ssl"][0])
ssl_key_var.set(dict["ssl"][1])
if "password" in dict and dict["password"]:
password_var.set(dict["password"])
if "preloadstory" in dict and dict["preloadstory"]:
preloadstory_var.set(dict["preloadstory"])
@ -2634,7 +2679,7 @@ def sanitize_string(input_string):
return sanitized_string
def main(launch_args,start_server=True):
global args, friendlymodelname, friendlysdmodelname, fullsdmodelpath, mmprojpath
global args, friendlymodelname, friendlysdmodelname, fullsdmodelpath, mmprojpath, password
args = launch_args
embedded_kailite = None
embedded_kcpp_docs = None
@ -2769,6 +2814,9 @@ def main(launch_args,start_server=True):
args.mmproj = os.path.abspath(args.mmproj)
mmprojpath = args.mmproj
if args.password and args.password!="":
password = args.password.strip()
if not args.blasthreads or args.blasthreads <= 0:
args.blasthreads = args.threads
@ -3017,5 +3065,6 @@ if __name__ == '__main__':
parser.add_argument("--nocertify", help="Allows insecure SSL connections. Use this if you have cert errors and need to bypass certificate restrictions.", action='store_true')
parser.add_argument("--sdconfig", help="Specify a stable diffusion safetensors model to enable image generation. If quick is specified, force optimal generation settings for speed.",metavar=('[sd_filename]', '[normal|quick|clamped] [threads] [quant|noquant]'), nargs='+')
parser.add_argument("--mmproj", help="Select a multimodal projector file for LLaVA.", default="")
parser.add_argument("--password", help="Enter a password required to use this instance. This key will be required for all text endpoints. Image endpoints are not secured.", default=None)
main(parser.parse_args(),start_server=True)