mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
added basic support for password protection (+2 squashed commit)
Squashed commit: [ff91ca72] added basic support for password protection [91b0b208] updated docs
This commit is contained in:
parent
a69bc44e7a
commit
6c6ad93f01
4 changed files with 192 additions and 49 deletions
77
koboldcpp.py
77
koboldcpp.py
|
@ -609,6 +609,7 @@ friendlymodelname = "inactive"
|
|||
friendlysdmodelname = "inactive"
|
||||
fullsdmodelpath = "" #if empty, it's not initialized
|
||||
mmprojpath = "" #if empty, it's not initialized
|
||||
password = "" #if empty, no auth key required
|
||||
maxctx = 2048
|
||||
maxhordectx = 2048
|
||||
maxhordelen = 256
|
||||
|
@ -898,6 +899,30 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
def secure_endpoint(self): #returns false if auth fails. caller should exit
|
||||
#handle password stuff
|
||||
if password and password !="":
|
||||
auth_header = None
|
||||
auth_ok = False
|
||||
if 'Authorization' in self.headers:
|
||||
auth_header = self.headers['Authorization']
|
||||
elif 'authorization' in self.headers:
|
||||
auth_header = self.headers['authorization']
|
||||
if auth_header != None and auth_header.startswith('Bearer '):
|
||||
token = auth_header[len('Bearer '):].strip()
|
||||
if token==password:
|
||||
auth_ok = True
|
||||
if auth_ok==False:
|
||||
self.send_response(401)
|
||||
self.end_headers(content_type='application/json')
|
||||
self.wfile.write(json.dumps({"detail": {
|
||||
"error": "Unauthorized",
|
||||
"msg": "Authentication key is missing or invalid.",
|
||||
"type": "unauthorized",
|
||||
}}).encode())
|
||||
return False
|
||||
return True
|
||||
|
||||
def noscript_webui(self):
|
||||
global modelbusy
|
||||
import html
|
||||
|
@ -978,7 +1003,7 @@ Enter Prompt:<br>
|
|||
self.wfile.write(finalhtml)
|
||||
|
||||
def do_GET(self):
|
||||
global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens, preloaded_story, exitcounter, currentusergenkey, friendlysdmodelname, fullsdmodelpath, mmprojpath
|
||||
global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens, preloaded_story, exitcounter, currentusergenkey, friendlysdmodelname, fullsdmodelpath, mmprojpath, password
|
||||
self.path = self.path.rstrip('/')
|
||||
response_body = None
|
||||
content_type = 'application/json'
|
||||
|
@ -1018,7 +1043,8 @@ Enter Prompt:<br>
|
|||
elif self.path.endswith(('/api/extra/version')):
|
||||
has_txt2img = not (friendlysdmodelname=="inactive" or fullsdmodelpath=="")
|
||||
has_vision = (mmprojpath!="")
|
||||
response_body = (json.dumps({"result":"KoboldCpp","version":KcppVersion,"txt2img":has_txt2img,"vision":has_vision}).encode())
|
||||
has_password = (password!="")
|
||||
response_body = (json.dumps({"result":"KoboldCpp","version":KcppVersion, "protected":has_password ,"txt2img":has_txt2img,"vision":has_vision}).encode())
|
||||
|
||||
elif self.path.endswith(('/api/extra/perf')):
|
||||
lastp = handle.get_last_process_time()
|
||||
|
@ -1031,6 +1057,8 @@ Enter Prompt:<br>
|
|||
response_body = (json.dumps({"last_process":lastp,"last_eval":laste,"last_token_count":lastc, "last_seed":lastseed, "total_gens":totalgens, "stop_reason":stopreason, "queue":requestsinqueue, "idle":(0 if modelbusy.locked() else 1), "hordeexitcounter":exitcounter, "uptime":uptime}).encode())
|
||||
|
||||
elif self.path.endswith('/api/extra/generate/check'):
|
||||
if not self.secure_endpoint():
|
||||
return
|
||||
pendtxtStr = ""
|
||||
if requestsinqueue==0 and totalgens>0 and currentusergenkey=="":
|
||||
pendtxt = handle.get_pending_output()
|
||||
|
@ -1102,6 +1130,8 @@ Enter Prompt:<br>
|
|||
response_code = 200
|
||||
|
||||
if self.path.endswith(('/api/extra/tokencount')):
|
||||
if not self.secure_endpoint():
|
||||
return
|
||||
try:
|
||||
genparams = json.loads(body)
|
||||
countprompt = genparams.get('prompt', "")
|
||||
|
@ -1117,6 +1147,8 @@ Enter Prompt:<br>
|
|||
response_body = (json.dumps({"value": -1}).encode())
|
||||
|
||||
elif self.path.endswith('/api/extra/abort'):
|
||||
if not self.secure_endpoint():
|
||||
return
|
||||
multiuserkey = ""
|
||||
try:
|
||||
tempbody = json.loads(body)
|
||||
|
@ -1125,7 +1157,6 @@ Enter Prompt:<br>
|
|||
except Exception as e:
|
||||
multiuserkey = ""
|
||||
pass
|
||||
|
||||
if (multiuserkey=="" and requestsinqueue==0) or (multiuserkey!="" and multiuserkey==currentusergenkey):
|
||||
ag = handle.abort_generate()
|
||||
time.sleep(0.1) #short delay before replying
|
||||
|
@ -1138,6 +1169,8 @@ Enter Prompt:<br>
|
|||
response_body = (json.dumps({"success": "false", "done":"false"}).encode())
|
||||
|
||||
elif self.path.endswith('/api/extra/generate/check'):
|
||||
if not self.secure_endpoint():
|
||||
return
|
||||
pendtxtStr = ""
|
||||
multiuserkey = ""
|
||||
try:
|
||||
|
@ -1216,6 +1249,11 @@ Enter Prompt:<br>
|
|||
is_txt2img = True
|
||||
|
||||
if is_txt2img or api_format > 0:
|
||||
|
||||
if not is_txt2img and api_format<5:
|
||||
if not self.secure_endpoint():
|
||||
return
|
||||
|
||||
genparams = None
|
||||
try:
|
||||
genparams = json.loads(body)
|
||||
|
@ -1410,7 +1448,7 @@ def show_new_gui():
|
|||
|
||||
tabs = ctk.CTkFrame(root, corner_radius = 0, width=windowwidth, height=windowheight-50)
|
||||
tabs.grid(row=0, stick="nsew")
|
||||
tabnames= ["Quick Launch", "Hardware", "Tokens", "Model", "Network","Image Gen"]
|
||||
tabnames= ["Quick Launch", "Hardware", "Tokens", "Model Files", "Network", "Horde Worker","Image Gen"]
|
||||
navbuttons = {}
|
||||
navbuttonframe = ctk.CTkFrame(tabs, width=100, height=int(tabs.cget("height")))
|
||||
navbuttonframe.grid(row=0, column=0, padx=2,pady=2)
|
||||
|
@ -1501,6 +1539,7 @@ def show_new_gui():
|
|||
usehorde_var = ctk.IntVar()
|
||||
ssl_cert_var = ctk.StringVar()
|
||||
ssl_key_var = ctk.StringVar()
|
||||
password_var = ctk.StringVar()
|
||||
|
||||
sd_model_var = ctk.StringVar()
|
||||
sd_quick_var = ctk.IntVar(value=0)
|
||||
|
@ -1934,7 +1973,7 @@ def show_new_gui():
|
|||
togglerope(1,1,1)
|
||||
|
||||
# Model Tab
|
||||
model_tab = tabcontent["Model"]
|
||||
model_tab = tabcontent["Model Files"]
|
||||
|
||||
makefileentry(model_tab, "Model:", "Select GGML Model File", model_var, 1, onchoosefile=on_picked_model_file,tooltiptxt="Select a GGUF or GGML model file on disk to be loaded.")
|
||||
makefileentry(model_tab, "Lora:", "Select Lora File",lora_var, 3,tooltiptxt="Select an optional GGML LoRA adapter to use.\nLeave blank to skip.")
|
||||
|
@ -1956,15 +1995,17 @@ def show_new_gui():
|
|||
|
||||
makefileentry(network_tab, "SSL Cert:", "Select SSL cert.pem file",ssl_cert_var, 5, width=130 ,filetypes=[("Unencrypted Certificate PEM", "*.pem")], singlerow=True,tooltiptxt="Select your unencrypted .pem SSL certificate file for https.\nCan be generated with OpenSSL.")
|
||||
makefileentry(network_tab, "SSL Key:", "Select SSL key.pem file", ssl_key_var, 7, width=130, filetypes=[("Unencrypted Key PEM", "*.pem")], singlerow=True,tooltiptxt="Select your unencrypted .pem SSL key file for https.\nCan be generated with OpenSSL.")
|
||||
makelabelentry(network_tab, "Password: ", password_var, 8, 150,tooltip="Enter a password required to use this instance.\nThis key will be required for all text endpoints.\nImage endpoints are not secured.")
|
||||
|
||||
# horde
|
||||
makelabel(network_tab, "Horde:", 18,0,"Settings for embedded AI Horde worker").grid(pady=10)
|
||||
# Horde Tab
|
||||
horde_tab = tabcontent["Horde Worker"]
|
||||
makelabel(horde_tab, "Horde:", 18,0,"Settings for embedded AI Horde worker").grid(pady=10)
|
||||
|
||||
horde_name_entry, horde_name_label = makelabelentry(network_tab, "Horde Model Name:", horde_name_var, 20, 180,"The model name to be displayed on the AI Horde.")
|
||||
horde_gen_entry, horde_gen_label = makelabelentry(network_tab, "Gen. Length:", horde_gen_var, 21, 50,"The maximum amount to generate per request \nthat this worker will accept jobs for.")
|
||||
horde_context_entry, horde_context_label = makelabelentry(network_tab, "Max Context:",horde_context_var, 22, 50,"The maximum context length \nthat this worker will accept jobs for.")
|
||||
horde_apikey_entry, horde_apikey_label = makelabelentry(network_tab, "API Key (If Embedded Worker):",horde_apikey_var, 23, 180,"Your AI Horde API Key that you have registered.")
|
||||
horde_workername_entry, horde_workername_label = makelabelentry(network_tab, "Horde Worker Name:",horde_workername_var, 24, 180,"Your worker's name to be displayed.")
|
||||
horde_name_entry, horde_name_label = makelabelentry(horde_tab, "Horde Model Name:", horde_name_var, 20, 180,"The model name to be displayed on the AI Horde.")
|
||||
horde_gen_entry, horde_gen_label = makelabelentry(horde_tab, "Gen. Length:", horde_gen_var, 21, 50,"The maximum amount to generate per request \nthat this worker will accept jobs for.")
|
||||
horde_context_entry, horde_context_label = makelabelentry(horde_tab, "Max Context:",horde_context_var, 22, 50,"The maximum context length \nthat this worker will accept jobs for.")
|
||||
horde_apikey_entry, horde_apikey_label = makelabelentry(horde_tab, "API Key (If Embedded Worker):",horde_apikey_var, 23, 180,"Your AI Horde API Key that you have registered.")
|
||||
horde_workername_entry, horde_workername_label = makelabelentry(horde_tab, "Horde Worker Name:",horde_workername_var, 24, 180,"Your worker's name to be displayed.")
|
||||
|
||||
def togglehorde(a,b,c):
|
||||
labels = [horde_name_label, horde_gen_label, horde_context_label, horde_apikey_label, horde_workername_label]
|
||||
|
@ -1979,7 +2020,7 @@ def show_new_gui():
|
|||
basefile = os.path.basename(model_var.get())
|
||||
horde_name_var.set(sanitize_string(os.path.splitext(basefile)[0]))
|
||||
|
||||
makecheckbox(network_tab, "Configure for Horde", usehorde_var, 19, command=togglehorde,tooltiptxt="Enable the embedded AI Horde worker.")
|
||||
makecheckbox(horde_tab, "Configure for Horde", usehorde_var, 19, command=togglehorde,tooltiptxt="Enable the embedded AI Horde worker.")
|
||||
togglehorde(1,1,1)
|
||||
|
||||
# Image Gen Tab
|
||||
|
@ -2067,6 +2108,7 @@ def show_new_gui():
|
|||
args.mmproj = None if mmproj_var.get() == "" else mmproj_var.get()
|
||||
|
||||
args.ssl = None if (ssl_cert_var.get() == "" or ssl_key_var.get() == "") else ([ssl_cert_var.get(), ssl_key_var.get()])
|
||||
args.password = None if (password_var.get() == "") else (password_var.get())
|
||||
|
||||
args.port_param = defaultport if port_var.get()=="" else int(port_var.get())
|
||||
args.host = host_var.get()
|
||||
|
@ -2188,6 +2230,9 @@ def show_new_gui():
|
|||
ssl_cert_var.set(dict["ssl"][0])
|
||||
ssl_key_var.set(dict["ssl"][1])
|
||||
|
||||
if "password" in dict and dict["password"]:
|
||||
password_var.set(dict["password"])
|
||||
|
||||
if "preloadstory" in dict and dict["preloadstory"]:
|
||||
preloadstory_var.set(dict["preloadstory"])
|
||||
|
||||
|
@ -2634,7 +2679,7 @@ def sanitize_string(input_string):
|
|||
return sanitized_string
|
||||
|
||||
def main(launch_args,start_server=True):
|
||||
global args, friendlymodelname, friendlysdmodelname, fullsdmodelpath, mmprojpath
|
||||
global args, friendlymodelname, friendlysdmodelname, fullsdmodelpath, mmprojpath, password
|
||||
args = launch_args
|
||||
embedded_kailite = None
|
||||
embedded_kcpp_docs = None
|
||||
|
@ -2769,6 +2814,9 @@ def main(launch_args,start_server=True):
|
|||
args.mmproj = os.path.abspath(args.mmproj)
|
||||
mmprojpath = args.mmproj
|
||||
|
||||
if args.password and args.password!="":
|
||||
password = args.password.strip()
|
||||
|
||||
if not args.blasthreads or args.blasthreads <= 0:
|
||||
args.blasthreads = args.threads
|
||||
|
||||
|
@ -3017,5 +3065,6 @@ if __name__ == '__main__':
|
|||
parser.add_argument("--nocertify", help="Allows insecure SSL connections. Use this if you have cert errors and need to bypass certificate restrictions.", action='store_true')
|
||||
parser.add_argument("--sdconfig", help="Specify a stable diffusion safetensors model to enable image generation. If quick is specified, force optimal generation settings for speed.",metavar=('[sd_filename]', '[normal|quick|clamped] [threads] [quant|noquant]'), nargs='+')
|
||||
parser.add_argument("--mmproj", help="Select a multimodal projector file for LLaVA.", default="")
|
||||
parser.add_argument("--password", help="Enter a password required to use this instance. This key will be required for all text endpoints. Image endpoints are not secured.", default=None)
|
||||
|
||||
main(parser.parse_args(),start_server=True)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue