reenable cfg

This commit is contained in:
Concedo 2026-02-26 14:51:15 +08:00
parent b7d2fe68e7
commit fb3f7d92bc

View file

@ -865,6 +865,17 @@ static std::vector<std::string> generate_phase1_batch(
std::vector<float> logits_uncond(V * N);
std::vector<int> tokens(N);
// CFG: single forward with 2*N (cond + uncond)
int N2 = use_cfg ? 2 * N : N;
std::vector<int> tokens_2n(N2), sets_2n(N2);
std::vector<float> logits_2n((size_t)V * N2);
if (use_cfg) {
for (int i = 0; i < N; i++) {
sets_2n[i] = cond_sets[i];
sets_2n[N + i] = uncond_sets[i];
}
}
int n_active = N;
for (int i = 0; i < N; i++)
if (seqs[i].done) n_active--;
@ -876,9 +887,18 @@ static std::vector<std::string> generate_phase1_batch(
for (int i = 0; i < N; i++)
tokens[i] = seqs[i].last_token;
qw3lm_forward_batch(m, tokens.data(), cond_sets.data(), N, logits_cond.data());
if (use_cfg)
qw3lm_forward_batch(m, tokens.data(), uncond_sets.data(), N, logits_uncond.data());
if (use_cfg) {
// Single batched forward: cond[0..N-1] + uncond[N..2N-1]
for (int i = 0; i < N; i++) {
tokens_2n[i] = tokens[i];
tokens_2n[N + i] = tokens[i];
}
qw3lm_forward_batch(m, tokens_2n.data(), sets_2n.data(), N2, logits_2n.data());
memcpy(logits_cond.data(), logits_2n.data(), (size_t)V * N * sizeof(float));
memcpy(logits_uncond.data(), logits_2n.data() + (size_t)V * N, (size_t)V * N * sizeof(float));
} else {
qw3lm_forward_batch(m, tokens.data(), cond_sets.data(), N, logits_cond.data());
}
for (int i = 0; i < N; i++) {
if (seqs[i].done) continue;
@ -1086,6 +1106,17 @@ static std::vector<std::string> run_phase2_batch(
std::vector<float> logits_uncond(V * N);
std::vector<int> tokens(N);
// CFG: single forward with 2*N (cond + uncond)
int N2 = use_cfg ? 2 * N : N;
std::vector<int> tokens_2n(N2), sets_2n(N2);
std::vector<float> logits_2n((size_t)V * N2);
if (use_cfg) {
for (int i = 0; i < N; i++) {
sets_2n[i] = cond_sets[i];
sets_2n[N + i] = uncond_sets[i];
}
}
int n_active = N;
for (int i = 0; i < N; i++)
if (seqs[i].done) n_active--;
@ -1095,12 +1126,18 @@ static std::vector<std::string> run_phase2_batch(
for (int i = 0; i < N; i++)
tokens[i] = seqs[i].last_token;
// Batched forward: cond
qw3lm_forward_batch(m, tokens.data(), cond_sets.data(), N, logits_cond.data());
// Batched forward: uncond
if (use_cfg)
qw3lm_forward_batch(m, tokens.data(), uncond_sets.data(), N, logits_uncond.data());
if (use_cfg) {
// Single batched forward: cond[0..N-1] + uncond[N..2N-1]
for (int i = 0; i < N; i++) {
tokens_2n[i] = tokens[i];
tokens_2n[N + i] = tokens[i];
}
qw3lm_forward_batch(m, tokens_2n.data(), sets_2n.data(), N2, logits_2n.data());
memcpy(logits_cond.data(), logits_2n.data(), (size_t)V * N * sizeof(float));
memcpy(logits_uncond.data(), logits_2n.data() + (size_t)V * N, (size_t)V * N * sizeof(float));
} else {
qw3lm_forward_batch(m, tokens.data(), cond_sets.data(), N, logits_cond.data());
}
// Per-sequence: CFG combine + sample
for (int i = 0; i < N; i++) {