mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
allow ipv6 as well
This commit is contained in:
parent
9a0976761e
commit
40481abf0c
1 changed files with 40 additions and 17 deletions
57
koboldcpp.py
57
koboldcpp.py
|
@ -37,7 +37,7 @@ password = "" #if empty, no auth key required
|
|||
fullwhispermodelpath = "" #if empty, it's not initialized
|
||||
maxctx = 4096
|
||||
maxhordectx = 4096
|
||||
maxhordelen = 350
|
||||
maxhordelen = 400
|
||||
modelbusy = threading.Lock()
|
||||
requestsinqueue = 0
|
||||
defaultport = 5001
|
||||
|
@ -1126,13 +1126,13 @@ def transform_genparams(genparams, api_format):
|
|||
if api_format==1:
|
||||
genparams["prompt"] = genparams.get('text', "")
|
||||
genparams["top_k"] = int(genparams.get('top_k', 120))
|
||||
genparams["max_length"] = genparams.get('max', 150)
|
||||
genparams["max_length"] = genparams.get('max', 180)
|
||||
|
||||
elif api_format==2:
|
||||
pass
|
||||
|
||||
elif api_format==3 or api_format==4:
|
||||
genparams["max_length"] = genparams.get('max_tokens', (350 if api_format==4 else 150))
|
||||
genparams["max_length"] = genparams.get('max_tokens', (400 if api_format==4 else 180))
|
||||
presence_penalty = genparams.get('presence_penalty', genparams.get('frequency_penalty', 0.0))
|
||||
genparams["presence_penalty"] = presence_penalty
|
||||
# openai allows either a string or a list as a stop sequence
|
||||
|
@ -1316,7 +1316,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
memory=genparams.get('memory', ""),
|
||||
images=genparams.get('images', []),
|
||||
max_context_length=genparams.get('max_context_length', maxctx),
|
||||
max_length=genparams.get('max_length', 150),
|
||||
max_length=genparams.get('max_length', 180),
|
||||
temperature=genparams.get('temperature', 0.7),
|
||||
top_k=genparams.get('top_k', 100),
|
||||
top_a=genparams.get('top_a', 0.0),
|
||||
|
@ -1571,7 +1571,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
else:
|
||||
if max_length>512:
|
||||
max_length = 512
|
||||
epurl = f"http://127.0.0.1:{args.port}"
|
||||
epurl = f"http://localhost:{args.port}"
|
||||
if args.host!="":
|
||||
epurl = f"http://{args.host}:{args.port}"
|
||||
gen_payload = {"prompt": prompt,"max_length": max_length,"temperature": temperature,"prompt": prompt,"top_k": top_k,"top_p": top_p,"rep_pen": rep_pen,"ban_eos_token":ban_eos_token}
|
||||
|
@ -2021,28 +2021,47 @@ def is_port_in_use(portNum):
|
|||
try:
|
||||
import socket
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(('127.0.0.1', portNum)) == 0
|
||||
return s.connect_ex(('localhost', portNum)) == 0
|
||||
except Exception as ex:
|
||||
return True
|
||||
|
||||
def is_ipv6_supported():
|
||||
try:
|
||||
# Attempt to create an IPv6 socket
|
||||
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
|
||||
sock.close()
|
||||
return True
|
||||
except socket.error:
|
||||
return False
|
||||
|
||||
def RunServerMultiThreaded(addr, port):
|
||||
global exitcounter, sslvalid
|
||||
global embedded_kailite, embedded_kcpp_docs, embedded_kcpp_sdui
|
||||
if is_port_in_use(port):
|
||||
print(f"Warning: Port {port} already appears to be in use by another program.")
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
ipv4_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
ipv4_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
ipv6_sock = None
|
||||
if is_ipv6_supported():
|
||||
ipv6_sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
|
||||
ipv6_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
|
||||
if args.ssl and sslvalid:
|
||||
import ssl
|
||||
certpath = os.path.abspath(args.ssl[0])
|
||||
keypath = os.path.abspath(args.ssl[1])
|
||||
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
context.load_cert_chain(certfile=certpath, keyfile=keypath)
|
||||
sock = context.wrap_socket(sock, server_side=True)
|
||||
ipv4_sock = context.wrap_socket(ipv4_sock, server_side=True)
|
||||
if ipv6_sock:
|
||||
ipv6_sock = context.wrap_socket(ipv6_sock, server_side=True)
|
||||
|
||||
sock.bind((addr, port))
|
||||
numThreads = 20
|
||||
sock.listen(numThreads)
|
||||
ipv4_sock.bind((addr, port))
|
||||
ipv4_sock.listen(numThreads)
|
||||
if ipv6_sock:
|
||||
ipv6_sock.bind((addr, port))
|
||||
ipv6_sock.listen(numThreads)
|
||||
|
||||
class Thread(threading.Thread):
|
||||
def __init__(self, i):
|
||||
|
@ -2056,7 +2075,11 @@ def RunServerMultiThreaded(addr, port):
|
|||
handler = ServerRequestHandler(addr, port)
|
||||
with http.server.HTTPServer((addr, port), handler, False) as self.httpd:
|
||||
try:
|
||||
self.httpd.socket = sock
|
||||
if ipv6_sock:
|
||||
self.httpd.socket = ipv4_sock if self.i < 16 else ipv6_sock
|
||||
else:
|
||||
self.httpd.socket = ipv4_sock
|
||||
|
||||
self.httpd.server_bind = self.server_close = lambda self: None
|
||||
self.httpd.serve_forever()
|
||||
except (KeyboardInterrupt,SystemExit):
|
||||
|
@ -3266,7 +3289,7 @@ def run_horde_worker(args, api_key, worker_name):
|
|||
from datetime import datetime
|
||||
import random
|
||||
global friendlymodelname, maxhordectx, maxhordelen, exitcounter, punishcounter, modelbusy, session_starttime
|
||||
epurl = f"http://127.0.0.1:{args.port}"
|
||||
epurl = f"http://localhost:{args.port}"
|
||||
if args.host!="":
|
||||
epurl = f"http://{args.host}:{args.port}"
|
||||
|
||||
|
@ -3475,13 +3498,13 @@ def setuptunnel(has_sd):
|
|||
time.sleep(0.2)
|
||||
if os.name == 'nt':
|
||||
print("Starting Cloudflare Tunnel for Windows, please wait...", flush=True)
|
||||
tunnelproc = subprocess.Popen(f"cloudflared.exe tunnel --url 127.0.0.1:{args.port}", text=True, encoding='utf-8', shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
|
||||
tunnelproc = subprocess.Popen(f"cloudflared.exe tunnel --url localhost:{args.port}", text=True, encoding='utf-8', shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
|
||||
elif sys.platform=="darwin":
|
||||
print("Starting Cloudflare Tunnel for MacOS, please wait...", flush=True)
|
||||
tunnelproc = subprocess.Popen(f"./cloudflared tunnel --url http://127.0.0.1:{args.port}", text=True, encoding='utf-8', shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
|
||||
tunnelproc = subprocess.Popen(f"./cloudflared tunnel --url http://localhost:{args.port}", text=True, encoding='utf-8', shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
|
||||
else:
|
||||
print("Starting Cloudflare Tunnel for Linux, please wait...", flush=True)
|
||||
tunnelproc = subprocess.Popen(f"./cloudflared-linux-amd64 tunnel --url http://127.0.0.1:{args.port}", text=True, encoding='utf-8', shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
|
||||
tunnelproc = subprocess.Popen(f"./cloudflared-linux-amd64 tunnel --url http://localhost:{args.port}", text=True, encoding='utf-8', shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
|
||||
time.sleep(10)
|
||||
def tunnel_reader():
|
||||
nonlocal tunnelproc,tunneloutput,tunnelrawlog
|
||||
|
@ -4080,7 +4103,7 @@ def main(launch_args,start_server=True):
|
|||
epurl = ""
|
||||
httpsaffix = ("https" if sslvalid else "http")
|
||||
if args.host=="":
|
||||
epurl = f"{httpsaffix}://127.0.0.1:{args.port}"
|
||||
epurl = f"{httpsaffix}://localhost:{args.port}"
|
||||
else:
|
||||
epurl = f"{httpsaffix}://{args.host}:{args.port}"
|
||||
if not args.remotetunnel:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue