diff --git a/otherarch/sdcpp/common/common.cpp b/otherarch/sdcpp/common/common.cpp index f32d0c6ff..519e8aae6 100644 --- a/otherarch/sdcpp/common/common.cpp +++ b/otherarch/sdcpp/common/common.cpp @@ -835,6 +835,10 @@ ArgOptions SDGenerationParams::get_options() { "--extra-sample-args", "extra sampler/scheduler args, key=value list. lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma", &extra_sample_args}, + {"", + "--extra-tiling-args", + "extra VAE tiling args, key=value list. LTX video VAE supports temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1)", + &extra_tiling_args}, }; options.int_options = { @@ -1780,6 +1784,9 @@ bool SDGenerationParams::from_json_str( if (tiling_json.contains("rel_size_y") && tiling_json["rel_size_y"].is_number()) { vae_tiling_params.rel_size_y = tiling_json["rel_size_y"]; } + if (tiling_json.contains("extra_tiling_args") && tiling_json["extra_tiling_args"].is_string()) { + extra_tiling_args = tiling_json["extra_tiling_args"].get(); + } } if (!parse_lora_json_field(j, lora_path_resolver, lora_map, high_noise_lora_map)) { @@ -2002,6 +2009,8 @@ bool SDGenerationParams::initialize_cache_params() { } bool SDGenerationParams::resolve(const std::string& lora_model_dir, const std::string& hires_upscalers_dir, bool strict) { + vae_tiling_params.extra_tiling_args = extra_tiling_args.empty() ? nullptr : extra_tiling_args.c_str(); + if (high_noise_sample_params.sample_steps <= 0) { high_noise_sample_params.sample_steps = -1; } @@ -2188,6 +2197,7 @@ sd_img_gen_params_t SDGenerationParams::to_sd_img_gen_params_t() { sample_params.custom_sigmas_count = static_cast(custom_sigmas.size()); sample_params.extra_sample_args = extra_sample_args.empty() ? nullptr : extra_sample_args.c_str(); high_noise_sample_params.extra_sample_args = high_noise_extra_sample_args.empty() ? nullptr : high_noise_extra_sample_args.c_str(); + vae_tiling_params.extra_tiling_args = extra_tiling_args.empty() ? nullptr : extra_tiling_args.c_str(); cache_params.scm_mask = scm_mask.empty() ? nullptr : scm_mask.c_str(); sd_pm_params_t pm_params = { @@ -2261,6 +2271,7 @@ sd_vid_gen_params_t SDGenerationParams::to_sd_vid_gen_params_t() { sample_params.custom_sigmas_count = static_cast(custom_sigmas.size()); sample_params.extra_sample_args = extra_sample_args.empty() ? nullptr : extra_sample_args.c_str(); high_noise_sample_params.extra_sample_args = high_noise_extra_sample_args.empty() ? nullptr : high_noise_extra_sample_args.c_str(); + vae_tiling_params.extra_tiling_args = extra_tiling_args.empty() ? nullptr : extra_tiling_args.c_str(); cache_params.scm_mask = scm_mask.empty() ? nullptr : scm_mask.c_str(); params.loras = lora_vec.empty() ? nullptr : lora_vec.data(); @@ -2386,7 +2397,8 @@ std::string SDGenerationParams::to_string() const { << vae_tiling_params.tile_size_y << ", " << vae_tiling_params.target_overlap << ", " << vae_tiling_params.rel_size_x << ", " - << vae_tiling_params.rel_size_y << " },\n" + << vae_tiling_params.rel_size_y << ", " + << "\"" << extra_tiling_args << "\" },\n" << "}"; return oss.str(); } @@ -2565,14 +2577,18 @@ std::string build_sdcpp_image_metadata_json(const SDContextParams& ctx_params, }; } - if (gen_params.vae_tiling_params.enabled) { + if (gen_params.vae_tiling_params.enabled || + gen_params.vae_tiling_params.temporal_tiling || + !gen_params.extra_tiling_args.empty()) { root["vae_tiling"] = { {"enabled", gen_params.vae_tiling_params.enabled}, + {"temporal_tiling", gen_params.vae_tiling_params.temporal_tiling}, {"tile_size_x", gen_params.vae_tiling_params.tile_size_x}, {"tile_size_y", gen_params.vae_tiling_params.tile_size_y}, {"target_overlap", gen_params.vae_tiling_params.target_overlap}, {"rel_size_x", gen_params.vae_tiling_params.rel_size_x}, {"rel_size_y", gen_params.vae_tiling_params.rel_size_y}, + {"extra_tiling_args", gen_params.extra_tiling_args}, }; } diff --git a/otherarch/sdcpp/common/common.h b/otherarch/sdcpp/common/common.h index d526ca3a5..ca367f7ee 100644 --- a/otherarch/sdcpp/common/common.h +++ b/otherarch/sdcpp/common/common.h @@ -189,7 +189,8 @@ struct SDGenerationParams { int video_frames = 1; int fps = 16; float vace_strength = 1.f; - sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f}; + sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr}; + std::string extra_tiling_args; std::string pm_id_images_dir; std::string pm_id_embed_path; diff --git a/otherarch/sdcpp/denoiser.hpp b/otherarch/sdcpp/denoiser.hpp index ee2ef380c..3d884bc79 100644 --- a/otherarch/sdcpp/denoiser.hpp +++ b/otherarch/sdcpp/denoiser.hpp @@ -496,84 +496,26 @@ struct LTX2Scheduler : SigmaScheduler { parse_extra_sample_args(extra_sample_args); } - static std::string trim(std::string value) { - const char* whitespace = " \t\r\n"; - size_t begin = value.find_first_not_of(whitespace); - if (begin == std::string::npos) { - return ""; - } - size_t end = value.find_last_not_of(whitespace); - return value.substr(begin, end - begin + 1); - } - void parse_extra_sample_args(const char* extra_sample_args) { - if (extra_sample_args == nullptr || extra_sample_args[0] == '\0') { - return; - } - - std::string raw(extra_sample_args); - size_t start = 0; - auto parse_arg = [&](const std::string& item) { - std::string token = trim(item); - if (token.empty()) { - return; - } - size_t eq = token.find('='); - if (eq == std::string::npos) { - LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str()); - return; - } - - std::string key = trim(token.substr(0, eq)); - std::string value = trim(token.substr(eq + 1)); - auto parse_float = [&](float* out) -> bool { - try { - size_t consumed = 0; - float parsed = std::stof(value, &consumed); - if (!trim(value.substr(consumed)).empty()) { - return false; - } - *out = parsed; - return true; - } catch (const std::exception&) { - return false; + for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "ltx2 scheduler arg")) { + if (key == "max_shift") { + if (!parse_strict_float(value, max_shift)) { + LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str()); } - }; - try { - if (key == "max_shift") { - if (!parse_float(&max_shift)) { - LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str()); - } - } else if (key == "base_shift") { - if (!parse_float(&base_shift)) { - LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str()); - } - } else if (key == "terminal") { - if (!parse_float(&terminal)) { - LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str()); - } - } else if (key == "stretch") { - std::string v = value; - std::transform(v.begin(), v.end(), v.begin(), [](unsigned char c) { return static_cast(std::tolower(c)); }); - if (v == "1" || v == "true" || v == "yes" || v == "on") { - stretch = true; - } else if (v == "0" || v == "false" || v == "no" || v == "off") { - stretch = false; - } else { - LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str()); - } - } else { - LOG_WARN("ignoring unknown ltx2 scheduler arg '%s'", key.c_str()); + } else if (key == "base_shift") { + if (!parse_strict_float(value, base_shift)) { + LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str()); } - } catch (const std::exception&) { - LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str()); - } - }; - - for (size_t pos = 0; pos <= raw.size(); ++pos) { - if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') { - parse_arg(raw.substr(start, pos - start)); - start = pos + 1; + } else if (key == "terminal") { + if (!parse_strict_float(value, terminal)) { + LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str()); + } + } else if (key == "stretch") { + if (!parse_strict_bool(value, stretch)) { + LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str()); + } + } else { + LOG_WARN("ignoring unknown ltx2 scheduler arg '%s'", key.c_str()); } } } @@ -1276,7 +1218,7 @@ static sd::Tensor sample_dpmpp_2m_v2(denoise_cb_t model, return x; } -using SamplerExtraArgs = std::vector>; +using SamplerExtraArgs = KeyValueArgs; static sd::Tensor sample_lcm(denoise_cb_t model, sd::Tensor x, @@ -1296,15 +1238,8 @@ static sd::Tensor sample_lcm(denoise_cb_t model, for (const auto& [key, value] : extra_sample_args) { float parsed = 0.0f; - try { - size_t consumed = 0; - parsed = std::stof(value, &consumed); - if (trim(value.substr(consumed)).size() != 0) { - LOG_WARN("ignoring invalid lcm extra sample arg '%s'", key.c_str()); - continue; - } - } catch (const std::exception&) { - LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str()); + if (!parse_strict_float(value, parsed)) { + LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str(), value.c_str()); continue; } if (key == "noise_clip_std") { @@ -1861,15 +1796,8 @@ static sd::Tensor sample_gradient_estimation(denoise_cb_t model, for (const auto& [key, value] : extra_sample_args) { float parsed = 0.0f; - try { - size_t consumed = 0; - parsed = std::stof(value, &consumed); - if (trim(value.substr(consumed)).size() != 0) { - LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str()); - continue; - } - } catch (const std::exception&) { - LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str()); + if (!parse_strict_float(value, parsed)) { + LOG_WARN("ignoring invalid euler_ge extra sample arg '%s=%s'", key.c_str(), value.c_str()); continue; } if (key == "gamma") { @@ -1916,46 +1844,6 @@ static sd::Tensor sample_gradient_estimation(denoise_cb_t model, return x; } -static SamplerExtraArgs parse_sampler_args(const char* extra_sample_args) { - SamplerExtraArgs pairs; - - if (extra_sample_args == nullptr || extra_sample_args[0] == '\0') { - return pairs; - } - - auto trim = [](std::string value) -> std::string { - const char* whitespace = " \t\r\n"; - size_t begin = value.find_first_not_of(whitespace); - if (begin == std::string::npos) { - return ""; - } - size_t end = value.find_last_not_of(whitespace); - return value.substr(begin, end - begin + 1); - }; - - std::string raw(extra_sample_args); - size_t start = 0; - - for (size_t pos = 0; pos <= raw.size(); ++pos) { - if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') { - std::string item = raw.substr(start, pos - start); - std::string token = trim(item); - - if (!token.empty()) { - size_t eq = token.find('='); - if (eq != std::string::npos) { - std::string key = trim(token.substr(0, eq)); - std::string value = trim(token.substr(eq + 1)); - pairs.emplace_back(std::move(key), std::move(value)); - } - } - start = pos + 1; - } - } - - return pairs; -} - // k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t static sd::Tensor sample_k_diffusion(sample_method_t method, denoise_cb_t model, @@ -1965,7 +1853,7 @@ static sd::Tensor sample_k_diffusion(sample_method_t method, float eta, bool is_flow_denoiser, const char* extra_sample_args) { - SamplerExtraArgs extra_args = parse_sampler_args(extra_sample_args); + SamplerExtraArgs extra_args = parse_key_value_args(extra_sample_args, "extra sample arg"); switch (method) { case EULER_A_SAMPLE_METHOD: return sample_euler_ancestral(model, std::move(x), sigmas, rng, is_flow_denoiser, eta); diff --git a/otherarch/sdcpp/ggml_extend.hpp b/otherarch/sdcpp/ggml_extend.hpp index 9407b7f44..a7d84514d 100644 --- a/otherarch/sdcpp/ggml_extend.hpp +++ b/otherarch/sdcpp/ggml_extend.hpp @@ -1602,6 +1602,23 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) { return num; } +__STATIC_INLINE__ ggml_tensor* ggml_ext_vec_concat(ggml_context* ctx, + std::vector& tensors, + int dim) { + while (tensors.size() > 1) { + std::vector next_level; + for (size_t i = 0; i < tensors.size(); i += 2) { + if (i + 1 < tensors.size()) { + next_level.push_back(ggml_concat(ctx, tensors[i], tensors[i + 1], dim)); + } else { + next_level.push_back(tensors[i]); + } + } + tensors = std::move(next_level); + } + return tensors[0]; +} + /* SDXL with LoRA requires more space */ #define MAX_PARAMS_TENSOR_NUM 32768 #define MAX_GRAPH_SIZE 327680 @@ -3139,6 +3156,163 @@ public: } }; +class Conv2d_grouped : public UnaryBlock { +protected: + int64_t in_channels; + int64_t out_channels; + int groups; + std::pair kernel_size; + std::pair stride; + std::pair padding; + std::pair dilation; + bool bias; + float scale = 1.f; + std::string prefix; + + void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override { + this->prefix = prefix; + enum ggml_type wtype = GGML_TYPE_F16; + params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels / groups, out_channels); + if (bias) { + enum ggml_type wtype = GGML_TYPE_F32; + params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_channels); + } + } + +public: + Conv2d_grouped(int64_t in_channels, + int64_t out_channels, + int groups, + std::pair kernel_size, + std::pair stride = {1, 1}, + std::pair padding = {0, 0}, + std::pair dilation = {1, 1}, + bool bias = true) + : in_channels(in_channels), + out_channels(out_channels), + groups(groups), + kernel_size(kernel_size), + stride(stride), + padding(padding), + dilation(dilation), + bias(bias) {} + + void set_scale(float scale_value) { + scale = scale_value; + } + + std::string get_desc() { + return "Conv2d_grouped"; + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + ggml_tensor* w = params["weight"]; + ggml_tensor* b = nullptr; + if (bias) { + b = params["bias"]; + } + + if (groups == 1) { + if (ctx->weight_adapter) { + WeightAdapter::ForwardParams forward_params; + forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_CONV2D; + forward_params.conv2d.s0 = stride.second; + forward_params.conv2d.s1 = stride.first; + forward_params.conv2d.p0 = padding.second; + forward_params.conv2d.p1 = padding.first; + forward_params.conv2d.d0 = dilation.second; + forward_params.conv2d.d1 = dilation.first; + forward_params.conv2d.direct = ctx->conv2d_direct_enabled; + forward_params.conv2d.circular_x = ctx->circular_x_enabled; + forward_params.conv2d.circular_y = ctx->circular_y_enabled; + forward_params.conv2d.scale = scale; + return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, ctx->backend, x, w, b, prefix, forward_params); + } + return ggml_ext_conv_2d(ctx->ggml_ctx, x, w, b, + stride.second, stride.first, + padding.second, padding.first, + dilation.second, dilation.first, + ctx->conv2d_direct_enabled, + ctx->circular_x_enabled, + ctx->circular_y_enabled, + scale); + } + + if (groups == in_channels && groups == out_channels) { + ggml_tensor* res; + if (ctx->conv2d_direct_enabled) { + res = ggml_conv_2d_dw_direct(ctx->ggml_ctx, x, w, + stride.second, stride.first, + padding.second, padding.first, + dilation.second, dilation.first); + } else { + res = ggml_conv_2d_dw(ctx->ggml_ctx, x, w, + stride.second, stride.first, + padding.second, padding.first, + dilation.second, dilation.first); + } + if (b) { + res = ggml_add(ctx->ggml_ctx, res, b); + } + return res; + } + + int64_t ic_g = in_channels / groups; + int64_t oc_g = out_channels / groups; + + std::vector out_slices(groups); + + for (int i = 0; i < groups; ++i) { + size_t x_offset = i * ic_g * x->nb[2]; + ggml_tensor* x_i = ggml_view_4d(ctx->ggml_ctx, x, + x->ne[0], x->ne[1], ic_g, x->ne[3], + x->nb[1], x->nb[2], x->nb[3], + x_offset); + + size_t w_offset = i * oc_g * w->nb[3]; + ggml_tensor* w_i = ggml_view_4d(ctx->ggml_ctx, w, + w->ne[0], w->ne[1], w->ne[2], oc_g, + w->nb[1], w->nb[2], w->nb[3], + w_offset); + + ggml_tensor* b_i = nullptr; + if (b) { + size_t b_offset = i * oc_g * b->nb[0]; + b_i = ggml_view_1d(ctx->ggml_ctx, b, oc_g, b_offset); + } + + if (ctx->weight_adapter) { + WeightAdapter::ForwardParams forward_params; + forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_CONV2D; + forward_params.conv2d.s0 = stride.second; + forward_params.conv2d.s1 = stride.first; + forward_params.conv2d.p0 = padding.second; + forward_params.conv2d.p1 = padding.first; + forward_params.conv2d.d0 = dilation.second; + forward_params.conv2d.d1 = dilation.first; + forward_params.conv2d.direct = ctx->conv2d_direct_enabled; + forward_params.conv2d.circular_x = ctx->circular_x_enabled; + forward_params.conv2d.circular_y = ctx->circular_y_enabled; + forward_params.conv2d.scale = scale; + out_slices[i] = ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, ctx->backend, x_i, w_i, b_i, prefix, forward_params); + } else { + out_slices[i] = ggml_ext_conv_2d(ctx->ggml_ctx, x_i, w_i, b_i, + stride.second, stride.first, + padding.second, padding.first, + dilation.second, dilation.first, + ctx->conv2d_direct_enabled, + ctx->circular_x_enabled, + ctx->circular_y_enabled, + scale); + } + } + + ggml_tensor* out = ggml_ext_vec_concat(ctx->ggml_ctx, out_slices, 2); + + return out; + } +}; + class Conv3d : public UnaryBlock { protected: int64_t in_channels; diff --git a/otherarch/sdcpp/ltx_audio_vae.h b/otherarch/sdcpp/ltx_audio_vae.h index aad8e0f87..a160338f4 100644 --- a/otherarch/sdcpp/ltx_audio_vae.h +++ b/otherarch/sdcpp/ltx_audio_vae.h @@ -2,6 +2,7 @@ #define __SD_LTX_AUDIO_VAE_H__ #include +#include #include #include #include @@ -171,90 +172,59 @@ namespace LTXV { } }; - static sd::Tensor squeeze_trailing_singleton_dims(sd::Tensor tensor) { - while (tensor.dim() > 0 && tensor.shape().back() == 1) { - tensor = tensor.squeeze(static_cast(tensor.dim() - 1)); - } - return tensor; - } + static ggml_tensor* compute_log_mel_spectrogram(GGMLRunnerContext* runner_ctx, + ggml_tensor* waveform, + ggml_tensor* forward_basis, + ggml_tensor* mel_basis, + int hop_length) { + auto ctx = runner_ctx->ggml_ctx; + GGML_ASSERT(ctx != nullptr); + GGML_ASSERT(waveform != nullptr); + GGML_ASSERT(forward_basis != nullptr); + GGML_ASSERT(mel_basis != nullptr); + GGML_ASSERT(waveform->type == GGML_TYPE_F32); + GGML_ASSERT(forward_basis->type == GGML_TYPE_F32); + GGML_ASSERT(mel_basis->type == GGML_TYPE_F32); + GGML_ASSERT(forward_basis->ne[1] == 1); - static sd::Tensor normalize_waveform_for_host(sd::Tensor waveform) { - waveform = squeeze_trailing_singleton_dims(std::move(waveform)); - if (waveform.empty()) { - return waveform; - } - if (waveform.dim() == 1) { - return waveform.reshape({waveform.shape()[0], 1, 1}); - } - if (waveform.dim() == 2) { - return waveform.reshape({waveform.shape()[0], waveform.shape()[1], 1}); - } - if (waveform.dim() == 3) { - return waveform; - } - throw std::runtime_error("Unsupported waveform rank for host processing: rank=" + std::to_string(waveform.dim())); - } + const int64_t time = waveform->ne[0]; + const int64_t channels = waveform->ne[1]; + const int64_t batch = waveform->ne[2]; + const int64_t filter_len = forward_basis->ne[0]; + const int64_t stft_channels = forward_basis->ne[2]; + const int64_t n_freqs = stft_channels / 2; + const int64_t n_mels = mel_basis->ne[1]; + const int64_t left_pad = std::max(0, filter_len - hop_length); + const int64_t padded_time = time + left_pad; + const int64_t frame_count = padded_time < filter_len ? 0 : 1 + (padded_time - filter_len) / hop_length; - static sd::Tensor load_param_tensor_f32(ggml_tensor* tensor) { - GGML_ASSERT(tensor != nullptr); - return squeeze_trailing_singleton_dims(sd::make_sd_tensor_from_ggml(tensor)); - } + GGML_ASSERT(stft_channels % 2 == 0); + GGML_ASSERT(mel_basis->ne[0] == n_freqs); + GGML_ASSERT(waveform->ne[3] == 1); + GGML_ASSERT(frame_count > 0); - static sd::Tensor compute_log_mel_spectrogram(const sd::Tensor& waveform_in, - const sd::Tensor& forward_basis, - const sd::Tensor& mel_basis, - int hop_length) { - auto waveform = normalize_waveform_for_host(waveform_in); - GGML_ASSERT(forward_basis.dim() >= 3); - GGML_ASSERT(mel_basis.dim() >= 2); - - const int64_t time = waveform.shape()[0]; - const int64_t channels = waveform.shape()[1]; - const int64_t batch = waveform.shape()[2]; - const int64_t filter_len = forward_basis.shape()[0]; - const int64_t basis_freq2 = forward_basis.shape().back(); - const int64_t n_freqs = basis_freq2 / 2; - const int64_t n_mels = mel_basis.shape()[1]; - const int64_t left_pad = std::max(0, filter_len - hop_length); - const int64_t padded_time = time + left_pad; - const int64_t frame_count = padded_time < filter_len ? 0 : 1 + (padded_time - filter_len) / hop_length; - - sd::Tensor log_mel({n_mels, frame_count, channels, batch}); - std::vector padded(static_cast(padded_time), 0.0f); - std::vector magnitude(static_cast(n_freqs), 0.0f); - - for (int64_t b = 0; b < batch; ++b) { - for (int64_t c = 0; c < channels; ++c) { - std::fill(padded.begin(), padded.end(), 0.0f); - for (int64_t t = 0; t < time; ++t) { - padded[static_cast(t + left_pad)] = waveform.index(t, c, b); - } - - for (int64_t frame = 0; frame < frame_count; ++frame) { - const int64_t frame_offset = frame * hop_length; - for (int64_t f = 0; f < n_freqs; ++f) { - double real = 0.0; - double imag = 0.0; - for (int64_t k = 0; k < filter_len; ++k) { - const float sample = padded[static_cast(frame_offset + k)]; - real += static_cast(sample) * static_cast(forward_basis.index(k, 0, f)); - imag += static_cast(sample) * static_cast(forward_basis.index(k, 0, f + n_freqs)); - } - magnitude[static_cast(f)] = static_cast(std::sqrt(real * real + imag * imag)); - } - - for (int64_t m = 0; m < n_mels; ++m) { - double mel_value = 0.0; - for (int64_t f = 0; f < n_freqs; ++f) { - mel_value += static_cast(mel_basis.index(f, m)) * static_cast(magnitude[static_cast(f)]); - } - log_mel.index(m, frame, c, b) = static_cast(std::log(std::max(mel_value, 1e-5))); - } - } - } + auto x = ggml_reshape_3d(ctx, waveform, time, 1, channels * batch); + if (left_pad > 0) { + x = ggml_pad_ext(ctx, x, static_cast(left_pad), 0, 0, 0, 0, 0, 0, 0); } - return log_mel; + auto frames = ggml_conv_1d(ctx, forward_basis, x, hop_length, 0, 1); + GGML_ASSERT(frames->ne[0] == frame_count); + GGML_ASSERT(frames->ne[1] == stft_channels); + GGML_ASSERT(frames->ne[2] == channels * batch); + + auto real = ggml_ext_slice(ctx, frames, 1, 0, n_freqs); + auto imag = ggml_ext_slice(ctx, frames, 1, n_freqs, stft_channels); + auto magnitude = ggml_sqrt(ctx, + ggml_add(ctx, + ggml_sqr(ctx, real), + ggml_sqr(ctx, imag))); + + magnitude = ggml_cont(ctx, ggml_permute(ctx, magnitude, 1, 0, 2, 3)); + auto mel = ggml_mul_mat(ctx, mel_basis, magnitude); + mel = ggml_log(ctx, ggml_clamp(ctx, mel, 1e-5f, std::numeric_limits::max())); + + return ggml_reshape_4d(ctx, mel, n_mels, frame_count, channels, batch); } static std::vector build_hann_resample_filter(int ratio) { @@ -276,75 +246,6 @@ namespace LTXV { return filter; } - static sd::Tensor upsample_waveform_hann(const sd::Tensor& waveform_in, int ratio) { - auto waveform = normalize_waveform_for_host(waveform_in); - if (ratio <= 1) { - return waveform; - } - - const int lowpass_filter_width = 6; - const double rolloff = 0.99; - const int width = static_cast(std::ceil(static_cast(lowpass_filter_width) / rolloff)); - const int kernel_size = 2 * width * ratio + 1; - const int pad = width; - const int pad_left = 2 * width * ratio; - const int pad_right = kernel_size - ratio; - const int64_t time = waveform.shape()[0]; - const int64_t channels = waveform.shape()[1]; - const int64_t batch = waveform.shape()[2]; - const int64_t padded_time = time + 2 * pad; - const int64_t conv_out_time = (padded_time - 1) * ratio + kernel_size; - const int64_t cropped_time = conv_out_time - pad_left - pad_right; - auto filter = build_hann_resample_filter(ratio); - - sd::Tensor output({cropped_time, channels, batch}); - std::vector padded(static_cast(padded_time), 0.0f); - std::vector conv_out(static_cast(conv_out_time), 0.0f); - - for (int64_t b = 0; b < batch; ++b) { - for (int64_t c = 0; c < channels; ++c) { - std::fill(padded.begin(), padded.end(), 0.0f); - const float first = waveform.index(0, c, b); - const float last = waveform.index(time - 1, c, b); - for (int i = 0; i < pad; ++i) { - padded[static_cast(i)] = first; - padded[static_cast(pad + time + i)] = last; - } - for (int64_t t = 0; t < time; ++t) { - padded[static_cast(pad + t)] = waveform.index(t, c, b); - } - - std::fill(conv_out.begin(), conv_out.end(), 0.0f); - for (int64_t t = 0; t < padded_time; ++t) { - const double sample = static_cast(padded[static_cast(t)]) * ratio; - const int64_t out_base = t * ratio; - for (int k = 0; k < kernel_size; ++k) { - conv_out[static_cast(out_base + k)] += static_cast(sample * filter[static_cast(k)]); - } - } - - for (int64_t t = 0; t < cropped_time; ++t) { - output.index(t, c, b) = conv_out[static_cast(t + pad_left)]; - } - } - } - - return output; - } - - static sd::Tensor crop_waveform_samples(const sd::Tensor& waveform_in, int64_t target_samples) { - auto waveform = normalize_waveform_for_host(waveform_in); - if (waveform.shape()[0] == target_samples) { - return waveform; - } - if (waveform.shape()[0] > target_samples) { - return sd::ops::slice(waveform, 0, 0, target_samples); - } - sd::Tensor output({target_samples, waveform.shape()[1], waveform.shape()[2]}); - sd::ops::slice_assign(&output, 0, 0, waveform.shape()[0], waveform); - return output; - } - static ggml_type audio_conv_weight_type(ggml_type type) { return type == GGML_TYPE_BF16 ? GGML_TYPE_F16 : type; } @@ -413,22 +314,101 @@ namespace LTXV { return ggml_reshape_4d(ctx, out, out->ne[0], out->ne[1], 1, 1); } + static ggml_tensor* reverse_1d_filter(ggml_context* ctx, ggml_tensor* filter) { + GGML_ASSERT(ctx != nullptr); + GGML_ASSERT(filter != nullptr); + GGML_ASSERT(filter->ne[1] == 1); + GGML_ASSERT(filter->ne[2] == 1); + GGML_ASSERT(filter->ne[3] == 1); + + ggml_tensor* reversed = nullptr; + for (int64_t k = filter->ne[0] - 1; k >= 0; --k) { + auto slice = ggml_ext_slice(ctx, filter, 0, k, k + 1); + reversed = reversed == nullptr ? slice : ggml_concat(ctx, reversed, slice, 0); + } + return reversed; + } + static ggml_tensor* depthwise_conv_transpose1d(ggml_context* ctx, ggml_tensor* x, ggml_tensor* filter, int stride) { GGML_ASSERT(x->ne[2] == 1 && x->ne[3] == 1); GGML_ASSERT(filter->ne[1] == 1); + GGML_ASSERT(filter->ne[2] == 1 && filter->ne[3] == 1); - ggml_tensor* out = nullptr; - for (int64_t c = 0; c < x->ne[1]; ++c) { - auto xi = ggml_ext_slice(ctx, x, 1, c, c + 1); - auto yi = ggml_conv_transpose_1d(ctx, filter, xi, stride, 0, 1); - yi = ggml_ext_scale(ctx, yi, static_cast(stride)); - yi = ggml_reshape_4d(ctx, yi, yi->ne[0], 1, 1, 1); - out = out == nullptr ? yi : ggml_concat(ctx, out, yi, 1); + const int64_t time = x->ne[0]; + const int64_t channels = x->ne[1]; + const int64_t kernel_size = filter->ne[0]; + const int64_t out_time = (time - 1) * stride + kernel_size; + + auto x_flat = ggml_reshape_3d(ctx, x, 1, time, channels); + if (stride > 1) { + auto zero_unit = ggml_ext_scale(ctx, x_flat, 0.0f); + auto zero_tail = zero_unit; + for (int i = 1; i < stride - 1; ++i) { + zero_tail = ggml_concat(ctx, zero_tail, zero_unit, 0); + } + x_flat = ggml_concat(ctx, x_flat, zero_tail, 0); } - return out; + x_flat = ggml_reshape_3d(ctx, x_flat, time * stride, 1, channels); + + auto reversed_filter = reverse_1d_filter(ctx, filter); + auto out = ggml_conv_1d(ctx, reversed_filter, x_flat, 1, static_cast(kernel_size - 1), 1); + if (out->ne[0] > out_time) { + out = ggml_ext_slice(ctx, out, 0, 0, out_time); + } + GGML_ASSERT(out->ne[0] == out_time); + GGML_ASSERT(out->ne[1] == 1); + GGML_ASSERT(out->ne[2] == channels); + + out = ggml_ext_scale(ctx, out, static_cast(stride)); + return ggml_reshape_4d(ctx, out, out_time, channels, 1, 1); + } + + static ggml_tensor* upsample_waveform_hann(GGMLRunnerContext* runner_ctx, + ggml_tensor* waveform, + ggml_tensor* filter, + int ratio) { + auto ctx = runner_ctx->ggml_ctx; + GGML_ASSERT(ctx != nullptr); + GGML_ASSERT(waveform != nullptr); + GGML_ASSERT(filter != nullptr); + GGML_ASSERT(waveform->ne[3] == 1); + if (ratio <= 1) { + return waveform; + } + + const int lowpass_filter_width = 6; + const double rolloff = 0.99; + const int width = static_cast(std::ceil(static_cast(lowpass_filter_width) / rolloff)); + const int kernel_size = 2 * width * ratio + 1; + const int pad = width; + const int pad_left = 2 * width * ratio; + const int pad_right = kernel_size - ratio; + const int64_t time = waveform->ne[0]; + const int64_t channels = waveform->ne[1]; + const int64_t batch = waveform->ne[2]; + + GGML_ASSERT(filter->ne[0] == kernel_size); + + auto x = ggml_reshape_3d(ctx, waveform, time, channels * batch, 1); + x = replicate_pad_1d(runner_ctx, x, pad, pad); + x = depthwise_conv_transpose1d(ctx, x, filter, ratio); + x = ggml_ext_slice(ctx, x, 0, pad_left, x->ne[0] - pad_right); + return ggml_reshape_3d(ctx, x, x->ne[0], channels, batch); + } + + static ggml_tensor* crop_waveform_samples(ggml_context* ctx, + ggml_tensor* waveform, + int64_t target_samples) { + GGML_ASSERT(ctx != nullptr); + GGML_ASSERT(waveform != nullptr); + if (waveform->ne[0] == target_samples) { + return waveform; + } + GGML_ASSERT(waveform->ne[0] > target_samples); + return ggml_ext_slice(ctx, waveform, 0, 0, target_samples); } struct PixelNorm2D : public UnaryBlock { @@ -950,41 +930,66 @@ namespace LTXV { } } - ggml_tensor* decode_to_mel(GGMLRunnerContext* ctx, - ggml_tensor* latent, - int target_time, - int target_freq) { - auto mean = params["audio_vae.per_channel_statistics.mean-of-means"]; - auto stddev = params["audio_vae.per_channel_statistics.std-of-means"]; - auto decoder = std::dynamic_pointer_cast(blocks["audio_vae.decoder"]); - return decoder->forward(ctx, latent, mean, stddev, target_time, target_freq); - } + ggml_tensor* decode(GGMLRunnerContext* ctx, + ggml_tensor* latent, + ggml_tensor* bwe_skip_filter) { + int target_time = static_cast(latent->ne[1]) * config.latent_downsample_factor() - + (config.latent_downsample_factor() - 1); + int target_freq = config.mel_bins; - ggml_tensor* run_vocoder(GGMLRunnerContext* ctx, ggml_tensor* mel) { - auto vocoder = std::dynamic_pointer_cast(blocks["vocoder.vocoder"]); - return vocoder->forward(ctx, mel); - } + auto decoder = std::dynamic_pointer_cast(blocks["audio_vae.decoder"]); + auto mean = params["audio_vae.per_channel_statistics.mean-of-means"]; + auto stddev = params["audio_vae.per_channel_statistics.std-of-means"]; + auto mel = decoder->forward(ctx, latent, mean, stddev, target_time, target_freq); + auto vocoder = std::dynamic_pointer_cast(blocks["vocoder.vocoder"]); + auto waveform = vocoder->forward(ctx, mel); - ggml_tensor* run_bwe_generator(GGMLRunnerContext* ctx, ggml_tensor* mel) { - GGML_ASSERT(config.has_bwe); - auto bwe_generator = std::dynamic_pointer_cast(blocks["vocoder.bwe_generator"]); - return bwe_generator->forward(ctx, mel); - } + if (config.has_bwe) { + GGML_ASSERT(bwe_skip_filter != nullptr); + const int bwe_ratio = config.bwe_output_sample_rate / config.bwe_input_sample_rate; + const int64_t low_time = waveform->ne[0]; + const int64_t out_time = low_time * bwe_ratio; + int64_t remainder = low_time % config.bwe_hop_length; + auto bwe_waveform = waveform; + if (remainder != 0) { + bwe_waveform = ggml_pad_ext(ctx->ggml_ctx, + bwe_waveform, + 0, + static_cast(config.bwe_hop_length - remainder), + 0, + 0, + 0, + 0, + 0, + 0); + } - ggml_tensor* mel_basis_tensor() const { - auto iter = params.find("vocoder.mel_stft.mel_basis"); - return iter == params.end() ? nullptr : iter->second; - } + auto mel_basis = params["vocoder.mel_stft.mel_basis"]; + auto stft_basis = params["vocoder.mel_stft.stft_fn.forward_basis"]; + GGML_ASSERT(mel_basis != nullptr && stft_basis != nullptr); + auto bwe_mel = compute_log_mel_spectrogram(ctx, bwe_waveform, stft_basis, mel_basis, config.bwe_hop_length); + auto bwe_generator = std::dynamic_pointer_cast(blocks["vocoder.bwe_generator"]); + auto residual = bwe_generator->forward(ctx, bwe_mel); - ggml_tensor* stft_forward_basis_tensor() const { - auto iter = params.find("vocoder.mel_stft.stft_fn.forward_basis"); - return iter == params.end() ? nullptr : iter->second; + auto skip = upsample_waveform_hann(ctx, + bwe_waveform, + bwe_skip_filter, + bwe_ratio); + waveform = ggml_clamp(ctx->ggml_ctx, + ggml_add(ctx->ggml_ctx, residual, skip), + -1.0f, + 1.0f); + waveform = crop_waveform_samples(ctx->ggml_ctx, waveform, out_time); + } + + return waveform; } }; struct LTXAudioVAERunner : public GGMLRunner { LTXAudioVAEConfig config; LTXAudioVAE model; + sd::Tensor bwe_skip_filter_tensor; LTXAudioVAERunner(ggml_backend_t backend, ggml_backend_t params_backend, @@ -994,6 +999,10 @@ namespace LTXV { config(LTXAudioVAEConfig::detect_from_weights(tensor_storage_map)), model(config) { model.init(params_ctx, tensor_storage_map, prefix); + if (config.has_bwe) { + const int bwe_ratio = config.bwe_output_sample_rate / config.bwe_input_sample_rate; + bwe_skip_filter_tensor = sd::Tensor::from_vector(build_hann_resample_filter(bwe_ratio)); + } } void get_param_tensors(std::map& tensors, const std::string prefix) { @@ -1008,77 +1017,22 @@ namespace LTXV { return "ltx_audio_vae"; } - ggml_cgraph* build_base_graph(const sd::Tensor& latent_tensor) { - auto latent = make_input(latent_tensor); - int target_time = static_cast(latent_tensor.shape()[1]) * config.latent_downsample_factor() - - (config.latent_downsample_factor() - 1); - int target_freq = config.mel_bins; - - ggml_cgraph* gf = new_graph_custom(655360); - auto runner_ctx = GGMLRunner::get_context(); - auto mel = model.decode_to_mel(&runner_ctx, latent, target_time, target_freq); - auto waveform = model.run_vocoder(&runner_ctx, mel); - ggml_build_forward_expand(gf, waveform); - return gf; - } - - ggml_cgraph* build_bwe_graph(const sd::Tensor& mel_tensor) { - auto mel = make_input(mel_tensor); - ggml_cgraph* gf = new_graph_custom(655360); - auto runner_ctx = GGMLRunner::get_context(); - auto residual = model.run_bwe_generator(&runner_ctx, mel); - ggml_build_forward_expand(gf, residual); - return gf; - } - - sd::Tensor compute_base_waveform(int n_threads, - const sd::Tensor& latent_tensor) { - auto get_graph = [&]() -> ggml_cgraph* { - return build_base_graph(latent_tensor); - }; - return restore_trailing_singleton_dims(GGMLRunner::compute(get_graph, n_threads, false), 4); - } - - sd::Tensor compute_bwe_residual(int n_threads, - const sd::Tensor& mel_tensor) { - auto get_graph = [&]() -> ggml_cgraph* { - return build_bwe_graph(mel_tensor); - }; - return restore_trailing_singleton_dims(GGMLRunner::compute(get_graph, n_threads, false), 4); - } - sd::Tensor decode(int n_threads, const sd::Tensor& latent_tensor) { - auto waveform = compute_base_waveform(n_threads, latent_tensor); - if (!config.has_bwe || waveform.empty()) { - return waveform; - } - - auto waveform_host = normalize_waveform_for_host(waveform); - const int64_t low_time = waveform_host.shape()[0]; - const int64_t out_time = low_time * config.bwe_output_sample_rate / config.bwe_input_sample_rate; - int64_t remainder = low_time % config.bwe_hop_length; - if (remainder != 0) { - sd::Tensor padded({low_time + (config.bwe_hop_length - remainder), waveform_host.shape()[1], waveform_host.shape()[2]}); - sd::ops::slice_assign(&padded, 0, 0, low_time, waveform_host); - waveform_host = std::move(padded); - } - - auto mel_basis_tensor = model.mel_basis_tensor(); - auto stft_basis_tensor = model.stft_forward_basis_tensor(); - GGML_ASSERT(mel_basis_tensor != nullptr && stft_basis_tensor != nullptr); - auto mel_basis = load_param_tensor_f32(mel_basis_tensor); - auto forward_basis = load_param_tensor_f32(stft_basis_tensor); - auto bwe_mel = compute_log_mel_spectrogram(waveform_host, forward_basis, mel_basis, config.bwe_hop_length); - auto residual_raw = compute_bwe_residual(n_threads, bwe_mel); - if (residual_raw.empty()) { - return waveform; - } - auto residual = normalize_waveform_for_host(residual_raw); - auto skip = upsample_waveform_hann(waveform_host, config.bwe_output_sample_rate / config.bwe_input_sample_rate); - auto combined = sd::ops::clamp(residual + skip, -1.0f, 1.0f); - auto cropped = crop_waveform_samples(combined, out_time); - return restore_trailing_singleton_dims(cropped, 4); + int64_t t0 = ggml_time_ms(); + auto get_graph = [&]() -> ggml_cgraph* { + auto latent = make_input(latent_tensor); + ggml_tensor* bwe_skip_filter = config.has_bwe ? make_input(bwe_skip_filter_tensor) : nullptr; + ggml_cgraph* gf = new_graph_custom(655360); + auto runner_ctx = GGMLRunner::get_context(); + auto waveform = model.decode(&runner_ctx, latent, bwe_skip_filter); + ggml_build_forward_expand(gf, waveform); + return gf; + }; + auto result = restore_trailing_singleton_dims(GGMLRunner::compute(get_graph, n_threads, false), 4); + int64_t t1 = ggml_time_ms(); + LOG_INFO("ltx audio vae decode completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); + return result; } void test(const std::string& input_path) { diff --git a/otherarch/sdcpp/ltx_vae.hpp b/otherarch/sdcpp/ltx_vae.hpp index 751995860..756741a43 100644 --- a/otherarch/sdcpp/ltx_vae.hpp +++ b/otherarch/sdcpp/ltx_vae.hpp @@ -1,6 +1,7 @@ #ifndef __SD_LTX_VAE_HPP__ #define __SD_LTX_VAE_HPP__ +#include #include #include #include @@ -143,16 +144,25 @@ namespace LTXVAE { std::vector& feat_map, int& feat_idx, int chunk_idx, - bool causal = true) { + bool causal = true, + int temporal_pad = 0) { auto conv = std::dynamic_pointer_cast(blocks["conv"]); const int pad = causal ? (time_kernel_size - 1) : (time_kernel_size - 1) / 2; ggml_tensor* prev = (feat_idx < (int)feat_map.size()) ? feat_map[feat_idx] : nullptr; + GGML_ASSERT(x->ne[2] >= temporal_pad); + + int end_idx = x->ne[2] - temporal_pad; + int start_idx = std::max(end_idx - pad, 0); + // Save a contiguous copy of the last `pad` frames so the large `x` // tensor is not kept alive across iterations by a dangling view. - if (feat_idx < (int)feat_map.size() && pad > 0 && x->ne[2] >= pad) { - auto slice = ggml_ext_slice(ctx->ggml_ctx, x, 2, x->ne[2] - pad, x->ne[2]); + if (feat_idx < (int)feat_map.size() && end_idx - start_idx > 0) { + GGML_ASSERT(start_idx >= 0); + GGML_ASSERT(end_idx > 0); + + auto slice = ggml_ext_slice(ctx->ggml_ctx, x, 2, start_idx, end_idx); feat_map[feat_idx] = ggml_cont(ctx->ggml_ctx, slice); } feat_idx++; @@ -284,7 +294,8 @@ namespace LTXVAE { bool causal, std::vector& feat_map, int& feat_idx, - int chunk_idx) { + int chunk_idx, + int temporal_pad = 0) { auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); auto conv1 = std::dynamic_pointer_cast(blocks["conv1"]); auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); @@ -311,14 +322,14 @@ namespace LTXVAE { h = apply_scale_shift(ctx->ggml_ctx, h, scale1, shift1); } h = ggml_silu_inplace(ctx->ggml_ctx, h); - h = conv1->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal); + h = conv1->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal, temporal_pad); h = norm2->forward(ctx, h); if (timestep_conditioning) { h = apply_scale_shift(ctx->ggml_ctx, h, scale2, shift2); } h = ggml_silu_inplace(ctx->ggml_ctx, h); - h = conv2->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal); + h = conv2->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal, temporal_pad); return ggml_add(ctx->ggml_ctx, h, x); } @@ -367,7 +378,8 @@ namespace LTXVAE { bool causal, std::vector& feat_map, int& feat_idx, - int chunk_idx) { + int chunk_idx, + int temporal_pad = 0) { ggml_tensor* timestep_embed = nullptr; if (timestep_conditioning) { GGML_ASSERT(timestep != nullptr); @@ -376,7 +388,7 @@ namespace LTXVAE { } for (int i = 0; i < num_layers; i++) { auto resnet = std::dynamic_pointer_cast(blocks["res_blocks." + std::to_string(i)]); - x = resnet->forward(ctx, x, timestep_embed, causal, feat_map, feat_idx, chunk_idx); + x = resnet->forward(ctx, x, timestep_embed, causal, feat_map, feat_idx, chunk_idx, temporal_pad); } return x; } @@ -437,7 +449,8 @@ namespace LTXVAE { bool causal, std::vector& feat_map, int& feat_idx, - int chunk_idx) { + int chunk_idx, + int temporal_pad = 0) { auto conv = std::dynamic_pointer_cast(blocks["conv"]); bool drop_first = (chunk_idx == 0) && (factor_t > 1); @@ -453,7 +466,7 @@ namespace LTXVAE { x_in = res; } - x = conv->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal); + x = conv->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal, temporal_pad); x = depth_to_space_3d(ctx->ggml_ctx, x, get_output_channels(), factor_t, factor_s, drop_first); if (residual) { x = ggml_add(ctx->ggml_ctx, x, x_in); @@ -986,7 +999,8 @@ namespace LTXVAE { ggml_tensor* timestep, std::vector& feat_map, int& feat_idx, - int chunk_idx) { + int chunk_idx, + int& temporal_pad) { auto conv_in = std::dynamic_pointer_cast(blocks["conv_in"]); auto conv_norm_out = std::dynamic_pointer_cast(blocks["conv_norm_out"]); auto conv_out = std::dynamic_pointer_cast(blocks["conv_out"]); @@ -998,7 +1012,7 @@ namespace LTXVAE { } // conv_in with feat_map for left temporal context - x = conv_in->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder); + x = conv_in->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder, temporal_pad); // up_blocks int block_idx = 0; @@ -1006,12 +1020,13 @@ namespace LTXVAE { auto mid_block = std::dynamic_pointer_cast(blocks["up_blocks." + std::to_string(block_idx)]); if (mid_block) { x = mid_block->forward(ctx, x, scaled_timestep, causal_decoder, - feat_map, feat_idx, chunk_idx); + feat_map, feat_idx, chunk_idx, temporal_pad); } else { auto upsample = std::dynamic_pointer_cast( blocks["up_blocks." + std::to_string(block_idx)]); x = upsample->forward(ctx, x, causal_decoder, - feat_map, feat_idx, chunk_idx); + feat_map, feat_idx, chunk_idx, temporal_pad); + temporal_pad *= upsample->factor_t; } block_idx++; } @@ -1028,7 +1043,7 @@ namespace LTXVAE { x = apply_scale_shift(ctx->ggml_ctx, x, scale, shift); } x = ggml_silu_inplace(ctx->ggml_ctx, x); - x = conv_out->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder); + x = conv_out->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder, temporal_pad); return x; } }; @@ -1084,7 +1099,9 @@ namespace LTXVAE { // tensors can be freed by GGML before the next iteration starts. ggml_tensor* decode_tiled(GGMLRunnerContext* ctx, ggml_tensor* z, - ggml_tensor* timestep) { + ggml_tensor* timestep, + int temporal_window_size = 1, + int temporal_tile_overlap = 0) { auto decoder = std::dynamic_pointer_cast(blocks["decoder"]); auto processor = std::dynamic_pointer_cast(blocks["per_channel_statistics"]); auto latents = processor->un_normalize(ctx, z); @@ -1099,18 +1116,69 @@ namespace LTXVAE { // 128 slots is generous enough for any supported decoder configuration. std::vector feat_map(128, nullptr); + // Ensure window size is at least 1 + int window = std::max(1, temporal_window_size); + int overlap = std::max(0, temporal_tile_overlap); + + if (overlap >= window) { + LOG_WARN("temporal_tile_overlap (%d) is greater than or equal to temporal_tile_frames (%d), adjusting values to avoid empty decode windows", + overlap, window); + overlap = window - 1; + } + LOG_DEBUG("Using temporal tiling: temporal_tile_frames = %d, temporal_tile_overlap = %d, total frames = %d, resulting in %d tiles", + window, + overlap, + (int)T, + (T + window - overlap - 1) / (window - overlap)); ggml_tensor* out = nullptr; - for (int i = 0; i < (int)T; i++) { + for (int i = 0; i < (int)T - overlap; i += (window - overlap)) { int feat_idx = 0; - auto z_i = ggml_ext_slice(ctx->ggml_ctx, latents, 2, i, i + 1); - auto out_i = decoder->forward_tiled_frame(ctx, z_i, timestep, - feat_map, feat_idx, i); - out = (out == nullptr) ? out_i : ggml_concat(ctx->ggml_ctx, out, out_i, 2); + + // Calculate the end index for the current temporal chunk + int end_i = std::min((int)T, i + window); + if (end_i >= (int)T) { + overlap = 0; // avoid overlap issues in the last chunk + } + + int chunk_overlap = overlap; // modified by forward_tiled_frame temporal inflation + + auto z_chunk = ggml_ext_slice(ctx->ggml_ctx, latents, 2, i, end_i); + + auto out_chunk = decoder->forward_tiled_frame(ctx, z_chunk, timestep, + feat_map, feat_idx, i, chunk_overlap); + + // discard overlap frames if it's not the final chunk + if (overlap > 0 && end_i < (int)T) { + out_chunk = ggml_ext_slice(ctx->ggml_ctx, out_chunk, 2, 0, out_chunk->ne[2] - chunk_overlap); + } + + out = (out == nullptr) ? out_chunk : ggml_concat(ctx->ggml_ctx, out, out_chunk, 2); } return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out, patch_size, 1); } + ggml_tensor* decode_tiled_chunk(GGMLRunnerContext* ctx, + ggml_tensor* z, + ggml_tensor* timestep, + std::vector& feat_map, + int chunk_idx, + int temporal_tile_overlap, + int& feat_idx) { + auto decoder = std::dynamic_pointer_cast(blocks["decoder"]); + auto processor = std::dynamic_pointer_cast(blocks["per_channel_statistics"]); + auto latents = processor->un_normalize(ctx, z); + + feat_idx = 0; + int chunk_overlap = temporal_tile_overlap; // modified by forward_tiled_frame temporal inflation + auto out_chunk = decoder->forward_tiled_frame(ctx, latents, timestep, + feat_map, feat_idx, chunk_idx, chunk_overlap); + if (chunk_overlap > 0) { + out_chunk = ggml_ext_slice(ctx->ggml_ctx, out_chunk, 2, 0, out_chunk->ne[2] - chunk_overlap); + } + return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out_chunk, patch_size, 1); + } + ggml_tensor* encode(GGMLRunnerContext* ctx, ggml_tensor* x) { GGML_ASSERT(!decode_only); @@ -1140,8 +1208,13 @@ namespace LTXVAE { } // namespace LTXVAE struct LTXVideoVAE : public VAE { + static constexpr int DEFAULT_TEMPORAL_TILE_FRAMES = 4; + static constexpr int DEFAULT_TEMPORAL_TILE_OVERLAP = 1; + bool decode_only; bool temporal_tiling_enabled = false; + int temporal_tile_frames = DEFAULT_TEMPORAL_TILE_FRAMES; + int temporal_tile_overlap = DEFAULT_TEMPORAL_TILE_OVERLAP; int ltx_vae_version; bool timestep_conditioning; int patch_size; @@ -1178,10 +1251,64 @@ struct LTXVideoVAE : public VAE { temporal_tiling_enabled = enabled; } + void set_tiling_params(const sd_tiling_params_t& params) override { + temporal_tiling_enabled = params.temporal_tiling; + temporal_tile_frames = DEFAULT_TEMPORAL_TILE_FRAMES; + temporal_tile_overlap = DEFAULT_TEMPORAL_TILE_OVERLAP; + + for (const auto& [key, value] : parse_key_value_args(params.extra_tiling_args, "LTX VAE extra tiling arg")) { + int parsed = 0; + if (!parse_strict_int(value, parsed)) { + LOG_WARN("ignoring invalid LTX VAE extra tiling arg '%s=%s'", key.c_str(), value.c_str()); + } else if (key == "temporal_tile_frames") { + temporal_tile_frames = std::max(1, parsed); + } else if (key == "temporal_tile_overlap") { + temporal_tile_overlap = std::max(0, parsed); + } else { + LOG_WARN("ignoring unknown LTX VAE extra tiling arg '%s'", key.c_str()); + } + } + } + void get_param_tensors(std::map& tensors, const std::string prefix) override { vae.get_param_tensors(tensors, prefix); } + struct TemporalTilePlan { + int frames = 1; + int overlap = 0; + int stride = 1; + int num_tiles = 1; + }; + + TemporalTilePlan resolve_temporal_tile_plan(int64_t total_frames) const { + TemporalTilePlan plan; + plan.frames = std::max(1, temporal_tile_frames); + plan.overlap = std::max(0, temporal_tile_overlap); + + if (plan.overlap >= plan.frames) { + LOG_WARN("temporal_tile_overlap (%d) is greater than or equal to temporal_tile_frames (%d), adjusting values to avoid empty decode windows", + plan.overlap, + plan.frames); + plan.overlap = plan.frames - 1; + } + if (total_frames > 1 && plan.overlap >= total_frames) { + LOG_WARN("temporal_tile_overlap (%d) is greater than or equal to total latent frames (%lld), adjusting values to decode at least one tile", + plan.overlap, + (long long)total_frames); + plan.overlap = static_cast(total_frames - 1); + } + + plan.stride = std::max(1, plan.frames - plan.overlap); + int64_t tiled_frames = std::max(1, total_frames - plan.overlap); + plan.num_tiles = total_frames > 0 ? static_cast((tiled_frames + plan.stride - 1) / plan.stride) : 0; + return plan; + } + + std::string temporal_feat_cache_name(size_t feat_idx) const { + return "ltx_vae_temporal_feat:" + std::to_string(feat_idx); + } + ggml_cgraph* build_graph(const sd::Tensor& z_tensor, bool decode_graph) { ggml_cgraph* gf = new_graph_custom(20480); ggml_tensor* z = make_input(z_tensor); @@ -1192,18 +1319,97 @@ struct LTXVideoVAE : public VAE { auto runner_ctx = get_context(); ggml_tensor* out; - bool use_tiled = decode_graph && temporal_tiling_enabled && - z_tensor.dim() == 5 && z_tensor.shape()[2] > 1; - if (use_tiled) { - out = vae.decode_tiled(&runner_ctx, z, timestep); - } else { - out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z); - } + out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z); ggml_build_forward_expand(gf, out); return gf; } + ggml_cgraph* build_temporal_tile_graph(const sd::Tensor& z_chunk_tensor, + int chunk_idx, + int chunk_overlap) { + ggml_cgraph* gf = new_graph_custom(20480); + ggml_tensor* z = make_input(z_chunk_tensor); + ggml_tensor* timestep = nullptr; + if (timestep_conditioning) { + timestep = make_input(decode_timestep_tensor); + } + + std::vector feat_map(128, nullptr); + for (size_t feat_idx = 0; feat_idx < feat_map.size(); ++feat_idx) { + feat_map[feat_idx] = get_cache_tensor_by_name(temporal_feat_cache_name(feat_idx)); + } + + auto runner_ctx = get_context(); + int feat_count = 0; + ggml_tensor* out = vae.decode_tiled_chunk(&runner_ctx, + z, + timestep, + feat_map, + chunk_idx, + chunk_overlap, + feat_count); + + for (int feat_idx = 0; feat_idx < feat_count && feat_idx < static_cast(feat_map.size()); ++feat_idx) { + ggml_tensor* feat_cache = feat_map[static_cast(feat_idx)]; + if (feat_cache != nullptr) { + cache(temporal_feat_cache_name(static_cast(feat_idx)), feat_cache); + ggml_build_forward_expand(gf, feat_cache); + } + } + + ggml_build_forward_expand(gf, out); + return gf; + } + + sd::Tensor decode_temporal_tiled_streaming(const int n_threads, + const sd::Tensor& input, + size_t expected_dim) { + const int64_t total_frames = input.shape()[2]; + TemporalTilePlan plan = resolve_temporal_tile_plan(total_frames); + + LOG_DEBUG("Using streaming temporal tiling: temporal_tile_frames=%d, temporal_tile_overlap=%d, total latent frames=%lld, resulting in %d tiles", + plan.frames, + plan.overlap, + (long long)total_frames, + plan.num_tiles); + + free_cache_ctx_and_buffer(); + cache_tensor_map.clear(); + + sd::Tensor output; + for (int64_t start = 0; start < total_frames - plan.overlap; start += plan.stride) { + const int64_t end = std::min(total_frames, start + plan.frames); + const int chunk_overlap = end < total_frames ? plan.overlap : 0; + auto z_chunk = sd::ops::slice(input, 2, start, end); + + LOG_DEBUG("LTX VAE temporal tile %lld/%d: latent frames [%lld, %lld), overlap=%d", + (long long)(start / plan.stride + 1), + plan.num_tiles, + (long long)start, + (long long)end, + chunk_overlap); + + auto get_graph = [&]() -> ggml_cgraph* { + return build_temporal_tile_graph(z_chunk, + static_cast(start), + chunk_overlap); + }; + auto chunk = restore_trailing_singleton_dims(GGMLRunner::compute(get_graph, n_threads, true), + expected_dim); + if (chunk.empty()) { + free_cache_ctx_and_buffer(); + cache_tensor_map.clear(); + return {}; + } + output = output.empty() ? std::move(chunk) : sd::ops::concat(output, chunk, 2); + } + + free_cache_ctx_and_buffer(); + cache_tensor_map.clear(); + return output; + } + ggml_cgraph* build_latent_statistics_graph(const sd::Tensor& z_tensor, bool normalize) { ggml_cgraph* gf = new_graph_custom(1024); ggml_tensor* z = make_input(z_tensor); @@ -1239,6 +1445,9 @@ struct LTXVideoVAE : public VAE { input = sd::ops::slice(input, 2, 0, cropped_t); } } + if (decode_graph && temporal_tiling_enabled && input.dim() == 5 && input.shape()[2] > 1) { + return decode_temporal_tiled_streaming(n_threads, input, expected_dim); + } auto get_graph = [&]() -> ggml_cgraph* { return build_graph(input, decode_graph); }; diff --git a/otherarch/sdcpp/stable-diffusion.cpp b/otherarch/sdcpp/stable-diffusion.cpp index a5fbe89fe..19d8c0223 100644 --- a/otherarch/sdcpp/stable-diffusion.cpp +++ b/otherarch/sdcpp/stable-diffusion.cpp @@ -155,7 +155,7 @@ public: bool kcpp_lora_cache_populate = false; std::string taesd_path; - sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0, 0}; + sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0, 0, nullptr}; bool offload_params_to_cpu = false; float max_vram = 0.f; bool use_pmid = false; @@ -2921,7 +2921,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) { sd_img_gen_params->batch_count = 1; sd_img_gen_params->control_strength = 0.9f; sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f}; - sd_img_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f}; + sd_img_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr}; sd_cache_params_init(&sd_img_gen_params->cache); sd_hires_params_init(&sd_img_gen_params->hires); } @@ -2950,7 +2950,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { "increase_ref_index: %s\n" "control_strength: %.2f\n" "photo maker: {style_strength = %.2f, id_images_count = %d, id_embed_path = %s}\n" - "VAE tiling: %s (temporal=%s)\n" + "VAE tiling: %s (temporal=%s, extra_tiling_args=%s)\n" "hires: {enabled=%s, upscaler=%s, model_path=%s, scale=%.2f, target=%dx%d, steps=%d, denoising_strength=%.2f}\n", SAFE_STR(sd_img_gen_params->prompt), SAFE_STR(sd_img_gen_params->negative_prompt), @@ -2970,6 +2970,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { SAFE_STR(sd_img_gen_params->pm_params.id_embed_path), BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled), BOOL_STR(sd_img_gen_params->vae_tiling_params.temporal_tiling), + SAFE_STR(sd_img_gen_params->vae_tiling_params.extra_tiling_args), BOOL_STR(sd_img_gen_params->hires.enabled), sd_hires_upscaler_name(sd_img_gen_params->hires.upscaler), SAFE_STR(sd_img_gen_params->hires.model_path), @@ -3007,7 +3008,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) { sd_vid_gen_params->fps = 16; sd_vid_gen_params->moe_boundary = 0.875f; sd_vid_gen_params->vace_strength = 1.f; - sd_vid_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f}; + sd_vid_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr}; sd_vid_gen_params->hires.enabled = false; sd_vid_gen_params->hires.upscaler = SD_HIRES_UPSCALER_LATENT; sd_vid_gen_params->hires.scale = 2.f; @@ -5460,14 +5461,24 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, sd_ctx->sd->diffusion_model->free_params_buffer(); } + int64_t latent_end = ggml_time_ms(); + LOG_INFO("generating latent video completed, taking %.2fs", (latent_end - latent_start) * 1.0f / 1000); + sd_audio_t* generated_audio = nullptr; if (sd_version_is_ltxav(sd_ctx->sd->version) && latents.audio_length > 0 && sd_ctx->sd->audio_vae_model != nullptr) { + int64_t audio_latent_decode_start = ggml_time_ms(); + auto audio_latent = unpack_ltxav_audio_latent(final_latent, latents.audio_length, sd_ctx->sd->get_latent_channel()); if (!audio_latent.empty()) { + LOG_DEBUG("decode audio latent %dx%dx%dx%d", + (int)audio_latent.shape()[0], + (int)audio_latent.shape()[1], + (int)audio_latent.shape()[2], + (int)audio_latent.shape()[3]); auto waveform = sd_ctx->sd->decode_ltx_audio_latent(audio_latent); if (!waveform.empty()) { generated_audio = waveform_to_sd_audio(sd_ctx->sd, waveform); @@ -5475,6 +5486,8 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, LOG_WARN("LTX audio latent decode failed; continuing with silent video output"); } } + int64_t audio_latent_decode_end = ggml_time_ms(); + LOG_INFO("decoding audio latent completed, taking %.2fs", (audio_latent_decode_end - audio_latent_decode_start) * 1.0f / 1000); } if (latents.video_conditioning_frame_count > 0) { @@ -5487,9 +5500,6 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, final_latent = sd::ops::slice(final_latent, 2, latents.ref_image_num, final_latent.shape()[2]); } - int64_t latent_end = ggml_time_ms(); - LOG_INFO("generating latent video completed, taking %.2fs", (latent_end - latent_start) * 1.0f / 1000); - auto result = decode_video_outputs(sd_ctx, latent_upscale_enabled ? hires_request : request, final_latent, num_frames_out); if (result == nullptr) { free_sd_audio(generated_audio); diff --git a/otherarch/sdcpp/stable-diffusion.h b/otherarch/sdcpp/stable-diffusion.h index 3ae44addf..f8b2c2f59 100644 --- a/otherarch/sdcpp/stable-diffusion.h +++ b/otherarch/sdcpp/stable-diffusion.h @@ -160,6 +160,7 @@ typedef struct { float target_overlap; float rel_size_x; float rel_size_y; + const char* extra_tiling_args; } sd_tiling_params_t; typedef struct { diff --git a/otherarch/sdcpp/tae.hpp b/otherarch/sdcpp/tae.hpp index c7217ceb1..62f14c9ee 100644 --- a/otherarch/sdcpp/tae.hpp +++ b/otherarch/sdcpp/tae.hpp @@ -259,10 +259,54 @@ public: } }; -ggml_tensor* patchify(ggml_context* ctx, - ggml_tensor* x, - int64_t patch_size, - int64_t b = 1) { +class WideMemBlock : public GGMLBlock { + bool has_skip_conv = false; + +public: + WideMemBlock(int channels, int out_channels) + : has_skip_conv(channels != out_channels) { + int groups = std::max(1, out_channels / 64); + blocks["conv.0"] = std::shared_ptr(new Conv2d(channels * 2, out_channels, {1, 1}, {1, 1})); + blocks["conv.2"] = std::shared_ptr(new Conv2d_grouped(out_channels, out_channels, groups, {3, 3}, {1, 1}, {1, 1})); + blocks["conv.4"] = std::shared_ptr(new Conv2d(out_channels, out_channels, {1, 1}, {1, 1})); + blocks["conv.6"] = std::shared_ptr(new Conv2d_grouped(out_channels, out_channels, groups, {3, 3}, {1, 1}, {1, 1})); + if (has_skip_conv) { + blocks["skip"] = std::shared_ptr(new Conv2d(channels, out_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false)); + } + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* past) { + // x: [n, channels, h, w] + auto conv0 = std::dynamic_pointer_cast(blocks["conv.0"]); + auto conv1 = std::dynamic_pointer_cast(blocks["conv.2"]); + auto conv2 = std::dynamic_pointer_cast(blocks["conv.4"]); + auto conv3 = std::dynamic_pointer_cast(blocks["conv.6"]); + + auto h = ggml_concat(ctx->ggml_ctx, x, past, 2); + h = conv0->forward(ctx, h); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + h = conv1->forward(ctx, h); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + h = conv2->forward(ctx, h); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + h = conv3->forward(ctx, h); + + auto skip = x; + if (has_skip_conv) { + auto skip_conv = std::dynamic_pointer_cast(blocks["skip"]); + skip = skip_conv->forward(ctx, x); + } + h = ggml_add_inplace(ctx->ggml_ctx, h, skip); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + return h; + } +}; + +ggml_tensor* +patchify(ggml_context* ctx, + ggml_tensor* x, + int64_t patch_size, + int64_t b = 1) { // x: [f, b*c, h*q, w*r] // return: [f, b*c*r*q, h, w] if (patch_size == 1) { @@ -325,7 +369,6 @@ public: int t_downscale = 1; TinyVideoEncoder(int z_channels = 4, int patch_size = 1, std::vector time_downscale = {true, true, false}) : z_channels(z_channels), patch_size(patch_size) { - // self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool)) t_downscale = 1; for (bool downscale : time_downscale) { if (downscale) { @@ -384,11 +427,18 @@ class TinyVideoDecoder : public UnaryBlock { int channels[num_layers + 1] = {256, 128, 64, 64}; int patch_size = 1; int t_upscale = 1; + bool is_wide = false; public: - TinyVideoDecoder(int z_channels = 4, int patch_size = 1, std::vector time_upscale = {false, true, true}) - : z_channels(z_channels), patch_size(patch_size) { + TinyVideoDecoder(int z_channels = 4, int patch_size = 1, std::vector time_upscale = {false, true, true}, bool is_wide = false) + : z_channels(z_channels), patch_size(patch_size), is_wide(is_wide) { t_upscale = 1; + if (is_wide) { + channels[0] = 1024; + channels[1] = 512; + channels[2] = 256; + } + for (bool upscale : time_upscale) { if (upscale) { t_upscale *= 2; @@ -400,7 +450,11 @@ public: for (int i = 0; i < num_layers; i++) { int stride = time_upscale[i] ? 2 : 1; for (int j = 0; j < num_blocks; j++) { - blocks[std::to_string(index++)] = std::shared_ptr(new MemBlock(channels[i], channels[i])); + if (is_wide) { + blocks[std::to_string(index++)] = std::shared_ptr(new WideMemBlock(channels[i], channels[i])); + } else { + blocks[std::to_string(index++)] = std::shared_ptr(new MemBlock(channels[i], channels[i])); + } } index++; // nn.Upsample() blocks[std::to_string(index++)] = std::shared_ptr(new TGrow(channels[i], stride)); @@ -425,10 +479,15 @@ public: int index = 3; for (int i = 0; i < num_layers; i++) { for (int j = 0; j < num_blocks; j++) { - auto block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); - auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0); - mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0); - h = block->forward(ctx, h, mem); + auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0); + mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0); + if (is_wide) { + auto block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + h = block->forward(ctx, h, mem); + } else { + auto block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + h = block->forward(ctx, h, mem); + } } // upsample index++; @@ -455,6 +514,7 @@ class TAEHV : public GGMLBlock { protected: bool decode_only; SDVersion version; + bool is_wide; public: int z_channels = 16; @@ -462,8 +522,8 @@ public: std::vector time_upscale = {false, true, true}; public: - TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2) - : decode_only(decode_only), version(version) { + TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2, bool is_wide = false) + : decode_only(decode_only), version(version), is_wide(is_wide) { int patch = 1; if (version == VERSION_WAN2_2_TI2V) { z_channels = 48; @@ -474,7 +534,7 @@ public: time_downscale = {true, true, true}; time_upscale = {true, true, true}; } - blocks["decoder"] = std::shared_ptr(new TinyVideoDecoder(z_channels, patch, time_upscale)); + blocks["decoder"] = std::shared_ptr(new TinyVideoDecoder(z_channels, patch, time_upscale, is_wide)); if (!decode_only) { blocks["encoder"] = std::shared_ptr(new TinyVideoEncoder(z_channels, patch, time_downscale)); } @@ -623,6 +683,7 @@ struct TinyImageAutoEncoder : public VAE { struct TinyVideoAutoEncoder : public VAE { TAEHV taehv; bool decode_only = false; + bool is_wide = false; TinyVideoAutoEncoder(ggml_backend_t backend, ggml_backend_t params_backend, @@ -631,8 +692,14 @@ struct TinyVideoAutoEncoder : public VAE { bool decoder_only = true, SDVersion version = VERSION_WAN2) : decode_only(decoder_only), - taehv(decoder_only, version), VAE(version, backend, params_backend) { + for (auto tensor_storage : tensor_storage_map) { + if (tensor_storage.first.find(prefix + ".3.conv.6.weight") != std::string::npos) { + is_wide = true; + break; + } + } + taehv = TAEHV(decoder_only, version, is_wide); scale_input = false; taehv.init(params_ctx, tensor_storage_map, prefix); } @@ -663,7 +730,8 @@ struct TinyVideoAutoEncoder : public VAE { } ggml_cgraph* build_graph(const sd::Tensor& z_tensor, bool decode_graph) { - ggml_cgraph* gf = ggml_new_graph(compute_ctx); + ggml_cgraph* gf = decode_graph && is_wide ? ggml_new_graph_custom(compute_ctx, 4096, false) + : ggml_new_graph(compute_ctx); ggml_tensor* z = make_input(z_tensor); auto runner_ctx = get_context(); ggml_tensor* out = decode_graph ? taehv.decode(&runner_ctx, z) : taehv.encode(&runner_ctx, z); diff --git a/otherarch/sdcpp/util.cpp b/otherarch/sdcpp/util.cpp index da369fac6..30ebed2f3 100644 --- a/otherarch/sdcpp/util.cpp +++ b/otherarch/sdcpp/util.cpp @@ -1,8 +1,10 @@ #include "util.h" #include +#include #include #include #include +#include #include #include #include @@ -420,6 +422,88 @@ std::vector split_string(const std::string& str, char delimiter) { return result; } +KeyValueArgs parse_key_value_args(const char* args, const char* context) { + KeyValueArgs pairs; + + if (args == nullptr || args[0] == '\0') { + return pairs; + } + + std::string raw(args); + size_t start = 0; + for (size_t pos = 0; pos <= raw.size(); ++pos) { + if (pos != raw.size() && raw[pos] != ',' && raw[pos] != ';') { + continue; + } + + std::string token = trim(raw.substr(start, pos - start)); + if (!token.empty()) { + size_t eq = token.find('='); + if (eq == std::string::npos) { + const char* log_context = context ? context : "key=value arg"; + LOG_WARN("ignoring malformed %s '%s'", log_context, token.c_str()); + } else { + std::string key = trim(token.substr(0, eq)); + std::string value = trim(token.substr(eq + 1)); + pairs.emplace_back(std::move(key), std::move(value)); + } + } + + start = pos + 1; + } + + return pairs; +} + +KeyValueArgs parse_key_value_args(const std::string& args, const char* context) { + return parse_key_value_args(args.c_str(), context); +} + +bool parse_strict_float(const std::string& text, float& value) { + try { + size_t consumed = 0; + float parsed = std::stof(text, &consumed); + if (!trim(text.substr(consumed)).empty()) { + return false; + } + value = parsed; + return true; + } catch (const std::exception&) { + return false; + } +} + +bool parse_strict_int(const std::string& text, int& value) { + try { + size_t consumed = 0; + int parsed = std::stoi(text, &consumed); + if (!trim(text.substr(consumed)).empty()) { + return false; + } + value = parsed; + return true; + } catch (const std::exception&) { + return false; + } +} + +bool parse_strict_bool(const std::string& text, bool& value) { + std::string lowered = trim(text); + std::transform(lowered.begin(), lowered.end(), lowered.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + + if (lowered == "1" || lowered == "true" || lowered == "yes" || lowered == "on") { + value = true; + return true; + } + if (lowered == "0" || lowered == "false" || lowered == "no" || lowered == "off") { + value = false; + return true; + } + return false; +} + // { kcpp static int sdloglevel = 0; //-1 = hide all, 0 = normal, 1 = showall static bool sdquiet = false; diff --git a/otherarch/sdcpp/util.h b/otherarch/sdcpp/util.h index 21dacc997..5cde617cb 100644 --- a/otherarch/sdcpp/util.h +++ b/otherarch/sdcpp/util.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include "ggml-backend.h" @@ -68,6 +69,15 @@ protected: std::string path_join(const std::string& p1, const std::string& p2); std::vector split_string(const std::string& str, char delimiter); + +using KeyValueArgs = std::vector>; + +KeyValueArgs parse_key_value_args(const char* args, const char* context = "key=value arg"); +KeyValueArgs parse_key_value_args(const std::string& args, const char* context = "key=value arg"); +bool parse_strict_float(const std::string& text, float& value); +bool parse_strict_int(const std::string& text, int& value); +bool parse_strict_bool(const std::string& text, bool& value); + void pretty_progress(int step, int steps, float time); void pretty_bytes_progress(int step, int steps, uint64_t bytes_processed, float elapsed_seconds); diff --git a/otherarch/sdcpp/vae.hpp b/otherarch/sdcpp/vae.hpp index d7e0fdee1..cc4cd967f 100644 --- a/otherarch/sdcpp/vae.hpp +++ b/otherarch/sdcpp/vae.hpp @@ -167,6 +167,7 @@ public: int64_t t0 = ggml_time_ms(); sd::Tensor input = x; sd::Tensor output; + set_tiling_params(tiling_params); if (tiling_params.enabled) { const int scale_factor = get_scale_factor(); @@ -216,6 +217,9 @@ public: virtual void get_param_tensors(std::map& tensors, const std::string prefix) = 0; virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); }; virtual void set_temporal_tiling_enabled(bool enabled) { SD_UNUSED(enabled); }; + virtual void set_tiling_params(const sd_tiling_params_t& params) { + set_temporal_tiling_enabled(params.temporal_tiling); + }; }; struct FakeVAE : public VAE {