From 2093ca4c736649bee937d7de71229c16ca9d2771 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Sun, 15 Mar 2026 20:58:45 +0800 Subject: [PATCH] ace step optimizations --- otherarch/acestep/dit-vae.cpp | 2 +- otherarch/acestep/mp3/mp3enc-quant.h | 67 ++++++++++--- otherarch/acestep/mp3/mp3enc.h | 142 +++++++++++++++++++-------- otherarch/utils.cpp | 35 +++++-- 4 files changed, 184 insertions(+), 62 deletions(-) diff --git a/otherarch/acestep/dit-vae.cpp b/otherarch/acestep/dit-vae.cpp index 172e04898..d7078a444 100644 --- a/otherarch/acestep/dit-vae.cpp +++ b/otherarch/acestep/dit-vae.cpp @@ -335,7 +335,7 @@ std::string acestep_generate_audio(const music_generation_inputs inputs) float duration = req.duration > 0 ? req.duration : 120.0f; long long seed = req.seed; int num_steps = req.inference_steps > 0 ? req.inference_steps : 8; - float guidance_scale = req.guidance_scale > 0 ? req.guidance_scale : 7.0f; + float guidance_scale = req.guidance_scale > 0 ? req.guidance_scale : 1.0f; float shift = req.shift > 0 ? req.shift : 1.0f; if (is_turbo && guidance_scale > 1.0f) { diff --git a/otherarch/acestep/mp3/mp3enc-quant.h b/otherarch/acestep/mp3/mp3enc-quant.h index 0d67242de..4afb0ea1a 100644 --- a/otherarch/acestep/mp3/mp3enc-quant.h +++ b/otherarch/acestep/mp3/mp3enc-quant.h @@ -487,31 +487,74 @@ static int mp3enc_outer_loop(const float * xr, // - stop when all bands are under threshold or no bits left // // Max 25 passes: enough for convergence at all bitrates. - float noise[21]; + float noise[22]; // 22 SFB bands before table terminator int best_ix[576]; - mp3enc_granule_info best_gi = gi; - int best_total = gi.part2_3_length; - int best_over = 21; // start pessimistic + mp3enc_granule_info best_gi = gi; + int best_total = gi.part2_3_length; + int best_over = 21; // start pessimistic + float best_max_db = 999.0f; // worst-band noise in dB over threshold + float best_tot_db = 999.0f; // total over-threshold noise in dB memcpy(best_ix, ix, sizeof(best_ix)); for (int iter = 0; iter < 25; iter++) { // Compute noise per SFB with current quantization mp3enc_calc_noise(xr, ix, gi.global_gain, gi.scalefac_l, gi.scalefac_scale, gi.preflag, sfb_table, noise); - // Count bands over threshold - int over_count = 0; + // Compute noise metrics for 3-axis comparison (GPSYCHO approach). + // Instead of just counting bands over threshold, track: + // - max_over_db: worst violation in dB (peak distortion) + // - tot_over_db: sum of violations in dB (for average) + // - over_count: number of distorted bands + // This prefers solutions that minimize peak distortion and spread + // remaining noise evenly, rather than concentrating it in one band. + int over_count = 0; + float max_over_db = 0.0f; + float tot_over_db = 0.0f; + for (int sfb = 0; sfb < 21; sfb++) { if (xmin[sfb] > 0.0f && noise[sfb] > xmin[sfb]) { over_count++; + float over_db = 10.0f * log10f(noise[sfb] / xmin[sfb]); + tot_over_db += over_db; + if (over_db > max_over_db) { + max_over_db = over_db; + } } } - // Track best result: prefer fewer over-threshold bands, - // then fewer total bits as tiebreaker. - if (over_count < best_over || (over_count == best_over && gi.part2_3_length < best_total)) { - best_gi = gi; - best_total = gi.part2_3_length; - best_over = over_count; + // 3-axis quant_compare (inspired by LAME GPSYCHO outer_loop): + // 1. Clean (over=0) always beats dirty (over>0) + // 2. Among clean solutions: prefer fewer bits + // 3. Among dirty solutions: minimize peak, then average, then count + bool is_better = false; + if (over_count == 0 && best_over > 0) { + is_better = true; + } else if (over_count == 0 && best_over == 0) { + is_better = (gi.part2_3_length < best_total); + } else if (over_count > 0 && best_over > 0) { + // both dirty: compare peak distortion first + if (max_over_db < best_max_db - 0.5f) { + // significantly lower peak -> better + is_better = true; + } else if (max_over_db < best_max_db + 0.5f) { + // similar peak: compare average violation + float avg = tot_over_db / (float) over_count; + float best_avg = (best_over > 0) ? best_tot_db / (float) best_over : 0.0f; + if (avg < best_avg - 0.3f) { + is_better = true; + } else if (avg < best_avg + 0.3f) { + // similar average: prefer fewer violated bands + is_better = (over_count < best_over); + } + } + } + + if (is_better) { + best_gi = gi; + best_total = gi.part2_3_length; + best_over = over_count; + best_max_db = max_over_db; + best_tot_db = tot_over_db; memcpy(best_ix, ix, sizeof(best_ix)); } diff --git a/otherarch/acestep/mp3/mp3enc.h b/otherarch/acestep/mp3/mp3enc.h index 84987d4cd..517d5687f 100644 --- a/otherarch/acestep/mp3/mp3enc.h +++ b/otherarch/acestep/mp3/mp3enc.h @@ -214,10 +214,11 @@ static int mp3enc_encode_frame(mp3enc_t * enc, const float * pcm) { int nch = enc->channels; int padding = mp3enc_get_padding(enc); - // Use MS stereo for stereo input (joint stereo mode) - // mode=1 (joint), mode_ext=2 (MS on, intensity off) + // MS stereo decision is deferred until after MDCT energy analysis. + // mode=1 (joint) allows switching M/S on or off per frame. + // mode=0 (stereo) would force L/R always. mode=3 (mono). int mode = (nch == 1) ? 3 : 1; - int mode_ext = (nch == 1) ? 0 : 2; + int mode_ext = 0; // set after MDCT energy analysis // Setup header mp3enc_header hdr; @@ -257,30 +258,28 @@ static int mp3enc_encode_frame(mp3enc_t * enc, const float * pcm) { // Total main_data bits: this frame's area + reservoir from previous frame int total_md_bits = main_data_bits + resv_bytes * 8; - // Mean bits per granule (from total budget) - int mean_bits = total_md_bits / 2; + // Phase 1: compute MDCT + psy for ALL granules before bit allocation. + // We need PE from both granules to weight the bit budget. + float mdct_all[2][2][576]; // [granule][channel][576 MDCT lines] + float saved_xmin[2][2][MP3ENC_PSY_SFB_MAX]; + float saved_pe[2] = { 0.0f, 0.0f }; - int ix[2][2][576]; // [granule][channel][line] - float mdct_lr[2][576]; // MDCT output per channel before M/S transform - - int total_bits_used = 0; - int intra_resv = 0; // intra-frame reservoir: bits saved by granule 0 for granule 1 + // Accumulators for MS stereo decision + float energy_mid = 0.0f; + float energy_side = 0.0f; for (int gr = 0; gr < 2; gr++) { int pcm_offset = gr * 576; - // Step 1: filterbank + MDCT for all channels + // filterbank + MDCT for all channels for (int ch = 0; ch < nch; ch++) { const float * ch_pcm = pcm + ch * 1152 + pcm_offset; - - // Run analysis filterbank: 576 PCM samples = 18 calls of 32 samples - float sb_out[32]; + float sb_out[32]; for (int slot = 0; slot < 18; slot++) { enc->filter[ch].process(ch_pcm + slot * 32, sb_out); for (int sb = 0; sb < 32; sb++) { enc->sb_cur[ch][sb][slot] = sb_out[sb]; } - // frequency inversion: negate odd subbands at odd time slots if (slot & 1) { for (int sb = 1; sb < 32; sb += 2) { @@ -288,43 +287,107 @@ static int mp3enc_encode_frame(mp3enc_t * enc, const float * pcm) { } } } - - // MDCT: transform subbands to 576 frequency lines - mp3enc_mdct_granule(enc->sb_prev[ch], enc->sb_cur[ch], mdct_lr[ch]); - - // Save current subbands as previous for next granule + mp3enc_mdct_granule(enc->sb_prev[ch], enc->sb_cur[ch], mdct_all[gr][ch]); memcpy(enc->sb_prev[ch], enc->sb_cur[ch], sizeof(enc->sb_cur[ch])); } - // Step 2: MS stereo transform + // MS stereo energy analysis (before transform, on L/R data). + // Accumulate mid and side energy across both granules to decide + // whether M/S coding is beneficial for this frame. + // M/S wins when L and R are correlated (side energy is small). if (nch == 2) { + for (int i = 0; i < enc->lowpass_line && i < 576; i++) { + float l = mdct_all[gr][0][i]; + float r = mdct_all[gr][1][i]; + float m = l + r; + float s = l - r; + energy_mid += m * m; + energy_side += s * s; + } + } + } + + // MS stereo decision: use M/S when channels are correlated enough + // that the side channel is cheap to encode. FhG "almost always uses + // ms_stereo" (GPSYCHO docs). We use it unless side energy dominates, + // which means the channels are very different (rare for music). + bool use_ms = false; + if (nch == 2) { + // Use M/S unless side channel has more energy than mid. + // This is generous -- almost always enables M/S (like FhG). + use_ms = (energy_side < energy_mid * 1.2f) || (energy_mid < 1e-20f); + mode_ext = use_ms ? 2 : 0; + hdr.mode_ext = mode_ext; + } + + // Now apply M/S transform + lowpass + psy for both granules + for (int gr = 0; gr < 2; gr++) { + // MS stereo transform + if (use_ms) { static const float ms_scale = 0.7071067811865476f; // 1/sqrt(2) for (int i = 0; i < 576; i++) { - float l = mdct_lr[0][i]; - float r = mdct_lr[1][i]; - mdct_lr[0][i] = (l + r) * ms_scale; - mdct_lr[1][i] = (l - r) * ms_scale; + float l = mdct_all[gr][0][i]; + float r = mdct_all[gr][1][i]; + mdct_all[gr][0][i] = (l + r) * ms_scale; + mdct_all[gr][1][i] = (l - r) * ms_scale; } } - // Step 2b: adaptive lowpass + // adaptive lowpass for (int ch = 0; ch < nch; ch++) { for (int i = enc->lowpass_line; i < 576; i++) { - mdct_lr[ch][i] = 0.0f; + mdct_all[gr][ch][i] = 0.0f; } } - // Step 3: run psy for all channels before bit allocation - float saved_xmin[2][MP3ENC_PSY_SFB_MAX]; + // psy: compute masking thresholds and perceptual entropy for (int ch = 0; ch < nch; ch++) { - enc->psy.compute(mdct_lr[ch], sfb_long, enc->sr_index, ch); - memcpy(saved_xmin[ch], enc->psy.xmin, sizeof(enc->psy.xmin)); + enc->psy.compute(mdct_all[gr][ch], sfb_long, enc->sr_index, ch); + memcpy(saved_xmin[gr][ch], enc->psy.xmin, sizeof(enc->psy.xmin)); + saved_pe[gr] += enc->psy.pe; } + } - // Step 4: bit allocation - int max_bits = mean_bits + intra_resv; + // Phase 2: PE-weighted bit allocation + quantization. + // Give more bits to granules with higher perceptual entropy (more complex + // signal). ISO 11172-3 computes PE but then ignores it for allocation; + // LAME uses it to drive the bit reservoir. We weight the per-granule + // budget proportionally to PE, clamped to avoid starving either granule. + int gr_budget[2]; + float pe_sum = saved_pe[0] + saved_pe[1]; + if (pe_sum > 1e-6f) { + float frac0 = saved_pe[0] / pe_sum; + gr_budget[0] = (int) ((float) total_md_bits * frac0); + gr_budget[1] = total_md_bits - gr_budget[0]; + // clamp: neither granule gets less than 20% or more than 80% + int lo = total_md_bits / 5; + int hi = total_md_bits * 4 / 5; + for (int gr = 0; gr < 2; gr++) { + if (gr_budget[gr] < lo) { + gr_budget[gr] = lo; + } + if (gr_budget[gr] > hi) { + gr_budget[gr] = hi; + } + } + // re-normalize after clamp + int clamped_sum = gr_budget[0] + gr_budget[1]; + int bits_to_spread = total_md_bits - clamped_sum; + gr_budget[0] += bits_to_spread / 2; + gr_budget[1] += bits_to_spread - bits_to_spread / 2; + } else { + gr_budget[0] = total_md_bits / 2; + gr_budget[1] = total_md_bits - gr_budget[0]; + } - // Don't exceed remaining budget + int ix[2][2][576]; + int total_bits_used = 0; + int intra_resv = 0; + + for (int gr = 0; gr < 2; gr++) { + int max_bits = gr_budget[gr] + intra_resv; + + // don't exceed remaining budget int remaining_bits = total_md_bits - total_bits_used; if (max_bits > remaining_bits) { max_bits = remaining_bits; @@ -335,17 +398,16 @@ static int mp3enc_encode_frame(mp3enc_t * enc, const float * pcm) { int bits_per_ch = max_bits / nch; - // Step 5: quantize each channel with saved thresholds + // quantize each channel with psy thresholds int gr_bits_used = 0; for (int ch = 0; ch < nch; ch++) { - int bits = mp3enc_outer_loop(mdct_lr[ch], ix[gr][ch], si.gr[gr][ch], saved_xmin[ch], bits_per_ch, sfb_long, - enc->sr_index, gr, si.scfsi[ch]); + int bits = mp3enc_outer_loop(mdct_all[gr][ch], ix[gr][ch], si.gr[gr][ch], saved_xmin[gr][ch], bits_per_ch, + sfb_long, enc->sr_index, gr, si.scfsi[ch]); gr_bits_used += bits; } - // Track intra-frame savings: bits not used by this granule - // become available for the next granule in this frame. - intra_resv += mean_bits - gr_bits_used; + // intra-frame savings: unused bits carry to next granule + intra_resv += gr_budget[gr] - gr_bits_used; if (intra_resv < 0) { intra_resv = 0; } diff --git a/otherarch/utils.cpp b/otherarch/utils.cpp index 228e90dd2..96b8e733d 100644 --- a/otherarch/utils.cpp +++ b/otherarch/utils.cpp @@ -418,11 +418,30 @@ std::vector resample_wav(int num_channels,const std::vector& input std::vector output((size_t)n_out * num_channels); const int half_len = 32; + const int taps = half_len * 2; + const double beta = 9.0; const double inv_i0b = 1.0 / audio_resample_bessel_i0(beta); const double fc = 0.5 * ((ratio < 1.0) ? ratio : 1.0); + // PRECOMPUTE KAISER WINDOW + std::vector window(taps + 1); + + for (int k = -half_len; k <= half_len; k++) + { + double t = (double)k / (double)half_len; + + double win; + + if (t < -1.0 || t > 1.0) + win = 0.0; + else + win = audio_resample_bessel_i0(beta * std::sqrt(1.0 - t * t)) * inv_i0b; + + window[k + half_len] = win; + } + for (int ch = 0; ch < num_channels; ch++) { const float* src = input.data() + ch * n_in; @@ -432,8 +451,10 @@ std::vector resample_wav(int num_channels,const std::vector& input { double center = (double)i / ratio; - int start = (int)std::floor(center) - half_len + 1; - int end = (int)std::floor(center) + half_len; + int base = (int)std::floor(center); + + int start = base - half_len + 1; + int end = base + half_len; double sum = 0.0; double wgt = 0.0; @@ -443,22 +464,18 @@ std::vector resample_wav(int num_channels,const std::vector& input double d = center - (double)j; double sinc_val; + if (std::fabs(d) < 1e-9) sinc_val = 2.0 * fc; else sinc_val = std::sin(2.0 * M_PI * fc * d) / (M_PI * d); - double t = d / (double)half_len; - - double win; - if (t < -1.0 || t > 1.0) - win = 0.0; - else - win = audio_resample_bessel_i0(beta * std::sqrt(1.0 - t * t)) * inv_i0b; + double win = window[j - start]; double h = sinc_val * win; int idx = j; + if (idx < 0) idx = 0; if (idx >= n_in) idx = n_in - 1;