mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 09:34:37 +00:00
updated sdcpp prepare for inpaint
fixed img2img (+1 squashed commits) Squashed commits: [42c48f14] try update sdcpp, feels kind of buggy
This commit is contained in:
parent
ebf924c5d1
commit
fea3b2bd4a
18 changed files with 1850 additions and 271 deletions
|
@ -166,6 +166,7 @@ public:
|
|||
// ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
class UnetModelBlock : public GGMLBlock {
|
||||
protected:
|
||||
static std::map<std::string, enum ggml_type> empty_tensor_types;
|
||||
SDVersion version = VERSION_SD1;
|
||||
// network hparams
|
||||
int in_channels = 4;
|
||||
|
@ -183,13 +184,13 @@ public:
|
|||
int model_channels = 320;
|
||||
int adm_in_channels = 2816; // only for VERSION_SDXL/SVD
|
||||
|
||||
UnetModelBlock(SDVersion version = VERSION_SD1, bool flash_attn = false)
|
||||
UnetModelBlock(SDVersion version = VERSION_SD1, std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types, bool flash_attn = false)
|
||||
: version(version) {
|
||||
if (version == VERSION_SD2) {
|
||||
if (sd_version_is_sd2(version)) {
|
||||
context_dim = 1024;
|
||||
num_head_channels = 64;
|
||||
num_heads = -1;
|
||||
} else if (version == VERSION_SDXL) {
|
||||
} else if (sd_version_is_sdxl(version)) {
|
||||
context_dim = 2048;
|
||||
attention_resolutions = {4, 2};
|
||||
channel_mult = {1, 2, 4};
|
||||
|
@ -204,6 +205,10 @@ public:
|
|||
num_head_channels = 64;
|
||||
num_heads = -1;
|
||||
}
|
||||
if (sd_version_is_inpaint(version)) {
|
||||
in_channels = 9;
|
||||
}
|
||||
|
||||
// dims is always 2
|
||||
// use_temporal_attention is always True for SVD
|
||||
|
||||
|
@ -211,7 +216,7 @@ public:
|
|||
// time_embed_1 is nn.SiLU()
|
||||
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
|
||||
|
||||
if (version == VERSION_SDXL || version == VERSION_SVD) {
|
||||
if (sd_version_is_sdxl(version) || version == VERSION_SVD) {
|
||||
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
|
||||
// label_emb_1 is nn.SiLU()
|
||||
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
|
||||
|
@ -536,7 +541,7 @@ struct UNetModelRunner : public GGMLRunner {
|
|||
const std::string prefix,
|
||||
SDVersion version = VERSION_SD1,
|
||||
bool flash_attn = false)
|
||||
: GGMLRunner(backend), unet(version, flash_attn) {
|
||||
: GGMLRunner(backend), unet(version, tensor_types, flash_attn) {
|
||||
unet.init(params_ctx, tensor_types, prefix);
|
||||
}
|
||||
|
||||
|
@ -566,6 +571,7 @@ struct UNetModelRunner : public GGMLRunner {
|
|||
context = to_backend(context);
|
||||
y = to_backend(y);
|
||||
timesteps = to_backend(timesteps);
|
||||
c_concat = to_backend(c_concat);
|
||||
|
||||
for (int i = 0; i < controls.size(); i++) {
|
||||
controls[i] = to_backend(controls[i]);
|
||||
|
@ -651,4 +657,4 @@ struct UNetModelRunner : public GGMLRunner {
|
|||
}
|
||||
};
|
||||
|
||||
#endif // __UNET_HPP__
|
||||
#endif // __UNET_HPP__
|
Loading…
Add table
Add a link
Reference in a new issue