mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 09:34:37 +00:00
updated sdcpp prepare for inpaint
fixed img2img (+1 squashed commits) Squashed commits: [42c48f14] try update sdcpp, feels kind of buggy
This commit is contained in:
parent
ebf924c5d1
commit
fea3b2bd4a
18 changed files with 1850 additions and 271 deletions
|
@ -546,7 +546,7 @@ protected:
|
||||||
int64_t num_positions;
|
int64_t num_positions;
|
||||||
|
|
||||||
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
|
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
|
||||||
enum ggml_type token_wtype = (tensor_types.find(prefix + "token_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "token_embedding.weight"] : GGML_TYPE_F32;
|
enum ggml_type token_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "token_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "token_embedding.weight"] : GGML_TYPE_F32;
|
||||||
enum ggml_type position_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32;
|
enum ggml_type position_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32;
|
||||||
|
|
||||||
params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, token_wtype, embed_dim, vocab_size);
|
params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, token_wtype, embed_dim, vocab_size);
|
||||||
|
|
|
@ -52,6 +52,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||||
std::string trigger_word = "img"; // should be user settable
|
std::string trigger_word = "img"; // should be user settable
|
||||||
std::string embd_dir;
|
std::string embd_dir;
|
||||||
int32_t num_custom_embeddings = 0;
|
int32_t num_custom_embeddings = 0;
|
||||||
|
int32_t num_custom_embeddings_2 = 0;
|
||||||
std::vector<uint8_t> token_embed_custom;
|
std::vector<uint8_t> token_embed_custom;
|
||||||
std::vector<std::string> readed_embeddings;
|
std::vector<std::string> readed_embeddings;
|
||||||
|
|
||||||
|
@ -61,18 +62,18 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||||
SDVersion version = VERSION_SD1,
|
SDVersion version = VERSION_SD1,
|
||||||
PMVersion pv = PM_VERSION_1,
|
PMVersion pv = PM_VERSION_1,
|
||||||
int clip_skip = -1)
|
int clip_skip = -1)
|
||||||
: version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir) {
|
: version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407), embd_dir(embd_dir) {
|
||||||
if (clip_skip <= 0) {
|
if (clip_skip <= 0) {
|
||||||
clip_skip = 1;
|
clip_skip = 1;
|
||||||
if (version == VERSION_SD2 || version == VERSION_SDXL) {
|
if (sd_version_is_sd2(version) || sd_version_is_sdxl(version)) {
|
||||||
clip_skip = 2;
|
clip_skip = 2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (version == VERSION_SD1) {
|
if (sd_version_is_sd1(version)) {
|
||||||
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip);
|
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip);
|
||||||
} else if (version == VERSION_SD2) {
|
} else if (sd_version_is_sd2(version)) {
|
||||||
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, clip_skip);
|
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, clip_skip);
|
||||||
} else if (version == VERSION_SDXL) {
|
} else if (sd_version_is_sdxl(version)) {
|
||||||
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
|
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
|
||||||
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
|
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
|
||||||
}
|
}
|
||||||
|
@ -80,35 +81,35 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||||
|
|
||||||
void set_clip_skip(int clip_skip) {
|
void set_clip_skip(int clip_skip) {
|
||||||
text_model->set_clip_skip(clip_skip);
|
text_model->set_clip_skip(clip_skip);
|
||||||
if (version == VERSION_SDXL) {
|
if (sd_version_is_sdxl(version)) {
|
||||||
text_model2->set_clip_skip(clip_skip);
|
text_model2->set_clip_skip(clip_skip);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
|
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
|
||||||
text_model->get_param_tensors(tensors, "cond_stage_model.transformer.text_model");
|
text_model->get_param_tensors(tensors, "cond_stage_model.transformer.text_model");
|
||||||
if (version == VERSION_SDXL) {
|
if (sd_version_is_sdxl(version)) {
|
||||||
text_model2->get_param_tensors(tensors, "cond_stage_model.1.transformer.text_model");
|
text_model2->get_param_tensors(tensors, "cond_stage_model.1.transformer.text_model");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void alloc_params_buffer() {
|
void alloc_params_buffer() {
|
||||||
text_model->alloc_params_buffer();
|
text_model->alloc_params_buffer();
|
||||||
if (version == VERSION_SDXL) {
|
if (sd_version_is_sdxl(version)) {
|
||||||
text_model2->alloc_params_buffer();
|
text_model2->alloc_params_buffer();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void free_params_buffer() {
|
void free_params_buffer() {
|
||||||
text_model->free_params_buffer();
|
text_model->free_params_buffer();
|
||||||
if (version == VERSION_SDXL) {
|
if (sd_version_is_sdxl(version)) {
|
||||||
text_model2->free_params_buffer();
|
text_model2->free_params_buffer();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t get_params_buffer_size() {
|
size_t get_params_buffer_size() {
|
||||||
size_t buffer_size = text_model->get_params_buffer_size();
|
size_t buffer_size = text_model->get_params_buffer_size();
|
||||||
if (version == VERSION_SDXL) {
|
if (sd_version_is_sdxl(version)) {
|
||||||
buffer_size += text_model2->get_params_buffer_size();
|
buffer_size += text_model2->get_params_buffer_size();
|
||||||
}
|
}
|
||||||
return buffer_size;
|
return buffer_size;
|
||||||
|
@ -131,18 +132,31 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||||
params.no_alloc = false;
|
params.no_alloc = false;
|
||||||
struct ggml_context* embd_ctx = ggml_init(params);
|
struct ggml_context* embd_ctx = ggml_init(params);
|
||||||
struct ggml_tensor* embd = NULL;
|
struct ggml_tensor* embd = NULL;
|
||||||
int64_t hidden_size = text_model->model.hidden_size;
|
struct ggml_tensor* embd2 = NULL;
|
||||||
auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) {
|
auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) {
|
||||||
if (tensor_storage.ne[0] != hidden_size) {
|
if (tensor_storage.ne[0] != text_model->model.hidden_size) {
|
||||||
LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], hidden_size);
|
if (text_model2) {
|
||||||
|
if (tensor_storage.ne[0] == text_model2->model.hidden_size) {
|
||||||
|
embd2 = ggml_new_tensor_2d(embd_ctx, tensor_storage.type, text_model2->model.hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1);
|
||||||
|
*dst_tensor = embd2;
|
||||||
|
} else {
|
||||||
|
LOG_DEBUG("embedding wrong hidden size, got %i, expected %i or %i", tensor_storage.ne[0], text_model->model.hidden_size, text_model2->model.hidden_size);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
embd = ggml_new_tensor_2d(embd_ctx, tensor_storage.type, hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1);
|
} else {
|
||||||
|
LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], text_model->model.hidden_size);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
embd = ggml_new_tensor_2d(embd_ctx, tensor_storage.type, text_model->model.hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1);
|
||||||
*dst_tensor = embd;
|
*dst_tensor = embd;
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
model_loader.load_tensors(on_load, NULL);
|
model_loader.load_tensors(on_load, NULL);
|
||||||
readed_embeddings.push_back(embd_name);
|
readed_embeddings.push_back(embd_name);
|
||||||
|
if (embd) {
|
||||||
|
int64_t hidden_size = text_model->model.hidden_size;
|
||||||
token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd));
|
token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd));
|
||||||
memcpy((void*)(token_embed_custom.data() + num_custom_embeddings * hidden_size * ggml_type_size(embd->type)),
|
memcpy((void*)(token_embed_custom.data() + num_custom_embeddings * hidden_size * ggml_type_size(embd->type)),
|
||||||
embd->data,
|
embd->data,
|
||||||
|
@ -153,6 +167,20 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||||
num_custom_embeddings++;
|
num_custom_embeddings++;
|
||||||
}
|
}
|
||||||
LOG_DEBUG("embedding '%s' applied, custom embeddings: %i", embd_name.c_str(), num_custom_embeddings);
|
LOG_DEBUG("embedding '%s' applied, custom embeddings: %i", embd_name.c_str(), num_custom_embeddings);
|
||||||
|
}
|
||||||
|
if (embd2) {
|
||||||
|
int64_t hidden_size = text_model2->model.hidden_size;
|
||||||
|
token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd2));
|
||||||
|
memcpy((void*)(token_embed_custom.data() + num_custom_embeddings_2 * hidden_size * ggml_type_size(embd2->type)),
|
||||||
|
embd2->data,
|
||||||
|
ggml_nbytes(embd2));
|
||||||
|
for (int i = 0; i < embd2->ne[1]; i++) {
|
||||||
|
bpe_tokens.push_back(text_model2->model.vocab_size + num_custom_embeddings_2);
|
||||||
|
// LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings);
|
||||||
|
num_custom_embeddings_2++;
|
||||||
|
}
|
||||||
|
LOG_DEBUG("embedding '%s' applied, custom embeddings: %i (text model 2)", embd_name.c_str(), num_custom_embeddings_2);
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -402,7 +430,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||||
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
|
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
|
||||||
struct ggml_tensor* input_ids2 = NULL;
|
struct ggml_tensor* input_ids2 = NULL;
|
||||||
size_t max_token_idx = 0;
|
size_t max_token_idx = 0;
|
||||||
if (version == VERSION_SDXL) {
|
if (sd_version_is_sdxl(version)) {
|
||||||
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), tokenizer.EOS_TOKEN_ID);
|
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), tokenizer.EOS_TOKEN_ID);
|
||||||
if (it != chunk_tokens.end()) {
|
if (it != chunk_tokens.end()) {
|
||||||
std::fill(std::next(it), chunk_tokens.end(), 0);
|
std::fill(std::next(it), chunk_tokens.end(), 0);
|
||||||
|
@ -427,7 +455,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||||
false,
|
false,
|
||||||
&chunk_hidden_states1,
|
&chunk_hidden_states1,
|
||||||
work_ctx);
|
work_ctx);
|
||||||
if (version == VERSION_SDXL) {
|
if (sd_version_is_sdxl(version)) {
|
||||||
text_model2->compute(n_threads,
|
text_model2->compute(n_threads,
|
||||||
input_ids2,
|
input_ids2,
|
||||||
0,
|
0,
|
||||||
|
@ -486,7 +514,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||||
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
|
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
|
||||||
|
|
||||||
ggml_tensor* vec = NULL;
|
ggml_tensor* vec = NULL;
|
||||||
if (version == VERSION_SDXL) {
|
if (sd_version_is_sdxl(version)) {
|
||||||
int out_dim = 256;
|
int out_dim = 256;
|
||||||
vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, adm_in_channels);
|
vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, adm_in_channels);
|
||||||
// [0:1280]
|
// [0:1280]
|
||||||
|
|
|
@ -34,11 +34,11 @@ public:
|
||||||
|
|
||||||
ControlNetBlock(SDVersion version = VERSION_SD1)
|
ControlNetBlock(SDVersion version = VERSION_SD1)
|
||||||
: version(version) {
|
: version(version) {
|
||||||
if (version == VERSION_SD2) {
|
if (sd_version_is_sd2(version)) {
|
||||||
context_dim = 1024;
|
context_dim = 1024;
|
||||||
num_head_channels = 64;
|
num_head_channels = 64;
|
||||||
num_heads = -1;
|
num_heads = -1;
|
||||||
} else if (version == VERSION_SDXL) {
|
} else if (sd_version_is_sdxl(version)) {
|
||||||
context_dim = 2048;
|
context_dim = 2048;
|
||||||
attention_resolutions = {4, 2};
|
attention_resolutions = {4, 2};
|
||||||
channel_mult = {1, 2, 4};
|
channel_mult = {1, 2, 4};
|
||||||
|
@ -58,7 +58,7 @@ public:
|
||||||
// time_embed_1 is nn.SiLU()
|
// time_embed_1 is nn.SiLU()
|
||||||
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
|
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
|
||||||
|
|
||||||
if (version == VERSION_SDXL || version == VERSION_SVD) {
|
if (sd_version_is_sdxl(version) || version == VERSION_SVD) {
|
||||||
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
|
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
|
||||||
// label_emb_1 is nn.SiLU()
|
// label_emb_1 is nn.SiLU()
|
||||||
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
|
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
|
||||||
|
|
|
@ -474,7 +474,8 @@ static void sample_k_diffusion(sample_method_t method,
|
||||||
ggml_context* work_ctx,
|
ggml_context* work_ctx,
|
||||||
ggml_tensor* x,
|
ggml_tensor* x,
|
||||||
std::vector<float> sigmas,
|
std::vector<float> sigmas,
|
||||||
std::shared_ptr<RNG> rng) {
|
std::shared_ptr<RNG> rng,
|
||||||
|
float eta) {
|
||||||
size_t steps = sigmas.size() - 1;
|
size_t steps = sigmas.size() - 1;
|
||||||
// sample_euler_ancestral
|
// sample_euler_ancestral
|
||||||
switch (method) {
|
switch (method) {
|
||||||
|
@ -1005,6 +1006,374 @@ static void sample_k_diffusion(sample_method_t method,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case DDIM_TRAILING: // Denoising Diffusion Implicit Models
|
||||||
|
// with the "trailing" timestep spacing
|
||||||
|
{
|
||||||
|
// See J. Song et al., "Denoising Diffusion Implicit
|
||||||
|
// Models", arXiv:2010.02502 [cs.LG]
|
||||||
|
//
|
||||||
|
// DDIM itself needs alphas_cumprod (DDPM, J. Ho et al.,
|
||||||
|
// arXiv:2006.11239 [cs.LG] with k-diffusion's start and
|
||||||
|
// end beta) (which unfortunately k-diffusion's data
|
||||||
|
// structure hides from the denoiser), and the sigmas are
|
||||||
|
// also needed to invert the behavior of CompVisDenoiser
|
||||||
|
// (k-diffusion's LMSDiscreteScheduler)
|
||||||
|
float beta_start = 0.00085f;
|
||||||
|
float beta_end = 0.0120f;
|
||||||
|
std::vector<double> alphas_cumprod;
|
||||||
|
std::vector<double> compvis_sigmas;
|
||||||
|
|
||||||
|
alphas_cumprod.reserve(TIMESTEPS);
|
||||||
|
compvis_sigmas.reserve(TIMESTEPS);
|
||||||
|
for (int i = 0; i < TIMESTEPS; i++) {
|
||||||
|
alphas_cumprod[i] =
|
||||||
|
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
|
||||||
|
(1.0f -
|
||||||
|
std::pow(sqrtf(beta_start) +
|
||||||
|
(sqrtf(beta_end) - sqrtf(beta_start)) *
|
||||||
|
((float)i / (TIMESTEPS - 1)), 2));
|
||||||
|
compvis_sigmas[i] =
|
||||||
|
std::sqrt((1 - alphas_cumprod[i]) /
|
||||||
|
alphas_cumprod[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* pred_original_sample =
|
||||||
|
ggml_dup_tensor(work_ctx, x);
|
||||||
|
struct ggml_tensor* variance_noise =
|
||||||
|
ggml_dup_tensor(work_ctx, x);
|
||||||
|
|
||||||
|
for (int i = 0; i < steps; i++) {
|
||||||
|
// The "trailing" DDIM timestep, see S. Lin et al.,
|
||||||
|
// "Common Diffusion Noise Schedules and Sample Steps
|
||||||
|
// are Flawed", arXiv:2305.08891 [cs], p. 4, Table
|
||||||
|
// 2. Most variables below follow Diffusers naming
|
||||||
|
//
|
||||||
|
// Diffuser naming vs. Song et al. (2010), p. 5, (12)
|
||||||
|
// and p. 16, (16) (<variable name> -> <name in
|
||||||
|
// paper>):
|
||||||
|
//
|
||||||
|
// - pred_noise_t -> epsilon_theta^(t)(x_t)
|
||||||
|
// - pred_original_sample -> f_theta^(t)(x_t) or x_0
|
||||||
|
// - std_dev_t -> sigma_t (not the LMS sigma)
|
||||||
|
// - eta -> eta (set to 0 at the moment)
|
||||||
|
// - pred_sample_direction -> "direction pointing to
|
||||||
|
// x_t"
|
||||||
|
// - pred_prev_sample -> "x_t-1"
|
||||||
|
int timestep =
|
||||||
|
roundf(TIMESTEPS -
|
||||||
|
i * ((float)TIMESTEPS / steps)) - 1;
|
||||||
|
// 1. get previous step value (=t-1)
|
||||||
|
int prev_timestep = timestep - TIMESTEPS / steps;
|
||||||
|
// The sigma here is chosen to cause the
|
||||||
|
// CompVisDenoiser to produce t = timestep
|
||||||
|
float sigma = compvis_sigmas[timestep];
|
||||||
|
if (i == 0) {
|
||||||
|
// The function add_noise intializes x to
|
||||||
|
// Diffusers' latents * sigma (as in Diffusers'
|
||||||
|
// pipeline) or sample * sigma (Diffusers'
|
||||||
|
// scheduler), where this sigma = init_noise_sigma
|
||||||
|
// in Diffusers. For DDPM and DDIM however,
|
||||||
|
// init_noise_sigma = 1. But the k-diffusion
|
||||||
|
// model() also evaluates F_theta(c_in(sigma) x;
|
||||||
|
// ...) instead of the bare U-net F_theta, with
|
||||||
|
// c_in = 1 / sqrt(sigma^2 + 1), as defined in
|
||||||
|
// T. Karras et al., "Elucidating the Design Space
|
||||||
|
// of Diffusion-Based Generative Models",
|
||||||
|
// arXiv:2206.00364 [cs.CV], p. 3, Table 1. Hence
|
||||||
|
// the first call has to be prescaled as x <- x /
|
||||||
|
// (c_in * sigma) with the k-diffusion pipeline
|
||||||
|
// and CompVisDenoiser.
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_x[j] *= std::sqrt(sigma * sigma + 1) /
|
||||||
|
sigma;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// For the subsequent steps after the first one,
|
||||||
|
// at this point x = latents or x = sample, and
|
||||||
|
// needs to be prescaled with x <- sample / c_in
|
||||||
|
// to compensate for model() applying the scale
|
||||||
|
// c_in before the U-net F_theta
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_x[j] *= std::sqrt(sigma * sigma + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Note (also noise_pred in Diffuser's pipeline)
|
||||||
|
// model_output = model() is the D(x, sigma) as
|
||||||
|
// defined in Karras et al. (2022), p. 3, Table 1 and
|
||||||
|
// p. 8 (7), compare also p. 38 (226) therein.
|
||||||
|
struct ggml_tensor* model_output =
|
||||||
|
model(x, sigma, i + 1);
|
||||||
|
// Here model_output is still the k-diffusion denoiser
|
||||||
|
// output, not the U-net output F_theta(c_in(sigma) x;
|
||||||
|
// ...) in Karras et al. (2022), whereas Diffusers'
|
||||||
|
// model_output is F_theta(...). Recover the actual
|
||||||
|
// model_output, which is also referred to as the
|
||||||
|
// "Karras ODE derivative" d or d_cur in several
|
||||||
|
// samplers above.
|
||||||
|
{
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
float* vec_model_output =
|
||||||
|
(float*)model_output->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_model_output[j] =
|
||||||
|
(vec_x[j] - vec_model_output[j]) *
|
||||||
|
(1 / sigma);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 2. compute alphas, betas
|
||||||
|
float alpha_prod_t = alphas_cumprod[timestep];
|
||||||
|
// Note final_alpha_cumprod = alphas_cumprod[0] due to
|
||||||
|
// trailing timestep spacing
|
||||||
|
float alpha_prod_t_prev = prev_timestep >= 0 ?
|
||||||
|
alphas_cumprod[prev_timestep] : alphas_cumprod[0];
|
||||||
|
float beta_prod_t = 1 - alpha_prod_t;
|
||||||
|
// 3. compute predicted original sample from predicted
|
||||||
|
// noise also called "predicted x_0" of formula (12)
|
||||||
|
// from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
|
{
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
float* vec_model_output =
|
||||||
|
(float*)model_output->data;
|
||||||
|
float* vec_pred_original_sample =
|
||||||
|
(float*)pred_original_sample->data;
|
||||||
|
// Note the substitution of latents or sample = x
|
||||||
|
// * c_in = x / sqrt(sigma^2 + 1)
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_pred_original_sample[j] =
|
||||||
|
(vec_x[j] / std::sqrt(sigma * sigma + 1) -
|
||||||
|
std::sqrt(beta_prod_t) *
|
||||||
|
vec_model_output[j]) *
|
||||||
|
(1 / std::sqrt(alpha_prod_t));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Assuming the "epsilon" prediction type, where below
|
||||||
|
// pred_epsilon = model_output is inserted, and is not
|
||||||
|
// defined/copied explicitly.
|
||||||
|
//
|
||||||
|
// 5. compute variance: "sigma_t(eta)" -> see formula
|
||||||
|
// (16)
|
||||||
|
//
|
||||||
|
// sigma_t = sqrt((1 - alpha_t-1)/(1 - alpha_t)) *
|
||||||
|
// sqrt(1 - alpha_t/alpha_t-1)
|
||||||
|
float beta_prod_t_prev = 1 - alpha_prod_t_prev;
|
||||||
|
float variance = (beta_prod_t_prev / beta_prod_t) *
|
||||||
|
(1 - alpha_prod_t / alpha_prod_t_prev);
|
||||||
|
float std_dev_t = eta * std::sqrt(variance);
|
||||||
|
// 6. compute "direction pointing to x_t" of formula
|
||||||
|
// (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
|
// 7. compute x_t without "random noise" of formula
|
||||||
|
// (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
|
{
|
||||||
|
float* vec_model_output = (float*)model_output->data;
|
||||||
|
float* vec_pred_original_sample =
|
||||||
|
(float*)pred_original_sample->data;
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
// Two step inner loop without an explicit
|
||||||
|
// tensor
|
||||||
|
float pred_sample_direction =
|
||||||
|
std::sqrt(1 - alpha_prod_t_prev -
|
||||||
|
std::pow(std_dev_t, 2)) *
|
||||||
|
vec_model_output[j];
|
||||||
|
vec_x[j] = std::sqrt(alpha_prod_t_prev) *
|
||||||
|
vec_pred_original_sample[j] +
|
||||||
|
pred_sample_direction;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (eta > 0) {
|
||||||
|
ggml_tensor_set_f32_randn(variance_noise, rng);
|
||||||
|
float* vec_variance_noise =
|
||||||
|
(float*)variance_noise->data;
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_x[j] += std_dev_t * vec_variance_noise[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// See the note above: x = latents or sample here, and
|
||||||
|
// is not scaled by the c_in. For the final output
|
||||||
|
// this is correct, but for subsequent iterations, x
|
||||||
|
// needs to be prescaled again, since k-diffusion's
|
||||||
|
// model() differes from the bare U-net F_theta by the
|
||||||
|
// factor c_in.
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case TCD: // Strategic Stochastic Sampling (Algorithm 4) in
|
||||||
|
// Trajectory Consistency Distillation
|
||||||
|
{
|
||||||
|
// See J. Zheng et al., "Trajectory Consistency
|
||||||
|
// Distillation: Improved Latent Consistency Distillation
|
||||||
|
// by Semi-Linear Consistency Function with Trajectory
|
||||||
|
// Mapping", arXiv:2402.19159 [cs.CV]
|
||||||
|
float beta_start = 0.00085f;
|
||||||
|
float beta_end = 0.0120f;
|
||||||
|
std::vector<double> alphas_cumprod;
|
||||||
|
std::vector<double> compvis_sigmas;
|
||||||
|
|
||||||
|
alphas_cumprod.reserve(TIMESTEPS);
|
||||||
|
compvis_sigmas.reserve(TIMESTEPS);
|
||||||
|
for (int i = 0; i < TIMESTEPS; i++) {
|
||||||
|
alphas_cumprod[i] =
|
||||||
|
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
|
||||||
|
(1.0f -
|
||||||
|
std::pow(sqrtf(beta_start) +
|
||||||
|
(sqrtf(beta_end) - sqrtf(beta_start)) *
|
||||||
|
((float)i / (TIMESTEPS - 1)), 2));
|
||||||
|
compvis_sigmas[i] =
|
||||||
|
std::sqrt((1 - alphas_cumprod[i]) /
|
||||||
|
alphas_cumprod[i]);
|
||||||
|
}
|
||||||
|
int original_steps = 50;
|
||||||
|
|
||||||
|
struct ggml_tensor* pred_original_sample =
|
||||||
|
ggml_dup_tensor(work_ctx, x);
|
||||||
|
struct ggml_tensor* noise =
|
||||||
|
ggml_dup_tensor(work_ctx, x);
|
||||||
|
|
||||||
|
for (int i = 0; i < steps; i++) {
|
||||||
|
// Analytic form for TCD timesteps
|
||||||
|
int timestep = TIMESTEPS - 1 -
|
||||||
|
(TIMESTEPS / original_steps) *
|
||||||
|
(int)floor(i * ((float)original_steps / steps));
|
||||||
|
// 1. get previous step value
|
||||||
|
int prev_timestep = i >= steps - 1 ? 0 :
|
||||||
|
TIMESTEPS - 1 - (TIMESTEPS / original_steps) *
|
||||||
|
(int)floor((i + 1) *
|
||||||
|
((float)original_steps / steps));
|
||||||
|
// Here timestep_s is tau_n' in Algorithm 4. The _s
|
||||||
|
// notation appears to be that from C. Lu,
|
||||||
|
// "DPM-Solver: A Fast ODE Solver for Diffusion
|
||||||
|
// Probabilistic Model Sampling in Around 10 Steps",
|
||||||
|
// arXiv:2206.00927 [cs.LG], but this notation is not
|
||||||
|
// continued in Algorithm 4, where _n' is used.
|
||||||
|
int timestep_s =
|
||||||
|
(int)floor((1 - eta) * prev_timestep);
|
||||||
|
// Begin k-diffusion specific workaround for
|
||||||
|
// evaluating F_theta(x; ...) from D(x, sigma), same
|
||||||
|
// as in DDIM (and see there for detailed comments)
|
||||||
|
float sigma = compvis_sigmas[timestep];
|
||||||
|
if (i == 0) {
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_x[j] *= std::sqrt(sigma * sigma + 1) /
|
||||||
|
sigma;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_x[j] *= std::sqrt(sigma * sigma + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
struct ggml_tensor* model_output =
|
||||||
|
model(x, sigma, i + 1);
|
||||||
|
{
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
float* vec_model_output =
|
||||||
|
(float*)model_output->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_model_output[j] =
|
||||||
|
(vec_x[j] - vec_model_output[j]) *
|
||||||
|
(1 / sigma);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 2. compute alphas, betas
|
||||||
|
//
|
||||||
|
// When comparing TCD with DDPM/DDIM note that Zheng
|
||||||
|
// et al. (2024) follows the DPM-Solver notation for
|
||||||
|
// alpha. One can find the following comment in the
|
||||||
|
// original DPM-Solver code
|
||||||
|
// (https://github.com/LuChengTHU/dpm-solver/):
|
||||||
|
// "**Important**: Please pay special attention for
|
||||||
|
// the args for `alphas_cumprod`: The `alphas_cumprod`
|
||||||
|
// is the \hat{alpha_n} arrays in the notations of
|
||||||
|
// DDPM. [...] Therefore, the notation \hat{alpha_n}
|
||||||
|
// is different from the notation alpha_t in
|
||||||
|
// DPM-Solver. In fact, we have alpha_{t_n} =
|
||||||
|
// \sqrt{\hat{alpha_n}}, [...]"
|
||||||
|
float alpha_prod_t = alphas_cumprod[timestep];
|
||||||
|
float beta_prod_t = 1 - alpha_prod_t;
|
||||||
|
// Note final_alpha_cumprod = alphas_cumprod[0] since
|
||||||
|
// TCD is always "trailing"
|
||||||
|
float alpha_prod_t_prev = prev_timestep >= 0 ?
|
||||||
|
alphas_cumprod[prev_timestep] : alphas_cumprod[0];
|
||||||
|
// The subscript _s are the only portion in this
|
||||||
|
// section (2) unique to TCD
|
||||||
|
float alpha_prod_s = alphas_cumprod[timestep_s];
|
||||||
|
float beta_prod_s = 1 - alpha_prod_s;
|
||||||
|
// 3. Compute the predicted noised sample x_s based on
|
||||||
|
// the model parameterization
|
||||||
|
//
|
||||||
|
// This section is also exactly the same as DDIM
|
||||||
|
{
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
float* vec_model_output =
|
||||||
|
(float*)model_output->data;
|
||||||
|
float* vec_pred_original_sample =
|
||||||
|
(float*)pred_original_sample->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_pred_original_sample[j] =
|
||||||
|
(vec_x[j] / std::sqrt(sigma * sigma + 1) -
|
||||||
|
std::sqrt(beta_prod_t) *
|
||||||
|
vec_model_output[j]) *
|
||||||
|
(1 / std::sqrt(alpha_prod_t));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// This consistency function step can be difficult to
|
||||||
|
// decipher from Algorithm 4, as it is simply stated
|
||||||
|
// using a consistency function. This step is the
|
||||||
|
// modified DDIM, i.e. p. 8 (32) in Zheng et
|
||||||
|
// al. (2024), with eta set to 0 (see the paragraph
|
||||||
|
// immediately thereafter that states this somewhat
|
||||||
|
// obliquely).
|
||||||
|
{
|
||||||
|
float* vec_pred_original_sample =
|
||||||
|
(float*)pred_original_sample->data;
|
||||||
|
float* vec_model_output =
|
||||||
|
(float*)model_output->data;
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
// Substituting x = pred_noised_sample and
|
||||||
|
// pred_epsilon = model_output
|
||||||
|
vec_x[j] =
|
||||||
|
std::sqrt(alpha_prod_s) *
|
||||||
|
vec_pred_original_sample[j] +
|
||||||
|
std::sqrt(beta_prod_s) *
|
||||||
|
vec_model_output[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 4. Sample and inject noise z ~ N(0, I) for
|
||||||
|
// MultiStep Inference Noise is not used on the final
|
||||||
|
// timestep of the timestep schedule. This also means
|
||||||
|
// that noise is not used for one-step sampling. Eta
|
||||||
|
// (referred to as "gamma" in the paper) was
|
||||||
|
// introduced to control the stochasticity in every
|
||||||
|
// step. When eta = 0, it represents deterministic
|
||||||
|
// sampling, whereas eta = 1 indicates full stochastic
|
||||||
|
// sampling.
|
||||||
|
if (eta > 0 && i != steps - 1) {
|
||||||
|
// In this case, x is still pred_noised_sample,
|
||||||
|
// continue in-place
|
||||||
|
ggml_tensor_set_f32_randn(noise, rng);
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
float* vec_noise = (float*)noise->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
// Corresponding to (35) in Zheng et
|
||||||
|
// al. (2024), substituting x =
|
||||||
|
// pred_noised_sample
|
||||||
|
vec_x[j] =
|
||||||
|
std::sqrt(alpha_prod_t_prev /
|
||||||
|
alpha_prod_s) *
|
||||||
|
vec_x[j] +
|
||||||
|
std::sqrt(1 - alpha_prod_t_prev /
|
||||||
|
alpha_prod_s) *
|
||||||
|
vec_noise[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);
|
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);
|
||||||
|
|
|
@ -133,8 +133,9 @@ struct FluxModel : public DiffusionModel {
|
||||||
|
|
||||||
FluxModel(ggml_backend_t backend,
|
FluxModel(ggml_backend_t backend,
|
||||||
std::map<std::string, enum ggml_type>& tensor_types,
|
std::map<std::string, enum ggml_type>& tensor_types,
|
||||||
|
SDVersion version = VERSION_FLUX,
|
||||||
bool flash_attn = false)
|
bool flash_attn = false)
|
||||||
: flux(backend, tensor_types, "model.diffusion_model", flash_attn) {
|
: flux(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void alloc_params_buffer() {
|
void alloc_params_buffer() {
|
||||||
|
@ -174,7 +175,7 @@ struct FluxModel : public DiffusionModel {
|
||||||
struct ggml_tensor** output = NULL,
|
struct ggml_tensor** output = NULL,
|
||||||
struct ggml_context* output_ctx = NULL,
|
struct ggml_context* output_ctx = NULL,
|
||||||
std::vector<int> skip_layers = std::vector<int>()) {
|
std::vector<int> skip_layers = std::vector<int>()) {
|
||||||
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx, skip_layers);
|
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, output, output_ctx, skip_layers);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -490,6 +490,7 @@ namespace Flux {
|
||||||
|
|
||||||
struct FluxParams {
|
struct FluxParams {
|
||||||
int64_t in_channels = 64;
|
int64_t in_channels = 64;
|
||||||
|
int64_t out_channels = 64;
|
||||||
int64_t vec_in_dim = 768;
|
int64_t vec_in_dim = 768;
|
||||||
int64_t context_in_dim = 4096;
|
int64_t context_in_dim = 4096;
|
||||||
int64_t hidden_size = 3072;
|
int64_t hidden_size = 3072;
|
||||||
|
@ -642,7 +643,6 @@ namespace Flux {
|
||||||
Flux() {}
|
Flux() {}
|
||||||
Flux(FluxParams params)
|
Flux(FluxParams params)
|
||||||
: params(params) {
|
: params(params) {
|
||||||
int64_t out_channels = params.in_channels;
|
|
||||||
int64_t pe_dim = params.hidden_size / params.num_heads;
|
int64_t pe_dim = params.hidden_size / params.num_heads;
|
||||||
|
|
||||||
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
|
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
|
||||||
|
@ -669,7 +669,7 @@ namespace Flux {
|
||||||
params.flash_attn));
|
params.flash_attn));
|
||||||
}
|
}
|
||||||
|
|
||||||
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, out_channels));
|
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, params.out_channels));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* patchify(struct ggml_context* ctx,
|
struct ggml_tensor* patchify(struct ggml_context* ctx,
|
||||||
|
@ -789,6 +789,7 @@ namespace Flux {
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* timestep,
|
struct ggml_tensor* timestep,
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
|
struct ggml_tensor* c_concat,
|
||||||
struct ggml_tensor* y,
|
struct ggml_tensor* y,
|
||||||
struct ggml_tensor* guidance,
|
struct ggml_tensor* guidance,
|
||||||
struct ggml_tensor* pe,
|
struct ggml_tensor* pe,
|
||||||
|
@ -797,6 +798,7 @@ namespace Flux {
|
||||||
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||||
// timestep: (N,) tensor of diffusion timesteps
|
// timestep: (N,) tensor of diffusion timesteps
|
||||||
// context: (N, L, D)
|
// context: (N, L, D)
|
||||||
|
// c_concat: NULL, or for (N,C+M, H, W) for Fill
|
||||||
// y: (N, adm_in_channels) tensor of class labels
|
// y: (N, adm_in_channels) tensor of class labels
|
||||||
// guidance: (N,)
|
// guidance: (N,)
|
||||||
// pe: (L, d_head/2, 2, 2)
|
// pe: (L, d_head/2, 2, 2)
|
||||||
|
@ -806,6 +808,7 @@ namespace Flux {
|
||||||
|
|
||||||
int64_t W = x->ne[0];
|
int64_t W = x->ne[0];
|
||||||
int64_t H = x->ne[1];
|
int64_t H = x->ne[1];
|
||||||
|
int64_t C = x->ne[2];
|
||||||
int64_t patch_size = 2;
|
int64_t patch_size = 2;
|
||||||
int pad_h = (patch_size - H % patch_size) % patch_size;
|
int pad_h = (patch_size - H % patch_size) % patch_size;
|
||||||
int pad_w = (patch_size - W % patch_size) % patch_size;
|
int pad_w = (patch_size - W % patch_size) % patch_size;
|
||||||
|
@ -814,6 +817,19 @@ namespace Flux {
|
||||||
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
|
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
|
||||||
|
|
||||||
|
if (c_concat != NULL) {
|
||||||
|
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
|
||||||
|
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
|
||||||
|
|
||||||
|
masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
|
||||||
|
mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);
|
||||||
|
|
||||||
|
masked = patchify(ctx, masked, patch_size);
|
||||||
|
mask = patchify(ctx, mask, patch_size);
|
||||||
|
|
||||||
|
img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0);
|
||||||
|
}
|
||||||
|
|
||||||
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]
|
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]
|
||||||
|
|
||||||
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
|
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
|
||||||
|
@ -834,12 +850,16 @@ namespace Flux {
|
||||||
FluxRunner(ggml_backend_t backend,
|
FluxRunner(ggml_backend_t backend,
|
||||||
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
|
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
|
||||||
const std::string prefix = "",
|
const std::string prefix = "",
|
||||||
|
SDVersion version = VERSION_FLUX,
|
||||||
bool flash_attn = false)
|
bool flash_attn = false)
|
||||||
: GGMLRunner(backend) {
|
: GGMLRunner(backend) {
|
||||||
flux_params.flash_attn = flash_attn;
|
flux_params.flash_attn = flash_attn;
|
||||||
flux_params.guidance_embed = false;
|
flux_params.guidance_embed = false;
|
||||||
flux_params.depth = 0;
|
flux_params.depth = 0;
|
||||||
flux_params.depth_single_blocks = 0;
|
flux_params.depth_single_blocks = 0;
|
||||||
|
if (version == VERSION_FLUX_FILL) {
|
||||||
|
flux_params.in_channels = 384;
|
||||||
|
}
|
||||||
for (auto pair : tensor_types) {
|
for (auto pair : tensor_types) {
|
||||||
std::string tensor_name = pair.first;
|
std::string tensor_name = pair.first;
|
||||||
if (tensor_name.find("model.diffusion_model.") == std::string::npos)
|
if (tensor_name.find("model.diffusion_model.") == std::string::npos)
|
||||||
|
@ -886,6 +906,7 @@ namespace Flux {
|
||||||
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
|
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
|
||||||
struct ggml_tensor* timesteps,
|
struct ggml_tensor* timesteps,
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
|
struct ggml_tensor* c_concat,
|
||||||
struct ggml_tensor* y,
|
struct ggml_tensor* y,
|
||||||
struct ggml_tensor* guidance,
|
struct ggml_tensor* guidance,
|
||||||
std::vector<int> skip_layers = std::vector<int>()) {
|
std::vector<int> skip_layers = std::vector<int>()) {
|
||||||
|
@ -894,6 +915,9 @@ namespace Flux {
|
||||||
|
|
||||||
x = to_backend(x);
|
x = to_backend(x);
|
||||||
context = to_backend(context);
|
context = to_backend(context);
|
||||||
|
if (c_concat != NULL) {
|
||||||
|
c_concat = to_backend(c_concat);
|
||||||
|
}
|
||||||
y = to_backend(y);
|
y = to_backend(y);
|
||||||
timesteps = to_backend(timesteps);
|
timesteps = to_backend(timesteps);
|
||||||
if (flux_params.guidance_embed) {
|
if (flux_params.guidance_embed) {
|
||||||
|
@ -913,6 +937,7 @@ namespace Flux {
|
||||||
x,
|
x,
|
||||||
timesteps,
|
timesteps,
|
||||||
context,
|
context,
|
||||||
|
c_concat,
|
||||||
y,
|
y,
|
||||||
guidance,
|
guidance,
|
||||||
pe,
|
pe,
|
||||||
|
@ -927,6 +952,7 @@ namespace Flux {
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* timesteps,
|
struct ggml_tensor* timesteps,
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
|
struct ggml_tensor* c_concat,
|
||||||
struct ggml_tensor* y,
|
struct ggml_tensor* y,
|
||||||
struct ggml_tensor* guidance,
|
struct ggml_tensor* guidance,
|
||||||
struct ggml_tensor** output = NULL,
|
struct ggml_tensor** output = NULL,
|
||||||
|
@ -938,7 +964,7 @@ namespace Flux {
|
||||||
// y: [N, adm_in_channels] or [1, adm_in_channels]
|
// y: [N, adm_in_channels] or [1, adm_in_channels]
|
||||||
// guidance: [N, ]
|
// guidance: [N, ]
|
||||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||||
return build_graph(x, timesteps, context, y, guidance, skip_layers);
|
return build_graph(x, timesteps, context, c_concat, y, guidance, skip_layers);
|
||||||
};
|
};
|
||||||
|
|
||||||
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||||
|
@ -978,7 +1004,7 @@ namespace Flux {
|
||||||
struct ggml_tensor* out = NULL;
|
struct ggml_tensor* out = NULL;
|
||||||
|
|
||||||
int t0 = ggml_time_ms();
|
int t0 = ggml_time_ms();
|
||||||
compute(8, x, timesteps, context, y, guidance, &out, work_ctx);
|
compute(8, x, timesteps, context, NULL, y, guidance, &out, work_ctx);
|
||||||
int t1 = ggml_time_ms();
|
int t1 = ggml_time_ms();
|
||||||
|
|
||||||
print_ggml_tensor(out);
|
print_ggml_tensor(out);
|
||||||
|
|
|
@ -52,6 +52,71 @@
|
||||||
#define __STATIC_INLINE__ static inline
|
#define __STATIC_INLINE__ static inline
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// n-mode trensor-matrix product
|
||||||
|
// example: 2-mode product
|
||||||
|
// A: [ne03, k, ne01, ne00]
|
||||||
|
// B: k rows, m columns => [k, m]
|
||||||
|
// result is [ne03, m, ne01, ne00]
|
||||||
|
__STATIC_INLINE__ struct ggml_tensor* ggml_mul_n_mode(struct ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b, int mode = 0) {
|
||||||
|
// reshape A
|
||||||
|
// swap 0th and nth axis
|
||||||
|
a = ggml_cont(ctx, ggml_permute(ctx, a, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0));
|
||||||
|
int ne1 = a->ne[1];
|
||||||
|
int ne2 = a->ne[2];
|
||||||
|
int ne3 = a->ne[3];
|
||||||
|
// make 2D
|
||||||
|
a = ggml_cont(ctx, ggml_reshape_2d(ctx, a, a->ne[0], (ne3 * ne2 * ne1)));
|
||||||
|
|
||||||
|
struct ggml_tensor* result = ggml_cont(ctx, ggml_transpose(ctx, ggml_mul_mat(ctx, a, b)));
|
||||||
|
|
||||||
|
// reshape output (same shape as a after permutation except first dim)
|
||||||
|
result = ggml_reshape_4d(ctx, result, result->ne[0], ne1, ne2, ne3);
|
||||||
|
// swap back 0th and nth axis
|
||||||
|
result = ggml_permute(ctx, result, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
__STATIC_INLINE__ struct ggml_tensor* ggml_merge_lora(ggml_context* ctx, struct ggml_tensor* lora_down, struct ggml_tensor* lora_up, struct ggml_tensor* lora_mid = NULL) {
|
||||||
|
struct ggml_tensor* updown;
|
||||||
|
// flat lora tensors to multiply it
|
||||||
|
int64_t lora_up_rows = lora_up->ne[ggml_n_dims(lora_up) - 1];
|
||||||
|
lora_up = ggml_reshape_2d(ctx, lora_up, ggml_nelements(lora_up) / lora_up_rows, lora_up_rows);
|
||||||
|
auto lora_down_n_dims = ggml_n_dims(lora_down);
|
||||||
|
// assume n_dims should always be a multiple of 2 (otherwise rank 1 doesn't work)
|
||||||
|
lora_down_n_dims = (lora_down_n_dims + lora_down_n_dims % 2);
|
||||||
|
int64_t lora_down_rows = lora_down->ne[lora_down_n_dims - 1];
|
||||||
|
lora_down = ggml_reshape_2d(ctx, lora_down, ggml_nelements(lora_down) / lora_down_rows, lora_down_rows);
|
||||||
|
|
||||||
|
// ggml_mul_mat requires tensor b transposed
|
||||||
|
lora_down = ggml_cont(ctx, ggml_transpose(ctx, lora_down));
|
||||||
|
if (lora_mid == NULL) {
|
||||||
|
updown = ggml_mul_mat(ctx, lora_up, lora_down);
|
||||||
|
updown = ggml_cont(ctx, ggml_transpose(ctx, updown));
|
||||||
|
} else {
|
||||||
|
// undoing tucker decomposition for conv layers.
|
||||||
|
// lora_mid has shape (3, 3, Rank, Rank)
|
||||||
|
// lora_down has shape (Rank, In, 1, 1)
|
||||||
|
// lora_up has shape (Rank, Out, 1, 1)
|
||||||
|
// conv layer shape is (3, 3, Out, In)
|
||||||
|
updown = ggml_mul_n_mode(ctx, ggml_mul_n_mode(ctx, lora_mid, lora_down, 3), lora_up, 2);
|
||||||
|
updown = ggml_cont(ctx, updown);
|
||||||
|
}
|
||||||
|
return updown;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Kronecker product
|
||||||
|
// [ne03,ne02,ne01,ne00] x [ne13,ne12,ne11,ne10] => [ne03*ne13,ne02*ne12,ne01*ne11,ne00*ne10]
|
||||||
|
__STATIC_INLINE__ struct ggml_tensor* ggml_kronecker(ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b) {
|
||||||
|
return ggml_mul(ctx,
|
||||||
|
ggml_upscale_ext(ctx,
|
||||||
|
a,
|
||||||
|
a->ne[0] * b->ne[0],
|
||||||
|
a->ne[1] * b->ne[1],
|
||||||
|
a->ne[2] * b->ne[2],
|
||||||
|
a->ne[3] * b->ne[3]),
|
||||||
|
b);
|
||||||
|
}
|
||||||
|
|
||||||
__STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const char* text, void* user_data) {
|
__STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const char* text, void* user_data) {
|
||||||
(void)level;
|
(void)level;
|
||||||
(void)user_data;
|
(void)user_data;
|
||||||
|
@ -290,6 +355,44 @@ __STATIC_INLINE__ void sd_image_to_tensor(const uint8_t* image_data,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__STATIC_INLINE__ void sd_mask_to_tensor(const uint8_t* image_data,
|
||||||
|
struct ggml_tensor* output,
|
||||||
|
bool scale = true) {
|
||||||
|
int64_t width = output->ne[0];
|
||||||
|
int64_t height = output->ne[1];
|
||||||
|
int64_t channels = output->ne[2];
|
||||||
|
GGML_ASSERT(channels == 1 && output->type == GGML_TYPE_F32);
|
||||||
|
for (int iy = 0; iy < height; iy++) {
|
||||||
|
for (int ix = 0; ix < width; ix++) {
|
||||||
|
float value = *(image_data + iy * width * channels + ix);
|
||||||
|
if (scale) {
|
||||||
|
value /= 255.f;
|
||||||
|
}
|
||||||
|
ggml_tensor_set_f32(output, value, ix, iy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
|
||||||
|
struct ggml_tensor* mask,
|
||||||
|
struct ggml_tensor* output) {
|
||||||
|
int64_t width = output->ne[0];
|
||||||
|
int64_t height = output->ne[1];
|
||||||
|
int64_t channels = output->ne[2];
|
||||||
|
GGML_ASSERT(output->type == GGML_TYPE_F32);
|
||||||
|
for (int ix = 0; ix < width; ix++) {
|
||||||
|
for (int iy = 0; iy < height; iy++) {
|
||||||
|
float m = ggml_tensor_get_f32(mask, ix, iy);
|
||||||
|
m = round(m); // inpaint models need binary masks
|
||||||
|
ggml_tensor_set_f32(mask, m, ix, iy);
|
||||||
|
for (int k = 0; k < channels; k++) {
|
||||||
|
float value = (1 - m) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
|
||||||
|
ggml_tensor_set_f32(output, value, ix, iy, k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
__STATIC_INLINE__ void sd_mul_images_to_tensor(const uint8_t* image_data,
|
__STATIC_INLINE__ void sd_mul_images_to_tensor(const uint8_t* image_data,
|
||||||
struct ggml_tensor* output,
|
struct ggml_tensor* output,
|
||||||
int idx,
|
int idx,
|
||||||
|
@ -951,8 +1054,8 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* SDXL with LoRA requires more space */
|
/* SDXL with LoRA requires more space */
|
||||||
#define MAX_PARAMS_TENSOR_NUM 15360
|
#define MAX_PARAMS_TENSOR_NUM 32768
|
||||||
#define MAX_GRAPH_SIZE 15360
|
#define MAX_GRAPH_SIZE 32768
|
||||||
|
|
||||||
struct GGMLRunner {
|
struct GGMLRunner {
|
||||||
protected:
|
protected:
|
||||||
|
|
|
@ -329,21 +329,21 @@ const std::vector<std::vector<float>> GITS_NOISE_1_50 = {
|
||||||
};
|
};
|
||||||
|
|
||||||
const std::vector<const std::vector<std::vector<float>>*> GITS_NOISE = {
|
const std::vector<const std::vector<std::vector<float>>*> GITS_NOISE = {
|
||||||
{ &GITS_NOISE_0_80 },
|
&GITS_NOISE_0_80,
|
||||||
{ &GITS_NOISE_0_85 },
|
&GITS_NOISE_0_85,
|
||||||
{ &GITS_NOISE_0_90 },
|
&GITS_NOISE_0_90,
|
||||||
{ &GITS_NOISE_0_95 },
|
&GITS_NOISE_0_95,
|
||||||
{ &GITS_NOISE_1_00 },
|
&GITS_NOISE_1_00,
|
||||||
{ &GITS_NOISE_1_05 },
|
&GITS_NOISE_1_05,
|
||||||
{ &GITS_NOISE_1_10 },
|
&GITS_NOISE_1_10,
|
||||||
{ &GITS_NOISE_1_15 },
|
&GITS_NOISE_1_15,
|
||||||
{ &GITS_NOISE_1_20 },
|
&GITS_NOISE_1_20,
|
||||||
{ &GITS_NOISE_1_25 },
|
&GITS_NOISE_1_25,
|
||||||
{ &GITS_NOISE_1_30 },
|
&GITS_NOISE_1_30,
|
||||||
{ &GITS_NOISE_1_35 },
|
&GITS_NOISE_1_35,
|
||||||
{ &GITS_NOISE_1_40 },
|
&GITS_NOISE_1_40,
|
||||||
{ &GITS_NOISE_1_45 },
|
&GITS_NOISE_1_45,
|
||||||
{ &GITS_NOISE_1_50 }
|
&GITS_NOISE_1_50
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // GITS_NOISE_INL
|
#endif // GITS_NOISE_INL
|
|
@ -6,6 +6,90 @@
|
||||||
#define LORA_GRAPH_SIZE 10240
|
#define LORA_GRAPH_SIZE 10240
|
||||||
|
|
||||||
struct LoraModel : public GGMLRunner {
|
struct LoraModel : public GGMLRunner {
|
||||||
|
enum lora_t {
|
||||||
|
REGULAR = 0,
|
||||||
|
DIFFUSERS = 1,
|
||||||
|
DIFFUSERS_2 = 2,
|
||||||
|
DIFFUSERS_3 = 3,
|
||||||
|
TRANSFORMERS = 4,
|
||||||
|
LORA_TYPE_COUNT
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::string lora_ups[LORA_TYPE_COUNT] = {
|
||||||
|
".lora_up",
|
||||||
|
"_lora.up",
|
||||||
|
".lora_B",
|
||||||
|
".lora.up",
|
||||||
|
".lora_linear_layer.up",
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::string lora_downs[LORA_TYPE_COUNT] = {
|
||||||
|
".lora_down",
|
||||||
|
"_lora.down",
|
||||||
|
".lora_A",
|
||||||
|
".lora.down",
|
||||||
|
".lora_linear_layer.down",
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::string lora_pre[LORA_TYPE_COUNT] = {
|
||||||
|
"lora.",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::map<std::string, std::string> alt_names = {
|
||||||
|
// mmdit
|
||||||
|
{"final_layer.adaLN_modulation.1", "norm_out.linear"},
|
||||||
|
{"pos_embed", "pos_embed.proj"},
|
||||||
|
{"final_layer.linear", "proj_out"},
|
||||||
|
{"y_embedder.mlp.0", "time_text_embed.text_embedder.linear_1"},
|
||||||
|
{"y_embedder.mlp.2", "time_text_embed.text_embedder.linear_2"},
|
||||||
|
{"t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1"},
|
||||||
|
{"t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2"},
|
||||||
|
{"x_block.mlp.fc1", "ff.net.0.proj"},
|
||||||
|
{"x_block.mlp.fc2", "ff.net.2"},
|
||||||
|
{"context_block.mlp.fc1", "ff_context.net.0.proj"},
|
||||||
|
{"context_block.mlp.fc2", "ff_context.net.2"},
|
||||||
|
{"x_block.adaLN_modulation.1", "norm1.linear"},
|
||||||
|
{"context_block.adaLN_modulation.1", "norm1_context.linear"},
|
||||||
|
{"context_block.attn.proj", "attn.to_add_out"},
|
||||||
|
{"x_block.attn.proj", "attn.to_out.0"},
|
||||||
|
{"x_block.attn2.proj", "attn2.to_out.0"},
|
||||||
|
// flux
|
||||||
|
// singlestream
|
||||||
|
{"linear2", "proj_out"},
|
||||||
|
{"modulation.lin", "norm.linear"},
|
||||||
|
// doublestream
|
||||||
|
{"txt_attn.proj", "attn.to_add_out"},
|
||||||
|
{"img_attn.proj", "attn.to_out.0"},
|
||||||
|
{"txt_mlp.0", "ff_context.net.0.proj"},
|
||||||
|
{"txt_mlp.2", "ff_context.net.2"},
|
||||||
|
{"img_mlp.0", "ff.net.0.proj"},
|
||||||
|
{"img_mlp.2", "ff.net.2"},
|
||||||
|
{"txt_mod.lin", "norm1_context.linear"},
|
||||||
|
{"img_mod.lin", "norm1.linear"},
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::map<std::string, std::string> qkv_prefixes = {
|
||||||
|
// mmdit
|
||||||
|
{"context_block.attn.qkv", "attn.add_"}, // suffix "_proj"
|
||||||
|
{"x_block.attn.qkv", "attn.to_"},
|
||||||
|
{"x_block.attn2.qkv", "attn2.to_"},
|
||||||
|
// flux
|
||||||
|
// doublestream
|
||||||
|
{"txt_attn.qkv", "attn.add_"}, // suffix "_proj"
|
||||||
|
{"img_attn.qkv", "attn.to_"},
|
||||||
|
};
|
||||||
|
const std::map<std::string, std::string> qkvm_prefixes = {
|
||||||
|
// flux
|
||||||
|
// singlestream
|
||||||
|
{"linear1", ""},
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::string* type_fingerprints = lora_ups;
|
||||||
|
|
||||||
float multiplier = 1.0f;
|
float multiplier = 1.0f;
|
||||||
std::map<std::string, struct ggml_tensor*> lora_tensors;
|
std::map<std::string, struct ggml_tensor*> lora_tensors;
|
||||||
std::string file_path;
|
std::string file_path;
|
||||||
|
@ -14,6 +98,7 @@ struct LoraModel : public GGMLRunner {
|
||||||
bool applied = false;
|
bool applied = false;
|
||||||
std::vector<int> zero_index_vec = {0};
|
std::vector<int> zero_index_vec = {0};
|
||||||
ggml_tensor* zero_index = NULL;
|
ggml_tensor* zero_index = NULL;
|
||||||
|
enum lora_t type = REGULAR;
|
||||||
|
|
||||||
LoraModel(ggml_backend_t backend,
|
LoraModel(ggml_backend_t backend,
|
||||||
const std::string& file_path = "",
|
const std::string& file_path = "",
|
||||||
|
@ -44,6 +129,13 @@ struct LoraModel : public GGMLRunner {
|
||||||
// LOG_INFO("skipping LoRA tesnor '%s'", name.c_str());
|
// LOG_INFO("skipping LoRA tesnor '%s'", name.c_str());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
// LOG_INFO("%s", name.c_str());
|
||||||
|
for (int i = 0; i < LORA_TYPE_COUNT; i++) {
|
||||||
|
if (name.find(type_fingerprints[i]) != std::string::npos) {
|
||||||
|
type = (lora_t)i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (dry_run) {
|
if (dry_run) {
|
||||||
struct ggml_tensor* real = ggml_new_tensor(params_ctx,
|
struct ggml_tensor* real = ggml_new_tensor(params_ctx,
|
||||||
|
@ -61,10 +153,12 @@ struct LoraModel : public GGMLRunner {
|
||||||
|
|
||||||
model_loader.load_tensors(on_new_tensor_cb, backend);
|
model_loader.load_tensors(on_new_tensor_cb, backend);
|
||||||
alloc_params_buffer();
|
alloc_params_buffer();
|
||||||
|
// exit(0);
|
||||||
dry_run = false;
|
dry_run = false;
|
||||||
model_loader.load_tensors(on_new_tensor_cb, backend);
|
model_loader.load_tensors(on_new_tensor_cb, backend);
|
||||||
|
|
||||||
|
LOG_DEBUG("lora type: \"%s\"/\"%s\"", lora_downs[type].c_str(), lora_ups[type].c_str());
|
||||||
|
|
||||||
LOG_DEBUG("finished loaded lora");
|
LOG_DEBUG("finished loaded lora");
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -76,7 +170,74 @@ struct LoraModel : public GGMLRunner {
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_cgraph* build_lora_graph(std::map<std::string, struct ggml_tensor*> model_tensors) {
|
std::vector<std::string> to_lora_keys(std::string blk_name, SDVersion version) {
|
||||||
|
std::vector<std::string> keys;
|
||||||
|
// if (!sd_version_is_sd3(version) || blk_name != "model.diffusion_model.pos_embed") {
|
||||||
|
size_t k_pos = blk_name.find(".weight");
|
||||||
|
if (k_pos == std::string::npos) {
|
||||||
|
return keys;
|
||||||
|
}
|
||||||
|
blk_name = blk_name.substr(0, k_pos);
|
||||||
|
// }
|
||||||
|
keys.push_back(blk_name);
|
||||||
|
keys.push_back("lora." + blk_name);
|
||||||
|
if (sd_version_is_dit(version)) {
|
||||||
|
if (blk_name.find("model.diffusion_model") != std::string::npos) {
|
||||||
|
blk_name.replace(blk_name.find("model.diffusion_model"), sizeof("model.diffusion_model") - 1, "transformer");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (blk_name.find(".single_blocks") != std::string::npos) {
|
||||||
|
blk_name.replace(blk_name.find(".single_blocks"), sizeof(".single_blocks") - 1, ".single_transformer_blocks");
|
||||||
|
}
|
||||||
|
if (blk_name.find(".double_blocks") != std::string::npos) {
|
||||||
|
blk_name.replace(blk_name.find(".double_blocks"), sizeof(".double_blocks") - 1, ".transformer_blocks");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (blk_name.find(".joint_blocks") != std::string::npos) {
|
||||||
|
blk_name.replace(blk_name.find(".joint_blocks"), sizeof(".joint_blocks") - 1, ".transformer_blocks");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (blk_name.find("text_encoders.clip_l") != std::string::npos) {
|
||||||
|
blk_name.replace(blk_name.find("text_encoders.clip_l"), sizeof("text_encoders.clip_l") - 1, "cond_stage_model");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& item : alt_names) {
|
||||||
|
size_t match = blk_name.find(item.first);
|
||||||
|
if (match != std::string::npos) {
|
||||||
|
blk_name = blk_name.substr(0, match) + item.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const auto& prefix : qkv_prefixes) {
|
||||||
|
size_t match = blk_name.find(prefix.first);
|
||||||
|
if (match != std::string::npos) {
|
||||||
|
std::string split_blk = "SPLIT|" + blk_name.substr(0, match) + prefix.second;
|
||||||
|
keys.push_back(split_blk);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const auto& prefix : qkvm_prefixes) {
|
||||||
|
size_t match = blk_name.find(prefix.first);
|
||||||
|
if (match != std::string::npos) {
|
||||||
|
std::string split_blk = "SPLIT_L|" + blk_name.substr(0, match) + prefix.second;
|
||||||
|
keys.push_back(split_blk);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
keys.push_back(blk_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> ret;
|
||||||
|
for (std::string& key : keys) {
|
||||||
|
ret.push_back(key);
|
||||||
|
replace_all_chars(key, '.', '_');
|
||||||
|
// fix for some sdxl lora, like lcm-lora-xl
|
||||||
|
if (key == "model_diffusion_model_output_blocks_2_2_conv") {
|
||||||
|
ret.push_back("model_diffusion_model_output_blocks_2_1_conv");
|
||||||
|
}
|
||||||
|
ret.push_back(key);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph* build_lora_graph(std::map<std::string, struct ggml_tensor*> model_tensors, SDVersion version) {
|
||||||
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, LORA_GRAPH_SIZE, false);
|
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, LORA_GRAPH_SIZE, false);
|
||||||
|
|
||||||
zero_index = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_I32, 1);
|
zero_index = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_I32, 1);
|
||||||
|
@ -88,69 +249,547 @@ struct LoraModel : public GGMLRunner {
|
||||||
std::string k_tensor = it.first;
|
std::string k_tensor = it.first;
|
||||||
struct ggml_tensor* weight = model_tensors[it.first];
|
struct ggml_tensor* weight = model_tensors[it.first];
|
||||||
|
|
||||||
size_t k_pos = k_tensor.find(".weight");
|
std::vector<std::string> keys = to_lora_keys(k_tensor, version);
|
||||||
if (k_pos == std::string::npos) {
|
if (keys.size() == 0)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
for (auto& key : keys) {
|
||||||
|
bool is_qkv_split = starts_with(key, "SPLIT|");
|
||||||
|
if (is_qkv_split) {
|
||||||
|
key = key.substr(sizeof("SPLIT|") - 1);
|
||||||
|
}
|
||||||
|
bool is_qkvm_split = starts_with(key, "SPLIT_L|");
|
||||||
|
if (is_qkvm_split) {
|
||||||
|
key = key.substr(sizeof("SPLIT_L|") - 1);
|
||||||
|
}
|
||||||
|
struct ggml_tensor* updown = NULL;
|
||||||
|
float scale_value = 1.0f;
|
||||||
|
std::string fk = lora_pre[type] + key;
|
||||||
|
if (lora_tensors.find(fk + ".hada_w1_a") != lora_tensors.end()) {
|
||||||
|
// LoHa mode
|
||||||
|
|
||||||
|
// TODO: split qkv convention for LoHas (is it ever used?)
|
||||||
|
if (is_qkv_split || is_qkvm_split) {
|
||||||
|
LOG_ERROR("Split qkv isn't supported for LoHa models.");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::string alpha_name = "";
|
||||||
|
|
||||||
|
ggml_tensor* hada_1_mid = NULL; // tau for tucker decomposition
|
||||||
|
ggml_tensor* hada_1_up = NULL;
|
||||||
|
ggml_tensor* hada_1_down = NULL;
|
||||||
|
|
||||||
|
ggml_tensor* hada_2_mid = NULL; // tau for tucker decomposition
|
||||||
|
ggml_tensor* hada_2_up = NULL;
|
||||||
|
ggml_tensor* hada_2_down = NULL;
|
||||||
|
|
||||||
|
std::string hada_1_mid_name = "";
|
||||||
|
std::string hada_1_down_name = "";
|
||||||
|
std::string hada_1_up_name = "";
|
||||||
|
|
||||||
|
std::string hada_2_mid_name = "";
|
||||||
|
std::string hada_2_down_name = "";
|
||||||
|
std::string hada_2_up_name = "";
|
||||||
|
|
||||||
|
|
||||||
|
hada_1_down_name = fk + ".hada_w1_b";
|
||||||
|
hada_1_up_name = fk + ".hada_w1_a";
|
||||||
|
hada_1_mid_name = fk + ".hada_t1";
|
||||||
|
if (lora_tensors.find(hada_1_down_name) != lora_tensors.end()) {
|
||||||
|
hada_1_down = to_f32(compute_ctx, lora_tensors[hada_1_down_name]);
|
||||||
|
}
|
||||||
|
if (lora_tensors.find(hada_1_up_name) != lora_tensors.end()) {
|
||||||
|
hada_1_up = to_f32(compute_ctx, lora_tensors[hada_1_up_name]);
|
||||||
|
}
|
||||||
|
if (lora_tensors.find(hada_1_mid_name) != lora_tensors.end()) {
|
||||||
|
hada_1_mid = to_f32(compute_ctx, lora_tensors[hada_1_mid_name]);
|
||||||
|
applied_lora_tensors.insert(hada_1_mid_name);
|
||||||
|
hada_1_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_1_up));
|
||||||
|
}
|
||||||
|
|
||||||
|
hada_2_down_name = fk + ".hada_w2_b";
|
||||||
|
hada_2_up_name = fk + ".hada_w2_a";
|
||||||
|
hada_2_mid_name = fk + ".hada_t2";
|
||||||
|
if (lora_tensors.find(hada_2_down_name) != lora_tensors.end()) {
|
||||||
|
hada_2_down = to_f32(compute_ctx, lora_tensors[hada_2_down_name]);
|
||||||
|
}
|
||||||
|
if (lora_tensors.find(hada_2_up_name) != lora_tensors.end()) {
|
||||||
|
hada_2_up = to_f32(compute_ctx, lora_tensors[hada_2_up_name]);
|
||||||
|
}
|
||||||
|
if (lora_tensors.find(hada_2_mid_name) != lora_tensors.end()) {
|
||||||
|
hada_2_mid = to_f32(compute_ctx, lora_tensors[hada_2_mid_name]);
|
||||||
|
applied_lora_tensors.insert(hada_2_mid_name);
|
||||||
|
hada_2_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_2_up));
|
||||||
|
}
|
||||||
|
|
||||||
|
alpha_name = fk + ".alpha";
|
||||||
|
|
||||||
|
applied_lora_tensors.insert(hada_1_down_name);
|
||||||
|
applied_lora_tensors.insert(hada_1_up_name);
|
||||||
|
applied_lora_tensors.insert(hada_2_down_name);
|
||||||
|
applied_lora_tensors.insert(hada_2_up_name);
|
||||||
|
|
||||||
|
applied_lora_tensors.insert(alpha_name);
|
||||||
|
if (hada_1_up == NULL || hada_1_down == NULL || hada_2_up == NULL || hada_2_down == NULL) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
k_tensor = k_tensor.substr(0, k_pos);
|
|
||||||
replace_all_chars(k_tensor, '.', '_');
|
struct ggml_tensor* updown_1 = ggml_merge_lora(compute_ctx, hada_1_down, hada_1_up, hada_1_mid);
|
||||||
// LOG_DEBUG("k_tensor %s", k_tensor.c_str());
|
struct ggml_tensor* updown_2 = ggml_merge_lora(compute_ctx, hada_2_down, hada_2_up, hada_2_mid);
|
||||||
std::string lora_up_name = "lora." + k_tensor + ".lora_up.weight";
|
updown = ggml_mul_inplace(compute_ctx, updown_1, updown_2);
|
||||||
if (lora_tensors.find(lora_up_name) == lora_tensors.end()) {
|
|
||||||
if (k_tensor == "model_diffusion_model_output_blocks_2_2_conv") {
|
// calc_scale
|
||||||
// fix for some sdxl lora, like lcm-lora-xl
|
// TODO: .dora_scale?
|
||||||
k_tensor = "model_diffusion_model_output_blocks_2_1_conv";
|
int64_t rank = hada_1_down->ne[ggml_n_dims(hada_1_down) - 1];
|
||||||
lora_up_name = "lora." + k_tensor + ".lora_up.weight";
|
if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
|
||||||
|
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
|
||||||
|
scale_value = alpha / rank;
|
||||||
}
|
}
|
||||||
|
} else if (lora_tensors.find(fk + ".lokr_w1") != lora_tensors.end() || lora_tensors.find(fk + ".lokr_w1_a") != lora_tensors.end()) {
|
||||||
|
// LoKr mode
|
||||||
|
|
||||||
|
// TODO: split qkv convention for LoKrs (is it ever used?)
|
||||||
|
if (is_qkv_split || is_qkvm_split) {
|
||||||
|
LOG_ERROR("Split qkv isn't supported for LoKr models.");
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string lora_down_name = "lora." + k_tensor + ".lora_down.weight";
|
std::string alpha_name = fk + ".alpha";
|
||||||
std::string alpha_name = "lora." + k_tensor + ".alpha";
|
|
||||||
std::string scale_name = "lora." + k_tensor + ".scale";
|
|
||||||
|
|
||||||
|
ggml_tensor* lokr_w1 = NULL;
|
||||||
|
ggml_tensor* lokr_w2 = NULL;
|
||||||
|
|
||||||
|
std::string lokr_w1_name = "";
|
||||||
|
std::string lokr_w2_name = "";
|
||||||
|
|
||||||
|
lokr_w1_name = fk + ".lokr_w1";
|
||||||
|
lokr_w2_name = fk + ".lokr_w2";
|
||||||
|
|
||||||
|
if (lora_tensors.find(lokr_w1_name) != lora_tensors.end()) {
|
||||||
|
lokr_w1 = to_f32(compute_ctx, lora_tensors[lokr_w1_name]);
|
||||||
|
applied_lora_tensors.insert(lokr_w1_name);
|
||||||
|
} else {
|
||||||
|
ggml_tensor* down = NULL;
|
||||||
|
ggml_tensor* up = NULL;
|
||||||
|
std::string down_name = lokr_w1_name + "_b";
|
||||||
|
std::string up_name = lokr_w1_name + "_a";
|
||||||
|
if (lora_tensors.find(down_name) != lora_tensors.end()) {
|
||||||
|
// w1 should not be low rank normally, sometimes w1 and w2 are swapped
|
||||||
|
down = to_f32(compute_ctx, lora_tensors[down_name]);
|
||||||
|
applied_lora_tensors.insert(down_name);
|
||||||
|
|
||||||
|
int64_t rank = down->ne[ggml_n_dims(down) - 1];
|
||||||
|
if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
|
||||||
|
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
|
||||||
|
scale_value = alpha / rank;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (lora_tensors.find(up_name) != lora_tensors.end()) {
|
||||||
|
up = to_f32(compute_ctx, lora_tensors[up_name]);
|
||||||
|
applied_lora_tensors.insert(up_name);
|
||||||
|
}
|
||||||
|
lokr_w1 = ggml_merge_lora(compute_ctx, down, up);
|
||||||
|
}
|
||||||
|
if (lora_tensors.find(lokr_w2_name) != lora_tensors.end()) {
|
||||||
|
lokr_w2 = to_f32(compute_ctx, lora_tensors[lokr_w2_name]);
|
||||||
|
applied_lora_tensors.insert(lokr_w2_name);
|
||||||
|
} else {
|
||||||
|
ggml_tensor* down = NULL;
|
||||||
|
ggml_tensor* up = NULL;
|
||||||
|
std::string down_name = lokr_w2_name + "_b";
|
||||||
|
std::string up_name = lokr_w2_name + "_a";
|
||||||
|
if (lora_tensors.find(down_name) != lora_tensors.end()) {
|
||||||
|
down = to_f32(compute_ctx, lora_tensors[down_name]);
|
||||||
|
applied_lora_tensors.insert(down_name);
|
||||||
|
|
||||||
|
int64_t rank = down->ne[ggml_n_dims(down) - 1];
|
||||||
|
if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
|
||||||
|
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
|
||||||
|
scale_value = alpha / rank;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (lora_tensors.find(up_name) != lora_tensors.end()) {
|
||||||
|
up = to_f32(compute_ctx, lora_tensors[up_name]);
|
||||||
|
applied_lora_tensors.insert(up_name);
|
||||||
|
}
|
||||||
|
lokr_w2 = ggml_merge_lora(compute_ctx, down, up);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Technically it might be unused, but I believe it's the expected behavior
|
||||||
|
applied_lora_tensors.insert(alpha_name);
|
||||||
|
|
||||||
|
updown = ggml_kronecker(compute_ctx, lokr_w1, lokr_w2);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// LoRA mode
|
||||||
|
ggml_tensor* lora_mid = NULL; // tau for tucker decomposition
|
||||||
ggml_tensor* lora_up = NULL;
|
ggml_tensor* lora_up = NULL;
|
||||||
ggml_tensor* lora_down = NULL;
|
ggml_tensor* lora_down = NULL;
|
||||||
|
|
||||||
|
std::string alpha_name = "";
|
||||||
|
std::string scale_name = "";
|
||||||
|
std::string split_q_scale_name = "";
|
||||||
|
std::string lora_mid_name = "";
|
||||||
|
std::string lora_down_name = "";
|
||||||
|
std::string lora_up_name = "";
|
||||||
|
|
||||||
|
if (is_qkv_split) {
|
||||||
|
std::string suffix = "";
|
||||||
|
auto split_q_d_name = fk + "q" + suffix + lora_downs[type] + ".weight";
|
||||||
|
|
||||||
|
if (lora_tensors.find(split_q_d_name) == lora_tensors.end()) {
|
||||||
|
suffix = "_proj";
|
||||||
|
split_q_d_name = fk + "q" + suffix + lora_downs[type] + ".weight";
|
||||||
|
}
|
||||||
|
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
|
||||||
|
// print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1]
|
||||||
|
// find qkv and mlp up parts in LoRA model
|
||||||
|
auto split_k_d_name = fk + "k" + suffix + lora_downs[type] + ".weight";
|
||||||
|
auto split_v_d_name = fk + "v" + suffix + lora_downs[type] + ".weight";
|
||||||
|
|
||||||
|
auto split_q_u_name = fk + "q" + suffix + lora_ups[type] + ".weight";
|
||||||
|
auto split_k_u_name = fk + "k" + suffix + lora_ups[type] + ".weight";
|
||||||
|
auto split_v_u_name = fk + "v" + suffix + lora_ups[type] + ".weight";
|
||||||
|
|
||||||
|
auto split_q_scale_name = fk + "q" + suffix + ".scale";
|
||||||
|
auto split_k_scale_name = fk + "k" + suffix + ".scale";
|
||||||
|
auto split_v_scale_name = fk + "v" + suffix + ".scale";
|
||||||
|
|
||||||
|
auto split_q_alpha_name = fk + "q" + suffix + ".alpha";
|
||||||
|
auto split_k_alpha_name = fk + "k" + suffix + ".alpha";
|
||||||
|
auto split_v_alpha_name = fk + "v" + suffix + ".alpha";
|
||||||
|
|
||||||
|
ggml_tensor* lora_q_down = NULL;
|
||||||
|
ggml_tensor* lora_q_up = NULL;
|
||||||
|
ggml_tensor* lora_k_down = NULL;
|
||||||
|
ggml_tensor* lora_k_up = NULL;
|
||||||
|
ggml_tensor* lora_v_down = NULL;
|
||||||
|
ggml_tensor* lora_v_up = NULL;
|
||||||
|
|
||||||
|
lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]);
|
||||||
|
|
||||||
|
if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) {
|
||||||
|
lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) {
|
||||||
|
lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) {
|
||||||
|
lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) {
|
||||||
|
lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) {
|
||||||
|
lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]);
|
||||||
|
}
|
||||||
|
|
||||||
|
float q_rank = lora_q_up->ne[0];
|
||||||
|
float k_rank = lora_k_up->ne[0];
|
||||||
|
float v_rank = lora_v_up->ne[0];
|
||||||
|
|
||||||
|
float lora_q_scale = 1;
|
||||||
|
float lora_k_scale = 1;
|
||||||
|
float lora_v_scale = 1;
|
||||||
|
|
||||||
|
if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) {
|
||||||
|
lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]);
|
||||||
|
applied_lora_tensors.insert(split_q_scale_name);
|
||||||
|
}
|
||||||
|
if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) {
|
||||||
|
lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]);
|
||||||
|
applied_lora_tensors.insert(split_k_scale_name);
|
||||||
|
}
|
||||||
|
if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) {
|
||||||
|
lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]);
|
||||||
|
applied_lora_tensors.insert(split_v_scale_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) {
|
||||||
|
float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]);
|
||||||
|
applied_lora_tensors.insert(split_q_alpha_name);
|
||||||
|
lora_q_scale = lora_q_alpha / q_rank;
|
||||||
|
}
|
||||||
|
if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) {
|
||||||
|
float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]);
|
||||||
|
applied_lora_tensors.insert(split_k_alpha_name);
|
||||||
|
lora_k_scale = lora_k_alpha / k_rank;
|
||||||
|
}
|
||||||
|
if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) {
|
||||||
|
float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]);
|
||||||
|
applied_lora_tensors.insert(split_v_alpha_name);
|
||||||
|
lora_v_scale = lora_v_alpha / v_rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale);
|
||||||
|
ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale);
|
||||||
|
ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale);
|
||||||
|
|
||||||
|
// print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1]
|
||||||
|
// print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1]
|
||||||
|
// print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1]
|
||||||
|
// print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1]
|
||||||
|
// print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1]
|
||||||
|
// print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1]
|
||||||
|
|
||||||
|
// these need to be stitched together this way:
|
||||||
|
// |q_up,0 ,0 |
|
||||||
|
// |0 ,k_up,0 |
|
||||||
|
// |0 ,0 ,v_up|
|
||||||
|
// (q_down,k_down,v_down) . (q ,k ,v)
|
||||||
|
|
||||||
|
// up_concat will be [9216, R*3, 1, 1]
|
||||||
|
// down_concat will be [R*3, 3072, 1, 1]
|
||||||
|
ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), lora_v_down, 1);
|
||||||
|
|
||||||
|
ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up);
|
||||||
|
ggml_scale(compute_ctx, z, 0);
|
||||||
|
ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1);
|
||||||
|
|
||||||
|
ggml_tensor* q_up = ggml_concat(compute_ctx, lora_q_up, zz, 1);
|
||||||
|
ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), z, 1);
|
||||||
|
ggml_tensor* v_up = ggml_concat(compute_ctx, zz, lora_v_up, 1);
|
||||||
|
// print_ggml_tensor(q_up, true); //[R, 9216, 1, 1]
|
||||||
|
// print_ggml_tensor(k_up, true); //[R, 9216, 1, 1]
|
||||||
|
// print_ggml_tensor(v_up, true); //[R, 9216, 1, 1]
|
||||||
|
ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), v_up, 0);
|
||||||
|
// print_ggml_tensor(lora_up_concat, true); //[R*3, 9216, 1, 1]
|
||||||
|
|
||||||
|
lora_down = ggml_cont(compute_ctx, lora_down_concat);
|
||||||
|
lora_up = ggml_cont(compute_ctx, lora_up_concat);
|
||||||
|
|
||||||
|
applied_lora_tensors.insert(split_q_u_name);
|
||||||
|
applied_lora_tensors.insert(split_k_u_name);
|
||||||
|
applied_lora_tensors.insert(split_v_u_name);
|
||||||
|
|
||||||
|
applied_lora_tensors.insert(split_q_d_name);
|
||||||
|
applied_lora_tensors.insert(split_k_d_name);
|
||||||
|
applied_lora_tensors.insert(split_v_d_name);
|
||||||
|
}
|
||||||
|
} else if (is_qkvm_split) {
|
||||||
|
auto split_q_d_name = fk + "attn.to_q" + lora_downs[type] + ".weight";
|
||||||
|
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
|
||||||
|
// print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1]
|
||||||
|
// find qkv and mlp up parts in LoRA model
|
||||||
|
auto split_k_d_name = fk + "attn.to_k" + lora_downs[type] + ".weight";
|
||||||
|
auto split_v_d_name = fk + "attn.to_v" + lora_downs[type] + ".weight";
|
||||||
|
|
||||||
|
auto split_q_u_name = fk + "attn.to_q" + lora_ups[type] + ".weight";
|
||||||
|
auto split_k_u_name = fk + "attn.to_k" + lora_ups[type] + ".weight";
|
||||||
|
auto split_v_u_name = fk + "attn.to_v" + lora_ups[type] + ".weight";
|
||||||
|
|
||||||
|
auto split_m_d_name = fk + "proj_mlp" + lora_downs[type] + ".weight";
|
||||||
|
auto split_m_u_name = fk + "proj_mlp" + lora_ups[type] + ".weight";
|
||||||
|
|
||||||
|
auto split_q_scale_name = fk + "attn.to_q" + ".scale";
|
||||||
|
auto split_k_scale_name = fk + "attn.to_k" + ".scale";
|
||||||
|
auto split_v_scale_name = fk + "attn.to_v" + ".scale";
|
||||||
|
auto split_m_scale_name = fk + "proj_mlp" + ".scale";
|
||||||
|
|
||||||
|
auto split_q_alpha_name = fk + "attn.to_q" + ".alpha";
|
||||||
|
auto split_k_alpha_name = fk + "attn.to_k" + ".alpha";
|
||||||
|
auto split_v_alpha_name = fk + "attn.to_v" + ".alpha";
|
||||||
|
auto split_m_alpha_name = fk + "proj_mlp" + ".alpha";
|
||||||
|
|
||||||
|
ggml_tensor* lora_q_down = NULL;
|
||||||
|
ggml_tensor* lora_q_up = NULL;
|
||||||
|
ggml_tensor* lora_k_down = NULL;
|
||||||
|
ggml_tensor* lora_k_up = NULL;
|
||||||
|
ggml_tensor* lora_v_down = NULL;
|
||||||
|
ggml_tensor* lora_v_up = NULL;
|
||||||
|
|
||||||
|
ggml_tensor* lora_m_down = NULL;
|
||||||
|
ggml_tensor* lora_m_up = NULL;
|
||||||
|
|
||||||
|
lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]);
|
||||||
|
|
||||||
|
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
|
||||||
|
lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) {
|
||||||
|
lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) {
|
||||||
|
lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) {
|
||||||
|
lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) {
|
||||||
|
lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) {
|
||||||
|
lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lora_tensors.find(split_m_d_name) != lora_tensors.end()) {
|
||||||
|
lora_m_down = to_f32(compute_ctx, lora_tensors[split_m_d_name]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lora_tensors.find(split_m_u_name) != lora_tensors.end()) {
|
||||||
|
lora_m_up = to_f32(compute_ctx, lora_tensors[split_m_u_name]);
|
||||||
|
}
|
||||||
|
|
||||||
|
float q_rank = lora_q_up->ne[0];
|
||||||
|
float k_rank = lora_k_up->ne[0];
|
||||||
|
float v_rank = lora_v_up->ne[0];
|
||||||
|
float m_rank = lora_v_up->ne[0];
|
||||||
|
|
||||||
|
float lora_q_scale = 1;
|
||||||
|
float lora_k_scale = 1;
|
||||||
|
float lora_v_scale = 1;
|
||||||
|
float lora_m_scale = 1;
|
||||||
|
|
||||||
|
if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) {
|
||||||
|
lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]);
|
||||||
|
applied_lora_tensors.insert(split_q_scale_name);
|
||||||
|
}
|
||||||
|
if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) {
|
||||||
|
lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]);
|
||||||
|
applied_lora_tensors.insert(split_k_scale_name);
|
||||||
|
}
|
||||||
|
if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) {
|
||||||
|
lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]);
|
||||||
|
applied_lora_tensors.insert(split_v_scale_name);
|
||||||
|
}
|
||||||
|
if (lora_tensors.find(split_m_scale_name) != lora_tensors.end()) {
|
||||||
|
lora_m_scale = ggml_backend_tensor_get_f32(lora_tensors[split_m_scale_name]);
|
||||||
|
applied_lora_tensors.insert(split_m_scale_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) {
|
||||||
|
float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]);
|
||||||
|
applied_lora_tensors.insert(split_q_alpha_name);
|
||||||
|
lora_q_scale = lora_q_alpha / q_rank;
|
||||||
|
}
|
||||||
|
if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) {
|
||||||
|
float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]);
|
||||||
|
applied_lora_tensors.insert(split_k_alpha_name);
|
||||||
|
lora_k_scale = lora_k_alpha / k_rank;
|
||||||
|
}
|
||||||
|
if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) {
|
||||||
|
float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]);
|
||||||
|
applied_lora_tensors.insert(split_v_alpha_name);
|
||||||
|
lora_v_scale = lora_v_alpha / v_rank;
|
||||||
|
}
|
||||||
|
if (lora_tensors.find(split_m_alpha_name) != lora_tensors.end()) {
|
||||||
|
float lora_m_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_m_alpha_name]);
|
||||||
|
applied_lora_tensors.insert(split_m_alpha_name);
|
||||||
|
lora_m_scale = lora_m_alpha / m_rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale);
|
||||||
|
ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale);
|
||||||
|
ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale);
|
||||||
|
ggml_scale_inplace(compute_ctx, lora_m_down, lora_m_scale);
|
||||||
|
|
||||||
|
// print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1]
|
||||||
|
// print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1]
|
||||||
|
// print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1]
|
||||||
|
// print_ggml_tensor(lora_m_down, true); //[3072, R, 1, 1]
|
||||||
|
// print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1]
|
||||||
|
// print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1]
|
||||||
|
// print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1]
|
||||||
|
// print_ggml_tensor(lora_m_up, true); //[R, 12288, 1, 1]
|
||||||
|
|
||||||
|
// these need to be stitched together this way:
|
||||||
|
// |q_up,0 ,0 ,0 |
|
||||||
|
// |0 ,k_up,0 ,0 |
|
||||||
|
// |0 ,0 ,v_up,0 |
|
||||||
|
// |0 ,0 ,0 ,m_up|
|
||||||
|
// (q_down,k_down,v_down,m_down) . (q ,k ,v ,m)
|
||||||
|
|
||||||
|
// up_concat will be [21504, R*4, 1, 1]
|
||||||
|
// down_concat will be [R*4, 3072, 1, 1]
|
||||||
|
|
||||||
|
ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), ggml_concat(compute_ctx, lora_v_down, lora_m_down, 1), 1);
|
||||||
|
// print_ggml_tensor(lora_down_concat, true); //[3072, R*4, 1, 1]
|
||||||
|
|
||||||
|
// this also means that if rank is bigger than 672, it is less memory efficient to do it this way (should be fine)
|
||||||
|
// print_ggml_tensor(lora_q_up, true); //[3072, R, 1, 1]
|
||||||
|
ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up);
|
||||||
|
ggml_tensor* mlp_z = ggml_dup_tensor(compute_ctx, lora_m_up);
|
||||||
|
ggml_scale(compute_ctx, z, 0);
|
||||||
|
ggml_scale(compute_ctx, mlp_z, 0);
|
||||||
|
ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1);
|
||||||
|
|
||||||
|
ggml_tensor* q_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_up, zz, 1), mlp_z, 1);
|
||||||
|
ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), ggml_concat(compute_ctx, z, mlp_z, 1), 1);
|
||||||
|
ggml_tensor* v_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, lora_v_up, 1), mlp_z, 1);
|
||||||
|
ggml_tensor* m_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, z, 1), lora_m_up, 1);
|
||||||
|
// print_ggml_tensor(q_up, true); //[R, 21504, 1, 1]
|
||||||
|
// print_ggml_tensor(k_up, true); //[R, 21504, 1, 1]
|
||||||
|
// print_ggml_tensor(v_up, true); //[R, 21504, 1, 1]
|
||||||
|
// print_ggml_tensor(m_up, true); //[R, 21504, 1, 1]
|
||||||
|
|
||||||
|
ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), ggml_concat(compute_ctx, v_up, m_up, 0), 0);
|
||||||
|
// print_ggml_tensor(lora_up_concat, true); //[R*4, 21504, 1, 1]
|
||||||
|
|
||||||
|
lora_down = ggml_cont(compute_ctx, lora_down_concat);
|
||||||
|
lora_up = ggml_cont(compute_ctx, lora_up_concat);
|
||||||
|
|
||||||
|
applied_lora_tensors.insert(split_q_u_name);
|
||||||
|
applied_lora_tensors.insert(split_k_u_name);
|
||||||
|
applied_lora_tensors.insert(split_v_u_name);
|
||||||
|
applied_lora_tensors.insert(split_m_u_name);
|
||||||
|
|
||||||
|
applied_lora_tensors.insert(split_q_d_name);
|
||||||
|
applied_lora_tensors.insert(split_k_d_name);
|
||||||
|
applied_lora_tensors.insert(split_v_d_name);
|
||||||
|
applied_lora_tensors.insert(split_m_d_name);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
lora_up_name = fk + lora_ups[type] + ".weight";
|
||||||
|
lora_down_name = fk + lora_downs[type] + ".weight";
|
||||||
|
lora_mid_name = fk + ".lora_mid.weight";
|
||||||
|
|
||||||
|
alpha_name = fk + ".alpha";
|
||||||
|
scale_name = fk + ".scale";
|
||||||
|
|
||||||
if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
|
if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
|
||||||
lora_up = lora_tensors[lora_up_name];
|
lora_up = to_f32(compute_ctx, lora_tensors[lora_up_name]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
|
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
|
||||||
lora_down = lora_tensors[lora_down_name];
|
lora_down = to_f32(compute_ctx, lora_tensors[lora_down_name]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lora_up == NULL || lora_down == NULL) {
|
if (lora_tensors.find(lora_mid_name) != lora_tensors.end()) {
|
||||||
continue;
|
lora_mid = to_f32(compute_ctx, lora_tensors[lora_mid_name]);
|
||||||
|
applied_lora_tensors.insert(lora_mid_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
applied_lora_tensors.insert(lora_up_name);
|
applied_lora_tensors.insert(lora_up_name);
|
||||||
applied_lora_tensors.insert(lora_down_name);
|
applied_lora_tensors.insert(lora_down_name);
|
||||||
applied_lora_tensors.insert(alpha_name);
|
applied_lora_tensors.insert(alpha_name);
|
||||||
applied_lora_tensors.insert(scale_name);
|
applied_lora_tensors.insert(scale_name);
|
||||||
|
}
|
||||||
|
|
||||||
// calc_cale
|
if (lora_up == NULL || lora_down == NULL) {
|
||||||
int64_t dim = lora_down->ne[ggml_n_dims(lora_down) - 1];
|
continue;
|
||||||
float scale_value = 1.0f;
|
}
|
||||||
|
// calc_scale
|
||||||
|
// TODO: .dora_scale?
|
||||||
|
int64_t rank = lora_down->ne[ggml_n_dims(lora_down) - 1];
|
||||||
if (lora_tensors.find(scale_name) != lora_tensors.end()) {
|
if (lora_tensors.find(scale_name) != lora_tensors.end()) {
|
||||||
scale_value = ggml_backend_tensor_get_f32(lora_tensors[scale_name]);
|
scale_value = ggml_backend_tensor_get_f32(lora_tensors[scale_name]);
|
||||||
} else if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
|
} else if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
|
||||||
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
|
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
|
||||||
scale_value = alpha / dim;
|
scale_value = alpha / rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
updown = ggml_merge_lora(compute_ctx, lora_down, lora_up, lora_mid);
|
||||||
}
|
}
|
||||||
scale_value *= multiplier;
|
scale_value *= multiplier;
|
||||||
|
|
||||||
// flat lora tensors to multiply it
|
|
||||||
int64_t lora_up_rows = lora_up->ne[ggml_n_dims(lora_up) - 1];
|
|
||||||
lora_up = ggml_reshape_2d(compute_ctx, lora_up, ggml_nelements(lora_up) / lora_up_rows, lora_up_rows);
|
|
||||||
auto lora_down_n_dims = ggml_n_dims(lora_down);
|
|
||||||
// assume n_dims should always be a multiple of 2 (otherwise rank 1 doesn't work)
|
|
||||||
lora_down_n_dims = (lora_down_n_dims + lora_down_n_dims % 2);
|
|
||||||
int64_t lora_down_rows = lora_down->ne[lora_down_n_dims - 1];
|
|
||||||
|
|
||||||
// ggml_mul_mat requires tensor b transposed
|
|
||||||
lora_down = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, lora_down));
|
|
||||||
struct ggml_tensor* updown = ggml_mul_mat(compute_ctx, lora_up, lora_down);
|
|
||||||
updown = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, updown));
|
|
||||||
updown = ggml_reshape(compute_ctx, updown, weight);
|
updown = ggml_reshape(compute_ctx, updown, weight);
|
||||||
GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(weight));
|
GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(weight));
|
||||||
updown = ggml_scale_inplace(compute_ctx, updown, scale_value);
|
updown = ggml_scale_inplace(compute_ctx, updown, scale_value);
|
||||||
|
@ -166,15 +805,18 @@ struct LoraModel : public GGMLRunner {
|
||||||
}
|
}
|
||||||
// final_weight = ggml_add_inplace(compute_ctx, weight, updown); // apply directly
|
// final_weight = ggml_add_inplace(compute_ctx, weight, updown); // apply directly
|
||||||
ggml_build_forward_expand(gf, final_weight);
|
ggml_build_forward_expand(gf, final_weight);
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t total_lora_tensors_count = 0;
|
size_t total_lora_tensors_count = 0;
|
||||||
size_t applied_lora_tensors_count = 0;
|
size_t applied_lora_tensors_count = 0;
|
||||||
|
|
||||||
for (auto& kv : lora_tensors) {
|
for (auto& kv : lora_tensors) {
|
||||||
total_lora_tensors_count++;
|
total_lora_tensors_count++;
|
||||||
if (applied_lora_tensors.find(kv.first) == applied_lora_tensors.end()) {
|
if (applied_lora_tensors.find(kv.first) == applied_lora_tensors.end()) {
|
||||||
LOG_WARN("unused lora tensor %s", kv.first.c_str());
|
LOG_WARN("unused lora tensor |%s|", kv.first.c_str());
|
||||||
|
print_ggml_tensor(kv.second, true);
|
||||||
|
// exit(0);
|
||||||
} else {
|
} else {
|
||||||
applied_lora_tensors_count++;
|
applied_lora_tensors_count++;
|
||||||
}
|
}
|
||||||
|
@ -193,9 +835,9 @@ struct LoraModel : public GGMLRunner {
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
void apply(std::map<std::string, struct ggml_tensor*> model_tensors, int n_threads) {
|
void apply(std::map<std::string, struct ggml_tensor*> model_tensors, SDVersion version, int n_threads) {
|
||||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||||
return build_lora_graph(model_tensors);
|
return build_lora_graph(model_tensors, version);
|
||||||
};
|
};
|
||||||
GGMLRunner::compute(get_graph, n_threads, true);
|
GGMLRunner::compute(get_graph, n_threads, true);
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
#include "stable-diffusion.h"
|
#include "stable-diffusion.h"
|
||||||
|
|
||||||
#define STB_IMAGE_IMPLEMENTATION
|
#define STB_IMAGE_IMPLEMENTATION
|
||||||
|
#define STB_IMAGE_STATIC
|
||||||
#include "stb_image.h"
|
#include "stb_image.h"
|
||||||
|
|
||||||
#define STB_IMAGE_WRITE_IMPLEMENTATION
|
#define STB_IMAGE_WRITE_IMPLEMENTATION
|
||||||
|
@ -18,6 +19,7 @@
|
||||||
#include "stb_image_write.h"
|
#include "stb_image_write.h"
|
||||||
|
|
||||||
#define STB_IMAGE_RESIZE_IMPLEMENTATION
|
#define STB_IMAGE_RESIZE_IMPLEMENTATION
|
||||||
|
#define STB_IMAGE_RESIZE_STATIC
|
||||||
#include "stb_image_resize.h"
|
#include "stb_image_resize.h"
|
||||||
|
|
||||||
const char* rng_type_to_str[] = {
|
const char* rng_type_to_str[] = {
|
||||||
|
@ -37,6 +39,8 @@ const char* sample_method_str[] = {
|
||||||
"ipndm",
|
"ipndm",
|
||||||
"ipndm_v",
|
"ipndm_v",
|
||||||
"lcm",
|
"lcm",
|
||||||
|
"ddim_trailing",
|
||||||
|
"tcd",
|
||||||
};
|
};
|
||||||
|
|
||||||
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
|
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
|
||||||
|
@ -83,6 +87,7 @@ struct SDParams {
|
||||||
std::string lora_model_dir;
|
std::string lora_model_dir;
|
||||||
std::string output_path = "output.png";
|
std::string output_path = "output.png";
|
||||||
std::string input_path;
|
std::string input_path;
|
||||||
|
std::string mask_path;
|
||||||
std::string control_image_path;
|
std::string control_image_path;
|
||||||
|
|
||||||
std::string prompt;
|
std::string prompt;
|
||||||
|
@ -90,6 +95,7 @@ struct SDParams {
|
||||||
float min_cfg = 1.0f;
|
float min_cfg = 1.0f;
|
||||||
float cfg_scale = 7.0f;
|
float cfg_scale = 7.0f;
|
||||||
float guidance = 3.5f;
|
float guidance = 3.5f;
|
||||||
|
float eta = 0.f;
|
||||||
float style_ratio = 20.f;
|
float style_ratio = 20.f;
|
||||||
int clip_skip = -1; // <= 0 represents unspecified
|
int clip_skip = -1; // <= 0 represents unspecified
|
||||||
int width = 512;
|
int width = 512;
|
||||||
|
@ -120,9 +126,9 @@ struct SDParams {
|
||||||
int upscale_repeats = 1;
|
int upscale_repeats = 1;
|
||||||
|
|
||||||
std::vector<int> skip_layers = {7, 8, 9};
|
std::vector<int> skip_layers = {7, 8, 9};
|
||||||
float slg_scale = 0.;
|
float slg_scale = 0.f;
|
||||||
float skip_layer_start = 0.01;
|
float skip_layer_start = 0.01f;
|
||||||
float skip_layer_end = 0.2;
|
float skip_layer_end = 0.2f;
|
||||||
};
|
};
|
||||||
|
|
||||||
void print_params(SDParams params) {
|
void print_params(SDParams params) {
|
||||||
|
@ -146,6 +152,7 @@ void print_params(SDParams params) {
|
||||||
printf(" normalize input image : %s\n", params.normalize_input ? "true" : "false");
|
printf(" normalize input image : %s\n", params.normalize_input ? "true" : "false");
|
||||||
printf(" output_path: %s\n", params.output_path.c_str());
|
printf(" output_path: %s\n", params.output_path.c_str());
|
||||||
printf(" init_img: %s\n", params.input_path.c_str());
|
printf(" init_img: %s\n", params.input_path.c_str());
|
||||||
|
printf(" mask_img: %s\n", params.mask_path.c_str());
|
||||||
printf(" control_image: %s\n", params.control_image_path.c_str());
|
printf(" control_image: %s\n", params.control_image_path.c_str());
|
||||||
printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false");
|
printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false");
|
||||||
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
|
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
|
||||||
|
@ -158,6 +165,7 @@ void print_params(SDParams params) {
|
||||||
printf(" cfg_scale: %.2f\n", params.cfg_scale);
|
printf(" cfg_scale: %.2f\n", params.cfg_scale);
|
||||||
printf(" slg_scale: %.2f\n", params.slg_scale);
|
printf(" slg_scale: %.2f\n", params.slg_scale);
|
||||||
printf(" guidance: %.2f\n", params.guidance);
|
printf(" guidance: %.2f\n", params.guidance);
|
||||||
|
printf(" eta: %.2f\n", params.eta);
|
||||||
printf(" clip_skip: %d\n", params.clip_skip);
|
printf(" clip_skip: %d\n", params.clip_skip);
|
||||||
printf(" width: %d\n", params.width);
|
printf(" width: %d\n", params.width);
|
||||||
printf(" height: %d\n", params.height);
|
printf(" height: %d\n", params.height);
|
||||||
|
@ -198,16 +206,19 @@ void print_usage(int argc, const char* argv[]) {
|
||||||
printf(" If not specified, the default is the type of the weight file\n");
|
printf(" If not specified, the default is the type of the weight file\n");
|
||||||
printf(" --lora-model-dir [DIR] lora model directory\n");
|
printf(" --lora-model-dir [DIR] lora model directory\n");
|
||||||
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
|
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
|
||||||
|
printf(" --mask [MASK] path to the mask image, required by img2img with mask\n");
|
||||||
printf(" --control-image [IMAGE] path to image condition, control net\n");
|
printf(" --control-image [IMAGE] path to image condition, control net\n");
|
||||||
printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n");
|
printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n");
|
||||||
printf(" -p, --prompt [PROMPT] the prompt to render\n");
|
printf(" -p, --prompt [PROMPT] the prompt to render\n");
|
||||||
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
|
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
|
||||||
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
|
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
|
||||||
|
printf(" --guidance SCALE guidance scale for img2img (default: 3.5)\n");
|
||||||
printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
|
printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
|
||||||
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
|
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
|
||||||
printf(" --skip_layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n");
|
printf(" --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0)\n");
|
||||||
printf(" --skip_layer_start START SLG enabling point: (default: 0.01)\n");
|
printf(" --skip-layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n");
|
||||||
printf(" --skip_layer_end END SLG disabling point: (default: 0.2)\n");
|
printf(" --skip-layer-start START SLG enabling point: (default: 0.01)\n");
|
||||||
|
printf(" --skip-layer-end END SLG disabling point: (default: 0.2)\n");
|
||||||
printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n");
|
printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n");
|
||||||
printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n");
|
printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n");
|
||||||
printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%%)\n");
|
printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%%)\n");
|
||||||
|
@ -215,7 +226,7 @@ void print_usage(int argc, const char* argv[]) {
|
||||||
printf(" 1.0 corresponds to full destruction of information in init image\n");
|
printf(" 1.0 corresponds to full destruction of information in init image\n");
|
||||||
printf(" -H, --height H image height, in pixel space (default: 512)\n");
|
printf(" -H, --height H image height, in pixel space (default: 512)\n");
|
||||||
printf(" -W, --width W image width, in pixel space (default: 512)\n");
|
printf(" -W, --width W image width, in pixel space (default: 512)\n");
|
||||||
printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm}\n");
|
printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n");
|
||||||
printf(" sampling method (default: \"euler_a\")\n");
|
printf(" sampling method (default: \"euler_a\")\n");
|
||||||
printf(" --steps STEPS number of sample steps (default: 20)\n");
|
printf(" --steps STEPS number of sample steps (default: 20)\n");
|
||||||
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
|
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
|
||||||
|
@ -382,6 +393,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.input_path = argv[i];
|
params.input_path = argv[i];
|
||||||
|
} else if (arg == "--mask") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_arg = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.mask_path = argv[i];
|
||||||
} else if (arg == "--control-image") {
|
} else if (arg == "--control-image") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_arg = true;
|
invalid_arg = true;
|
||||||
|
@ -428,6 +445,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.guidance = std::stof(argv[i]);
|
params.guidance = std::stof(argv[i]);
|
||||||
|
} else if (arg == "--eta") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_arg = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.eta = std::stof(argv[i]);
|
||||||
} else if (arg == "--strength") {
|
} else if (arg == "--strength") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_arg = true;
|
invalid_arg = true;
|
||||||
|
@ -707,6 +730,7 @@ std::string get_image_params(SDParams params, int64_t seed) {
|
||||||
parameter_string += "Skip layer end: " + std::to_string(params.skip_layer_end) + ", ";
|
parameter_string += "Skip layer end: " + std::to_string(params.skip_layer_end) + ", ";
|
||||||
}
|
}
|
||||||
parameter_string += "Guidance: " + std::to_string(params.guidance) + ", ";
|
parameter_string += "Guidance: " + std::to_string(params.guidance) + ", ";
|
||||||
|
parameter_string += "Eta: " + std::to_string(params.eta) + ", ";
|
||||||
parameter_string += "Seed: " + std::to_string(seed) + ", ";
|
parameter_string += "Seed: " + std::to_string(seed) + ", ";
|
||||||
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
|
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
|
||||||
parameter_string += "Model: " + sd_basename(params.model_path) + ", ";
|
parameter_string += "Model: " + sd_basename(params.model_path) + ", ";
|
||||||
|
@ -801,6 +825,8 @@ int main(int argc, const char* argv[]) {
|
||||||
bool vae_decode_only = true;
|
bool vae_decode_only = true;
|
||||||
uint8_t* input_image_buffer = NULL;
|
uint8_t* input_image_buffer = NULL;
|
||||||
uint8_t* control_image_buffer = NULL;
|
uint8_t* control_image_buffer = NULL;
|
||||||
|
uint8_t* mask_image_buffer = NULL;
|
||||||
|
|
||||||
if (params.mode == IMG2IMG || params.mode == IMG2VID) {
|
if (params.mode == IMG2IMG || params.mode == IMG2VID) {
|
||||||
vae_decode_only = false;
|
vae_decode_only = false;
|
||||||
|
|
||||||
|
@ -905,6 +931,18 @@ int main(int argc, const char* argv[]) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<uint8_t> default_mask_image_vec(params.width * params.height, 255);
|
||||||
|
if (params.mask_path != "") {
|
||||||
|
int c = 0;
|
||||||
|
mask_image_buffer = stbi_load(params.mask_path.c_str(), ¶ms.width, ¶ms.height, &c, 1);
|
||||||
|
} else {
|
||||||
|
mask_image_buffer = default_mask_image_vec.data();
|
||||||
|
}
|
||||||
|
sd_image_t mask_image = {(uint32_t)params.width,
|
||||||
|
(uint32_t)params.height,
|
||||||
|
1,
|
||||||
|
mask_image_buffer};
|
||||||
|
|
||||||
sd_image_t* results;
|
sd_image_t* results;
|
||||||
if (params.mode == TXT2IMG) {
|
if (params.mode == TXT2IMG) {
|
||||||
results = txt2img(sd_ctx,
|
results = txt2img(sd_ctx,
|
||||||
|
@ -913,6 +951,7 @@ int main(int argc, const char* argv[]) {
|
||||||
params.clip_skip,
|
params.clip_skip,
|
||||||
params.cfg_scale,
|
params.cfg_scale,
|
||||||
params.guidance,
|
params.guidance,
|
||||||
|
params.eta,
|
||||||
params.width,
|
params.width,
|
||||||
params.height,
|
params.height,
|
||||||
params.sample_method,
|
params.sample_method,
|
||||||
|
@ -974,11 +1013,13 @@ int main(int argc, const char* argv[]) {
|
||||||
} else {
|
} else {
|
||||||
results = img2img(sd_ctx,
|
results = img2img(sd_ctx,
|
||||||
input_image,
|
input_image,
|
||||||
|
mask_image,
|
||||||
params.prompt.c_str(),
|
params.prompt.c_str(),
|
||||||
params.negative_prompt.c_str(),
|
params.negative_prompt.c_str(),
|
||||||
params.clip_skip,
|
params.clip_skip,
|
||||||
params.cfg_scale,
|
params.cfg_scale,
|
||||||
params.guidance,
|
params.guidance,
|
||||||
|
params.eta,
|
||||||
params.width,
|
params.width,
|
||||||
params.height,
|
params.height,
|
||||||
params.sample_method,
|
params.sample_method,
|
||||||
|
@ -1032,16 +1073,41 @@ int main(int argc, const char* argv[]) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string dummy_name, ext, lc_ext;
|
||||||
|
bool is_jpg;
|
||||||
size_t last = params.output_path.find_last_of(".");
|
size_t last = params.output_path.find_last_of(".");
|
||||||
std::string dummy_name = last != std::string::npos ? params.output_path.substr(0, last) : params.output_path;
|
size_t last_path = std::min(params.output_path.find_last_of("/"),
|
||||||
|
params.output_path.find_last_of("\\"));
|
||||||
|
if (last != std::string::npos // filename has extension
|
||||||
|
&& (last_path == std::string::npos || last > last_path)) {
|
||||||
|
dummy_name = params.output_path.substr(0, last);
|
||||||
|
ext = lc_ext = params.output_path.substr(last);
|
||||||
|
std::transform(ext.begin(), ext.end(), lc_ext.begin(), ::tolower);
|
||||||
|
is_jpg = lc_ext == ".jpg" || lc_ext == ".jpeg" || lc_ext == ".jpe";
|
||||||
|
} else {
|
||||||
|
dummy_name = params.output_path;
|
||||||
|
ext = lc_ext = "";
|
||||||
|
is_jpg = false;
|
||||||
|
}
|
||||||
|
// appending ".png" to absent or unknown extension
|
||||||
|
if (!is_jpg && lc_ext != ".png") {
|
||||||
|
dummy_name += ext;
|
||||||
|
ext = ".png";
|
||||||
|
}
|
||||||
for (int i = 0; i < params.batch_count; i++) {
|
for (int i = 0; i < params.batch_count; i++) {
|
||||||
if (results[i].data == NULL) {
|
if (results[i].data == NULL) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ".png" : dummy_name + ".png";
|
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ext : dummy_name + ext;
|
||||||
|
if(is_jpg) {
|
||||||
|
stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
|
||||||
|
results[i].data, 90, get_image_params(params, params.seed + i).c_str());
|
||||||
|
printf("save result JPEG image to '%s'\n", final_image_path.c_str());
|
||||||
|
} else {
|
||||||
stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
|
stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
|
||||||
results[i].data, 0, get_image_params(params, params.seed + i).c_str());
|
results[i].data, 0, get_image_params(params, params.seed + i).c_str());
|
||||||
printf("save result image to '%s'\n", final_image_path.c_str());
|
printf("save result PNG image to '%s'\n", final_image_path.c_str());
|
||||||
|
}
|
||||||
free(results[i].data);
|
free(results[i].data);
|
||||||
results[i].data = NULL;
|
results[i].data = NULL;
|
||||||
}
|
}
|
||||||
|
|
|
@ -572,6 +572,26 @@ std::string convert_tensor_name(std::string name) {
|
||||||
return new_name;
|
return new_name;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void add_preprocess_tensor_storage_types(std::map<std::string, enum ggml_type>& tensor_storages_types, std::string name, enum ggml_type type) {
|
||||||
|
std::string new_name = convert_tensor_name(name);
|
||||||
|
|
||||||
|
if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_weight")) {
|
||||||
|
size_t prefix_size = new_name.find("attn.in_proj_weight");
|
||||||
|
std::string prefix = new_name.substr(0, prefix_size);
|
||||||
|
tensor_storages_types[prefix + "self_attn.q_proj.weight"] = type;
|
||||||
|
tensor_storages_types[prefix + "self_attn.k_proj.weight"] = type;
|
||||||
|
tensor_storages_types[prefix + "self_attn.v_proj.weight"] = type;
|
||||||
|
} else if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_bias")) {
|
||||||
|
size_t prefix_size = new_name.find("attn.in_proj_bias");
|
||||||
|
std::string prefix = new_name.substr(0, prefix_size);
|
||||||
|
tensor_storages_types[prefix + "self_attn.q_proj.bias"] = type;
|
||||||
|
tensor_storages_types[prefix + "self_attn.k_proj.bias"] = type;
|
||||||
|
tensor_storages_types[prefix + "self_attn.v_proj.bias"] = type;
|
||||||
|
} else {
|
||||||
|
tensor_storages_types[new_name] = type;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void preprocess_tensor(TensorStorage tensor_storage,
|
void preprocess_tensor(TensorStorage tensor_storage,
|
||||||
std::vector<TensorStorage>& processed_tensor_storages) {
|
std::vector<TensorStorage>& processed_tensor_storages) {
|
||||||
std::vector<TensorStorage> result;
|
std::vector<TensorStorage> result;
|
||||||
|
@ -942,7 +962,7 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
|
||||||
GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes());
|
GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes());
|
||||||
|
|
||||||
tensor_storages.push_back(tensor_storage);
|
tensor_storages.push_back(tensor_storage);
|
||||||
tensor_storages_types[tensor_storage.name] = tensor_storage.type;
|
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
|
||||||
}
|
}
|
||||||
|
|
||||||
gguf_free(ctx_gguf_);
|
gguf_free(ctx_gguf_);
|
||||||
|
@ -1087,7 +1107,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
||||||
}
|
}
|
||||||
|
|
||||||
tensor_storages.push_back(tensor_storage);
|
tensor_storages.push_back(tensor_storage);
|
||||||
tensor_storages_types[tensor_storage.name] = tensor_storage.type;
|
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
|
||||||
|
|
||||||
// LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str());
|
// LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str());
|
||||||
}
|
}
|
||||||
|
@ -1418,7 +1438,7 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer,
|
||||||
// printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str());
|
// printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str());
|
||||||
reader.tensor_storage.name = prefix + reader.tensor_storage.name;
|
reader.tensor_storage.name = prefix + reader.tensor_storage.name;
|
||||||
tensor_storages.push_back(reader.tensor_storage);
|
tensor_storages.push_back(reader.tensor_storage);
|
||||||
tensor_storages_types[reader.tensor_storage.name] = reader.tensor_storage.type;
|
add_preprocess_tensor_storage_types(tensor_storages_types, reader.tensor_storage.name, reader.tensor_storage.type);
|
||||||
|
|
||||||
// LOG_DEBUG("%s", reader.tensor_storage.name.c_str());
|
// LOG_DEBUG("%s", reader.tensor_storage.name.c_str());
|
||||||
// reset
|
// reset
|
||||||
|
@ -1483,24 +1503,49 @@ bool ModelLoader::has_diffusion_model_tensors()
|
||||||
}
|
}
|
||||||
|
|
||||||
SDVersion ModelLoader::get_sd_version() {
|
SDVersion ModelLoader::get_sd_version() {
|
||||||
TensorStorage token_embedding_weight;
|
TensorStorage token_embedding_weight, input_block_weight;
|
||||||
|
bool input_block_checked = false;
|
||||||
|
|
||||||
|
bool has_multiple_encoders = false;
|
||||||
|
bool is_unet = false;
|
||||||
|
|
||||||
|
bool is_xl = false;
|
||||||
|
bool is_flux = false;
|
||||||
|
|
||||||
|
#define found_family (is_xl || is_flux)
|
||||||
for (auto& tensor_storage : tensor_storages) {
|
for (auto& tensor_storage : tensor_storages) {
|
||||||
|
if (!found_family) {
|
||||||
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
|
||||||
return VERSION_FLUX;
|
is_flux = true;
|
||||||
|
if (input_block_checked) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
|
||||||
return VERSION_SD3;
|
return VERSION_SD3;
|
||||||
}
|
}
|
||||||
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) {
|
||||||
return VERSION_SDXL;
|
is_unet = true;
|
||||||
|
if (has_multiple_encoders) {
|
||||||
|
is_xl = true;
|
||||||
|
if (input_block_checked) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos) {
|
||||||
|
has_multiple_encoders = true;
|
||||||
|
if (is_unet) {
|
||||||
|
is_xl = true;
|
||||||
|
if (input_block_checked) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (tensor_storage.name.find("cond_stage_model.1") != std::string::npos) {
|
|
||||||
return VERSION_SDXL;
|
|
||||||
}
|
}
|
||||||
if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) {
|
||||||
return VERSION_SVD;
|
return VERSION_SVD;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
|
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
|
||||||
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
|
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
|
||||||
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
|
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
|
||||||
|
@ -1510,11 +1555,39 @@ SDVersion ModelLoader::get_sd_version() {
|
||||||
token_embedding_weight = tensor_storage;
|
token_embedding_weight = tensor_storage;
|
||||||
// break;
|
// break;
|
||||||
}
|
}
|
||||||
|
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight") {
|
||||||
|
input_block_weight = tensor_storage;
|
||||||
|
input_block_checked = true;
|
||||||
|
if (found_family) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool is_inpaint = input_block_weight.ne[2] == 9;
|
||||||
|
if (is_xl) {
|
||||||
|
if (is_inpaint) {
|
||||||
|
return VERSION_SDXL_INPAINT;
|
||||||
|
}
|
||||||
|
return VERSION_SDXL;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_flux) {
|
||||||
|
is_inpaint = input_block_weight.ne[0] == 384;
|
||||||
|
if (is_inpaint) {
|
||||||
|
return VERSION_FLUX_FILL;
|
||||||
|
}
|
||||||
|
return VERSION_FLUX;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (token_embedding_weight.ne[0] == 768) {
|
if (token_embedding_weight.ne[0] == 768) {
|
||||||
|
if (is_inpaint) {
|
||||||
|
return VERSION_SD1_INPAINT;
|
||||||
|
}
|
||||||
return VERSION_SD1;
|
return VERSION_SD1;
|
||||||
} else if (token_embedding_weight.ne[0] == 1024) {
|
} else if (token_embedding_weight.ne[0] == 1024) {
|
||||||
|
if (is_inpaint) {
|
||||||
|
return VERSION_SD2_INPAINT;
|
||||||
|
}
|
||||||
return VERSION_SD2;
|
return VERSION_SD2;
|
||||||
}
|
}
|
||||||
return VERSION_COUNT;
|
return VERSION_COUNT;
|
||||||
|
@ -1607,11 +1680,20 @@ ggml_type ModelLoader::get_vae_wtype() {
|
||||||
void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) {
|
void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) {
|
||||||
for (auto& pair : tensor_storages_types) {
|
for (auto& pair : tensor_storages_types) {
|
||||||
if (prefix.size() < 1 || pair.first.substr(0, prefix.size()) == prefix) {
|
if (prefix.size() < 1 || pair.first.substr(0, prefix.size()) == prefix) {
|
||||||
|
bool found = false;
|
||||||
for (auto& tensor_storage : tensor_storages) {
|
for (auto& tensor_storage : tensor_storages) {
|
||||||
if (tensor_storage.name == pair.first) {
|
std::map<std::string, ggml_type> temp;
|
||||||
|
add_preprocess_tensor_storage_types(temp, tensor_storage.name, tensor_storage.type);
|
||||||
|
for (auto& preprocessed_name : temp) {
|
||||||
|
if (preprocessed_name.first == pair.first) {
|
||||||
if (tensor_should_be_converted(tensor_storage, wtype)) {
|
if (tensor_should_be_converted(tensor_storage, wtype)) {
|
||||||
pair.second = wtype;
|
pair.second = wtype;
|
||||||
}
|
}
|
||||||
|
found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (found) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1720,9 +1802,11 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
int tensor_count = 0;
|
||||||
|
int64_t t1 = ggml_time_ms();
|
||||||
for (auto& tensor_storage : processed_tensor_storages) {
|
for (auto& tensor_storage : processed_tensor_storages) {
|
||||||
if (tensor_storage.file_index != file_index) {
|
if (tensor_storage.file_index != file_index) {
|
||||||
|
++tensor_count;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
ggml_tensor* dst_tensor = NULL;
|
ggml_tensor* dst_tensor = NULL;
|
||||||
|
@ -1734,6 +1818,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
|
||||||
}
|
}
|
||||||
|
|
||||||
if (dst_tensor == NULL) {
|
if (dst_tensor == NULL) {
|
||||||
|
++tensor_count;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1800,6 +1885,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
|
||||||
ggml_backend_tensor_set(dst_tensor, convert_buffer.data(), 0, ggml_nbytes(dst_tensor));
|
ggml_backend_tensor_set(dst_tensor, convert_buffer.data(), 0, ggml_nbytes(dst_tensor));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
int64_t t2 = ggml_time_ms();
|
||||||
|
pretty_progress(++tensor_count, processed_tensor_storages.size(), (t2 - t1) / 1000.0f);
|
||||||
|
t1 = t2;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (zip != NULL) {
|
if (zip != NULL) {
|
||||||
|
@ -1866,9 +1954,6 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
|
||||||
if (pair.first.find("cond_stage_model.transformer.text_model.encoder.layers.23") != std::string::npos) {
|
if (pair.first.find("cond_stage_model.transformer.text_model.encoder.layers.23") != std::string::npos) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (pair.first.find("alphas_cumprod") != std::string::npos) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (pair.first.find("alphas_cumprod") != std::string::npos) {
|
if (pair.first.find("alphas_cumprod") != std::string::npos) {
|
||||||
continue;
|
continue;
|
||||||
|
|
|
@ -14,21 +14,26 @@
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
#include "zip.h"
|
#include "zip.h"
|
||||||
|
#include "gguf.h"
|
||||||
|
|
||||||
#define SD_MAX_DIMS 5
|
#define SD_MAX_DIMS 5
|
||||||
|
|
||||||
enum SDVersion {
|
enum SDVersion {
|
||||||
VERSION_SD1,
|
VERSION_SD1,
|
||||||
|
VERSION_SD1_INPAINT,
|
||||||
VERSION_SD2,
|
VERSION_SD2,
|
||||||
|
VERSION_SD2_INPAINT,
|
||||||
VERSION_SDXL,
|
VERSION_SDXL,
|
||||||
|
VERSION_SDXL_INPAINT,
|
||||||
VERSION_SVD,
|
VERSION_SVD,
|
||||||
VERSION_SD3,
|
VERSION_SD3,
|
||||||
VERSION_FLUX,
|
VERSION_FLUX,
|
||||||
|
VERSION_FLUX_FILL,
|
||||||
VERSION_COUNT,
|
VERSION_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
static inline bool sd_version_is_flux(SDVersion version) {
|
static inline bool sd_version_is_flux(SDVersion version) {
|
||||||
if (version == VERSION_FLUX) {
|
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
@ -41,6 +46,34 @@ static inline bool sd_version_is_sd3(SDVersion version) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline bool sd_version_is_sd1(SDVersion version) {
|
||||||
|
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline bool sd_version_is_sd2(SDVersion version) {
|
||||||
|
if (version == VERSION_SD2 || version == VERSION_SD2_INPAINT) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline bool sd_version_is_sdxl(SDVersion version) {
|
||||||
|
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline bool sd_version_is_inpaint(SDVersion version) {
|
||||||
|
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
static inline bool sd_version_is_dit(SDVersion version) {
|
static inline bool sd_version_is_dit(SDVersion version) {
|
||||||
if (sd_version_is_flux(version) || sd_version_is_sd3(version)) {
|
if (sd_version_is_flux(version) || sd_version_is_sd3(version)) {
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -62,6 +62,7 @@ struct SDParams {
|
||||||
std::string lora_model_dir;
|
std::string lora_model_dir;
|
||||||
std::string output_path = "output.png";
|
std::string output_path = "output.png";
|
||||||
std::string input_path;
|
std::string input_path;
|
||||||
|
std::string mask_path;
|
||||||
std::string control_image_path;
|
std::string control_image_path;
|
||||||
|
|
||||||
std::string prompt;
|
std::string prompt;
|
||||||
|
@ -69,6 +70,7 @@ struct SDParams {
|
||||||
float min_cfg = 1.0f;
|
float min_cfg = 1.0f;
|
||||||
float cfg_scale = 7.0f;
|
float cfg_scale = 7.0f;
|
||||||
float guidance = 3.5f;
|
float guidance = 3.5f;
|
||||||
|
float eta = 0.f;
|
||||||
float style_ratio = 20.f;
|
float style_ratio = 20.f;
|
||||||
int clip_skip = -1; // <= 0 represents unspecified
|
int clip_skip = -1; // <= 0 represents unspecified
|
||||||
int width = 512;
|
int width = 512;
|
||||||
|
@ -99,9 +101,9 @@ struct SDParams {
|
||||||
int upscale_repeats = 1;
|
int upscale_repeats = 1;
|
||||||
|
|
||||||
std::vector<int> skip_layers = {7, 8, 9};
|
std::vector<int> skip_layers = {7, 8, 9};
|
||||||
float slg_scale = 0.;
|
float slg_scale = 0.f;
|
||||||
float skip_layer_start = 0.01;
|
float skip_layer_start = 0.01f;
|
||||||
float skip_layer_end = 0.2;
|
float skip_layer_end = 0.2f;
|
||||||
};
|
};
|
||||||
|
|
||||||
//shared
|
//shared
|
||||||
|
@ -113,6 +115,7 @@ static sd_ctx_t * sd_ctx = nullptr;
|
||||||
static int sddebugmode = 0;
|
static int sddebugmode = 0;
|
||||||
static std::string recent_data = "";
|
static std::string recent_data = "";
|
||||||
static uint8_t * input_image_buffer = NULL;
|
static uint8_t * input_image_buffer = NULL;
|
||||||
|
static uint8_t * input_mask_buffer = NULL;
|
||||||
|
|
||||||
static std::string sdplatformenv, sddeviceenv, sdvulkandeviceenv;
|
static std::string sdplatformenv, sddeviceenv, sdvulkandeviceenv;
|
||||||
static bool notiling = false;
|
static bool notiling = false;
|
||||||
|
@ -317,6 +320,7 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
|
||||||
std::string cleanprompt = clean_input_prompt(inputs.prompt);
|
std::string cleanprompt = clean_input_prompt(inputs.prompt);
|
||||||
std::string cleannegprompt = clean_input_prompt(inputs.negative_prompt);
|
std::string cleannegprompt = clean_input_prompt(inputs.negative_prompt);
|
||||||
std::string img2img_data = std::string(inputs.init_images);
|
std::string img2img_data = std::string(inputs.init_images);
|
||||||
|
std::string img2img_mask = "";
|
||||||
std::string sampler = inputs.sample_method;
|
std::string sampler = inputs.sample_method;
|
||||||
|
|
||||||
sd_params->prompt = cleanprompt;
|
sd_params->prompt = cleanprompt;
|
||||||
|
@ -351,6 +355,10 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
|
||||||
newheight = newheight - (newheight%64);
|
newheight = newheight - (newheight%64);
|
||||||
sd_params->width = newwidth;
|
sd_params->width = newwidth;
|
||||||
sd_params->height = newheight;
|
sd_params->height = newheight;
|
||||||
|
if(!sd_is_quiet && sddebugmode==1)
|
||||||
|
{
|
||||||
|
printf("\nDownscale to %dx%d as %d > %d\n",newwidth,newheight,biggestdim,reslimit);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
bool dotile = (sd_params->width>768 || sd_params->height>768) && !notiling;
|
bool dotile = (sd_params->width>768 || sd_params->height>768) && !notiling;
|
||||||
set_sd_vae_tiling(sd_ctx,dotile); //changes vae tiling, prevents memory related crash/oom
|
set_sd_vae_tiling(sd_ctx,dotile); //changes vae tiling, prevents memory related crash/oom
|
||||||
|
@ -358,11 +366,14 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
|
||||||
//for img2img
|
//for img2img
|
||||||
sd_image_t input_image = {0,0,0,nullptr};
|
sd_image_t input_image = {0,0,0,nullptr};
|
||||||
std::vector<uint8_t> image_buffer;
|
std::vector<uint8_t> image_buffer;
|
||||||
|
std::vector<uint8_t> image_mask_buffer;
|
||||||
int nx, ny, nc;
|
int nx, ny, nc;
|
||||||
|
int nx2, ny2, nc2;
|
||||||
int img2imgW = sd_params->width; //for img2img input
|
int img2imgW = sd_params->width; //for img2img input
|
||||||
int img2imgH = sd_params->height;
|
int img2imgH = sd_params->height;
|
||||||
int img2imgC = 3; // Assuming RGB image
|
int img2imgC = 3; // Assuming RGB image
|
||||||
std::vector<uint8_t> resized_image_buf(img2imgW * img2imgH * img2imgC);
|
std::vector<uint8_t> resized_image_buf(img2imgW * img2imgH * img2imgC);
|
||||||
|
std::vector<uint8_t> resized_mask_buf(img2imgW * img2imgH * img2imgC);
|
||||||
|
|
||||||
std::string ts = get_timestamp_str();
|
std::string ts = get_timestamp_str();
|
||||||
if(!sd_is_quiet)
|
if(!sd_is_quiet)
|
||||||
|
@ -429,6 +440,7 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
|
||||||
sd_params->clip_skip,
|
sd_params->clip_skip,
|
||||||
sd_params->cfg_scale,
|
sd_params->cfg_scale,
|
||||||
sd_params->guidance,
|
sd_params->guidance,
|
||||||
|
sd_params->eta,
|
||||||
sd_params->width,
|
sd_params->width,
|
||||||
sd_params->height,
|
sd_params->height,
|
||||||
sd_params->sample_method,
|
sd_params->sample_method,
|
||||||
|
@ -461,6 +473,11 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
|
||||||
stbi_image_free(input_image_buffer);
|
stbi_image_free(input_image_buffer);
|
||||||
input_image_buffer = nullptr;
|
input_image_buffer = nullptr;
|
||||||
}
|
}
|
||||||
|
if(input_mask_buffer!=nullptr) //just in time free old buffer
|
||||||
|
{
|
||||||
|
stbi_image_free(input_mask_buffer);
|
||||||
|
input_mask_buffer = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
input_image_buffer = stbi_load_from_memory(image_buffer.data(), image_buffer.size(), &nx, &ny, &nc, 3);
|
input_image_buffer = stbi_load_from_memory(image_buffer.data(), image_buffer.size(), &nx, &ny, &nc, 3);
|
||||||
|
|
||||||
|
@ -486,11 +503,34 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(img2img_mask!="")
|
||||||
|
{
|
||||||
|
image_mask_buffer = kcpp_base64_decode(img2img_mask);
|
||||||
|
input_mask_buffer = stbi_load_from_memory(image_mask_buffer.data(), image_mask_buffer.size(), &nx2, &ny2, &nc2, 3);
|
||||||
|
// Resize the image
|
||||||
|
int resok = stbir_resize_uint8(input_mask_buffer, nx, ny, 0, resized_mask_buf.data(), img2imgW, img2imgH, 0, img2imgC);
|
||||||
|
if (!resok) {
|
||||||
|
printf("\nKCPP SD: resize image failed!\n");
|
||||||
|
output.data = "";
|
||||||
|
output.status = 0;
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
input_image.width = img2imgW;
|
input_image.width = img2imgW;
|
||||||
input_image.height = img2imgH;
|
input_image.height = img2imgH;
|
||||||
input_image.channel = img2imgC;
|
input_image.channel = img2imgC;
|
||||||
input_image.data = resized_image_buf.data();
|
input_image.data = resized_image_buf.data();
|
||||||
|
|
||||||
|
uint8_t* mask_image_buffer = NULL;
|
||||||
|
std::vector<uint8_t> default_mask_image_vec(img2imgW * img2imgH * img2imgC, 255);
|
||||||
|
if (img2img_mask != "") {
|
||||||
|
mask_image_buffer = resized_mask_buf.data();
|
||||||
|
} else {
|
||||||
|
mask_image_buffer = default_mask_image_vec.data();
|
||||||
|
}
|
||||||
|
sd_image_t mask_image = { (uint32_t) img2imgW, (uint32_t) img2imgH, 1, mask_image_buffer };
|
||||||
|
|
||||||
if(!sd_is_quiet && sddebugmode==1)
|
if(!sd_is_quiet && sddebugmode==1)
|
||||||
{
|
{
|
||||||
printf("\nIMG2IMG PROMPT:%s\nNPROMPT:%s\nCLPSKP:%d\nCFGSCLE:%f\nW:%d\nH:%d\nSM:%d\nSTEP:%d\nSEED:%d\nBATCH:%d\nCIMG:%p\nSTR:%f\n\n",
|
printf("\nIMG2IMG PROMPT:%s\nNPROMPT:%s\nCLPSKP:%d\nCFGSCLE:%f\nW:%d\nH:%d\nSM:%d\nSTEP:%d\nSEED:%d\nBATCH:%d\nCIMG:%p\nSTR:%f\n\n",
|
||||||
|
@ -510,11 +550,13 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
|
||||||
|
|
||||||
results = img2img(sd_ctx,
|
results = img2img(sd_ctx,
|
||||||
input_image,
|
input_image,
|
||||||
|
mask_image,
|
||||||
sd_params->prompt.c_str(),
|
sd_params->prompt.c_str(),
|
||||||
sd_params->negative_prompt.c_str(),
|
sd_params->negative_prompt.c_str(),
|
||||||
sd_params->clip_skip,
|
sd_params->clip_skip,
|
||||||
sd_params->cfg_scale,
|
sd_params->cfg_scale,
|
||||||
sd_params->guidance,
|
sd_params->guidance,
|
||||||
|
sd_params->eta,
|
||||||
sd_params->width,
|
sd_params->width,
|
||||||
sd_params->height,
|
sd_params->height,
|
||||||
sd_params->sample_method,
|
sd_params->sample_method,
|
||||||
|
|
|
@ -25,11 +25,15 @@ static float pending_apply_lora_power = 1.0f;
|
||||||
|
|
||||||
const char* model_version_to_str[] = {
|
const char* model_version_to_str[] = {
|
||||||
"SD 1.x",
|
"SD 1.x",
|
||||||
|
"SD 1.x Inpaint",
|
||||||
"SD 2.x",
|
"SD 2.x",
|
||||||
|
"SD 2.x Inpaint",
|
||||||
"SDXL",
|
"SDXL",
|
||||||
|
"SDXL Inpaint",
|
||||||
"SVD",
|
"SVD",
|
||||||
"SD3.x",
|
"SD3.x",
|
||||||
"Flux"};
|
"Flux",
|
||||||
|
"Flux Fill"};
|
||||||
|
|
||||||
const char* sampling_methods_str[] = {
|
const char* sampling_methods_str[] = {
|
||||||
"Euler A",
|
"Euler A",
|
||||||
|
@ -42,6 +46,8 @@ const char* sampling_methods_str[] = {
|
||||||
"iPNDM",
|
"iPNDM",
|
||||||
"iPNDM_v",
|
"iPNDM_v",
|
||||||
"LCM",
|
"LCM",
|
||||||
|
"DDIM \"trailing\"",
|
||||||
|
"TCD"
|
||||||
};
|
};
|
||||||
|
|
||||||
/*================================================== Helper Functions ================================================*/
|
/*================================================== Helper Functions ================================================*/
|
||||||
|
@ -302,7 +308,7 @@ public:
|
||||||
model_loader.set_wtype_override(wtype);
|
model_loader.set_wtype_override(wtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (version == VERSION_SDXL) {
|
if (sd_version_is_sdxl(version)) {
|
||||||
vae_wtype = GGML_TYPE_F32;
|
vae_wtype = GGML_TYPE_F32;
|
||||||
model_loader.set_wtype_override(GGML_TYPE_F32, "vae.");
|
model_loader.set_wtype_override(GGML_TYPE_F32, "vae.");
|
||||||
}
|
}
|
||||||
|
@ -314,7 +320,7 @@ public:
|
||||||
|
|
||||||
LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor));
|
LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor));
|
||||||
|
|
||||||
if (version == VERSION_SDXL) {
|
if (sd_version_is_sdxl(version)) {
|
||||||
scale_factor = 0.13025f;
|
scale_factor = 0.13025f;
|
||||||
if (vae_path.size() == 0 && taesd_path_fixed.size() == 0) {
|
if (vae_path.size() == 0 && taesd_path_fixed.size() == 0) {
|
||||||
LOG_WARN(
|
LOG_WARN(
|
||||||
|
@ -368,7 +374,7 @@ public:
|
||||||
diffusion_model = std::make_shared<MMDiTModel>(backend, model_loader.tensor_storages_types);
|
diffusion_model = std::make_shared<MMDiTModel>(backend, model_loader.tensor_storages_types);
|
||||||
} else if (sd_version_is_flux(version)) {
|
} else if (sd_version_is_flux(version)) {
|
||||||
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, model_loader.tensor_storages_types);
|
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, model_loader.tensor_storages_types);
|
||||||
diffusion_model = std::make_shared<FluxModel>(backend, model_loader.tensor_storages_types, diffusion_flash_attn);
|
diffusion_model = std::make_shared<FluxModel>(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn);
|
||||||
} else {
|
} else {
|
||||||
if (id_embeddings_path.find("v2") != std::string::npos) {
|
if (id_embeddings_path.find("v2") != std::string::npos) {
|
||||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, model_loader.tensor_storages_types, embeddings_path, version, PM_VERSION_2);
|
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, model_loader.tensor_storages_types, embeddings_path, version, PM_VERSION_2);
|
||||||
|
@ -556,8 +562,12 @@ public:
|
||||||
|
|
||||||
// check is_using_v_parameterization_for_sd2
|
// check is_using_v_parameterization_for_sd2
|
||||||
bool is_using_v_parameterization = false;
|
bool is_using_v_parameterization = false;
|
||||||
if (version == VERSION_SD2) {
|
if (sd_version_is_sd2(version)) {
|
||||||
if (is_using_v_parameterization_for_sd2(ctx)) {
|
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
|
||||||
|
is_using_v_parameterization = true;
|
||||||
|
}
|
||||||
|
} else if (sd_version_is_sdxl(version)) {
|
||||||
|
if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) {
|
||||||
is_using_v_parameterization = true;
|
is_using_v_parameterization = true;
|
||||||
}
|
}
|
||||||
} else if (version == VERSION_SVD) {
|
} else if (version == VERSION_SVD) {
|
||||||
|
@ -631,7 +641,7 @@ public:
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_using_v_parameterization_for_sd2(ggml_context* work_ctx) {
|
bool is_using_v_parameterization_for_sd2(ggml_context* work_ctx, bool is_inpaint = false) {
|
||||||
struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1);
|
struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1);
|
||||||
ggml_set_f32(x_t, 0.5);
|
ggml_set_f32(x_t, 0.5);
|
||||||
struct ggml_tensor* c = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 1024, 2, 1, 1);
|
struct ggml_tensor* c = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 1024, 2, 1, 1);
|
||||||
|
@ -639,9 +649,15 @@ public:
|
||||||
|
|
||||||
struct ggml_tensor* timesteps = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1);
|
struct ggml_tensor* timesteps = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1);
|
||||||
ggml_set_f32(timesteps, 999);
|
ggml_set_f32(timesteps, 999);
|
||||||
|
|
||||||
|
struct ggml_tensor* concat = is_inpaint ? ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 5, 1) : NULL;
|
||||||
|
if (concat != NULL) {
|
||||||
|
ggml_set_f32(concat, 0);
|
||||||
|
}
|
||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t);
|
struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t);
|
||||||
diffusion_model->compute(n_threads, x_t, timesteps, c, NULL, NULL, NULL, -1, {}, 0.f, &out);
|
diffusion_model->compute(n_threads, x_t, timesteps, c, concat, NULL, NULL, -1, {}, 0.f, &out);
|
||||||
diffusion_model->free_compute_buffer();
|
diffusion_model->free_compute_buffer();
|
||||||
|
|
||||||
double result = 0.f;
|
double result = 0.f;
|
||||||
|
@ -683,7 +699,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
lora.multiplier = multiplier;
|
lora.multiplier = multiplier;
|
||||||
lora.apply(tensors, n_threads);
|
lora.apply(tensors, version, n_threads);
|
||||||
lora.free_params_buffer();
|
lora.free_params_buffer();
|
||||||
|
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
|
@ -713,7 +729,8 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
lora.multiplier = multiplier;
|
lora.multiplier = multiplier;
|
||||||
lora.apply(tensors, n_threads);
|
// TODO: send version?
|
||||||
|
lora.apply(tensors, version, n_threads);
|
||||||
lora.free_params_buffer();
|
lora.free_params_buffer();
|
||||||
|
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
|
@ -729,19 +746,20 @@ public:
|
||||||
for (auto& kv : lora_state) {
|
for (auto& kv : lora_state) {
|
||||||
const std::string& lora_name = kv.first;
|
const std::string& lora_name = kv.first;
|
||||||
float multiplier = kv.second;
|
float multiplier = kv.second;
|
||||||
|
lora_state_diff[lora_name] += multiplier;
|
||||||
if (curr_lora_state.find(lora_name) != curr_lora_state.end()) {
|
|
||||||
float curr_multiplier = curr_lora_state[lora_name];
|
|
||||||
float multiplier_diff = multiplier - curr_multiplier;
|
|
||||||
if (multiplier_diff != 0.f) {
|
|
||||||
lora_state_diff[lora_name] = multiplier_diff;
|
|
||||||
}
|
}
|
||||||
|
for (auto& kv : curr_lora_state) {
|
||||||
|
const std::string& lora_name = kv.first;
|
||||||
|
float curr_multiplier = kv.second;
|
||||||
|
lora_state_diff[lora_name] -= curr_multiplier;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t rm = lora_state_diff.size() - lora_state.size();
|
||||||
|
if (rm != 0) {
|
||||||
|
LOG_INFO("Attempting to apply %lu LoRAs (removing %lu applied LoRAs)", lora_state.size(), rm);
|
||||||
} else {
|
} else {
|
||||||
lora_state_diff[lora_name] = multiplier;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
LOG_INFO("Attempting to apply %lu LoRAs", lora_state.size());
|
LOG_INFO("Attempting to apply %lu LoRAs", lora_state.size());
|
||||||
|
}
|
||||||
|
|
||||||
for (auto& kv : lora_state_diff) {
|
for (auto& kv : lora_state_diff) {
|
||||||
apply_lora(kv.first, kv.second);
|
apply_lora(kv.first, kv.second);
|
||||||
|
@ -848,6 +866,7 @@ public:
|
||||||
float min_cfg,
|
float min_cfg,
|
||||||
float cfg_scale,
|
float cfg_scale,
|
||||||
float guidance,
|
float guidance,
|
||||||
|
float eta,
|
||||||
sample_method_t method,
|
sample_method_t method,
|
||||||
const std::vector<float>& sigmas,
|
const std::vector<float>& sigmas,
|
||||||
int start_merge_step,
|
int start_merge_step,
|
||||||
|
@ -855,7 +874,20 @@ public:
|
||||||
std::vector<int> skip_layers = {},
|
std::vector<int> skip_layers = {},
|
||||||
float slg_scale = 0,
|
float slg_scale = 0,
|
||||||
float skip_layer_start = 0.01,
|
float skip_layer_start = 0.01,
|
||||||
float skip_layer_end = 0.2) {
|
float skip_layer_end = 0.2,
|
||||||
|
ggml_tensor* noise_mask = nullptr) {
|
||||||
|
LOG_DEBUG("Sample");
|
||||||
|
struct ggml_init_params params;
|
||||||
|
size_t data_size = ggml_row_size(init_latent->type, init_latent->ne[0]);
|
||||||
|
for (int i = 1; i < 4; i++) {
|
||||||
|
data_size *= init_latent->ne[i];
|
||||||
|
}
|
||||||
|
data_size += 1024;
|
||||||
|
params.mem_size = data_size * 3;
|
||||||
|
params.mem_buffer = NULL;
|
||||||
|
params.no_alloc = false;
|
||||||
|
ggml_context* tmp_ctx = ggml_init(params);
|
||||||
|
|
||||||
size_t steps = sigmas.size() - 1;
|
size_t steps = sigmas.size() - 1;
|
||||||
// noise = load_tensor_from_file(work_ctx, "./rand0.bin");
|
// noise = load_tensor_from_file(work_ctx, "./rand0.bin");
|
||||||
// print_ggml_tensor(noise);
|
// print_ggml_tensor(noise);
|
||||||
|
@ -1014,10 +1046,23 @@ public:
|
||||||
pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f);
|
pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f);
|
||||||
// LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000);
|
// LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000);
|
||||||
}
|
}
|
||||||
|
if (noise_mask != nullptr) {
|
||||||
|
for (int64_t x = 0; x < denoised->ne[0]; x++) {
|
||||||
|
for (int64_t y = 0; y < denoised->ne[1]; y++) {
|
||||||
|
float mask = ggml_tensor_get_f32(noise_mask, x, y);
|
||||||
|
for (int64_t k = 0; k < denoised->ne[2]; k++) {
|
||||||
|
float init = ggml_tensor_get_f32(init_latent, x, y, k);
|
||||||
|
float den = ggml_tensor_get_f32(denoised, x, y, k);
|
||||||
|
ggml_tensor_set_f32(denoised, init + mask * (den - init), x, y, k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return denoised;
|
return denoised;
|
||||||
};
|
};
|
||||||
|
|
||||||
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng);
|
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta);
|
||||||
|
|
||||||
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
|
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
|
||||||
|
|
||||||
|
@ -1234,6 +1279,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
|
||||||
int clip_skip,
|
int clip_skip,
|
||||||
float cfg_scale,
|
float cfg_scale,
|
||||||
float guidance,
|
float guidance,
|
||||||
|
float eta,
|
||||||
int width,
|
int width,
|
||||||
int height,
|
int height,
|
||||||
enum sample_method_t sample_method,
|
enum sample_method_t sample_method,
|
||||||
|
@ -1248,7 +1294,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
|
||||||
std::vector<int> skip_layers = {},
|
std::vector<int> skip_layers = {},
|
||||||
float slg_scale = 0,
|
float slg_scale = 0,
|
||||||
float skip_layer_start = 0.01,
|
float skip_layer_start = 0.01,
|
||||||
float skip_layer_end = 0.2) {
|
float skip_layer_end = 0.2,
|
||||||
|
ggml_tensor* masked_image = NULL) {
|
||||||
if (seed < 0) {
|
if (seed < 0) {
|
||||||
// Generally, when using the provided command line, the seed is always >0.
|
// Generally, when using the provided command line, the seed is always >0.
|
||||||
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
|
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
|
||||||
|
@ -1294,7 +1341,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
|
||||||
if (sd_ctx->sd->stacked_id) {
|
if (sd_ctx->sd->stacked_id) {
|
||||||
if (!sd_ctx->sd->pmid_lora->applied) {
|
if (!sd_ctx->sd->pmid_lora->applied) {
|
||||||
t0 = ggml_time_ms();
|
t0 = ggml_time_ms();
|
||||||
sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->n_threads);
|
sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->version, sd_ctx->sd->n_threads);
|
||||||
t1 = ggml_time_ms();
|
t1 = ggml_time_ms();
|
||||||
sd_ctx->sd->pmid_lora->applied = true;
|
sd_ctx->sd->pmid_lora->applied = true;
|
||||||
LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||||
|
@ -1404,7 +1451,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
|
||||||
SDCondition uncond;
|
SDCondition uncond;
|
||||||
if (cfg_scale != 1.0) {
|
if (cfg_scale != 1.0) {
|
||||||
bool force_zero_embeddings = false;
|
bool force_zero_embeddings = false;
|
||||||
if (sd_ctx->sd->version == VERSION_SDXL && negative_prompt.size() == 0) {
|
if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0) {
|
||||||
force_zero_embeddings = true;
|
force_zero_embeddings = true;
|
||||||
}
|
}
|
||||||
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
|
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
|
||||||
|
@ -1441,6 +1488,39 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
|
||||||
int W = width / 8;
|
int W = width / 8;
|
||||||
int H = height / 8;
|
int H = height / 8;
|
||||||
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
|
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
|
||||||
|
ggml_tensor* noise_mask = nullptr;
|
||||||
|
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
|
||||||
|
if (masked_image == NULL) {
|
||||||
|
int64_t mask_channels = 1;
|
||||||
|
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
|
||||||
|
mask_channels = 8 * 8; // flatten the whole mask
|
||||||
|
}
|
||||||
|
// no mask, set the whole image as masked
|
||||||
|
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
|
||||||
|
for (int64_t x = 0; x < masked_image->ne[0]; x++) {
|
||||||
|
for (int64_t y = 0; y < masked_image->ne[1]; y++) {
|
||||||
|
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
|
||||||
|
// TODO: this might be wrong
|
||||||
|
for (int64_t c = 0; c < init_latent->ne[2]; c++) {
|
||||||
|
ggml_tensor_set_f32(masked_image, 0, x, y, c);
|
||||||
|
}
|
||||||
|
for (int64_t c = init_latent->ne[2]; c < masked_image->ne[2]; c++) {
|
||||||
|
ggml_tensor_set_f32(masked_image, 1, x, y, c);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ggml_tensor_set_f32(masked_image, 1, x, y, 0);
|
||||||
|
for (int64_t c = 1; c < masked_image->ne[2]; c++) {
|
||||||
|
ggml_tensor_set_f32(masked_image, 0, x, y, c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cond.c_concat = masked_image;
|
||||||
|
uncond.c_concat = masked_image;
|
||||||
|
} else {
|
||||||
|
noise_mask = masked_image;
|
||||||
|
}
|
||||||
for (int b = 0; b < batch_count; b++) {
|
for (int b = 0; b < batch_count; b++) {
|
||||||
int64_t sampling_start = ggml_time_ms();
|
int64_t sampling_start = ggml_time_ms();
|
||||||
int64_t cur_seed = seed + b;
|
int64_t cur_seed = seed + b;
|
||||||
|
@ -1469,6 +1549,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
guidance,
|
guidance,
|
||||||
|
eta,
|
||||||
sample_method,
|
sample_method,
|
||||||
sigmas,
|
sigmas,
|
||||||
start_merge_step,
|
start_merge_step,
|
||||||
|
@ -1476,7 +1557,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
|
||||||
skip_layers,
|
skip_layers,
|
||||||
slg_scale,
|
slg_scale,
|
||||||
skip_layer_start,
|
skip_layer_start,
|
||||||
skip_layer_end);
|
skip_layer_end,
|
||||||
|
noise_mask);
|
||||||
|
|
||||||
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
|
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
|
||||||
// print_ggml_tensor(x_0);
|
// print_ggml_tensor(x_0);
|
||||||
int64_t sampling_end = ggml_time_ms();
|
int64_t sampling_end = ggml_time_ms();
|
||||||
|
@ -1532,6 +1615,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
||||||
int clip_skip,
|
int clip_skip,
|
||||||
float cfg_scale,
|
float cfg_scale,
|
||||||
float guidance,
|
float guidance,
|
||||||
|
float eta,
|
||||||
int width,
|
int width,
|
||||||
int height,
|
int height,
|
||||||
enum sample_method_t sample_method,
|
enum sample_method_t sample_method,
|
||||||
|
@ -1598,6 +1682,10 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
||||||
ggml_set_f32(init_latent, 0.f);
|
ggml_set_f32(init_latent, 0.f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
|
||||||
|
LOG_WARN("This is an inpainting model, this should only be used in img2img mode with a mask");
|
||||||
|
}
|
||||||
|
|
||||||
sd_image_t* result_images = generate_image(sd_ctx,
|
sd_image_t* result_images = generate_image(sd_ctx,
|
||||||
work_ctx,
|
work_ctx,
|
||||||
init_latent,
|
init_latent,
|
||||||
|
@ -1606,6 +1694,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
||||||
clip_skip,
|
clip_skip,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
guidance,
|
guidance,
|
||||||
|
eta,
|
||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
sample_method,
|
sample_method,
|
||||||
|
@ -1631,11 +1720,13 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
||||||
|
|
||||||
sd_image_t* img2img(sd_ctx_t* sd_ctx,
|
sd_image_t* img2img(sd_ctx_t* sd_ctx,
|
||||||
sd_image_t init_image,
|
sd_image_t init_image,
|
||||||
|
sd_image_t mask,
|
||||||
const char* prompt_c_str,
|
const char* prompt_c_str,
|
||||||
const char* negative_prompt_c_str,
|
const char* negative_prompt_c_str,
|
||||||
int clip_skip,
|
int clip_skip,
|
||||||
float cfg_scale,
|
float cfg_scale,
|
||||||
float guidance,
|
float guidance,
|
||||||
|
float eta,
|
||||||
int width,
|
int width,
|
||||||
int height,
|
int height,
|
||||||
sample_method_t sample_method,
|
sample_method_t sample_method,
|
||||||
|
@ -1670,7 +1761,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
|
||||||
if (sd_ctx->sd->stacked_id) {
|
if (sd_ctx->sd->stacked_id) {
|
||||||
params.mem_size += static_cast<size_t>(10 * 1024 * 1024); // 10 MB
|
params.mem_size += static_cast<size_t>(10 * 1024 * 1024); // 10 MB
|
||||||
}
|
}
|
||||||
params.mem_size += width * height * 3 * sizeof(float) * 2;
|
params.mem_size += width * height * 3 * sizeof(float) * 3;
|
||||||
params.mem_size *= batch_count;
|
params.mem_size *= batch_count;
|
||||||
params.mem_buffer = NULL;
|
params.mem_buffer = NULL;
|
||||||
params.no_alloc = false;
|
params.no_alloc = false;
|
||||||
|
@ -1691,7 +1782,70 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
|
||||||
sd_ctx->sd->rng->manual_seed(seed);
|
sd_ctx->sd->rng->manual_seed(seed);
|
||||||
|
|
||||||
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
|
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
|
||||||
|
ggml_tensor* mask_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 1, 1);
|
||||||
|
|
||||||
|
sd_mask_to_tensor(mask.data, mask_img);
|
||||||
|
|
||||||
sd_image_to_tensor(init_image.data, init_img);
|
sd_image_to_tensor(init_image.data, init_img);
|
||||||
|
|
||||||
|
ggml_tensor* masked_image;
|
||||||
|
|
||||||
|
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
|
||||||
|
int64_t mask_channels = 1;
|
||||||
|
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
|
||||||
|
mask_channels = 8 * 8; // flatten the whole mask
|
||||||
|
}
|
||||||
|
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
|
||||||
|
sd_apply_mask(init_img, mask_img, masked_img);
|
||||||
|
ggml_tensor* masked_image_0 = NULL;
|
||||||
|
if (!sd_ctx->sd->use_tiny_autoencoder) {
|
||||||
|
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
|
||||||
|
masked_image_0 = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
|
||||||
|
} else {
|
||||||
|
masked_image_0 = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
|
||||||
|
}
|
||||||
|
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_image_0->ne[0], masked_image_0->ne[1], mask_channels + masked_image_0->ne[2], 1);
|
||||||
|
for (int ix = 0; ix < masked_image_0->ne[0]; ix++) {
|
||||||
|
for (int iy = 0; iy < masked_image_0->ne[1]; iy++) {
|
||||||
|
int mx = ix * 8;
|
||||||
|
int my = iy * 8;
|
||||||
|
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
|
||||||
|
for (int k = 0; k < masked_image_0->ne[2]; k++) {
|
||||||
|
float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k);
|
||||||
|
ggml_tensor_set_f32(masked_image, v, ix, iy, k);
|
||||||
|
}
|
||||||
|
// "Encode" 8x8 mask chunks into a flattened 1x64 vector, and concatenate to masked image
|
||||||
|
for (int x = 0; x < 8; x++) {
|
||||||
|
for (int y = 0; y < 8; y++) {
|
||||||
|
float m = ggml_tensor_get_f32(mask_img, mx + x, my + y);
|
||||||
|
// TODO: check if the way the mask is flattened is correct (is it supposed to be x*8+y or x+8*y?)
|
||||||
|
// python code was using "b (h 8) (w 8) -> b (8 8) h w"
|
||||||
|
ggml_tensor_set_f32(masked_image, m, ix, iy, masked_image_0->ne[2] + x * 8 + y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
float m = ggml_tensor_get_f32(mask_img, mx, my);
|
||||||
|
ggml_tensor_set_f32(masked_image, m, ix, iy, 0);
|
||||||
|
for (int k = 0; k < masked_image_0->ne[2]; k++) {
|
||||||
|
float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k);
|
||||||
|
ggml_tensor_set_f32(masked_image, v, ix, iy, k + mask_channels);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// LOG_WARN("Inpainting with a base model is not great");
|
||||||
|
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 1, 1);
|
||||||
|
for (int ix = 0; ix < masked_image->ne[0]; ix++) {
|
||||||
|
for (int iy = 0; iy < masked_image->ne[1]; iy++) {
|
||||||
|
int mx = ix * 8;
|
||||||
|
int my = iy * 8;
|
||||||
|
float m = ggml_tensor_get_f32(mask_img, mx, my);
|
||||||
|
ggml_tensor_set_f32(masked_image, m, ix, iy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor* init_latent = NULL;
|
ggml_tensor* init_latent = NULL;
|
||||||
if (!sd_ctx->sd->use_tiny_autoencoder) {
|
if (!sd_ctx->sd->use_tiny_autoencoder) {
|
||||||
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
|
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
|
||||||
|
@ -1705,6 +1859,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
|
||||||
|
|
||||||
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
|
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
|
||||||
size_t t_enc = static_cast<size_t>(sample_steps * strength);
|
size_t t_enc = static_cast<size_t>(sample_steps * strength);
|
||||||
|
if (t_enc == sample_steps)
|
||||||
|
t_enc--;
|
||||||
LOG_INFO("target t_enc is %zu steps", t_enc);
|
LOG_INFO("target t_enc is %zu steps", t_enc);
|
||||||
std::vector<float> sigma_sched;
|
std::vector<float> sigma_sched;
|
||||||
sigma_sched.assign(sigmas.begin() + sample_steps - t_enc - 1, sigmas.end());
|
sigma_sched.assign(sigmas.begin() + sample_steps - t_enc - 1, sigmas.end());
|
||||||
|
@ -1717,6 +1873,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
|
||||||
clip_skip,
|
clip_skip,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
guidance,
|
guidance,
|
||||||
|
eta,
|
||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
sample_method,
|
sample_method,
|
||||||
|
@ -1731,11 +1888,12 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
|
||||||
skip_layers_vec,
|
skip_layers_vec,
|
||||||
slg_scale,
|
slg_scale,
|
||||||
skip_layer_start,
|
skip_layer_start,
|
||||||
skip_layer_end);
|
skip_layer_end,
|
||||||
|
masked_image);
|
||||||
|
|
||||||
size_t t2 = ggml_time_ms();
|
size_t t2 = ggml_time_ms();
|
||||||
|
|
||||||
LOG_INFO("img2img completed in %.2fs", (t1 - t0) * 1.0f / 1000);
|
LOG_INFO("img2img completed in %.2fs", (t2 - t0) * 1.0f / 1000);
|
||||||
|
|
||||||
return result_images;
|
return result_images;
|
||||||
}
|
}
|
||||||
|
@ -1829,6 +1987,7 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
|
||||||
min_cfg,
|
min_cfg,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
0.f,
|
0.f,
|
||||||
|
0.f,
|
||||||
sample_method,
|
sample_method,
|
||||||
sigmas,
|
sigmas,
|
||||||
-1,
|
-1,
|
||||||
|
|
|
@ -44,6 +44,8 @@ enum sample_method_t {
|
||||||
IPNDM,
|
IPNDM,
|
||||||
IPNDM_V,
|
IPNDM_V,
|
||||||
LCM,
|
LCM,
|
||||||
|
DDIM_TRAILING,
|
||||||
|
TCD,
|
||||||
N_SAMPLE_METHODS
|
N_SAMPLE_METHODS
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -161,6 +163,7 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
||||||
int clip_skip,
|
int clip_skip,
|
||||||
float cfg_scale,
|
float cfg_scale,
|
||||||
float guidance,
|
float guidance,
|
||||||
|
float eta,
|
||||||
int width,
|
int width,
|
||||||
int height,
|
int height,
|
||||||
enum sample_method_t sample_method,
|
enum sample_method_t sample_method,
|
||||||
|
@ -180,11 +183,13 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
||||||
|
|
||||||
SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
|
SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
|
||||||
sd_image_t init_image,
|
sd_image_t init_image,
|
||||||
|
sd_image_t mask_image,
|
||||||
const char* prompt,
|
const char* prompt,
|
||||||
const char* negative_prompt,
|
const char* negative_prompt,
|
||||||
int clip_skip,
|
int clip_skip,
|
||||||
float cfg_scale,
|
float cfg_scale,
|
||||||
float guidance,
|
float guidance,
|
||||||
|
float eta,
|
||||||
int width,
|
int width,
|
||||||
int height,
|
int height,
|
||||||
enum sample_method_t sample_method,
|
enum sample_method_t sample_method,
|
||||||
|
|
|
@ -201,7 +201,7 @@ struct TinyAutoEncoder : public GGMLRunner {
|
||||||
bool decoder_only = true,
|
bool decoder_only = true,
|
||||||
SDVersion version = VERSION_SD1)
|
SDVersion version = VERSION_SD1)
|
||||||
: decode_only(decoder_only),
|
: decode_only(decoder_only),
|
||||||
taesd(decode_only, version),
|
taesd(decoder_only, version),
|
||||||
GGMLRunner(backend) {
|
GGMLRunner(backend) {
|
||||||
taesd.init(params_ctx, tensor_types, prefix);
|
taesd.init(params_ctx, tensor_types, prefix);
|
||||||
}
|
}
|
||||||
|
|
24
otherarch/sdcpp/thirdparty/stb_image_write.h
vendored
24
otherarch/sdcpp/thirdparty/stb_image_write.h
vendored
|
@ -177,7 +177,7 @@ STBIWDEF int stbi_write_png(char const *filename, int w, int h, int comp, const
|
||||||
STBIWDEF int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data);
|
STBIWDEF int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data);
|
||||||
STBIWDEF int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data);
|
STBIWDEF int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data);
|
||||||
STBIWDEF int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data);
|
STBIWDEF int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data);
|
||||||
STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality);
|
STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality, const char* parameters = NULL);
|
||||||
|
|
||||||
#ifdef STBIW_WINDOWS_UTF8
|
#ifdef STBIW_WINDOWS_UTF8
|
||||||
STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input);
|
STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input);
|
||||||
|
@ -1412,7 +1412,7 @@ static int stbiw__jpg_processDU(stbi__write_context *s, int *bitBuf, int *bitCnt
|
||||||
return DU[0];
|
return DU[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
static int stbi_write_jpg_core(stbi__write_context *s, int width, int height, int comp, const void* data, int quality) {
|
static int stbi_write_jpg_core(stbi__write_context *s, int width, int height, int comp, const void* data, int quality, const char* parameters) {
|
||||||
// Constants that don't pollute global namespace
|
// Constants that don't pollute global namespace
|
||||||
static const unsigned char std_dc_luminance_nrcodes[] = {0,0,1,5,1,1,1,1,1,1,0,0,0,0,0,0,0};
|
static const unsigned char std_dc_luminance_nrcodes[] = {0,0,1,5,1,1,1,1,1,1,0,0,0,0,0,0,0};
|
||||||
static const unsigned char std_dc_luminance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11};
|
static const unsigned char std_dc_luminance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11};
|
||||||
|
@ -1521,6 +1521,20 @@ static int stbi_write_jpg_core(stbi__write_context *s, int width, int height, in
|
||||||
s->func(s->context, (void*)YTable, sizeof(YTable));
|
s->func(s->context, (void*)YTable, sizeof(YTable));
|
||||||
stbiw__putc(s, 1);
|
stbiw__putc(s, 1);
|
||||||
s->func(s->context, UVTable, sizeof(UVTable));
|
s->func(s->context, UVTable, sizeof(UVTable));
|
||||||
|
|
||||||
|
// comment block with parameters of generation
|
||||||
|
if(parameters != NULL) {
|
||||||
|
stbiw__putc(s, 0xFF /* comnent */ );
|
||||||
|
stbiw__putc(s, 0xFE /* marker */ );
|
||||||
|
size_t param_length = std::min(2 + strlen("parameters") + 1 + strlen(parameters) + 1, (size_t) 0xFFFF);
|
||||||
|
stbiw__putc(s, param_length >> 8); // no need to mask, length < 65536
|
||||||
|
stbiw__putc(s, param_length & 0xFF);
|
||||||
|
s->func(s->context, (void*)"parameters", strlen("parameters") + 1); // std::string is zero-terminated
|
||||||
|
s->func(s->context, (void*)parameters, std::min(param_length, (size_t) 65534) - 2 - strlen("parameters") - 1);
|
||||||
|
if(param_length > 65534) stbiw__putc(s, 0); // always zero-terminate for safety
|
||||||
|
if(param_length & 1) stbiw__putc(s, 0xFF); // pad to even length
|
||||||
|
}
|
||||||
|
|
||||||
s->func(s->context, (void*)head1, sizeof(head1));
|
s->func(s->context, (void*)head1, sizeof(head1));
|
||||||
s->func(s->context, (void*)(std_dc_luminance_nrcodes+1), sizeof(std_dc_luminance_nrcodes)-1);
|
s->func(s->context, (void*)(std_dc_luminance_nrcodes+1), sizeof(std_dc_luminance_nrcodes)-1);
|
||||||
s->func(s->context, (void*)std_dc_luminance_values, sizeof(std_dc_luminance_values));
|
s->func(s->context, (void*)std_dc_luminance_values, sizeof(std_dc_luminance_values));
|
||||||
|
@ -1625,16 +1639,16 @@ STBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x,
|
||||||
{
|
{
|
||||||
stbi__write_context s = { 0 };
|
stbi__write_context s = { 0 };
|
||||||
stbi__start_write_callbacks(&s, func, context);
|
stbi__start_write_callbacks(&s, func, context);
|
||||||
return stbi_write_jpg_core(&s, x, y, comp, (void *) data, quality);
|
return stbi_write_jpg_core(&s, x, y, comp, (void *) data, quality, NULL);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#ifndef STBI_WRITE_NO_STDIO
|
#ifndef STBI_WRITE_NO_STDIO
|
||||||
STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality)
|
STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality, const char* parameters)
|
||||||
{
|
{
|
||||||
stbi__write_context s = { 0 };
|
stbi__write_context s = { 0 };
|
||||||
if (stbi__start_write_file(&s,filename)) {
|
if (stbi__start_write_file(&s,filename)) {
|
||||||
int r = stbi_write_jpg_core(&s, x, y, comp, data, quality);
|
int r = stbi_write_jpg_core(&s, x, y, comp, data, quality, parameters);
|
||||||
stbi__end_write_file(&s);
|
stbi__end_write_file(&s);
|
||||||
return r;
|
return r;
|
||||||
} else
|
} else
|
||||||
|
|
|
@ -166,6 +166,7 @@ public:
|
||||||
// ldm.modules.diffusionmodules.openaimodel.UNetModel
|
// ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
class UnetModelBlock : public GGMLBlock {
|
class UnetModelBlock : public GGMLBlock {
|
||||||
protected:
|
protected:
|
||||||
|
static std::map<std::string, enum ggml_type> empty_tensor_types;
|
||||||
SDVersion version = VERSION_SD1;
|
SDVersion version = VERSION_SD1;
|
||||||
// network hparams
|
// network hparams
|
||||||
int in_channels = 4;
|
int in_channels = 4;
|
||||||
|
@ -183,13 +184,13 @@ public:
|
||||||
int model_channels = 320;
|
int model_channels = 320;
|
||||||
int adm_in_channels = 2816; // only for VERSION_SDXL/SVD
|
int adm_in_channels = 2816; // only for VERSION_SDXL/SVD
|
||||||
|
|
||||||
UnetModelBlock(SDVersion version = VERSION_SD1, bool flash_attn = false)
|
UnetModelBlock(SDVersion version = VERSION_SD1, std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types, bool flash_attn = false)
|
||||||
: version(version) {
|
: version(version) {
|
||||||
if (version == VERSION_SD2) {
|
if (sd_version_is_sd2(version)) {
|
||||||
context_dim = 1024;
|
context_dim = 1024;
|
||||||
num_head_channels = 64;
|
num_head_channels = 64;
|
||||||
num_heads = -1;
|
num_heads = -1;
|
||||||
} else if (version == VERSION_SDXL) {
|
} else if (sd_version_is_sdxl(version)) {
|
||||||
context_dim = 2048;
|
context_dim = 2048;
|
||||||
attention_resolutions = {4, 2};
|
attention_resolutions = {4, 2};
|
||||||
channel_mult = {1, 2, 4};
|
channel_mult = {1, 2, 4};
|
||||||
|
@ -204,6 +205,10 @@ public:
|
||||||
num_head_channels = 64;
|
num_head_channels = 64;
|
||||||
num_heads = -1;
|
num_heads = -1;
|
||||||
}
|
}
|
||||||
|
if (sd_version_is_inpaint(version)) {
|
||||||
|
in_channels = 9;
|
||||||
|
}
|
||||||
|
|
||||||
// dims is always 2
|
// dims is always 2
|
||||||
// use_temporal_attention is always True for SVD
|
// use_temporal_attention is always True for SVD
|
||||||
|
|
||||||
|
@ -211,7 +216,7 @@ public:
|
||||||
// time_embed_1 is nn.SiLU()
|
// time_embed_1 is nn.SiLU()
|
||||||
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
|
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
|
||||||
|
|
||||||
if (version == VERSION_SDXL || version == VERSION_SVD) {
|
if (sd_version_is_sdxl(version) || version == VERSION_SVD) {
|
||||||
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
|
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
|
||||||
// label_emb_1 is nn.SiLU()
|
// label_emb_1 is nn.SiLU()
|
||||||
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
|
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
|
||||||
|
@ -536,7 +541,7 @@ struct UNetModelRunner : public GGMLRunner {
|
||||||
const std::string prefix,
|
const std::string prefix,
|
||||||
SDVersion version = VERSION_SD1,
|
SDVersion version = VERSION_SD1,
|
||||||
bool flash_attn = false)
|
bool flash_attn = false)
|
||||||
: GGMLRunner(backend), unet(version, flash_attn) {
|
: GGMLRunner(backend), unet(version, tensor_types, flash_attn) {
|
||||||
unet.init(params_ctx, tensor_types, prefix);
|
unet.init(params_ctx, tensor_types, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -566,6 +571,7 @@ struct UNetModelRunner : public GGMLRunner {
|
||||||
context = to_backend(context);
|
context = to_backend(context);
|
||||||
y = to_backend(y);
|
y = to_backend(y);
|
||||||
timesteps = to_backend(timesteps);
|
timesteps = to_backend(timesteps);
|
||||||
|
c_concat = to_backend(c_concat);
|
||||||
|
|
||||||
for (int i = 0; i < controls.size(); i++) {
|
for (int i = 0; i < controls.size(); i++) {
|
||||||
controls[i] = to_backend(controls[i]);
|
controls[i] = to_backend(controls[i]);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue