mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-07 00:41:50 +00:00
reenable cfg
This commit is contained in:
parent
b7d2fe68e7
commit
fb3f7d92bc
1 changed files with 46 additions and 9 deletions
|
|
@ -865,6 +865,17 @@ static std::vector<std::string> generate_phase1_batch(
|
|||
std::vector<float> logits_uncond(V * N);
|
||||
std::vector<int> tokens(N);
|
||||
|
||||
// CFG: single forward with 2*N (cond + uncond)
|
||||
int N2 = use_cfg ? 2 * N : N;
|
||||
std::vector<int> tokens_2n(N2), sets_2n(N2);
|
||||
std::vector<float> logits_2n((size_t)V * N2);
|
||||
if (use_cfg) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
sets_2n[i] = cond_sets[i];
|
||||
sets_2n[N + i] = uncond_sets[i];
|
||||
}
|
||||
}
|
||||
|
||||
int n_active = N;
|
||||
for (int i = 0; i < N; i++)
|
||||
if (seqs[i].done) n_active--;
|
||||
|
|
@ -876,9 +887,18 @@ static std::vector<std::string> generate_phase1_batch(
|
|||
for (int i = 0; i < N; i++)
|
||||
tokens[i] = seqs[i].last_token;
|
||||
|
||||
qw3lm_forward_batch(m, tokens.data(), cond_sets.data(), N, logits_cond.data());
|
||||
if (use_cfg)
|
||||
qw3lm_forward_batch(m, tokens.data(), uncond_sets.data(), N, logits_uncond.data());
|
||||
if (use_cfg) {
|
||||
// Single batched forward: cond[0..N-1] + uncond[N..2N-1]
|
||||
for (int i = 0; i < N; i++) {
|
||||
tokens_2n[i] = tokens[i];
|
||||
tokens_2n[N + i] = tokens[i];
|
||||
}
|
||||
qw3lm_forward_batch(m, tokens_2n.data(), sets_2n.data(), N2, logits_2n.data());
|
||||
memcpy(logits_cond.data(), logits_2n.data(), (size_t)V * N * sizeof(float));
|
||||
memcpy(logits_uncond.data(), logits_2n.data() + (size_t)V * N, (size_t)V * N * sizeof(float));
|
||||
} else {
|
||||
qw3lm_forward_batch(m, tokens.data(), cond_sets.data(), N, logits_cond.data());
|
||||
}
|
||||
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (seqs[i].done) continue;
|
||||
|
|
@ -1086,6 +1106,17 @@ static std::vector<std::string> run_phase2_batch(
|
|||
std::vector<float> logits_uncond(V * N);
|
||||
std::vector<int> tokens(N);
|
||||
|
||||
// CFG: single forward with 2*N (cond + uncond)
|
||||
int N2 = use_cfg ? 2 * N : N;
|
||||
std::vector<int> tokens_2n(N2), sets_2n(N2);
|
||||
std::vector<float> logits_2n((size_t)V * N2);
|
||||
if (use_cfg) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
sets_2n[i] = cond_sets[i];
|
||||
sets_2n[N + i] = uncond_sets[i];
|
||||
}
|
||||
}
|
||||
|
||||
int n_active = N;
|
||||
for (int i = 0; i < N; i++)
|
||||
if (seqs[i].done) n_active--;
|
||||
|
|
@ -1095,12 +1126,18 @@ static std::vector<std::string> run_phase2_batch(
|
|||
for (int i = 0; i < N; i++)
|
||||
tokens[i] = seqs[i].last_token;
|
||||
|
||||
// Batched forward: cond
|
||||
qw3lm_forward_batch(m, tokens.data(), cond_sets.data(), N, logits_cond.data());
|
||||
|
||||
// Batched forward: uncond
|
||||
if (use_cfg)
|
||||
qw3lm_forward_batch(m, tokens.data(), uncond_sets.data(), N, logits_uncond.data());
|
||||
if (use_cfg) {
|
||||
// Single batched forward: cond[0..N-1] + uncond[N..2N-1]
|
||||
for (int i = 0; i < N; i++) {
|
||||
tokens_2n[i] = tokens[i];
|
||||
tokens_2n[N + i] = tokens[i];
|
||||
}
|
||||
qw3lm_forward_batch(m, tokens_2n.data(), sets_2n.data(), N2, logits_2n.data());
|
||||
memcpy(logits_cond.data(), logits_2n.data(), (size_t)V * N * sizeof(float));
|
||||
memcpy(logits_uncond.data(), logits_2n.data() + (size_t)V * N, (size_t)V * N * sizeof(float));
|
||||
} else {
|
||||
qw3lm_forward_batch(m, tokens.data(), cond_sets.data(), N, logits_cond.data());
|
||||
}
|
||||
|
||||
// Per-sequence: CFG combine + sample
|
||||
for (int i = 0; i < N; i++) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue