mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
added model and config endpoints for sdcpp, added more samplers. speed is still not good
This commit is contained in:
parent
257015bb94
commit
e8f4d7b3da
2 changed files with 86 additions and 21 deletions
84
koboldcpp.py
84
koboldcpp.py
|
@ -478,15 +478,28 @@ def sd_load_model(model_filename):
|
|||
ret = handle.sd_load_model(inputs)
|
||||
return ret
|
||||
|
||||
def sd_generate(prompt, negative_prompt="", cfg_scale=5, sample_steps=20, seed=-1, sample_method="euler a"):
|
||||
def sd_generate(genparams):
|
||||
global maxctx, args, currentusergenkey, totalgens, pendingabortkey
|
||||
prompt = genparams.get("prompt", "high quality")
|
||||
negative_prompt = genparams.get("negative_prompt", "")
|
||||
cfg_scale = genparams.get("cfg_scale", 5)
|
||||
sample_steps = genparams.get("steps", 20)
|
||||
seed = genparams.get("seed", -1)
|
||||
sample_method = genparams.get("sampler_name", "euler a")
|
||||
|
||||
#quick mode
|
||||
if args.sdconfig and len(args.sdconfig)>1 and args.sdconfig[1]=="quick":
|
||||
cfg_scale = 1
|
||||
sample_steps = 7
|
||||
sample_method = "dpm++ 2m karras"
|
||||
|
||||
inputs = sd_generation_inputs()
|
||||
inputs.prompt = prompt.encode("UTF-8")
|
||||
inputs.negative_prompt = negative_prompt.encode("UTF-8")
|
||||
inputs.cfg_scale = cfg_scale
|
||||
inputs.sample_steps = sample_steps
|
||||
inputs.seed = seed
|
||||
inputs.sample_method = sample_method.encode("UTF-8")
|
||||
inputs.sample_method = sample_method.lower().encode("UTF-8")
|
||||
ret = handle.sd_generate(inputs)
|
||||
outstr = ""
|
||||
if ret.status==1:
|
||||
|
@ -512,7 +525,9 @@ def bring_terminal_to_foreground():
|
|||
### A hacky simple HTTP server simulating a kobold api by Concedo
|
||||
### we are intentionally NOT using flask, because we want MINIMAL dependencies
|
||||
#################################################################
|
||||
friendlymodelname = "concedo/koboldcpp" # local kobold api apparently needs a hardcoded known HF model name
|
||||
friendlymodelname = "inactive"
|
||||
friendlysdmodelname = "inactive"
|
||||
fullsdmodelpath = "" #if empty, it's not initialized
|
||||
maxctx = 2048
|
||||
maxhordectx = 2048
|
||||
maxhordelen = 256
|
||||
|
@ -860,7 +875,7 @@ Enter Prompt:<br>
|
|||
self.wfile.write(finalhtml)
|
||||
|
||||
def do_GET(self):
|
||||
global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens, preloaded_story, exitcounter, currentusergenkey
|
||||
global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens, preloaded_story, exitcounter, currentusergenkey, friendlysdmodelname, fullsdmodelpath
|
||||
self.path = self.path.rstrip('/')
|
||||
response_body = None
|
||||
content_type = 'application/json'
|
||||
|
@ -920,6 +935,11 @@ Enter Prompt:<br>
|
|||
elif self.path.endswith('/v1/models'):
|
||||
response_body = (json.dumps({"object":"list","data":[{"id":friendlymodelname,"object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode())
|
||||
|
||||
elif self.path.endswith('/sdapi/v1/sd-models'):
|
||||
response_body = (json.dumps([{"title":friendlysdmodelname,"model_name":friendlysdmodelname,"hash":"8888888888","sha256":"8888888888888888888888888888888888888888888888888888888888888888","filename":fullsdmodelpath,"config": None}]).encode())
|
||||
elif self.path.endswith('/sdapi/v1/options'):
|
||||
response_body = (json.dumps({"samples_format":"png","sd_model_checkpoint":friendlysdmodelname}).encode())
|
||||
|
||||
elif self.path=="/api":
|
||||
content_type = 'text/html'
|
||||
if self.embedded_kcpp_docs is None:
|
||||
|
@ -1044,6 +1064,7 @@ Enter Prompt:<br>
|
|||
sse_stream_flag = False
|
||||
|
||||
api_format = 0 #1=basic,2=kai,3=oai,4=oai-chat
|
||||
is_txt2img = False
|
||||
|
||||
if self.path.endswith('/request'):
|
||||
api_format = 1
|
||||
|
@ -1061,7 +1082,10 @@ Enter Prompt:<br>
|
|||
if self.path.endswith('/v1/chat/completions'):
|
||||
api_format = 4
|
||||
|
||||
if api_format > 0:
|
||||
if self.path.endswith('/sdapi/v1/txt2img'):
|
||||
is_txt2img = True
|
||||
|
||||
if is_txt2img or api_format > 0:
|
||||
genparams = None
|
||||
try:
|
||||
genparams = json.loads(body)
|
||||
|
@ -1076,25 +1100,43 @@ Enter Prompt:<br>
|
|||
if args.foreground:
|
||||
bring_terminal_to_foreground()
|
||||
|
||||
# Check if streaming chat completions, if so, set stream mode to true
|
||||
if (api_format == 4 or api_format == 3) and "stream" in genparams and genparams["stream"]:
|
||||
sse_stream_flag = True
|
||||
if api_format > 0:#text gen
|
||||
# Check if streaming chat completions, if so, set stream mode to true
|
||||
if (api_format == 4 or api_format == 3) and "stream" in genparams and genparams["stream"]:
|
||||
sse_stream_flag = True
|
||||
|
||||
gen = asyncio.run(self.handle_request(genparams, api_format, sse_stream_flag))
|
||||
gen = asyncio.run(self.handle_request(genparams, api_format, sse_stream_flag))
|
||||
|
||||
try:
|
||||
# Headers are already sent when streaming
|
||||
if not sse_stream_flag:
|
||||
try:
|
||||
# Headers are already sent when streaming
|
||||
if not sse_stream_flag:
|
||||
self.send_response(200)
|
||||
genresp = (json.dumps(gen).encode())
|
||||
self.send_header('content-length', str(len(genresp)))
|
||||
self.end_headers(content_type='application/json')
|
||||
self.wfile.write(genresp)
|
||||
except Exception as ex:
|
||||
if args.debugmode:
|
||||
print(ex)
|
||||
print("Generate: The response could not be sent, maybe connection was terminated?")
|
||||
handle.abort_generate()
|
||||
time.sleep(0.2) #short delay
|
||||
return
|
||||
|
||||
elif is_txt2img: #image gen
|
||||
try:
|
||||
gen = sd_generate(genparams)
|
||||
genresp = (json.dumps({"images":[gen],"parameters":{},"info":""}).encode())
|
||||
self.send_response(200)
|
||||
genresp = (json.dumps(gen).encode())
|
||||
self.send_header('content-length', str(len(genresp)))
|
||||
self.end_headers(content_type='application/json')
|
||||
self.wfile.write(genresp)
|
||||
except Exception as ex:
|
||||
print("Generate: The response could not be sent, maybe connection was terminated?")
|
||||
handle.abort_generate()
|
||||
time.sleep(0.2) #short delay
|
||||
return
|
||||
except Exception as ex:
|
||||
if args.debugmode:
|
||||
print(ex)
|
||||
print("Generate Image: The response could not be sent, maybe connection was terminated?")
|
||||
time.sleep(0.2) #short delay
|
||||
return
|
||||
finally:
|
||||
modelbusy.release()
|
||||
|
||||
|
@ -2433,7 +2475,7 @@ def sanitize_string(input_string):
|
|||
return sanitized_string
|
||||
|
||||
def main(launch_args,start_server=True):
|
||||
global args, friendlymodelname
|
||||
global args, friendlymodelname, friendlysdmodelname, fullsdmodelpath
|
||||
args = launch_args
|
||||
embedded_kailite = None
|
||||
embedded_kcpp_docs = None
|
||||
|
@ -2583,6 +2625,10 @@ def main(launch_args,start_server=True):
|
|||
time.sleep(3)
|
||||
sys.exit(2)
|
||||
imgmodel = os.path.abspath(imgmodel)
|
||||
fullsdmodelpath = imgmodel
|
||||
friendlysdmodelname = os.path.basename(imgmodel)
|
||||
friendlysdmodelname = os.path.splitext(friendlysdmodelname)[0]
|
||||
friendlysdmodelname = sanitize_string(friendlysdmodelname)
|
||||
loadok = sd_load_model(imgmodel)
|
||||
print("Load Image Model OK: " + str(loadok))
|
||||
if not loadok:
|
||||
|
|
|
@ -127,7 +127,7 @@ static void sd_logger_callback(enum sd_log_level_t level, const char* log, void*
|
|||
|
||||
bool sdtype_load_model(const sd_load_model_inputs inputs) {
|
||||
|
||||
printf("\nSelected Image Model: %s\n",inputs.model_filename);
|
||||
printf("\nImage Gen - Load Safetensors Image Model: %s\n",inputs.model_filename);
|
||||
|
||||
sd_params = new SDParams();
|
||||
sd_params->model_path = inputs.model_filename;
|
||||
|
@ -187,10 +187,29 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
|
|||
sd_params->sample_steps = inputs.sample_steps;
|
||||
sd_params->seed = inputs.seed;
|
||||
|
||||
if(inputs.sample_method=="euler a") //all lowercase
|
||||
printf("\nGenerating Image (%d steps)\n",inputs.sample_steps);
|
||||
std::string sampler = inputs.sample_method;
|
||||
|
||||
if(sampler=="euler a") //all lowercase
|
||||
{
|
||||
sd_params->sample_method = sample_method_t::EULER_A;
|
||||
}
|
||||
else if(sampler=="euler")
|
||||
{
|
||||
sd_params->sample_method = sample_method_t::EULER;
|
||||
}
|
||||
else if(sampler=="heun")
|
||||
{
|
||||
sd_params->sample_method = sample_method_t::HEUN;
|
||||
}
|
||||
else if(sampler=="dpm2")
|
||||
{
|
||||
sd_params->sample_method = sample_method_t::DPM2;
|
||||
}
|
||||
else if(sampler=="dpm++ 2m karras" || sampler=="dpm++ 2m")
|
||||
{
|
||||
sd_params->sample_method = sample_method_t::DPMPP2M;
|
||||
}
|
||||
else
|
||||
{
|
||||
sd_params->sample_method = sample_method_t::EULER_A;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue