allow ipv6 as well

This commit is contained in:
Concedo 2024-08-04 00:53:19 +08:00
parent 9a0976761e
commit 40481abf0c

View file

@ -37,7 +37,7 @@ password = "" #if empty, no auth key required
fullwhispermodelpath = "" #if empty, it's not initialized
maxctx = 4096
maxhordectx = 4096
maxhordelen = 350
maxhordelen = 400
modelbusy = threading.Lock()
requestsinqueue = 0
defaultport = 5001
@ -1126,13 +1126,13 @@ def transform_genparams(genparams, api_format):
if api_format==1:
genparams["prompt"] = genparams.get('text', "")
genparams["top_k"] = int(genparams.get('top_k', 120))
genparams["max_length"] = genparams.get('max', 150)
genparams["max_length"] = genparams.get('max', 180)
elif api_format==2:
pass
elif api_format==3 or api_format==4:
genparams["max_length"] = genparams.get('max_tokens', (350 if api_format==4 else 150))
genparams["max_length"] = genparams.get('max_tokens', (400 if api_format==4 else 180))
presence_penalty = genparams.get('presence_penalty', genparams.get('frequency_penalty', 0.0))
genparams["presence_penalty"] = presence_penalty
# openai allows either a string or a list as a stop sequence
@ -1316,7 +1316,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
memory=genparams.get('memory', ""),
images=genparams.get('images', []),
max_context_length=genparams.get('max_context_length', maxctx),
max_length=genparams.get('max_length', 150),
max_length=genparams.get('max_length', 180),
temperature=genparams.get('temperature', 0.7),
top_k=genparams.get('top_k', 100),
top_a=genparams.get('top_a', 0.0),
@ -1571,7 +1571,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
else:
if max_length>512:
max_length = 512
epurl = f"http://127.0.0.1:{args.port}"
epurl = f"http://localhost:{args.port}"
if args.host!="":
epurl = f"http://{args.host}:{args.port}"
gen_payload = {"prompt": prompt,"max_length": max_length,"temperature": temperature,"prompt": prompt,"top_k": top_k,"top_p": top_p,"rep_pen": rep_pen,"ban_eos_token":ban_eos_token}
@ -2021,28 +2021,47 @@ def is_port_in_use(portNum):
try:
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('127.0.0.1', portNum)) == 0
return s.connect_ex(('localhost', portNum)) == 0
except Exception as ex:
return True
def is_ipv6_supported():
try:
# Attempt to create an IPv6 socket
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
sock.close()
return True
except socket.error:
return False
def RunServerMultiThreaded(addr, port):
global exitcounter, sslvalid
global embedded_kailite, embedded_kcpp_docs, embedded_kcpp_sdui
if is_port_in_use(port):
print(f"Warning: Port {port} already appears to be in use by another program.")
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
ipv4_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
ipv4_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
ipv6_sock = None
if is_ipv6_supported():
ipv6_sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
ipv6_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if args.ssl and sslvalid:
import ssl
certpath = os.path.abspath(args.ssl[0])
keypath = os.path.abspath(args.ssl[1])
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.load_cert_chain(certfile=certpath, keyfile=keypath)
sock = context.wrap_socket(sock, server_side=True)
ipv4_sock = context.wrap_socket(ipv4_sock, server_side=True)
if ipv6_sock:
ipv6_sock = context.wrap_socket(ipv6_sock, server_side=True)
sock.bind((addr, port))
numThreads = 20
sock.listen(numThreads)
ipv4_sock.bind((addr, port))
ipv4_sock.listen(numThreads)
if ipv6_sock:
ipv6_sock.bind((addr, port))
ipv6_sock.listen(numThreads)
class Thread(threading.Thread):
def __init__(self, i):
@ -2056,7 +2075,11 @@ def RunServerMultiThreaded(addr, port):
handler = ServerRequestHandler(addr, port)
with http.server.HTTPServer((addr, port), handler, False) as self.httpd:
try:
self.httpd.socket = sock
if ipv6_sock:
self.httpd.socket = ipv4_sock if self.i < 16 else ipv6_sock
else:
self.httpd.socket = ipv4_sock
self.httpd.server_bind = self.server_close = lambda self: None
self.httpd.serve_forever()
except (KeyboardInterrupt,SystemExit):
@ -3266,7 +3289,7 @@ def run_horde_worker(args, api_key, worker_name):
from datetime import datetime
import random
global friendlymodelname, maxhordectx, maxhordelen, exitcounter, punishcounter, modelbusy, session_starttime
epurl = f"http://127.0.0.1:{args.port}"
epurl = f"http://localhost:{args.port}"
if args.host!="":
epurl = f"http://{args.host}:{args.port}"
@ -3475,13 +3498,13 @@ def setuptunnel(has_sd):
time.sleep(0.2)
if os.name == 'nt':
print("Starting Cloudflare Tunnel for Windows, please wait...", flush=True)
tunnelproc = subprocess.Popen(f"cloudflared.exe tunnel --url 127.0.0.1:{args.port}", text=True, encoding='utf-8', shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
tunnelproc = subprocess.Popen(f"cloudflared.exe tunnel --url localhost:{args.port}", text=True, encoding='utf-8', shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
elif sys.platform=="darwin":
print("Starting Cloudflare Tunnel for MacOS, please wait...", flush=True)
tunnelproc = subprocess.Popen(f"./cloudflared tunnel --url http://127.0.0.1:{args.port}", text=True, encoding='utf-8', shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
tunnelproc = subprocess.Popen(f"./cloudflared tunnel --url http://localhost:{args.port}", text=True, encoding='utf-8', shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
else:
print("Starting Cloudflare Tunnel for Linux, please wait...", flush=True)
tunnelproc = subprocess.Popen(f"./cloudflared-linux-amd64 tunnel --url http://127.0.0.1:{args.port}", text=True, encoding='utf-8', shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
tunnelproc = subprocess.Popen(f"./cloudflared-linux-amd64 tunnel --url http://localhost:{args.port}", text=True, encoding='utf-8', shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
time.sleep(10)
def tunnel_reader():
nonlocal tunnelproc,tunneloutput,tunnelrawlog
@ -4080,7 +4103,7 @@ def main(launch_args,start_server=True):
epurl = ""
httpsaffix = ("https" if sslvalid else "http")
if args.host=="":
epurl = f"{httpsaffix}://127.0.0.1:{args.port}"
epurl = f"{httpsaffix}://localhost:{args.port}"
else:
epurl = f"{httpsaffix}://{args.host}:{args.port}"
if not args.remotetunnel: