resync and updated sdcpp for flux and sd3 support

This commit is contained in:
Concedo 2024-11-03 22:03:16 +08:00
parent 33721615b5
commit f32a874966
30 changed files with 2434248 additions and 1729 deletions

View file

@ -14,7 +14,7 @@
*/
class ControlNetBlock : public GGMLBlock {
protected:
SDVersion version = VERSION_1_x;
SDVersion version = VERSION_SD1;
// network hparams
int in_channels = 4;
int out_channels = 4;
@ -26,19 +26,19 @@ protected:
int time_embed_dim = 1280; // model_channels*4
int num_heads = 8;
int num_head_channels = -1; // channels // num_heads
int context_dim = 768; // 1024 for VERSION_2_x, 2048 for VERSION_XL
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL
public:
int model_channels = 320;
int adm_in_channels = 2816; // only for VERSION_XL
int adm_in_channels = 2816; // only for VERSION_SDXL
ControlNetBlock(SDVersion version = VERSION_1_x)
ControlNetBlock(SDVersion version = VERSION_SD1)
: version(version) {
if (version == VERSION_2_x) {
if (version == VERSION_SD2) {
context_dim = 1024;
num_head_channels = 64;
num_heads = -1;
} else if (version == VERSION_XL) {
} else if (version == VERSION_SDXL) {
context_dim = 2048;
attention_resolutions = {4, 2};
channel_mult = {1, 2, 4};
@ -58,7 +58,7 @@ public:
// time_embed_1 is nn.SiLU()
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
if (version == VERSION_XL || version == VERSION_SVD) {
if (version == VERSION_SDXL || version == VERSION_SVD) {
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
// label_emb_1 is nn.SiLU()
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
@ -306,8 +306,8 @@ public:
}
};
struct ControlNet : public GGMLModule {
SDVersion version = VERSION_1_x;
struct ControlNet : public GGMLRunner {
SDVersion version = VERSION_SD1;
ControlNetBlock control_net;
ggml_backend_buffer_t control_buffer = NULL; // keep control output tensors in backend memory
@ -318,8 +318,8 @@ struct ControlNet : public GGMLModule {
ControlNet(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_1_x)
: GGMLModule(backend, wtype), control_net(version) {
SDVersion version = VERSION_SD1)
: GGMLRunner(backend, wtype), control_net(version) {
control_net.init(params_ctx, wtype);
}
@ -369,14 +369,6 @@ struct ControlNet : public GGMLModule {
return "control_net";
}
size_t get_params_mem_size() {
return control_net.get_params_mem_size();
}
size_t get_params_num() {
return control_net.get_params_num();
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
control_net.get_param_tensors(tensors, prefix);
}
@ -434,7 +426,7 @@ struct ControlNet : public GGMLModule {
return build_graph(x, hint, timesteps, context, y);
};
GGMLModule::compute(get_graph, n_threads, false, output, output_ctx);
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
guided_hint_cached = true;
}