diff --git a/otherarch/sdcpp/main.cpp b/otherarch/sdcpp/main.cpp index 9ff38f60a..9499e0c34 100644 --- a/otherarch/sdcpp/main.cpp +++ b/otherarch/sdcpp/main.cpp @@ -54,7 +54,6 @@ const char* modes_str[] = { "txt2img", "img2img", "img2vid", - "edit", "convert", }; @@ -62,7 +61,6 @@ enum SDMode { TXT2IMG, IMG2IMG, IMG2VID, - EDIT, CONVERT, MODE_COUNT }; @@ -88,7 +86,8 @@ struct SDParams { std::string input_path; std::string mask_path; std::string control_image_path; - std::vector ref_image_paths; + + std::vector kontext_image_paths; std::string prompt; std::string negative_prompt; @@ -154,10 +153,6 @@ void print_params(SDParams params) { printf(" init_img: %s\n", params.input_path.c_str()); printf(" mask_img: %s\n", params.mask_path.c_str()); printf(" control_image: %s\n", params.control_image_path.c_str()); - printf(" ref_images_paths:\n"); - for (auto& path : params.ref_image_paths) { - printf(" %s\n", path.c_str()); - }; printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false"); printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false"); printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false"); @@ -212,7 +207,6 @@ void print_usage(int argc, const char* argv[]) { printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n"); printf(" --mask [MASK] path to the mask image, required by img2img with mask\n"); printf(" --control-image [IMAGE] path to image condition, control net\n"); - printf(" -r, --ref_image [PATH] reference image for Flux Kontext models (can be used multiple times) \n"); printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n"); printf(" -p, --prompt [PROMPT] the prompt to render\n"); printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n"); @@ -248,8 +242,9 @@ void print_usage(int argc, const char* argv[]) { printf(" This might crash if it is not supported by the backend.\n"); printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n"); printf(" --canny apply canny preprocessor (edge detection)\n"); - printf(" --color colors the logging tags according to level\n"); + printf(" --color Colors the logging tags according to level\n"); printf(" -v, --verbose print extra info\n"); + printf(" -ki, --kontext_img [PATH] Reference image for Flux Kontext models (can be used multiple times) \n"); } void parse_args(int argc, const char** argv, SDParams& params) { @@ -634,12 +629,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.skip_layer_end = std::stof(argv[i]); - } else if (arg == "-r" || arg == "--ref-image") { + } else if (arg == "-ki" || arg == "--kontext-img") { if (++i >= argc) { invalid_arg = true; break; } - params.ref_image_paths.push_back(argv[i]); + params.kontext_image_paths.push_back(argv[i]); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); print_usage(argc, argv); @@ -668,13 +663,7 @@ void parse_args(int argc, const char** argv, SDParams& params) { } if ((params.mode == IMG2IMG || params.mode == IMG2VID) && params.input_path.length() == 0) { - fprintf(stderr, "error: when using the img2img/img2vid mode, the following arguments are required: init-img\n"); - print_usage(argc, argv); - exit(1); - } - - if (params.mode == EDIT && params.ref_image_paths.size() == 0) { - fprintf(stderr, "error: when using the edit mode, the following arguments are required: ref-image\n"); + fprintf(stderr, "error: when using the img2img mode, the following arguments are required: init-img\n"); print_usage(argc, argv); exit(1); } @@ -838,12 +827,43 @@ int main(int argc, const char* argv[]) { fprintf(stderr, "SVD support is broken, do not use it!!!\n"); return 1; } + bool vae_decode_only = true; + + std::vector kontext_imgs; + for (auto& path : params.kontext_image_paths) { + vae_decode_only = false; + int c = 0; + int width = 0; + int height = 0; + uint8_t* image_buffer = stbi_load(path.c_str(), &width, &height, &c, 3); + if (image_buffer == NULL) { + fprintf(stderr, "load image from '%s' failed\n", path.c_str()); + return 1; + } + if (c < 3) { + fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c); + free(image_buffer); + return 1; + } + if (width <= 0) { + fprintf(stderr, "error: the width of image must be greater than 0\n"); + free(image_buffer); + return 1; + } + if (height <= 0) { + fprintf(stderr, "error: the height of image must be greater than 0\n"); + free(image_buffer); + return 1; + } + kontext_imgs.push_back({(uint32_t)width, + (uint32_t)height, + 3, + image_buffer}); + } - bool vae_decode_only = true; uint8_t* input_image_buffer = NULL; uint8_t* control_image_buffer = NULL; uint8_t* mask_image_buffer = NULL; - std::vector ref_images; if (params.mode == IMG2IMG || params.mode == IMG2VID) { vae_decode_only = false; @@ -895,37 +915,6 @@ int main(int argc, const char* argv[]) { free(input_image_buffer); input_image_buffer = resized_image_buffer; } - } else if (params.mode == EDIT) { - vae_decode_only = false; - for (auto& path : params.ref_image_paths) { - int c = 0; - int width = 0; - int height = 0; - uint8_t* image_buffer = stbi_load(path.c_str(), &width, &height, &c, 3); - if (image_buffer == NULL) { - fprintf(stderr, "load image from '%s' failed\n", path.c_str()); - return 1; - } - if (c < 3) { - fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c); - free(image_buffer); - return 1; - } - if (width <= 0) { - fprintf(stderr, "error: the width of image must be greater than 0\n"); - free(image_buffer); - return 1; - } - if (height <= 0) { - fprintf(stderr, "error: the height of image must be greater than 0\n"); - free(image_buffer); - return 1; - } - ref_images.push_back({(uint32_t)width, - (uint32_t)height, - 3, - image_buffer}); - } } sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(), @@ -1012,12 +1001,14 @@ int main(int argc, const char* argv[]) { params.style_ratio, params.normalize_input, params.input_id_images_path.c_str(), + kontext_imgs.data(), kontext_imgs.size(), params.skip_layers.data(), params.skip_layers.size(), params.slg_scale, params.skip_layer_start, - params.skip_layer_end); - } else if (params.mode == IMG2IMG || params.mode == IMG2VID) { + params.skip_layer_end, + std::vector()); + } else { sd_image_t input_image = {(uint32_t)params.width, (uint32_t)params.height, 3, @@ -1081,38 +1072,14 @@ int main(int argc, const char* argv[]) { params.style_ratio, params.normalize_input, params.input_id_images_path.c_str(), + kontext_imgs.data(), kontext_imgs.size(), params.skip_layers.data(), params.skip_layers.size(), params.slg_scale, params.skip_layer_start, - params.skip_layer_end); + params.skip_layer_end, + std::vector()); } - } else { // EDIT - results = edit(sd_ctx, - ref_images.data(), - ref_images.size(), - params.prompt.c_str(), - params.negative_prompt.c_str(), - params.clip_skip, - params.cfg_scale, - params.guidance, - params.eta, - params.width, - params.height, - params.sample_method, - params.sample_steps, - params.strength, - params.seed, - params.batch_count, - control_image, - params.control_strength, - params.style_ratio, - params.normalize_input, - params.skip_layers.data(), - params.skip_layers.size(), - params.slg_scale, - params.skip_layer_start, - params.skip_layer_end); } if (results == NULL) { @@ -1150,11 +1117,11 @@ int main(int argc, const char* argv[]) { std::string dummy_name, ext, lc_ext; bool is_jpg; - size_t last = params.output_path.find_last_of("."); + size_t last = params.output_path.find_last_of("."); size_t last_path = std::min(params.output_path.find_last_of("/"), params.output_path.find_last_of("\\")); - if (last != std::string::npos // filename has extension - && (last_path == std::string::npos || last > last_path)) { + if (last != std::string::npos // filename has extension + && (last_path == std::string::npos || last > last_path)) { dummy_name = params.output_path.substr(0, last); ext = lc_ext = params.output_path.substr(last); std::transform(ext.begin(), ext.end(), lc_ext.begin(), ::tolower); @@ -1162,7 +1129,7 @@ int main(int argc, const char* argv[]) { } else { dummy_name = params.output_path; ext = lc_ext = ""; - is_jpg = false; + is_jpg = false; } // appending ".png" to absent or unknown extension if (!is_jpg && lc_ext != ".png") { @@ -1174,7 +1141,7 @@ int main(int argc, const char* argv[]) { continue; } std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ext : dummy_name + ext; - if(is_jpg) { + if (is_jpg) { stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, results[i].data, 90); printf("save result JPEG image to '%s'\n", final_image_path.c_str());