diff --git a/koboldcpp.py b/koboldcpp.py index f831af684..fa934cf1d 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -3772,7 +3772,7 @@ Change Mode
if friendlysdmodelname=="inactive" or fullsdmodelpath=="": response_body = (json.dumps([]).encode()) else: - response_body = (json.dumps([{"name":name,"label":name} for name in ["default","discrete","karras","exponential","ays","gits","sgm_uniform","simple","smoothstep","lcm"]]).encode()) + response_body = (json.dumps([{"name":name,"label":name} for name in ["default","discrete","karras","exponential","ays","gits","sgm_uniform","simple","smoothstep","kl_optimal","lcm"]]).encode()) elif clean_path.endswith('/sdapi/v1/latent-upscale-modes'): response_body = (json.dumps([]).encode()) elif clean_path.endswith('/sdapi/v1/upscalers'): diff --git a/otherarch/sdcpp/common/common.hpp b/otherarch/sdcpp/common/common.hpp index 34076799a..6c32de030 100644 --- a/otherarch/sdcpp/common/common.hpp +++ b/otherarch/sdcpp/common/common.hpp @@ -20,18 +20,22 @@ namespace fs = std::filesystem; #include "stable-diffusion.h" -// #define STB_IMAGE_IMPLEMENTATION -//#define STB_IMAGE_STATIC +namespace { // kcpp + +#define STB_IMAGE_IMPLEMENTATION +#define STB_IMAGE_STATIC #include "stb_image.h" #define STB_IMAGE_WRITE_IMPLEMENTATION -//#define STB_IMAGE_WRITE_STATIC +#define STB_IMAGE_WRITE_STATIC #include "stb_image_write.h" -// #define STB_IMAGE_RESIZE_IMPLEMENTATION -//#define STB_IMAGE_RESIZE_STATIC +#define STB_IMAGE_RESIZE_IMPLEMENTATION +#define STB_IMAGE_RESIZE_STATIC #include "stb_image_resize.h" +} + #define SAFE_STR(s) ((s) ? (s) : "") #define BOOL_STR(b) ((b) ? "true" : "false") @@ -87,6 +91,114 @@ static std::string argv_to_utf8(int index, const char** argv) { #endif +static void print_utf8(FILE* stream, const char* utf8) { + if (!utf8) + return; + +#ifdef _WIN32 + HANDLE h = (stream == stderr) + ? GetStdHandle(STD_ERROR_HANDLE) + : GetStdHandle(STD_OUTPUT_HANDLE); + + int wlen = MultiByteToWideChar(CP_UTF8, 0, utf8, -1, NULL, 0); + if (wlen <= 0) + return; + + wchar_t* wbuf = (wchar_t*)malloc(wlen * sizeof(wchar_t)); + MultiByteToWideChar(CP_UTF8, 0, utf8, -1, wbuf, wlen); + + DWORD written; + WriteConsoleW(h, wbuf, wlen - 1, &written, NULL); + + free(wbuf); +#else + fputs(utf8, stream); +#endif +} + +static std::string sd_basename(const std::string& path) { + size_t pos = path.find_last_of('/'); + if (pos != std::string::npos) { + return path.substr(pos + 1); + } + pos = path.find_last_of('\\'); + if (pos != std::string::npos) { + return path.substr(pos + 1); + } + return path; +} + +static void log_print(enum sd_log_level_t level, const char* log, bool verbose, bool color) { + int tag_color; + const char* level_str; + FILE* out_stream = (level == SD_LOG_ERROR) ? stderr : stdout; + + if (!log || (!verbose && level <= SD_LOG_DEBUG)) { + return; + } + + switch (level) { + case SD_LOG_DEBUG: + tag_color = 37; + level_str = "DEBUG"; + break; + case SD_LOG_INFO: + tag_color = 34; + level_str = "INFO"; + break; + case SD_LOG_WARN: + tag_color = 35; + level_str = "WARN"; + break; + case SD_LOG_ERROR: + tag_color = 31; + level_str = "ERROR"; + break; + default: /* Potential future-proofing */ + tag_color = 33; + level_str = "?????"; + break; + } + + if (color) { + fprintf(out_stream, "\033[%d;1m[%-5s]\033[0m ", tag_color, level_str); + } else { + fprintf(out_stream, "[%-5s] ", level_str); + } + print_utf8(out_stream, log); + fflush(out_stream); +} + +#define LOG_BUFFER_SIZE 4096 + +static bool log_verbose = false; +static bool log_color = false; + +static void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...) { + va_list args; + va_start(args, format); + + static char log_buffer[LOG_BUFFER_SIZE + 1]; + int written = snprintf(log_buffer, LOG_BUFFER_SIZE, "%s:%-4d - ", sd_basename(file).c_str(), line); + + if (written >= 0 && written < LOG_BUFFER_SIZE) { + vsnprintf(log_buffer + written, LOG_BUFFER_SIZE - written, format, args); + } + size_t len = strlen(log_buffer); + if (log_buffer[len - 1] != '\n') { + strncat(log_buffer, "\n", LOG_BUFFER_SIZE - len); + } + + log_print(level, log_buffer, log_verbose, log_color); + + va_end(args); +} + +#define LOG_DEBUG(format, ...) log_printf(SD_LOG_DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__) +#define LOG_INFO(format, ...) log_printf(SD_LOG_INFO, __FILE__, __LINE__, format, ##__VA_ARGS__) +#define LOG_WARN(format, ...) log_printf(SD_LOG_WARN, __FILE__, __LINE__, format, ##__VA_ARGS__) +#define LOG_ERROR(format, ...) log_printf(SD_LOG_ERROR, __FILE__, __LINE__, format, ##__VA_ARGS__) + struct StringOption { std::string short_name; std::string long_name; @@ -296,11 +408,11 @@ static bool parse_options(int argc, const char** argv, const std::vector(embedding_vec.size()), photo_maker_path.c_str(), @@ -864,7 +979,7 @@ static bool is_absolute_path(const std::string& p) { struct SDGenerationParams { std::string prompt; - std::string prompt_with_lora; // for metadata record only + std::string prompt_with_lora; // for metadata record only std::string negative_prompt; int clip_skip = -1; // <= 0 represents unspecified int width = 512; @@ -1116,8 +1231,8 @@ struct SDGenerationParams { const char* arg = argv[index]; sample_params.sample_method = str_to_sample_method(arg); if (sample_params.sample_method == SAMPLE_METHOD_COUNT) { - fprintf(stderr, "error: invalid sample method %s\n", - arg); + LOG_ERROR("error: invalid sample method %s", + arg); return -1; } return 1; @@ -1130,8 +1245,8 @@ struct SDGenerationParams { const char* arg = argv[index]; high_noise_sample_params.sample_method = str_to_sample_method(arg); if (high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) { - fprintf(stderr, "error: invalid high noise sample method %s\n", - arg); + LOG_ERROR("error: invalid high noise sample method %s", + arg); return -1; } return 1; @@ -1144,8 +1259,8 @@ struct SDGenerationParams { const char* arg = argv[index]; sample_params.scheduler = str_to_scheduler(arg); if (sample_params.scheduler == SCHEDULER_COUNT) { - fprintf(stderr, "error: invalid scheduler %s\n", - arg); + LOG_ERROR("error: invalid scheduler %s", + arg); return -1; } return 1; @@ -1226,17 +1341,17 @@ struct SDGenerationParams { try { custom_sigmas.push_back(std::stof(item)); } catch (const std::invalid_argument& e) { - fprintf(stderr, "error: invalid float value '%s' in --sigmas\n", item.c_str()); + LOG_ERROR("error: invalid float value '%s' in --sigmas", item.c_str()); return -1; } catch (const std::out_of_range& e) { - fprintf(stderr, "error: float value '%s' out of range in --sigmas\n", item.c_str()); + LOG_ERROR("error: float value '%s' out of range in --sigmas", item.c_str()); return -1; } } } if (custom_sigmas.empty() && !sigmas_str.empty()) { - fprintf(stderr, "error: could not parse any sigma values from '%s'\n", argv[index]); + LOG_ERROR("error: could not parse any sigma values from '%s'", argv[index]); return -1; } return 1; @@ -1299,7 +1414,7 @@ struct SDGenerationParams { on_high_noise_sample_method_arg}, {"", "--scheduler", - "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], default: discrete", + "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm], default: discrete", on_scheduler_arg}, {"", "--sigmas", @@ -1332,7 +1447,7 @@ struct SDGenerationParams { try { j = json::parse(json_str); } catch (...) { - fprintf(stderr, "json parse failed %s\n", json_str.c_str()); + LOG_ERROR("json parse failed %s", json_str.c_str()); return false; } @@ -1441,7 +1556,7 @@ struct SDGenerationParams { } } if (!found) { - printf("can not found lora %s\n", final_path.lexically_normal().string().c_str()); + LOG_WARN("can not found lora %s", final_path.lexically_normal().string().c_str()); tmp = m.suffix().str(); prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only); continue; @@ -1480,17 +1595,17 @@ struct SDGenerationParams { bool process_and_check(SDMode mode, const std::string& lora_model_dir) { prompt_with_lora = prompt; if (width <= 0) { - fprintf(stderr, "error: the width must be greater than 0\n"); + LOG_ERROR("error: the width must be greater than 0\n"); return false; } if (height <= 0) { - fprintf(stderr, "error: the height must be greater than 0\n"); + LOG_ERROR("error: the height must be greater than 0\n"); return false; } if (sample_params.sample_steps <= 0) { - fprintf(stderr, "error: the sample_steps must be greater than 0\n"); + LOG_ERROR("error: the sample_steps must be greater than 0\n"); return false; } @@ -1499,7 +1614,7 @@ struct SDGenerationParams { } if (strength < 0.f || strength > 1.f) { - fprintf(stderr, "error: can only work with strength in [0.0, 1.0]\n"); + LOG_ERROR("error: can only work with strength in [0.0, 1.0]\n"); return false; } @@ -1521,31 +1636,31 @@ struct SDGenerationParams { }; trim(token); if (token.empty()) { - fprintf(stderr, "error: invalid easycache option '%s'\n", easycache_option.c_str()); + LOG_ERROR("error: invalid easycache option '%s'", easycache_option.c_str()); return false; } if (idx >= 3) { - fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n"); + LOG_ERROR("error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n"); return false; } try { values[idx] = std::stof(token); } catch (const std::exception&) { - fprintf(stderr, "error: invalid easycache value '%s'\n", token.c_str()); + LOG_ERROR("error: invalid easycache value '%s'", token.c_str()); return false; } idx++; } if (idx != 3) { - fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n"); + LOG_ERROR("error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n"); return false; } if (values[0] < 0.0f) { - fprintf(stderr, "error: easycache threshold must be non-negative\n"); + LOG_ERROR("error: easycache threshold must be non-negative\n"); return false; } if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) { - fprintf(stderr, "error: easycache start/end percents must satisfy 0.0 <= start < end <= 1.0\n"); + LOG_ERROR("error: easycache start/end percents must satisfy 0.0 <= start < end <= 1.0\n"); return false; } easycache_params.enabled = true; @@ -1585,7 +1700,7 @@ struct SDGenerationParams { if (mode == UPSCALE) { if (init_image_path.length() == 0) { - fprintf(stderr, "error: upscale mode needs an init image (--init-img)\n"); + LOG_ERROR("error: upscale mode needs an init image (--init-img)\n"); return false; } } @@ -1700,13 +1815,13 @@ uint8_t* load_image_common(bool from_memory, image_buffer = (uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel); } if (image_buffer == nullptr) { - fprintf(stderr, "load image from '%s' failed\n", image_path); + LOG_ERROR("load image from '%s' failed", image_path); return nullptr; } if (c < expected_channel) { fprintf(stderr, "the number of channels for the input image must be >= %d," - "but got %d channels, image_path = %s\n", + "but got %d channels, image_path = %s", expected_channel, c, image_path); @@ -1714,12 +1829,12 @@ uint8_t* load_image_common(bool from_memory, return nullptr; } if (width <= 0) { - fprintf(stderr, "error: the width of image must be greater than 0, image_path = %s\n", image_path); + LOG_ERROR("error: the width of image must be greater than 0, image_path = %s", image_path); free(image_buffer); return nullptr; } if (height <= 0) { - fprintf(stderr, "error: the height of image must be greater than 0, image_path = %s\n", image_path); + LOG_ERROR("error: the height of image must be greater than 0, image_path = %s", image_path); free(image_buffer); return nullptr; } @@ -1741,10 +1856,10 @@ uint8_t* load_image_common(bool from_memory, } if (crop_x != 0 || crop_y != 0) { - printf("crop input image from %dx%d to %dx%d, image_path = %s\n", width, height, crop_w, crop_h, image_path); + LOG_INFO("crop input image from %dx%d to %dx%d, image_path = %s", width, height, crop_w, crop_h, image_path); uint8_t* cropped_image_buffer = (uint8_t*)malloc(crop_w * crop_h * expected_channel); if (cropped_image_buffer == nullptr) { - fprintf(stderr, "error: allocate memory for crop\n"); + LOG_ERROR("error: allocate memory for crop\n"); free(image_buffer); return nullptr; } @@ -1760,13 +1875,13 @@ uint8_t* load_image_common(bool from_memory, image_buffer = cropped_image_buffer; } - printf("resize input image from %dx%d to %dx%d\n", width, height, expected_width, expected_height); + LOG_INFO("resize input image from %dx%d to %dx%d", width, height, expected_width, expected_height); int resized_height = expected_height; int resized_width = expected_width; uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * expected_channel); if (resized_image_buffer == nullptr) { - fprintf(stderr, "error: allocate memory for resize input image\n"); + LOG_ERROR("error: allocate memory for resize input image\n"); free(image_buffer); return nullptr; } diff --git a/otherarch/sdcpp/denoiser.hpp b/otherarch/sdcpp/denoiser.hpp index 32f402786..fc5230d7b 100644 --- a/otherarch/sdcpp/denoiser.hpp +++ b/otherarch/sdcpp/denoiser.hpp @@ -347,6 +347,41 @@ struct SmoothStepScheduler : SigmaScheduler { } }; +// Implementation adapted from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608 +struct KLOptimalScheduler : SigmaScheduler { + std::vector get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { + std::vector sigmas; + + if (n == 0) { + return sigmas; + } + if (n == 1) { + sigmas.push_back(sigma_max); + sigmas.push_back(0.0f); + return sigmas; + } + + float alpha_min = std::atan(sigma_min); + float alpha_max = std::atan(sigma_max); + + for (uint32_t i = 0; i < n; ++i) { + // t goes from 0.0 to 1.0 + float t = static_cast(i) / static_cast(n-1); + + // Interpolate in the angle domain + float angle = t * alpha_min + (1.0f - t) * alpha_max; + + // Convert back to sigma + sigmas.push_back(std::tan(angle)); + } + + // Append the final zero to sigma + sigmas.push_back(0.0f); + + return sigmas; + } +}; + struct Denoiser { virtual float sigma_min() = 0; virtual float sigma_max() = 0; @@ -392,6 +427,10 @@ struct Denoiser { LOG_INFO("get_sigmas with SmoothStep scheduler"); scheduler = std::make_shared(); break; + case KL_OPTIMAL_SCHEDULER: + LOG_INFO("get_sigmas with KL Optimal scheduler"); + scheduler = std::make_shared(); + break; case LCM_SCHEDULER: LOG_INFO("get_sigmas with LCM scheduler"); scheduler = std::make_shared(); diff --git a/otherarch/sdcpp/flux.hpp b/otherarch/sdcpp/flux.hpp index 1df2874ae..7ce263569 100644 --- a/otherarch/sdcpp/flux.hpp +++ b/otherarch/sdcpp/flux.hpp @@ -744,6 +744,8 @@ namespace Flux { int64_t nerf_mlp_ratio = 4; int64_t nerf_depth = 4; int64_t nerf_max_freqs = 8; + bool use_x0 = false; + bool use_patch_size_32 = false; }; struct FluxParams { @@ -781,7 +783,7 @@ namespace Flux { Flux(FluxParams params) : params(params) { if (params.version == VERSION_CHROMA_RADIANCE) { - std::pair kernel_size = {(int)params.patch_size, (int)params.patch_size}; + std::pair kernel_size = {16, 16}; std::pair stride = kernel_size; blocks["img_in_patch"] = std::make_shared(params.in_channels, @@ -1044,6 +1046,15 @@ namespace Flux { return img; } + struct ggml_tensor* _apply_x0_residual(GGMLRunnerContext* ctx, + struct ggml_tensor* predicted, + struct ggml_tensor* noisy, + struct ggml_tensor* timesteps) { + auto x = ggml_sub(ctx->ggml_ctx, noisy, predicted); + x = ggml_div(ctx->ggml_ctx, x, timesteps); + return x; + } + struct ggml_tensor* forward_chroma_radiance(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* timestep, @@ -1068,6 +1079,13 @@ namespace Flux { auto img = pad_to_patch_size(ctx->ggml_ctx, x); auto orig_img = img; + if (params.chroma_radiance_params.use_patch_size_32) { + // It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable + // Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch? + // img = F.interpolate(img, size=(H//2, W//2), mode="nearest") + img = ggml_interpolate(ctx->ggml_ctx, img, W / 2, H / 2, C, x->ne[3], GGML_SCALE_MODE_BILINEAR); + } + auto img_in_patch = std::dynamic_pointer_cast(blocks["img_in_patch"]); img = img_in_patch->forward(ctx, img); // [N, hidden_size, H/patch_size, W/patch_size] @@ -1104,6 +1122,10 @@ namespace Flux { out = nerf_final_layer_conv->forward(ctx, img_dct); // [N, C, H, W] + if (params.chroma_radiance_params.use_x0) { + out = _apply_x0_residual(ctx, out, orig_img, timestep); + } + return out; } @@ -1290,6 +1312,15 @@ namespace Flux { // not schnell flux_params.guidance_embed = true; } + if (tensor_name.find("__x0__") != std::string::npos) { + LOG_DEBUG("using x0 prediction"); + flux_params.chroma_radiance_params.use_x0 = true; + } + if (tensor_name.find("__32x32__") != std::string::npos) { + LOG_DEBUG("using patch size 32 prediction"); + flux_params.chroma_radiance_params.use_patch_size_32 = true; + flux_params.patch_size = 32; + } if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { // Chroma flux_params.is_chroma = true; diff --git a/otherarch/sdcpp/ggml_extend.hpp b/otherarch/sdcpp/ggml_extend.hpp index f76aaef42..7b355d3ee 100644 --- a/otherarch/sdcpp/ggml_extend.hpp +++ b/otherarch/sdcpp/ggml_extend.hpp @@ -848,8 +848,6 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input, LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y); LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor); - GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2 - int tile_overlap_x = (int32_t)(p_tile_size_x * tile_overlap_factor_x); int non_tile_overlap_x = p_tile_size_x - tile_overlap_x; diff --git a/otherarch/sdcpp/main.cpp b/otherarch/sdcpp/main.cpp index e37434b71..a1b92f77e 100644 --- a/otherarch/sdcpp/main.cpp +++ b/otherarch/sdcpp/main.cpp @@ -31,7 +31,6 @@ struct SDCliParams { std::string output_path = "output.png"; bool verbose = false; - bool version = false; bool canny_preprocess = false; preview_t preview_method = PREVIEW_NONE; @@ -74,10 +73,6 @@ struct SDCliParams { "--verbose", "print extra info", true, &verbose}, - {"", - "--version", - "print stable-diffusion.cpp version", - true, &version}, {"", "--color", "colors the logging tags according to level", @@ -106,9 +101,8 @@ struct SDCliParams { } } if (mode_found == -1) { - fprintf(stderr, - "error: invalid mode %s, must be one of [%s]\n", - mode_c_str, SD_ALL_MODES_STR); + LOG_ERROR("error: invalid mode %s, must be one of [%s]\n", + mode_c_str, SD_ALL_MODES_STR); exit(1); } mode = (SDMode)mode_found; @@ -128,8 +122,7 @@ struct SDCliParams { } } if (preview_found == -1) { - fprintf(stderr, "error: preview method %s\n", - preview); + LOG_ERROR("error: preview method %s", preview); return -1; } preview_method = (preview_t)preview_found; @@ -161,7 +154,7 @@ struct SDCliParams { bool process_and_check() { if (output_path.length() == 0) { - fprintf(stderr, "error: the following arguments are required: output_path\n"); + LOG_ERROR("error: the following arguments are required: output_path"); return false; } @@ -219,18 +212,6 @@ void parse_args(int argc, const char** argv, SDCliParams& cli_params, SDContextP } } -static std::string sd_basename(const std::string& path) { - size_t pos = path.find_last_of('/'); - if (pos != std::string::npos) { - return path.substr(pos + 1); - } - pos = path.find_last_of('\\'); - if (pos != std::string::npos) { - return path.substr(pos + 1); - } - return path; -} - std::string get_image_params(const SDCliParams& cli_params, const SDContextParams& ctx_params, const SDGenerationParams& gen_params, int64_t seed) { std::string parameter_string = gen_params.prompt_with_lora + "\n"; if (gen_params.negative_prompt.size() != 0) { @@ -288,47 +269,9 @@ std::string get_image_params(const SDCliParams& cli_params, const SDContextParam return parameter_string; } -/* Enables Printing the log level tag in color using ANSI escape codes */ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) { SDCliParams* cli_params = (SDCliParams*)data; - int tag_color; - const char* level_str; - FILE* out_stream = (level == SD_LOG_ERROR) ? stderr : stdout; - - if (!log || (!cli_params->verbose && level <= SD_LOG_DEBUG)) { - return; - } - - switch (level) { - case SD_LOG_DEBUG: - tag_color = 37; - level_str = "DEBUG"; - break; - case SD_LOG_INFO: - tag_color = 34; - level_str = "INFO"; - break; - case SD_LOG_WARN: - tag_color = 35; - level_str = "WARN"; - break; - case SD_LOG_ERROR: - tag_color = 31; - level_str = "ERROR"; - break; - default: /* Potential future-proofing */ - tag_color = 33; - level_str = "?????"; - break; - } - - if (cli_params->color == true) { - fprintf(out_stream, "\033[%d;1m[%-5s]\033[0m ", tag_color, level_str); - } else { - fprintf(out_stream, "[%-5s] ", level_str); - } - fputs(log, out_stream); - fflush(out_stream); + log_print(level, log, cli_params->verbose, cli_params->color); } bool load_images_from_dir(const std::string dir, @@ -338,7 +281,7 @@ bool load_images_from_dir(const std::string dir, int max_image_num = 0, bool verbose = false) { if (!fs::exists(dir) || !fs::is_directory(dir)) { - fprintf(stderr, "'%s' is not a valid directory\n", dir.c_str()); + LOG_ERROR("'%s' is not a valid directory\n", dir.c_str()); return false; } @@ -360,14 +303,12 @@ bool load_images_from_dir(const std::string dir, std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); if (ext == ".jpg" || ext == ".jpeg" || ext == ".png" || ext == ".bmp") { - if (verbose) { - printf("load image %zu from '%s'\n", images.size(), path.c_str()); - } + LOG_DEBUG("load image %zu from '%s'", images.size(), path.c_str()); int width = 0; int height = 0; uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height, expected_width, expected_height); if (image_buffer == nullptr) { - fprintf(stderr, "load image from '%s' failed\n", path.c_str()); + LOG_ERROR("load image from '%s' failed", path.c_str()); return false; } @@ -408,9 +349,6 @@ int main(int argc, const char* argv[]) { SDGenerationParams gen_params; parse_args(argc, argv, cli_params, ctx_params, gen_params); - if (cli_params.verbose || cli_params.version) { - std::cout << version_string() << "\n"; - } if (gen_params.video_frames > 4) { size_t last_dot_pos = cli_params.preview_path.find_last_of("."); std::string base_path = cli_params.preview_path; @@ -429,6 +367,8 @@ int main(int argc, const char* argv[]) { cli_params.preview_fps /= 4; sd_set_log_callback(sd_log_cb, (void*)&cli_params); + log_verbose = cli_params.verbose; + log_color = cli_params.color; sd_set_preview_callback(step_callback, cli_params.preview_method, cli_params.preview_interval, @@ -436,12 +376,11 @@ int main(int argc, const char* argv[]) { cli_params.preview_noisy, (void*)&cli_params); - if (cli_params.verbose) { - printf("%s", sd_get_system_info()); - printf("%s\n", cli_params.to_string().c_str()); - printf("%s\n", ctx_params.to_string().c_str()); - printf("%s\n", gen_params.to_string().c_str()); - } + LOG_DEBUG("version: %s", version_string().c_str()); + LOG_DEBUG("%s", sd_get_system_info()); + LOG_DEBUG("%s", cli_params.to_string().c_str()); + LOG_DEBUG("%s", ctx_params.to_string().c_str()); + LOG_DEBUG("%s", gen_params.to_string().c_str()); if (cli_params.mode == CONVERT) { bool success = convert(ctx_params.model_path.c_str(), @@ -450,17 +389,16 @@ int main(int argc, const char* argv[]) { ctx_params.wtype, ctx_params.tensor_type_rules.c_str()); if (!success) { - fprintf(stderr, - "convert '%s'/'%s' to '%s' failed\n", - ctx_params.model_path.c_str(), - ctx_params.vae_path.c_str(), - cli_params.output_path.c_str()); + LOG_ERROR("convert '%s'/'%s' to '%s' failed", + ctx_params.model_path.c_str(), + ctx_params.vae_path.c_str(), + cli_params.output_path.c_str()); return 1; } else { - printf("convert '%s'/'%s' to '%s' success\n", - ctx_params.model_path.c_str(), - ctx_params.vae_path.c_str(), - cli_params.output_path.c_str()); + LOG_INFO("convert '%s'/'%s' to '%s' success", + ctx_params.model_path.c_str(), + ctx_params.vae_path.c_str(), + cli_params.output_path.c_str()); return 0; } } @@ -503,7 +441,7 @@ int main(int argc, const char* argv[]) { int height = 0; init_image.data = load_image_from_file(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height); if (init_image.data == nullptr) { - fprintf(stderr, "load image from '%s' failed\n", gen_params.init_image_path.c_str()); + LOG_ERROR("load image from '%s' failed", gen_params.init_image_path.c_str()); release_all_resources(); return 1; } @@ -516,7 +454,7 @@ int main(int argc, const char* argv[]) { int height = 0; end_image.data = load_image_from_file(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height); if (end_image.data == nullptr) { - fprintf(stderr, "load image from '%s' failed\n", gen_params.end_image_path.c_str()); + LOG_ERROR("load image from '%s' failed", gen_params.end_image_path.c_str()); release_all_resources(); return 1; } @@ -528,7 +466,7 @@ int main(int argc, const char* argv[]) { int height = 0; mask_image.data = load_image_from_file(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1); if (mask_image.data == nullptr) { - fprintf(stderr, "load image from '%s' failed\n", gen_params.mask_image_path.c_str()); + LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str()); release_all_resources(); return 1; } @@ -536,7 +474,7 @@ int main(int argc, const char* argv[]) { mask_image.data = (uint8_t*)malloc(gen_params.width * gen_params.height); memset(mask_image.data, 255, gen_params.width * gen_params.height); if (mask_image.data == nullptr) { - fprintf(stderr, "malloc mask image failed\n"); + LOG_ERROR("malloc mask image failed"); release_all_resources(); return 1; } @@ -547,7 +485,7 @@ int main(int argc, const char* argv[]) { int height = 0; control_image.data = load_image_from_file(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height); if (control_image.data == nullptr) { - fprintf(stderr, "load image from '%s' failed\n", gen_params.control_image_path.c_str()); + LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str()); release_all_resources(); return 1; } @@ -568,7 +506,7 @@ int main(int argc, const char* argv[]) { int height = 0; uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height); if (image_buffer == nullptr) { - fprintf(stderr, "load image from '%s' failed\n", path.c_str()); + LOG_ERROR("load image from '%s' failed", path.c_str()); release_all_resources(); return 1; } @@ -616,7 +554,7 @@ int main(int argc, const char* argv[]) { num_results = 1; results = (sd_image_t*)calloc(num_results, sizeof(sd_image_t)); if (results == nullptr) { - printf("failed to allocate results array\n"); + LOG_INFO("failed to allocate results array"); release_all_resources(); return 1; } @@ -627,7 +565,7 @@ int main(int argc, const char* argv[]) { sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params); if (sd_ctx == nullptr) { - printf("new_sd_ctx_t failed\n"); + LOG_INFO("new_sd_ctx_t failed"); release_all_resources(); return 1; } @@ -641,7 +579,7 @@ int main(int argc, const char* argv[]) { } if (gen_params.sample_params.scheduler == SCHEDULER_COUNT) { - gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx); + gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx, gen_params.sample_params.sample_method); } if (cli_params.mode == IMG_GEN) { @@ -704,7 +642,7 @@ int main(int argc, const char* argv[]) { } if (results == nullptr) { - printf("generate failed\n"); + LOG_ERROR("generate failed"); free_sd_ctx(sd_ctx); return 1; } @@ -721,7 +659,7 @@ int main(int argc, const char* argv[]) { gen_params.upscale_tile_size); if (upscaler_ctx == nullptr) { - printf("new_upscaler_ctx failed\n"); + LOG_ERROR("new_upscaler_ctx failed"); } else { for (int i = 0; i < num_results; i++) { if (results[i].data == nullptr) { @@ -731,7 +669,7 @@ int main(int argc, const char* argv[]) { for (int u = 0; u < gen_params.upscale_repeats; ++u) { sd_image_t upscaled_image = upscale(upscaler_ctx, current_image, upscale_factor); if (upscaled_image.data == nullptr) { - printf("upscale failed\n"); + LOG_ERROR("upscale failed"); break; } free(current_image.data); @@ -749,8 +687,8 @@ int main(int argc, const char* argv[]) { std::error_code ec; fs::create_directories(out_dir, ec); // OK if already exists if (ec) { - fprintf(stderr, "failed to create directory '%s': %s\n", - out_dir.string().c_str(), ec.message().c_str()); + LOG_ERROR("failed to create directory '%s': %s", + out_dir.string().c_str(), ec.message().c_str()); return 1; } } @@ -780,7 +718,7 @@ int main(int argc, const char* argv[]) { vid_output_path = base_path + ".avi"; } create_mjpg_avi_from_sd_images(vid_output_path.c_str(), results, num_results, gen_params.fps); - printf("save result MJPG AVI video to '%s'\n", vid_output_path.c_str()); + LOG_INFO("save result MJPG AVI video to '%s'\n", vid_output_path.c_str()); } else { // appending ".png" to absent or unknown extension if (!is_jpg && file_ext_lower != ".png") { @@ -796,11 +734,11 @@ int main(int argc, const char* argv[]) { if (is_jpg) { write_ok = stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, results[i].data, 90, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str()); - printf("save result JPEG image to '%s' (%s)\n", final_image_path.c_str(), write_ok == 0 ? "failure" : "success"); + LOG_INFO("save result JPEG image to '%s' (%s)", final_image_path.c_str(), write_ok == 0 ? "failure" : "success"); } else { write_ok = stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, results[i].data, 0, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str()); - printf("save result PNG image to '%s' (%s)\n", final_image_path.c_str(), write_ok == 0 ? "failure" : "success"); + LOG_INFO("save result PNG image to '%s' (%s)", final_image_path.c_str(), write_ok == 0 ? "failure" : "success"); } } } diff --git a/otherarch/sdcpp/model.cpp b/otherarch/sdcpp/model.cpp index 682a72c81..f2edfb9e2 100644 --- a/otherarch/sdcpp/model.cpp +++ b/otherarch/sdcpp/model.cpp @@ -1574,6 +1574,11 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread i64_to_i32_vec((int64_t*)read_buf, (int32_t*)target_buf, tensor_storage.nelements()); } if (tensor_storage.type != dst_tensor->type) { + if (convert_buf == nullptr) { + LOG_ERROR("read tensor data failed: too less memory for conversion"); + failed = true; + return; + } convert_tensor((void*)target_buf, tensor_storage.type, convert_buf, @@ -1786,6 +1791,13 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type // tensor_storage.ne[0], tensor_storage.ne[1], tensor_storage.ne[2], tensor_storage.ne[3], // tensor->n_dims, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + if (!tensor->data) { + GGML_ASSERT(ggml_nelements(tensor) == 0); + // avoid crashing the gguf writer by setting a dummy pointer for zero-sized tensors + LOG_DEBUG("setting dummy pointer for zero-sized tensor %s", name.c_str()); + tensor->data = ggml_get_mem_buffer(ggml_ctx); + } + *dst_tensor = tensor; gguf_add_tensor(gguf_ctx, tensor); diff --git a/otherarch/sdcpp/name_conversion.cpp b/otherarch/sdcpp/name_conversion.cpp index 8b521486d..6a8ae72c0 100644 --- a/otherarch/sdcpp/name_conversion.cpp +++ b/otherarch/sdcpp/name_conversion.cpp @@ -835,6 +835,7 @@ std::string convert_sep_to_dot(std::string name) { "proj_out", "transformer_blocks", "single_transformer_blocks", + "single_blocks", "diffusion_model", "cond_stage_model", "first_stage_model", @@ -876,7 +877,18 @@ std::string convert_sep_to_dot(std::string name) { "ff_context", "norm_added_q", "norm_added_v", - "to_add_out"}; + "to_add_out", + "txt_mod", + "img_mod", + "txt_mlp", + "img_mlp", + "proj_mlp", + "wi_0", + "wi_1", + "norm1_context", + "ff_context", + "x_embedder", + }; // record the positions of underscores that should NOT be replaced std::unordered_set protected_positions; @@ -1020,12 +1032,14 @@ std::string convert_tensor_name(std::string name, SDVersion version) { } } - if (sd_version_is_unet(version) || is_lycoris_underline) { + // LOG_DEBUG("name %s %d", name.c_str(), version); + + if (sd_version_is_unet(version) || sd_version_is_flux(version) || is_lycoris_underline) { name = convert_sep_to_dot(name); } } - std::vector> prefix_map = { + std::unordered_map prefix_map = { {"diffusion_model.", "model.diffusion_model."}, {"unet.", "model.diffusion_model."}, {"transformer.", "model.diffusion_model."}, // dit @@ -1040,8 +1054,13 @@ std::string convert_tensor_name(std::string name, SDVersion version) { // {"te2.text_model.encoder.layers.", "cond_stage_model.1.model.transformer.resblocks."}, {"te2.", "cond_stage_model.1.transformer."}, {"te1.", "cond_stage_model.transformer."}, + {"te3.", "text_encoders.t5xxl.transformer."}, }; + if (sd_version_is_flux(version)) { + prefix_map["te1."] = "text_encoders.clip_l.transformer."; + } + replace_with_prefix_map(name, prefix_map); // diffusion model diff --git a/otherarch/sdcpp/stable-diffusion.cpp b/otherarch/sdcpp/stable-diffusion.cpp index c802eace2..1b56669bf 100644 --- a/otherarch/sdcpp/stable-diffusion.cpp +++ b/otherarch/sdcpp/stable-diffusion.cpp @@ -136,7 +136,6 @@ public: std::map tensors; - std::string lora_model_dir; // lora_name => multiplier std::unordered_map curr_lora_state; @@ -219,7 +218,6 @@ public: n_threads = sd_ctx_params->n_threads; vae_decode_only = sd_ctx_params->vae_decode_only; free_params_immediately = sd_ctx_params->free_params_immediately; - lora_model_dir = SAFE_STR(sd_ctx_params->lora_model_dir); taesd_path = SAFE_STR(sd_ctx_params->taesd_path); use_tiny_autoencoder = taesd_path.size() > 0; offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu; @@ -418,6 +416,14 @@ public: { to_replace = "taesd_3.embd"; } + else if(version == VERSION_WAN2_2_TI2V) + { + to_replace = "taesd_w22.embd"; + } + else if(sd_version_is_wan(version)||sd_version_is_qwen_image(version)) + { + to_replace = "taesd_w21.embd"; + } if(to_replace!="") { @@ -432,6 +438,12 @@ public: taesd_path_fixed = ""; use_tiny_autoencoder = false; } + if (use_tiny_autoencoder && !file_exists(taesd_path_fixed)) + { + printf("\nCannot use TAESD: \"%s\" not found. TAESD Disabled!\n", taesd_path_fixed); + taesd_path_fixed = ""; + use_tiny_autoencoder = false; + } } ggml_type wtype = (int)sd_ctx_params->wtype < std::min(SD_TYPE_COUNT, GGML_TYPE_COUNT) @@ -663,6 +675,9 @@ public: if (sd_ctx_params->diffusion_flash_attn) { LOG_INFO("Using flash attention in the diffusion model"); diffusion_model->set_flash_attn_enabled(true); + if (high_noise_diffusion_model) { + high_noise_diffusion_model->set_flash_attn_enabled(true); + } } cond_stage_model->alloc_params_buffer(); @@ -688,14 +703,27 @@ public: } if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { - first_stage_model = std::make_shared(vae_backend, - offload_params_to_cpu, - tensor_storage_map, - "first_stage_model", - vae_decode_only, - version); - first_stage_model->alloc_params_buffer(); - first_stage_model->get_param_tensors(tensors, "first_stage_model"); + if (!use_tiny_autoencoder) { + first_stage_model = std::make_shared(vae_backend, + offload_params_to_cpu, + tensor_storage_map, + "first_stage_model", + vae_decode_only, + version); + first_stage_model->alloc_params_buffer(); + first_stage_model->get_param_tensors(tensors, "first_stage_model"); + } else { + tae_first_stage = std::make_shared(vae_backend, + offload_params_to_cpu, + tensor_storage_map, + "decoder", + vae_decode_only, + version); + if (sd_ctx_params->vae_conv_direct) { + LOG_INFO("Using Conv2d direct in the tae model"); + tae_first_stage->set_conv2d_direct_enabled(true); + } + } } else if (version == VERSION_CHROMA_RADIANCE) { first_stage_model = std::make_shared(vae_backend, offload_params_to_cpu); @@ -722,14 +750,13 @@ public: } first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); - } - if (use_tiny_autoencoder) { - tae_first_stage = std::make_shared(vae_backend, - offload_params_to_cpu, - tensor_storage_map, - "decoder.layers", - vae_decode_only, - version); + } else if (use_tiny_autoencoder) { + tae_first_stage = std::make_shared(vae_backend, + offload_params_to_cpu, + tensor_storage_map, + "decoder.layers", + vae_decode_only, + version); if (sd_ctx_params->vae_conv_direct) { LOG_INFO("Using Conv2d direct in the tae model"); tae_first_stage->set_conv2d_direct_enabled(true); @@ -823,6 +850,8 @@ public: if (stacked_id) { ignore_tensors.insert("pmid.unet."); } + ignore_tensors.insert("model.diffusion_model.__x0__"); + ignore_tensors.insert("model.diffusion_model.__32x32__"); if (vae_decode_only) { ignore_tensors.insert("first_stage_model.encoder"); @@ -957,6 +986,7 @@ public: } } else if (sd_version_is_flux(version)) { pred_type = FLUX_FLOW_PRED; + if (flow_shift == INFINITY) { flow_shift = 1.0f; // TODO: validate for (const auto& [name, tensor_storage] : tensor_storage_map) { @@ -1612,6 +1642,17 @@ public: std::vector skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count); float cfg_scale = guidance.txt_cfg; + if (cfg_scale < 1.f) { + if (cfg_scale == 0.f) { + // Diffusers follow the convention from the original paper + // (https://arxiv.org/abs/2207.12598v1), so many distilled model docs + // recommend 0 as guidance; warn the user that it'll disable prompt folowing + LOG_WARN("unconditioned mode, images won't follow the prompt (use cfg-scale=1 for distilled models)"); + } else { + LOG_WARN("cfg value out of expected range may produce unexpected results"); + } + } + float img_cfg_scale = std::isfinite(guidance.img_cfg) ? guidance.img_cfg : guidance.txt_cfg; float slg_scale = guidance.slg.scale; @@ -2527,6 +2568,7 @@ const char* scheduler_to_str[] = { "sgm_uniform", "simple", "smoothstep", + "kl_optimal", "lcm", }; @@ -2664,7 +2706,6 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { "vae_path: %s\n" "taesd_path: %s\n" "control_net_path: %s\n" - "lora_model_dir: %s\n" "photo_maker_path: %s\n" "tensor_type_rules: %s\n" "vae_decode_only: %s\n" @@ -2694,7 +2735,6 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { SAFE_STR(sd_ctx_params->vae_path), SAFE_STR(sd_ctx_params->taesd_path), SAFE_STR(sd_ctx_params->control_net_path), - SAFE_STR(sd_ctx_params->lora_model_dir), SAFE_STR(sd_ctx_params->photo_maker_path), SAFE_STR(sd_ctx_params->tensor_type_rules), BOOL_STR(sd_ctx_params->vae_decode_only), @@ -2893,13 +2933,16 @@ enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) { return EULER_A_SAMPLE_METHOD; } -enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx) { +enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_method_t sample_method) { if (sd_ctx != nullptr && sd_ctx->sd != nullptr) { auto edm_v_denoiser = std::dynamic_pointer_cast(sd_ctx->sd->denoiser); if (edm_v_denoiser) { return EXPONENTIAL_SCHEDULER; } } + if (sample_method == LCM_SAMPLE_METHOD) { + return LCM_SCHEDULER; + } return DISCRETE_SCHEDULER; } @@ -3334,9 +3377,13 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g LOG_WARN("sample_steps != custom_sigmas_count - 1, set sample_steps to %d", sample_steps); } } else { + scheduler_t scheduler = sd_img_gen_params->sample_params.scheduler; + if (scheduler == SCHEDULER_COUNT) { + scheduler = sd_get_default_scheduler(sd_ctx, sample_method); + } sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps, sd_ctx->sd->get_image_seq_len(height, width), - sd_img_gen_params->sample_params.scheduler, + scheduler, sd_ctx->sd->version); } @@ -3619,9 +3666,13 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s } } } else { + scheduler_t scheduler = sd_vid_gen_params->sample_params.scheduler; + if (scheduler == SCHEDULER_COUNT) { + scheduler = sd_get_default_scheduler(sd_ctx, sample_method); + } sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps, 0, - sd_vid_gen_params->sample_params.scheduler, + scheduler, sd_ctx->sd->version); } @@ -3746,7 +3797,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1); ggml_set_f32(denoise_mask, 1.f); - sd_ctx->sd->process_latent_out(init_latent); + if (!sd_ctx->sd->use_tiny_autoencoder) + sd_ctx->sd->process_latent_out(init_latent); ggml_ext_tensor_iter(init_image_latent, [&](ggml_tensor* t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { float value = ggml_ext_tensor_get_f32(t, i0, i1, i2, i3); @@ -3756,7 +3808,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s } }); - sd_ctx->sd->process_latent_in(init_latent); + if (!sd_ctx->sd->use_tiny_autoencoder) + sd_ctx->sd->process_latent_in(init_latent); int64_t t2 = ggml_time_ms(); LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1); @@ -3979,7 +4032,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s struct ggml_tensor* vid = sd_ctx->sd->decode_first_stage(work_ctx, final_latent, true); int64_t t5 = ggml_time_ms(); LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000); - if (sd_ctx->sd->free_params_immediately) { + if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) { sd_ctx->sd->first_stage_model->free_params_buffer(); } diff --git a/otherarch/sdcpp/stable-diffusion.h b/otherarch/sdcpp/stable-diffusion.h index e4abc8dcd..adb65a1d2 100644 --- a/otherarch/sdcpp/stable-diffusion.h +++ b/otherarch/sdcpp/stable-diffusion.h @@ -60,6 +60,7 @@ enum scheduler_t { SGM_UNIFORM_SCHEDULER, SIMPLE_SCHEDULER, SMOOTHSTEP_SCHEDULER, + KL_OPTIMAL_SCHEDULER, LCM_SCHEDULER, SCHEDULER_COUNT }; @@ -168,7 +169,6 @@ typedef struct { const char* vae_path; const char* taesd_path; const char* control_net_path; - const char* lora_model_dir; const sd_embedding_t* embeddings; uint32_t embedding_count; const char* photo_maker_path; @@ -335,7 +335,7 @@ SD_API void sd_sample_params_init(sd_sample_params_t* sample_params); SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params); SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx); -SD_API enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx); +SD_API enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_method_t sample_method); SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params); SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params); diff --git a/otherarch/sdcpp/tae.hpp b/otherarch/sdcpp/tae.hpp index 7f3ca449a..5da76e692 100644 --- a/otherarch/sdcpp/tae.hpp +++ b/otherarch/sdcpp/tae.hpp @@ -162,6 +162,311 @@ public: } }; +class TPool : public UnaryBlock { + int stride; + +public: + TPool(int channels, int stride) + : stride(stride) { + blocks["conv"] = std::shared_ptr(new Conv2d(channels * stride, channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false)); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + auto h = x; + if (stride != 1) { + h = ggml_reshape_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2] * stride, h->ne[3] / stride); + } + h = conv->forward(ctx, h); + return h; + } +}; + +class TGrow : public UnaryBlock { + int stride; + +public: + TGrow(int channels, int stride) + : stride(stride) { + blocks["conv"] = std::shared_ptr(new Conv2d(channels, channels * stride, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false)); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + auto h = conv->forward(ctx, x); + if (stride != 1) { + h = ggml_reshape_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2] / stride, h->ne[3] * stride); + } + return h; + } +}; + +class MemBlock : public GGMLBlock { + bool has_skip_conv = false; + +public: + MemBlock(int channels, int out_channels) + : has_skip_conv(channels != out_channels) { + blocks["conv.0"] = std::shared_ptr(new Conv2d(channels * 2, out_channels, {3, 3}, {1, 1}, {1, 1})); + blocks["conv.2"] = std::shared_ptr(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1})); + blocks["conv.4"] = std::shared_ptr(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1})); + if (has_skip_conv) { + blocks["skip"] = std::shared_ptr(new Conv2d(channels, out_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false)); + } + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* past) { + // x: [n, channels, h, w] + auto conv0 = std::dynamic_pointer_cast(blocks["conv.0"]); + auto conv1 = std::dynamic_pointer_cast(blocks["conv.2"]); + auto conv2 = std::dynamic_pointer_cast(blocks["conv.4"]); + + auto h = ggml_concat(ctx->ggml_ctx, x, past, 2); + h = conv0->forward(ctx, h); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + h = conv1->forward(ctx, h); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + h = conv2->forward(ctx, h); + + auto skip = x; + if (has_skip_conv) { + auto skip_conv = std::dynamic_pointer_cast(blocks["skip"]); + skip = skip_conv->forward(ctx, x); + } + h = ggml_add_inplace(ctx->ggml_ctx, h, skip); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + return h; + } +}; + +struct ggml_tensor* patchify(struct ggml_context* ctx, + struct ggml_tensor* x, + int64_t patch_size, + int64_t b = 1) { + // x: [f, b*c, h*q, w*r] + // return: [f, b*c*r*q, h, w] + if (patch_size == 1) { + return x; + } + int64_t r = patch_size; + int64_t q = patch_size; + + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; + int64_t C = x->ne[2]; + int64_t f = x->ne[3]; + + int64_t w = W / r; + int64_t h = H / q; + + x = ggml_reshape_4d(ctx, x, W, q, h, C * f); // [W, q, h, C*f] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [W, h, q, C*f] + x = ggml_reshape_4d(ctx, x, r, w, h, q * C * f); // [r, w, h, q*C*f] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [w, h, r, q*C*f] + x = ggml_reshape_4d(ctx, x, w, h, r * q * C, f); // [f, b*c*r*q, h, w] + + return x; +} + +struct ggml_tensor* unpatchify(struct ggml_context* ctx, + struct ggml_tensor* x, + int64_t patch_size, + int64_t b = 1) { + // x: [f, b*c*r*q, h, w] + // return: [f, b*c, h*q, w*r] + if (patch_size == 1) { + return x; + } + int64_t r = patch_size; + int64_t q = patch_size; + int64_t c = x->ne[2] / b / q / r; + int64_t f = x->ne[3]; + int64_t h = x->ne[1]; + int64_t w = x->ne[0]; + + x = ggml_reshape_4d(ctx, x, w, h, r, q * c * b * f); // [q*c*b*f, r, h, w] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [r, w, h, q*c*b*f] + x = ggml_reshape_4d(ctx, x, r * w, h, q, c * b * f); // [c*b*f, q, h, r*w] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [r*w, q, h, c*b*f] + x = ggml_reshape_4d(ctx, x, r * w, q * h, c * b, f); + + return x; +} + +class TinyVideoEncoder : public UnaryBlock { + int in_channels = 3; + int hidden = 64; + int z_channels = 4; + int num_blocks = 3; + int num_layers = 3; + int patch_size = 1; + +public: + TinyVideoEncoder(int z_channels = 4, int patch_size = 1) + : z_channels(z_channels), patch_size(patch_size) { + int index = 0; + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(in_channels * patch_size * patch_size, hidden, {3, 3}, {1, 1}, {1, 1})); + index++; // nn.ReLU() + for (int i = 0; i < num_layers; i++) { + int stride = i == num_layers - 1 ? 1 : 2; + blocks[std::to_string(index++)] = std::shared_ptr(new TPool(hidden, stride)); + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(hidden, hidden, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false)); + for (int j = 0; j < num_blocks; j++) { + blocks[std::to_string(index++)] = std::shared_ptr(new MemBlock(hidden, hidden)); + } + } + blocks[std::to_string(index)] = std::shared_ptr(new Conv2d(hidden, z_channels, {3, 3}, {1, 1}, {1, 1})); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override { + auto first_conv = std::dynamic_pointer_cast(blocks["0"]); + + if (patch_size > 1) { + z = patchify(ctx->ggml_ctx, z, patch_size, 1); + } + + auto h = first_conv->forward(ctx, z); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + + int index = 2; + for (int i = 0; i < num_layers; i++) { + auto pool = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + auto conv = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + + h = pool->forward(ctx, h); + h = conv->forward(ctx, h); + for (int j = 0; j < num_blocks; j++) { + auto block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0); + mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0); + h = block->forward(ctx, h, mem); + } + } + auto last_conv = std::dynamic_pointer_cast(blocks[std::to_string(index)]); + h = last_conv->forward(ctx, h); + return h; + } +}; + +class TinyVideoDecoder : public UnaryBlock { + int z_channels = 4; + int out_channels = 3; + int num_blocks = 3; + static const int num_layers = 3; + int channels[num_layers + 1] = {256, 128, 64, 64}; + int patch_size = 1; + +public: + TinyVideoDecoder(int z_channels = 4, int patch_size = 1) + : z_channels(z_channels), patch_size(patch_size) { + int index = 1; // Clamp() + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(z_channels, channels[0], {3, 3}, {1, 1}, {1, 1})); + index++; // nn.ReLU() + for (int i = 0; i < num_layers; i++) { + int stride = i == 0 ? 1 : 2; + for (int j = 0; j < num_blocks; j++) { + blocks[std::to_string(index++)] = std::shared_ptr(new MemBlock(channels[i], channels[i])); + } + index++; // nn.Upsample() + blocks[std::to_string(index++)] = std::shared_ptr(new TGrow(channels[i], stride)); + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels[i], channels[i + 1], {3, 3}, {1, 1}, {1, 1}, {1, 1}, false)); + } + index++; // nn.ReLU() + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels[num_layers], out_channels * patch_size * patch_size, {3, 3}, {1, 1}, {1, 1})); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override { + auto first_conv = std::dynamic_pointer_cast(blocks["1"]); + + // Clamp() + auto h = ggml_scale_inplace(ctx->ggml_ctx, + ggml_tanh_inplace(ctx->ggml_ctx, + ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)), + 3.0f); + + h = first_conv->forward(ctx, h); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + int index = 3; + for (int i = 0; i < num_layers; i++) { + for (int j = 0; j < num_blocks; j++) { + auto block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0); + mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0); + h = block->forward(ctx, h, mem); + } + // upsample + index++; + h = ggml_upscale(ctx->ggml_ctx, h, 2, GGML_SCALE_MODE_NEAREST); + auto block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + h = block->forward(ctx, h); + block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + h = block->forward(ctx, h); + } + h = ggml_relu_inplace(ctx->ggml_ctx, h); + + auto last_conv = std::dynamic_pointer_cast(blocks[std::to_string(++index)]); + h = last_conv->forward(ctx, h); + if (patch_size > 1) { + h = unpatchify(ctx->ggml_ctx, h, patch_size, 1); + } + // shape(W, H, 3, 3 + T) => shape(W, H, 3, T) + h = ggml_view_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3] - 3, h->nb[1], h->nb[2], h->nb[3], 3 * h->nb[3]); + return h; + } +}; + +class TAEHV : public GGMLBlock { +protected: + bool decode_only; + SDVersion version; + +public: + TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2) + : decode_only(decode_only), version(version) { + int z_channels = 16; + int patch = 1; + if (version == VERSION_WAN2_2_TI2V) { + z_channels = 48; + patch = 2; + } + blocks["decoder"] = std::shared_ptr(new TinyVideoDecoder(z_channels, patch)); + if (!decode_only) { + blocks["encoder"] = std::shared_ptr(new TinyVideoEncoder(z_channels, patch)); + } + } + + struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { + auto decoder = std::dynamic_pointer_cast(blocks["decoder"]); + if (sd_version_is_wan(version)) { + // (W, H, C, T) -> (W, H, T, C) + z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 0, 1, 3, 2)); + } + auto result = decoder->forward(ctx, z); + if (sd_version_is_wan(version)) { + // (W, H, C, T) -> (W, H, T, C) + result = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, result, 0, 1, 3, 2)); + } + return result; + } + + struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) { + auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); + // (W, H, T, C) -> (W, H, C, T) + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); + int64_t num_frames = x->ne[3]; + if (num_frames % 4) { + // pad to multiple of 4 at the end + auto last_frame = ggml_view_4d(ctx->ggml_ctx, x, x->ne[0], x->ne[1], x->ne[2], 1, x->nb[1], x->nb[2], x->nb[3], (num_frames - 1) * x->nb[3]); + for (int i = 0; i < 4 - num_frames % 4; i++) { + x = ggml_concat(ctx->ggml_ctx, x, last_frame, 3); + } + } + x = encoder->forward(ctx, x); + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); + return x; + } +}; + class TAESD : public GGMLBlock { protected: bool decode_only; @@ -192,18 +497,30 @@ public: }; struct TinyAutoEncoder : public GGMLRunner { + TinyAutoEncoder(ggml_backend_t backend, bool offload_params_to_cpu) + : GGMLRunner(backend, offload_params_to_cpu) {} + virtual bool compute(const int n_threads, + struct ggml_tensor* z, + bool decode_graph, + struct ggml_tensor** output, + struct ggml_context* output_ctx = nullptr) = 0; + + virtual bool load_from_file(const std::string& file_path, int n_threads) = 0; +}; + +struct TinyImageAutoEncoder : public TinyAutoEncoder { TAESD taesd; bool decode_only = false; - TinyAutoEncoder(ggml_backend_t backend, - bool offload_params_to_cpu, - const String2TensorStorage& tensor_storage_map, - const std::string prefix, - bool decoder_only = true, - SDVersion version = VERSION_SD1) + TinyImageAutoEncoder(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string prefix, + bool decoder_only = true, + SDVersion version = VERSION_SD1) : decode_only(decoder_only), taesd(decoder_only, version), - GGMLRunner(backend, offload_params_to_cpu) { + TinyAutoEncoder(backend, offload_params_to_cpu) { taesd.init(params_ctx, tensor_storage_map, prefix); } @@ -260,4 +577,73 @@ struct TinyAutoEncoder : public GGMLRunner { } }; +struct TinyVideoAutoEncoder : public TinyAutoEncoder { + TAEHV taehv; + bool decode_only = false; + + TinyVideoAutoEncoder(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string prefix, + bool decoder_only = true, + SDVersion version = VERSION_WAN2) + : decode_only(decoder_only), + taehv(decoder_only, version), + TinyAutoEncoder(backend, offload_params_to_cpu) { + taehv.init(params_ctx, tensor_storage_map, prefix); + } + + std::string get_desc() override { + return "taehv"; + } + + bool load_from_file(const std::string& file_path, int n_threads) { + LOG_INFO("loading taehv from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false"); + alloc_params_buffer(); + std::map taehv_tensors; + taehv.get_param_tensors(taehv_tensors); + std::set ignore_tensors; + if (decode_only) { + ignore_tensors.insert("encoder."); + } + + ModelLoader model_loader; + if (!model_loader.init_from_file(file_path)) { + LOG_ERROR("init taehv model loader from file failed: '%s'", file_path.c_str()); + return false; + } + + bool success = model_loader.load_tensors(taehv_tensors, ignore_tensors, n_threads); + + if (!success) { + LOG_ERROR("load tae tensors from model loader failed"); + return false; + } + + LOG_INFO("taehv model loaded"); + return success; + } + + struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { + struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); + z = to_backend(z); + auto runner_ctx = get_context(); + struct ggml_tensor* out = decode_graph ? taehv.decode(&runner_ctx, z) : taehv.encode(&runner_ctx, z); + ggml_build_forward_expand(gf, out); + return gf; + } + + bool compute(const int n_threads, + struct ggml_tensor* z, + bool decode_graph, + struct ggml_tensor** output, + struct ggml_context* output_ctx = nullptr) { + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph(z, decode_graph); + }; + + return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + } +}; + #endif // __TAE_HPP__ \ No newline at end of file