From 715a76ce5f778e236b90c99b195f162c1fb2f39c Mon Sep 17 00:00:00 2001 From: Wagner Bruna Date: Wed, 21 Jan 2026 09:49:35 -0300 Subject: [PATCH] sd: sync to master-480-b87fe13 (#1932) --- otherarch/sdcpp/flux.hpp | 28 ++++++++++++++++++++-------- otherarch/sdcpp/model.cpp | 7 +++++++ otherarch/sdcpp/model.h | 3 ++- otherarch/sdcpp/stable-diffusion.cpp | 3 ++- otherarch/sdcpp/unet.hpp | 7 +++++-- 5 files changed, 36 insertions(+), 12 deletions(-) diff --git a/otherarch/sdcpp/flux.hpp b/otherarch/sdcpp/flux.hpp index 9826fadee..77a65c557 100644 --- a/otherarch/sdcpp/flux.hpp +++ b/otherarch/sdcpp/flux.hpp @@ -748,7 +748,7 @@ namespace Flux { int nerf_depth = 4; int nerf_max_freqs = 8; bool use_x0 = false; - bool use_patch_size_32 = false; + bool fake_patch_size_x2 = false; }; struct FluxParams { @@ -786,8 +786,11 @@ namespace Flux { Flux(FluxParams params) : params(params) { if (params.version == VERSION_CHROMA_RADIANCE) { - std::pair kernel_size = {16, 16}; - std::pair stride = kernel_size; + std::pair kernel_size = {params.patch_size, params.patch_size}; + if (params.chroma_radiance_params.fake_patch_size_x2) { + kernel_size = {params.patch_size / 2, params.patch_size / 2}; + } + std::pair stride = kernel_size; blocks["img_in_patch"] = std::make_shared(params.in_channels, params.hidden_size, @@ -1082,7 +1085,7 @@ namespace Flux { auto img = pad_to_patch_size(ctx, x); auto orig_img = img; - if (params.chroma_radiance_params.use_patch_size_32) { + if (params.chroma_radiance_params.fake_patch_size_x2) { // It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable // Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch? // img = F.interpolate(img, size=(H//2, W//2), mode="nearest") @@ -1303,7 +1306,8 @@ namespace Flux { flux_params.ref_index_scale = 10.f; flux_params.use_mlp_silu_act = true; } - int64_t head_dim = 0; + int64_t head_dim = 0; + int64_t actual_radiance_patch_size = -1; for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; if (!starts_with(tensor_name, prefix)) @@ -1316,9 +1320,12 @@ namespace Flux { flux_params.chroma_radiance_params.use_x0 = true; } if (tensor_name.find("__32x32__") != std::string::npos) { - LOG_DEBUG("using patch size 32 prediction"); - flux_params.chroma_radiance_params.use_patch_size_32 = true; - flux_params.patch_size = 32; + LOG_DEBUG("using patch size 32"); + flux_params.patch_size = 32; + } + if (tensor_name.find("img_in_patch.weight") != std::string::npos) { + actual_radiance_patch_size = pair.second.ne[0]; + LOG_DEBUG("actual radiance patch size: %d", actual_radiance_patch_size); } if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { // Chroma @@ -1351,6 +1358,11 @@ namespace Flux { head_dim = pair.second.ne[0]; } } + if (actual_radiance_patch_size > 0 && actual_radiance_patch_size != flux_params.patch_size) { + GGML_ASSERT(flux_params.patch_size == 2 * actual_radiance_patch_size); + LOG_DEBUG("using fake x2 patch size"); + flux_params.chroma_radiance_params.fake_patch_size_x2 = true; + } flux_params.num_heads = static_cast(flux_params.hidden_size / head_dim); diff --git a/otherarch/sdcpp/model.cpp b/otherarch/sdcpp/model.cpp index 4cbc258cf..dedd31268 100644 --- a/otherarch/sdcpp/model.cpp +++ b/otherarch/sdcpp/model.cpp @@ -1068,6 +1068,7 @@ SDVersion ModelLoader::get_sd_version() { int64_t patch_embedding_channels = 0; bool has_img_emb = false; bool has_middle_block_1 = false; + bool has_output_block_311 = false; bool has_output_block_71 = false; for (auto& [name, tensor_storage] : tensor_storage_map) { @@ -1128,6 +1129,9 @@ SDVersion ModelLoader::get_sd_version() { tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) { has_middle_block_1 = true; } + if (tensor_storage.name.find("model.diffusion_model.output_blocks.3.1.transformer_blocks.1") != std::string::npos) { + has_output_block_311 = true; + } if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos) { has_output_block_71 = true; } @@ -1166,6 +1170,9 @@ SDVersion ModelLoader::get_sd_version() { return VERSION_SDXL_PIX2PIX; } if (!has_middle_block_1) { + if (!has_output_block_311) { + return VERSION_SDXL_VEGA; + } return VERSION_SDXL_SSD1B; } return VERSION_SDXL; diff --git a/otherarch/sdcpp/model.h b/otherarch/sdcpp/model.h index 1029fb569..1dd07130d 100644 --- a/otherarch/sdcpp/model.h +++ b/otherarch/sdcpp/model.h @@ -32,6 +32,7 @@ enum SDVersion { VERSION_SDXL, VERSION_SDXL_INPAINT, VERSION_SDXL_PIX2PIX, + VERSION_SDXL_VEGA, VERSION_SDXL_SSD1B, VERSION_SVD, VERSION_SD3, @@ -66,7 +67,7 @@ static inline bool sd_version_is_sd2(SDVersion version) { } static inline bool sd_version_is_sdxl(SDVersion version) { - if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B) { + if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B || version == VERSION_SDXL_VEGA) { return true; } return false; diff --git a/otherarch/sdcpp/stable-diffusion.cpp b/otherarch/sdcpp/stable-diffusion.cpp index 07e016ce6..b143883c0 100644 --- a/otherarch/sdcpp/stable-diffusion.cpp +++ b/otherarch/sdcpp/stable-diffusion.cpp @@ -37,6 +37,7 @@ const char* model_version_to_str[] = { "SDXL", "SDXL Inpaint", "SDXL Instruct-Pix2Pix", + "SDXL (Vega)", "SDXL (SSD1B)", "SVD", "SD3.x", @@ -763,7 +764,7 @@ public: LOG_INFO("Using Conv2d direct in the vae model"); first_stage_model->set_conv2d_direct_enabled(true); } - if (version == VERSION_SDXL && + if (sd_version_is_sdxl(version) && (strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) { float vae_conv_2d_scale = 1.f / 32.f; LOG_WARN( diff --git a/otherarch/sdcpp/unet.hpp b/otherarch/sdcpp/unet.hpp index 9fe24e243..6e15e1f45 100644 --- a/otherarch/sdcpp/unet.hpp +++ b/otherarch/sdcpp/unet.hpp @@ -201,6 +201,9 @@ public: num_head_channels = 64; num_heads = -1; use_linear_projection = true; + if (version == VERSION_SDXL_VEGA) { + transformer_depth = {1, 1, 2}; + } } else if (version == VERSION_SVD) { in_channels = 8; out_channels = 4; @@ -319,7 +322,7 @@ public: } if (!tiny_unet) { blocks["middle_block.0"] = std::shared_ptr(get_resblock(ch, time_embed_dim, ch)); - if (version != VERSION_SDXL_SSD1B) { + if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) { blocks["middle_block.1"] = std::shared_ptr(get_attention_layer(ch, n_head, d_head, @@ -520,7 +523,7 @@ public: // middle_block if (!tiny_unet) { h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] - if (version != VERSION_SDXL_SSD1B) { + if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) { h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8] h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] }