sd: sync to master-646-0baf721

This commit is contained in:
Wagner Bruna 2026-05-22 23:15:29 -03:00
parent 8427efb4c6
commit 3ec404b2cb
3 changed files with 365 additions and 91 deletions

View file

@ -6,8 +6,10 @@
#include <cstdlib>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "common_dit.hpp"
#include "ggml_extend.hpp"
@ -26,6 +28,10 @@ namespace LTXVUpsampler {
bool spatial_upsample = true;
bool temporal_upsample = false;
bool rational_resampler = false;
float spatial_scale = 2.f;
int spatial_up_num = 2;
int spatial_down_den = 1;
int temporal_up_factor = 1;
};
static inline bool has_tensor(const String2TensorStorage& tensor_storage_map,
@ -33,14 +39,21 @@ namespace LTXVUpsampler {
return tensor_storage_map.find(name) != tensor_storage_map.end();
}
static inline int64_t get_tensor_ne(const String2TensorStorage& tensor_storage_map,
const std::string& name,
int axis,
int64_t fallback) {
auto it = tensor_storage_map.find(name);
if (it == tensor_storage_map.end() || axis < 0 || axis >= GGML_MAX_DIMS) {
return fallback;
}
return it->second.ne[axis];
}
static inline int64_t get_tensor_ne0(const String2TensorStorage& tensor_storage_map,
const std::string& name,
int64_t fallback) {
auto it = tensor_storage_map.find(name);
if (it == tensor_storage_map.end()) {
return fallback;
}
return it->second.ne[0];
return get_tensor_ne(tensor_storage_map, name, 0, fallback);
}
static inline int count_module_blocks(const String2TensorStorage& tensor_storage_map,
@ -71,8 +84,36 @@ namespace LTXVUpsampler {
if (detected_blocks > 0) {
config.num_blocks_per_stage = detected_blocks;
}
config.spatial_upsample = has_tensor(tensor_storage_map, "upsampler.0.weight");
config.temporal_upsample = has_tensor(tensor_storage_map, "temporal_upsampler.0.weight");
config.rational_resampler = has_tensor(tensor_storage_map, "upsampler.conv.weight");
int64_t upsampler_out_channels = get_tensor_ne0(tensor_storage_map, "upsampler.0.bias", 0);
config.spatial_upsample = config.rational_resampler || upsampler_out_channels == 4 * config.mid_channels;
config.temporal_upsample = upsampler_out_channels == 2 * config.mid_channels;
if (config.temporal_upsample) {
config.temporal_up_factor = 2;
}
if (config.rational_resampler) {
int64_t out_channels = get_tensor_ne(tensor_storage_map,
"upsampler.conv.weight",
3,
config.mid_channels * 9);
if (config.mid_channels > 0 && out_channels % config.mid_channels == 0) {
int64_t ratio = out_channels / config.mid_channels;
int num = static_cast<int>(std::round(std::sqrt(static_cast<double>(ratio))));
if (num > 0 && static_cast<int64_t>(num) * num == ratio) {
config.spatial_up_num = num;
}
}
if (config.spatial_up_num == 3) {
config.spatial_down_den = 2;
config.spatial_scale = 1.5f;
} else if (config.spatial_up_num == 4) {
config.spatial_down_den = 1;
config.spatial_scale = 4.f;
} else {
config.spatial_down_den = 1;
config.spatial_scale = static_cast<float>(config.spatial_up_num);
}
}
return config;
}
@ -160,16 +201,135 @@ namespace LTXVUpsampler {
: upscale_factor(upscale_factor) {}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override {
GGML_ASSERT(upscale_factor == 2);
GGML_ASSERT(upscale_factor > 0);
int64_t h = x->ne[1];
int64_t w = x->ne[0];
// x: [b*f, c*4, h, w] -> [b*f, c, h*2, w*2]
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 2, 0, 1, 3)); // [b*f, h, w, c*4]
x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]); // [b*f, h*w, c*4]
GGML_ASSERT(x->ne[2] % (upscale_factor * upscale_factor) == 0);
// x: [b*f, c*p1*p2, h, w] -> [b*f, c, h*p1, w*p2]
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 2, 0, 1, 3)); // [b*f, h, w, c*p1*p2]
x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]); // [b*f, h*w, c*p1*p2]
return DiT::unpatchify(ctx->ggml_ctx, x, h, w, upscale_factor, upscale_factor, true);
}
};
class TemporalPixelShuffleND : public UnaryBlock {
protected:
int upscale_factor;
public:
explicit TemporalPixelShuffleND(int upscale_factor)
: upscale_factor(upscale_factor) {}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override {
GGML_ASSERT(upscale_factor > 0);
GGML_ASSERT(x->ne[3] % upscale_factor == 0);
const int64_t W = x->ne[0];
const int64_t H = x->ne[1];
const int64_t F = x->ne[2];
const int64_t C = x->ne[3] / upscale_factor;
// x: [b, c*p, f, h, w] -> [b, c, f*p, h, w]
x = ggml_ext_cont(ctx->ggml_ctx, x);
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, F, upscale_factor, C);
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3));
return ggml_reshape_4d(ctx->ggml_ctx, x, W, H, F * upscale_factor, C);
}
};
class BlurDownsample : public GGMLBlock {
protected:
int64_t channels;
int stride;
ggml_tensor* kernel = nullptr;
std::vector<float> kernel_data;
void init_params(ggml_context* ctx,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "") override {
SD_UNUSED(tensor_storage_map);
if (stride == 1) {
return;
}
kernel = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 5, 5, 1, channels);
std::string name = prefix + "kernel";
ggml_set_name(kernel, name.c_str());
static const float binomial[5] = {1.f, 4.f, 6.f, 4.f, 1.f};
kernel_data.resize(static_cast<size_t>(5 * 5 * channels));
for (int64_t c = 0; c < channels; ++c) {
for (int y = 0; y < 5; ++y) {
for (int x = 0; x < 5; ++x) {
kernel_data[static_cast<size_t>(x + 5 * (y + 5 * c))] =
binomial[y] * binomial[x] / 256.f;
}
}
}
}
public:
BlurDownsample(int64_t channels, int stride)
: channels(channels),
stride(stride) {
GGML_ASSERT(stride >= 1);
}
void load_fixed_tensors() {
if (kernel == nullptr || kernel_data.empty()) {
return;
}
ggml_backend_tensor_set(kernel, kernel_data.data(), 0, kernel_data.size() * sizeof(float));
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
if (stride == 1) {
return x;
}
GGML_ASSERT(kernel != nullptr);
GGML_ASSERT(x->ne[2] == channels);
if (ctx->conv2d_direct_enabled) {
return ggml_conv_2d_dw_direct(ctx->ggml_ctx, kernel, x, stride, stride, 2, 2, 1, 1);
}
return ggml_conv_2d_dw(ctx->ggml_ctx, kernel, x, stride, stride, 2, 2, 1, 1);
}
};
class SpatialRationalResampler : public GGMLBlock {
protected:
int64_t mid_channels;
int num;
int den;
public:
SpatialRationalResampler(int64_t mid_channels, int num, int den)
: mid_channels(mid_channels),
num(num),
den(den) {
GGML_ASSERT(num >= 1);
GGML_ASSERT(den >= 1);
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(mid_channels, num * num * mid_channels, {3, 3}, {1, 1}, {1, 1}));
blocks["pixel_shuffle"] = std::shared_ptr<GGMLBlock>(new PixelShuffleND(num));
blocks["blur_down"] = std::shared_ptr<GGMLBlock>(new BlurDownsample(mid_channels, den));
}
void load_fixed_tensors() {
auto blur_down = std::dynamic_pointer_cast<BlurDownsample>(blocks["blur_down"]);
blur_down->load_fixed_tensors();
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
auto pixel_shuffle = std::dynamic_pointer_cast<PixelShuffleND>(blocks["pixel_shuffle"]);
auto blur_down = std::dynamic_pointer_cast<BlurDownsample>(blocks["blur_down"]);
// rearrange(x, "b c f h w -> (b f) c h w")
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
x = conv->forward(ctx, x);
x = pixel_shuffle->forward(ctx, x);
x = blur_down->forward(ctx, x);
return ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
}
};
class LatentUpsampler : public GGMLBlock {
public:
LatentUpsamplerConfig config;
@ -177,9 +337,7 @@ namespace LTXVUpsampler {
explicit LatentUpsampler(LatentUpsamplerConfig config)
: config(std::move(config)) {
GGML_ASSERT(this->config.dims == 3);
GGML_ASSERT(this->config.spatial_upsample);
GGML_ASSERT(!this->config.temporal_upsample);
GGML_ASSERT(!this->config.rational_resampler);
GGML_ASSERT(this->config.spatial_upsample || this->config.temporal_upsample);
blocks["initial_conv"] = std::shared_ptr<GGMLBlock>(new Conv3d(this->config.in_channels,
this->config.mid_channels,
@ -190,12 +348,25 @@ namespace LTXVUpsampler {
for (int i = 0; i < this->config.num_blocks_per_stage; ++i) {
blocks["res_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResBlock(this->config.mid_channels, this->config.dims));
}
blocks["upsampler.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(this->config.mid_channels,
4 * this->config.mid_channels,
{3, 3},
{1, 1},
{1, 1}));
blocks["upsampler.1"] = std::shared_ptr<GGMLBlock>(new PixelShuffleND(2));
if (this->config.rational_resampler) {
blocks["upsampler"] = std::shared_ptr<GGMLBlock>(new SpatialRationalResampler(this->config.mid_channels,
this->config.spatial_up_num,
this->config.spatial_down_den));
} else if (this->config.temporal_upsample) {
blocks["upsampler.0"] = std::shared_ptr<GGMLBlock>(new Conv3d(this->config.mid_channels,
this->config.temporal_up_factor * this->config.mid_channels,
{3, 3, 3},
{1, 1, 1},
{1, 1, 1}));
blocks["upsampler.1"] = std::shared_ptr<GGMLBlock>(new TemporalPixelShuffleND(this->config.temporal_up_factor));
} else {
blocks["upsampler.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(this->config.mid_channels,
4 * this->config.mid_channels,
{3, 3},
{1, 1},
{1, 1}));
blocks["upsampler.1"] = std::shared_ptr<GGMLBlock>(new PixelShuffleND(2));
}
for (int i = 0; i < this->config.num_blocks_per_stage; ++i) {
blocks["post_upsample_res_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResBlock(this->config.mid_channels, this->config.dims));
}
@ -207,13 +378,11 @@ namespace LTXVUpsampler {
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
// x: [b*c, f, h, w]
// return: [b*c, f, h*2, w*2]
auto initial_conv = std::dynamic_pointer_cast<Conv3d>(blocks["initial_conv"]);
auto initial_norm = std::dynamic_pointer_cast<VideoGroupNorm>(blocks["initial_norm"]);
auto upsample_conv = std::dynamic_pointer_cast<Conv2d>(blocks["upsampler.0"]);
auto pixel_shuffle = std::dynamic_pointer_cast<PixelShuffleND>(blocks["upsampler.1"]);
auto final_conv = std::dynamic_pointer_cast<Conv3d>(blocks["final_conv"]);
// x: [b, c, f, h, w]
// return: [b, c, scaled_f, scaled_h, scaled_w]
auto initial_conv = std::dynamic_pointer_cast<Conv3d>(blocks["initial_conv"]);
auto initial_norm = std::dynamic_pointer_cast<VideoGroupNorm>(blocks["initial_norm"]);
auto final_conv = std::dynamic_pointer_cast<Conv3d>(blocks["final_conv"]);
x = initial_conv->forward(ctx, x);
x = initial_norm->forward(ctx, x);
@ -226,11 +395,25 @@ namespace LTXVUpsampler {
sd::ggml_graph_cut::mark_graph_cut(x, "ltx_latent_upsampler.res_blocks." + std::to_string(i), "x");
}
// rearrange(x, "b c f h w -> (b f) c h w"),
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // [b*f, c, h, w]
x = upsample_conv->forward(ctx, x); // [b*f, c*4, h, w]
x = pixel_shuffle->forward(ctx, x); // [b*f, c, h*2, w*2]
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // [b*c, f, h, w]
if (config.rational_resampler) {
auto upsampler = std::dynamic_pointer_cast<SpatialRationalResampler>(blocks["upsampler"]);
x = upsampler->forward(ctx, x);
} else if (config.temporal_upsample) {
auto upsample_conv = std::dynamic_pointer_cast<Conv3d>(blocks["upsampler.0"]);
auto pixel_shuffle = std::dynamic_pointer_cast<TemporalPixelShuffleND>(blocks["upsampler.1"]);
x = upsample_conv->forward(ctx, x); // [b, c*2, f, h, w]
x = pixel_shuffle->forward(ctx, x); // [b, c, f*2, h, w]
x = ggml_ext_slice(ctx->ggml_ctx, x, 2, 1, x->ne[2]); // x[:, :, 1:, :, :]
} else {
auto upsample_conv = std::dynamic_pointer_cast<Conv2d>(blocks["upsampler.0"]);
auto pixel_shuffle = std::dynamic_pointer_cast<PixelShuffleND>(blocks["upsampler.1"]);
// rearrange(x, "b c f h w -> (b f) c h w"),
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // [b*f, c, h, w]
x = upsample_conv->forward(ctx, x); // [b*f, c*4, h, w]
x = pixel_shuffle->forward(ctx, x); // [b*f, c, h*2, w*2]
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // [b*c, f, h, w]
}
sd::ggml_graph_cut::mark_graph_cut(x, "ltx_latent_upsampler.spatial_up", "x");
for (int i = 0; i < config.num_blocks_per_stage; ++i) {
@ -243,6 +426,14 @@ namespace LTXVUpsampler {
sd::ggml_graph_cut::mark_graph_cut(x, "ltx_latent_upsampler.final", "x");
return x;
}
void load_fixed_tensors() {
if (!config.rational_resampler) {
return;
}
auto upsampler = std::dynamic_pointer_cast<SpatialRationalResampler>(blocks["upsampler"]);
upsampler->load_fixed_tensors();
}
};
struct LatentUpsamplerRunner : public GGMLRunner {
@ -265,20 +456,24 @@ namespace LTXVUpsampler {
}
const auto& tensor_storage_map = model_loader.get_tensor_storage_map();
bool has_regular_upsampler = has_tensor(tensor_storage_map, "upsampler.0.weight");
bool has_rational_spatial = has_tensor(tensor_storage_map, "upsampler.conv.weight");
if (!has_tensor(tensor_storage_map, "post_upsample_res_blocks.0.conv2.bias") ||
!has_tensor(tensor_storage_map, "upsampler.0.weight")) {
LOG_ERROR("unsupported LTX latent upsampler weights: expected spatial upsampler tensors");
(!has_regular_upsampler && !has_rational_spatial)) {
LOG_ERROR("unsupported LTX latent upsampler weights: expected upsampler tensors");
return false;
}
LatentUpsamplerConfig config = detect_config_from_weights(tensor_storage_map);
if (config.dims != 3 || !config.spatial_upsample || config.temporal_upsample ||
config.rational_resampler) {
LOG_ERROR("unsupported LTX latent upsampler config: dims=%d spatial=%d temporal=%d rational=%d",
if (config.dims != 3 || (!config.spatial_upsample && !config.temporal_upsample) ||
config.spatial_up_num < 1 || config.spatial_down_den < 1 || config.temporal_up_factor < 1) {
LOG_ERROR("unsupported LTX latent upsampler config: dims=%d spatial=%d temporal=%d rational=%d scale=%.3f temporal_factor=%d",
config.dims,
config.spatial_upsample,
config.temporal_upsample,
config.rational_resampler);
config.rational_resampler,
config.spatial_scale,
config.temporal_up_factor);
return false;
}
@ -291,15 +486,23 @@ namespace LTXVUpsampler {
std::map<std::string, ggml_tensor*> tensors;
model->get_param_tensors(tensors);
if (!model_loader.load_tensors(tensors, {}, n_threads)) {
std::set<std::string> ignore_tensors;
if (config.rational_resampler) {
ignore_tensors.insert("upsampler.blur_down.kernel");
}
if (!model_loader.load_tensors(tensors, ignore_tensors, n_threads)) {
LOG_ERROR("load LTX latent upsampler tensors failed");
return false;
}
model->load_fixed_tensors();
LOG_INFO("LTX latent upsampler loaded: in_channels=%" PRId64 ", mid_channels=%" PRId64 ", blocks=%d",
LOG_INFO("LTX latent upsampler loaded: in_channels=%" PRId64 ", mid_channels=%" PRId64 ", blocks=%d, scale=%.3f, temporal_factor=%d, rational=%d",
config.in_channels,
config.mid_channels,
config.num_blocks_per_stage);
config.num_blocks_per_stage,
config.spatial_scale,
config.temporal_up_factor,
config.rational_resampler);
return true;
}

View file

@ -153,7 +153,7 @@ namespace LTXVAE {
GGML_ASSERT(x->ne[2] >= temporal_pad);
int end_idx = x->ne[2] - temporal_pad;
int end_idx = (int)x->ne[2] - temporal_pad;
int start_idx = std::max(end_idx - pad, 0);
// Save a contiguous copy of the last `pad` frames so the large `x`

View file

@ -963,7 +963,7 @@ public:
LOG_INFO("using TAE for preview");
preview_vae = create_tae();
preview_vae->set_max_graph_vram_bytes(max_graph_vram_bytes);
get_param_tensors_p(first_stage_model, vae_mmap, "vae");
get_param_tensors_p(preview_vae, vae_mmap, "tae");
}
}
@ -1835,23 +1835,18 @@ public:
std::function<void(int, int, sd_image_t*, bool, void*)> step_callback,
void* step_callback_data,
bool is_noisy) {
bool is_video = preview_latent_tensor_is_video(latents);
uint32_t dim = is_video ? static_cast<uint32_t>(latents.shape()[3]) : static_cast<uint32_t>(latents.shape()[2]);
int channels = get_latent_channel();
auto _latents = channels != dim ? is_video ? sd::ops::slice(latents, 3, 0, channels)
: sd::ops::slice(latents, 2, 0, channels)
: latents;
if (preview_mode == PREVIEW_PROJ) {
sd::Tensor<float> _latents = latents;
int patch_sz = 1;
const float(*latent_rgb_proj)[3] = nullptr;
float* latent_rgb_bias = nullptr;
bool is_video = preview_latent_tensor_is_video(latents);
uint32_t dim = is_video ? static_cast<uint32_t>(latents.shape()[3]) : static_cast<uint32_t>(latents.shape()[2]);
if (version == VERSION_LTXAV) {
if (is_video) {
_latents = sd::ops::slice(_latents, 3, 0, 128);
} else {
_latents = sd::ops::slice(_latents, 2, 0, 128);
}
dim = 128;
}
if (dim == 128) {
if (channels == 128) {
if (sd_version_uses_flux2_vae(version)) {
latent_rgb_proj = flux2_latent_rgb_proj;
latent_rgb_bias = flux2_latent_rgb_bias;
@ -1863,7 +1858,7 @@ public:
LOG_WARN("No latent to RGB projection known for this model");
return;
}
} else if (dim == 48) {
} else if (channels == 48) {
if (sd_version_is_wan(version)) {
latent_rgb_proj = wan_22_latent_rgb_proj;
latent_rgb_bias = wan_22_latent_rgb_bias;
@ -1871,7 +1866,7 @@ public:
LOG_WARN("No latent to RGB projection known for this model");
return;
}
} else if (dim == 16) {
} else if (channels == 16) {
if (sd_version_is_sd3(version)) {
latent_rgb_proj = sd3_latent_rgb_proj;
latent_rgb_bias = sd3_latent_rgb_bias;
@ -1885,7 +1880,7 @@ public:
LOG_WARN("No latent to RGB projection known for this model");
return;
}
} else if (dim == 4) {
} else if (channels == 4) {
if (sd_version_is_sdxl(version)) {
latent_rgb_proj = sdxl_latent_rgb_proj;
latent_rgb_bias = sdxl_latent_rgb_bias;
@ -1896,8 +1891,8 @@ public:
LOG_WARN("No latent to RGB projection known for this model");
return;
}
} else if (dim != 3) {
LOG_WARN("No latent to RGB projection known for this model");
} else if (channels != 3) {
LOG_WARN("No latent to RGB projection known for this model (dim = %d)", dim);
return;
}
@ -1922,14 +1917,13 @@ public:
if (preview_mode == PREVIEW_VAE || preview_mode == PREVIEW_TAE) {
sd::Tensor<float> vae_latents;
sd::Tensor<float> decoded;
bool is_video = preview_latent_tensor_is_video(latents);
if (preview_vae) {
preview_vae->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling);
vae_latents = preview_vae->diffusion_to_vae_latents(latents);
vae_latents = preview_vae->diffusion_to_vae_latents(_latents);
decoded = preview_vae->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true);
} else {
first_stage_model->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling);
vae_latents = first_stage_model->diffusion_to_vae_latents(latents);
vae_latents = first_stage_model->diffusion_to_vae_latents(_latents);
decoded = first_stage_model->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true);
}
if (decoded.empty()) {
@ -2384,19 +2378,41 @@ public:
int vae_scale_factor = get_vae_scale_factor();
int W = width / vae_scale_factor;
int H = height / vae_scale_factor;
int T = frames;
if (sd_version_is_ltxav(version)) {
T = ((T - 1) / 8) + 1;
} else if (sd_version_is_wan(version)) {
T = ((T - 1) / 4) + 1;
}
int C = get_latent_channel();
int T = video_frames_to_latent_frames(frames);
int C = get_latent_channel();
if (video) {
return sd::zeros<float>({W, H, T, C, 1});
}
return sd::zeros<float>({W, H, C, 1});
}
int video_frames_to_latent_frames(int frames) {
int latent_frames = frames;
if (sd_version_is_ltxav(version)) {
latent_frames = ((frames - 1) / 8) + 1;
} else if (sd_version_is_wan(version)) {
latent_frames = ((frames - 1) / 4) + 1;
}
return latent_frames;
}
int latent_frames_to_video_frames(int latent_frames) {
if (latent_frames <= 0) {
return latent_frames;
}
if (sd_version_is_ltxav(version)) {
return (latent_frames - 1) * 8 + 1;
}
if (sd_version_is_wan(version)) {
return (latent_frames - 1) * 4 + 1;
}
return latent_frames;
}
int align_video_frames(int frames) {
return latent_frames_to_video_frames(video_frames_to_latent_frames(frames));
}
sd::Tensor<float> encode_to_vae_latents(const sd::Tensor<float>& x) {
auto latents = first_stage_model->encode(n_threads, x, vae_tiling_params, circular_x, circular_y);
if (latents.empty()) {
@ -3248,16 +3264,12 @@ struct GenerationRequest {
}
GenerationRequest(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params) {
prompt = SAFE_STR(sd_vid_gen_params->prompt);
negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt);
width = sd_vid_gen_params->width;
height = sd_vid_gen_params->height;
requested_frames = std::max(1, sd_vid_gen_params->video_frames);
if (sd_version_is_ltxav(sd_ctx->sd->version)) {
frames = ((requested_frames - 1 + 7) / 8) * 8 + 1;
} else {
frames = (requested_frames - 1) / 4 * 4 + 1;
}
prompt = SAFE_STR(sd_vid_gen_params->prompt);
negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt);
width = sd_vid_gen_params->width;
height = sd_vid_gen_params->height;
requested_frames = std::max(1, sd_vid_gen_params->video_frames);
frames = sd_ctx->sd->align_video_frames(requested_frames);
clip_skip = sd_vid_gen_params->clip_skip;
fps = std::max(1, sd_vid_gen_params->fps);
vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
@ -3815,6 +3827,30 @@ static sd::Tensor<float> unpack_ltxav_audio_latent(const sd::Tensor<float>& pack
return audio_latent;
}
static sd::Tensor<float> make_ltxav_empty_audio_latent(int audio_length) {
if (audio_length <= 0) {
return {};
}
constexpr int kLtxavAudioFrequencyBins = 16;
constexpr int kLtxavAudioChannels = 8;
return sd::zeros<float>({kLtxavAudioFrequencyBins, audio_length, kLtxavAudioChannels, 1});
}
static sd::Tensor<float> resize_ltxav_audio_latent(const sd::Tensor<float>& audio_latent,
int target_audio_length) {
auto resized = make_ltxav_empty_audio_latent(target_audio_length);
if (resized.empty() || audio_latent.empty()) {
return resized;
}
GGML_ASSERT(audio_latent.dim() == 3 || audio_latent.dim() == 4);
int copy_length = std::min(static_cast<int>(audio_latent.shape()[1]), target_audio_length);
if (copy_length > 0) {
auto copied = sd::ops::slice(audio_latent, 1, 0, copy_length);
sd::ops::slice_assign(&resized, 1, 0, copy_length, copied);
}
return resized;
}
static int get_ltxav_num_audio_latents(int frames, int fps) {
GGML_ASSERT(frames > 0);
GGML_ASSERT(fps > 0);
@ -4644,10 +4680,8 @@ static std::optional<ImageGenerationLatents> prepare_video_generation_latents(sd
}
if (sd_version_is_ltxav(sd_ctx->sd->version)) {
constexpr int kLtxavAudioFrequencyBins = 16;
constexpr int kLtxavAudioChannels = 8;
latents.audio_length = get_ltxav_num_audio_latents(request->frames, request->fps);
latents.audio_latent = sd::zeros<float>({kLtxavAudioFrequencyBins, latents.audio_length, kLtxavAudioChannels, 1});
latents.audio_length = get_ltxav_num_audio_latents(request->frames, request->fps);
latents.audio_latent = make_ltxav_empty_audio_latent(latents.audio_length);
}
if (sd_version_is_ltxav(sd_ctx->sd->version)) {
@ -4997,9 +5031,9 @@ static sd_image_t* decode_video_outputs(sd_ctx_t* sd_ctx,
(int)vid.shape()[1],
(int)vid.shape()[2],
(int)vid.shape()[3]);
if (request.requested_frames > 0 &&
vid.shape()[2] > request.requested_frames) {
vid = sd::ops::slice(vid, 2, 0, request.requested_frames);
if (request.frames > 0 &&
vid.shape()[2] > request.frames) {
vid = sd::ops::slice(vid, 2, 0, request.frames);
}
sd_image_t* result_images = (sd_image_t*)calloc(vid.shape()[2], sizeof(sd_image_t));
@ -5040,7 +5074,7 @@ static sd::Tensor<float> upscale_ltx_spatial_video_latent(sd_ctx_t* sd_ctx,
audio_latent = unpack_ltxav_audio_latent(packed_latent, audio_length, latent_channels);
}
LOG_INFO("LTX latent spatial upscale: latent %dx%dx%dx%d -> x2",
LOG_INFO("LTX latent spatial upscale: latent %dx%dx%dx%d -> model output",
(int)video_latent.shape()[0],
(int)video_latent.shape()[1],
(int)video_latent.shape()[2],
@ -5366,9 +5400,46 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
LOG_INFO("LTX latent spatial upscale completed, taking %.2fs",
(upscale_end - upscale_start) * 1.0f / 1000);
x_t = std::move(upscaled_latent);
hires_request.width = static_cast<int>(x_t.shape()[0]) * hires_request.vae_scale_factor;
hires_request.height = static_cast<int>(x_t.shape()[1]) * hires_request.vae_scale_factor;
x_t = std::move(upscaled_latent);
hires_request.width = static_cast<int>(x_t.shape()[0]) * hires_request.vae_scale_factor;
hires_request.height = static_cast<int>(x_t.shape()[1]) * hires_request.vae_scale_factor;
int upscaled_latent_frames = static_cast<int>(x_t.shape()[2]);
int upscaled_frames = sd_ctx->sd->latent_frames_to_video_frames(upscaled_latent_frames);
if (upscaled_frames != hires_request.frames) {
LOG_INFO("LTX latent upsampler output latent frames %d, frames %d -> %d",
upscaled_latent_frames,
hires_request.frames,
upscaled_frames);
hires_request.frames = upscaled_frames;
}
if (sd_version_is_ltxav(sd_ctx->sd->version) && latents.audio_length > 0) {
int target_audio_length = get_ltxav_num_audio_latents(hires_request.frames, hires_request.fps);
if (target_audio_length != latents.audio_length) {
int latent_channels = sd_ctx->sd->get_latent_channel();
sd::Tensor<float> video_latent = x_t;
sd::Tensor<float> audio_latent = latents.audio_latent;
if (x_t.shape()[3] > latent_channels) {
video_latent = sd::ops::slice(x_t, 3, 0, latent_channels);
audio_latent = unpack_ltxav_audio_latent(x_t, latents.audio_length, latent_channels);
}
audio_latent = resize_ltxav_audio_latent(audio_latent, target_audio_length);
if (audio_latent.empty()) {
LOG_ERROR("failed to resize LTX audio latent for latent upscale: %d -> %d",
latents.audio_length,
target_audio_length);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->diffusion_model->free_params_buffer();
}
return false;
}
x_t = pack_ltxav_audio_and_video_latents(video_latent, audio_latent);
latents.audio_latent = std::move(audio_latent);
LOG_INFO("LTX audio latent length adjusted for latent upscale: %d -> %d",
latents.audio_length,
target_audio_length);
latents.audio_length = target_audio_length;
}
}
if ((request.hires.target_width > 0 || request.hires.target_height > 0) &&
(request.hires.target_width != hires_request.width || request.hires.target_height != hires_request.height)) {
LOG_WARN("LTX latent spatial upsampler output is %dx%d; ignoring hires target %dx%d",