mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-24 22:04:26 +00:00
rpc implementation is complete
This commit is contained in:
parent
3520b915f9
commit
4bbbd55be6
5 changed files with 191 additions and 20 deletions
|
|
@ -192,6 +192,11 @@ extern "C"
|
|||
}
|
||||
}
|
||||
|
||||
bool launch_rpc_server(const char * endpoint, const char * devices)
|
||||
{
|
||||
return host_rpc_server(endpoint,devices);
|
||||
}
|
||||
|
||||
bool sd_load_model(const sd_load_model_inputs inputs)
|
||||
{
|
||||
return sdtype_load_model(inputs);
|
||||
|
|
|
|||
2
expose.h
2
expose.h
|
|
@ -85,6 +85,8 @@ struct load_model_inputs
|
|||
const bool quiet = false;
|
||||
const int debugmode = 0;
|
||||
const int continuous_batching_slots = 0;
|
||||
const int rpc_mode = 0; //0=disabled, 1=connect, 2=host
|
||||
const char * rpc_targets = nullptr;
|
||||
};
|
||||
struct generation_inputs
|
||||
{
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@
|
|||
#include "tools/mtmd/llava.h"
|
||||
#include "tools/mtmd/mtmd-audio.h"
|
||||
#include "common/common.h"
|
||||
#include "ggml-rpc.h"
|
||||
|
||||
#if defined(GGML_USE_HIP)
|
||||
// for rocblas_initialize()
|
||||
|
|
@ -2158,6 +2159,76 @@ static float CalcGradientAIRopeFreqBase(float original_rope_base, int n_ctx_trai
|
|||
}
|
||||
}
|
||||
|
||||
bool host_rpc_server(std::string endpoint, std::string devices_str)
|
||||
{
|
||||
llama_backend_init();
|
||||
int num_backends = ggml_backend_reg_count();
|
||||
printf("Number of Backends: %d\n",num_backends);
|
||||
for (size_t i = 0; i < num_backends; i++) {
|
||||
auto * reg = ggml_backend_reg_get(i);
|
||||
printf("Backend %d: %s\n", i, ggml_backend_reg_name(reg));
|
||||
}
|
||||
|
||||
ggml_backend_reg_t reg = ggml_backend_reg_by_name("RPC");
|
||||
if (!reg) {
|
||||
fprintf(stderr, "Error: Failed to find RPC backend\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto start_server_fn = (decltype(ggml_backend_rpc_start_server)*) ggml_backend_reg_get_proc_address(reg, "ggml_backend_rpc_start_server");
|
||||
if (!start_server_fn) {
|
||||
fprintf(stderr, "Failed to obtain RPC backend start server function\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
|
||||
if(devices_str!="") //check if devices is overridden
|
||||
{
|
||||
devices = kcpp_parse_device_list(devices_str);
|
||||
// Remove all nullptr elements
|
||||
devices.erase( std::remove(devices.begin(), devices.end(), nullptr), devices.end());
|
||||
}
|
||||
|
||||
//try dGPU first
|
||||
if (devices.empty()) {
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
||||
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
||||
devices.push_back(dev);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if not, find other non-cpu devices
|
||||
if (devices.empty()) {
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
||||
if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU) {
|
||||
devices.push_back(dev);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If there are no accelerators, fallback to CPU device
|
||||
if (devices.empty()) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (dev) {
|
||||
devices.push_back(dev);
|
||||
}
|
||||
}
|
||||
printf("\nUsing %d Devices for this RPC server:",devices.size());
|
||||
for(int i=0;i<devices.size();++i)
|
||||
{
|
||||
printf("\n%d: %s",i,ggml_backend_dev_name(devices[i]));
|
||||
}
|
||||
|
||||
printf("\nNote: It's not advised to expose RPC server to the open internet.\n=====\nStarting RPC server on %s, clients may now connect\n=====\n",endpoint.c_str());
|
||||
|
||||
start_server_fn(endpoint.c_str(), "", 4, devices.size(), devices.data());
|
||||
return true;
|
||||
}
|
||||
|
||||
static void connect_rpc_servers(const std::string & servers) {
|
||||
auto rpc_servers = string_split<std::string>(servers, ',');
|
||||
if (rpc_servers.empty()) {
|
||||
|
|
@ -2492,7 +2563,17 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
auto * reg = ggml_backend_reg_get(i);
|
||||
printf("Backend %d: %s\n", i, ggml_backend_reg_name(reg));
|
||||
}
|
||||
// connect_rpc_servers("127.0.0.1:7777");
|
||||
|
||||
if(inputs.rpc_mode==2) //host mode, not supposed to happen
|
||||
{
|
||||
printf("\nShould not reach here, RPC host does not need to load models.\n");
|
||||
return ModelLoadResult::FAIL;
|
||||
}
|
||||
else if(inputs.rpc_mode==1) //connect
|
||||
{
|
||||
std::string servers = inputs.rpc_targets;
|
||||
connect_rpc_servers(servers);
|
||||
}
|
||||
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
llama_context_params llama_ctx_params = llama_context_default_params();
|
||||
|
|
@ -2677,7 +2758,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
llama_ctx_params.type_k = (inputs.quant_k==4?GGML_TYPE_Q4_0:(inputs.quant_k==3?GGML_TYPE_Q5_1:(inputs.quant_k==2?GGML_TYPE_Q8_0:(inputs.quant_k==1?GGML_TYPE_BF16:GGML_TYPE_F16))));
|
||||
llama_ctx_params.type_v = (inputs.quant_v==4?GGML_TYPE_Q4_0:(inputs.quant_v==3?GGML_TYPE_Q5_1:(inputs.quant_v==2?GGML_TYPE_Q8_0:(inputs.quant_v==1?GGML_TYPE_BF16:GGML_TYPE_F16))));
|
||||
|
||||
|
||||
//apply overrides from autofit
|
||||
float tensor_split_temp[128] = {0}; //temp buffer for autofit
|
||||
std::vector<size_t> fit_params_target = std::vector<size_t>(llama_max_devices(),1024*1024*1024);
|
||||
|
|
|
|||
118
koboldcpp.py
118
koboldcpp.py
|
|
@ -176,6 +176,7 @@ last_non_horde_req_time = time.time()
|
|||
currfinishreason = None
|
||||
zenity_recent_dir = os.getcwd()
|
||||
zenity_permitted = True
|
||||
default_rpc_port = 51001
|
||||
thinkformats = [{"start":"<|channel|>analysis<|message|>","end":"<|start|>assistant<|channel|>final<|message|>"},
|
||||
{"start":"<think>","end":"</think>"},
|
||||
{"start":"<seed:think>","end":"</seed:think>"},
|
||||
|
|
@ -296,7 +297,9 @@ class load_model_inputs(ctypes.Structure):
|
|||
("devices_override", ctypes.c_char_p),
|
||||
("quiet", ctypes.c_bool),
|
||||
("debugmode", ctypes.c_int),
|
||||
("continuous_batching_slots", ctypes.c_int)]
|
||||
("continuous_batching_slots", ctypes.c_int),
|
||||
("rpc_mode", ctypes.c_int),
|
||||
("rpc_targets", ctypes.c_char_p)]
|
||||
|
||||
class generation_inputs(ctypes.Structure):
|
||||
_fields_ = [("seed", ctypes.c_int),
|
||||
|
|
@ -980,6 +983,8 @@ def init_library():
|
|||
handle.last_logprobs.restype = last_logprobs_outputs
|
||||
handle.detokenize.argtypes = [detokenize_inputs]
|
||||
handle.detokenize.restype = ctypes.c_char_p
|
||||
handle.launch_rpc_server.argtypes = [ctypes.c_char_p, ctypes.c_char_p]
|
||||
handle.launch_rpc_server.restype = ctypes.c_bool
|
||||
handle.set_environment_variable.restype = ctypes.c_int
|
||||
handle.set_environment_variable.argtypes = [ctypes.c_char_p, ctypes.c_char_p]
|
||||
|
||||
|
|
@ -1222,7 +1227,7 @@ def old_cpu_check(): #return -1 for pass, 0 if has avx2, 1 if has avx, 2 if has
|
|||
return -1 #cannot determine
|
||||
|
||||
def has_valid_model():
|
||||
return args.model_param or args.sdmodel or args.whispermodel or args.ttsmodel or args.embeddingsmodel or args.musicdiffusion or args.musicllm or args.mcpfile or args.nomodel
|
||||
return args.model_param or args.sdmodel or args.whispermodel or args.ttsmodel or args.embeddingsmodel or args.musicdiffusion or args.musicllm or args.mcpfile or args.nomodel or args.rpcmode=="host"
|
||||
|
||||
def unpack_to_dir(destpath = ""):
|
||||
srcpath = os.path.abspath(os.path.dirname(__file__))
|
||||
|
|
@ -2021,6 +2026,9 @@ def load_model(model_filename):
|
|||
inputs.smartcacheslots = sclimit
|
||||
inputs.pipelineparallel = (not args.nopipelineparallel)
|
||||
inputs.continuous_batching_slots = int(args.continuous_batching) if hasattr(args, "continuous_batching") else 0
|
||||
inputs.rpc_mode = (2 if args.rpcmode=="host" else (1 if args.rpcmode=="connect" else 0))
|
||||
inputs.rpc_targets = (args.rpctargets if args.rpcmode=="connect" else "").encode("UTF-8")
|
||||
|
||||
inputs = set_backend_props(inputs)
|
||||
ret = handle.load_model(inputs)
|
||||
return ret
|
||||
|
|
@ -7651,6 +7659,11 @@ def show_gui():
|
|||
maxrequestsize_var = ctk.StringVar(value=str(32))
|
||||
ratelimit_var = ctk.StringVar(value=str(0))
|
||||
reqtimeout_var = ctk.StringVar(value=str(default_reqtimeout))
|
||||
rpcmode_var = ctk.StringVar(value=str("disabled"))
|
||||
rpcendpoints_var = ctk.StringVar(value="")
|
||||
rpc_host_var = ctk.StringVar(value="0.0.0.0")
|
||||
rpc_port_var = ctk.StringVar(value=str(default_rpc_port))
|
||||
rpc_device_var = ctk.StringVar()
|
||||
|
||||
sd_model_var = ctk.StringVar()
|
||||
sd_lora_var = ctk.StringVar()
|
||||
|
|
@ -8450,23 +8463,61 @@ def show_gui():
|
|||
network_tab = tabcontent["Network"]
|
||||
|
||||
# interfaces
|
||||
makelabelentry(network_tab, "Port: ", port_var, 1, 150,tooltip=f"Select the port to host the KoboldCPP webserver.\n(Defaults to {defaultport})")
|
||||
makelabelentry(network_tab, "Host: ", host_var, 2, 150,tooltip="Select a specific host interface to bind to.\n(Defaults to all)")
|
||||
makelabelentry(network_tab, "Host: ", host_var, row=1, width=150, padx=(50), singleline=True,tooltip="Select a specific host interface to bind to.\n(Defaults to all)")
|
||||
makelabelentry(network_tab, "Port: ", port_var, row=1, width=100, padx=(254), singleline=True,tooltip=f"Select the port to host the KoboldCPP webserver.\n(Defaults to {defaultport})",labelpadx=220)
|
||||
|
||||
makecheckbox(network_tab, "Remote Tunnel", remotetunnel_var, 3,tooltiptxt="Creates a trycloudflare tunnel.\nAllows you to access koboldcpp from other devices over an internet URL.")
|
||||
makecheckbox(network_tab, "Quiet Mode", quietmode, 4,tooltiptxt="Prevents all generation related terminal output from being displayed.")
|
||||
makecheckbox(network_tab, "NoCertify Mode (Insecure)", nocertifymode, 4, 1,tooltiptxt="Allows insecure SSL connections. Use this if you have cert errors and need to bypass certificate restrictions.")
|
||||
makecheckbox(network_tab, "Shared Multiplayer", multiplayer_var, 5,tooltiptxt="Hosts a shared multiplayer session that others can join.")
|
||||
makecheckbox(network_tab, "Enable WebSearch", websearch_var, 5, 1,tooltiptxt="Enable the local search engine proxy so Web Searches can be done.")
|
||||
makecheckbox(network_tab, "Remote Tunnel", remotetunnel_var, 11,tooltiptxt="Creates a trycloudflare tunnel.\nAllows you to access koboldcpp from other devices over an internet URL.")
|
||||
makecheckbox(network_tab, "Quiet Mode", quietmode, 12,tooltiptxt="Prevents all generation related terminal output from being displayed.")
|
||||
makecheckbox(network_tab, "NoCertify Mode (Insecure)", nocertifymode, 12,padx=(200),tooltiptxt="Allows insecure SSL connections. Use this if you have cert errors and need to bypass certificate restrictions.")
|
||||
makecheckbox(network_tab, "Shared Multiplayer", multiplayer_var, 13,tooltiptxt="Hosts a shared multiplayer session that others can join.")
|
||||
makecheckbox(network_tab, "Enable WebSearch", websearch_var, 13,padx=(200),tooltiptxt="Enable the local search engine proxy so Web Searches can be done.")
|
||||
|
||||
makefileentry(network_tab, "SSL Cert:", "Select SSL cert.pem file",ssl_cert_var, 7, width=200 ,filetypes=[("Unencrypted Certificate PEM", "*.pem")], singlerow=True, singlecol=False,tooltiptxt="Select your unencrypted .pem SSL certificate file for https.\nCan be generated with OpenSSL.")
|
||||
makefileentry(network_tab, "SSL Key:", "Select SSL key.pem file", ssl_key_var, 9, width=200, filetypes=[("Unencrypted Key PEM", "*.pem")], singlerow=True, singlecol=False, tooltiptxt="Select your unencrypted .pem SSL key file for https.\nCan be generated with OpenSSL.")
|
||||
makelabelentry(network_tab, "Password: ", password_var, 10, 200,tooltip="Enter a password required to use this instance.\nThis key will be required for all text endpoints.\nImage endpoints are not secured.")
|
||||
makefileentry(network_tab, "SSL Cert:", "Select SSL cert.pem file",ssl_cert_var, 20, width=200,filetypes=[("Unencrypted Certificate PEM", "*.pem")], singlerow=True,tooltiptxt="Select your unencrypted .pem SSL certificate file for https.\nCan be generated with OpenSSL.")
|
||||
makefileentry(network_tab, "SSL Key:", "Select SSL key.pem file", ssl_key_var, 22, width=200, filetypes=[("Unencrypted Key PEM", "*.pem")], singlerow=True, tooltiptxt="Select your unencrypted .pem SSL key file for https.\nCan be generated with OpenSSL.")
|
||||
makelabelentry(network_tab, "Password: ", password_var, 24, 200, padx=(100), singleline=True, tooltip="Enter a password required to use this instance.\nThis key will be required for all text endpoints.\nImage endpoints are not secured.")
|
||||
|
||||
makelabelentry(network_tab, "Multiuser Queue:", multiuser_var, row=20, width=50, tooltip="Maximum queued incoming requests.")
|
||||
makelabelentry(network_tab, "Max Req. Size (MB):", maxrequestsize_var, row=22, width=50, tooltip="Specify a max request payload size. Any requests to the server larger than this size will be dropped. Do not change if unsure.")
|
||||
makelabelentry(network_tab, "IP Rate Limiter (s):", ratelimit_var, row=24, width=50, tooltip="Rate limits each IP to allow a new request once per X seconds. Do not change if unsure.")
|
||||
makelabelentry(network_tab, "Request Timeout (s):", reqtimeout_var, row=26, width=50, tooltip="Timeout in seconds for HTTP requests")
|
||||
makelabelentry(network_tab, "Multiuser Queue:", multiuser_var, row=30, width=50, padx=(120), singleline=True, tooltip="Maximum queued incoming requests.")
|
||||
makelabelentry(network_tab, "Max Req. Size (MB):", maxrequestsize_var, row=30, width=50, padx=(340), singleline=True, tooltip="Specify a max request payload size. Any requests to the server larger than this size will be dropped. Do not change if unsure.",labelpadx=210)
|
||||
makelabelentry(network_tab, "IP Rate Limiter (s):", ratelimit_var, row=34, width=50, padx=(120), singleline=True, tooltip="Rate limits each IP to allow a new request once per X seconds. Do not change if unsure.")
|
||||
makelabelentry(network_tab, "Request Timeout (s):", reqtimeout_var, row=34, width=50, padx=(340), singleline=True, tooltip="Timeout in seconds for HTTP requests",labelpadx=210)
|
||||
|
||||
def togglerpcmode(a,b,c):
|
||||
if rpcmode_var.get()=="connect":
|
||||
rpcepbox.grid()
|
||||
rpceplbl.grid()
|
||||
rpchostbox.grid_remove()
|
||||
rpchostlbl.grid_remove()
|
||||
rpcportbox.grid_remove()
|
||||
rpcportlbl.grid_remove()
|
||||
rpcdevicebox.grid_remove()
|
||||
rpcdevicelbl.grid_remove()
|
||||
rpcdesc.configure(text="Connect to one or more remote RPC endpoints (ip:port), comma seperated")
|
||||
elif rpcmode_var.get()=="host":
|
||||
rpcepbox.grid_remove()
|
||||
rpceplbl.grid_remove()
|
||||
rpchostbox.grid()
|
||||
rpchostlbl.grid()
|
||||
rpcportbox.grid()
|
||||
rpcportlbl.grid()
|
||||
rpcdevicebox.grid()
|
||||
rpcdevicelbl.grid()
|
||||
rpcdesc.configure(text="Run a RPC server that others can connect to. Disables local models and API.")
|
||||
else:
|
||||
rpcepbox.grid_remove()
|
||||
rpceplbl.grid_remove()
|
||||
rpchostbox.grid_remove()
|
||||
rpchostlbl.grid_remove()
|
||||
rpcportbox.grid_remove()
|
||||
rpcportlbl.grid_remove()
|
||||
rpcdevicebox.grid_remove()
|
||||
rpcdevicelbl.grid_remove()
|
||||
rpcdesc.configure(text="RPC is disabled and will not be used")
|
||||
makelabelcombobox(network_tab, "RPC Mode:", rpcmode_var, row=40, padx=(100), width=90, command=togglerpcmode, tooltiptxt="RPC functionality to connect to remote RPC instances or host one, allowing GPUs to be shared over a network.", values=["disabled","connect","host"])
|
||||
rpcdesc = makelabel(network_tab, "RPC is disabled and will not be used", row=42)
|
||||
rpcepbox, rpceplbl = makelabelentry(network_tab, "RPC Endpoints: ", rpcendpoints_var, row=44, padx=(100), width=200, singleline=True,tooltip="Specify a comma separated list of remote RPC endpoints to connect to e.g. 127.0.0.1:51001,127.0.0.1:51002")
|
||||
rpchostbox, rpchostlbl = makelabelentry(network_tab, "RPC Host IP: ", rpc_host_var, row=46, width=150, padx=(100), singleline=True,tooltip="IP address for RPC server to listen on. Use 0.0.0.0 for all interfaces, 127.0.0.1 for localhost only.")
|
||||
rpcportbox, rpcportlbl = makelabelentry(network_tab, "RPC Host Port: ", rpc_port_var, row=46, width=100, padx=(360), singleline=True,tooltip=f"Port for RPC server to listen on. (default:{default_rpc_port})",labelpadx=260)
|
||||
rpcdevicebox, rpcdevicelbl = makelabelentry(network_tab, "RPC Devices: ", rpc_device_var, row=48, padx=(100), width=200, singleline=True,tooltip="Set specific devices to use for RPC server. Comma separated. Overrides normal RPC device choices.")
|
||||
|
||||
|
||||
# Horde Tab
|
||||
|
|
@ -8688,10 +8739,11 @@ def show_gui():
|
|||
togglejinja(1,1,1)
|
||||
toggleadmin(1,1,1)
|
||||
updatejinjathinktoggle(1,1,1)
|
||||
togglerpcmode(1,1,1)
|
||||
|
||||
# launch
|
||||
def guilaunch():
|
||||
if model_var.get() == "" and sd_model_var.get() == "" and whisper_model_var.get() == "" and tts_model_var.get() == "" and embeddings_model_var.get() == "" and musicdiffusion_var.get() == "" and musicllm_var.get() == "" and nomodel.get()!=1:
|
||||
if model_var.get() == "" and sd_model_var.get() == "" and whisper_model_var.get() == "" and tts_model_var.get() == "" and embeddings_model_var.get() == "" and musicdiffusion_var.get() == "" and musicllm_var.get() == "" and nomodel.get()!=1 and rpcmode_var.get()!="host":
|
||||
# prevent launch without at least one valid model
|
||||
givehelp = show_gui_yesnobox("No Models Selected","Error: You need to load at least one AI model to continue.\n\nDo you want help finding a model?")
|
||||
if givehelp == 'yes':
|
||||
|
|
@ -8864,6 +8916,12 @@ def show_gui():
|
|||
if not args.reqtimeout:
|
||||
args.reqtimeout = default_reqtimeout
|
||||
|
||||
args.rpcmode = rpcmode_var.get() if rpcmode_var.get() else "disabled"
|
||||
args.rpchost = rpc_host_var.get() if (args.rpcmode=="host" and rpc_host_var.get()) else "0.0.0.0"
|
||||
args.rpcport = int(rpc_port_var.get()) if (args.rpcmode=="host" and rpc_port_var.get()) else default_rpc_port
|
||||
args.rpctargets = rpcendpoints_var.get() if (args.rpcmode=="connect" and rpcendpoints_var.get()) else ""
|
||||
args.rpcdevice = rpc_device_var.get() if (args.rpcmode=="host" and rpc_device_var.get()) else ""
|
||||
|
||||
if usehorde_var.get() != 0:
|
||||
args.hordemodelname = horde_name_var.get()
|
||||
args.hordegenlen = int(horde_gen_var.get())
|
||||
|
|
@ -9164,6 +9222,11 @@ def show_gui():
|
|||
maxrequestsize_var.set(mydict["maxrequestsize"] if ("maxrequestsize" in mydict and mydict["maxrequestsize"]) else 32)
|
||||
ratelimit_var.set(mydict["ratelimit"] if ("ratelimit" in mydict and mydict["ratelimit"]) else 0)
|
||||
reqtimeout_var.set(mydict["reqtimeout"] if ("reqtimeout" in mydict and mydict["reqtimeout"]) else 0)
|
||||
rpcmode_var.set(mydict["rpcmode"] if ("rpcmode" in mydict and mydict["rpcmode"]) else "disabled")
|
||||
rpc_host_var.set(mydict["rpchost"] if ("rpchost" in mydict and mydict["rpchost"]) else "0.0.0.0")
|
||||
rpc_port_var.set(mydict["rpcport"] if ("rpcport" in mydict and mydict["rpcport"]) else str(default_rpc_port))
|
||||
rpcendpoints_var.set(mydict["rpctargets"] if ("rpctargets" in mydict and mydict["rpctargets"]) else "")
|
||||
rpc_device_var.set(mydict["rpcdevice"] if ("rpcdevice" in mydict and mydict["rpcdevice"]) else "")
|
||||
|
||||
sd_model_var.set(mydict["sdmodel"] if ("sdmodel" in mydict and mydict["sdmodel"]) else "")
|
||||
sd_clamped_var.set(int(mydict["sdclamped"]) if ("sdclamped" in mydict and mydict["sdclamped"]) else 0)
|
||||
|
|
@ -9782,7 +9845,7 @@ def reload_from_new_args(newargs):
|
|||
args.istemplate = False
|
||||
newargs = convert_invalid_args(newargs)
|
||||
for key, value in newargs.items(): #do not overwrite certain values
|
||||
if key not in ["remotetunnel","showgui","port","host","port_param","admin","adminpassword","password","adminunloadtimeout","routermode","admindir","ssl","nocertify","benchmark","prompt","config","baseconfig","downloaddir","onready"]:
|
||||
if key not in ["remotetunnel","showgui","port","host","port_param","admin","adminpassword","password","adminunloadtimeout","routermode","admindir","ssl","nocertify","benchmark","prompt","config","baseconfig","downloaddir","onready","rpcmode","rpchost","rpcport","rpcdevice"]:
|
||||
setattr(args, key, value)
|
||||
setattr(args,"showgui",False)
|
||||
setattr(args,"benchmark",False)
|
||||
|
|
@ -10191,6 +10254,8 @@ def main(launch_args, default_args):
|
|||
#prevent disallowed combos
|
||||
if (args.nomodel or args.benchmark or args.launch or args.admin) and args.cli:
|
||||
exit_with_error(1, "Error: --cli cannot be combined with --launch, --nomodel, --admin or --benchmark")
|
||||
if (args.rpcmode!="connect" and args.rpctargets):
|
||||
exit_with_error(1, "Error: rpctargets can only be used in connect mode")
|
||||
|
||||
args = convert_invalid_args(args)
|
||||
|
||||
|
|
@ -10974,6 +11039,15 @@ def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False):
|
|||
print(args)
|
||||
print("==========")
|
||||
|
||||
#if RPC Host mode is specified
|
||||
if args.rpcmode=="host":
|
||||
rpc_endpt = f"{args.rpchost}:{args.rpcport}"
|
||||
rpc_devices = args.rpcdevice
|
||||
print(f"Initialize RPC server as host mode at {rpc_endpt}.")
|
||||
handle.launch_rpc_server(rpc_endpt.encode("UTF-8"),rpc_devices.encode("UTF-8"))
|
||||
time.sleep(1)
|
||||
return
|
||||
|
||||
#handle loading text model
|
||||
if args.model_param:
|
||||
if not os.path.exists(args.model_param):
|
||||
|
|
@ -11730,6 +11804,14 @@ if __name__ == '__main__':
|
|||
embeddingsparsergroup.add_argument("--embeddingsmaxctx", metavar=('[amount]'), help="Overrides the default maximum supported context of an embeddings model (defaults to trained context).", type=int, default=0)
|
||||
embeddingsparsergroup.add_argument("--embeddingsgpu", help="Attempts to offload layers of the embeddings model to GPU. Usually not needed.", action='store_true')
|
||||
|
||||
rpcgroup = parser.add_argument_group('RPC Commands')
|
||||
rpcgroup.add_argument("--rpcmode", help="RPC allows GPUs to be shared over the network. connect=access a remote GPU, host=share your GPU",metavar=('[disabled/connect/host]'), type=str, choices=["disabled","connect","host"], default="disabled")
|
||||
rpcgroupA = rpcgroup.add_mutually_exclusive_group()
|
||||
rpcgroupA.add_argument("--rpcport", metavar=('[portnumber]'), help=f"RPC host mode only. Port for RPC server to listen on (default: {default_rpc_port}).", default=default_rpc_port, type=int)
|
||||
rpcgroup.add_argument("--rpchost", metavar=('[IP address]'), help="RPC host mode only. IP address for RPC server to listen on. Use 0.0.0.0 for all interfaces, 127.0.0.1 for localhost only.", default="0.0.0.0", type=str)
|
||||
rpcgroup.add_argument("--rpcdevice", metavar=('<dev1,dev2,..>'), help="RPC host mode only. Set specific devices to use for RPC server. Comma separated. Overrides normal RPC device choices.", default="")
|
||||
rpcgroupA.add_argument("--rpctargets", metavar=('[remotehost1:port1,remotehost2:port2]'), help="RPC connect mode only. Specify a comma separated list of remote RPC endpoints to connect to e.g. 127.0.0.1:51001,127.0.0.1:51002", default="", type=str)
|
||||
|
||||
admingroup = parser.add_argument_group('Administration Commands')
|
||||
admingroup.add_argument("--admin", help="Enables admin mode, allowing you to unload and reload different configurations or models.", action='store_true')
|
||||
admingroup.add_argument("--adminpassword", metavar=('[password]'), help="Require a password to access admin functions. You are strongly advised to use one for publically accessible instances! Can also be set with env var KCPP_ADMINPASSWORD", default=os.getenv('KCPP_ADMINPASSWORD',None))
|
||||
|
|
|
|||
|
|
@ -99,6 +99,8 @@ std::vector<int> gpttype_get_token_arr(const std::string & input, bool addbos);
|
|||
std::string gpttype_detokenize(const std::vector<int> & input, bool render_special);
|
||||
const std::vector<TopPicksData> gpttype_get_top_picks_data();
|
||||
|
||||
bool host_rpc_server(std::string endpoint, std::string devices);
|
||||
|
||||
bool sdtype_load_model(const sd_load_model_inputs inputs);
|
||||
sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs);
|
||||
sd_generation_outputs sdtype_upscale(const sd_upscale_inputs inputs);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue