Concedo 2025-06-20 11:13:04 +08:00
parent b925bbfc6d
commit 175c99081e
2 changed files with 226 additions and 29 deletions

View file

@ -1139,18 +1139,111 @@ public:
decode ? 3 : C,
x->ne[3]); // channels
int64_t t0 = ggml_time_ms();
// TODO: args instead of env for tile size / overlap?
float tile_overlap = 0.5f;
const char* SD_TILE_OVERLAP = getenv("SD_TILE_OVERLAP");
if (SD_TILE_OVERLAP != nullptr) {
std::string sd_tile_overlap_str = SD_TILE_OVERLAP;
try {
tile_overlap = std::stof(sd_tile_overlap_str);
if (tile_overlap < 0.0) {
LOG_WARN("SD_TILE_OVERLAP too low, setting it to 0.0");
tile_overlap = 0.0;
}
else if (tile_overlap > 0.5) {
LOG_WARN("SD_TILE_OVERLAP too high, setting it to 0.5");
tile_overlap = 0.5;
}
} catch (const std::invalid_argument&) {
LOG_WARN("SD_TILE_OVERLAP is invalid, keeping the default");
} catch (const std::out_of_range&) {
LOG_WARN("SD_TILE_OVERLAP is out of range, keeping the default");
}
}
int tile_size_x = 32;
int tile_size_y = 32;
const char* SD_TILE_SIZE = getenv("SD_TILE_SIZE");
if (SD_TILE_SIZE != nullptr) {
// format is AxB, or just A (equivalent to AxA)
// A and B can be integers (tile size) or floating point
// floating point <= 1 means simple fraction of the latent dimension
// floating point > 1 means number of tiles across that dimension
// a single number gets applied to both
auto get_tile_factor = [tile_overlap](const std::string& factor_str) {
float factor = std::stof(factor_str);
if (factor > 1.0)
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
return factor;
};
const int latent_x = W / (decode ? 1 : 8);
const int latent_y = H / (decode ? 1 : 8);
const int min_tile_dimension = 4;
std::string sd_tile_size_str = SD_TILE_SIZE;
size_t x_pos = sd_tile_size_str.find('x');
try {
int tmp_x = tile_size_x, tmp_y = tile_size_y;
if (x_pos != std::string::npos) {
std::string tile_x_str = sd_tile_size_str.substr(0, x_pos);
std::string tile_y_str = sd_tile_size_str.substr(x_pos + 1);
if (tile_x_str.find('.') != std::string::npos) {
tmp_x = std::round(latent_x * get_tile_factor(tile_x_str));
}
else {
tmp_x = std::stoi(tile_x_str);
}
if (tile_y_str.find('.') != std::string::npos) {
tmp_y = std::round(latent_y * get_tile_factor(tile_y_str));
}
else {
tmp_y = std::stoi(tile_y_str);
}
}
else {
if (sd_tile_size_str.find('.') != std::string::npos) {
float tile_factor = get_tile_factor(sd_tile_size_str);
tmp_x = std::round(latent_x * tile_factor);
tmp_y = std::round(latent_y * tile_factor);
}
else {
tmp_x = tmp_y = std::stoi(sd_tile_size_str);
}
}
tile_size_x = std::max(std::min(tmp_x, latent_x), min_tile_dimension);
tile_size_y = std::max(std::min(tmp_y, latent_y), min_tile_dimension);
} catch (const std::invalid_argument&) {
LOG_WARN("SD_TILE_SIZE is invalid, keeping the default");
} catch (const std::out_of_range&) {
LOG_WARN("SD_TILE_SIZE is out of range, keeping the default");
}
}
if(!decode){
// TODO: also use and arg for this one?
// to keep the compute buffer size consistent
tile_size_x*=1.30539;
tile_size_y*=1.30539;
}
if (!use_tiny_autoencoder) {
if (decode) {
ggml_tensor_scale(x, 1.0f / scale_factor);
} else {
ggml_tensor_scale_input(x);
}
if (vae_tiling && decode) { // TODO: support tiling vae encode
if (vae_tiling) {
if (SD_TILE_SIZE != nullptr) {
LOG_INFO("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
}
if (SD_TILE_OVERLAP != nullptr) {
LOG_INFO("VAE Tile overlap: %.2f", tile_overlap);
}
// split latent in 32x32 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
first_stage_model->compute(n_threads, in, decode, &out);
};
sd_tiling(x, result, 8, 32, 0.5f, on_tiling);
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
} else {
first_stage_model->compute(n_threads, x, decode, &result);
}
@ -1160,7 +1253,7 @@ public:
}
} else {
//koboldcpp never use tiling with taesd
if (false && vae_tiling && decode) { // TODO: support tiling vae encode
if (false && vae_tiling) { // TODO: support tiling vae encode
// split latent in 64x64 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
tae_first_stage->compute(n_threads, in, decode, &out);