added model and config endpoints for sdcpp, added more samplers. speed is still not good

This commit is contained in:
Concedo 2024-02-29 22:56:09 +08:00
parent 257015bb94
commit e8f4d7b3da
2 changed files with 86 additions and 21 deletions

View file

@ -478,15 +478,28 @@ def sd_load_model(model_filename):
ret = handle.sd_load_model(inputs) ret = handle.sd_load_model(inputs)
return ret return ret
def sd_generate(prompt, negative_prompt="", cfg_scale=5, sample_steps=20, seed=-1, sample_method="euler a"): def sd_generate(genparams):
global maxctx, args, currentusergenkey, totalgens, pendingabortkey global maxctx, args, currentusergenkey, totalgens, pendingabortkey
prompt = genparams.get("prompt", "high quality")
negative_prompt = genparams.get("negative_prompt", "")
cfg_scale = genparams.get("cfg_scale", 5)
sample_steps = genparams.get("steps", 20)
seed = genparams.get("seed", -1)
sample_method = genparams.get("sampler_name", "euler a")
#quick mode
if args.sdconfig and len(args.sdconfig)>1 and args.sdconfig[1]=="quick":
cfg_scale = 1
sample_steps = 7
sample_method = "dpm++ 2m karras"
inputs = sd_generation_inputs() inputs = sd_generation_inputs()
inputs.prompt = prompt.encode("UTF-8") inputs.prompt = prompt.encode("UTF-8")
inputs.negative_prompt = negative_prompt.encode("UTF-8") inputs.negative_prompt = negative_prompt.encode("UTF-8")
inputs.cfg_scale = cfg_scale inputs.cfg_scale = cfg_scale
inputs.sample_steps = sample_steps inputs.sample_steps = sample_steps
inputs.seed = seed inputs.seed = seed
inputs.sample_method = sample_method.encode("UTF-8") inputs.sample_method = sample_method.lower().encode("UTF-8")
ret = handle.sd_generate(inputs) ret = handle.sd_generate(inputs)
outstr = "" outstr = ""
if ret.status==1: if ret.status==1:
@ -512,7 +525,9 @@ def bring_terminal_to_foreground():
### A hacky simple HTTP server simulating a kobold api by Concedo ### A hacky simple HTTP server simulating a kobold api by Concedo
### we are intentionally NOT using flask, because we want MINIMAL dependencies ### we are intentionally NOT using flask, because we want MINIMAL dependencies
################################################################# #################################################################
friendlymodelname = "concedo/koboldcpp" # local kobold api apparently needs a hardcoded known HF model name friendlymodelname = "inactive"
friendlysdmodelname = "inactive"
fullsdmodelpath = "" #if empty, it's not initialized
maxctx = 2048 maxctx = 2048
maxhordectx = 2048 maxhordectx = 2048
maxhordelen = 256 maxhordelen = 256
@ -860,7 +875,7 @@ Enter Prompt:<br>
self.wfile.write(finalhtml) self.wfile.write(finalhtml)
def do_GET(self): def do_GET(self):
global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens, preloaded_story, exitcounter, currentusergenkey global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens, preloaded_story, exitcounter, currentusergenkey, friendlysdmodelname, fullsdmodelpath
self.path = self.path.rstrip('/') self.path = self.path.rstrip('/')
response_body = None response_body = None
content_type = 'application/json' content_type = 'application/json'
@ -920,6 +935,11 @@ Enter Prompt:<br>
elif self.path.endswith('/v1/models'): elif self.path.endswith('/v1/models'):
response_body = (json.dumps({"object":"list","data":[{"id":friendlymodelname,"object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode()) response_body = (json.dumps({"object":"list","data":[{"id":friendlymodelname,"object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode())
elif self.path.endswith('/sdapi/v1/sd-models'):
response_body = (json.dumps([{"title":friendlysdmodelname,"model_name":friendlysdmodelname,"hash":"8888888888","sha256":"8888888888888888888888888888888888888888888888888888888888888888","filename":fullsdmodelpath,"config": None}]).encode())
elif self.path.endswith('/sdapi/v1/options'):
response_body = (json.dumps({"samples_format":"png","sd_model_checkpoint":friendlysdmodelname}).encode())
elif self.path=="/api": elif self.path=="/api":
content_type = 'text/html' content_type = 'text/html'
if self.embedded_kcpp_docs is None: if self.embedded_kcpp_docs is None:
@ -1044,6 +1064,7 @@ Enter Prompt:<br>
sse_stream_flag = False sse_stream_flag = False
api_format = 0 #1=basic,2=kai,3=oai,4=oai-chat api_format = 0 #1=basic,2=kai,3=oai,4=oai-chat
is_txt2img = False
if self.path.endswith('/request'): if self.path.endswith('/request'):
api_format = 1 api_format = 1
@ -1061,7 +1082,10 @@ Enter Prompt:<br>
if self.path.endswith('/v1/chat/completions'): if self.path.endswith('/v1/chat/completions'):
api_format = 4 api_format = 4
if api_format > 0: if self.path.endswith('/sdapi/v1/txt2img'):
is_txt2img = True
if is_txt2img or api_format > 0:
genparams = None genparams = None
try: try:
genparams = json.loads(body) genparams = json.loads(body)
@ -1076,6 +1100,7 @@ Enter Prompt:<br>
if args.foreground: if args.foreground:
bring_terminal_to_foreground() bring_terminal_to_foreground()
if api_format > 0:#text gen
# Check if streaming chat completions, if so, set stream mode to true # Check if streaming chat completions, if so, set stream mode to true
if (api_format == 4 or api_format == 3) and "stream" in genparams and genparams["stream"]: if (api_format == 4 or api_format == 3) and "stream" in genparams and genparams["stream"]:
sse_stream_flag = True sse_stream_flag = True
@ -1091,10 +1116,27 @@ Enter Prompt:<br>
self.end_headers(content_type='application/json') self.end_headers(content_type='application/json')
self.wfile.write(genresp) self.wfile.write(genresp)
except Exception as ex: except Exception as ex:
if args.debugmode:
print(ex)
print("Generate: The response could not be sent, maybe connection was terminated?") print("Generate: The response could not be sent, maybe connection was terminated?")
handle.abort_generate() handle.abort_generate()
time.sleep(0.2) #short delay time.sleep(0.2) #short delay
return return
elif is_txt2img: #image gen
try:
gen = sd_generate(genparams)
genresp = (json.dumps({"images":[gen],"parameters":{},"info":""}).encode())
self.send_response(200)
self.send_header('content-length', str(len(genresp)))
self.end_headers(content_type='application/json')
self.wfile.write(genresp)
except Exception as ex:
if args.debugmode:
print(ex)
print("Generate Image: The response could not be sent, maybe connection was terminated?")
time.sleep(0.2) #short delay
return
finally: finally:
modelbusy.release() modelbusy.release()
@ -2433,7 +2475,7 @@ def sanitize_string(input_string):
return sanitized_string return sanitized_string
def main(launch_args,start_server=True): def main(launch_args,start_server=True):
global args, friendlymodelname global args, friendlymodelname, friendlysdmodelname, fullsdmodelpath
args = launch_args args = launch_args
embedded_kailite = None embedded_kailite = None
embedded_kcpp_docs = None embedded_kcpp_docs = None
@ -2583,6 +2625,10 @@ def main(launch_args,start_server=True):
time.sleep(3) time.sleep(3)
sys.exit(2) sys.exit(2)
imgmodel = os.path.abspath(imgmodel) imgmodel = os.path.abspath(imgmodel)
fullsdmodelpath = imgmodel
friendlysdmodelname = os.path.basename(imgmodel)
friendlysdmodelname = os.path.splitext(friendlysdmodelname)[0]
friendlysdmodelname = sanitize_string(friendlysdmodelname)
loadok = sd_load_model(imgmodel) loadok = sd_load_model(imgmodel)
print("Load Image Model OK: " + str(loadok)) print("Load Image Model OK: " + str(loadok))
if not loadok: if not loadok:

View file

@ -127,7 +127,7 @@ static void sd_logger_callback(enum sd_log_level_t level, const char* log, void*
bool sdtype_load_model(const sd_load_model_inputs inputs) { bool sdtype_load_model(const sd_load_model_inputs inputs) {
printf("\nSelected Image Model: %s\n",inputs.model_filename); printf("\nImage Gen - Load Safetensors Image Model: %s\n",inputs.model_filename);
sd_params = new SDParams(); sd_params = new SDParams();
sd_params->model_path = inputs.model_filename; sd_params->model_path = inputs.model_filename;
@ -187,10 +187,29 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
sd_params->sample_steps = inputs.sample_steps; sd_params->sample_steps = inputs.sample_steps;
sd_params->seed = inputs.seed; sd_params->seed = inputs.seed;
if(inputs.sample_method=="euler a") //all lowercase printf("\nGenerating Image (%d steps)\n",inputs.sample_steps);
std::string sampler = inputs.sample_method;
if(sampler=="euler a") //all lowercase
{ {
sd_params->sample_method = sample_method_t::EULER_A; sd_params->sample_method = sample_method_t::EULER_A;
} }
else if(sampler=="euler")
{
sd_params->sample_method = sample_method_t::EULER;
}
else if(sampler=="heun")
{
sd_params->sample_method = sample_method_t::HEUN;
}
else if(sampler=="dpm2")
{
sd_params->sample_method = sample_method_t::DPM2;
}
else if(sampler=="dpm++ 2m karras" || sampler=="dpm++ 2m")
{
sd_params->sample_method = sample_method_t::DPMPP2M;
}
else else
{ {
sd_params->sample_method = sample_method_t::EULER_A; sd_params->sample_method = sample_method_t::EULER_A;