mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
image generation is fully working over api (+1 squashed commits)
Squashed commits: [c98ab0b4] single image generation is working now
This commit is contained in:
parent
e8f4d7b3da
commit
3463688a0e
5 changed files with 190 additions and 76 deletions
|
@ -430,7 +430,7 @@ target_link_libraries(common2 PRIVATE ggml ${LLAMA_EXTRA_LIBS})
|
|||
set_target_properties(common2 PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
add_library(sdtype_adapter
|
||||
sdtype_adapter.cpp)
|
||||
otherarch/sdcpp/sdtype_adapter.cpp)
|
||||
target_include_directories(sdtype_adapter PUBLIC . ./otherarch ./otherarch/tools ./otherarch/sdcpp ./otherarch/sdcpp/thirdparty ./examples ./common)
|
||||
target_compile_features(sdtype_adapter PUBLIC cxx_std_11) # don't bump
|
||||
target_link_libraries(sdtype_adapter PRIVATE common2 ggml ${LLAMA_EXTRA_LIBS})
|
||||
|
|
8
Makefile
8
Makefile
|
@ -481,9 +481,9 @@ expose.o: expose.cpp expose.h
|
|||
|
||||
# sd.cpp objects
|
||||
sdcpp_default.o: otherarch/sdcpp/sdtype_adapter.cpp otherarch/sdcpp/stable-diffusion.h otherarch/sdcpp/stable-diffusion.cpp otherarch/sdcpp/util.cpp otherarch/sdcpp/upscaler.cpp otherarch/sdcpp/model.cpp otherarch/sdcpp/thirdparty/zip.c
|
||||
$(CXX) $(FASTCXXFLAGS) -c $< -o $@
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
sdcpp_cublas.o: otherarch/sdcpp/sdtype_adapter.cpp otherarch/sdcpp/stable-diffusion.h otherarch/sdcpp/stable-diffusion.cpp otherarch/sdcpp/util.cpp otherarch/sdcpp/upscaler.cpp otherarch/sdcpp/model.cpp otherarch/sdcpp/thirdparty/zip.c
|
||||
$(CXX) $(FASTCXXFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
|
||||
$(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
|
||||
|
||||
# idiotic "for easier compilation"
|
||||
GPTTYPE_ADAPTER = gpttype_adapter.cpp otherarch/llama_v2.cpp otherarch/llama_v3.cpp llama.cpp otherarch/utils.cpp otherarch/gptj_v1.cpp otherarch/gptj_v2.cpp otherarch/gptj_v3.cpp otherarch/gpt2_v1.cpp otherarch/gpt2_v2.cpp otherarch/gpt2_v3.cpp otherarch/rwkv_v2.cpp otherarch/rwkv_v3.cpp otherarch/neox_v2.cpp otherarch/neox_v3.cpp otherarch/mpt_v3.cpp ggml.h ggml-cuda.h llama.h otherarch/llama-util.h
|
||||
|
@ -563,7 +563,7 @@ koboldcpp_clblast_noavx2:
|
|||
endif
|
||||
|
||||
ifdef CUBLAS_BUILD
|
||||
koboldcpp_cublas: ggml_v4_cublas.o ggml_v3_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o common.o gpttype_adapter_cublas.o ggml-quants.o ggml-alloc.o ggml-backend.o grammar-parser.o sdcpp_default.o $(CUBLAS_OBJS) $(OBJS)
|
||||
koboldcpp_cublas: ggml_v4_cublas.o ggml_v3_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o common.o gpttype_adapter_cublas.o ggml-quants.o ggml-alloc.o ggml-backend.o grammar-parser.o sdcpp_cublas.o $(CUBLAS_OBJS) $(OBJS)
|
||||
$(CUBLAS_BUILD)
|
||||
else
|
||||
koboldcpp_cublas:
|
||||
|
@ -571,7 +571,7 @@ koboldcpp_cublas:
|
|||
endif
|
||||
|
||||
ifdef HIPBLAS_BUILD
|
||||
koboldcpp_hipblas: ggml_v4_cublas.o ggml_v3_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o common.o gpttype_adapter_cublas.o ggml-quants.o ggml-alloc.o ggml-backend.o grammar-parser.o sdcpp_default.o $(HIP_OBJS) $(OBJS)
|
||||
koboldcpp_hipblas: ggml_v4_cublas.o ggml_v3_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o common.o gpttype_adapter_cublas.o ggml-quants.o ggml-alloc.o ggml-backend.o grammar-parser.o sdcpp_cublas.o $(HIP_OBJS) $(OBJS)
|
||||
$(HIPBLAS_BUILD)
|
||||
else
|
||||
koboldcpp_hipblas:
|
||||
|
|
7
expose.h
7
expose.h
|
@ -102,6 +102,10 @@ struct token_count_outputs
|
|||
struct sd_load_model_inputs
|
||||
{
|
||||
const char * model_filename;
|
||||
const int clblast_info = 0;
|
||||
const int cublas_info = 0;
|
||||
const char * vulkan_info;
|
||||
const int threads;
|
||||
const int debugmode = 0;
|
||||
};
|
||||
struct sd_generation_inputs
|
||||
|
@ -116,8 +120,7 @@ struct sd_generation_inputs
|
|||
struct sd_generation_outputs
|
||||
{
|
||||
int status = -1;
|
||||
unsigned int data_length = 0;
|
||||
const char * data;
|
||||
const char * data = "";
|
||||
};
|
||||
|
||||
extern std::string executable_path;
|
||||
|
|
102
koboldcpp.py
102
koboldcpp.py
|
@ -95,6 +95,10 @@ class generation_outputs(ctypes.Structure):
|
|||
|
||||
class sd_load_model_inputs(ctypes.Structure):
|
||||
_fields_ = [("model_filename", ctypes.c_char_p),
|
||||
("clblast_info", ctypes.c_int),
|
||||
("cublas_info", ctypes.c_int),
|
||||
("vulkan_info", ctypes.c_char_p),
|
||||
("threads", ctypes.c_int),
|
||||
("debugmode", ctypes.c_int)]
|
||||
|
||||
class sd_generation_inputs(ctypes.Structure):
|
||||
|
@ -107,7 +111,6 @@ class sd_generation_inputs(ctypes.Structure):
|
|||
|
||||
class sd_generation_outputs(ctypes.Structure):
|
||||
_fields_ = [("status", ctypes.c_int),
|
||||
("data_length", ctypes.c_uint),
|
||||
("data", ctypes.c_char_p)]
|
||||
|
||||
handle = None
|
||||
|
@ -279,47 +282,12 @@ def init_library():
|
|||
handle.sd_generate.argtypes = [sd_generation_inputs]
|
||||
handle.sd_generate.restype = sd_generation_outputs
|
||||
|
||||
def load_model(model_filename):
|
||||
global args
|
||||
inputs = load_model_inputs()
|
||||
inputs.model_filename = model_filename.encode("UTF-8")
|
||||
inputs.max_context_length = maxctx #initial value to use for ctx, can be overwritten
|
||||
inputs.threads = args.threads
|
||||
inputs.low_vram = (True if (args.usecublas and "lowvram" in args.usecublas) else False)
|
||||
inputs.use_mmq = (True if (args.usecublas and "mmq" in args.usecublas) else False)
|
||||
inputs.use_rowsplit = (True if (args.usecublas and "rowsplit" in args.usecublas) else False)
|
||||
inputs.vulkan_info = "0".encode("UTF-8")
|
||||
inputs.blasthreads = args.blasthreads
|
||||
inputs.use_mmap = (not args.nommap)
|
||||
inputs.use_mlock = args.usemlock
|
||||
inputs.lora_filename = "".encode("UTF-8")
|
||||
inputs.lora_base = "".encode("UTF-8")
|
||||
if args.lora:
|
||||
inputs.lora_filename = args.lora[0].encode("UTF-8")
|
||||
inputs.use_mmap = False
|
||||
if len(args.lora) > 1:
|
||||
inputs.lora_base = args.lora[1].encode("UTF-8")
|
||||
inputs.use_smartcontext = args.smartcontext
|
||||
inputs.use_contextshift = (0 if args.noshift else 1)
|
||||
inputs.blasbatchsize = args.blasbatchsize
|
||||
inputs.forceversion = args.forceversion
|
||||
inputs.gpulayers = args.gpulayers
|
||||
inputs.rope_freq_scale = args.ropeconfig[0]
|
||||
if len(args.ropeconfig)>1:
|
||||
inputs.rope_freq_base = args.ropeconfig[1]
|
||||
else:
|
||||
inputs.rope_freq_base = 10000
|
||||
def set_backend_props(inputs):
|
||||
clblastids = 0
|
||||
if args.useclblast:
|
||||
clblastids = 100 + int(args.useclblast[0])*10 + int(args.useclblast[1])
|
||||
inputs.clblast_info = clblastids
|
||||
|
||||
for n in range(tensor_split_max):
|
||||
if args.tensor_split and n < len(args.tensor_split):
|
||||
inputs.tensor_split[n] = float(args.tensor_split[n])
|
||||
else:
|
||||
inputs.tensor_split[n] = 0
|
||||
|
||||
# we must force an explicit tensor split
|
||||
# otherwise the default will divide equally and multigpu crap will slow it down badly
|
||||
inputs.cublas_info = 0
|
||||
|
@ -356,6 +324,46 @@ def load_model(model_filename):
|
|||
inputs.vulkan_info = s.encode("UTF-8")
|
||||
else:
|
||||
inputs.vulkan_info = "0".encode("UTF-8")
|
||||
return inputs
|
||||
|
||||
def load_model(model_filename):
|
||||
global args
|
||||
inputs = load_model_inputs()
|
||||
inputs.model_filename = model_filename.encode("UTF-8")
|
||||
inputs.max_context_length = maxctx #initial value to use for ctx, can be overwritten
|
||||
inputs.threads = args.threads
|
||||
inputs.low_vram = (True if (args.usecublas and "lowvram" in args.usecublas) else False)
|
||||
inputs.use_mmq = (True if (args.usecublas and "mmq" in args.usecublas) else False)
|
||||
inputs.use_rowsplit = (True if (args.usecublas and "rowsplit" in args.usecublas) else False)
|
||||
inputs.vulkan_info = "0".encode("UTF-8")
|
||||
inputs.blasthreads = args.blasthreads
|
||||
inputs.use_mmap = (not args.nommap)
|
||||
inputs.use_mlock = args.usemlock
|
||||
inputs.lora_filename = "".encode("UTF-8")
|
||||
inputs.lora_base = "".encode("UTF-8")
|
||||
if args.lora:
|
||||
inputs.lora_filename = args.lora[0].encode("UTF-8")
|
||||
inputs.use_mmap = False
|
||||
if len(args.lora) > 1:
|
||||
inputs.lora_base = args.lora[1].encode("UTF-8")
|
||||
inputs.use_smartcontext = args.smartcontext
|
||||
inputs.use_contextshift = (0 if args.noshift else 1)
|
||||
inputs.blasbatchsize = args.blasbatchsize
|
||||
inputs.forceversion = args.forceversion
|
||||
inputs.gpulayers = args.gpulayers
|
||||
inputs.rope_freq_scale = args.ropeconfig[0]
|
||||
if len(args.ropeconfig)>1:
|
||||
inputs.rope_freq_base = args.ropeconfig[1]
|
||||
else:
|
||||
inputs.rope_freq_base = 10000
|
||||
|
||||
for n in range(tensor_split_max):
|
||||
if args.tensor_split and n < len(args.tensor_split):
|
||||
inputs.tensor_split[n] = float(args.tensor_split[n])
|
||||
else:
|
||||
inputs.tensor_split[n] = 0
|
||||
|
||||
inputs = set_backend_props(inputs)
|
||||
|
||||
inputs.executable_path = (getdirpath()+"/").encode("UTF-8")
|
||||
inputs.debugmode = args.debugmode
|
||||
|
@ -475,6 +483,13 @@ def sd_load_model(model_filename):
|
|||
inputs = sd_load_model_inputs()
|
||||
inputs.debugmode = args.debugmode
|
||||
inputs.model_filename = model_filename.encode("UTF-8")
|
||||
thds = args.threads
|
||||
if len(args.sdconfig) > 2:
|
||||
sdt = int(args.sdconfig[2])
|
||||
if sdt > 0:
|
||||
thds = sdt
|
||||
inputs.threads = thds
|
||||
inputs = set_backend_props(inputs)
|
||||
ret = handle.sd_load_model(inputs)
|
||||
return ret
|
||||
|
||||
|
@ -1372,6 +1387,7 @@ def show_new_gui():
|
|||
|
||||
sd_model_var = ctk.StringVar()
|
||||
sd_quick_var = ctk.IntVar(value=0)
|
||||
sd_threads_var = ctk.StringVar()
|
||||
|
||||
def tabbuttonaction(name):
|
||||
for t in tabcontent:
|
||||
|
@ -1849,6 +1865,8 @@ def show_new_gui():
|
|||
images_tab = tabcontent["Image Gen"]
|
||||
makefileentry(images_tab, "Stable Diffusion Model (f16):", "Select Stable Diffusion Model File", sd_model_var, 1, filetypes=[("*.safetensors","*.safetensors")], tooltiptxt="Select a .safetensors Stable Diffusion model file on disk to be loaded.")
|
||||
makecheckbox(images_tab, "Quick Mode (Low Quality)", sd_quick_var, 4,tooltiptxt="Force optimal generation settings for speed.")
|
||||
makelabelentry(images_tab, "Image threads:" , sd_threads_var, 6, 50,"How many threads to use during image generation.\nIf left blank, uses same value as threads.")
|
||||
|
||||
|
||||
# launch
|
||||
def guilaunch():
|
||||
|
@ -1936,7 +1954,7 @@ def show_new_gui():
|
|||
else:
|
||||
args.hordeconfig = None if usehorde_var.get() == 0 else [horde_name_var.get(), horde_gen_var.get(), horde_context_var.get(), horde_apikey_var.get(), horde_workername_var.get()]
|
||||
|
||||
args.sdconfig = None if sd_model_var.get() == "" else [sd_model_var.get(), ("quick" if sd_quick_var.get()==1 else "normal")]
|
||||
args.sdconfig = None if sd_model_var.get() == "" else [sd_model_var.get(), ("quick" if sd_quick_var.get()==1 else "normal"),(int(threads_var.get()) if sd_threads_var.get()=="" else int(sd_threads_var.get()))]
|
||||
|
||||
def import_vars(dict):
|
||||
if "threads" in dict:
|
||||
|
@ -2065,10 +2083,12 @@ def show_new_gui():
|
|||
horde_workername_var.set(dict["hordeconfig"][4])
|
||||
usehorde_var.set("1")
|
||||
|
||||
if "sdconfig" in dict and dict["sdconfig"] and len(dict["sdconfig"]) > 1:
|
||||
if "sdconfig" in dict and dict["sdconfig"] and len(dict["sdconfig"]) > 0:
|
||||
sd_model_var.set(dict["sdconfig"][0])
|
||||
if len(dict["sdconfig"]) > 2:
|
||||
if len(dict["sdconfig"]) > 1:
|
||||
sd_quick_var.set(1 if dict["sdconfig"][1]=="quick" else 0)
|
||||
if len(dict["sdconfig"]) > 2:
|
||||
sd_threads_var.set(str(dict["sdconfig"][2]))
|
||||
|
||||
def save_config():
|
||||
file_type = [("KoboldCpp Settings", "*.kcpps")]
|
||||
|
@ -2845,6 +2865,6 @@ if __name__ == '__main__':
|
|||
parser.add_argument("--quiet", help="Enable quiet mode, which hides generation inputs and outputs in the terminal. Quiet mode is automatically enabled when running --hordeconfig.", action='store_true')
|
||||
parser.add_argument("--ssl", help="Allows all content to be served over SSL instead. A valid UNENCRYPTED SSL cert and key .pem files must be provided", metavar=('[cert_pem]', '[key_pem]'), nargs='+')
|
||||
parser.add_argument("--nocertify", help="Allows insecure SSL connections. Use this if you have cert errors and need to bypass certificate restrictions.", action='store_true')
|
||||
parser.add_argument("--sdconfig", help="Specify a stable diffusion safetensors model to enable image generation. If quick is specified, force optimal generation settings for speed.",metavar=('[sd_filename]', '[normal|quick]'), nargs='+')
|
||||
parser.add_argument("--sdconfig", help="Specify a stable diffusion safetensors model to enable image generation. If quick is specified, force optimal generation settings for speed.",metavar=('[sd_filename]', '[normal|quick] [sd_threads]'), nargs='+')
|
||||
|
||||
main(parser.parse_args(),start_server=True)
|
||||
|
|
|
@ -110,6 +110,30 @@ struct SDParams {
|
|||
//global static vars for SD
|
||||
static SDParams * sd_params = nullptr;
|
||||
static sd_ctx_t * sd_ctx = nullptr;
|
||||
static int sddebugmode = 0;
|
||||
static std::string recent_data = "";
|
||||
|
||||
std::string base64_encode(const unsigned char* data, unsigned int data_length) {
|
||||
const std::string base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||
std::string encoded;
|
||||
encoded.reserve(((data_length + 2) / 3) * 4);
|
||||
for (unsigned int i = 0; i < data_length; i += 3) {
|
||||
unsigned int triple = (data[i] << 16) + (i + 1 < data_length ? data[i + 1] << 8 : 0) + (i + 2 < data_length ? data[i + 2] : 0);
|
||||
encoded.push_back(base64_chars[(triple >> 18) & 0x3F]);
|
||||
encoded.push_back(base64_chars[(triple >> 12) & 0x3F]);
|
||||
if (i + 1 < data_length) {
|
||||
encoded.push_back(base64_chars[(triple >> 6) & 0x3F]);
|
||||
} else {
|
||||
encoded.push_back('=');
|
||||
}
|
||||
if (i + 2 < data_length) {
|
||||
encoded.push_back(base64_chars[triple & 0x3F]);
|
||||
} else {
|
||||
encoded.push_back('=');
|
||||
}
|
||||
}
|
||||
return encoded;
|
||||
}
|
||||
|
||||
static void sd_logger_callback(enum sd_log_level_t level, const char* log, void* data) {
|
||||
SDParams* params = (SDParams*)data;
|
||||
|
@ -125,22 +149,71 @@ static void sd_logger_callback(enum sd_log_level_t level, const char* log, void*
|
|||
}
|
||||
}
|
||||
|
||||
static std::string sdplatformenv, sddeviceenv, sdvulkandeviceenv;
|
||||
bool sdtype_load_model(const sd_load_model_inputs inputs) {
|
||||
|
||||
printf("\nImage Gen - Load Safetensors Image Model: %s\n",inputs.model_filename);
|
||||
|
||||
//duplicated from expose.cpp
|
||||
int cl_parseinfo = inputs.clblast_info; //first digit is whether configured, second is platform, third is devices
|
||||
std::string usingclblast = "GGML_OPENCL_CONFIGURED="+std::to_string(cl_parseinfo>0?1:0);
|
||||
putenv((char*)usingclblast.c_str());
|
||||
cl_parseinfo = cl_parseinfo%100; //keep last 2 digits
|
||||
int platform = cl_parseinfo/10;
|
||||
int devices = cl_parseinfo%10;
|
||||
sdplatformenv = "GGML_OPENCL_PLATFORM="+std::to_string(platform);
|
||||
sddeviceenv = "GGML_OPENCL_DEVICE="+std::to_string(devices);
|
||||
putenv((char*)sdplatformenv.c_str());
|
||||
putenv((char*)sddeviceenv.c_str());
|
||||
std::string vulkan_info_raw = inputs.vulkan_info;
|
||||
std::string vulkan_info_str = "";
|
||||
for (size_t i = 0; i < vulkan_info_raw.length(); ++i) {
|
||||
vulkan_info_str += vulkan_info_raw[i];
|
||||
if (i < vulkan_info_raw.length() - 1) {
|
||||
vulkan_info_str += ",";
|
||||
}
|
||||
}
|
||||
if(vulkan_info_str=="")
|
||||
{
|
||||
vulkan_info_str = "0";
|
||||
}
|
||||
sdvulkandeviceenv = "GGML_VK_VISIBLE_DEVICES="+vulkan_info_str;
|
||||
putenv((char*)sdvulkandeviceenv.c_str());
|
||||
|
||||
sd_params = new SDParams();
|
||||
sd_params->model_path = inputs.model_filename;
|
||||
sd_params->wtype = SD_TYPE_F16;
|
||||
sd_params->n_threads = -1; //use physical cores
|
||||
sd_params->n_threads = inputs.threads; //if -1 use physical cores
|
||||
sd_params->input_path = ""; //unused
|
||||
sd_params->batch_count = 1;
|
||||
|
||||
if(inputs.debugmode==1)
|
||||
sddebugmode = inputs.debugmode;
|
||||
|
||||
if(sddebugmode==1)
|
||||
{
|
||||
sd_set_log_callback(sd_logger_callback, (void*)sd_params);
|
||||
}
|
||||
|
||||
bool vae_decode_only = false;
|
||||
bool free_param = false;
|
||||
if(inputs.debugmode==1)
|
||||
{
|
||||
printf("\nMODEL:%s\nVAE:%s\nTAESD:%s\nCNET:%s\nLORA:%s\nEMBD:%s\nVAE_DEC:%d\nVAE_TILE:%d\nFREE_PARAM:%d\nTHREADS:%d\nWTYPE:%d\nRNGTYPE:%d\nSCHED:%d\nCNETCPU:%d\n\n",
|
||||
sd_params->model_path.c_str(),
|
||||
sd_params->vae_path.c_str(),
|
||||
sd_params->taesd_path.c_str(),
|
||||
sd_params->controlnet_path.c_str(),
|
||||
sd_params->lora_model_dir.c_str(),
|
||||
sd_params->embeddings_path.c_str(),
|
||||
vae_decode_only,
|
||||
sd_params->vae_tiling,
|
||||
free_param,
|
||||
sd_params->n_threads,
|
||||
sd_params->wtype,
|
||||
sd_params->rng_type,
|
||||
sd_params->schedule,
|
||||
sd_params->control_net_cpu);
|
||||
}
|
||||
|
||||
sd_ctx = new_sd_ctx(sd_params->model_path.c_str(),
|
||||
sd_params->vae_path.c_str(),
|
||||
|
@ -150,7 +223,7 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
|
|||
sd_params->embeddings_path.c_str(),
|
||||
vae_decode_only,
|
||||
sd_params->vae_tiling,
|
||||
true,
|
||||
free_param,
|
||||
sd_params->n_threads,
|
||||
sd_params->wtype,
|
||||
sd_params->rng_type,
|
||||
|
@ -169,12 +242,12 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
|
|||
sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
|
||||
{
|
||||
sd_generation_outputs output;
|
||||
|
||||
if(sd_ctx == nullptr || sd_params == nullptr)
|
||||
{
|
||||
printf("\nError: KCPP SD is not initialized!\n");
|
||||
output.data = nullptr;
|
||||
output.data = "";
|
||||
output.status = 0;
|
||||
output.data_length = 0;
|
||||
return output;
|
||||
}
|
||||
uint8_t * input_image_buffer = NULL;
|
||||
|
@ -188,6 +261,7 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
|
|||
sd_params->seed = inputs.seed;
|
||||
|
||||
printf("\nGenerating Image (%d steps)\n",inputs.sample_steps);
|
||||
fflush(stdout);
|
||||
std::string sampler = inputs.sample_method;
|
||||
|
||||
if(sampler=="euler a") //all lowercase
|
||||
|
@ -216,6 +290,23 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
|
|||
}
|
||||
|
||||
if (sd_params->mode == TXT2IMG) {
|
||||
|
||||
if(sddebugmode==1)
|
||||
{
|
||||
printf("\nPROMPT:%s\nNPROMPT:%s\nCLPSKP:%d\nCFGSCLE:%f\nW:%d\nH:%d\nSM:%d\nSTEP:%d\nSEED:%d\nBATCH:%d\nCIMG:%d\nCSTR:%f\n\n",
|
||||
sd_params->prompt.c_str(),
|
||||
sd_params->negative_prompt.c_str(),
|
||||
sd_params->clip_skip,
|
||||
sd_params->cfg_scale,
|
||||
sd_params->width,
|
||||
sd_params->height,
|
||||
sd_params->sample_method,
|
||||
sd_params->sample_steps,
|
||||
sd_params->seed,
|
||||
sd_params->batch_count,
|
||||
control_image,
|
||||
sd_params->control_strength);
|
||||
}
|
||||
results = txt2img(sd_ctx,
|
||||
sd_params->prompt.c_str(),
|
||||
sd_params->negative_prompt.c_str(),
|
||||
|
@ -251,31 +342,31 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
|
|||
|
||||
if (results == NULL) {
|
||||
printf("\nKCPP SD generate failed!\n");
|
||||
output.data = nullptr;
|
||||
output.data = "";
|
||||
output.status = 0;
|
||||
output.data_length = 0;
|
||||
return output;
|
||||
}
|
||||
|
||||
|
||||
size_t last = sd_params->output_path.find_last_of(".");
|
||||
std::string dummy_name = last != std::string::npos ? sd_params->output_path.substr(0, last) : sd_params->output_path;
|
||||
for (int i = 0; i < sd_params->batch_count; i++) {
|
||||
if (results[i].data == NULL) {
|
||||
continue;
|
||||
}
|
||||
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ".png" : dummy_name + ".png";
|
||||
stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
|
||||
results[i].data, 0, "Made By KoboldCpp");
|
||||
printf("save result image to '%s'\n", final_image_path.c_str());
|
||||
|
||||
int out_data_len;
|
||||
unsigned char * png = stbi_write_png_to_mem(results[i].data, 0, results[i].width, results[i].height, results[i].channel, &out_data_len, "");
|
||||
if (png != NULL)
|
||||
{
|
||||
recent_data = base64_encode(png,out_data_len);
|
||||
free(png);
|
||||
}
|
||||
|
||||
free(results[i].data);
|
||||
results[i].data = NULL;
|
||||
}
|
||||
|
||||
free(results);
|
||||
|
||||
output.data = nullptr;
|
||||
output.data = recent_data.c_str();
|
||||
output.status = 1;
|
||||
output.data_length = 0;
|
||||
return output;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue