updated sdcpp prepare for inpaint

fixed img2img (+1 squashed commits)

Squashed commits:

[42c48f14] try update sdcpp, feels kind of buggy
This commit is contained in:
Concedo 2025-04-08 23:47:12 +08:00
parent ebf924c5d1
commit fea3b2bd4a
18 changed files with 1850 additions and 271 deletions

View file

@ -166,6 +166,7 @@ public:
// ldm.modules.diffusionmodules.openaimodel.UNetModel
class UnetModelBlock : public GGMLBlock {
protected:
static std::map<std::string, enum ggml_type> empty_tensor_types;
SDVersion version = VERSION_SD1;
// network hparams
int in_channels = 4;
@ -183,13 +184,13 @@ public:
int model_channels = 320;
int adm_in_channels = 2816; // only for VERSION_SDXL/SVD
UnetModelBlock(SDVersion version = VERSION_SD1, bool flash_attn = false)
UnetModelBlock(SDVersion version = VERSION_SD1, std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types, bool flash_attn = false)
: version(version) {
if (version == VERSION_SD2) {
if (sd_version_is_sd2(version)) {
context_dim = 1024;
num_head_channels = 64;
num_heads = -1;
} else if (version == VERSION_SDXL) {
} else if (sd_version_is_sdxl(version)) {
context_dim = 2048;
attention_resolutions = {4, 2};
channel_mult = {1, 2, 4};
@ -204,6 +205,10 @@ public:
num_head_channels = 64;
num_heads = -1;
}
if (sd_version_is_inpaint(version)) {
in_channels = 9;
}
// dims is always 2
// use_temporal_attention is always True for SVD
@ -211,7 +216,7 @@ public:
// time_embed_1 is nn.SiLU()
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
if (version == VERSION_SDXL || version == VERSION_SVD) {
if (sd_version_is_sdxl(version) || version == VERSION_SVD) {
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
// label_emb_1 is nn.SiLU()
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
@ -536,7 +541,7 @@ struct UNetModelRunner : public GGMLRunner {
const std::string prefix,
SDVersion version = VERSION_SD1,
bool flash_attn = false)
: GGMLRunner(backend), unet(version, flash_attn) {
: GGMLRunner(backend), unet(version, tensor_types, flash_attn) {
unet.init(params_ctx, tensor_types, prefix);
}
@ -566,6 +571,7 @@ struct UNetModelRunner : public GGMLRunner {
context = to_backend(context);
y = to_backend(y);
timesteps = to_backend(timesteps);
c_concat = to_backend(c_concat);
for (int i = 0; i < controls.size(); i++) {
controls[i] = to_backend(controls[i]);
@ -651,4 +657,4 @@ struct UNetModelRunner : public GGMLRunner {
}
};
#endif // __UNET_HPP__
#endif // __UNET_HPP__