diff --git a/otherarch/acestep/ace-qwen3.cpp b/otherarch/acestep/ace-qwen3.cpp index 4e0012b4a..bdc9ca4ba 100644 --- a/otherarch/acestep/ace-qwen3.cpp +++ b/otherarch/acestep/ace-qwen3.cpp @@ -865,6 +865,17 @@ static std::vector generate_phase1_batch( std::vector logits_uncond(V * N); std::vector tokens(N); + // CFG: single forward with 2*N (cond + uncond) + int N2 = use_cfg ? 2 * N : N; + std::vector tokens_2n(N2), sets_2n(N2); + std::vector logits_2n((size_t)V * N2); + if (use_cfg) { + for (int i = 0; i < N; i++) { + sets_2n[i] = cond_sets[i]; + sets_2n[N + i] = uncond_sets[i]; + } + } + int n_active = N; for (int i = 0; i < N; i++) if (seqs[i].done) n_active--; @@ -876,9 +887,18 @@ static std::vector generate_phase1_batch( for (int i = 0; i < N; i++) tokens[i] = seqs[i].last_token; - qw3lm_forward_batch(m, tokens.data(), cond_sets.data(), N, logits_cond.data()); - if (use_cfg) - qw3lm_forward_batch(m, tokens.data(), uncond_sets.data(), N, logits_uncond.data()); + if (use_cfg) { + // Single batched forward: cond[0..N-1] + uncond[N..2N-1] + for (int i = 0; i < N; i++) { + tokens_2n[i] = tokens[i]; + tokens_2n[N + i] = tokens[i]; + } + qw3lm_forward_batch(m, tokens_2n.data(), sets_2n.data(), N2, logits_2n.data()); + memcpy(logits_cond.data(), logits_2n.data(), (size_t)V * N * sizeof(float)); + memcpy(logits_uncond.data(), logits_2n.data() + (size_t)V * N, (size_t)V * N * sizeof(float)); + } else { + qw3lm_forward_batch(m, tokens.data(), cond_sets.data(), N, logits_cond.data()); + } for (int i = 0; i < N; i++) { if (seqs[i].done) continue; @@ -1086,6 +1106,17 @@ static std::vector run_phase2_batch( std::vector logits_uncond(V * N); std::vector tokens(N); + // CFG: single forward with 2*N (cond + uncond) + int N2 = use_cfg ? 2 * N : N; + std::vector tokens_2n(N2), sets_2n(N2); + std::vector logits_2n((size_t)V * N2); + if (use_cfg) { + for (int i = 0; i < N; i++) { + sets_2n[i] = cond_sets[i]; + sets_2n[N + i] = uncond_sets[i]; + } + } + int n_active = N; for (int i = 0; i < N; i++) if (seqs[i].done) n_active--; @@ -1095,12 +1126,18 @@ static std::vector run_phase2_batch( for (int i = 0; i < N; i++) tokens[i] = seqs[i].last_token; - // Batched forward: cond - qw3lm_forward_batch(m, tokens.data(), cond_sets.data(), N, logits_cond.data()); - - // Batched forward: uncond - if (use_cfg) - qw3lm_forward_batch(m, tokens.data(), uncond_sets.data(), N, logits_uncond.data()); + if (use_cfg) { + // Single batched forward: cond[0..N-1] + uncond[N..2N-1] + for (int i = 0; i < N; i++) { + tokens_2n[i] = tokens[i]; + tokens_2n[N + i] = tokens[i]; + } + qw3lm_forward_batch(m, tokens_2n.data(), sets_2n.data(), N2, logits_2n.data()); + memcpy(logits_cond.data(), logits_2n.data(), (size_t)V * N * sizeof(float)); + memcpy(logits_uncond.data(), logits_2n.data() + (size_t)V * N, (size_t)V * N * sizeof(float)); + } else { + qw3lm_forward_batch(m, tokens.data(), cond_sets.data(), N, logits_cond.data()); + } // Per-sequence: CFG combine + sample for (int i = 0; i < N; i++) {