diff --git a/otherarch/sdcpp/ltx_latent_upscaler.hpp b/otherarch/sdcpp/ltx_latent_upscaler.hpp index 93254454d..ea4a830c6 100644 --- a/otherarch/sdcpp/ltx_latent_upscaler.hpp +++ b/otherarch/sdcpp/ltx_latent_upscaler.hpp @@ -6,8 +6,10 @@ #include #include #include +#include #include #include +#include #include "common_dit.hpp" #include "ggml_extend.hpp" @@ -26,6 +28,10 @@ namespace LTXVUpsampler { bool spatial_upsample = true; bool temporal_upsample = false; bool rational_resampler = false; + float spatial_scale = 2.f; + int spatial_up_num = 2; + int spatial_down_den = 1; + int temporal_up_factor = 1; }; static inline bool has_tensor(const String2TensorStorage& tensor_storage_map, @@ -33,14 +39,21 @@ namespace LTXVUpsampler { return tensor_storage_map.find(name) != tensor_storage_map.end(); } + static inline int64_t get_tensor_ne(const String2TensorStorage& tensor_storage_map, + const std::string& name, + int axis, + int64_t fallback) { + auto it = tensor_storage_map.find(name); + if (it == tensor_storage_map.end() || axis < 0 || axis >= GGML_MAX_DIMS) { + return fallback; + } + return it->second.ne[axis]; + } + static inline int64_t get_tensor_ne0(const String2TensorStorage& tensor_storage_map, const std::string& name, int64_t fallback) { - auto it = tensor_storage_map.find(name); - if (it == tensor_storage_map.end()) { - return fallback; - } - return it->second.ne[0]; + return get_tensor_ne(tensor_storage_map, name, 0, fallback); } static inline int count_module_blocks(const String2TensorStorage& tensor_storage_map, @@ -71,8 +84,36 @@ namespace LTXVUpsampler { if (detected_blocks > 0) { config.num_blocks_per_stage = detected_blocks; } - config.spatial_upsample = has_tensor(tensor_storage_map, "upsampler.0.weight"); - config.temporal_upsample = has_tensor(tensor_storage_map, "temporal_upsampler.0.weight"); + config.rational_resampler = has_tensor(tensor_storage_map, "upsampler.conv.weight"); + int64_t upsampler_out_channels = get_tensor_ne0(tensor_storage_map, "upsampler.0.bias", 0); + config.spatial_upsample = config.rational_resampler || upsampler_out_channels == 4 * config.mid_channels; + config.temporal_upsample = upsampler_out_channels == 2 * config.mid_channels; + if (config.temporal_upsample) { + config.temporal_up_factor = 2; + } + if (config.rational_resampler) { + int64_t out_channels = get_tensor_ne(tensor_storage_map, + "upsampler.conv.weight", + 3, + config.mid_channels * 9); + if (config.mid_channels > 0 && out_channels % config.mid_channels == 0) { + int64_t ratio = out_channels / config.mid_channels; + int num = static_cast(std::round(std::sqrt(static_cast(ratio)))); + if (num > 0 && static_cast(num) * num == ratio) { + config.spatial_up_num = num; + } + } + if (config.spatial_up_num == 3) { + config.spatial_down_den = 2; + config.spatial_scale = 1.5f; + } else if (config.spatial_up_num == 4) { + config.spatial_down_den = 1; + config.spatial_scale = 4.f; + } else { + config.spatial_down_den = 1; + config.spatial_scale = static_cast(config.spatial_up_num); + } + } return config; } @@ -160,16 +201,135 @@ namespace LTXVUpsampler { : upscale_factor(upscale_factor) {} ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override { - GGML_ASSERT(upscale_factor == 2); + GGML_ASSERT(upscale_factor > 0); int64_t h = x->ne[1]; int64_t w = x->ne[0]; - // x: [b*f, c*4, h, w] -> [b*f, c, h*2, w*2] - x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 2, 0, 1, 3)); // [b*f, h, w, c*4] - x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]); // [b*f, h*w, c*4] + GGML_ASSERT(x->ne[2] % (upscale_factor * upscale_factor) == 0); + // x: [b*f, c*p1*p2, h, w] -> [b*f, c, h*p1, w*p2] + x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 2, 0, 1, 3)); // [b*f, h, w, c*p1*p2] + x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]); // [b*f, h*w, c*p1*p2] return DiT::unpatchify(ctx->ggml_ctx, x, h, w, upscale_factor, upscale_factor, true); } }; + class TemporalPixelShuffleND : public UnaryBlock { + protected: + int upscale_factor; + + public: + explicit TemporalPixelShuffleND(int upscale_factor) + : upscale_factor(upscale_factor) {} + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override { + GGML_ASSERT(upscale_factor > 0); + GGML_ASSERT(x->ne[3] % upscale_factor == 0); + const int64_t W = x->ne[0]; + const int64_t H = x->ne[1]; + const int64_t F = x->ne[2]; + const int64_t C = x->ne[3] / upscale_factor; + + // x: [b, c*p, f, h, w] -> [b, c, f*p, h, w] + x = ggml_ext_cont(ctx->ggml_ctx, x); + x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, F, upscale_factor, C); + x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); + return ggml_reshape_4d(ctx->ggml_ctx, x, W, H, F * upscale_factor, C); + } + }; + + class BlurDownsample : public GGMLBlock { + protected: + int64_t channels; + int stride; + ggml_tensor* kernel = nullptr; + std::vector kernel_data; + + void init_params(ggml_context* ctx, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "") override { + SD_UNUSED(tensor_storage_map); + if (stride == 1) { + return; + } + kernel = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 5, 5, 1, channels); + std::string name = prefix + "kernel"; + ggml_set_name(kernel, name.c_str()); + + static const float binomial[5] = {1.f, 4.f, 6.f, 4.f, 1.f}; + kernel_data.resize(static_cast(5 * 5 * channels)); + for (int64_t c = 0; c < channels; ++c) { + for (int y = 0; y < 5; ++y) { + for (int x = 0; x < 5; ++x) { + kernel_data[static_cast(x + 5 * (y + 5 * c))] = + binomial[y] * binomial[x] / 256.f; + } + } + } + } + + public: + BlurDownsample(int64_t channels, int stride) + : channels(channels), + stride(stride) { + GGML_ASSERT(stride >= 1); + } + + void load_fixed_tensors() { + if (kernel == nullptr || kernel_data.empty()) { + return; + } + ggml_backend_tensor_set(kernel, kernel_data.data(), 0, kernel_data.size() * sizeof(float)); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + if (stride == 1) { + return x; + } + GGML_ASSERT(kernel != nullptr); + GGML_ASSERT(x->ne[2] == channels); + if (ctx->conv2d_direct_enabled) { + return ggml_conv_2d_dw_direct(ctx->ggml_ctx, kernel, x, stride, stride, 2, 2, 1, 1); + } + return ggml_conv_2d_dw(ctx->ggml_ctx, kernel, x, stride, stride, 2, 2, 1, 1); + } + }; + + class SpatialRationalResampler : public GGMLBlock { + protected: + int64_t mid_channels; + int num; + int den; + + public: + SpatialRationalResampler(int64_t mid_channels, int num, int den) + : mid_channels(mid_channels), + num(num), + den(den) { + GGML_ASSERT(num >= 1); + GGML_ASSERT(den >= 1); + blocks["conv"] = std::shared_ptr(new Conv2d(mid_channels, num * num * mid_channels, {3, 3}, {1, 1}, {1, 1})); + blocks["pixel_shuffle"] = std::shared_ptr(new PixelShuffleND(num)); + blocks["blur_down"] = std::shared_ptr(new BlurDownsample(mid_channels, den)); + } + + void load_fixed_tensors() { + auto blur_down = std::dynamic_pointer_cast(blocks["blur_down"]); + blur_down->load_fixed_tensors(); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + auto pixel_shuffle = std::dynamic_pointer_cast(blocks["pixel_shuffle"]); + auto blur_down = std::dynamic_pointer_cast(blocks["blur_down"]); + + // rearrange(x, "b c f h w -> (b f) c h w") + x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); + x = conv->forward(ctx, x); + x = pixel_shuffle->forward(ctx, x); + x = blur_down->forward(ctx, x); + return ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); + } + }; + class LatentUpsampler : public GGMLBlock { public: LatentUpsamplerConfig config; @@ -177,9 +337,7 @@ namespace LTXVUpsampler { explicit LatentUpsampler(LatentUpsamplerConfig config) : config(std::move(config)) { GGML_ASSERT(this->config.dims == 3); - GGML_ASSERT(this->config.spatial_upsample); - GGML_ASSERT(!this->config.temporal_upsample); - GGML_ASSERT(!this->config.rational_resampler); + GGML_ASSERT(this->config.spatial_upsample || this->config.temporal_upsample); blocks["initial_conv"] = std::shared_ptr(new Conv3d(this->config.in_channels, this->config.mid_channels, @@ -190,12 +348,25 @@ namespace LTXVUpsampler { for (int i = 0; i < this->config.num_blocks_per_stage; ++i) { blocks["res_blocks." + std::to_string(i)] = std::shared_ptr(new ResBlock(this->config.mid_channels, this->config.dims)); } - blocks["upsampler.0"] = std::shared_ptr(new Conv2d(this->config.mid_channels, - 4 * this->config.mid_channels, - {3, 3}, - {1, 1}, - {1, 1})); - blocks["upsampler.1"] = std::shared_ptr(new PixelShuffleND(2)); + if (this->config.rational_resampler) { + blocks["upsampler"] = std::shared_ptr(new SpatialRationalResampler(this->config.mid_channels, + this->config.spatial_up_num, + this->config.spatial_down_den)); + } else if (this->config.temporal_upsample) { + blocks["upsampler.0"] = std::shared_ptr(new Conv3d(this->config.mid_channels, + this->config.temporal_up_factor * this->config.mid_channels, + {3, 3, 3}, + {1, 1, 1}, + {1, 1, 1})); + blocks["upsampler.1"] = std::shared_ptr(new TemporalPixelShuffleND(this->config.temporal_up_factor)); + } else { + blocks["upsampler.0"] = std::shared_ptr(new Conv2d(this->config.mid_channels, + 4 * this->config.mid_channels, + {3, 3}, + {1, 1}, + {1, 1})); + blocks["upsampler.1"] = std::shared_ptr(new PixelShuffleND(2)); + } for (int i = 0; i < this->config.num_blocks_per_stage; ++i) { blocks["post_upsample_res_blocks." + std::to_string(i)] = std::shared_ptr(new ResBlock(this->config.mid_channels, this->config.dims)); } @@ -207,13 +378,11 @@ namespace LTXVUpsampler { } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { - // x: [b*c, f, h, w] - // return: [b*c, f, h*2, w*2] - auto initial_conv = std::dynamic_pointer_cast(blocks["initial_conv"]); - auto initial_norm = std::dynamic_pointer_cast(blocks["initial_norm"]); - auto upsample_conv = std::dynamic_pointer_cast(blocks["upsampler.0"]); - auto pixel_shuffle = std::dynamic_pointer_cast(blocks["upsampler.1"]); - auto final_conv = std::dynamic_pointer_cast(blocks["final_conv"]); + // x: [b, c, f, h, w] + // return: [b, c, scaled_f, scaled_h, scaled_w] + auto initial_conv = std::dynamic_pointer_cast(blocks["initial_conv"]); + auto initial_norm = std::dynamic_pointer_cast(blocks["initial_norm"]); + auto final_conv = std::dynamic_pointer_cast(blocks["final_conv"]); x = initial_conv->forward(ctx, x); x = initial_norm->forward(ctx, x); @@ -226,11 +395,25 @@ namespace LTXVUpsampler { sd::ggml_graph_cut::mark_graph_cut(x, "ltx_latent_upsampler.res_blocks." + std::to_string(i), "x"); } - // rearrange(x, "b c f h w -> (b f) c h w"), - x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // [b*f, c, h, w] - x = upsample_conv->forward(ctx, x); // [b*f, c*4, h, w] - x = pixel_shuffle->forward(ctx, x); // [b*f, c, h*2, w*2] - x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // [b*c, f, h, w] + if (config.rational_resampler) { + auto upsampler = std::dynamic_pointer_cast(blocks["upsampler"]); + x = upsampler->forward(ctx, x); + } else if (config.temporal_upsample) { + auto upsample_conv = std::dynamic_pointer_cast(blocks["upsampler.0"]); + auto pixel_shuffle = std::dynamic_pointer_cast(blocks["upsampler.1"]); + x = upsample_conv->forward(ctx, x); // [b, c*2, f, h, w] + x = pixel_shuffle->forward(ctx, x); // [b, c, f*2, h, w] + x = ggml_ext_slice(ctx->ggml_ctx, x, 2, 1, x->ne[2]); // x[:, :, 1:, :, :] + } else { + auto upsample_conv = std::dynamic_pointer_cast(blocks["upsampler.0"]); + auto pixel_shuffle = std::dynamic_pointer_cast(blocks["upsampler.1"]); + + // rearrange(x, "b c f h w -> (b f) c h w"), + x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // [b*f, c, h, w] + x = upsample_conv->forward(ctx, x); // [b*f, c*4, h, w] + x = pixel_shuffle->forward(ctx, x); // [b*f, c, h*2, w*2] + x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // [b*c, f, h, w] + } sd::ggml_graph_cut::mark_graph_cut(x, "ltx_latent_upsampler.spatial_up", "x"); for (int i = 0; i < config.num_blocks_per_stage; ++i) { @@ -243,6 +426,14 @@ namespace LTXVUpsampler { sd::ggml_graph_cut::mark_graph_cut(x, "ltx_latent_upsampler.final", "x"); return x; } + + void load_fixed_tensors() { + if (!config.rational_resampler) { + return; + } + auto upsampler = std::dynamic_pointer_cast(blocks["upsampler"]); + upsampler->load_fixed_tensors(); + } }; struct LatentUpsamplerRunner : public GGMLRunner { @@ -265,20 +456,24 @@ namespace LTXVUpsampler { } const auto& tensor_storage_map = model_loader.get_tensor_storage_map(); + bool has_regular_upsampler = has_tensor(tensor_storage_map, "upsampler.0.weight"); + bool has_rational_spatial = has_tensor(tensor_storage_map, "upsampler.conv.weight"); if (!has_tensor(tensor_storage_map, "post_upsample_res_blocks.0.conv2.bias") || - !has_tensor(tensor_storage_map, "upsampler.0.weight")) { - LOG_ERROR("unsupported LTX latent upsampler weights: expected spatial upsampler tensors"); + (!has_regular_upsampler && !has_rational_spatial)) { + LOG_ERROR("unsupported LTX latent upsampler weights: expected upsampler tensors"); return false; } LatentUpsamplerConfig config = detect_config_from_weights(tensor_storage_map); - if (config.dims != 3 || !config.spatial_upsample || config.temporal_upsample || - config.rational_resampler) { - LOG_ERROR("unsupported LTX latent upsampler config: dims=%d spatial=%d temporal=%d rational=%d", + if (config.dims != 3 || (!config.spatial_upsample && !config.temporal_upsample) || + config.spatial_up_num < 1 || config.spatial_down_den < 1 || config.temporal_up_factor < 1) { + LOG_ERROR("unsupported LTX latent upsampler config: dims=%d spatial=%d temporal=%d rational=%d scale=%.3f temporal_factor=%d", config.dims, config.spatial_upsample, config.temporal_upsample, - config.rational_resampler); + config.rational_resampler, + config.spatial_scale, + config.temporal_up_factor); return false; } @@ -291,15 +486,23 @@ namespace LTXVUpsampler { std::map tensors; model->get_param_tensors(tensors); - if (!model_loader.load_tensors(tensors, {}, n_threads)) { + std::set ignore_tensors; + if (config.rational_resampler) { + ignore_tensors.insert("upsampler.blur_down.kernel"); + } + if (!model_loader.load_tensors(tensors, ignore_tensors, n_threads)) { LOG_ERROR("load LTX latent upsampler tensors failed"); return false; } + model->load_fixed_tensors(); - LOG_INFO("LTX latent upsampler loaded: in_channels=%" PRId64 ", mid_channels=%" PRId64 ", blocks=%d", + LOG_INFO("LTX latent upsampler loaded: in_channels=%" PRId64 ", mid_channels=%" PRId64 ", blocks=%d, scale=%.3f, temporal_factor=%d, rational=%d", config.in_channels, config.mid_channels, - config.num_blocks_per_stage); + config.num_blocks_per_stage, + config.spatial_scale, + config.temporal_up_factor, + config.rational_resampler); return true; } diff --git a/otherarch/sdcpp/ltx_vae.hpp b/otherarch/sdcpp/ltx_vae.hpp index 756741a43..5fdce6c28 100644 --- a/otherarch/sdcpp/ltx_vae.hpp +++ b/otherarch/sdcpp/ltx_vae.hpp @@ -153,7 +153,7 @@ namespace LTXVAE { GGML_ASSERT(x->ne[2] >= temporal_pad); - int end_idx = x->ne[2] - temporal_pad; + int end_idx = (int)x->ne[2] - temporal_pad; int start_idx = std::max(end_idx - pad, 0); // Save a contiguous copy of the last `pad` frames so the large `x` diff --git a/otherarch/sdcpp/stable-diffusion.cpp b/otherarch/sdcpp/stable-diffusion.cpp index 4643a919f..b7723f876 100644 --- a/otherarch/sdcpp/stable-diffusion.cpp +++ b/otherarch/sdcpp/stable-diffusion.cpp @@ -963,7 +963,7 @@ public: LOG_INFO("using TAE for preview"); preview_vae = create_tae(); preview_vae->set_max_graph_vram_bytes(max_graph_vram_bytes); - get_param_tensors_p(first_stage_model, vae_mmap, "vae"); + get_param_tensors_p(preview_vae, vae_mmap, "tae"); } } @@ -1835,23 +1835,18 @@ public: std::function step_callback, void* step_callback_data, bool is_noisy) { + bool is_video = preview_latent_tensor_is_video(latents); + uint32_t dim = is_video ? static_cast(latents.shape()[3]) : static_cast(latents.shape()[2]); + int channels = get_latent_channel(); + auto _latents = channels != dim ? is_video ? sd::ops::slice(latents, 3, 0, channels) + : sd::ops::slice(latents, 2, 0, channels) + : latents; if (preview_mode == PREVIEW_PROJ) { - sd::Tensor _latents = latents; int patch_sz = 1; const float(*latent_rgb_proj)[3] = nullptr; float* latent_rgb_bias = nullptr; - bool is_video = preview_latent_tensor_is_video(latents); - uint32_t dim = is_video ? static_cast(latents.shape()[3]) : static_cast(latents.shape()[2]); - if (version == VERSION_LTXAV) { - if (is_video) { - _latents = sd::ops::slice(_latents, 3, 0, 128); - } else { - _latents = sd::ops::slice(_latents, 2, 0, 128); - } - dim = 128; - } - if (dim == 128) { + if (channels == 128) { if (sd_version_uses_flux2_vae(version)) { latent_rgb_proj = flux2_latent_rgb_proj; latent_rgb_bias = flux2_latent_rgb_bias; @@ -1863,7 +1858,7 @@ public: LOG_WARN("No latent to RGB projection known for this model"); return; } - } else if (dim == 48) { + } else if (channels == 48) { if (sd_version_is_wan(version)) { latent_rgb_proj = wan_22_latent_rgb_proj; latent_rgb_bias = wan_22_latent_rgb_bias; @@ -1871,7 +1866,7 @@ public: LOG_WARN("No latent to RGB projection known for this model"); return; } - } else if (dim == 16) { + } else if (channels == 16) { if (sd_version_is_sd3(version)) { latent_rgb_proj = sd3_latent_rgb_proj; latent_rgb_bias = sd3_latent_rgb_bias; @@ -1885,7 +1880,7 @@ public: LOG_WARN("No latent to RGB projection known for this model"); return; } - } else if (dim == 4) { + } else if (channels == 4) { if (sd_version_is_sdxl(version)) { latent_rgb_proj = sdxl_latent_rgb_proj; latent_rgb_bias = sdxl_latent_rgb_bias; @@ -1896,8 +1891,8 @@ public: LOG_WARN("No latent to RGB projection known for this model"); return; } - } else if (dim != 3) { - LOG_WARN("No latent to RGB projection known for this model"); + } else if (channels != 3) { + LOG_WARN("No latent to RGB projection known for this model (dim = %d)", dim); return; } @@ -1922,14 +1917,13 @@ public: if (preview_mode == PREVIEW_VAE || preview_mode == PREVIEW_TAE) { sd::Tensor vae_latents; sd::Tensor decoded; - bool is_video = preview_latent_tensor_is_video(latents); if (preview_vae) { preview_vae->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling); - vae_latents = preview_vae->diffusion_to_vae_latents(latents); + vae_latents = preview_vae->diffusion_to_vae_latents(_latents); decoded = preview_vae->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true); } else { first_stage_model->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling); - vae_latents = first_stage_model->diffusion_to_vae_latents(latents); + vae_latents = first_stage_model->diffusion_to_vae_latents(_latents); decoded = first_stage_model->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true); } if (decoded.empty()) { @@ -2384,19 +2378,41 @@ public: int vae_scale_factor = get_vae_scale_factor(); int W = width / vae_scale_factor; int H = height / vae_scale_factor; - int T = frames; - if (sd_version_is_ltxav(version)) { - T = ((T - 1) / 8) + 1; - } else if (sd_version_is_wan(version)) { - T = ((T - 1) / 4) + 1; - } - int C = get_latent_channel(); + int T = video_frames_to_latent_frames(frames); + int C = get_latent_channel(); if (video) { return sd::zeros({W, H, T, C, 1}); } return sd::zeros({W, H, C, 1}); } + int video_frames_to_latent_frames(int frames) { + int latent_frames = frames; + if (sd_version_is_ltxav(version)) { + latent_frames = ((frames - 1) / 8) + 1; + } else if (sd_version_is_wan(version)) { + latent_frames = ((frames - 1) / 4) + 1; + } + return latent_frames; + } + + int latent_frames_to_video_frames(int latent_frames) { + if (latent_frames <= 0) { + return latent_frames; + } + if (sd_version_is_ltxav(version)) { + return (latent_frames - 1) * 8 + 1; + } + if (sd_version_is_wan(version)) { + return (latent_frames - 1) * 4 + 1; + } + return latent_frames; + } + + int align_video_frames(int frames) { + return latent_frames_to_video_frames(video_frames_to_latent_frames(frames)); + } + sd::Tensor encode_to_vae_latents(const sd::Tensor& x) { auto latents = first_stage_model->encode(n_threads, x, vae_tiling_params, circular_x, circular_y); if (latents.empty()) { @@ -3248,16 +3264,12 @@ struct GenerationRequest { } GenerationRequest(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params) { - prompt = SAFE_STR(sd_vid_gen_params->prompt); - negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt); - width = sd_vid_gen_params->width; - height = sd_vid_gen_params->height; - requested_frames = std::max(1, sd_vid_gen_params->video_frames); - if (sd_version_is_ltxav(sd_ctx->sd->version)) { - frames = ((requested_frames - 1 + 7) / 8) * 8 + 1; - } else { - frames = (requested_frames - 1) / 4 * 4 + 1; - } + prompt = SAFE_STR(sd_vid_gen_params->prompt); + negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt); + width = sd_vid_gen_params->width; + height = sd_vid_gen_params->height; + requested_frames = std::max(1, sd_vid_gen_params->video_frames); + frames = sd_ctx->sd->align_video_frames(requested_frames); clip_skip = sd_vid_gen_params->clip_skip; fps = std::max(1, sd_vid_gen_params->fps); vae_scale_factor = sd_ctx->sd->get_vae_scale_factor(); @@ -3815,6 +3827,30 @@ static sd::Tensor unpack_ltxav_audio_latent(const sd::Tensor& pack return audio_latent; } +static sd::Tensor make_ltxav_empty_audio_latent(int audio_length) { + if (audio_length <= 0) { + return {}; + } + constexpr int kLtxavAudioFrequencyBins = 16; + constexpr int kLtxavAudioChannels = 8; + return sd::zeros({kLtxavAudioFrequencyBins, audio_length, kLtxavAudioChannels, 1}); +} + +static sd::Tensor resize_ltxav_audio_latent(const sd::Tensor& audio_latent, + int target_audio_length) { + auto resized = make_ltxav_empty_audio_latent(target_audio_length); + if (resized.empty() || audio_latent.empty()) { + return resized; + } + GGML_ASSERT(audio_latent.dim() == 3 || audio_latent.dim() == 4); + int copy_length = std::min(static_cast(audio_latent.shape()[1]), target_audio_length); + if (copy_length > 0) { + auto copied = sd::ops::slice(audio_latent, 1, 0, copy_length); + sd::ops::slice_assign(&resized, 1, 0, copy_length, copied); + } + return resized; +} + static int get_ltxav_num_audio_latents(int frames, int fps) { GGML_ASSERT(frames > 0); GGML_ASSERT(fps > 0); @@ -4644,10 +4680,8 @@ static std::optional prepare_video_generation_latents(sd } if (sd_version_is_ltxav(sd_ctx->sd->version)) { - constexpr int kLtxavAudioFrequencyBins = 16; - constexpr int kLtxavAudioChannels = 8; - latents.audio_length = get_ltxav_num_audio_latents(request->frames, request->fps); - latents.audio_latent = sd::zeros({kLtxavAudioFrequencyBins, latents.audio_length, kLtxavAudioChannels, 1}); + latents.audio_length = get_ltxav_num_audio_latents(request->frames, request->fps); + latents.audio_latent = make_ltxav_empty_audio_latent(latents.audio_length); } if (sd_version_is_ltxav(sd_ctx->sd->version)) { @@ -4997,9 +5031,9 @@ static sd_image_t* decode_video_outputs(sd_ctx_t* sd_ctx, (int)vid.shape()[1], (int)vid.shape()[2], (int)vid.shape()[3]); - if (request.requested_frames > 0 && - vid.shape()[2] > request.requested_frames) { - vid = sd::ops::slice(vid, 2, 0, request.requested_frames); + if (request.frames > 0 && + vid.shape()[2] > request.frames) { + vid = sd::ops::slice(vid, 2, 0, request.frames); } sd_image_t* result_images = (sd_image_t*)calloc(vid.shape()[2], sizeof(sd_image_t)); @@ -5040,7 +5074,7 @@ static sd::Tensor upscale_ltx_spatial_video_latent(sd_ctx_t* sd_ctx, audio_latent = unpack_ltxav_audio_latent(packed_latent, audio_length, latent_channels); } - LOG_INFO("LTX latent spatial upscale: latent %dx%dx%dx%d -> x2", + LOG_INFO("LTX latent spatial upscale: latent %dx%dx%dx%d -> model output", (int)video_latent.shape()[0], (int)video_latent.shape()[1], (int)video_latent.shape()[2], @@ -5366,9 +5400,46 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, LOG_INFO("LTX latent spatial upscale completed, taking %.2fs", (upscale_end - upscale_start) * 1.0f / 1000); - x_t = std::move(upscaled_latent); - hires_request.width = static_cast(x_t.shape()[0]) * hires_request.vae_scale_factor; - hires_request.height = static_cast(x_t.shape()[1]) * hires_request.vae_scale_factor; + x_t = std::move(upscaled_latent); + hires_request.width = static_cast(x_t.shape()[0]) * hires_request.vae_scale_factor; + hires_request.height = static_cast(x_t.shape()[1]) * hires_request.vae_scale_factor; + int upscaled_latent_frames = static_cast(x_t.shape()[2]); + int upscaled_frames = sd_ctx->sd->latent_frames_to_video_frames(upscaled_latent_frames); + if (upscaled_frames != hires_request.frames) { + LOG_INFO("LTX latent upsampler output latent frames %d, frames %d -> %d", + upscaled_latent_frames, + hires_request.frames, + upscaled_frames); + hires_request.frames = upscaled_frames; + } + if (sd_version_is_ltxav(sd_ctx->sd->version) && latents.audio_length > 0) { + int target_audio_length = get_ltxav_num_audio_latents(hires_request.frames, hires_request.fps); + if (target_audio_length != latents.audio_length) { + int latent_channels = sd_ctx->sd->get_latent_channel(); + sd::Tensor video_latent = x_t; + sd::Tensor audio_latent = latents.audio_latent; + if (x_t.shape()[3] > latent_channels) { + video_latent = sd::ops::slice(x_t, 3, 0, latent_channels); + audio_latent = unpack_ltxav_audio_latent(x_t, latents.audio_length, latent_channels); + } + audio_latent = resize_ltxav_audio_latent(audio_latent, target_audio_length); + if (audio_latent.empty()) { + LOG_ERROR("failed to resize LTX audio latent for latent upscale: %d -> %d", + latents.audio_length, + target_audio_length); + if (sd_ctx->sd->free_params_immediately) { + sd_ctx->sd->diffusion_model->free_params_buffer(); + } + return false; + } + x_t = pack_ltxav_audio_and_video_latents(video_latent, audio_latent); + latents.audio_latent = std::move(audio_latent); + LOG_INFO("LTX audio latent length adjusted for latent upscale: %d -> %d", + latents.audio_length, + target_audio_length); + latents.audio_length = target_audio_length; + } + } if ((request.hires.target_width > 0 || request.hires.target_height > 0) && (request.hires.target_width != hires_request.width || request.hires.target_height != hires_request.height)) { LOG_WARN("LTX latent spatial upsampler output is %dx%d; ignoring hires target %dx%d",