sd: sync to master-642-3a8788c

This commit is contained in:
Wagner Bruna 2026-05-21 20:46:14 -03:00
parent f27795cef0
commit e7f386ceb6
12 changed files with 859 additions and 440 deletions

View file

@ -835,6 +835,10 @@ ArgOptions SDGenerationParams::get_options() {
"--extra-sample-args",
"extra sampler/scheduler args, key=value list. lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma",
&extra_sample_args},
{"",
"--extra-tiling-args",
"extra VAE tiling args, key=value list. LTX video VAE supports temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1)",
&extra_tiling_args},
};
options.int_options = {
@ -1780,6 +1784,9 @@ bool SDGenerationParams::from_json_str(
if (tiling_json.contains("rel_size_y") && tiling_json["rel_size_y"].is_number()) {
vae_tiling_params.rel_size_y = tiling_json["rel_size_y"];
}
if (tiling_json.contains("extra_tiling_args") && tiling_json["extra_tiling_args"].is_string()) {
extra_tiling_args = tiling_json["extra_tiling_args"].get<std::string>();
}
}
if (!parse_lora_json_field(j, lora_path_resolver, lora_map, high_noise_lora_map)) {
@ -2002,6 +2009,8 @@ bool SDGenerationParams::initialize_cache_params() {
}
bool SDGenerationParams::resolve(const std::string& lora_model_dir, const std::string& hires_upscalers_dir, bool strict) {
vae_tiling_params.extra_tiling_args = extra_tiling_args.empty() ? nullptr : extra_tiling_args.c_str();
if (high_noise_sample_params.sample_steps <= 0) {
high_noise_sample_params.sample_steps = -1;
}
@ -2188,6 +2197,7 @@ sd_img_gen_params_t SDGenerationParams::to_sd_img_gen_params_t() {
sample_params.custom_sigmas_count = static_cast<int>(custom_sigmas.size());
sample_params.extra_sample_args = extra_sample_args.empty() ? nullptr : extra_sample_args.c_str();
high_noise_sample_params.extra_sample_args = high_noise_extra_sample_args.empty() ? nullptr : high_noise_extra_sample_args.c_str();
vae_tiling_params.extra_tiling_args = extra_tiling_args.empty() ? nullptr : extra_tiling_args.c_str();
cache_params.scm_mask = scm_mask.empty() ? nullptr : scm_mask.c_str();
sd_pm_params_t pm_params = {
@ -2261,6 +2271,7 @@ sd_vid_gen_params_t SDGenerationParams::to_sd_vid_gen_params_t() {
sample_params.custom_sigmas_count = static_cast<int>(custom_sigmas.size());
sample_params.extra_sample_args = extra_sample_args.empty() ? nullptr : extra_sample_args.c_str();
high_noise_sample_params.extra_sample_args = high_noise_extra_sample_args.empty() ? nullptr : high_noise_extra_sample_args.c_str();
vae_tiling_params.extra_tiling_args = extra_tiling_args.empty() ? nullptr : extra_tiling_args.c_str();
cache_params.scm_mask = scm_mask.empty() ? nullptr : scm_mask.c_str();
params.loras = lora_vec.empty() ? nullptr : lora_vec.data();
@ -2386,7 +2397,8 @@ std::string SDGenerationParams::to_string() const {
<< vae_tiling_params.tile_size_y << ", "
<< vae_tiling_params.target_overlap << ", "
<< vae_tiling_params.rel_size_x << ", "
<< vae_tiling_params.rel_size_y << " },\n"
<< vae_tiling_params.rel_size_y << ", "
<< "\"" << extra_tiling_args << "\" },\n"
<< "}";
return oss.str();
}
@ -2565,14 +2577,18 @@ std::string build_sdcpp_image_metadata_json(const SDContextParams& ctx_params,
};
}
if (gen_params.vae_tiling_params.enabled) {
if (gen_params.vae_tiling_params.enabled ||
gen_params.vae_tiling_params.temporal_tiling ||
!gen_params.extra_tiling_args.empty()) {
root["vae_tiling"] = {
{"enabled", gen_params.vae_tiling_params.enabled},
{"temporal_tiling", gen_params.vae_tiling_params.temporal_tiling},
{"tile_size_x", gen_params.vae_tiling_params.tile_size_x},
{"tile_size_y", gen_params.vae_tiling_params.tile_size_y},
{"target_overlap", gen_params.vae_tiling_params.target_overlap},
{"rel_size_x", gen_params.vae_tiling_params.rel_size_x},
{"rel_size_y", gen_params.vae_tiling_params.rel_size_y},
{"extra_tiling_args", gen_params.extra_tiling_args},
};
}

View file

@ -189,7 +189,8 @@ struct SDGenerationParams {
int video_frames = 1;
int fps = 16;
float vace_strength = 1.f;
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f};
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
std::string extra_tiling_args;
std::string pm_id_images_dir;
std::string pm_id_embed_path;

View file

@ -496,84 +496,26 @@ struct LTX2Scheduler : SigmaScheduler {
parse_extra_sample_args(extra_sample_args);
}
static std::string trim(std::string value) {
const char* whitespace = " \t\r\n";
size_t begin = value.find_first_not_of(whitespace);
if (begin == std::string::npos) {
return "";
}
size_t end = value.find_last_not_of(whitespace);
return value.substr(begin, end - begin + 1);
}
void parse_extra_sample_args(const char* extra_sample_args) {
if (extra_sample_args == nullptr || extra_sample_args[0] == '\0') {
return;
}
std::string raw(extra_sample_args);
size_t start = 0;
auto parse_arg = [&](const std::string& item) {
std::string token = trim(item);
if (token.empty()) {
return;
}
size_t eq = token.find('=');
if (eq == std::string::npos) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
return;
}
std::string key = trim(token.substr(0, eq));
std::string value = trim(token.substr(eq + 1));
auto parse_float = [&](float* out) -> bool {
try {
size_t consumed = 0;
float parsed = std::stof(value, &consumed);
if (!trim(value.substr(consumed)).empty()) {
return false;
}
*out = parsed;
return true;
} catch (const std::exception&) {
return false;
for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "ltx2 scheduler arg")) {
if (key == "max_shift") {
if (!parse_strict_float(value, max_shift)) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
}
};
try {
if (key == "max_shift") {
if (!parse_float(&max_shift)) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
}
} else if (key == "base_shift") {
if (!parse_float(&base_shift)) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
}
} else if (key == "terminal") {
if (!parse_float(&terminal)) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
}
} else if (key == "stretch") {
std::string v = value;
std::transform(v.begin(), v.end(), v.begin(), [](unsigned char c) { return static_cast<char>(std::tolower(c)); });
if (v == "1" || v == "true" || v == "yes" || v == "on") {
stretch = true;
} else if (v == "0" || v == "false" || v == "no" || v == "off") {
stretch = false;
} else {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
}
} else {
LOG_WARN("ignoring unknown ltx2 scheduler arg '%s'", key.c_str());
} else if (key == "base_shift") {
if (!parse_strict_float(value, base_shift)) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
}
} catch (const std::exception&) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
}
};
for (size_t pos = 0; pos <= raw.size(); ++pos) {
if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') {
parse_arg(raw.substr(start, pos - start));
start = pos + 1;
} else if (key == "terminal") {
if (!parse_strict_float(value, terminal)) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
}
} else if (key == "stretch") {
if (!parse_strict_bool(value, stretch)) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
}
} else {
LOG_WARN("ignoring unknown ltx2 scheduler arg '%s'", key.c_str());
}
}
}
@ -1276,7 +1218,7 @@ static sd::Tensor<float> sample_dpmpp_2m_v2(denoise_cb_t model,
return x;
}
using SamplerExtraArgs = std::vector<std::pair<std::string, std::string>>;
using SamplerExtraArgs = KeyValueArgs;
static sd::Tensor<float> sample_lcm(denoise_cb_t model,
sd::Tensor<float> x,
@ -1296,15 +1238,8 @@ static sd::Tensor<float> sample_lcm(denoise_cb_t model,
for (const auto& [key, value] : extra_sample_args) {
float parsed = 0.0f;
try {
size_t consumed = 0;
parsed = std::stof(value, &consumed);
if (trim(value.substr(consumed)).size() != 0) {
LOG_WARN("ignoring invalid lcm extra sample arg '%s'", key.c_str());
continue;
}
} catch (const std::exception&) {
LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str());
if (!parse_strict_float(value, parsed)) {
LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str(), value.c_str());
continue;
}
if (key == "noise_clip_std") {
@ -1861,15 +1796,8 @@ static sd::Tensor<float> sample_gradient_estimation(denoise_cb_t model,
for (const auto& [key, value] : extra_sample_args) {
float parsed = 0.0f;
try {
size_t consumed = 0;
parsed = std::stof(value, &consumed);
if (trim(value.substr(consumed)).size() != 0) {
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str());
continue;
}
} catch (const std::exception&) {
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str());
if (!parse_strict_float(value, parsed)) {
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s=%s'", key.c_str(), value.c_str());
continue;
}
if (key == "gamma") {
@ -1916,46 +1844,6 @@ static sd::Tensor<float> sample_gradient_estimation(denoise_cb_t model,
return x;
}
static SamplerExtraArgs parse_sampler_args(const char* extra_sample_args) {
SamplerExtraArgs pairs;
if (extra_sample_args == nullptr || extra_sample_args[0] == '\0') {
return pairs;
}
auto trim = [](std::string value) -> std::string {
const char* whitespace = " \t\r\n";
size_t begin = value.find_first_not_of(whitespace);
if (begin == std::string::npos) {
return "";
}
size_t end = value.find_last_not_of(whitespace);
return value.substr(begin, end - begin + 1);
};
std::string raw(extra_sample_args);
size_t start = 0;
for (size_t pos = 0; pos <= raw.size(); ++pos) {
if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') {
std::string item = raw.substr(start, pos - start);
std::string token = trim(item);
if (!token.empty()) {
size_t eq = token.find('=');
if (eq != std::string::npos) {
std::string key = trim(token.substr(0, eq));
std::string value = trim(token.substr(eq + 1));
pairs.emplace_back(std::move(key), std::move(value));
}
}
start = pos + 1;
}
}
return pairs;
}
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
denoise_cb_t model,
@ -1965,7 +1853,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
float eta,
bool is_flow_denoiser,
const char* extra_sample_args) {
SamplerExtraArgs extra_args = parse_sampler_args(extra_sample_args);
SamplerExtraArgs extra_args = parse_key_value_args(extra_sample_args, "extra sample arg");
switch (method) {
case EULER_A_SAMPLE_METHOD:
return sample_euler_ancestral(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);

View file

@ -1602,6 +1602,23 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
return num;
}
__STATIC_INLINE__ ggml_tensor* ggml_ext_vec_concat(ggml_context* ctx,
std::vector<ggml_tensor*>& tensors,
int dim) {
while (tensors.size() > 1) {
std::vector<ggml_tensor*> next_level;
for (size_t i = 0; i < tensors.size(); i += 2) {
if (i + 1 < tensors.size()) {
next_level.push_back(ggml_concat(ctx, tensors[i], tensors[i + 1], dim));
} else {
next_level.push_back(tensors[i]);
}
}
tensors = std::move(next_level);
}
return tensors[0];
}
/* SDXL with LoRA requires more space */
#define MAX_PARAMS_TENSOR_NUM 32768
#define MAX_GRAPH_SIZE 327680
@ -3139,6 +3156,163 @@ public:
}
};
class Conv2d_grouped : public UnaryBlock {
protected:
int64_t in_channels;
int64_t out_channels;
int groups;
std::pair<int, int> kernel_size;
std::pair<int, int> stride;
std::pair<int, int> padding;
std::pair<int, int> dilation;
bool bias;
float scale = 1.f;
std::string prefix;
void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override {
this->prefix = prefix;
enum ggml_type wtype = GGML_TYPE_F16;
params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels / groups, out_channels);
if (bias) {
enum ggml_type wtype = GGML_TYPE_F32;
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_channels);
}
}
public:
Conv2d_grouped(int64_t in_channels,
int64_t out_channels,
int groups,
std::pair<int, int> kernel_size,
std::pair<int, int> stride = {1, 1},
std::pair<int, int> padding = {0, 0},
std::pair<int, int> dilation = {1, 1},
bool bias = true)
: in_channels(in_channels),
out_channels(out_channels),
groups(groups),
kernel_size(kernel_size),
stride(stride),
padding(padding),
dilation(dilation),
bias(bias) {}
void set_scale(float scale_value) {
scale = scale_value;
}
std::string get_desc() {
return "Conv2d_grouped";
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
ggml_tensor* w = params["weight"];
ggml_tensor* b = nullptr;
if (bias) {
b = params["bias"];
}
if (groups == 1) {
if (ctx->weight_adapter) {
WeightAdapter::ForwardParams forward_params;
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_CONV2D;
forward_params.conv2d.s0 = stride.second;
forward_params.conv2d.s1 = stride.first;
forward_params.conv2d.p0 = padding.second;
forward_params.conv2d.p1 = padding.first;
forward_params.conv2d.d0 = dilation.second;
forward_params.conv2d.d1 = dilation.first;
forward_params.conv2d.direct = ctx->conv2d_direct_enabled;
forward_params.conv2d.circular_x = ctx->circular_x_enabled;
forward_params.conv2d.circular_y = ctx->circular_y_enabled;
forward_params.conv2d.scale = scale;
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, ctx->backend, x, w, b, prefix, forward_params);
}
return ggml_ext_conv_2d(ctx->ggml_ctx, x, w, b,
stride.second, stride.first,
padding.second, padding.first,
dilation.second, dilation.first,
ctx->conv2d_direct_enabled,
ctx->circular_x_enabled,
ctx->circular_y_enabled,
scale);
}
if (groups == in_channels && groups == out_channels) {
ggml_tensor* res;
if (ctx->conv2d_direct_enabled) {
res = ggml_conv_2d_dw_direct(ctx->ggml_ctx, x, w,
stride.second, stride.first,
padding.second, padding.first,
dilation.second, dilation.first);
} else {
res = ggml_conv_2d_dw(ctx->ggml_ctx, x, w,
stride.second, stride.first,
padding.second, padding.first,
dilation.second, dilation.first);
}
if (b) {
res = ggml_add(ctx->ggml_ctx, res, b);
}
return res;
}
int64_t ic_g = in_channels / groups;
int64_t oc_g = out_channels / groups;
std::vector<ggml_tensor*> out_slices(groups);
for (int i = 0; i < groups; ++i) {
size_t x_offset = i * ic_g * x->nb[2];
ggml_tensor* x_i = ggml_view_4d(ctx->ggml_ctx, x,
x->ne[0], x->ne[1], ic_g, x->ne[3],
x->nb[1], x->nb[2], x->nb[3],
x_offset);
size_t w_offset = i * oc_g * w->nb[3];
ggml_tensor* w_i = ggml_view_4d(ctx->ggml_ctx, w,
w->ne[0], w->ne[1], w->ne[2], oc_g,
w->nb[1], w->nb[2], w->nb[3],
w_offset);
ggml_tensor* b_i = nullptr;
if (b) {
size_t b_offset = i * oc_g * b->nb[0];
b_i = ggml_view_1d(ctx->ggml_ctx, b, oc_g, b_offset);
}
if (ctx->weight_adapter) {
WeightAdapter::ForwardParams forward_params;
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_CONV2D;
forward_params.conv2d.s0 = stride.second;
forward_params.conv2d.s1 = stride.first;
forward_params.conv2d.p0 = padding.second;
forward_params.conv2d.p1 = padding.first;
forward_params.conv2d.d0 = dilation.second;
forward_params.conv2d.d1 = dilation.first;
forward_params.conv2d.direct = ctx->conv2d_direct_enabled;
forward_params.conv2d.circular_x = ctx->circular_x_enabled;
forward_params.conv2d.circular_y = ctx->circular_y_enabled;
forward_params.conv2d.scale = scale;
out_slices[i] = ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, ctx->backend, x_i, w_i, b_i, prefix, forward_params);
} else {
out_slices[i] = ggml_ext_conv_2d(ctx->ggml_ctx, x_i, w_i, b_i,
stride.second, stride.first,
padding.second, padding.first,
dilation.second, dilation.first,
ctx->conv2d_direct_enabled,
ctx->circular_x_enabled,
ctx->circular_y_enabled,
scale);
}
}
ggml_tensor* out = ggml_ext_vec_concat(ctx->ggml_ctx, out_slices, 2);
return out;
}
};
class Conv3d : public UnaryBlock {
protected:
int64_t in_channels;

View file

@ -2,6 +2,7 @@
#define __SD_LTX_AUDIO_VAE_H__
#include <cmath>
#include <limits>
#include <numeric>
#include <string>
#include <vector>
@ -171,90 +172,59 @@ namespace LTXV {
}
};
static sd::Tensor<float> squeeze_trailing_singleton_dims(sd::Tensor<float> tensor) {
while (tensor.dim() > 0 && tensor.shape().back() == 1) {
tensor = tensor.squeeze(static_cast<size_t>(tensor.dim() - 1));
}
return tensor;
}
static ggml_tensor* compute_log_mel_spectrogram(GGMLRunnerContext* runner_ctx,
ggml_tensor* waveform,
ggml_tensor* forward_basis,
ggml_tensor* mel_basis,
int hop_length) {
auto ctx = runner_ctx->ggml_ctx;
GGML_ASSERT(ctx != nullptr);
GGML_ASSERT(waveform != nullptr);
GGML_ASSERT(forward_basis != nullptr);
GGML_ASSERT(mel_basis != nullptr);
GGML_ASSERT(waveform->type == GGML_TYPE_F32);
GGML_ASSERT(forward_basis->type == GGML_TYPE_F32);
GGML_ASSERT(mel_basis->type == GGML_TYPE_F32);
GGML_ASSERT(forward_basis->ne[1] == 1);
static sd::Tensor<float> normalize_waveform_for_host(sd::Tensor<float> waveform) {
waveform = squeeze_trailing_singleton_dims(std::move(waveform));
if (waveform.empty()) {
return waveform;
}
if (waveform.dim() == 1) {
return waveform.reshape({waveform.shape()[0], 1, 1});
}
if (waveform.dim() == 2) {
return waveform.reshape({waveform.shape()[0], waveform.shape()[1], 1});
}
if (waveform.dim() == 3) {
return waveform;
}
throw std::runtime_error("Unsupported waveform rank for host processing: rank=" + std::to_string(waveform.dim()));
}
const int64_t time = waveform->ne[0];
const int64_t channels = waveform->ne[1];
const int64_t batch = waveform->ne[2];
const int64_t filter_len = forward_basis->ne[0];
const int64_t stft_channels = forward_basis->ne[2];
const int64_t n_freqs = stft_channels / 2;
const int64_t n_mels = mel_basis->ne[1];
const int64_t left_pad = std::max<int64_t>(0, filter_len - hop_length);
const int64_t padded_time = time + left_pad;
const int64_t frame_count = padded_time < filter_len ? 0 : 1 + (padded_time - filter_len) / hop_length;
static sd::Tensor<float> load_param_tensor_f32(ggml_tensor* tensor) {
GGML_ASSERT(tensor != nullptr);
return squeeze_trailing_singleton_dims(sd::make_sd_tensor_from_ggml<float>(tensor));
}
GGML_ASSERT(stft_channels % 2 == 0);
GGML_ASSERT(mel_basis->ne[0] == n_freqs);
GGML_ASSERT(waveform->ne[3] == 1);
GGML_ASSERT(frame_count > 0);
static sd::Tensor<float> compute_log_mel_spectrogram(const sd::Tensor<float>& waveform_in,
const sd::Tensor<float>& forward_basis,
const sd::Tensor<float>& mel_basis,
int hop_length) {
auto waveform = normalize_waveform_for_host(waveform_in);
GGML_ASSERT(forward_basis.dim() >= 3);
GGML_ASSERT(mel_basis.dim() >= 2);
const int64_t time = waveform.shape()[0];
const int64_t channels = waveform.shape()[1];
const int64_t batch = waveform.shape()[2];
const int64_t filter_len = forward_basis.shape()[0];
const int64_t basis_freq2 = forward_basis.shape().back();
const int64_t n_freqs = basis_freq2 / 2;
const int64_t n_mels = mel_basis.shape()[1];
const int64_t left_pad = std::max<int64_t>(0, filter_len - hop_length);
const int64_t padded_time = time + left_pad;
const int64_t frame_count = padded_time < filter_len ? 0 : 1 + (padded_time - filter_len) / hop_length;
sd::Tensor<float> log_mel({n_mels, frame_count, channels, batch});
std::vector<float> padded(static_cast<size_t>(padded_time), 0.0f);
std::vector<float> magnitude(static_cast<size_t>(n_freqs), 0.0f);
for (int64_t b = 0; b < batch; ++b) {
for (int64_t c = 0; c < channels; ++c) {
std::fill(padded.begin(), padded.end(), 0.0f);
for (int64_t t = 0; t < time; ++t) {
padded[static_cast<size_t>(t + left_pad)] = waveform.index(t, c, b);
}
for (int64_t frame = 0; frame < frame_count; ++frame) {
const int64_t frame_offset = frame * hop_length;
for (int64_t f = 0; f < n_freqs; ++f) {
double real = 0.0;
double imag = 0.0;
for (int64_t k = 0; k < filter_len; ++k) {
const float sample = padded[static_cast<size_t>(frame_offset + k)];
real += static_cast<double>(sample) * static_cast<double>(forward_basis.index(k, 0, f));
imag += static_cast<double>(sample) * static_cast<double>(forward_basis.index(k, 0, f + n_freqs));
}
magnitude[static_cast<size_t>(f)] = static_cast<float>(std::sqrt(real * real + imag * imag));
}
for (int64_t m = 0; m < n_mels; ++m) {
double mel_value = 0.0;
for (int64_t f = 0; f < n_freqs; ++f) {
mel_value += static_cast<double>(mel_basis.index(f, m)) * static_cast<double>(magnitude[static_cast<size_t>(f)]);
}
log_mel.index(m, frame, c, b) = static_cast<float>(std::log(std::max(mel_value, 1e-5)));
}
}
}
auto x = ggml_reshape_3d(ctx, waveform, time, 1, channels * batch);
if (left_pad > 0) {
x = ggml_pad_ext(ctx, x, static_cast<int>(left_pad), 0, 0, 0, 0, 0, 0, 0);
}
return log_mel;
auto frames = ggml_conv_1d(ctx, forward_basis, x, hop_length, 0, 1);
GGML_ASSERT(frames->ne[0] == frame_count);
GGML_ASSERT(frames->ne[1] == stft_channels);
GGML_ASSERT(frames->ne[2] == channels * batch);
auto real = ggml_ext_slice(ctx, frames, 1, 0, n_freqs);
auto imag = ggml_ext_slice(ctx, frames, 1, n_freqs, stft_channels);
auto magnitude = ggml_sqrt(ctx,
ggml_add(ctx,
ggml_sqr(ctx, real),
ggml_sqr(ctx, imag)));
magnitude = ggml_cont(ctx, ggml_permute(ctx, magnitude, 1, 0, 2, 3));
auto mel = ggml_mul_mat(ctx, mel_basis, magnitude);
mel = ggml_log(ctx, ggml_clamp(ctx, mel, 1e-5f, std::numeric_limits<float>::max()));
return ggml_reshape_4d(ctx, mel, n_mels, frame_count, channels, batch);
}
static std::vector<float> build_hann_resample_filter(int ratio) {
@ -276,75 +246,6 @@ namespace LTXV {
return filter;
}
static sd::Tensor<float> upsample_waveform_hann(const sd::Tensor<float>& waveform_in, int ratio) {
auto waveform = normalize_waveform_for_host(waveform_in);
if (ratio <= 1) {
return waveform;
}
const int lowpass_filter_width = 6;
const double rolloff = 0.99;
const int width = static_cast<int>(std::ceil(static_cast<double>(lowpass_filter_width) / rolloff));
const int kernel_size = 2 * width * ratio + 1;
const int pad = width;
const int pad_left = 2 * width * ratio;
const int pad_right = kernel_size - ratio;
const int64_t time = waveform.shape()[0];
const int64_t channels = waveform.shape()[1];
const int64_t batch = waveform.shape()[2];
const int64_t padded_time = time + 2 * pad;
const int64_t conv_out_time = (padded_time - 1) * ratio + kernel_size;
const int64_t cropped_time = conv_out_time - pad_left - pad_right;
auto filter = build_hann_resample_filter(ratio);
sd::Tensor<float> output({cropped_time, channels, batch});
std::vector<float> padded(static_cast<size_t>(padded_time), 0.0f);
std::vector<float> conv_out(static_cast<size_t>(conv_out_time), 0.0f);
for (int64_t b = 0; b < batch; ++b) {
for (int64_t c = 0; c < channels; ++c) {
std::fill(padded.begin(), padded.end(), 0.0f);
const float first = waveform.index(0, c, b);
const float last = waveform.index(time - 1, c, b);
for (int i = 0; i < pad; ++i) {
padded[static_cast<size_t>(i)] = first;
padded[static_cast<size_t>(pad + time + i)] = last;
}
for (int64_t t = 0; t < time; ++t) {
padded[static_cast<size_t>(pad + t)] = waveform.index(t, c, b);
}
std::fill(conv_out.begin(), conv_out.end(), 0.0f);
for (int64_t t = 0; t < padded_time; ++t) {
const double sample = static_cast<double>(padded[static_cast<size_t>(t)]) * ratio;
const int64_t out_base = t * ratio;
for (int k = 0; k < kernel_size; ++k) {
conv_out[static_cast<size_t>(out_base + k)] += static_cast<float>(sample * filter[static_cast<size_t>(k)]);
}
}
for (int64_t t = 0; t < cropped_time; ++t) {
output.index(t, c, b) = conv_out[static_cast<size_t>(t + pad_left)];
}
}
}
return output;
}
static sd::Tensor<float> crop_waveform_samples(const sd::Tensor<float>& waveform_in, int64_t target_samples) {
auto waveform = normalize_waveform_for_host(waveform_in);
if (waveform.shape()[0] == target_samples) {
return waveform;
}
if (waveform.shape()[0] > target_samples) {
return sd::ops::slice(waveform, 0, 0, target_samples);
}
sd::Tensor<float> output({target_samples, waveform.shape()[1], waveform.shape()[2]});
sd::ops::slice_assign(&output, 0, 0, waveform.shape()[0], waveform);
return output;
}
static ggml_type audio_conv_weight_type(ggml_type type) {
return type == GGML_TYPE_BF16 ? GGML_TYPE_F16 : type;
}
@ -413,22 +314,101 @@ namespace LTXV {
return ggml_reshape_4d(ctx, out, out->ne[0], out->ne[1], 1, 1);
}
static ggml_tensor* reverse_1d_filter(ggml_context* ctx, ggml_tensor* filter) {
GGML_ASSERT(ctx != nullptr);
GGML_ASSERT(filter != nullptr);
GGML_ASSERT(filter->ne[1] == 1);
GGML_ASSERT(filter->ne[2] == 1);
GGML_ASSERT(filter->ne[3] == 1);
ggml_tensor* reversed = nullptr;
for (int64_t k = filter->ne[0] - 1; k >= 0; --k) {
auto slice = ggml_ext_slice(ctx, filter, 0, k, k + 1);
reversed = reversed == nullptr ? slice : ggml_concat(ctx, reversed, slice, 0);
}
return reversed;
}
static ggml_tensor* depthwise_conv_transpose1d(ggml_context* ctx,
ggml_tensor* x,
ggml_tensor* filter,
int stride) {
GGML_ASSERT(x->ne[2] == 1 && x->ne[3] == 1);
GGML_ASSERT(filter->ne[1] == 1);
GGML_ASSERT(filter->ne[2] == 1 && filter->ne[3] == 1);
ggml_tensor* out = nullptr;
for (int64_t c = 0; c < x->ne[1]; ++c) {
auto xi = ggml_ext_slice(ctx, x, 1, c, c + 1);
auto yi = ggml_conv_transpose_1d(ctx, filter, xi, stride, 0, 1);
yi = ggml_ext_scale(ctx, yi, static_cast<float>(stride));
yi = ggml_reshape_4d(ctx, yi, yi->ne[0], 1, 1, 1);
out = out == nullptr ? yi : ggml_concat(ctx, out, yi, 1);
const int64_t time = x->ne[0];
const int64_t channels = x->ne[1];
const int64_t kernel_size = filter->ne[0];
const int64_t out_time = (time - 1) * stride + kernel_size;
auto x_flat = ggml_reshape_3d(ctx, x, 1, time, channels);
if (stride > 1) {
auto zero_unit = ggml_ext_scale(ctx, x_flat, 0.0f);
auto zero_tail = zero_unit;
for (int i = 1; i < stride - 1; ++i) {
zero_tail = ggml_concat(ctx, zero_tail, zero_unit, 0);
}
x_flat = ggml_concat(ctx, x_flat, zero_tail, 0);
}
return out;
x_flat = ggml_reshape_3d(ctx, x_flat, time * stride, 1, channels);
auto reversed_filter = reverse_1d_filter(ctx, filter);
auto out = ggml_conv_1d(ctx, reversed_filter, x_flat, 1, static_cast<int>(kernel_size - 1), 1);
if (out->ne[0] > out_time) {
out = ggml_ext_slice(ctx, out, 0, 0, out_time);
}
GGML_ASSERT(out->ne[0] == out_time);
GGML_ASSERT(out->ne[1] == 1);
GGML_ASSERT(out->ne[2] == channels);
out = ggml_ext_scale(ctx, out, static_cast<float>(stride));
return ggml_reshape_4d(ctx, out, out_time, channels, 1, 1);
}
static ggml_tensor* upsample_waveform_hann(GGMLRunnerContext* runner_ctx,
ggml_tensor* waveform,
ggml_tensor* filter,
int ratio) {
auto ctx = runner_ctx->ggml_ctx;
GGML_ASSERT(ctx != nullptr);
GGML_ASSERT(waveform != nullptr);
GGML_ASSERT(filter != nullptr);
GGML_ASSERT(waveform->ne[3] == 1);
if (ratio <= 1) {
return waveform;
}
const int lowpass_filter_width = 6;
const double rolloff = 0.99;
const int width = static_cast<int>(std::ceil(static_cast<double>(lowpass_filter_width) / rolloff));
const int kernel_size = 2 * width * ratio + 1;
const int pad = width;
const int pad_left = 2 * width * ratio;
const int pad_right = kernel_size - ratio;
const int64_t time = waveform->ne[0];
const int64_t channels = waveform->ne[1];
const int64_t batch = waveform->ne[2];
GGML_ASSERT(filter->ne[0] == kernel_size);
auto x = ggml_reshape_3d(ctx, waveform, time, channels * batch, 1);
x = replicate_pad_1d(runner_ctx, x, pad, pad);
x = depthwise_conv_transpose1d(ctx, x, filter, ratio);
x = ggml_ext_slice(ctx, x, 0, pad_left, x->ne[0] - pad_right);
return ggml_reshape_3d(ctx, x, x->ne[0], channels, batch);
}
static ggml_tensor* crop_waveform_samples(ggml_context* ctx,
ggml_tensor* waveform,
int64_t target_samples) {
GGML_ASSERT(ctx != nullptr);
GGML_ASSERT(waveform != nullptr);
if (waveform->ne[0] == target_samples) {
return waveform;
}
GGML_ASSERT(waveform->ne[0] > target_samples);
return ggml_ext_slice(ctx, waveform, 0, 0, target_samples);
}
struct PixelNorm2D : public UnaryBlock {
@ -950,41 +930,66 @@ namespace LTXV {
}
}
ggml_tensor* decode_to_mel(GGMLRunnerContext* ctx,
ggml_tensor* latent,
int target_time,
int target_freq) {
auto mean = params["audio_vae.per_channel_statistics.mean-of-means"];
auto stddev = params["audio_vae.per_channel_statistics.std-of-means"];
auto decoder = std::dynamic_pointer_cast<AudioDecoder>(blocks["audio_vae.decoder"]);
return decoder->forward(ctx, latent, mean, stddev, target_time, target_freq);
}
ggml_tensor* decode(GGMLRunnerContext* ctx,
ggml_tensor* latent,
ggml_tensor* bwe_skip_filter) {
int target_time = static_cast<int>(latent->ne[1]) * config.latent_downsample_factor() -
(config.latent_downsample_factor() - 1);
int target_freq = config.mel_bins;
ggml_tensor* run_vocoder(GGMLRunnerContext* ctx, ggml_tensor* mel) {
auto vocoder = std::dynamic_pointer_cast<Vocoder>(blocks["vocoder.vocoder"]);
return vocoder->forward(ctx, mel);
}
auto decoder = std::dynamic_pointer_cast<AudioDecoder>(blocks["audio_vae.decoder"]);
auto mean = params["audio_vae.per_channel_statistics.mean-of-means"];
auto stddev = params["audio_vae.per_channel_statistics.std-of-means"];
auto mel = decoder->forward(ctx, latent, mean, stddev, target_time, target_freq);
auto vocoder = std::dynamic_pointer_cast<Vocoder>(blocks["vocoder.vocoder"]);
auto waveform = vocoder->forward(ctx, mel);
ggml_tensor* run_bwe_generator(GGMLRunnerContext* ctx, ggml_tensor* mel) {
GGML_ASSERT(config.has_bwe);
auto bwe_generator = std::dynamic_pointer_cast<Vocoder>(blocks["vocoder.bwe_generator"]);
return bwe_generator->forward(ctx, mel);
}
if (config.has_bwe) {
GGML_ASSERT(bwe_skip_filter != nullptr);
const int bwe_ratio = config.bwe_output_sample_rate / config.bwe_input_sample_rate;
const int64_t low_time = waveform->ne[0];
const int64_t out_time = low_time * bwe_ratio;
int64_t remainder = low_time % config.bwe_hop_length;
auto bwe_waveform = waveform;
if (remainder != 0) {
bwe_waveform = ggml_pad_ext(ctx->ggml_ctx,
bwe_waveform,
0,
static_cast<int>(config.bwe_hop_length - remainder),
0,
0,
0,
0,
0,
0);
}
ggml_tensor* mel_basis_tensor() const {
auto iter = params.find("vocoder.mel_stft.mel_basis");
return iter == params.end() ? nullptr : iter->second;
}
auto mel_basis = params["vocoder.mel_stft.mel_basis"];
auto stft_basis = params["vocoder.mel_stft.stft_fn.forward_basis"];
GGML_ASSERT(mel_basis != nullptr && stft_basis != nullptr);
auto bwe_mel = compute_log_mel_spectrogram(ctx, bwe_waveform, stft_basis, mel_basis, config.bwe_hop_length);
auto bwe_generator = std::dynamic_pointer_cast<Vocoder>(blocks["vocoder.bwe_generator"]);
auto residual = bwe_generator->forward(ctx, bwe_mel);
ggml_tensor* stft_forward_basis_tensor() const {
auto iter = params.find("vocoder.mel_stft.stft_fn.forward_basis");
return iter == params.end() ? nullptr : iter->second;
auto skip = upsample_waveform_hann(ctx,
bwe_waveform,
bwe_skip_filter,
bwe_ratio);
waveform = ggml_clamp(ctx->ggml_ctx,
ggml_add(ctx->ggml_ctx, residual, skip),
-1.0f,
1.0f);
waveform = crop_waveform_samples(ctx->ggml_ctx, waveform, out_time);
}
return waveform;
}
};
struct LTXAudioVAERunner : public GGMLRunner {
LTXAudioVAEConfig config;
LTXAudioVAE model;
sd::Tensor<float> bwe_skip_filter_tensor;
LTXAudioVAERunner(ggml_backend_t backend,
ggml_backend_t params_backend,
@ -994,6 +999,10 @@ namespace LTXV {
config(LTXAudioVAEConfig::detect_from_weights(tensor_storage_map)),
model(config) {
model.init(params_ctx, tensor_storage_map, prefix);
if (config.has_bwe) {
const int bwe_ratio = config.bwe_output_sample_rate / config.bwe_input_sample_rate;
bwe_skip_filter_tensor = sd::Tensor<float>::from_vector(build_hann_resample_filter(bwe_ratio));
}
}
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) {
@ -1008,77 +1017,22 @@ namespace LTXV {
return "ltx_audio_vae";
}
ggml_cgraph* build_base_graph(const sd::Tensor<float>& latent_tensor) {
auto latent = make_input(latent_tensor);
int target_time = static_cast<int>(latent_tensor.shape()[1]) * config.latent_downsample_factor() -
(config.latent_downsample_factor() - 1);
int target_freq = config.mel_bins;
ggml_cgraph* gf = new_graph_custom(655360);
auto runner_ctx = GGMLRunner::get_context();
auto mel = model.decode_to_mel(&runner_ctx, latent, target_time, target_freq);
auto waveform = model.run_vocoder(&runner_ctx, mel);
ggml_build_forward_expand(gf, waveform);
return gf;
}
ggml_cgraph* build_bwe_graph(const sd::Tensor<float>& mel_tensor) {
auto mel = make_input(mel_tensor);
ggml_cgraph* gf = new_graph_custom(655360);
auto runner_ctx = GGMLRunner::get_context();
auto residual = model.run_bwe_generator(&runner_ctx, mel);
ggml_build_forward_expand(gf, residual);
return gf;
}
sd::Tensor<float> compute_base_waveform(int n_threads,
const sd::Tensor<float>& latent_tensor) {
auto get_graph = [&]() -> ggml_cgraph* {
return build_base_graph(latent_tensor);
};
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), 4);
}
sd::Tensor<float> compute_bwe_residual(int n_threads,
const sd::Tensor<float>& mel_tensor) {
auto get_graph = [&]() -> ggml_cgraph* {
return build_bwe_graph(mel_tensor);
};
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), 4);
}
sd::Tensor<float> decode(int n_threads,
const sd::Tensor<float>& latent_tensor) {
auto waveform = compute_base_waveform(n_threads, latent_tensor);
if (!config.has_bwe || waveform.empty()) {
return waveform;
}
auto waveform_host = normalize_waveform_for_host(waveform);
const int64_t low_time = waveform_host.shape()[0];
const int64_t out_time = low_time * config.bwe_output_sample_rate / config.bwe_input_sample_rate;
int64_t remainder = low_time % config.bwe_hop_length;
if (remainder != 0) {
sd::Tensor<float> padded({low_time + (config.bwe_hop_length - remainder), waveform_host.shape()[1], waveform_host.shape()[2]});
sd::ops::slice_assign(&padded, 0, 0, low_time, waveform_host);
waveform_host = std::move(padded);
}
auto mel_basis_tensor = model.mel_basis_tensor();
auto stft_basis_tensor = model.stft_forward_basis_tensor();
GGML_ASSERT(mel_basis_tensor != nullptr && stft_basis_tensor != nullptr);
auto mel_basis = load_param_tensor_f32(mel_basis_tensor);
auto forward_basis = load_param_tensor_f32(stft_basis_tensor);
auto bwe_mel = compute_log_mel_spectrogram(waveform_host, forward_basis, mel_basis, config.bwe_hop_length);
auto residual_raw = compute_bwe_residual(n_threads, bwe_mel);
if (residual_raw.empty()) {
return waveform;
}
auto residual = normalize_waveform_for_host(residual_raw);
auto skip = upsample_waveform_hann(waveform_host, config.bwe_output_sample_rate / config.bwe_input_sample_rate);
auto combined = sd::ops::clamp(residual + skip, -1.0f, 1.0f);
auto cropped = crop_waveform_samples(combined, out_time);
return restore_trailing_singleton_dims(cropped, 4);
int64_t t0 = ggml_time_ms();
auto get_graph = [&]() -> ggml_cgraph* {
auto latent = make_input(latent_tensor);
ggml_tensor* bwe_skip_filter = config.has_bwe ? make_input(bwe_skip_filter_tensor) : nullptr;
ggml_cgraph* gf = new_graph_custom(655360);
auto runner_ctx = GGMLRunner::get_context();
auto waveform = model.decode(&runner_ctx, latent, bwe_skip_filter);
ggml_build_forward_expand(gf, waveform);
return gf;
};
auto result = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), 4);
int64_t t1 = ggml_time_ms();
LOG_INFO("ltx audio vae decode completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
return result;
}
void test(const std::string& input_path) {

View file

@ -1,6 +1,7 @@
#ifndef __SD_LTX_VAE_HPP__
#define __SD_LTX_VAE_HPP__
#include <algorithm>
#include <fstream>
#include <memory>
#include <string>
@ -143,16 +144,25 @@ namespace LTXVAE {
std::vector<ggml_tensor*>& feat_map,
int& feat_idx,
int chunk_idx,
bool causal = true) {
bool causal = true,
int temporal_pad = 0) {
auto conv = std::dynamic_pointer_cast<Conv3d>(blocks["conv"]);
const int pad = causal ? (time_kernel_size - 1) : (time_kernel_size - 1) / 2;
ggml_tensor* prev = (feat_idx < (int)feat_map.size()) ? feat_map[feat_idx] : nullptr;
GGML_ASSERT(x->ne[2] >= temporal_pad);
int end_idx = x->ne[2] - temporal_pad;
int start_idx = std::max(end_idx - pad, 0);
// Save a contiguous copy of the last `pad` frames so the large `x`
// tensor is not kept alive across iterations by a dangling view.
if (feat_idx < (int)feat_map.size() && pad > 0 && x->ne[2] >= pad) {
auto slice = ggml_ext_slice(ctx->ggml_ctx, x, 2, x->ne[2] - pad, x->ne[2]);
if (feat_idx < (int)feat_map.size() && end_idx - start_idx > 0) {
GGML_ASSERT(start_idx >= 0);
GGML_ASSERT(end_idx > 0);
auto slice = ggml_ext_slice(ctx->ggml_ctx, x, 2, start_idx, end_idx);
feat_map[feat_idx] = ggml_cont(ctx->ggml_ctx, slice);
}
feat_idx++;
@ -284,7 +294,8 @@ namespace LTXVAE {
bool causal,
std::vector<ggml_tensor*>& feat_map,
int& feat_idx,
int chunk_idx) {
int chunk_idx,
int temporal_pad = 0) {
auto norm1 = std::dynamic_pointer_cast<PixelNorm3D>(blocks["norm1"]);
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
auto norm2 = std::dynamic_pointer_cast<PixelNorm3D>(blocks["norm2"]);
@ -311,14 +322,14 @@ namespace LTXVAE {
h = apply_scale_shift(ctx->ggml_ctx, h, scale1, shift1);
}
h = ggml_silu_inplace(ctx->ggml_ctx, h);
h = conv1->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal);
h = conv1->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal, temporal_pad);
h = norm2->forward(ctx, h);
if (timestep_conditioning) {
h = apply_scale_shift(ctx->ggml_ctx, h, scale2, shift2);
}
h = ggml_silu_inplace(ctx->ggml_ctx, h);
h = conv2->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal);
h = conv2->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal, temporal_pad);
return ggml_add(ctx->ggml_ctx, h, x);
}
@ -367,7 +378,8 @@ namespace LTXVAE {
bool causal,
std::vector<ggml_tensor*>& feat_map,
int& feat_idx,
int chunk_idx) {
int chunk_idx,
int temporal_pad = 0) {
ggml_tensor* timestep_embed = nullptr;
if (timestep_conditioning) {
GGML_ASSERT(timestep != nullptr);
@ -376,7 +388,7 @@ namespace LTXVAE {
}
for (int i = 0; i < num_layers; i++) {
auto resnet = std::dynamic_pointer_cast<ResnetBlock3D>(blocks["res_blocks." + std::to_string(i)]);
x = resnet->forward(ctx, x, timestep_embed, causal, feat_map, feat_idx, chunk_idx);
x = resnet->forward(ctx, x, timestep_embed, causal, feat_map, feat_idx, chunk_idx, temporal_pad);
}
return x;
}
@ -437,7 +449,8 @@ namespace LTXVAE {
bool causal,
std::vector<ggml_tensor*>& feat_map,
int& feat_idx,
int chunk_idx) {
int chunk_idx,
int temporal_pad = 0) {
auto conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv"]);
bool drop_first = (chunk_idx == 0) && (factor_t > 1);
@ -453,7 +466,7 @@ namespace LTXVAE {
x_in = res;
}
x = conv->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal);
x = conv->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal, temporal_pad);
x = depth_to_space_3d(ctx->ggml_ctx, x, get_output_channels(), factor_t, factor_s, drop_first);
if (residual) {
x = ggml_add(ctx->ggml_ctx, x, x_in);
@ -986,7 +999,8 @@ namespace LTXVAE {
ggml_tensor* timestep,
std::vector<ggml_tensor*>& feat_map,
int& feat_idx,
int chunk_idx) {
int chunk_idx,
int& temporal_pad) {
auto conv_in = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_in"]);
auto conv_norm_out = std::dynamic_pointer_cast<PixelNorm3D>(blocks["conv_norm_out"]);
auto conv_out = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_out"]);
@ -998,7 +1012,7 @@ namespace LTXVAE {
}
// conv_in with feat_map for left temporal context
x = conv_in->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder);
x = conv_in->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder, temporal_pad);
// up_blocks
int block_idx = 0;
@ -1006,12 +1020,13 @@ namespace LTXVAE {
auto mid_block = std::dynamic_pointer_cast<UNetMidBlock3D>(blocks["up_blocks." + std::to_string(block_idx)]);
if (mid_block) {
x = mid_block->forward(ctx, x, scaled_timestep, causal_decoder,
feat_map, feat_idx, chunk_idx);
feat_map, feat_idx, chunk_idx, temporal_pad);
} else {
auto upsample = std::dynamic_pointer_cast<DepthToSpaceUpsample>(
blocks["up_blocks." + std::to_string(block_idx)]);
x = upsample->forward(ctx, x, causal_decoder,
feat_map, feat_idx, chunk_idx);
feat_map, feat_idx, chunk_idx, temporal_pad);
temporal_pad *= upsample->factor_t;
}
block_idx++;
}
@ -1028,7 +1043,7 @@ namespace LTXVAE {
x = apply_scale_shift(ctx->ggml_ctx, x, scale, shift);
}
x = ggml_silu_inplace(ctx->ggml_ctx, x);
x = conv_out->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder);
x = conv_out->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder, temporal_pad);
return x;
}
};
@ -1084,7 +1099,9 @@ namespace LTXVAE {
// tensors can be freed by GGML before the next iteration starts.
ggml_tensor* decode_tiled(GGMLRunnerContext* ctx,
ggml_tensor* z,
ggml_tensor* timestep) {
ggml_tensor* timestep,
int temporal_window_size = 1,
int temporal_tile_overlap = 0) {
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
auto processor = std::dynamic_pointer_cast<PerChannelStatistics>(blocks["per_channel_statistics"]);
auto latents = processor->un_normalize(ctx, z);
@ -1099,18 +1116,69 @@ namespace LTXVAE {
// 128 slots is generous enough for any supported decoder configuration.
std::vector<ggml_tensor*> feat_map(128, nullptr);
// Ensure window size is at least 1
int window = std::max(1, temporal_window_size);
int overlap = std::max(0, temporal_tile_overlap);
if (overlap >= window) {
LOG_WARN("temporal_tile_overlap (%d) is greater than or equal to temporal_tile_frames (%d), adjusting values to avoid empty decode windows",
overlap, window);
overlap = window - 1;
}
LOG_DEBUG("Using temporal tiling: temporal_tile_frames = %d, temporal_tile_overlap = %d, total frames = %d, resulting in %d tiles",
window,
overlap,
(int)T,
(T + window - overlap - 1) / (window - overlap));
ggml_tensor* out = nullptr;
for (int i = 0; i < (int)T; i++) {
for (int i = 0; i < (int)T - overlap; i += (window - overlap)) {
int feat_idx = 0;
auto z_i = ggml_ext_slice(ctx->ggml_ctx, latents, 2, i, i + 1);
auto out_i = decoder->forward_tiled_frame(ctx, z_i, timestep,
feat_map, feat_idx, i);
out = (out == nullptr) ? out_i : ggml_concat(ctx->ggml_ctx, out, out_i, 2);
// Calculate the end index for the current temporal chunk
int end_i = std::min((int)T, i + window);
if (end_i >= (int)T) {
overlap = 0; // avoid overlap issues in the last chunk
}
int chunk_overlap = overlap; // modified by forward_tiled_frame temporal inflation
auto z_chunk = ggml_ext_slice(ctx->ggml_ctx, latents, 2, i, end_i);
auto out_chunk = decoder->forward_tiled_frame(ctx, z_chunk, timestep,
feat_map, feat_idx, i, chunk_overlap);
// discard overlap frames if it's not the final chunk
if (overlap > 0 && end_i < (int)T) {
out_chunk = ggml_ext_slice(ctx->ggml_ctx, out_chunk, 2, 0, out_chunk->ne[2] - chunk_overlap);
}
out = (out == nullptr) ? out_chunk : ggml_concat(ctx->ggml_ctx, out, out_chunk, 2);
}
return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out, patch_size, 1);
}
ggml_tensor* decode_tiled_chunk(GGMLRunnerContext* ctx,
ggml_tensor* z,
ggml_tensor* timestep,
std::vector<ggml_tensor*>& feat_map,
int chunk_idx,
int temporal_tile_overlap,
int& feat_idx) {
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
auto processor = std::dynamic_pointer_cast<PerChannelStatistics>(blocks["per_channel_statistics"]);
auto latents = processor->un_normalize(ctx, z);
feat_idx = 0;
int chunk_overlap = temporal_tile_overlap; // modified by forward_tiled_frame temporal inflation
auto out_chunk = decoder->forward_tiled_frame(ctx, latents, timestep,
feat_map, feat_idx, chunk_idx, chunk_overlap);
if (chunk_overlap > 0) {
out_chunk = ggml_ext_slice(ctx->ggml_ctx, out_chunk, 2, 0, out_chunk->ne[2] - chunk_overlap);
}
return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out_chunk, patch_size, 1);
}
ggml_tensor* encode(GGMLRunnerContext* ctx,
ggml_tensor* x) {
GGML_ASSERT(!decode_only);
@ -1140,8 +1208,13 @@ namespace LTXVAE {
} // namespace LTXVAE
struct LTXVideoVAE : public VAE {
static constexpr int DEFAULT_TEMPORAL_TILE_FRAMES = 4;
static constexpr int DEFAULT_TEMPORAL_TILE_OVERLAP = 1;
bool decode_only;
bool temporal_tiling_enabled = false;
int temporal_tile_frames = DEFAULT_TEMPORAL_TILE_FRAMES;
int temporal_tile_overlap = DEFAULT_TEMPORAL_TILE_OVERLAP;
int ltx_vae_version;
bool timestep_conditioning;
int patch_size;
@ -1178,10 +1251,64 @@ struct LTXVideoVAE : public VAE {
temporal_tiling_enabled = enabled;
}
void set_tiling_params(const sd_tiling_params_t& params) override {
temporal_tiling_enabled = params.temporal_tiling;
temporal_tile_frames = DEFAULT_TEMPORAL_TILE_FRAMES;
temporal_tile_overlap = DEFAULT_TEMPORAL_TILE_OVERLAP;
for (const auto& [key, value] : parse_key_value_args(params.extra_tiling_args, "LTX VAE extra tiling arg")) {
int parsed = 0;
if (!parse_strict_int(value, parsed)) {
LOG_WARN("ignoring invalid LTX VAE extra tiling arg '%s=%s'", key.c_str(), value.c_str());
} else if (key == "temporal_tile_frames") {
temporal_tile_frames = std::max(1, parsed);
} else if (key == "temporal_tile_overlap") {
temporal_tile_overlap = std::max(0, parsed);
} else {
LOG_WARN("ignoring unknown LTX VAE extra tiling arg '%s'", key.c_str());
}
}
}
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) override {
vae.get_param_tensors(tensors, prefix);
}
struct TemporalTilePlan {
int frames = 1;
int overlap = 0;
int stride = 1;
int num_tiles = 1;
};
TemporalTilePlan resolve_temporal_tile_plan(int64_t total_frames) const {
TemporalTilePlan plan;
plan.frames = std::max(1, temporal_tile_frames);
plan.overlap = std::max(0, temporal_tile_overlap);
if (plan.overlap >= plan.frames) {
LOG_WARN("temporal_tile_overlap (%d) is greater than or equal to temporal_tile_frames (%d), adjusting values to avoid empty decode windows",
plan.overlap,
plan.frames);
plan.overlap = plan.frames - 1;
}
if (total_frames > 1 && plan.overlap >= total_frames) {
LOG_WARN("temporal_tile_overlap (%d) is greater than or equal to total latent frames (%lld), adjusting values to decode at least one tile",
plan.overlap,
(long long)total_frames);
plan.overlap = static_cast<int>(total_frames - 1);
}
plan.stride = std::max(1, plan.frames - plan.overlap);
int64_t tiled_frames = std::max<int64_t>(1, total_frames - plan.overlap);
plan.num_tiles = total_frames > 0 ? static_cast<int>((tiled_frames + plan.stride - 1) / plan.stride) : 0;
return plan;
}
std::string temporal_feat_cache_name(size_t feat_idx) const {
return "ltx_vae_temporal_feat:" + std::to_string(feat_idx);
}
ggml_cgraph* build_graph(const sd::Tensor<float>& z_tensor, bool decode_graph) {
ggml_cgraph* gf = new_graph_custom(20480);
ggml_tensor* z = make_input(z_tensor);
@ -1192,18 +1319,97 @@ struct LTXVideoVAE : public VAE {
auto runner_ctx = get_context();
ggml_tensor* out;
bool use_tiled = decode_graph && temporal_tiling_enabled &&
z_tensor.dim() == 5 && z_tensor.shape()[2] > 1;
if (use_tiled) {
out = vae.decode_tiled(&runner_ctx, z, timestep);
} else {
out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z);
}
out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z);
ggml_build_forward_expand(gf, out);
return gf;
}
ggml_cgraph* build_temporal_tile_graph(const sd::Tensor<float>& z_chunk_tensor,
int chunk_idx,
int chunk_overlap) {
ggml_cgraph* gf = new_graph_custom(20480);
ggml_tensor* z = make_input(z_chunk_tensor);
ggml_tensor* timestep = nullptr;
if (timestep_conditioning) {
timestep = make_input(decode_timestep_tensor);
}
std::vector<ggml_tensor*> feat_map(128, nullptr);
for (size_t feat_idx = 0; feat_idx < feat_map.size(); ++feat_idx) {
feat_map[feat_idx] = get_cache_tensor_by_name(temporal_feat_cache_name(feat_idx));
}
auto runner_ctx = get_context();
int feat_count = 0;
ggml_tensor* out = vae.decode_tiled_chunk(&runner_ctx,
z,
timestep,
feat_map,
chunk_idx,
chunk_overlap,
feat_count);
for (int feat_idx = 0; feat_idx < feat_count && feat_idx < static_cast<int>(feat_map.size()); ++feat_idx) {
ggml_tensor* feat_cache = feat_map[static_cast<size_t>(feat_idx)];
if (feat_cache != nullptr) {
cache(temporal_feat_cache_name(static_cast<size_t>(feat_idx)), feat_cache);
ggml_build_forward_expand(gf, feat_cache);
}
}
ggml_build_forward_expand(gf, out);
return gf;
}
sd::Tensor<float> decode_temporal_tiled_streaming(const int n_threads,
const sd::Tensor<float>& input,
size_t expected_dim) {
const int64_t total_frames = input.shape()[2];
TemporalTilePlan plan = resolve_temporal_tile_plan(total_frames);
LOG_DEBUG("Using streaming temporal tiling: temporal_tile_frames=%d, temporal_tile_overlap=%d, total latent frames=%lld, resulting in %d tiles",
plan.frames,
plan.overlap,
(long long)total_frames,
plan.num_tiles);
free_cache_ctx_and_buffer();
cache_tensor_map.clear();
sd::Tensor<float> output;
for (int64_t start = 0; start < total_frames - plan.overlap; start += plan.stride) {
const int64_t end = std::min<int64_t>(total_frames, start + plan.frames);
const int chunk_overlap = end < total_frames ? plan.overlap : 0;
auto z_chunk = sd::ops::slice(input, 2, start, end);
LOG_DEBUG("LTX VAE temporal tile %lld/%d: latent frames [%lld, %lld), overlap=%d",
(long long)(start / plan.stride + 1),
plan.num_tiles,
(long long)start,
(long long)end,
chunk_overlap);
auto get_graph = [&]() -> ggml_cgraph* {
return build_temporal_tile_graph(z_chunk,
static_cast<int>(start),
chunk_overlap);
};
auto chunk = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, true),
expected_dim);
if (chunk.empty()) {
free_cache_ctx_and_buffer();
cache_tensor_map.clear();
return {};
}
output = output.empty() ? std::move(chunk) : sd::ops::concat(output, chunk, 2);
}
free_cache_ctx_and_buffer();
cache_tensor_map.clear();
return output;
}
ggml_cgraph* build_latent_statistics_graph(const sd::Tensor<float>& z_tensor, bool normalize) {
ggml_cgraph* gf = new_graph_custom(1024);
ggml_tensor* z = make_input(z_tensor);
@ -1239,6 +1445,9 @@ struct LTXVideoVAE : public VAE {
input = sd::ops::slice(input, 2, 0, cropped_t);
}
}
if (decode_graph && temporal_tiling_enabled && input.dim() == 5 && input.shape()[2] > 1) {
return decode_temporal_tiled_streaming(n_threads, input, expected_dim);
}
auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(input, decode_graph);
};

View file

@ -155,7 +155,7 @@ public:
bool kcpp_lora_cache_populate = false;
std::string taesd_path;
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0, 0};
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0, 0, nullptr};
bool offload_params_to_cpu = false;
float max_vram = 0.f;
bool use_pmid = false;
@ -2921,7 +2921,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
sd_img_gen_params->batch_count = 1;
sd_img_gen_params->control_strength = 0.9f;
sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f};
sd_img_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f};
sd_img_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
sd_cache_params_init(&sd_img_gen_params->cache);
sd_hires_params_init(&sd_img_gen_params->hires);
}
@ -2950,7 +2950,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
"increase_ref_index: %s\n"
"control_strength: %.2f\n"
"photo maker: {style_strength = %.2f, id_images_count = %d, id_embed_path = %s}\n"
"VAE tiling: %s (temporal=%s)\n"
"VAE tiling: %s (temporal=%s, extra_tiling_args=%s)\n"
"hires: {enabled=%s, upscaler=%s, model_path=%s, scale=%.2f, target=%dx%d, steps=%d, denoising_strength=%.2f}\n",
SAFE_STR(sd_img_gen_params->prompt),
SAFE_STR(sd_img_gen_params->negative_prompt),
@ -2970,6 +2970,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
SAFE_STR(sd_img_gen_params->pm_params.id_embed_path),
BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled),
BOOL_STR(sd_img_gen_params->vae_tiling_params.temporal_tiling),
SAFE_STR(sd_img_gen_params->vae_tiling_params.extra_tiling_args),
BOOL_STR(sd_img_gen_params->hires.enabled),
sd_hires_upscaler_name(sd_img_gen_params->hires.upscaler),
SAFE_STR(sd_img_gen_params->hires.model_path),
@ -3007,7 +3008,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
sd_vid_gen_params->fps = 16;
sd_vid_gen_params->moe_boundary = 0.875f;
sd_vid_gen_params->vace_strength = 1.f;
sd_vid_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f};
sd_vid_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
sd_vid_gen_params->hires.enabled = false;
sd_vid_gen_params->hires.upscaler = SD_HIRES_UPSCALER_LATENT;
sd_vid_gen_params->hires.scale = 2.f;
@ -5460,14 +5461,24 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
sd_ctx->sd->diffusion_model->free_params_buffer();
}
int64_t latent_end = ggml_time_ms();
LOG_INFO("generating latent video completed, taking %.2fs", (latent_end - latent_start) * 1.0f / 1000);
sd_audio_t* generated_audio = nullptr;
if (sd_version_is_ltxav(sd_ctx->sd->version) &&
latents.audio_length > 0 &&
sd_ctx->sd->audio_vae_model != nullptr) {
int64_t audio_latent_decode_start = ggml_time_ms();
auto audio_latent = unpack_ltxav_audio_latent(final_latent,
latents.audio_length,
sd_ctx->sd->get_latent_channel());
if (!audio_latent.empty()) {
LOG_DEBUG("decode audio latent %dx%dx%dx%d",
(int)audio_latent.shape()[0],
(int)audio_latent.shape()[1],
(int)audio_latent.shape()[2],
(int)audio_latent.shape()[3]);
auto waveform = sd_ctx->sd->decode_ltx_audio_latent(audio_latent);
if (!waveform.empty()) {
generated_audio = waveform_to_sd_audio(sd_ctx->sd, waveform);
@ -5475,6 +5486,8 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
LOG_WARN("LTX audio latent decode failed; continuing with silent video output");
}
}
int64_t audio_latent_decode_end = ggml_time_ms();
LOG_INFO("decoding audio latent completed, taking %.2fs", (audio_latent_decode_end - audio_latent_decode_start) * 1.0f / 1000);
}
if (latents.video_conditioning_frame_count > 0) {
@ -5487,9 +5500,6 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
final_latent = sd::ops::slice(final_latent, 2, latents.ref_image_num, final_latent.shape()[2]);
}
int64_t latent_end = ggml_time_ms();
LOG_INFO("generating latent video completed, taking %.2fs", (latent_end - latent_start) * 1.0f / 1000);
auto result = decode_video_outputs(sd_ctx, latent_upscale_enabled ? hires_request : request, final_latent, num_frames_out);
if (result == nullptr) {
free_sd_audio(generated_audio);

View file

@ -160,6 +160,7 @@ typedef struct {
float target_overlap;
float rel_size_x;
float rel_size_y;
const char* extra_tiling_args;
} sd_tiling_params_t;
typedef struct {

View file

@ -259,10 +259,54 @@ public:
}
};
ggml_tensor* patchify(ggml_context* ctx,
ggml_tensor* x,
int64_t patch_size,
int64_t b = 1) {
class WideMemBlock : public GGMLBlock {
bool has_skip_conv = false;
public:
WideMemBlock(int channels, int out_channels)
: has_skip_conv(channels != out_channels) {
int groups = std::max(1, out_channels / 64);
blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels * 2, out_channels, {1, 1}, {1, 1}));
blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d_grouped(out_channels, out_channels, groups, {3, 3}, {1, 1}, {1, 1}));
blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {1, 1}, {1, 1}));
blocks["conv.6"] = std::shared_ptr<GGMLBlock>(new Conv2d_grouped(out_channels, out_channels, groups, {3, 3}, {1, 1}, {1, 1}));
if (has_skip_conv) {
blocks["skip"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
}
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* past) {
// x: [n, channels, h, w]
auto conv0 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.0"]);
auto conv1 = std::dynamic_pointer_cast<Conv2d_grouped>(blocks["conv.2"]);
auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]);
auto conv3 = std::dynamic_pointer_cast<Conv2d_grouped>(blocks["conv.6"]);
auto h = ggml_concat(ctx->ggml_ctx, x, past, 2);
h = conv0->forward(ctx, h);
h = ggml_relu_inplace(ctx->ggml_ctx, h);
h = conv1->forward(ctx, h);
h = ggml_relu_inplace(ctx->ggml_ctx, h);
h = conv2->forward(ctx, h);
h = ggml_relu_inplace(ctx->ggml_ctx, h);
h = conv3->forward(ctx, h);
auto skip = x;
if (has_skip_conv) {
auto skip_conv = std::dynamic_pointer_cast<Conv2d>(blocks["skip"]);
skip = skip_conv->forward(ctx, x);
}
h = ggml_add_inplace(ctx->ggml_ctx, h, skip);
h = ggml_relu_inplace(ctx->ggml_ctx, h);
return h;
}
};
ggml_tensor*
patchify(ggml_context* ctx,
ggml_tensor* x,
int64_t patch_size,
int64_t b = 1) {
// x: [f, b*c, h*q, w*r]
// return: [f, b*c*r*q, h, w]
if (patch_size == 1) {
@ -325,7 +369,6 @@ public:
int t_downscale = 1;
TinyVideoEncoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_downscale = {true, true, false})
: z_channels(z_channels), patch_size(patch_size) {
// self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool))
t_downscale = 1;
for (bool downscale : time_downscale) {
if (downscale) {
@ -384,11 +427,18 @@ class TinyVideoDecoder : public UnaryBlock {
int channels[num_layers + 1] = {256, 128, 64, 64};
int patch_size = 1;
int t_upscale = 1;
bool is_wide = false;
public:
TinyVideoDecoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_upscale = {false, true, true})
: z_channels(z_channels), patch_size(patch_size) {
TinyVideoDecoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_upscale = {false, true, true}, bool is_wide = false)
: z_channels(z_channels), patch_size(patch_size), is_wide(is_wide) {
t_upscale = 1;
if (is_wide) {
channels[0] = 1024;
channels[1] = 512;
channels[2] = 256;
}
for (bool upscale : time_upscale) {
if (upscale) {
t_upscale *= 2;
@ -400,7 +450,11 @@ public:
for (int i = 0; i < num_layers; i++) {
int stride = time_upscale[i] ? 2 : 1;
for (int j = 0; j < num_blocks; j++) {
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new MemBlock(channels[i], channels[i]));
if (is_wide) {
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new WideMemBlock(channels[i], channels[i]));
} else {
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new MemBlock(channels[i], channels[i]));
}
}
index++; // nn.Upsample()
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TGrow(channels[i], stride));
@ -425,10 +479,15 @@ public:
int index = 3;
for (int i = 0; i < num_layers; i++) {
for (int j = 0; j < num_blocks; j++) {
auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string(index++)]);
auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0);
mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0);
h = block->forward(ctx, h, mem);
auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0);
mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0);
if (is_wide) {
auto block = std::dynamic_pointer_cast<WideMemBlock>(blocks[std::to_string(index++)]);
h = block->forward(ctx, h, mem);
} else {
auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string(index++)]);
h = block->forward(ctx, h, mem);
}
}
// upsample
index++;
@ -455,6 +514,7 @@ class TAEHV : public GGMLBlock {
protected:
bool decode_only;
SDVersion version;
bool is_wide;
public:
int z_channels = 16;
@ -462,8 +522,8 @@ public:
std::vector<bool> time_upscale = {false, true, true};
public:
TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2)
: decode_only(decode_only), version(version) {
TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2, bool is_wide = false)
: decode_only(decode_only), version(version), is_wide(is_wide) {
int patch = 1;
if (version == VERSION_WAN2_2_TI2V) {
z_channels = 48;
@ -474,7 +534,7 @@ public:
time_downscale = {true, true, true};
time_upscale = {true, true, true};
}
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoDecoder(z_channels, patch, time_upscale));
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoDecoder(z_channels, patch, time_upscale, is_wide));
if (!decode_only) {
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoEncoder(z_channels, patch, time_downscale));
}
@ -623,6 +683,7 @@ struct TinyImageAutoEncoder : public VAE {
struct TinyVideoAutoEncoder : public VAE {
TAEHV taehv;
bool decode_only = false;
bool is_wide = false;
TinyVideoAutoEncoder(ggml_backend_t backend,
ggml_backend_t params_backend,
@ -631,8 +692,14 @@ struct TinyVideoAutoEncoder : public VAE {
bool decoder_only = true,
SDVersion version = VERSION_WAN2)
: decode_only(decoder_only),
taehv(decoder_only, version),
VAE(version, backend, params_backend) {
for (auto tensor_storage : tensor_storage_map) {
if (tensor_storage.first.find(prefix + ".3.conv.6.weight") != std::string::npos) {
is_wide = true;
break;
}
}
taehv = TAEHV(decoder_only, version, is_wide);
scale_input = false;
taehv.init(params_ctx, tensor_storage_map, prefix);
}
@ -663,7 +730,8 @@ struct TinyVideoAutoEncoder : public VAE {
}
ggml_cgraph* build_graph(const sd::Tensor<float>& z_tensor, bool decode_graph) {
ggml_cgraph* gf = ggml_new_graph(compute_ctx);
ggml_cgraph* gf = decode_graph && is_wide ? ggml_new_graph_custom(compute_ctx, 4096, false)
: ggml_new_graph(compute_ctx);
ggml_tensor* z = make_input(z_tensor);
auto runner_ctx = get_context();
ggml_tensor* out = decode_graph ? taehv.decode(&runner_ctx, z) : taehv.encode(&runner_ctx, z);

View file

@ -1,8 +1,10 @@
#include "util.h"
#include <algorithm>
#include <cctype>
#include <cmath>
#include <codecvt>
#include <cstdarg>
#include <exception>
#include <filesystem>
#include <fstream>
#include <locale>
@ -420,6 +422,88 @@ std::vector<std::string> split_string(const std::string& str, char delimiter) {
return result;
}
KeyValueArgs parse_key_value_args(const char* args, const char* context) {
KeyValueArgs pairs;
if (args == nullptr || args[0] == '\0') {
return pairs;
}
std::string raw(args);
size_t start = 0;
for (size_t pos = 0; pos <= raw.size(); ++pos) {
if (pos != raw.size() && raw[pos] != ',' && raw[pos] != ';') {
continue;
}
std::string token = trim(raw.substr(start, pos - start));
if (!token.empty()) {
size_t eq = token.find('=');
if (eq == std::string::npos) {
const char* log_context = context ? context : "key=value arg";
LOG_WARN("ignoring malformed %s '%s'", log_context, token.c_str());
} else {
std::string key = trim(token.substr(0, eq));
std::string value = trim(token.substr(eq + 1));
pairs.emplace_back(std::move(key), std::move(value));
}
}
start = pos + 1;
}
return pairs;
}
KeyValueArgs parse_key_value_args(const std::string& args, const char* context) {
return parse_key_value_args(args.c_str(), context);
}
bool parse_strict_float(const std::string& text, float& value) {
try {
size_t consumed = 0;
float parsed = std::stof(text, &consumed);
if (!trim(text.substr(consumed)).empty()) {
return false;
}
value = parsed;
return true;
} catch (const std::exception&) {
return false;
}
}
bool parse_strict_int(const std::string& text, int& value) {
try {
size_t consumed = 0;
int parsed = std::stoi(text, &consumed);
if (!trim(text.substr(consumed)).empty()) {
return false;
}
value = parsed;
return true;
} catch (const std::exception&) {
return false;
}
}
bool parse_strict_bool(const std::string& text, bool& value) {
std::string lowered = trim(text);
std::transform(lowered.begin(), lowered.end(), lowered.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (lowered == "1" || lowered == "true" || lowered == "yes" || lowered == "on") {
value = true;
return true;
}
if (lowered == "0" || lowered == "false" || lowered == "no" || lowered == "off") {
value = false;
return true;
}
return false;
}
// { kcpp
static int sdloglevel = 0; //-1 = hide all, 0 = normal, 1 = showall
static bool sdquiet = false;

View file

@ -4,6 +4,7 @@
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "ggml-backend.h"
@ -68,6 +69,15 @@ protected:
std::string path_join(const std::string& p1, const std::string& p2);
std::vector<std::string> split_string(const std::string& str, char delimiter);
using KeyValueArgs = std::vector<std::pair<std::string, std::string>>;
KeyValueArgs parse_key_value_args(const char* args, const char* context = "key=value arg");
KeyValueArgs parse_key_value_args(const std::string& args, const char* context = "key=value arg");
bool parse_strict_float(const std::string& text, float& value);
bool parse_strict_int(const std::string& text, int& value);
bool parse_strict_bool(const std::string& text, bool& value);
void pretty_progress(int step, int steps, float time);
void pretty_bytes_progress(int step, int steps, uint64_t bytes_processed, float elapsed_seconds);

View file

@ -167,6 +167,7 @@ public:
int64_t t0 = ggml_time_ms();
sd::Tensor<float> input = x;
sd::Tensor<float> output;
set_tiling_params(tiling_params);
if (tiling_params.enabled) {
const int scale_factor = get_scale_factor();
@ -216,6 +217,9 @@ public:
virtual void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) = 0;
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
virtual void set_temporal_tiling_enabled(bool enabled) { SD_UNUSED(enabled); };
virtual void set_tiling_params(const sd_tiling_params_t& params) {
set_temporal_tiling_enabled(params.temporal_tiling);
};
};
struct FakeVAE : public VAE {