fixed sd to work on larger images by adding tiling, also limit res for sd1.5

This commit is contained in:
Concedo 2024-11-04 23:26:15 +08:00
parent f153a14daf
commit 5b90eeaf17
3 changed files with 35 additions and 2 deletions

View file

@ -297,6 +297,7 @@ std::string clean_input_prompt(const std::string& input) {
return result; return result;
} }
sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs) sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
{ {
sd_generation_outputs output; sd_generation_outputs output;
@ -331,12 +332,29 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
sd_params->clip_skip = inputs.clip_skip; sd_params->clip_skip = inputs.clip_skip;
sd_params->mode = (img2img_data==""?SDMode::TXT2IMG:SDMode::IMG2IMG); sd_params->mode = (img2img_data==""?SDMode::TXT2IMG:SDMode::IMG2IMG);
//ensure unsupported dimensions are fixed
int biggestdim = std::max(sd_params->width,sd_params->height);
auto loadedsdver = get_loaded_sd_version(sd_ctx);
int reslimit = (loadedsdver==SDVersion::VERSION_SD1 || loadedsdver==SDVersion::VERSION_SD2)?832:1024;
if(biggestdim > reslimit)
{
float scaler = (float)biggestdim / (float)reslimit;
int newwidth = (int)((float)sd_params->width / scaler);
int newheight = (int)((float)sd_params->height / scaler);
newwidth = newwidth - (newwidth%64);
newheight = newheight - (newheight%64);
sd_params->width = newwidth;
sd_params->height = newheight;
}
bool dotile = (sd_params->width>768 || sd_params->height>768);
set_sd_vae_tiling(sd_ctx,dotile); //changes vae tiling, prevents memory related crash/oom
//for img2img //for img2img
sd_image_t input_image = {0,0,0,nullptr}; sd_image_t input_image = {0,0,0,nullptr};
std::vector<uint8_t> image_buffer; std::vector<uint8_t> image_buffer;
int nx, ny, nc; int nx, ny, nc;
int img2imgW = inputs.width; //for img2img input int img2imgW = sd_params->width; //for img2img input
int img2imgH = inputs.height; int img2imgH = sd_params->height;
int img2imgC = 3; // Assuming RGB image int img2imgC = 3; // Assuming RGB image
std::vector<uint8_t> resized_image_buf(img2imgW * img2imgH * img2imgC); std::vector<uint8_t> resized_image_buf(img2imgW * img2imgH * img2imgC);
@ -397,6 +415,8 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
control_image, control_image,
sd_params->control_strength); sd_params->control_strength);
} }
results = txt2img(sd_ctx, results = txt2img(sd_ctx,
sd_params->prompt.c_str(), sd_params->prompt.c_str(),
sd_params->negative_prompt.c_str(), sd_params->negative_prompt.c_str(),

View file

@ -1054,6 +1054,16 @@ struct sd_ctx_t {
StableDiffusionGGML* sd = NULL; StableDiffusionGGML* sd = NULL;
}; };
void set_sd_vae_tiling(sd_ctx_t* ctx, bool tiling)
{
ctx->sd->vae_tiling = tiling;
}
int get_loaded_sd_version(sd_ctx_t* ctx)
{
return ctx->sd->version;
}
sd_ctx_t* new_sd_ctx(const char* model_path_c_str, sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
const char* clip_l_path_c_str, const char* clip_l_path_c_str,
const char* clip_g_path_c_str, const char* clip_g_path_c_str,

View file

@ -122,6 +122,9 @@ typedef struct {
typedef struct sd_ctx_t sd_ctx_t; typedef struct sd_ctx_t sd_ctx_t;
SD_API void set_sd_vae_tiling(sd_ctx_t* ctx, bool tiling);
SD_API int get_loaded_sd_version(sd_ctx_t* ctx);
SD_API sd_ctx_t* new_sd_ctx(const char* model_path, SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
const char* clip_l_path, const char* clip_l_path,
const char* clip_g_path, const char* clip_g_path,