resync and updated sdcpp for flux and sd3 support

This commit is contained in:
Concedo 2024-11-03 22:03:16 +08:00
parent 33721615b5
commit f32a874966
30 changed files with 2434248 additions and 1729 deletions

View file

@ -166,7 +166,7 @@ public:
// ldm.modules.diffusionmodules.openaimodel.UNetModel
class UnetModelBlock : public GGMLBlock {
protected:
SDVersion version = VERSION_1_x;
SDVersion version = VERSION_SD1;
// network hparams
int in_channels = 4;
int out_channels = 4;
@ -177,19 +177,19 @@ protected:
int time_embed_dim = 1280; // model_channels*4
int num_heads = 8;
int num_head_channels = -1; // channels // num_heads
int context_dim = 768; // 1024 for VERSION_2_x, 2048 for VERSION_XL
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL
public:
int model_channels = 320;
int adm_in_channels = 2816; // only for VERSION_XL/SVD
int adm_in_channels = 2816; // only for VERSION_SDXL/SVD
UnetModelBlock(SDVersion version = VERSION_1_x)
UnetModelBlock(SDVersion version = VERSION_SD1)
: version(version) {
if (version == VERSION_2_x) {
if (version == VERSION_SD2) {
context_dim = 1024;
num_head_channels = 64;
num_heads = -1;
} else if (version == VERSION_XL) {
} else if (version == VERSION_SDXL) {
context_dim = 2048;
attention_resolutions = {4, 2};
channel_mult = {1, 2, 4};
@ -211,7 +211,7 @@ public:
// time_embed_1 is nn.SiLU()
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
if (version == VERSION_XL || version == VERSION_SVD) {
if (version == VERSION_SDXL || version == VERSION_SVD) {
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
// label_emb_1 is nn.SiLU()
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
@ -528,14 +528,13 @@ public:
}
};
struct UNetModel : public GGMLModule {
SDVersion version = VERSION_1_x;
struct UNetModelRunner : public GGMLRunner {
UnetModelBlock unet;
UNetModel(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_1_x)
: GGMLModule(backend, wtype), unet(version) {
UNetModelRunner(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_SD1)
: GGMLRunner(backend, wtype), unet(version) {
unet.init(params_ctx, wtype);
}
@ -543,14 +542,6 @@ struct UNetModel : public GGMLModule {
return "unet";
}
size_t get_params_mem_size() {
return unet.get_params_mem_size();
}
size_t get_params_num() {
return unet.get_params_num();
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
unet.get_param_tensors(tensors, prefix);
}
@ -613,7 +604,7 @@ struct UNetModel : public GGMLModule {
return build_graph(x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength);
};
GGMLModule::compute(get_graph, n_threads, false, output, output_ctx);
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
}
void test() {
@ -655,7 +646,7 @@ struct UNetModel : public GGMLModule {
print_ggml_tensor(out);
LOG_DEBUG("unet test done in %dms", t1 - t0);
}
};
}
};
#endif // __UNET_HPP__