mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 09:34:37 +00:00
resync and updated sdcpp for flux and sd3 support
This commit is contained in:
parent
33721615b5
commit
f32a874966
30 changed files with 2434248 additions and 1729 deletions
|
@ -14,7 +14,7 @@
|
|||
*/
|
||||
class ControlNetBlock : public GGMLBlock {
|
||||
protected:
|
||||
SDVersion version = VERSION_1_x;
|
||||
SDVersion version = VERSION_SD1;
|
||||
// network hparams
|
||||
int in_channels = 4;
|
||||
int out_channels = 4;
|
||||
|
@ -26,19 +26,19 @@ protected:
|
|||
int time_embed_dim = 1280; // model_channels*4
|
||||
int num_heads = 8;
|
||||
int num_head_channels = -1; // channels // num_heads
|
||||
int context_dim = 768; // 1024 for VERSION_2_x, 2048 for VERSION_XL
|
||||
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL
|
||||
|
||||
public:
|
||||
int model_channels = 320;
|
||||
int adm_in_channels = 2816; // only for VERSION_XL
|
||||
int adm_in_channels = 2816; // only for VERSION_SDXL
|
||||
|
||||
ControlNetBlock(SDVersion version = VERSION_1_x)
|
||||
ControlNetBlock(SDVersion version = VERSION_SD1)
|
||||
: version(version) {
|
||||
if (version == VERSION_2_x) {
|
||||
if (version == VERSION_SD2) {
|
||||
context_dim = 1024;
|
||||
num_head_channels = 64;
|
||||
num_heads = -1;
|
||||
} else if (version == VERSION_XL) {
|
||||
} else if (version == VERSION_SDXL) {
|
||||
context_dim = 2048;
|
||||
attention_resolutions = {4, 2};
|
||||
channel_mult = {1, 2, 4};
|
||||
|
@ -58,7 +58,7 @@ public:
|
|||
// time_embed_1 is nn.SiLU()
|
||||
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
|
||||
|
||||
if (version == VERSION_XL || version == VERSION_SVD) {
|
||||
if (version == VERSION_SDXL || version == VERSION_SVD) {
|
||||
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
|
||||
// label_emb_1 is nn.SiLU()
|
||||
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
|
||||
|
@ -306,8 +306,8 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
struct ControlNet : public GGMLModule {
|
||||
SDVersion version = VERSION_1_x;
|
||||
struct ControlNet : public GGMLRunner {
|
||||
SDVersion version = VERSION_SD1;
|
||||
ControlNetBlock control_net;
|
||||
|
||||
ggml_backend_buffer_t control_buffer = NULL; // keep control output tensors in backend memory
|
||||
|
@ -318,8 +318,8 @@ struct ControlNet : public GGMLModule {
|
|||
|
||||
ControlNet(ggml_backend_t backend,
|
||||
ggml_type wtype,
|
||||
SDVersion version = VERSION_1_x)
|
||||
: GGMLModule(backend, wtype), control_net(version) {
|
||||
SDVersion version = VERSION_SD1)
|
||||
: GGMLRunner(backend, wtype), control_net(version) {
|
||||
control_net.init(params_ctx, wtype);
|
||||
}
|
||||
|
||||
|
@ -369,14 +369,6 @@ struct ControlNet : public GGMLModule {
|
|||
return "control_net";
|
||||
}
|
||||
|
||||
size_t get_params_mem_size() {
|
||||
return control_net.get_params_mem_size();
|
||||
}
|
||||
|
||||
size_t get_params_num() {
|
||||
return control_net.get_params_num();
|
||||
}
|
||||
|
||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||
control_net.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
|
@ -434,7 +426,7 @@ struct ControlNet : public GGMLModule {
|
|||
return build_graph(x, hint, timesteps, context, y);
|
||||
};
|
||||
|
||||
GGMLModule::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
guided_hint_cached = true;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue