mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-04-28 11:40:43 +00:00
sd: sync to master-431-23fce0b (#1893)
* sd: sync to master-427-78e15bd * add kl_optimal to the available schedulers list * more robust workaround to avoid stb linkage issues * sd: sync to master-431-23fce0b * add TAEHV support and disable TAE if the model isn't found
This commit is contained in:
parent
27c53099f4
commit
44ce1a80b3
11 changed files with 787 additions and 196 deletions
|
|
@ -3772,7 +3772,7 @@ Change Mode<br>
|
|||
if friendlysdmodelname=="inactive" or fullsdmodelpath=="":
|
||||
response_body = (json.dumps([]).encode())
|
||||
else:
|
||||
response_body = (json.dumps([{"name":name,"label":name} for name in ["default","discrete","karras","exponential","ays","gits","sgm_uniform","simple","smoothstep","lcm"]]).encode())
|
||||
response_body = (json.dumps([{"name":name,"label":name} for name in ["default","discrete","karras","exponential","ays","gits","sgm_uniform","simple","smoothstep","kl_optimal","lcm"]]).encode())
|
||||
elif clean_path.endswith('/sdapi/v1/latent-upscale-modes'):
|
||||
response_body = (json.dumps([]).encode())
|
||||
elif clean_path.endswith('/sdapi/v1/upscalers'):
|
||||
|
|
|
|||
|
|
@ -20,18 +20,22 @@ namespace fs = std::filesystem;
|
|||
|
||||
#include "stable-diffusion.h"
|
||||
|
||||
// #define STB_IMAGE_IMPLEMENTATION
|
||||
//#define STB_IMAGE_STATIC
|
||||
namespace { // kcpp
|
||||
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#define STB_IMAGE_STATIC
|
||||
#include "stb_image.h"
|
||||
|
||||
#define STB_IMAGE_WRITE_IMPLEMENTATION
|
||||
//#define STB_IMAGE_WRITE_STATIC
|
||||
#define STB_IMAGE_WRITE_STATIC
|
||||
#include "stb_image_write.h"
|
||||
|
||||
// #define STB_IMAGE_RESIZE_IMPLEMENTATION
|
||||
//#define STB_IMAGE_RESIZE_STATIC
|
||||
#define STB_IMAGE_RESIZE_IMPLEMENTATION
|
||||
#define STB_IMAGE_RESIZE_STATIC
|
||||
#include "stb_image_resize.h"
|
||||
|
||||
}
|
||||
|
||||
#define SAFE_STR(s) ((s) ? (s) : "")
|
||||
#define BOOL_STR(b) ((b) ? "true" : "false")
|
||||
|
||||
|
|
@ -87,6 +91,114 @@ static std::string argv_to_utf8(int index, const char** argv) {
|
|||
|
||||
#endif
|
||||
|
||||
static void print_utf8(FILE* stream, const char* utf8) {
|
||||
if (!utf8)
|
||||
return;
|
||||
|
||||
#ifdef _WIN32
|
||||
HANDLE h = (stream == stderr)
|
||||
? GetStdHandle(STD_ERROR_HANDLE)
|
||||
: GetStdHandle(STD_OUTPUT_HANDLE);
|
||||
|
||||
int wlen = MultiByteToWideChar(CP_UTF8, 0, utf8, -1, NULL, 0);
|
||||
if (wlen <= 0)
|
||||
return;
|
||||
|
||||
wchar_t* wbuf = (wchar_t*)malloc(wlen * sizeof(wchar_t));
|
||||
MultiByteToWideChar(CP_UTF8, 0, utf8, -1, wbuf, wlen);
|
||||
|
||||
DWORD written;
|
||||
WriteConsoleW(h, wbuf, wlen - 1, &written, NULL);
|
||||
|
||||
free(wbuf);
|
||||
#else
|
||||
fputs(utf8, stream);
|
||||
#endif
|
||||
}
|
||||
|
||||
static std::string sd_basename(const std::string& path) {
|
||||
size_t pos = path.find_last_of('/');
|
||||
if (pos != std::string::npos) {
|
||||
return path.substr(pos + 1);
|
||||
}
|
||||
pos = path.find_last_of('\\');
|
||||
if (pos != std::string::npos) {
|
||||
return path.substr(pos + 1);
|
||||
}
|
||||
return path;
|
||||
}
|
||||
|
||||
static void log_print(enum sd_log_level_t level, const char* log, bool verbose, bool color) {
|
||||
int tag_color;
|
||||
const char* level_str;
|
||||
FILE* out_stream = (level == SD_LOG_ERROR) ? stderr : stdout;
|
||||
|
||||
if (!log || (!verbose && level <= SD_LOG_DEBUG)) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (level) {
|
||||
case SD_LOG_DEBUG:
|
||||
tag_color = 37;
|
||||
level_str = "DEBUG";
|
||||
break;
|
||||
case SD_LOG_INFO:
|
||||
tag_color = 34;
|
||||
level_str = "INFO";
|
||||
break;
|
||||
case SD_LOG_WARN:
|
||||
tag_color = 35;
|
||||
level_str = "WARN";
|
||||
break;
|
||||
case SD_LOG_ERROR:
|
||||
tag_color = 31;
|
||||
level_str = "ERROR";
|
||||
break;
|
||||
default: /* Potential future-proofing */
|
||||
tag_color = 33;
|
||||
level_str = "?????";
|
||||
break;
|
||||
}
|
||||
|
||||
if (color) {
|
||||
fprintf(out_stream, "\033[%d;1m[%-5s]\033[0m ", tag_color, level_str);
|
||||
} else {
|
||||
fprintf(out_stream, "[%-5s] ", level_str);
|
||||
}
|
||||
print_utf8(out_stream, log);
|
||||
fflush(out_stream);
|
||||
}
|
||||
|
||||
#define LOG_BUFFER_SIZE 4096
|
||||
|
||||
static bool log_verbose = false;
|
||||
static bool log_color = false;
|
||||
|
||||
static void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...) {
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
|
||||
static char log_buffer[LOG_BUFFER_SIZE + 1];
|
||||
int written = snprintf(log_buffer, LOG_BUFFER_SIZE, "%s:%-4d - ", sd_basename(file).c_str(), line);
|
||||
|
||||
if (written >= 0 && written < LOG_BUFFER_SIZE) {
|
||||
vsnprintf(log_buffer + written, LOG_BUFFER_SIZE - written, format, args);
|
||||
}
|
||||
size_t len = strlen(log_buffer);
|
||||
if (log_buffer[len - 1] != '\n') {
|
||||
strncat(log_buffer, "\n", LOG_BUFFER_SIZE - len);
|
||||
}
|
||||
|
||||
log_print(level, log_buffer, log_verbose, log_color);
|
||||
|
||||
va_end(args);
|
||||
}
|
||||
|
||||
#define LOG_DEBUG(format, ...) log_printf(SD_LOG_DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__)
|
||||
#define LOG_INFO(format, ...) log_printf(SD_LOG_INFO, __FILE__, __LINE__, format, ##__VA_ARGS__)
|
||||
#define LOG_WARN(format, ...) log_printf(SD_LOG_WARN, __FILE__, __LINE__, format, ##__VA_ARGS__)
|
||||
#define LOG_ERROR(format, ...) log_printf(SD_LOG_ERROR, __FILE__, __LINE__, format, ##__VA_ARGS__)
|
||||
|
||||
struct StringOption {
|
||||
std::string short_name;
|
||||
std::string long_name;
|
||||
|
|
@ -296,11 +408,11 @@ static bool parse_options(int argc, const char** argv, const std::vector<ArgOpti
|
|||
}
|
||||
|
||||
if (invalid_arg) {
|
||||
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
|
||||
LOG_ERROR("error: invalid parameter for argument: %s", arg.c_str());
|
||||
return false;
|
||||
}
|
||||
if (!found_arg) {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
LOG_ERROR("error: unknown argument: %s", arg.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -407,6 +519,10 @@ struct SDContextParams {
|
|||
"--taesd",
|
||||
"path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)",
|
||||
&taesd_path},
|
||||
{"",
|
||||
"--tae",
|
||||
"alias of --taesd",
|
||||
&taesd_path},
|
||||
{"",
|
||||
"--control-net",
|
||||
"path to control net model",
|
||||
|
|
@ -511,8 +627,8 @@ struct SDContextParams {
|
|||
const char* arg = argv[index];
|
||||
wtype = str_to_sd_type(arg);
|
||||
if (wtype == SD_TYPE_COUNT) {
|
||||
fprintf(stderr, "error: invalid weight format %s\n",
|
||||
arg);
|
||||
LOG_ERROR("error: invalid weight format %s",
|
||||
arg);
|
||||
return -1;
|
||||
}
|
||||
return 1;
|
||||
|
|
@ -525,8 +641,8 @@ struct SDContextParams {
|
|||
const char* arg = argv[index];
|
||||
rng_type = str_to_rng_type(arg);
|
||||
if (rng_type == RNG_TYPE_COUNT) {
|
||||
fprintf(stderr, "error: invalid rng type %s\n",
|
||||
arg);
|
||||
LOG_ERROR("error: invalid rng type %s",
|
||||
arg);
|
||||
return -1;
|
||||
}
|
||||
return 1;
|
||||
|
|
@ -539,8 +655,8 @@ struct SDContextParams {
|
|||
const char* arg = argv[index];
|
||||
sampler_rng_type = str_to_rng_type(arg);
|
||||
if (sampler_rng_type == RNG_TYPE_COUNT) {
|
||||
fprintf(stderr, "error: invalid sampler rng type %s\n",
|
||||
arg);
|
||||
LOG_ERROR("error: invalid sampler rng type %s",
|
||||
arg);
|
||||
return -1;
|
||||
}
|
||||
return 1;
|
||||
|
|
@ -553,8 +669,8 @@ struct SDContextParams {
|
|||
const char* arg = argv[index];
|
||||
prediction = str_to_prediction(arg);
|
||||
if (prediction == PREDICTION_COUNT) {
|
||||
fprintf(stderr, "error: invalid prediction type %s\n",
|
||||
arg);
|
||||
LOG_ERROR("error: invalid prediction type %s",
|
||||
arg);
|
||||
return -1;
|
||||
}
|
||||
return 1;
|
||||
|
|
@ -567,8 +683,8 @@ struct SDContextParams {
|
|||
const char* arg = argv[index];
|
||||
lora_apply_mode = str_to_lora_apply_mode(arg);
|
||||
if (lora_apply_mode == LORA_APPLY_MODE_COUNT) {
|
||||
fprintf(stderr, "error: invalid lora apply model %s\n",
|
||||
arg);
|
||||
LOG_ERROR("error: invalid lora apply model %s",
|
||||
arg);
|
||||
return -1;
|
||||
}
|
||||
return 1;
|
||||
|
|
@ -692,13 +808,13 @@ struct SDContextParams {
|
|||
|
||||
bool process_and_check(SDMode mode) {
|
||||
if (mode != UPSCALE && model_path.length() == 0 && diffusion_model_path.length() == 0) {
|
||||
fprintf(stderr, "error: the following arguments are required: model_path/diffusion_model\n");
|
||||
LOG_ERROR("error: the following arguments are required: model_path/diffusion_model\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (mode == UPSCALE) {
|
||||
if (esrgan_path.length() == 0) {
|
||||
fprintf(stderr, "error: upscale mode needs an upscaler model (--upscale-model)\n");
|
||||
LOG_ERROR("error: upscale mode needs an upscaler model (--upscale-model)\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -797,7 +913,6 @@ struct SDContextParams {
|
|||
vae_path.c_str(),
|
||||
taesd_path.c_str(),
|
||||
control_net_path.c_str(),
|
||||
lora_model_dir.c_str(),
|
||||
embedding_vec.data(),
|
||||
static_cast<uint32_t>(embedding_vec.size()),
|
||||
photo_maker_path.c_str(),
|
||||
|
|
@ -864,7 +979,7 @@ static bool is_absolute_path(const std::string& p) {
|
|||
|
||||
struct SDGenerationParams {
|
||||
std::string prompt;
|
||||
std::string prompt_with_lora; // for metadata record only
|
||||
std::string prompt_with_lora; // for metadata record only
|
||||
std::string negative_prompt;
|
||||
int clip_skip = -1; // <= 0 represents unspecified
|
||||
int width = 512;
|
||||
|
|
@ -1116,8 +1231,8 @@ struct SDGenerationParams {
|
|||
const char* arg = argv[index];
|
||||
sample_params.sample_method = str_to_sample_method(arg);
|
||||
if (sample_params.sample_method == SAMPLE_METHOD_COUNT) {
|
||||
fprintf(stderr, "error: invalid sample method %s\n",
|
||||
arg);
|
||||
LOG_ERROR("error: invalid sample method %s",
|
||||
arg);
|
||||
return -1;
|
||||
}
|
||||
return 1;
|
||||
|
|
@ -1130,8 +1245,8 @@ struct SDGenerationParams {
|
|||
const char* arg = argv[index];
|
||||
high_noise_sample_params.sample_method = str_to_sample_method(arg);
|
||||
if (high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) {
|
||||
fprintf(stderr, "error: invalid high noise sample method %s\n",
|
||||
arg);
|
||||
LOG_ERROR("error: invalid high noise sample method %s",
|
||||
arg);
|
||||
return -1;
|
||||
}
|
||||
return 1;
|
||||
|
|
@ -1144,8 +1259,8 @@ struct SDGenerationParams {
|
|||
const char* arg = argv[index];
|
||||
sample_params.scheduler = str_to_scheduler(arg);
|
||||
if (sample_params.scheduler == SCHEDULER_COUNT) {
|
||||
fprintf(stderr, "error: invalid scheduler %s\n",
|
||||
arg);
|
||||
LOG_ERROR("error: invalid scheduler %s",
|
||||
arg);
|
||||
return -1;
|
||||
}
|
||||
return 1;
|
||||
|
|
@ -1226,17 +1341,17 @@ struct SDGenerationParams {
|
|||
try {
|
||||
custom_sigmas.push_back(std::stof(item));
|
||||
} catch (const std::invalid_argument& e) {
|
||||
fprintf(stderr, "error: invalid float value '%s' in --sigmas\n", item.c_str());
|
||||
LOG_ERROR("error: invalid float value '%s' in --sigmas", item.c_str());
|
||||
return -1;
|
||||
} catch (const std::out_of_range& e) {
|
||||
fprintf(stderr, "error: float value '%s' out of range in --sigmas\n", item.c_str());
|
||||
LOG_ERROR("error: float value '%s' out of range in --sigmas", item.c_str());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (custom_sigmas.empty() && !sigmas_str.empty()) {
|
||||
fprintf(stderr, "error: could not parse any sigma values from '%s'\n", argv[index]);
|
||||
LOG_ERROR("error: could not parse any sigma values from '%s'", argv[index]);
|
||||
return -1;
|
||||
}
|
||||
return 1;
|
||||
|
|
@ -1299,7 +1414,7 @@ struct SDGenerationParams {
|
|||
on_high_noise_sample_method_arg},
|
||||
{"",
|
||||
"--scheduler",
|
||||
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], default: discrete",
|
||||
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm], default: discrete",
|
||||
on_scheduler_arg},
|
||||
{"",
|
||||
"--sigmas",
|
||||
|
|
@ -1332,7 +1447,7 @@ struct SDGenerationParams {
|
|||
try {
|
||||
j = json::parse(json_str);
|
||||
} catch (...) {
|
||||
fprintf(stderr, "json parse failed %s\n", json_str.c_str());
|
||||
LOG_ERROR("json parse failed %s", json_str.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -1441,7 +1556,7 @@ struct SDGenerationParams {
|
|||
}
|
||||
}
|
||||
if (!found) {
|
||||
printf("can not found lora %s\n", final_path.lexically_normal().string().c_str());
|
||||
LOG_WARN("can not found lora %s", final_path.lexically_normal().string().c_str());
|
||||
tmp = m.suffix().str();
|
||||
prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only);
|
||||
continue;
|
||||
|
|
@ -1480,17 +1595,17 @@ struct SDGenerationParams {
|
|||
bool process_and_check(SDMode mode, const std::string& lora_model_dir) {
|
||||
prompt_with_lora = prompt;
|
||||
if (width <= 0) {
|
||||
fprintf(stderr, "error: the width must be greater than 0\n");
|
||||
LOG_ERROR("error: the width must be greater than 0\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (height <= 0) {
|
||||
fprintf(stderr, "error: the height must be greater than 0\n");
|
||||
LOG_ERROR("error: the height must be greater than 0\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (sample_params.sample_steps <= 0) {
|
||||
fprintf(stderr, "error: the sample_steps must be greater than 0\n");
|
||||
LOG_ERROR("error: the sample_steps must be greater than 0\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -1499,7 +1614,7 @@ struct SDGenerationParams {
|
|||
}
|
||||
|
||||
if (strength < 0.f || strength > 1.f) {
|
||||
fprintf(stderr, "error: can only work with strength in [0.0, 1.0]\n");
|
||||
LOG_ERROR("error: can only work with strength in [0.0, 1.0]\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -1521,31 +1636,31 @@ struct SDGenerationParams {
|
|||
};
|
||||
trim(token);
|
||||
if (token.empty()) {
|
||||
fprintf(stderr, "error: invalid easycache option '%s'\n", easycache_option.c_str());
|
||||
LOG_ERROR("error: invalid easycache option '%s'", easycache_option.c_str());
|
||||
return false;
|
||||
}
|
||||
if (idx >= 3) {
|
||||
fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n");
|
||||
LOG_ERROR("error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n");
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
values[idx] = std::stof(token);
|
||||
} catch (const std::exception&) {
|
||||
fprintf(stderr, "error: invalid easycache value '%s'\n", token.c_str());
|
||||
LOG_ERROR("error: invalid easycache value '%s'", token.c_str());
|
||||
return false;
|
||||
}
|
||||
idx++;
|
||||
}
|
||||
if (idx != 3) {
|
||||
fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n");
|
||||
LOG_ERROR("error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n");
|
||||
return false;
|
||||
}
|
||||
if (values[0] < 0.0f) {
|
||||
fprintf(stderr, "error: easycache threshold must be non-negative\n");
|
||||
LOG_ERROR("error: easycache threshold must be non-negative\n");
|
||||
return false;
|
||||
}
|
||||
if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) {
|
||||
fprintf(stderr, "error: easycache start/end percents must satisfy 0.0 <= start < end <= 1.0\n");
|
||||
LOG_ERROR("error: easycache start/end percents must satisfy 0.0 <= start < end <= 1.0\n");
|
||||
return false;
|
||||
}
|
||||
easycache_params.enabled = true;
|
||||
|
|
@ -1585,7 +1700,7 @@ struct SDGenerationParams {
|
|||
|
||||
if (mode == UPSCALE) {
|
||||
if (init_image_path.length() == 0) {
|
||||
fprintf(stderr, "error: upscale mode needs an init image (--init-img)\n");
|
||||
LOG_ERROR("error: upscale mode needs an init image (--init-img)\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -1700,13 +1815,13 @@ uint8_t* load_image_common(bool from_memory,
|
|||
image_buffer = (uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel);
|
||||
}
|
||||
if (image_buffer == nullptr) {
|
||||
fprintf(stderr, "load image from '%s' failed\n", image_path);
|
||||
LOG_ERROR("load image from '%s' failed", image_path);
|
||||
return nullptr;
|
||||
}
|
||||
if (c < expected_channel) {
|
||||
fprintf(stderr,
|
||||
"the number of channels for the input image must be >= %d,"
|
||||
"but got %d channels, image_path = %s\n",
|
||||
"but got %d channels, image_path = %s",
|
||||
expected_channel,
|
||||
c,
|
||||
image_path);
|
||||
|
|
@ -1714,12 +1829,12 @@ uint8_t* load_image_common(bool from_memory,
|
|||
return nullptr;
|
||||
}
|
||||
if (width <= 0) {
|
||||
fprintf(stderr, "error: the width of image must be greater than 0, image_path = %s\n", image_path);
|
||||
LOG_ERROR("error: the width of image must be greater than 0, image_path = %s", image_path);
|
||||
free(image_buffer);
|
||||
return nullptr;
|
||||
}
|
||||
if (height <= 0) {
|
||||
fprintf(stderr, "error: the height of image must be greater than 0, image_path = %s\n", image_path);
|
||||
LOG_ERROR("error: the height of image must be greater than 0, image_path = %s", image_path);
|
||||
free(image_buffer);
|
||||
return nullptr;
|
||||
}
|
||||
|
|
@ -1741,10 +1856,10 @@ uint8_t* load_image_common(bool from_memory,
|
|||
}
|
||||
|
||||
if (crop_x != 0 || crop_y != 0) {
|
||||
printf("crop input image from %dx%d to %dx%d, image_path = %s\n", width, height, crop_w, crop_h, image_path);
|
||||
LOG_INFO("crop input image from %dx%d to %dx%d, image_path = %s", width, height, crop_w, crop_h, image_path);
|
||||
uint8_t* cropped_image_buffer = (uint8_t*)malloc(crop_w * crop_h * expected_channel);
|
||||
if (cropped_image_buffer == nullptr) {
|
||||
fprintf(stderr, "error: allocate memory for crop\n");
|
||||
LOG_ERROR("error: allocate memory for crop\n");
|
||||
free(image_buffer);
|
||||
return nullptr;
|
||||
}
|
||||
|
|
@ -1760,13 +1875,13 @@ uint8_t* load_image_common(bool from_memory,
|
|||
image_buffer = cropped_image_buffer;
|
||||
}
|
||||
|
||||
printf("resize input image from %dx%d to %dx%d\n", width, height, expected_width, expected_height);
|
||||
LOG_INFO("resize input image from %dx%d to %dx%d", width, height, expected_width, expected_height);
|
||||
int resized_height = expected_height;
|
||||
int resized_width = expected_width;
|
||||
|
||||
uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * expected_channel);
|
||||
if (resized_image_buffer == nullptr) {
|
||||
fprintf(stderr, "error: allocate memory for resize input image\n");
|
||||
LOG_ERROR("error: allocate memory for resize input image\n");
|
||||
free(image_buffer);
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -347,6 +347,41 @@ struct SmoothStepScheduler : SigmaScheduler {
|
|||
}
|
||||
};
|
||||
|
||||
// Implementation adapted from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
|
||||
struct KLOptimalScheduler : SigmaScheduler {
|
||||
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
|
||||
std::vector<float> sigmas;
|
||||
|
||||
if (n == 0) {
|
||||
return sigmas;
|
||||
}
|
||||
if (n == 1) {
|
||||
sigmas.push_back(sigma_max);
|
||||
sigmas.push_back(0.0f);
|
||||
return sigmas;
|
||||
}
|
||||
|
||||
float alpha_min = std::atan(sigma_min);
|
||||
float alpha_max = std::atan(sigma_max);
|
||||
|
||||
for (uint32_t i = 0; i < n; ++i) {
|
||||
// t goes from 0.0 to 1.0
|
||||
float t = static_cast<float>(i) / static_cast<float>(n-1);
|
||||
|
||||
// Interpolate in the angle domain
|
||||
float angle = t * alpha_min + (1.0f - t) * alpha_max;
|
||||
|
||||
// Convert back to sigma
|
||||
sigmas.push_back(std::tan(angle));
|
||||
}
|
||||
|
||||
// Append the final zero to sigma
|
||||
sigmas.push_back(0.0f);
|
||||
|
||||
return sigmas;
|
||||
}
|
||||
};
|
||||
|
||||
struct Denoiser {
|
||||
virtual float sigma_min() = 0;
|
||||
virtual float sigma_max() = 0;
|
||||
|
|
@ -392,6 +427,10 @@ struct Denoiser {
|
|||
LOG_INFO("get_sigmas with SmoothStep scheduler");
|
||||
scheduler = std::make_shared<SmoothStepScheduler>();
|
||||
break;
|
||||
case KL_OPTIMAL_SCHEDULER:
|
||||
LOG_INFO("get_sigmas with KL Optimal scheduler");
|
||||
scheduler = std::make_shared<KLOptimalScheduler>();
|
||||
break;
|
||||
case LCM_SCHEDULER:
|
||||
LOG_INFO("get_sigmas with LCM scheduler");
|
||||
scheduler = std::make_shared<LCMScheduler>();
|
||||
|
|
|
|||
|
|
@ -744,6 +744,8 @@ namespace Flux {
|
|||
int64_t nerf_mlp_ratio = 4;
|
||||
int64_t nerf_depth = 4;
|
||||
int64_t nerf_max_freqs = 8;
|
||||
bool use_x0 = false;
|
||||
bool use_patch_size_32 = false;
|
||||
};
|
||||
|
||||
struct FluxParams {
|
||||
|
|
@ -781,7 +783,7 @@ namespace Flux {
|
|||
Flux(FluxParams params)
|
||||
: params(params) {
|
||||
if (params.version == VERSION_CHROMA_RADIANCE) {
|
||||
std::pair<int, int> kernel_size = {(int)params.patch_size, (int)params.patch_size};
|
||||
std::pair<int, int> kernel_size = {16, 16};
|
||||
std::pair<int, int> stride = kernel_size;
|
||||
|
||||
blocks["img_in_patch"] = std::make_shared<Conv2d>(params.in_channels,
|
||||
|
|
@ -1044,6 +1046,15 @@ namespace Flux {
|
|||
return img;
|
||||
}
|
||||
|
||||
struct ggml_tensor* _apply_x0_residual(GGMLRunnerContext* ctx,
|
||||
struct ggml_tensor* predicted,
|
||||
struct ggml_tensor* noisy,
|
||||
struct ggml_tensor* timesteps) {
|
||||
auto x = ggml_sub(ctx->ggml_ctx, noisy, predicted);
|
||||
x = ggml_div(ctx->ggml_ctx, x, timesteps);
|
||||
return x;
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward_chroma_radiance(GGMLRunnerContext* ctx,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* timestep,
|
||||
|
|
@ -1068,6 +1079,13 @@ namespace Flux {
|
|||
auto img = pad_to_patch_size(ctx->ggml_ctx, x);
|
||||
auto orig_img = img;
|
||||
|
||||
if (params.chroma_radiance_params.use_patch_size_32) {
|
||||
// It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable
|
||||
// Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch?
|
||||
// img = F.interpolate(img, size=(H//2, W//2), mode="nearest")
|
||||
img = ggml_interpolate(ctx->ggml_ctx, img, W / 2, H / 2, C, x->ne[3], GGML_SCALE_MODE_BILINEAR);
|
||||
}
|
||||
|
||||
auto img_in_patch = std::dynamic_pointer_cast<Conv2d>(blocks["img_in_patch"]);
|
||||
|
||||
img = img_in_patch->forward(ctx, img); // [N, hidden_size, H/patch_size, W/patch_size]
|
||||
|
|
@ -1104,6 +1122,10 @@ namespace Flux {
|
|||
|
||||
out = nerf_final_layer_conv->forward(ctx, img_dct); // [N, C, H, W]
|
||||
|
||||
if (params.chroma_radiance_params.use_x0) {
|
||||
out = _apply_x0_residual(ctx, out, orig_img, timestep);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
|
@ -1290,6 +1312,15 @@ namespace Flux {
|
|||
// not schnell
|
||||
flux_params.guidance_embed = true;
|
||||
}
|
||||
if (tensor_name.find("__x0__") != std::string::npos) {
|
||||
LOG_DEBUG("using x0 prediction");
|
||||
flux_params.chroma_radiance_params.use_x0 = true;
|
||||
}
|
||||
if (tensor_name.find("__32x32__") != std::string::npos) {
|
||||
LOG_DEBUG("using patch size 32 prediction");
|
||||
flux_params.chroma_radiance_params.use_patch_size_32 = true;
|
||||
flux_params.patch_size = 32;
|
||||
}
|
||||
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
|
||||
// Chroma
|
||||
flux_params.is_chroma = true;
|
||||
|
|
|
|||
|
|
@ -848,8 +848,6 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
|
|||
LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y);
|
||||
LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);
|
||||
|
||||
GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2
|
||||
|
||||
int tile_overlap_x = (int32_t)(p_tile_size_x * tile_overlap_factor_x);
|
||||
int non_tile_overlap_x = p_tile_size_x - tile_overlap_x;
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ struct SDCliParams {
|
|||
std::string output_path = "output.png";
|
||||
|
||||
bool verbose = false;
|
||||
bool version = false;
|
||||
bool canny_preprocess = false;
|
||||
|
||||
preview_t preview_method = PREVIEW_NONE;
|
||||
|
|
@ -74,10 +73,6 @@ struct SDCliParams {
|
|||
"--verbose",
|
||||
"print extra info",
|
||||
true, &verbose},
|
||||
{"",
|
||||
"--version",
|
||||
"print stable-diffusion.cpp version",
|
||||
true, &version},
|
||||
{"",
|
||||
"--color",
|
||||
"colors the logging tags according to level",
|
||||
|
|
@ -106,9 +101,8 @@ struct SDCliParams {
|
|||
}
|
||||
}
|
||||
if (mode_found == -1) {
|
||||
fprintf(stderr,
|
||||
"error: invalid mode %s, must be one of [%s]\n",
|
||||
mode_c_str, SD_ALL_MODES_STR);
|
||||
LOG_ERROR("error: invalid mode %s, must be one of [%s]\n",
|
||||
mode_c_str, SD_ALL_MODES_STR);
|
||||
exit(1);
|
||||
}
|
||||
mode = (SDMode)mode_found;
|
||||
|
|
@ -128,8 +122,7 @@ struct SDCliParams {
|
|||
}
|
||||
}
|
||||
if (preview_found == -1) {
|
||||
fprintf(stderr, "error: preview method %s\n",
|
||||
preview);
|
||||
LOG_ERROR("error: preview method %s", preview);
|
||||
return -1;
|
||||
}
|
||||
preview_method = (preview_t)preview_found;
|
||||
|
|
@ -161,7 +154,7 @@ struct SDCliParams {
|
|||
|
||||
bool process_and_check() {
|
||||
if (output_path.length() == 0) {
|
||||
fprintf(stderr, "error: the following arguments are required: output_path\n");
|
||||
LOG_ERROR("error: the following arguments are required: output_path");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -219,18 +212,6 @@ void parse_args(int argc, const char** argv, SDCliParams& cli_params, SDContextP
|
|||
}
|
||||
}
|
||||
|
||||
static std::string sd_basename(const std::string& path) {
|
||||
size_t pos = path.find_last_of('/');
|
||||
if (pos != std::string::npos) {
|
||||
return path.substr(pos + 1);
|
||||
}
|
||||
pos = path.find_last_of('\\');
|
||||
if (pos != std::string::npos) {
|
||||
return path.substr(pos + 1);
|
||||
}
|
||||
return path;
|
||||
}
|
||||
|
||||
std::string get_image_params(const SDCliParams& cli_params, const SDContextParams& ctx_params, const SDGenerationParams& gen_params, int64_t seed) {
|
||||
std::string parameter_string = gen_params.prompt_with_lora + "\n";
|
||||
if (gen_params.negative_prompt.size() != 0) {
|
||||
|
|
@ -288,47 +269,9 @@ std::string get_image_params(const SDCliParams& cli_params, const SDContextParam
|
|||
return parameter_string;
|
||||
}
|
||||
|
||||
/* Enables Printing the log level tag in color using ANSI escape codes */
|
||||
void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
|
||||
SDCliParams* cli_params = (SDCliParams*)data;
|
||||
int tag_color;
|
||||
const char* level_str;
|
||||
FILE* out_stream = (level == SD_LOG_ERROR) ? stderr : stdout;
|
||||
|
||||
if (!log || (!cli_params->verbose && level <= SD_LOG_DEBUG)) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (level) {
|
||||
case SD_LOG_DEBUG:
|
||||
tag_color = 37;
|
||||
level_str = "DEBUG";
|
||||
break;
|
||||
case SD_LOG_INFO:
|
||||
tag_color = 34;
|
||||
level_str = "INFO";
|
||||
break;
|
||||
case SD_LOG_WARN:
|
||||
tag_color = 35;
|
||||
level_str = "WARN";
|
||||
break;
|
||||
case SD_LOG_ERROR:
|
||||
tag_color = 31;
|
||||
level_str = "ERROR";
|
||||
break;
|
||||
default: /* Potential future-proofing */
|
||||
tag_color = 33;
|
||||
level_str = "?????";
|
||||
break;
|
||||
}
|
||||
|
||||
if (cli_params->color == true) {
|
||||
fprintf(out_stream, "\033[%d;1m[%-5s]\033[0m ", tag_color, level_str);
|
||||
} else {
|
||||
fprintf(out_stream, "[%-5s] ", level_str);
|
||||
}
|
||||
fputs(log, out_stream);
|
||||
fflush(out_stream);
|
||||
log_print(level, log, cli_params->verbose, cli_params->color);
|
||||
}
|
||||
|
||||
bool load_images_from_dir(const std::string dir,
|
||||
|
|
@ -338,7 +281,7 @@ bool load_images_from_dir(const std::string dir,
|
|||
int max_image_num = 0,
|
||||
bool verbose = false) {
|
||||
if (!fs::exists(dir) || !fs::is_directory(dir)) {
|
||||
fprintf(stderr, "'%s' is not a valid directory\n", dir.c_str());
|
||||
LOG_ERROR("'%s' is not a valid directory\n", dir.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -360,14 +303,12 @@ bool load_images_from_dir(const std::string dir,
|
|||
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
|
||||
|
||||
if (ext == ".jpg" || ext == ".jpeg" || ext == ".png" || ext == ".bmp") {
|
||||
if (verbose) {
|
||||
printf("load image %zu from '%s'\n", images.size(), path.c_str());
|
||||
}
|
||||
LOG_DEBUG("load image %zu from '%s'", images.size(), path.c_str());
|
||||
int width = 0;
|
||||
int height = 0;
|
||||
uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height, expected_width, expected_height);
|
||||
if (image_buffer == nullptr) {
|
||||
fprintf(stderr, "load image from '%s' failed\n", path.c_str());
|
||||
LOG_ERROR("load image from '%s' failed", path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -408,9 +349,6 @@ int main(int argc, const char* argv[]) {
|
|||
SDGenerationParams gen_params;
|
||||
|
||||
parse_args(argc, argv, cli_params, ctx_params, gen_params);
|
||||
if (cli_params.verbose || cli_params.version) {
|
||||
std::cout << version_string() << "\n";
|
||||
}
|
||||
if (gen_params.video_frames > 4) {
|
||||
size_t last_dot_pos = cli_params.preview_path.find_last_of(".");
|
||||
std::string base_path = cli_params.preview_path;
|
||||
|
|
@ -429,6 +367,8 @@ int main(int argc, const char* argv[]) {
|
|||
cli_params.preview_fps /= 4;
|
||||
|
||||
sd_set_log_callback(sd_log_cb, (void*)&cli_params);
|
||||
log_verbose = cli_params.verbose;
|
||||
log_color = cli_params.color;
|
||||
sd_set_preview_callback(step_callback,
|
||||
cli_params.preview_method,
|
||||
cli_params.preview_interval,
|
||||
|
|
@ -436,12 +376,11 @@ int main(int argc, const char* argv[]) {
|
|||
cli_params.preview_noisy,
|
||||
(void*)&cli_params);
|
||||
|
||||
if (cli_params.verbose) {
|
||||
printf("%s", sd_get_system_info());
|
||||
printf("%s\n", cli_params.to_string().c_str());
|
||||
printf("%s\n", ctx_params.to_string().c_str());
|
||||
printf("%s\n", gen_params.to_string().c_str());
|
||||
}
|
||||
LOG_DEBUG("version: %s", version_string().c_str());
|
||||
LOG_DEBUG("%s", sd_get_system_info());
|
||||
LOG_DEBUG("%s", cli_params.to_string().c_str());
|
||||
LOG_DEBUG("%s", ctx_params.to_string().c_str());
|
||||
LOG_DEBUG("%s", gen_params.to_string().c_str());
|
||||
|
||||
if (cli_params.mode == CONVERT) {
|
||||
bool success = convert(ctx_params.model_path.c_str(),
|
||||
|
|
@ -450,17 +389,16 @@ int main(int argc, const char* argv[]) {
|
|||
ctx_params.wtype,
|
||||
ctx_params.tensor_type_rules.c_str());
|
||||
if (!success) {
|
||||
fprintf(stderr,
|
||||
"convert '%s'/'%s' to '%s' failed\n",
|
||||
ctx_params.model_path.c_str(),
|
||||
ctx_params.vae_path.c_str(),
|
||||
cli_params.output_path.c_str());
|
||||
LOG_ERROR("convert '%s'/'%s' to '%s' failed",
|
||||
ctx_params.model_path.c_str(),
|
||||
ctx_params.vae_path.c_str(),
|
||||
cli_params.output_path.c_str());
|
||||
return 1;
|
||||
} else {
|
||||
printf("convert '%s'/'%s' to '%s' success\n",
|
||||
ctx_params.model_path.c_str(),
|
||||
ctx_params.vae_path.c_str(),
|
||||
cli_params.output_path.c_str());
|
||||
LOG_INFO("convert '%s'/'%s' to '%s' success",
|
||||
ctx_params.model_path.c_str(),
|
||||
ctx_params.vae_path.c_str(),
|
||||
cli_params.output_path.c_str());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
|
@ -503,7 +441,7 @@ int main(int argc, const char* argv[]) {
|
|||
int height = 0;
|
||||
init_image.data = load_image_from_file(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height);
|
||||
if (init_image.data == nullptr) {
|
||||
fprintf(stderr, "load image from '%s' failed\n", gen_params.init_image_path.c_str());
|
||||
LOG_ERROR("load image from '%s' failed", gen_params.init_image_path.c_str());
|
||||
release_all_resources();
|
||||
return 1;
|
||||
}
|
||||
|
|
@ -516,7 +454,7 @@ int main(int argc, const char* argv[]) {
|
|||
int height = 0;
|
||||
end_image.data = load_image_from_file(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height);
|
||||
if (end_image.data == nullptr) {
|
||||
fprintf(stderr, "load image from '%s' failed\n", gen_params.end_image_path.c_str());
|
||||
LOG_ERROR("load image from '%s' failed", gen_params.end_image_path.c_str());
|
||||
release_all_resources();
|
||||
return 1;
|
||||
}
|
||||
|
|
@ -528,7 +466,7 @@ int main(int argc, const char* argv[]) {
|
|||
int height = 0;
|
||||
mask_image.data = load_image_from_file(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1);
|
||||
if (mask_image.data == nullptr) {
|
||||
fprintf(stderr, "load image from '%s' failed\n", gen_params.mask_image_path.c_str());
|
||||
LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str());
|
||||
release_all_resources();
|
||||
return 1;
|
||||
}
|
||||
|
|
@ -536,7 +474,7 @@ int main(int argc, const char* argv[]) {
|
|||
mask_image.data = (uint8_t*)malloc(gen_params.width * gen_params.height);
|
||||
memset(mask_image.data, 255, gen_params.width * gen_params.height);
|
||||
if (mask_image.data == nullptr) {
|
||||
fprintf(stderr, "malloc mask image failed\n");
|
||||
LOG_ERROR("malloc mask image failed");
|
||||
release_all_resources();
|
||||
return 1;
|
||||
}
|
||||
|
|
@ -547,7 +485,7 @@ int main(int argc, const char* argv[]) {
|
|||
int height = 0;
|
||||
control_image.data = load_image_from_file(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height);
|
||||
if (control_image.data == nullptr) {
|
||||
fprintf(stderr, "load image from '%s' failed\n", gen_params.control_image_path.c_str());
|
||||
LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str());
|
||||
release_all_resources();
|
||||
return 1;
|
||||
}
|
||||
|
|
@ -568,7 +506,7 @@ int main(int argc, const char* argv[]) {
|
|||
int height = 0;
|
||||
uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height);
|
||||
if (image_buffer == nullptr) {
|
||||
fprintf(stderr, "load image from '%s' failed\n", path.c_str());
|
||||
LOG_ERROR("load image from '%s' failed", path.c_str());
|
||||
release_all_resources();
|
||||
return 1;
|
||||
}
|
||||
|
|
@ -616,7 +554,7 @@ int main(int argc, const char* argv[]) {
|
|||
num_results = 1;
|
||||
results = (sd_image_t*)calloc(num_results, sizeof(sd_image_t));
|
||||
if (results == nullptr) {
|
||||
printf("failed to allocate results array\n");
|
||||
LOG_INFO("failed to allocate results array");
|
||||
release_all_resources();
|
||||
return 1;
|
||||
}
|
||||
|
|
@ -627,7 +565,7 @@ int main(int argc, const char* argv[]) {
|
|||
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
|
||||
|
||||
if (sd_ctx == nullptr) {
|
||||
printf("new_sd_ctx_t failed\n");
|
||||
LOG_INFO("new_sd_ctx_t failed");
|
||||
release_all_resources();
|
||||
return 1;
|
||||
}
|
||||
|
|
@ -641,7 +579,7 @@ int main(int argc, const char* argv[]) {
|
|||
}
|
||||
|
||||
if (gen_params.sample_params.scheduler == SCHEDULER_COUNT) {
|
||||
gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx);
|
||||
gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx, gen_params.sample_params.sample_method);
|
||||
}
|
||||
|
||||
if (cli_params.mode == IMG_GEN) {
|
||||
|
|
@ -704,7 +642,7 @@ int main(int argc, const char* argv[]) {
|
|||
}
|
||||
|
||||
if (results == nullptr) {
|
||||
printf("generate failed\n");
|
||||
LOG_ERROR("generate failed");
|
||||
free_sd_ctx(sd_ctx);
|
||||
return 1;
|
||||
}
|
||||
|
|
@ -721,7 +659,7 @@ int main(int argc, const char* argv[]) {
|
|||
gen_params.upscale_tile_size);
|
||||
|
||||
if (upscaler_ctx == nullptr) {
|
||||
printf("new_upscaler_ctx failed\n");
|
||||
LOG_ERROR("new_upscaler_ctx failed");
|
||||
} else {
|
||||
for (int i = 0; i < num_results; i++) {
|
||||
if (results[i].data == nullptr) {
|
||||
|
|
@ -731,7 +669,7 @@ int main(int argc, const char* argv[]) {
|
|||
for (int u = 0; u < gen_params.upscale_repeats; ++u) {
|
||||
sd_image_t upscaled_image = upscale(upscaler_ctx, current_image, upscale_factor);
|
||||
if (upscaled_image.data == nullptr) {
|
||||
printf("upscale failed\n");
|
||||
LOG_ERROR("upscale failed");
|
||||
break;
|
||||
}
|
||||
free(current_image.data);
|
||||
|
|
@ -749,8 +687,8 @@ int main(int argc, const char* argv[]) {
|
|||
std::error_code ec;
|
||||
fs::create_directories(out_dir, ec); // OK if already exists
|
||||
if (ec) {
|
||||
fprintf(stderr, "failed to create directory '%s': %s\n",
|
||||
out_dir.string().c_str(), ec.message().c_str());
|
||||
LOG_ERROR("failed to create directory '%s': %s",
|
||||
out_dir.string().c_str(), ec.message().c_str());
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
|
@ -780,7 +718,7 @@ int main(int argc, const char* argv[]) {
|
|||
vid_output_path = base_path + ".avi";
|
||||
}
|
||||
create_mjpg_avi_from_sd_images(vid_output_path.c_str(), results, num_results, gen_params.fps);
|
||||
printf("save result MJPG AVI video to '%s'\n", vid_output_path.c_str());
|
||||
LOG_INFO("save result MJPG AVI video to '%s'\n", vid_output_path.c_str());
|
||||
} else {
|
||||
// appending ".png" to absent or unknown extension
|
||||
if (!is_jpg && file_ext_lower != ".png") {
|
||||
|
|
@ -796,11 +734,11 @@ int main(int argc, const char* argv[]) {
|
|||
if (is_jpg) {
|
||||
write_ok = stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
|
||||
results[i].data, 90, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str());
|
||||
printf("save result JPEG image to '%s' (%s)\n", final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
|
||||
LOG_INFO("save result JPEG image to '%s' (%s)", final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
|
||||
} else {
|
||||
write_ok = stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
|
||||
results[i].data, 0, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str());
|
||||
printf("save result PNG image to '%s' (%s)\n", final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
|
||||
LOG_INFO("save result PNG image to '%s' (%s)", final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1574,6 +1574,11 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
|
|||
i64_to_i32_vec((int64_t*)read_buf, (int32_t*)target_buf, tensor_storage.nelements());
|
||||
}
|
||||
if (tensor_storage.type != dst_tensor->type) {
|
||||
if (convert_buf == nullptr) {
|
||||
LOG_ERROR("read tensor data failed: too less memory for conversion");
|
||||
failed = true;
|
||||
return;
|
||||
}
|
||||
convert_tensor((void*)target_buf,
|
||||
tensor_storage.type,
|
||||
convert_buf,
|
||||
|
|
@ -1786,6 +1791,13 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type
|
|||
// tensor_storage.ne[0], tensor_storage.ne[1], tensor_storage.ne[2], tensor_storage.ne[3],
|
||||
// tensor->n_dims, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
||||
|
||||
if (!tensor->data) {
|
||||
GGML_ASSERT(ggml_nelements(tensor) == 0);
|
||||
// avoid crashing the gguf writer by setting a dummy pointer for zero-sized tensors
|
||||
LOG_DEBUG("setting dummy pointer for zero-sized tensor %s", name.c_str());
|
||||
tensor->data = ggml_get_mem_buffer(ggml_ctx);
|
||||
}
|
||||
|
||||
*dst_tensor = tensor;
|
||||
|
||||
gguf_add_tensor(gguf_ctx, tensor);
|
||||
|
|
|
|||
|
|
@ -835,6 +835,7 @@ std::string convert_sep_to_dot(std::string name) {
|
|||
"proj_out",
|
||||
"transformer_blocks",
|
||||
"single_transformer_blocks",
|
||||
"single_blocks",
|
||||
"diffusion_model",
|
||||
"cond_stage_model",
|
||||
"first_stage_model",
|
||||
|
|
@ -876,7 +877,18 @@ std::string convert_sep_to_dot(std::string name) {
|
|||
"ff_context",
|
||||
"norm_added_q",
|
||||
"norm_added_v",
|
||||
"to_add_out"};
|
||||
"to_add_out",
|
||||
"txt_mod",
|
||||
"img_mod",
|
||||
"txt_mlp",
|
||||
"img_mlp",
|
||||
"proj_mlp",
|
||||
"wi_0",
|
||||
"wi_1",
|
||||
"norm1_context",
|
||||
"ff_context",
|
||||
"x_embedder",
|
||||
};
|
||||
|
||||
// record the positions of underscores that should NOT be replaced
|
||||
std::unordered_set<size_t> protected_positions;
|
||||
|
|
@ -1020,12 +1032,14 @@ std::string convert_tensor_name(std::string name, SDVersion version) {
|
|||
}
|
||||
}
|
||||
|
||||
if (sd_version_is_unet(version) || is_lycoris_underline) {
|
||||
// LOG_DEBUG("name %s %d", name.c_str(), version);
|
||||
|
||||
if (sd_version_is_unet(version) || sd_version_is_flux(version) || is_lycoris_underline) {
|
||||
name = convert_sep_to_dot(name);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> prefix_map = {
|
||||
std::unordered_map<std::string, std::string> prefix_map = {
|
||||
{"diffusion_model.", "model.diffusion_model."},
|
||||
{"unet.", "model.diffusion_model."},
|
||||
{"transformer.", "model.diffusion_model."}, // dit
|
||||
|
|
@ -1040,8 +1054,13 @@ std::string convert_tensor_name(std::string name, SDVersion version) {
|
|||
// {"te2.text_model.encoder.layers.", "cond_stage_model.1.model.transformer.resblocks."},
|
||||
{"te2.", "cond_stage_model.1.transformer."},
|
||||
{"te1.", "cond_stage_model.transformer."},
|
||||
{"te3.", "text_encoders.t5xxl.transformer."},
|
||||
};
|
||||
|
||||
if (sd_version_is_flux(version)) {
|
||||
prefix_map["te1."] = "text_encoders.clip_l.transformer.";
|
||||
}
|
||||
|
||||
replace_with_prefix_map(name, prefix_map);
|
||||
|
||||
// diffusion model
|
||||
|
|
|
|||
|
|
@ -136,7 +136,6 @@ public:
|
|||
|
||||
std::map<std::string, struct ggml_tensor*> tensors;
|
||||
|
||||
std::string lora_model_dir;
|
||||
// lora_name => multiplier
|
||||
std::unordered_map<std::string, float> curr_lora_state;
|
||||
|
||||
|
|
@ -219,7 +218,6 @@ public:
|
|||
n_threads = sd_ctx_params->n_threads;
|
||||
vae_decode_only = sd_ctx_params->vae_decode_only;
|
||||
free_params_immediately = sd_ctx_params->free_params_immediately;
|
||||
lora_model_dir = SAFE_STR(sd_ctx_params->lora_model_dir);
|
||||
taesd_path = SAFE_STR(sd_ctx_params->taesd_path);
|
||||
use_tiny_autoencoder = taesd_path.size() > 0;
|
||||
offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu;
|
||||
|
|
@ -418,6 +416,14 @@ public:
|
|||
{
|
||||
to_replace = "taesd_3.embd";
|
||||
}
|
||||
else if(version == VERSION_WAN2_2_TI2V)
|
||||
{
|
||||
to_replace = "taesd_w22.embd";
|
||||
}
|
||||
else if(sd_version_is_wan(version)||sd_version_is_qwen_image(version))
|
||||
{
|
||||
to_replace = "taesd_w21.embd";
|
||||
}
|
||||
|
||||
if(to_replace!="")
|
||||
{
|
||||
|
|
@ -432,6 +438,12 @@ public:
|
|||
taesd_path_fixed = "";
|
||||
use_tiny_autoencoder = false;
|
||||
}
|
||||
if (use_tiny_autoencoder && !file_exists(taesd_path_fixed))
|
||||
{
|
||||
printf("\nCannot use TAESD: \"%s\" not found. TAESD Disabled!\n", taesd_path_fixed);
|
||||
taesd_path_fixed = "";
|
||||
use_tiny_autoencoder = false;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
|
||||
|
|
@ -663,6 +675,9 @@ public:
|
|||
if (sd_ctx_params->diffusion_flash_attn) {
|
||||
LOG_INFO("Using flash attention in the diffusion model");
|
||||
diffusion_model->set_flash_attn_enabled(true);
|
||||
if (high_noise_diffusion_model) {
|
||||
high_noise_diffusion_model->set_flash_attn_enabled(true);
|
||||
}
|
||||
}
|
||||
|
||||
cond_stage_model->alloc_params_buffer();
|
||||
|
|
@ -688,14 +703,27 @@ public:
|
|||
}
|
||||
|
||||
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
|
||||
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
|
||||
offload_params_to_cpu,
|
||||
tensor_storage_map,
|
||||
"first_stage_model",
|
||||
vae_decode_only,
|
||||
version);
|
||||
first_stage_model->alloc_params_buffer();
|
||||
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
||||
if (!use_tiny_autoencoder) {
|
||||
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
|
||||
offload_params_to_cpu,
|
||||
tensor_storage_map,
|
||||
"first_stage_model",
|
||||
vae_decode_only,
|
||||
version);
|
||||
first_stage_model->alloc_params_buffer();
|
||||
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
||||
} else {
|
||||
tae_first_stage = std::make_shared<TinyVideoAutoEncoder>(vae_backend,
|
||||
offload_params_to_cpu,
|
||||
tensor_storage_map,
|
||||
"decoder",
|
||||
vae_decode_only,
|
||||
version);
|
||||
if (sd_ctx_params->vae_conv_direct) {
|
||||
LOG_INFO("Using Conv2d direct in the tae model");
|
||||
tae_first_stage->set_conv2d_direct_enabled(true);
|
||||
}
|
||||
}
|
||||
} else if (version == VERSION_CHROMA_RADIANCE) {
|
||||
first_stage_model = std::make_shared<FakeVAE>(vae_backend,
|
||||
offload_params_to_cpu);
|
||||
|
|
@ -722,14 +750,13 @@ public:
|
|||
}
|
||||
first_stage_model->alloc_params_buffer();
|
||||
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
||||
}
|
||||
if (use_tiny_autoencoder) {
|
||||
tae_first_stage = std::make_shared<TinyAutoEncoder>(vae_backend,
|
||||
offload_params_to_cpu,
|
||||
tensor_storage_map,
|
||||
"decoder.layers",
|
||||
vae_decode_only,
|
||||
version);
|
||||
} else if (use_tiny_autoencoder) {
|
||||
tae_first_stage = std::make_shared<TinyImageAutoEncoder>(vae_backend,
|
||||
offload_params_to_cpu,
|
||||
tensor_storage_map,
|
||||
"decoder.layers",
|
||||
vae_decode_only,
|
||||
version);
|
||||
if (sd_ctx_params->vae_conv_direct) {
|
||||
LOG_INFO("Using Conv2d direct in the tae model");
|
||||
tae_first_stage->set_conv2d_direct_enabled(true);
|
||||
|
|
@ -823,6 +850,8 @@ public:
|
|||
if (stacked_id) {
|
||||
ignore_tensors.insert("pmid.unet.");
|
||||
}
|
||||
ignore_tensors.insert("model.diffusion_model.__x0__");
|
||||
ignore_tensors.insert("model.diffusion_model.__32x32__");
|
||||
|
||||
if (vae_decode_only) {
|
||||
ignore_tensors.insert("first_stage_model.encoder");
|
||||
|
|
@ -957,6 +986,7 @@ public:
|
|||
}
|
||||
} else if (sd_version_is_flux(version)) {
|
||||
pred_type = FLUX_FLOW_PRED;
|
||||
|
||||
if (flow_shift == INFINITY) {
|
||||
flow_shift = 1.0f; // TODO: validate
|
||||
for (const auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
|
|
@ -1612,6 +1642,17 @@ public:
|
|||
std::vector<int> skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count);
|
||||
|
||||
float cfg_scale = guidance.txt_cfg;
|
||||
if (cfg_scale < 1.f) {
|
||||
if (cfg_scale == 0.f) {
|
||||
// Diffusers follow the convention from the original paper
|
||||
// (https://arxiv.org/abs/2207.12598v1), so many distilled model docs
|
||||
// recommend 0 as guidance; warn the user that it'll disable prompt folowing
|
||||
LOG_WARN("unconditioned mode, images won't follow the prompt (use cfg-scale=1 for distilled models)");
|
||||
} else {
|
||||
LOG_WARN("cfg value out of expected range may produce unexpected results");
|
||||
}
|
||||
}
|
||||
|
||||
float img_cfg_scale = std::isfinite(guidance.img_cfg) ? guidance.img_cfg : guidance.txt_cfg;
|
||||
float slg_scale = guidance.slg.scale;
|
||||
|
||||
|
|
@ -2527,6 +2568,7 @@ const char* scheduler_to_str[] = {
|
|||
"sgm_uniform",
|
||||
"simple",
|
||||
"smoothstep",
|
||||
"kl_optimal",
|
||||
"lcm",
|
||||
};
|
||||
|
||||
|
|
@ -2664,7 +2706,6 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
|||
"vae_path: %s\n"
|
||||
"taesd_path: %s\n"
|
||||
"control_net_path: %s\n"
|
||||
"lora_model_dir: %s\n"
|
||||
"photo_maker_path: %s\n"
|
||||
"tensor_type_rules: %s\n"
|
||||
"vae_decode_only: %s\n"
|
||||
|
|
@ -2694,7 +2735,6 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
|||
SAFE_STR(sd_ctx_params->vae_path),
|
||||
SAFE_STR(sd_ctx_params->taesd_path),
|
||||
SAFE_STR(sd_ctx_params->control_net_path),
|
||||
SAFE_STR(sd_ctx_params->lora_model_dir),
|
||||
SAFE_STR(sd_ctx_params->photo_maker_path),
|
||||
SAFE_STR(sd_ctx_params->tensor_type_rules),
|
||||
BOOL_STR(sd_ctx_params->vae_decode_only),
|
||||
|
|
@ -2893,13 +2933,16 @@ enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) {
|
|||
return EULER_A_SAMPLE_METHOD;
|
||||
}
|
||||
|
||||
enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx) {
|
||||
enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_method_t sample_method) {
|
||||
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
|
||||
auto edm_v_denoiser = std::dynamic_pointer_cast<EDMVDenoiser>(sd_ctx->sd->denoiser);
|
||||
if (edm_v_denoiser) {
|
||||
return EXPONENTIAL_SCHEDULER;
|
||||
}
|
||||
}
|
||||
if (sample_method == LCM_SAMPLE_METHOD) {
|
||||
return LCM_SCHEDULER;
|
||||
}
|
||||
return DISCRETE_SCHEDULER;
|
||||
}
|
||||
|
||||
|
|
@ -3334,9 +3377,13 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
|||
LOG_WARN("sample_steps != custom_sigmas_count - 1, set sample_steps to %d", sample_steps);
|
||||
}
|
||||
} else {
|
||||
scheduler_t scheduler = sd_img_gen_params->sample_params.scheduler;
|
||||
if (scheduler == SCHEDULER_COUNT) {
|
||||
scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
|
||||
}
|
||||
sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps,
|
||||
sd_ctx->sd->get_image_seq_len(height, width),
|
||||
sd_img_gen_params->sample_params.scheduler,
|
||||
scheduler,
|
||||
sd_ctx->sd->version);
|
||||
}
|
||||
|
||||
|
|
@ -3619,9 +3666,13 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||
}
|
||||
}
|
||||
} else {
|
||||
scheduler_t scheduler = sd_vid_gen_params->sample_params.scheduler;
|
||||
if (scheduler == SCHEDULER_COUNT) {
|
||||
scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
|
||||
}
|
||||
sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps,
|
||||
0,
|
||||
sd_vid_gen_params->sample_params.scheduler,
|
||||
scheduler,
|
||||
sd_ctx->sd->version);
|
||||
}
|
||||
|
||||
|
|
@ -3746,7 +3797,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
|
||||
ggml_set_f32(denoise_mask, 1.f);
|
||||
|
||||
sd_ctx->sd->process_latent_out(init_latent);
|
||||
if (!sd_ctx->sd->use_tiny_autoencoder)
|
||||
sd_ctx->sd->process_latent_out(init_latent);
|
||||
|
||||
ggml_ext_tensor_iter(init_image_latent, [&](ggml_tensor* t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||
float value = ggml_ext_tensor_get_f32(t, i0, i1, i2, i3);
|
||||
|
|
@ -3756,7 +3808,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||
}
|
||||
});
|
||||
|
||||
sd_ctx->sd->process_latent_in(init_latent);
|
||||
if (!sd_ctx->sd->use_tiny_autoencoder)
|
||||
sd_ctx->sd->process_latent_in(init_latent);
|
||||
|
||||
int64_t t2 = ggml_time_ms();
|
||||
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
|
||||
|
|
@ -3979,7 +4032,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||
struct ggml_tensor* vid = sd_ctx->sd->decode_first_stage(work_ctx, final_latent, true);
|
||||
int64_t t5 = ggml_time_ms();
|
||||
LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000);
|
||||
if (sd_ctx->sd->free_params_immediately) {
|
||||
if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) {
|
||||
sd_ctx->sd->first_stage_model->free_params_buffer();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ enum scheduler_t {
|
|||
SGM_UNIFORM_SCHEDULER,
|
||||
SIMPLE_SCHEDULER,
|
||||
SMOOTHSTEP_SCHEDULER,
|
||||
KL_OPTIMAL_SCHEDULER,
|
||||
LCM_SCHEDULER,
|
||||
SCHEDULER_COUNT
|
||||
};
|
||||
|
|
@ -168,7 +169,6 @@ typedef struct {
|
|||
const char* vae_path;
|
||||
const char* taesd_path;
|
||||
const char* control_net_path;
|
||||
const char* lora_model_dir;
|
||||
const sd_embedding_t* embeddings;
|
||||
uint32_t embedding_count;
|
||||
const char* photo_maker_path;
|
||||
|
|
@ -335,7 +335,7 @@ SD_API void sd_sample_params_init(sd_sample_params_t* sample_params);
|
|||
SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params);
|
||||
|
||||
SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx);
|
||||
SD_API enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx);
|
||||
SD_API enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_method_t sample_method);
|
||||
|
||||
SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
|
||||
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
|
||||
|
|
|
|||
|
|
@ -162,6 +162,311 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class TPool : public UnaryBlock {
|
||||
int stride;
|
||||
|
||||
public:
|
||||
TPool(int channels, int stride)
|
||||
: stride(stride) {
|
||||
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels * stride, channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||
auto conv = std::dynamic_pointer_cast<UnaryBlock>(blocks["conv"]);
|
||||
auto h = x;
|
||||
if (stride != 1) {
|
||||
h = ggml_reshape_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2] * stride, h->ne[3] / stride);
|
||||
}
|
||||
h = conv->forward(ctx, h);
|
||||
return h;
|
||||
}
|
||||
};
|
||||
|
||||
class TGrow : public UnaryBlock {
|
||||
int stride;
|
||||
|
||||
public:
|
||||
TGrow(int channels, int stride)
|
||||
: stride(stride) {
|
||||
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels * stride, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||
auto conv = std::dynamic_pointer_cast<UnaryBlock>(blocks["conv"]);
|
||||
auto h = conv->forward(ctx, x);
|
||||
if (stride != 1) {
|
||||
h = ggml_reshape_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2] / stride, h->ne[3] * stride);
|
||||
}
|
||||
return h;
|
||||
}
|
||||
};
|
||||
|
||||
class MemBlock : public GGMLBlock {
|
||||
bool has_skip_conv = false;
|
||||
|
||||
public:
|
||||
MemBlock(int channels, int out_channels)
|
||||
: has_skip_conv(channels != out_channels) {
|
||||
blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels * 2, out_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||
blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||
blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||
if (has_skip_conv) {
|
||||
blocks["skip"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* past) {
|
||||
// x: [n, channels, h, w]
|
||||
auto conv0 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.0"]);
|
||||
auto conv1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.2"]);
|
||||
auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]);
|
||||
|
||||
auto h = ggml_concat(ctx->ggml_ctx, x, past, 2);
|
||||
h = conv0->forward(ctx, h);
|
||||
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||
h = conv1->forward(ctx, h);
|
||||
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||
h = conv2->forward(ctx, h);
|
||||
|
||||
auto skip = x;
|
||||
if (has_skip_conv) {
|
||||
auto skip_conv = std::dynamic_pointer_cast<Conv2d>(blocks["skip"]);
|
||||
skip = skip_conv->forward(ctx, x);
|
||||
}
|
||||
h = ggml_add_inplace(ctx->ggml_ctx, h, skip);
|
||||
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||
return h;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_tensor* patchify(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
int64_t patch_size,
|
||||
int64_t b = 1) {
|
||||
// x: [f, b*c, h*q, w*r]
|
||||
// return: [f, b*c*r*q, h, w]
|
||||
if (patch_size == 1) {
|
||||
return x;
|
||||
}
|
||||
int64_t r = patch_size;
|
||||
int64_t q = patch_size;
|
||||
|
||||
int64_t W = x->ne[0];
|
||||
int64_t H = x->ne[1];
|
||||
int64_t C = x->ne[2];
|
||||
int64_t f = x->ne[3];
|
||||
|
||||
int64_t w = W / r;
|
||||
int64_t h = H / q;
|
||||
|
||||
x = ggml_reshape_4d(ctx, x, W, q, h, C * f); // [W, q, h, C*f]
|
||||
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [W, h, q, C*f]
|
||||
x = ggml_reshape_4d(ctx, x, r, w, h, q * C * f); // [r, w, h, q*C*f]
|
||||
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [w, h, r, q*C*f]
|
||||
x = ggml_reshape_4d(ctx, x, w, h, r * q * C, f); // [f, b*c*r*q, h, w]
|
||||
|
||||
return x;
|
||||
}
|
||||
|
||||
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
int64_t patch_size,
|
||||
int64_t b = 1) {
|
||||
// x: [f, b*c*r*q, h, w]
|
||||
// return: [f, b*c, h*q, w*r]
|
||||
if (patch_size == 1) {
|
||||
return x;
|
||||
}
|
||||
int64_t r = patch_size;
|
||||
int64_t q = patch_size;
|
||||
int64_t c = x->ne[2] / b / q / r;
|
||||
int64_t f = x->ne[3];
|
||||
int64_t h = x->ne[1];
|
||||
int64_t w = x->ne[0];
|
||||
|
||||
x = ggml_reshape_4d(ctx, x, w, h, r, q * c * b * f); // [q*c*b*f, r, h, w]
|
||||
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [r, w, h, q*c*b*f]
|
||||
x = ggml_reshape_4d(ctx, x, r * w, h, q, c * b * f); // [c*b*f, q, h, r*w]
|
||||
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [r*w, q, h, c*b*f]
|
||||
x = ggml_reshape_4d(ctx, x, r * w, q * h, c * b, f);
|
||||
|
||||
return x;
|
||||
}
|
||||
|
||||
class TinyVideoEncoder : public UnaryBlock {
|
||||
int in_channels = 3;
|
||||
int hidden = 64;
|
||||
int z_channels = 4;
|
||||
int num_blocks = 3;
|
||||
int num_layers = 3;
|
||||
int patch_size = 1;
|
||||
|
||||
public:
|
||||
TinyVideoEncoder(int z_channels = 4, int patch_size = 1)
|
||||
: z_channels(z_channels), patch_size(patch_size) {
|
||||
int index = 0;
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels * patch_size * patch_size, hidden, {3, 3}, {1, 1}, {1, 1}));
|
||||
index++; // nn.ReLU()
|
||||
for (int i = 0; i < num_layers; i++) {
|
||||
int stride = i == num_layers - 1 ? 1 : 2;
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TPool(hidden, stride));
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(hidden, hidden, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false));
|
||||
for (int j = 0; j < num_blocks; j++) {
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new MemBlock(hidden, hidden));
|
||||
}
|
||||
}
|
||||
blocks[std::to_string(index)] = std::shared_ptr<GGMLBlock>(new Conv2d(hidden, z_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override {
|
||||
auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks["0"]);
|
||||
|
||||
if (patch_size > 1) {
|
||||
z = patchify(ctx->ggml_ctx, z, patch_size, 1);
|
||||
}
|
||||
|
||||
auto h = first_conv->forward(ctx, z);
|
||||
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||
|
||||
int index = 2;
|
||||
for (int i = 0; i < num_layers; i++) {
|
||||
auto pool = std::dynamic_pointer_cast<UnaryBlock>(blocks[std::to_string(index++)]);
|
||||
auto conv = std::dynamic_pointer_cast<UnaryBlock>(blocks[std::to_string(index++)]);
|
||||
|
||||
h = pool->forward(ctx, h);
|
||||
h = conv->forward(ctx, h);
|
||||
for (int j = 0; j < num_blocks; j++) {
|
||||
auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string(index++)]);
|
||||
auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0);
|
||||
mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0);
|
||||
h = block->forward(ctx, h, mem);
|
||||
}
|
||||
}
|
||||
auto last_conv = std::dynamic_pointer_cast<Conv2d>(blocks[std::to_string(index)]);
|
||||
h = last_conv->forward(ctx, h);
|
||||
return h;
|
||||
}
|
||||
};
|
||||
|
||||
class TinyVideoDecoder : public UnaryBlock {
|
||||
int z_channels = 4;
|
||||
int out_channels = 3;
|
||||
int num_blocks = 3;
|
||||
static const int num_layers = 3;
|
||||
int channels[num_layers + 1] = {256, 128, 64, 64};
|
||||
int patch_size = 1;
|
||||
|
||||
public:
|
||||
TinyVideoDecoder(int z_channels = 4, int patch_size = 1)
|
||||
: z_channels(z_channels), patch_size(patch_size) {
|
||||
int index = 1; // Clamp()
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, channels[0], {3, 3}, {1, 1}, {1, 1}));
|
||||
index++; // nn.ReLU()
|
||||
for (int i = 0; i < num_layers; i++) {
|
||||
int stride = i == 0 ? 1 : 2;
|
||||
for (int j = 0; j < num_blocks; j++) {
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new MemBlock(channels[i], channels[i]));
|
||||
}
|
||||
index++; // nn.Upsample()
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TGrow(channels[i], stride));
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels[i], channels[i + 1], {3, 3}, {1, 1}, {1, 1}, {1, 1}, false));
|
||||
}
|
||||
index++; // nn.ReLU()
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels[num_layers], out_channels * patch_size * patch_size, {3, 3}, {1, 1}, {1, 1}));
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override {
|
||||
auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks["1"]);
|
||||
|
||||
// Clamp()
|
||||
auto h = ggml_scale_inplace(ctx->ggml_ctx,
|
||||
ggml_tanh_inplace(ctx->ggml_ctx,
|
||||
ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)),
|
||||
3.0f);
|
||||
|
||||
h = first_conv->forward(ctx, h);
|
||||
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||
int index = 3;
|
||||
for (int i = 0; i < num_layers; i++) {
|
||||
for (int j = 0; j < num_blocks; j++) {
|
||||
auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string(index++)]);
|
||||
auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0);
|
||||
mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0);
|
||||
h = block->forward(ctx, h, mem);
|
||||
}
|
||||
// upsample
|
||||
index++;
|
||||
h = ggml_upscale(ctx->ggml_ctx, h, 2, GGML_SCALE_MODE_NEAREST);
|
||||
auto block = std::dynamic_pointer_cast<UnaryBlock>(blocks[std::to_string(index++)]);
|
||||
h = block->forward(ctx, h);
|
||||
block = std::dynamic_pointer_cast<UnaryBlock>(blocks[std::to_string(index++)]);
|
||||
h = block->forward(ctx, h);
|
||||
}
|
||||
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||
|
||||
auto last_conv = std::dynamic_pointer_cast<Conv2d>(blocks[std::to_string(++index)]);
|
||||
h = last_conv->forward(ctx, h);
|
||||
if (patch_size > 1) {
|
||||
h = unpatchify(ctx->ggml_ctx, h, patch_size, 1);
|
||||
}
|
||||
// shape(W, H, 3, 3 + T) => shape(W, H, 3, T)
|
||||
h = ggml_view_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3] - 3, h->nb[1], h->nb[2], h->nb[3], 3 * h->nb[3]);
|
||||
return h;
|
||||
}
|
||||
};
|
||||
|
||||
class TAEHV : public GGMLBlock {
|
||||
protected:
|
||||
bool decode_only;
|
||||
SDVersion version;
|
||||
|
||||
public:
|
||||
TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2)
|
||||
: decode_only(decode_only), version(version) {
|
||||
int z_channels = 16;
|
||||
int patch = 1;
|
||||
if (version == VERSION_WAN2_2_TI2V) {
|
||||
z_channels = 48;
|
||||
patch = 2;
|
||||
}
|
||||
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoDecoder(z_channels, patch));
|
||||
if (!decode_only) {
|
||||
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoEncoder(z_channels, patch));
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
|
||||
auto decoder = std::dynamic_pointer_cast<TinyVideoDecoder>(blocks["decoder"]);
|
||||
if (sd_version_is_wan(version)) {
|
||||
// (W, H, C, T) -> (W, H, T, C)
|
||||
z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 0, 1, 3, 2));
|
||||
}
|
||||
auto result = decoder->forward(ctx, z);
|
||||
if (sd_version_is_wan(version)) {
|
||||
// (W, H, C, T) -> (W, H, T, C)
|
||||
result = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, result, 0, 1, 3, 2));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||
auto encoder = std::dynamic_pointer_cast<TinyVideoEncoder>(blocks["encoder"]);
|
||||
// (W, H, T, C) -> (W, H, C, T)
|
||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
|
||||
int64_t num_frames = x->ne[3];
|
||||
if (num_frames % 4) {
|
||||
// pad to multiple of 4 at the end
|
||||
auto last_frame = ggml_view_4d(ctx->ggml_ctx, x, x->ne[0], x->ne[1], x->ne[2], 1, x->nb[1], x->nb[2], x->nb[3], (num_frames - 1) * x->nb[3]);
|
||||
for (int i = 0; i < 4 - num_frames % 4; i++) {
|
||||
x = ggml_concat(ctx->ggml_ctx, x, last_frame, 3);
|
||||
}
|
||||
}
|
||||
x = encoder->forward(ctx, x);
|
||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
class TAESD : public GGMLBlock {
|
||||
protected:
|
||||
bool decode_only;
|
||||
|
|
@ -192,18 +497,30 @@ public:
|
|||
};
|
||||
|
||||
struct TinyAutoEncoder : public GGMLRunner {
|
||||
TinyAutoEncoder(ggml_backend_t backend, bool offload_params_to_cpu)
|
||||
: GGMLRunner(backend, offload_params_to_cpu) {}
|
||||
virtual bool compute(const int n_threads,
|
||||
struct ggml_tensor* z,
|
||||
bool decode_graph,
|
||||
struct ggml_tensor** output,
|
||||
struct ggml_context* output_ctx = nullptr) = 0;
|
||||
|
||||
virtual bool load_from_file(const std::string& file_path, int n_threads) = 0;
|
||||
};
|
||||
|
||||
struct TinyImageAutoEncoder : public TinyAutoEncoder {
|
||||
TAESD taesd;
|
||||
bool decode_only = false;
|
||||
|
||||
TinyAutoEncoder(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2TensorStorage& tensor_storage_map,
|
||||
const std::string prefix,
|
||||
bool decoder_only = true,
|
||||
SDVersion version = VERSION_SD1)
|
||||
TinyImageAutoEncoder(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2TensorStorage& tensor_storage_map,
|
||||
const std::string prefix,
|
||||
bool decoder_only = true,
|
||||
SDVersion version = VERSION_SD1)
|
||||
: decode_only(decoder_only),
|
||||
taesd(decoder_only, version),
|
||||
GGMLRunner(backend, offload_params_to_cpu) {
|
||||
TinyAutoEncoder(backend, offload_params_to_cpu) {
|
||||
taesd.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
|
|
@ -260,4 +577,73 @@ struct TinyAutoEncoder : public GGMLRunner {
|
|||
}
|
||||
};
|
||||
|
||||
struct TinyVideoAutoEncoder : public TinyAutoEncoder {
|
||||
TAEHV taehv;
|
||||
bool decode_only = false;
|
||||
|
||||
TinyVideoAutoEncoder(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2TensorStorage& tensor_storage_map,
|
||||
const std::string prefix,
|
||||
bool decoder_only = true,
|
||||
SDVersion version = VERSION_WAN2)
|
||||
: decode_only(decoder_only),
|
||||
taehv(decoder_only, version),
|
||||
TinyAutoEncoder(backend, offload_params_to_cpu) {
|
||||
taehv.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
return "taehv";
|
||||
}
|
||||
|
||||
bool load_from_file(const std::string& file_path, int n_threads) {
|
||||
LOG_INFO("loading taehv from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false");
|
||||
alloc_params_buffer();
|
||||
std::map<std::string, ggml_tensor*> taehv_tensors;
|
||||
taehv.get_param_tensors(taehv_tensors);
|
||||
std::set<std::string> ignore_tensors;
|
||||
if (decode_only) {
|
||||
ignore_tensors.insert("encoder.");
|
||||
}
|
||||
|
||||
ModelLoader model_loader;
|
||||
if (!model_loader.init_from_file(file_path)) {
|
||||
LOG_ERROR("init taehv model loader from file failed: '%s'", file_path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
bool success = model_loader.load_tensors(taehv_tensors, ignore_tensors, n_threads);
|
||||
|
||||
if (!success) {
|
||||
LOG_ERROR("load tae tensors from model loader failed");
|
||||
return false;
|
||||
}
|
||||
|
||||
LOG_INFO("taehv model loaded");
|
||||
return success;
|
||||
}
|
||||
|
||||
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
||||
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||
z = to_backend(z);
|
||||
auto runner_ctx = get_context();
|
||||
struct ggml_tensor* out = decode_graph ? taehv.decode(&runner_ctx, z) : taehv.encode(&runner_ctx, z);
|
||||
ggml_build_forward_expand(gf, out);
|
||||
return gf;
|
||||
}
|
||||
|
||||
bool compute(const int n_threads,
|
||||
struct ggml_tensor* z,
|
||||
bool decode_graph,
|
||||
struct ggml_tensor** output,
|
||||
struct ggml_context* output_ctx = nullptr) {
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
return build_graph(z, decode_graph);
|
||||
};
|
||||
|
||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __TAE_HPP__
|
||||
Loading…
Add table
Add a link
Reference in a new issue