sd 3.5 medium

This commit is contained in:
Concedo 2024-11-03 23:27:06 +08:00
parent f32a874966
commit 5233e8ed1d
5 changed files with 204 additions and 49 deletions

View file

@ -252,6 +252,7 @@ struct DismantledBlock : public GGMLBlock {
public: public:
int64_t num_heads; int64_t num_heads;
bool pre_only; bool pre_only;
bool self_attn;
public: public:
DismantledBlock(int64_t hidden_size, DismantledBlock(int64_t hidden_size,
@ -259,14 +260,19 @@ public:
float mlp_ratio = 4.0, float mlp_ratio = 4.0,
std::string qk_norm = "", std::string qk_norm = "",
bool qkv_bias = false, bool qkv_bias = false,
bool pre_only = false) bool pre_only = false,
: num_heads(num_heads), pre_only(pre_only) { bool self_attn = false)
: num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) {
// rmsnorm is always Flase // rmsnorm is always Flase
// scale_mod_only is always Flase // scale_mod_only is always Flase
// swiglu is always Flase // swiglu is always Flase
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false)); blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only)); blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only));
if (self_attn) {
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false));
}
if (!pre_only) { if (!pre_only) {
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false)); blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
int64_t mlp_hidden_dim = (int64_t)(hidden_size * mlp_ratio); int64_t mlp_hidden_dim = (int64_t)(hidden_size * mlp_ratio);
@ -277,9 +283,52 @@ public:
if (pre_only) { if (pre_only) {
n_mods = 2; n_mods = 2;
} }
if (self_attn) {
n_mods = 9;
}
blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, n_mods * hidden_size)); blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, n_mods * hidden_size));
} }
std::tuple<std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>> pre_attention_x(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* c) {
GGML_ASSERT(self_attn);
// x: [N, n_token, hidden_size]
// c: [N, hidden_size]
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]);
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
int64_t n_mods = 9;
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, n_mods * hidden_size]
m = ggml_reshape_3d(ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size]
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1];
auto shift_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
auto scale_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
auto gate_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size]
auto shift_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size]
auto scale_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size]
auto gate_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size]
auto shift_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size]
auto scale_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size]
auto gate_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size]
auto x_norm = norm1->forward(ctx, x);
auto attn_in = modulate(ctx, x_norm, shift_msa, scale_msa);
auto qkv = attn->pre_attention(ctx, attn_in);
auto attn2_in = modulate(ctx, x_norm, shift_msa2, scale_msa2);
auto qkv2 = attn2->pre_attention(ctx, attn2_in);
return {qkv, qkv2, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2}};
}
std::pair<std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>> pre_attention(struct ggml_context* ctx, std::pair<std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>> pre_attention(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* c) { struct ggml_tensor* c) {
@ -319,6 +368,44 @@ public:
} }
} }
struct ggml_tensor* post_attention_x(struct ggml_context* ctx,
struct ggml_tensor* attn_out,
struct ggml_tensor* attn2_out,
struct ggml_tensor* x,
struct ggml_tensor* gate_msa,
struct ggml_tensor* shift_mlp,
struct ggml_tensor* scale_mlp,
struct ggml_tensor* gate_mlp,
struct ggml_tensor* gate_msa2) {
// attn_out: [N, n_token, hidden_size]
// x: [N, n_token, hidden_size]
// gate_msa: [N, hidden_size]
// shift_mlp: [N, hidden_size]
// scale_mlp: [N, hidden_size]
// gate_mlp: [N, hidden_size]
// return: [N, n_token, hidden_size]
GGML_ASSERT(!pre_only);
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]);
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
auto mlp = std::dynamic_pointer_cast<Mlp>(blocks["mlp"]);
gate_msa = ggml_reshape_3d(ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size]
gate_mlp = ggml_reshape_3d(ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size]
gate_msa2 = ggml_reshape_3d(ctx, gate_msa2, gate_msa2->ne[0], 1, gate_msa2->ne[1]); // [N, 1, hidden_size]
attn_out = attn->post_attention(ctx, attn_out);
attn2_out = attn2->post_attention(ctx, attn2_out);
x = ggml_add(ctx, x, ggml_mul(ctx, attn_out, gate_msa));
x = ggml_add(ctx, x, ggml_mul(ctx, attn2_out, gate_msa2));
auto mlp_out = mlp->forward(ctx, modulate(ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp));
x = ggml_add(ctx, x, ggml_mul(ctx, mlp_out, gate_mlp));
return x;
}
struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* post_attention(struct ggml_context* ctx,
struct ggml_tensor* attn_out, struct ggml_tensor* attn_out,
struct ggml_tensor* x, struct ggml_tensor* x,
@ -357,29 +444,52 @@ public:
// return: [N, n_token, hidden_size] // return: [N, n_token, hidden_size]
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]); auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
if (self_attn) {
auto qkv_intermediates = pre_attention_x(ctx, x, c);
// auto qkv = qkv_intermediates.first;
// auto intermediates = qkv_intermediates.second;
// no longer a pair, but a tuple
auto qkv = std::get<0>(qkv_intermediates);
auto qkv2 = std::get<1>(qkv_intermediates);
auto intermediates = std::get<2>(qkv_intermediates);
auto qkv_intermediates = pre_attention(ctx, x, c); auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
auto qkv = qkv_intermediates.first; auto attn2_out = ggml_nn_attention_ext(ctx, qkv2[0], qkv2[1], qkv2[2], num_heads); // [N, n_token, dim]
auto intermediates = qkv_intermediates.second; x = post_attention_x(ctx,
attn_out,
attn2_out,
intermediates[0],
intermediates[1],
intermediates[2],
intermediates[3],
intermediates[4],
intermediates[5]);
return x; // [N, n_token, dim]
} else {
auto qkv_intermediates = pre_attention(ctx, x, c);
auto qkv = qkv_intermediates.first;
auto intermediates = qkv_intermediates.second;
auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
x = post_attention(ctx, x = post_attention(ctx,
attn_out, attn_out,
intermediates[0], intermediates[0],
intermediates[1], intermediates[1],
intermediates[2], intermediates[2],
intermediates[3], intermediates[3],
intermediates[4]); intermediates[4]);
return x; // [N, n_token, dim] return x; // [N, n_token, dim]
}
} }
}; };
__STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*> block_mixing(struct ggml_context* ctx, __STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*>
struct ggml_tensor* context, block_mixing(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* context,
struct ggml_tensor* c, struct ggml_tensor* x,
std::shared_ptr<DismantledBlock> context_block, struct ggml_tensor* c,
std::shared_ptr<DismantledBlock> x_block) { std::shared_ptr<DismantledBlock> context_block,
std::shared_ptr<DismantledBlock> x_block) {
// context: [N, n_context, hidden_size] // context: [N, n_context, hidden_size]
// x: [N, n_token, hidden_size] // x: [N, n_token, hidden_size]
// c: [N, hidden_size] // c: [N, hidden_size]
@ -387,10 +497,18 @@ __STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*> block_mixi
auto context_qkv = context_qkv_intermediates.first; auto context_qkv = context_qkv_intermediates.first;
auto context_intermediates = context_qkv_intermediates.second; auto context_intermediates = context_qkv_intermediates.second;
auto x_qkv_intermediates = x_block->pre_attention(ctx, x, c); std::vector<ggml_tensor*> x_qkv, x_qkv2, x_intermediates;
auto x_qkv = x_qkv_intermediates.first;
auto x_intermediates = x_qkv_intermediates.second;
if (x_block->self_attn) {
auto x_qkv_intermediates = x_block->pre_attention_x(ctx, x, c);
x_qkv = std::get<0>(x_qkv_intermediates);
x_qkv2 = std::get<1>(x_qkv_intermediates);
x_intermediates = std::get<2>(x_qkv_intermediates);
} else {
auto x_qkv_intermediates = x_block->pre_attention(ctx, x, c);
x_qkv = x_qkv_intermediates.first;
x_intermediates = x_qkv_intermediates.second;
}
std::vector<struct ggml_tensor*> qkv; std::vector<struct ggml_tensor*> qkv;
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1)); qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1));
@ -429,13 +547,27 @@ __STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*> block_mixi
context = NULL; context = NULL;
} }
x = x_block->post_attention(ctx, if (x_block->self_attn) {
x_attn, auto attn2 = ggml_nn_attention_ext(ctx, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads); // [N, n_token, hidden_size]
x_intermediates[0],
x_intermediates[1], x = x_block->post_attention_x(ctx,
x_intermediates[2], x_attn,
x_intermediates[3], attn2,
x_intermediates[4]); x_intermediates[0],
x_intermediates[1],
x_intermediates[2],
x_intermediates[3],
x_intermediates[4],
x_intermediates[5]);
} else {
x = x_block->post_attention(ctx,
x_attn,
x_intermediates[0],
x_intermediates[1],
x_intermediates[2],
x_intermediates[3],
x_intermediates[4]);
}
return {context, x}; return {context, x};
} }
@ -447,9 +579,10 @@ public:
float mlp_ratio = 4.0, float mlp_ratio = 4.0,
std::string qk_norm = "", std::string qk_norm = "",
bool qkv_bias = false, bool qkv_bias = false,
bool pre_only = false) { bool pre_only = false,
bool self_attn_x = false) {
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only)); blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only));
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false)); blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x));
} }
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx, std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
@ -507,6 +640,7 @@ protected:
int64_t input_size = -1; int64_t input_size = -1;
int64_t patch_size = 2; int64_t patch_size = 2;
int64_t in_channels = 16; int64_t in_channels = 16;
int64_t d_self = -1; // >=0 for MMdiT-X
int64_t depth = 24; int64_t depth = 24;
float mlp_ratio = 4.0f; float mlp_ratio = 4.0f;
int64_t adm_in_channels = 2048; int64_t adm_in_channels = 2048;
@ -561,6 +695,20 @@ public:
context_size = 4096; context_size = 4096;
context_embedder_out_dim = 2432; context_embedder_out_dim = 2432;
qk_norm = "rms"; qk_norm = "rms";
} else if (version == VERSION_SD3_5_2B) {
input_size = -1;
patch_size = 2;
in_channels = 16;
depth = 24;
d_self = 12;
mlp_ratio = 4.0f;
adm_in_channels = 2048;
out_channels = 16;
pos_embed_max_size = 384;
num_patchs = 147456;
context_size = 4096;
context_embedder_out_dim = 1536;
qk_norm = "rms";
} }
int64_t default_out_channels = in_channels; int64_t default_out_channels = in_channels;
hidden_size = 64 * depth; hidden_size = 64 * depth;
@ -581,15 +729,17 @@ public:
mlp_ratio, mlp_ratio,
qk_norm, qk_norm,
true, true,
i == depth - 1)); i == depth - 1,
i <= d_self));
} }
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new FinalLayer(hidden_size, patch_size, out_channels)); blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new FinalLayer(hidden_size, patch_size, out_channels));
} }
struct ggml_tensor* cropped_pos_embed(struct ggml_context* ctx, struct ggml_tensor*
int64_t h, cropped_pos_embed(struct ggml_context* ctx,
int64_t w) { int64_t h,
int64_t w) {
auto pos_embed = params["pos_embed"]; auto pos_embed = params["pos_embed"];
h = (h + 1) / patch_size; h = (h + 1) / patch_size;

View file

@ -1376,6 +1376,9 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
is_flux = true; is_flux = true;
} }
if (tensor_storage.name.find("joint_blocks.0.x_block.attn2.ln_q.weight") != std::string::npos) {
return VERSION_SD3_5_2B;
}
if (tensor_storage.name.find("joint_blocks.37.x_block.attn.ln_q.weight") != std::string::npos) { if (tensor_storage.name.find("joint_blocks.37.x_block.attn.ln_q.weight") != std::string::npos) {
return VERSION_SD3_5_8B; return VERSION_SD3_5_8B;
} }

View file

@ -26,6 +26,7 @@ enum SDVersion {
VERSION_FLUX_DEV, VERSION_FLUX_DEV,
VERSION_FLUX_SCHNELL, VERSION_FLUX_SCHNELL,
VERSION_SD3_5_8B, VERSION_SD3_5_8B,
VERSION_SD3_5_2B,
VERSION_COUNT, VERSION_COUNT,
}; };

View file

@ -31,7 +31,8 @@ const char* model_version_to_str[] = {
"SD3 2B", "SD3 2B",
"Flux Dev", "Flux Dev",
"Flux Schnell", "Flux Schnell",
"SD3.5 8B"}; "SD3.5 8B",
"SD3.5 2B"};
const char* sampling_methods_str[] = { const char* sampling_methods_str[] = {
"Euler A", "Euler A",
@ -294,7 +295,7 @@ public:
"try specifying SDXL VAE FP16 Fix with the --vae parameter. " "try specifying SDXL VAE FP16 Fix with the --vae parameter. "
"You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors"); "You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors");
} }
} else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { } else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
scale_factor = 1.5305f; scale_factor = 1.5305f;
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
scale_factor = 0.3611; scale_factor = 0.3611;
@ -317,7 +318,7 @@ public:
} else { } else {
clip_backend = backend; clip_backend = backend;
bool use_t5xxl = false; bool use_t5xxl = false;
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
use_t5xxl = true; use_t5xxl = true;
} }
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) { if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) {
@ -328,7 +329,7 @@ public:
LOG_INFO("CLIP: Using CPU backend"); LOG_INFO("CLIP: Using CPU backend");
clip_backend = ggml_backend_cpu_init(); clip_backend = ggml_backend_cpu_init();
} }
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype); cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype);
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version); diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
@ -526,7 +527,7 @@ public:
is_using_v_parameterization = true; is_using_v_parameterization = true;
} }
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
LOG_INFO("running in FLOW mode"); LOG_INFO("running in FLOW mode");
denoiser = std::make_shared<DiscreteFlowDenoiser>(); denoiser = std::make_shared<DiscreteFlowDenoiser>();
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
@ -986,7 +987,7 @@ public:
if (use_tiny_autoencoder) { if (use_tiny_autoencoder) {
C = 4; C = 4;
} else { } else {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
C = 32; C = 32;
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
C = 32; C = 32;
@ -1325,7 +1326,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
// Sample // Sample
std::vector<struct ggml_tensor*> final_latents; // collect latents to decode std::vector<struct ggml_tensor*> final_latents; // collect latents to decode
int C = 4; int C = 4;
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
C = 16; C = 16;
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
C = 16; C = 16;
@ -1438,7 +1439,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
struct ggml_init_params params; struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
params.mem_size *= 3; params.mem_size *= 3;
} }
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
@ -1464,7 +1465,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
int C = 4; int C = 4;
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
C = 16; C = 16;
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
C = 16; C = 16;
@ -1472,7 +1473,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
int W = width / 8; int W = width / 8;
int H = height / 8; int H = height / 8;
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
ggml_set_f32(init_latent, 0.0609f); ggml_set_f32(init_latent, 0.0609f);
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
ggml_set_f32(init_latent, 0.1159f); ggml_set_f32(init_latent, 0.1159f);
@ -1533,7 +1534,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
struct ggml_init_params params; struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
params.mem_size *= 2; params.mem_size *= 2;
} }
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
@ -1571,7 +1572,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
} else { } else {
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
} }
print_ggml_tensor(init_latent, true); // print_ggml_tensor(init_latent, true);
size_t t1 = ggml_time_ms(); size_t t1 = ggml_time_ms();
LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);

View file

@ -457,7 +457,7 @@ public:
bool use_video_decoder = false, bool use_video_decoder = false,
SDVersion version = VERSION_SD1) SDVersion version = VERSION_SD1)
: decode_only(decode_only), use_video_decoder(use_video_decoder) { : decode_only(decode_only), use_video_decoder(use_video_decoder) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
dd_config.z_channels = 16; dd_config.z_channels = 16;
use_quant = false; use_quant = false;
} }