resync and updated sdcpp for flux and sd3 support

This commit is contained in:
Concedo 2024-11-03 22:03:16 +08:00
parent 33721615b5
commit f32a874966
30 changed files with 2434248 additions and 1729 deletions

View file

@ -183,7 +183,7 @@ public:
}
};
struct TinyAutoEncoder : public GGMLModule {
struct TinyAutoEncoder : public GGMLRunner {
TAESD taesd;
bool decode_only = false;
@ -192,7 +192,7 @@ struct TinyAutoEncoder : public GGMLModule {
bool decoder_only = true)
: decode_only(decoder_only),
taesd(decode_only),
GGMLModule(backend, wtype) {
GGMLRunner(backend, wtype) {
taesd.init(params_ctx, wtype);
}
@ -200,16 +200,8 @@ struct TinyAutoEncoder : public GGMLModule {
return "taesd";
}
size_t get_params_mem_size() {
return taesd.get_params_mem_size();
}
size_t get_params_num() {
return taesd.get_params_num();
}
bool load_from_file(const std::string& file_path) {
LOG_INFO("loading taesd from '%s'", file_path.c_str());
LOG_INFO("loading taesd from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false");
alloc_params_buffer();
std::map<std::string, ggml_tensor*> taesd_tensors;
taesd.get_param_tensors(taesd_tensors);
@ -252,7 +244,7 @@ struct TinyAutoEncoder : public GGMLModule {
return build_graph(z, decode_graph);
};
GGMLModule::compute(get_graph, n_threads, false, output, output_ctx);
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
}
};