diff --git a/otherarch/sdcpp/mmdit.hpp b/otherarch/sdcpp/mmdit.hpp index 3a278dac7..132a4cdd7 100644 --- a/otherarch/sdcpp/mmdit.hpp +++ b/otherarch/sdcpp/mmdit.hpp @@ -252,6 +252,7 @@ struct DismantledBlock : public GGMLBlock { public: int64_t num_heads; bool pre_only; + bool self_attn; public: DismantledBlock(int64_t hidden_size, @@ -259,14 +260,19 @@ public: float mlp_ratio = 4.0, std::string qk_norm = "", bool qkv_bias = false, - bool pre_only = false) - : num_heads(num_heads), pre_only(pre_only) { + bool pre_only = false, + bool self_attn = false) + : num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) { // rmsnorm is always Flase // scale_mod_only is always Flase // swiglu is always Flase blocks["norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); blocks["attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only)); + if (self_attn) { + blocks["attn2"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false)); + } + if (!pre_only) { blocks["norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); int64_t mlp_hidden_dim = (int64_t)(hidden_size * mlp_ratio); @@ -277,9 +283,52 @@ public: if (pre_only) { n_mods = 2; } + if (self_attn) { + n_mods = 9; + } blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, n_mods * hidden_size)); } + std::tuple, std::vector, std::vector> pre_attention_x(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* c) { + GGML_ASSERT(self_attn); + // x: [N, n_token, hidden_size] + // c: [N, hidden_size] + auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); + auto attn = std::dynamic_pointer_cast(blocks["attn"]); + auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); + auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); + + int64_t n_mods = 9; + auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, n_mods * hidden_size] + m = ggml_reshape_3d(ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size] + m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size] + + int64_t offset = m->nb[1] * m->ne[1]; + auto shift_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] + auto scale_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + auto gate_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size] + + auto shift_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size] + auto scale_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size] + auto gate_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size] + + auto shift_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size] + auto scale_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size] + auto gate_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size] + + auto x_norm = norm1->forward(ctx, x); + + auto attn_in = modulate(ctx, x_norm, shift_msa, scale_msa); + auto qkv = attn->pre_attention(ctx, attn_in); + + auto attn2_in = modulate(ctx, x_norm, shift_msa2, scale_msa2); + auto qkv2 = attn2->pre_attention(ctx, attn2_in); + + return {qkv, qkv2, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2}}; + } + std::pair, std::vector> pre_attention(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* c) { @@ -319,6 +368,44 @@ public: } } + struct ggml_tensor* post_attention_x(struct ggml_context* ctx, + struct ggml_tensor* attn_out, + struct ggml_tensor* attn2_out, + struct ggml_tensor* x, + struct ggml_tensor* gate_msa, + struct ggml_tensor* shift_mlp, + struct ggml_tensor* scale_mlp, + struct ggml_tensor* gate_mlp, + struct ggml_tensor* gate_msa2) { + // attn_out: [N, n_token, hidden_size] + // x: [N, n_token, hidden_size] + // gate_msa: [N, hidden_size] + // shift_mlp: [N, hidden_size] + // scale_mlp: [N, hidden_size] + // gate_mlp: [N, hidden_size] + // return: [N, n_token, hidden_size] + GGML_ASSERT(!pre_only); + + auto attn = std::dynamic_pointer_cast(blocks["attn"]); + auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); + auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); + auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); + + gate_msa = ggml_reshape_3d(ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size] + gate_mlp = ggml_reshape_3d(ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size] + gate_msa2 = ggml_reshape_3d(ctx, gate_msa2, gate_msa2->ne[0], 1, gate_msa2->ne[1]); // [N, 1, hidden_size] + + attn_out = attn->post_attention(ctx, attn_out); + attn2_out = attn2->post_attention(ctx, attn2_out); + + x = ggml_add(ctx, x, ggml_mul(ctx, attn_out, gate_msa)); + x = ggml_add(ctx, x, ggml_mul(ctx, attn2_out, gate_msa2)); + auto mlp_out = mlp->forward(ctx, modulate(ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp)); + x = ggml_add(ctx, x, ggml_mul(ctx, mlp_out, gate_mlp)); + + return x; + } + struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* attn_out, struct ggml_tensor* x, @@ -357,29 +444,52 @@ public: // return: [N, n_token, hidden_size] auto attn = std::dynamic_pointer_cast(blocks["attn"]); + if (self_attn) { + auto qkv_intermediates = pre_attention_x(ctx, x, c); + // auto qkv = qkv_intermediates.first; + // auto intermediates = qkv_intermediates.second; + // no longer a pair, but a tuple + auto qkv = std::get<0>(qkv_intermediates); + auto qkv2 = std::get<1>(qkv_intermediates); + auto intermediates = std::get<2>(qkv_intermediates); - auto qkv_intermediates = pre_attention(ctx, x, c); - auto qkv = qkv_intermediates.first; - auto intermediates = qkv_intermediates.second; + auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] + auto attn2_out = ggml_nn_attention_ext(ctx, qkv2[0], qkv2[1], qkv2[2], num_heads); // [N, n_token, dim] + x = post_attention_x(ctx, + attn_out, + attn2_out, + intermediates[0], + intermediates[1], + intermediates[2], + intermediates[3], + intermediates[4], + intermediates[5]); + return x; // [N, n_token, dim] + } else { + auto qkv_intermediates = pre_attention(ctx, x, c); + auto qkv = qkv_intermediates.first; + auto intermediates = qkv_intermediates.second; - auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] - x = post_attention(ctx, - attn_out, - intermediates[0], - intermediates[1], - intermediates[2], - intermediates[3], - intermediates[4]); - return x; // [N, n_token, dim] + auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] + x = post_attention(ctx, + attn_out, + intermediates[0], + intermediates[1], + intermediates[2], + intermediates[3], + intermediates[4]); + return x; // [N, n_token, dim] + } } }; -__STATIC_INLINE__ std::pair block_mixing(struct ggml_context* ctx, - struct ggml_tensor* context, - struct ggml_tensor* x, - struct ggml_tensor* c, - std::shared_ptr context_block, - std::shared_ptr x_block) { +__STATIC_INLINE__ std::pair +block_mixing(struct ggml_context* ctx, + struct ggml_tensor* context, + struct ggml_tensor* x, + struct ggml_tensor* c, + std::shared_ptr context_block, + std::shared_ptr x_block) { // context: [N, n_context, hidden_size] // x: [N, n_token, hidden_size] // c: [N, hidden_size] @@ -387,10 +497,18 @@ __STATIC_INLINE__ std::pair block_mixi auto context_qkv = context_qkv_intermediates.first; auto context_intermediates = context_qkv_intermediates.second; - auto x_qkv_intermediates = x_block->pre_attention(ctx, x, c); - auto x_qkv = x_qkv_intermediates.first; - auto x_intermediates = x_qkv_intermediates.second; + std::vector x_qkv, x_qkv2, x_intermediates; + if (x_block->self_attn) { + auto x_qkv_intermediates = x_block->pre_attention_x(ctx, x, c); + x_qkv = std::get<0>(x_qkv_intermediates); + x_qkv2 = std::get<1>(x_qkv_intermediates); + x_intermediates = std::get<2>(x_qkv_intermediates); + } else { + auto x_qkv_intermediates = x_block->pre_attention(ctx, x, c); + x_qkv = x_qkv_intermediates.first; + x_intermediates = x_qkv_intermediates.second; + } std::vector qkv; for (int i = 0; i < 3; i++) { qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1)); @@ -429,13 +547,27 @@ __STATIC_INLINE__ std::pair block_mixi context = NULL; } - x = x_block->post_attention(ctx, - x_attn, - x_intermediates[0], - x_intermediates[1], - x_intermediates[2], - x_intermediates[3], - x_intermediates[4]); + if (x_block->self_attn) { + auto attn2 = ggml_nn_attention_ext(ctx, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads); // [N, n_token, hidden_size] + + x = x_block->post_attention_x(ctx, + x_attn, + attn2, + x_intermediates[0], + x_intermediates[1], + x_intermediates[2], + x_intermediates[3], + x_intermediates[4], + x_intermediates[5]); + } else { + x = x_block->post_attention(ctx, + x_attn, + x_intermediates[0], + x_intermediates[1], + x_intermediates[2], + x_intermediates[3], + x_intermediates[4]); + } return {context, x}; } @@ -447,9 +579,10 @@ public: float mlp_ratio = 4.0, std::string qk_norm = "", bool qkv_bias = false, - bool pre_only = false) { + bool pre_only = false, + bool self_attn_x = false) { blocks["context_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only)); - blocks["x_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false)); + blocks["x_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x)); } std::pair forward(struct ggml_context* ctx, @@ -507,6 +640,7 @@ protected: int64_t input_size = -1; int64_t patch_size = 2; int64_t in_channels = 16; + int64_t d_self = -1; // >=0 for MMdiT-X int64_t depth = 24; float mlp_ratio = 4.0f; int64_t adm_in_channels = 2048; @@ -561,6 +695,20 @@ public: context_size = 4096; context_embedder_out_dim = 2432; qk_norm = "rms"; + } else if (version == VERSION_SD3_5_2B) { + input_size = -1; + patch_size = 2; + in_channels = 16; + depth = 24; + d_self = 12; + mlp_ratio = 4.0f; + adm_in_channels = 2048; + out_channels = 16; + pos_embed_max_size = 384; + num_patchs = 147456; + context_size = 4096; + context_embedder_out_dim = 1536; + qk_norm = "rms"; } int64_t default_out_channels = in_channels; hidden_size = 64 * depth; @@ -581,15 +729,17 @@ public: mlp_ratio, qk_norm, true, - i == depth - 1)); + i == depth - 1, + i <= d_self)); } blocks["final_layer"] = std::shared_ptr(new FinalLayer(hidden_size, patch_size, out_channels)); } - struct ggml_tensor* cropped_pos_embed(struct ggml_context* ctx, - int64_t h, - int64_t w) { + struct ggml_tensor* + cropped_pos_embed(struct ggml_context* ctx, + int64_t h, + int64_t w) { auto pos_embed = params["pos_embed"]; h = (h + 1) / patch_size; diff --git a/otherarch/sdcpp/model.cpp b/otherarch/sdcpp/model.cpp index 6a391e90f..dfaff874d 100644 --- a/otherarch/sdcpp/model.cpp +++ b/otherarch/sdcpp/model.cpp @@ -1376,6 +1376,9 @@ SDVersion ModelLoader::get_sd_version() { if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { is_flux = true; } + if (tensor_storage.name.find("joint_blocks.0.x_block.attn2.ln_q.weight") != std::string::npos) { + return VERSION_SD3_5_2B; + } if (tensor_storage.name.find("joint_blocks.37.x_block.attn.ln_q.weight") != std::string::npos) { return VERSION_SD3_5_8B; } diff --git a/otherarch/sdcpp/model.h b/otherarch/sdcpp/model.h index 4efbdf813..041245e37 100644 --- a/otherarch/sdcpp/model.h +++ b/otherarch/sdcpp/model.h @@ -26,6 +26,7 @@ enum SDVersion { VERSION_FLUX_DEV, VERSION_FLUX_SCHNELL, VERSION_SD3_5_8B, + VERSION_SD3_5_2B, VERSION_COUNT, }; diff --git a/otherarch/sdcpp/stable-diffusion.cpp b/otherarch/sdcpp/stable-diffusion.cpp index 1aa78e36f..abc732834 100644 --- a/otherarch/sdcpp/stable-diffusion.cpp +++ b/otherarch/sdcpp/stable-diffusion.cpp @@ -31,7 +31,8 @@ const char* model_version_to_str[] = { "SD3 2B", "Flux Dev", "Flux Schnell", - "SD3.5 8B"}; + "SD3.5 8B", + "SD3.5 2B"}; const char* sampling_methods_str[] = { "Euler A", @@ -294,7 +295,7 @@ public: "try specifying SDXL VAE FP16 Fix with the --vae parameter. " "You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors"); } - } else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { + } else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) { scale_factor = 1.5305f; } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { scale_factor = 0.3611; @@ -317,7 +318,7 @@ public: } else { clip_backend = backend; bool use_t5xxl = false; - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { use_t5xxl = true; } if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) { @@ -328,7 +329,7 @@ public: LOG_INFO("CLIP: Using CPU backend"); clip_backend = ggml_backend_cpu_init(); } - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) { cond_stage_model = std::make_shared(clip_backend, conditioner_wtype); diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { @@ -526,7 +527,7 @@ public: is_using_v_parameterization = true; } - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) { LOG_INFO("running in FLOW mode"); denoiser = std::make_shared(); } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { @@ -986,7 +987,7 @@ public: if (use_tiny_autoencoder) { C = 4; } else { - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) { C = 32; } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { C = 32; @@ -1325,7 +1326,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, // Sample std::vector final_latents; // collect latents to decode int C = 4; - if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { + if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { C = 16; } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { C = 16; @@ -1438,7 +1439,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, struct ggml_init_params params; params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB - if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { + if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { params.mem_size *= 3; } if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { @@ -1464,7 +1465,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); int C = 4; - if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { + if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { C = 16; } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { C = 16; @@ -1472,7 +1473,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, int W = width / 8; int H = height / 8; ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); - if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { + if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { ggml_set_f32(init_latent, 0.0609f); } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { ggml_set_f32(init_latent, 0.1159f); @@ -1533,7 +1534,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, struct ggml_init_params params; params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB - if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { + if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { params.mem_size *= 2; } if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { @@ -1571,7 +1572,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, } else { init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); } - print_ggml_tensor(init_latent, true); + // print_ggml_tensor(init_latent, true); size_t t1 = ggml_time_ms(); LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); diff --git a/otherarch/sdcpp/vae.hpp b/otherarch/sdcpp/vae.hpp index 42b694cd5..50ddf7529 100644 --- a/otherarch/sdcpp/vae.hpp +++ b/otherarch/sdcpp/vae.hpp @@ -457,7 +457,7 @@ public: bool use_video_decoder = false, SDVersion version = VERSION_SD1) : decode_only(decode_only), use_video_decoder(use_video_decoder) { - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { dd_config.z_channels = 16; use_quant = false; }