resync and updated sdcpp for flux and sd3 support

This commit is contained in:
Concedo 2024-11-03 22:03:16 +08:00
parent 33721615b5
commit f32a874966
30 changed files with 2434248 additions and 1729 deletions

View file

@ -6,7 +6,7 @@
/*================================================== AutoEncoderKL ===================================================*/
#define VAE_GRAPH_SIZE 10240
#define VAE_GRAPH_SIZE 20480
class ResnetBlock : public UnaryBlock {
protected:
@ -439,6 +439,7 @@ class AutoencodingEngine : public GGMLBlock {
protected:
bool decode_only = true;
bool use_video_decoder = false;
bool use_quant = true;
int embed_dim = 4;
struct {
int z_channels = 4;
@ -453,15 +454,23 @@ protected:
public:
AutoencodingEngine(bool decode_only = true,
bool use_video_decoder = false)
bool use_video_decoder = false,
SDVersion version = VERSION_SD1)
: decode_only(decode_only), use_video_decoder(use_video_decoder) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
dd_config.z_channels = 16;
use_quant = false;
}
if (use_video_decoder) {
use_quant = false;
}
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder(dd_config.ch,
dd_config.out_ch,
dd_config.ch_mult,
dd_config.num_res_blocks,
dd_config.z_channels,
use_video_decoder));
if (!use_video_decoder) {
if (use_quant) {
blocks["post_quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(dd_config.z_channels,
embed_dim,
{1, 1}));
@ -473,7 +482,7 @@ public:
dd_config.in_channels,
dd_config.z_channels,
dd_config.double_z));
if (!use_video_decoder) {
if (use_quant) {
int factor = dd_config.double_z ? 2 : 1;
blocks["quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(embed_dim * factor,
@ -485,7 +494,7 @@ public:
struct ggml_tensor* decode(struct ggml_context* ctx, struct ggml_tensor* z) {
// z: [N, z_channels, h, w]
if (!use_video_decoder) {
if (use_quant) {
auto post_quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["post_quant_conv"]);
z = post_quant_conv->forward(ctx, z); // [N, z_channels, h, w]
}
@ -502,7 +511,7 @@ public:
auto encoder = std::dynamic_pointer_cast<Encoder>(blocks["encoder"]);
auto h = encoder->forward(ctx, x); // [N, 2*z_channels, h/8, w/8]
if (!use_video_decoder) {
if (use_quant) {
auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]);
h = quant_conv->forward(ctx, h); // [N, 2*embed_dim, h/8, w/8]
}
@ -510,15 +519,16 @@ public:
}
};
struct AutoEncoderKL : public GGMLModule {
struct AutoEncoderKL : public GGMLRunner {
bool decode_only = true;
AutoencodingEngine ae;
AutoEncoderKL(ggml_backend_t backend,
ggml_type wtype,
bool decode_only = false,
bool use_video_decoder = false)
: decode_only(decode_only), ae(decode_only, use_video_decoder), GGMLModule(backend, wtype) {
bool use_video_decoder = false,
SDVersion version = VERSION_SD1)
: decode_only(decode_only), ae(decode_only, use_video_decoder, version), GGMLRunner(backend, wtype) {
ae.init(params_ctx, wtype);
}
@ -526,14 +536,6 @@ struct AutoEncoderKL : public GGMLModule {
return "vae";
}
size_t get_params_mem_size() {
return ae.get_params_mem_size();
}
size_t get_params_num() {
return ae.get_params_num();
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
ae.get_param_tensors(tensors, prefix);
}
@ -560,7 +562,7 @@ struct AutoEncoderKL : public GGMLModule {
};
// ggml_set_f32(z, 0.5f);
// print_ggml_tensor(z);
GGMLModule::compute(get_graph, n_threads, true, output, output_ctx);
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
}
void test() {