ace step optimizations

This commit is contained in:
Concedo 2026-03-15 20:58:45 +08:00
parent ccd4745e0c
commit 2093ca4c73
4 changed files with 184 additions and 62 deletions

View file

@ -335,7 +335,7 @@ std::string acestep_generate_audio(const music_generation_inputs inputs)
float duration = req.duration > 0 ? req.duration : 120.0f;
long long seed = req.seed;
int num_steps = req.inference_steps > 0 ? req.inference_steps : 8;
float guidance_scale = req.guidance_scale > 0 ? req.guidance_scale : 7.0f;
float guidance_scale = req.guidance_scale > 0 ? req.guidance_scale : 1.0f;
float shift = req.shift > 0 ? req.shift : 1.0f;
if (is_turbo && guidance_scale > 1.0f) {

View file

@ -487,31 +487,74 @@ static int mp3enc_outer_loop(const float * xr,
// - stop when all bands are under threshold or no bits left
//
// Max 25 passes: enough for convergence at all bitrates.
float noise[21];
float noise[22]; // 22 SFB bands before table terminator
int best_ix[576];
mp3enc_granule_info best_gi = gi;
int best_total = gi.part2_3_length;
int best_over = 21; // start pessimistic
mp3enc_granule_info best_gi = gi;
int best_total = gi.part2_3_length;
int best_over = 21; // start pessimistic
float best_max_db = 999.0f; // worst-band noise in dB over threshold
float best_tot_db = 999.0f; // total over-threshold noise in dB
memcpy(best_ix, ix, sizeof(best_ix));
for (int iter = 0; iter < 25; iter++) {
// Compute noise per SFB with current quantization
mp3enc_calc_noise(xr, ix, gi.global_gain, gi.scalefac_l, gi.scalefac_scale, gi.preflag, sfb_table, noise);
// Count bands over threshold
int over_count = 0;
// Compute noise metrics for 3-axis comparison (GPSYCHO approach).
// Instead of just counting bands over threshold, track:
// - max_over_db: worst violation in dB (peak distortion)
// - tot_over_db: sum of violations in dB (for average)
// - over_count: number of distorted bands
// This prefers solutions that minimize peak distortion and spread
// remaining noise evenly, rather than concentrating it in one band.
int over_count = 0;
float max_over_db = 0.0f;
float tot_over_db = 0.0f;
for (int sfb = 0; sfb < 21; sfb++) {
if (xmin[sfb] > 0.0f && noise[sfb] > xmin[sfb]) {
over_count++;
float over_db = 10.0f * log10f(noise[sfb] / xmin[sfb]);
tot_over_db += over_db;
if (over_db > max_over_db) {
max_over_db = over_db;
}
}
}
// Track best result: prefer fewer over-threshold bands,
// then fewer total bits as tiebreaker.
if (over_count < best_over || (over_count == best_over && gi.part2_3_length < best_total)) {
best_gi = gi;
best_total = gi.part2_3_length;
best_over = over_count;
// 3-axis quant_compare (inspired by LAME GPSYCHO outer_loop):
// 1. Clean (over=0) always beats dirty (over>0)
// 2. Among clean solutions: prefer fewer bits
// 3. Among dirty solutions: minimize peak, then average, then count
bool is_better = false;
if (over_count == 0 && best_over > 0) {
is_better = true;
} else if (over_count == 0 && best_over == 0) {
is_better = (gi.part2_3_length < best_total);
} else if (over_count > 0 && best_over > 0) {
// both dirty: compare peak distortion first
if (max_over_db < best_max_db - 0.5f) {
// significantly lower peak -> better
is_better = true;
} else if (max_over_db < best_max_db + 0.5f) {
// similar peak: compare average violation
float avg = tot_over_db / (float) over_count;
float best_avg = (best_over > 0) ? best_tot_db / (float) best_over : 0.0f;
if (avg < best_avg - 0.3f) {
is_better = true;
} else if (avg < best_avg + 0.3f) {
// similar average: prefer fewer violated bands
is_better = (over_count < best_over);
}
}
}
if (is_better) {
best_gi = gi;
best_total = gi.part2_3_length;
best_over = over_count;
best_max_db = max_over_db;
best_tot_db = tot_over_db;
memcpy(best_ix, ix, sizeof(best_ix));
}

View file

@ -214,10 +214,11 @@ static int mp3enc_encode_frame(mp3enc_t * enc, const float * pcm) {
int nch = enc->channels;
int padding = mp3enc_get_padding(enc);
// Use MS stereo for stereo input (joint stereo mode)
// mode=1 (joint), mode_ext=2 (MS on, intensity off)
// MS stereo decision is deferred until after MDCT energy analysis.
// mode=1 (joint) allows switching M/S on or off per frame.
// mode=0 (stereo) would force L/R always. mode=3 (mono).
int mode = (nch == 1) ? 3 : 1;
int mode_ext = (nch == 1) ? 0 : 2;
int mode_ext = 0; // set after MDCT energy analysis
// Setup header
mp3enc_header hdr;
@ -257,30 +258,28 @@ static int mp3enc_encode_frame(mp3enc_t * enc, const float * pcm) {
// Total main_data bits: this frame's area + reservoir from previous frame
int total_md_bits = main_data_bits + resv_bytes * 8;
// Mean bits per granule (from total budget)
int mean_bits = total_md_bits / 2;
// Phase 1: compute MDCT + psy for ALL granules before bit allocation.
// We need PE from both granules to weight the bit budget.
float mdct_all[2][2][576]; // [granule][channel][576 MDCT lines]
float saved_xmin[2][2][MP3ENC_PSY_SFB_MAX];
float saved_pe[2] = { 0.0f, 0.0f };
int ix[2][2][576]; // [granule][channel][line]
float mdct_lr[2][576]; // MDCT output per channel before M/S transform
int total_bits_used = 0;
int intra_resv = 0; // intra-frame reservoir: bits saved by granule 0 for granule 1
// Accumulators for MS stereo decision
float energy_mid = 0.0f;
float energy_side = 0.0f;
for (int gr = 0; gr < 2; gr++) {
int pcm_offset = gr * 576;
// Step 1: filterbank + MDCT for all channels
// filterbank + MDCT for all channels
for (int ch = 0; ch < nch; ch++) {
const float * ch_pcm = pcm + ch * 1152 + pcm_offset;
// Run analysis filterbank: 576 PCM samples = 18 calls of 32 samples
float sb_out[32];
float sb_out[32];
for (int slot = 0; slot < 18; slot++) {
enc->filter[ch].process(ch_pcm + slot * 32, sb_out);
for (int sb = 0; sb < 32; sb++) {
enc->sb_cur[ch][sb][slot] = sb_out[sb];
}
// frequency inversion: negate odd subbands at odd time slots
if (slot & 1) {
for (int sb = 1; sb < 32; sb += 2) {
@ -288,43 +287,107 @@ static int mp3enc_encode_frame(mp3enc_t * enc, const float * pcm) {
}
}
}
// MDCT: transform subbands to 576 frequency lines
mp3enc_mdct_granule(enc->sb_prev[ch], enc->sb_cur[ch], mdct_lr[ch]);
// Save current subbands as previous for next granule
mp3enc_mdct_granule(enc->sb_prev[ch], enc->sb_cur[ch], mdct_all[gr][ch]);
memcpy(enc->sb_prev[ch], enc->sb_cur[ch], sizeof(enc->sb_cur[ch]));
}
// Step 2: MS stereo transform
// MS stereo energy analysis (before transform, on L/R data).
// Accumulate mid and side energy across both granules to decide
// whether M/S coding is beneficial for this frame.
// M/S wins when L and R are correlated (side energy is small).
if (nch == 2) {
for (int i = 0; i < enc->lowpass_line && i < 576; i++) {
float l = mdct_all[gr][0][i];
float r = mdct_all[gr][1][i];
float m = l + r;
float s = l - r;
energy_mid += m * m;
energy_side += s * s;
}
}
}
// MS stereo decision: use M/S when channels are correlated enough
// that the side channel is cheap to encode. FhG "almost always uses
// ms_stereo" (GPSYCHO docs). We use it unless side energy dominates,
// which means the channels are very different (rare for music).
bool use_ms = false;
if (nch == 2) {
// Use M/S unless side channel has more energy than mid.
// This is generous -- almost always enables M/S (like FhG).
use_ms = (energy_side < energy_mid * 1.2f) || (energy_mid < 1e-20f);
mode_ext = use_ms ? 2 : 0;
hdr.mode_ext = mode_ext;
}
// Now apply M/S transform + lowpass + psy for both granules
for (int gr = 0; gr < 2; gr++) {
// MS stereo transform
if (use_ms) {
static const float ms_scale = 0.7071067811865476f; // 1/sqrt(2)
for (int i = 0; i < 576; i++) {
float l = mdct_lr[0][i];
float r = mdct_lr[1][i];
mdct_lr[0][i] = (l + r) * ms_scale;
mdct_lr[1][i] = (l - r) * ms_scale;
float l = mdct_all[gr][0][i];
float r = mdct_all[gr][1][i];
mdct_all[gr][0][i] = (l + r) * ms_scale;
mdct_all[gr][1][i] = (l - r) * ms_scale;
}
}
// Step 2b: adaptive lowpass
// adaptive lowpass
for (int ch = 0; ch < nch; ch++) {
for (int i = enc->lowpass_line; i < 576; i++) {
mdct_lr[ch][i] = 0.0f;
mdct_all[gr][ch][i] = 0.0f;
}
}
// Step 3: run psy for all channels before bit allocation
float saved_xmin[2][MP3ENC_PSY_SFB_MAX];
// psy: compute masking thresholds and perceptual entropy
for (int ch = 0; ch < nch; ch++) {
enc->psy.compute(mdct_lr[ch], sfb_long, enc->sr_index, ch);
memcpy(saved_xmin[ch], enc->psy.xmin, sizeof(enc->psy.xmin));
enc->psy.compute(mdct_all[gr][ch], sfb_long, enc->sr_index, ch);
memcpy(saved_xmin[gr][ch], enc->psy.xmin, sizeof(enc->psy.xmin));
saved_pe[gr] += enc->psy.pe;
}
}
// Step 4: bit allocation
int max_bits = mean_bits + intra_resv;
// Phase 2: PE-weighted bit allocation + quantization.
// Give more bits to granules with higher perceptual entropy (more complex
// signal). ISO 11172-3 computes PE but then ignores it for allocation;
// LAME uses it to drive the bit reservoir. We weight the per-granule
// budget proportionally to PE, clamped to avoid starving either granule.
int gr_budget[2];
float pe_sum = saved_pe[0] + saved_pe[1];
if (pe_sum > 1e-6f) {
float frac0 = saved_pe[0] / pe_sum;
gr_budget[0] = (int) ((float) total_md_bits * frac0);
gr_budget[1] = total_md_bits - gr_budget[0];
// clamp: neither granule gets less than 20% or more than 80%
int lo = total_md_bits / 5;
int hi = total_md_bits * 4 / 5;
for (int gr = 0; gr < 2; gr++) {
if (gr_budget[gr] < lo) {
gr_budget[gr] = lo;
}
if (gr_budget[gr] > hi) {
gr_budget[gr] = hi;
}
}
// re-normalize after clamp
int clamped_sum = gr_budget[0] + gr_budget[1];
int bits_to_spread = total_md_bits - clamped_sum;
gr_budget[0] += bits_to_spread / 2;
gr_budget[1] += bits_to_spread - bits_to_spread / 2;
} else {
gr_budget[0] = total_md_bits / 2;
gr_budget[1] = total_md_bits - gr_budget[0];
}
// Don't exceed remaining budget
int ix[2][2][576];
int total_bits_used = 0;
int intra_resv = 0;
for (int gr = 0; gr < 2; gr++) {
int max_bits = gr_budget[gr] + intra_resv;
// don't exceed remaining budget
int remaining_bits = total_md_bits - total_bits_used;
if (max_bits > remaining_bits) {
max_bits = remaining_bits;
@ -335,17 +398,16 @@ static int mp3enc_encode_frame(mp3enc_t * enc, const float * pcm) {
int bits_per_ch = max_bits / nch;
// Step 5: quantize each channel with saved thresholds
// quantize each channel with psy thresholds
int gr_bits_used = 0;
for (int ch = 0; ch < nch; ch++) {
int bits = mp3enc_outer_loop(mdct_lr[ch], ix[gr][ch], si.gr[gr][ch], saved_xmin[ch], bits_per_ch, sfb_long,
enc->sr_index, gr, si.scfsi[ch]);
int bits = mp3enc_outer_loop(mdct_all[gr][ch], ix[gr][ch], si.gr[gr][ch], saved_xmin[gr][ch], bits_per_ch,
sfb_long, enc->sr_index, gr, si.scfsi[ch]);
gr_bits_used += bits;
}
// Track intra-frame savings: bits not used by this granule
// become available for the next granule in this frame.
intra_resv += mean_bits - gr_bits_used;
// intra-frame savings: unused bits carry to next granule
intra_resv += gr_budget[gr] - gr_bits_used;
if (intra_resv < 0) {
intra_resv = 0;
}

View file

@ -418,11 +418,30 @@ std::vector<float> resample_wav(int num_channels,const std::vector<float>& input
std::vector<float> output((size_t)n_out * num_channels);
const int half_len = 32;
const int taps = half_len * 2;
const double beta = 9.0;
const double inv_i0b = 1.0 / audio_resample_bessel_i0(beta);
const double fc = 0.5 * ((ratio < 1.0) ? ratio : 1.0);
// PRECOMPUTE KAISER WINDOW
std::vector<double> window(taps + 1);
for (int k = -half_len; k <= half_len; k++)
{
double t = (double)k / (double)half_len;
double win;
if (t < -1.0 || t > 1.0)
win = 0.0;
else
win = audio_resample_bessel_i0(beta * std::sqrt(1.0 - t * t)) * inv_i0b;
window[k + half_len] = win;
}
for (int ch = 0; ch < num_channels; ch++)
{
const float* src = input.data() + ch * n_in;
@ -432,8 +451,10 @@ std::vector<float> resample_wav(int num_channels,const std::vector<float>& input
{
double center = (double)i / ratio;
int start = (int)std::floor(center) - half_len + 1;
int end = (int)std::floor(center) + half_len;
int base = (int)std::floor(center);
int start = base - half_len + 1;
int end = base + half_len;
double sum = 0.0;
double wgt = 0.0;
@ -443,22 +464,18 @@ std::vector<float> resample_wav(int num_channels,const std::vector<float>& input
double d = center - (double)j;
double sinc_val;
if (std::fabs(d) < 1e-9)
sinc_val = 2.0 * fc;
else
sinc_val = std::sin(2.0 * M_PI * fc * d) / (M_PI * d);
double t = d / (double)half_len;
double win;
if (t < -1.0 || t > 1.0)
win = 0.0;
else
win = audio_resample_bessel_i0(beta * std::sqrt(1.0 - t * t)) * inv_i0b;
double win = window[j - start];
double h = sinc_val * win;
int idx = j;
if (idx < 0) idx = 0;
if (idx >= n_in) idx = n_in - 1;