mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
resync and updated sdcpp for flux and sd3 support
This commit is contained in:
parent
33721615b5
commit
f32a874966
30 changed files with 2434248 additions and 1729 deletions
|
@ -6,7 +6,7 @@
|
|||
|
||||
/*================================================== AutoEncoderKL ===================================================*/
|
||||
|
||||
#define VAE_GRAPH_SIZE 10240
|
||||
#define VAE_GRAPH_SIZE 20480
|
||||
|
||||
class ResnetBlock : public UnaryBlock {
|
||||
protected:
|
||||
|
@ -439,6 +439,7 @@ class AutoencodingEngine : public GGMLBlock {
|
|||
protected:
|
||||
bool decode_only = true;
|
||||
bool use_video_decoder = false;
|
||||
bool use_quant = true;
|
||||
int embed_dim = 4;
|
||||
struct {
|
||||
int z_channels = 4;
|
||||
|
@ -453,15 +454,23 @@ protected:
|
|||
|
||||
public:
|
||||
AutoencodingEngine(bool decode_only = true,
|
||||
bool use_video_decoder = false)
|
||||
bool use_video_decoder = false,
|
||||
SDVersion version = VERSION_SD1)
|
||||
: decode_only(decode_only), use_video_decoder(use_video_decoder) {
|
||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
||||
dd_config.z_channels = 16;
|
||||
use_quant = false;
|
||||
}
|
||||
if (use_video_decoder) {
|
||||
use_quant = false;
|
||||
}
|
||||
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder(dd_config.ch,
|
||||
dd_config.out_ch,
|
||||
dd_config.ch_mult,
|
||||
dd_config.num_res_blocks,
|
||||
dd_config.z_channels,
|
||||
use_video_decoder));
|
||||
if (!use_video_decoder) {
|
||||
if (use_quant) {
|
||||
blocks["post_quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(dd_config.z_channels,
|
||||
embed_dim,
|
||||
{1, 1}));
|
||||
|
@ -473,7 +482,7 @@ public:
|
|||
dd_config.in_channels,
|
||||
dd_config.z_channels,
|
||||
dd_config.double_z));
|
||||
if (!use_video_decoder) {
|
||||
if (use_quant) {
|
||||
int factor = dd_config.double_z ? 2 : 1;
|
||||
|
||||
blocks["quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(embed_dim * factor,
|
||||
|
@ -485,7 +494,7 @@ public:
|
|||
|
||||
struct ggml_tensor* decode(struct ggml_context* ctx, struct ggml_tensor* z) {
|
||||
// z: [N, z_channels, h, w]
|
||||
if (!use_video_decoder) {
|
||||
if (use_quant) {
|
||||
auto post_quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["post_quant_conv"]);
|
||||
z = post_quant_conv->forward(ctx, z); // [N, z_channels, h, w]
|
||||
}
|
||||
|
@ -502,7 +511,7 @@ public:
|
|||
auto encoder = std::dynamic_pointer_cast<Encoder>(blocks["encoder"]);
|
||||
|
||||
auto h = encoder->forward(ctx, x); // [N, 2*z_channels, h/8, w/8]
|
||||
if (!use_video_decoder) {
|
||||
if (use_quant) {
|
||||
auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]);
|
||||
h = quant_conv->forward(ctx, h); // [N, 2*embed_dim, h/8, w/8]
|
||||
}
|
||||
|
@ -510,15 +519,16 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
struct AutoEncoderKL : public GGMLModule {
|
||||
struct AutoEncoderKL : public GGMLRunner {
|
||||
bool decode_only = true;
|
||||
AutoencodingEngine ae;
|
||||
|
||||
AutoEncoderKL(ggml_backend_t backend,
|
||||
ggml_type wtype,
|
||||
bool decode_only = false,
|
||||
bool use_video_decoder = false)
|
||||
: decode_only(decode_only), ae(decode_only, use_video_decoder), GGMLModule(backend, wtype) {
|
||||
bool use_video_decoder = false,
|
||||
SDVersion version = VERSION_SD1)
|
||||
: decode_only(decode_only), ae(decode_only, use_video_decoder, version), GGMLRunner(backend, wtype) {
|
||||
ae.init(params_ctx, wtype);
|
||||
}
|
||||
|
||||
|
@ -526,14 +536,6 @@ struct AutoEncoderKL : public GGMLModule {
|
|||
return "vae";
|
||||
}
|
||||
|
||||
size_t get_params_mem_size() {
|
||||
return ae.get_params_mem_size();
|
||||
}
|
||||
|
||||
size_t get_params_num() {
|
||||
return ae.get_params_num();
|
||||
}
|
||||
|
||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||
ae.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
|
@ -560,7 +562,7 @@ struct AutoEncoderKL : public GGMLModule {
|
|||
};
|
||||
// ggml_set_f32(z, 0.5f);
|
||||
// print_ggml_tensor(z);
|
||||
GGMLModule::compute(get_graph, n_threads, true, output, output_ctx);
|
||||
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||
}
|
||||
|
||||
void test() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue