mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-08 18:30:50 +00:00
sd: sync to master-480-b87fe13 (#1932)
This commit is contained in:
parent
28091dec43
commit
715a76ce5f
5 changed files with 36 additions and 12 deletions
|
|
@ -748,7 +748,7 @@ namespace Flux {
|
|||
int nerf_depth = 4;
|
||||
int nerf_max_freqs = 8;
|
||||
bool use_x0 = false;
|
||||
bool use_patch_size_32 = false;
|
||||
bool fake_patch_size_x2 = false;
|
||||
};
|
||||
|
||||
struct FluxParams {
|
||||
|
|
@ -786,8 +786,11 @@ namespace Flux {
|
|||
Flux(FluxParams params)
|
||||
: params(params) {
|
||||
if (params.version == VERSION_CHROMA_RADIANCE) {
|
||||
std::pair<int, int> kernel_size = {16, 16};
|
||||
std::pair<int, int> stride = kernel_size;
|
||||
std::pair<int, int> kernel_size = {params.patch_size, params.patch_size};
|
||||
if (params.chroma_radiance_params.fake_patch_size_x2) {
|
||||
kernel_size = {params.patch_size / 2, params.patch_size / 2};
|
||||
}
|
||||
std::pair<int, int> stride = kernel_size;
|
||||
|
||||
blocks["img_in_patch"] = std::make_shared<Conv2d>(params.in_channels,
|
||||
params.hidden_size,
|
||||
|
|
@ -1082,7 +1085,7 @@ namespace Flux {
|
|||
auto img = pad_to_patch_size(ctx, x);
|
||||
auto orig_img = img;
|
||||
|
||||
if (params.chroma_radiance_params.use_patch_size_32) {
|
||||
if (params.chroma_radiance_params.fake_patch_size_x2) {
|
||||
// It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable
|
||||
// Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch?
|
||||
// img = F.interpolate(img, size=(H//2, W//2), mode="nearest")
|
||||
|
|
@ -1303,7 +1306,8 @@ namespace Flux {
|
|||
flux_params.ref_index_scale = 10.f;
|
||||
flux_params.use_mlp_silu_act = true;
|
||||
}
|
||||
int64_t head_dim = 0;
|
||||
int64_t head_dim = 0;
|
||||
int64_t actual_radiance_patch_size = -1;
|
||||
for (auto pair : tensor_storage_map) {
|
||||
std::string tensor_name = pair.first;
|
||||
if (!starts_with(tensor_name, prefix))
|
||||
|
|
@ -1316,9 +1320,12 @@ namespace Flux {
|
|||
flux_params.chroma_radiance_params.use_x0 = true;
|
||||
}
|
||||
if (tensor_name.find("__32x32__") != std::string::npos) {
|
||||
LOG_DEBUG("using patch size 32 prediction");
|
||||
flux_params.chroma_radiance_params.use_patch_size_32 = true;
|
||||
flux_params.patch_size = 32;
|
||||
LOG_DEBUG("using patch size 32");
|
||||
flux_params.patch_size = 32;
|
||||
}
|
||||
if (tensor_name.find("img_in_patch.weight") != std::string::npos) {
|
||||
actual_radiance_patch_size = pair.second.ne[0];
|
||||
LOG_DEBUG("actual radiance patch size: %d", actual_radiance_patch_size);
|
||||
}
|
||||
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
|
||||
// Chroma
|
||||
|
|
@ -1351,6 +1358,11 @@ namespace Flux {
|
|||
head_dim = pair.second.ne[0];
|
||||
}
|
||||
}
|
||||
if (actual_radiance_patch_size > 0 && actual_radiance_patch_size != flux_params.patch_size) {
|
||||
GGML_ASSERT(flux_params.patch_size == 2 * actual_radiance_patch_size);
|
||||
LOG_DEBUG("using fake x2 patch size");
|
||||
flux_params.chroma_radiance_params.fake_patch_size_x2 = true;
|
||||
}
|
||||
|
||||
flux_params.num_heads = static_cast<int>(flux_params.hidden_size / head_dim);
|
||||
|
||||
|
|
|
|||
|
|
@ -1068,6 +1068,7 @@ SDVersion ModelLoader::get_sd_version() {
|
|||
int64_t patch_embedding_channels = 0;
|
||||
bool has_img_emb = false;
|
||||
bool has_middle_block_1 = false;
|
||||
bool has_output_block_311 = false;
|
||||
bool has_output_block_71 = false;
|
||||
|
||||
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
|
|
@ -1128,6 +1129,9 @@ SDVersion ModelLoader::get_sd_version() {
|
|||
tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) {
|
||||
has_middle_block_1 = true;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.output_blocks.3.1.transformer_blocks.1") != std::string::npos) {
|
||||
has_output_block_311 = true;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos) {
|
||||
has_output_block_71 = true;
|
||||
}
|
||||
|
|
@ -1166,6 +1170,9 @@ SDVersion ModelLoader::get_sd_version() {
|
|||
return VERSION_SDXL_PIX2PIX;
|
||||
}
|
||||
if (!has_middle_block_1) {
|
||||
if (!has_output_block_311) {
|
||||
return VERSION_SDXL_VEGA;
|
||||
}
|
||||
return VERSION_SDXL_SSD1B;
|
||||
}
|
||||
return VERSION_SDXL;
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ enum SDVersion {
|
|||
VERSION_SDXL,
|
||||
VERSION_SDXL_INPAINT,
|
||||
VERSION_SDXL_PIX2PIX,
|
||||
VERSION_SDXL_VEGA,
|
||||
VERSION_SDXL_SSD1B,
|
||||
VERSION_SVD,
|
||||
VERSION_SD3,
|
||||
|
|
@ -66,7 +67,7 @@ static inline bool sd_version_is_sd2(SDVersion version) {
|
|||
}
|
||||
|
||||
static inline bool sd_version_is_sdxl(SDVersion version) {
|
||||
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B) {
|
||||
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B || version == VERSION_SDXL_VEGA) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ const char* model_version_to_str[] = {
|
|||
"SDXL",
|
||||
"SDXL Inpaint",
|
||||
"SDXL Instruct-Pix2Pix",
|
||||
"SDXL (Vega)",
|
||||
"SDXL (SSD1B)",
|
||||
"SVD",
|
||||
"SD3.x",
|
||||
|
|
@ -763,7 +764,7 @@ public:
|
|||
LOG_INFO("Using Conv2d direct in the vae model");
|
||||
first_stage_model->set_conv2d_direct_enabled(true);
|
||||
}
|
||||
if (version == VERSION_SDXL &&
|
||||
if (sd_version_is_sdxl(version) &&
|
||||
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) {
|
||||
float vae_conv_2d_scale = 1.f / 32.f;
|
||||
LOG_WARN(
|
||||
|
|
|
|||
|
|
@ -201,6 +201,9 @@ public:
|
|||
num_head_channels = 64;
|
||||
num_heads = -1;
|
||||
use_linear_projection = true;
|
||||
if (version == VERSION_SDXL_VEGA) {
|
||||
transformer_depth = {1, 1, 2};
|
||||
}
|
||||
} else if (version == VERSION_SVD) {
|
||||
in_channels = 8;
|
||||
out_channels = 4;
|
||||
|
|
@ -319,7 +322,7 @@ public:
|
|||
}
|
||||
if (!tiny_unet) {
|
||||
blocks["middle_block.0"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
|
||||
if (version != VERSION_SDXL_SSD1B) {
|
||||
if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) {
|
||||
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
|
||||
n_head,
|
||||
d_head,
|
||||
|
|
@ -520,7 +523,7 @@ public:
|
|||
// middle_block
|
||||
if (!tiny_unet) {
|
||||
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
||||
if (version != VERSION_SDXL_SSD1B) {
|
||||
if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) {
|
||||
h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
||||
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue