mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-13 15:39:11 +00:00
1536 lines
55 KiB
C++
1536 lines
55 KiB
C++
// ace-qwen3.cpp : ACE-Step 5Hz LM inference (GGML)
|
|
// Qwen3 causal LM: CoT reasoning + audio code generation
|
|
// ace-qwen3: Qwen3 causal LM for ACE-Step music generation (GGML backend)
|
|
#include "./qwen3-lm.h"
|
|
#include "./bpe.h"
|
|
#include "./request.h"
|
|
|
|
#include <cstdio>
|
|
#include <cstdlib>
|
|
#include <cstring>
|
|
#include <cmath>
|
|
#include <string>
|
|
#include <vector>
|
|
#include <algorithm>
|
|
#include <random>
|
|
#include <chrono>
|
|
#include <map>
|
|
#include <unordered_map>
|
|
#include <iomanip>
|
|
#include <sstream>
|
|
|
|
// Timer
|
|
struct Timer {
|
|
std::chrono::steady_clock::time_point t;
|
|
Timer() : t(std::chrono::steady_clock::now()) {}
|
|
double ms() const {
|
|
return std::chrono::duration<double, std::milli>(
|
|
std::chrono::steady_clock::now() - t).count();
|
|
}
|
|
void reset() { t = std::chrono::steady_clock::now(); }
|
|
};
|
|
|
|
// Special token IDs (Qwen3 extended vocab)
|
|
#define TOKEN_IM_START 151644
|
|
#define TOKEN_IM_END 151645
|
|
#define TOKEN_THINK 151667
|
|
#define TOKEN_THINK_END 151668
|
|
#define AUDIO_CODE_BASE 151669
|
|
#define AUDIO_CODE_COUNT 65535
|
|
|
|
//
|
|
// Sampling
|
|
//
|
|
|
|
struct TokenProb {
|
|
int id;
|
|
float prob;
|
|
};
|
|
|
|
// Sampling: temperature -> top_k -> top_p -> softmax -> multinomial
|
|
// Matches nano-vLLM Sampler: div_(temperature) -> apply_top_k_top_p -> softmax -> sample
|
|
static int sample_top_k_p(float * logits, int V, float temperature, float top_p, int top_k, std::mt19937 & rng) {
|
|
if (temperature <= 0.0f) {
|
|
// greedy
|
|
return (int)(std::max_element(logits, logits + V) - logits);
|
|
}
|
|
|
|
// 1. temperature (matches nano-vLLM: logits.float().div_(temperatures))
|
|
float inv_temp = 1.0f / temperature;
|
|
for (int i = 0; i < V; i++)
|
|
logits[i] *= inv_temp;
|
|
|
|
// 2. top_k: keep top K values, set rest to -inf
|
|
// nano-vLLM: topk(k) returns k-th largest as threshold, mask < threshold
|
|
if (top_k > 0 && top_k < V) {
|
|
std::vector<float> tmp(logits, logits + V);
|
|
std::nth_element(tmp.begin(), tmp.begin() + (top_k - 1), tmp.end(), std::greater<float>());
|
|
float threshold = tmp[top_k - 1];
|
|
for (int i = 0; i < V; i++)
|
|
if (logits[i] < threshold) logits[i] = -INFINITY;
|
|
}
|
|
|
|
// 3. top_p: nucleus filter on temp-scaled logits (matches nano-vLLM: softmax on scaled logits)
|
|
// nano-vLLM sorts ascending, cumsum, masks cumsum <= (1-p), keeps last element.
|
|
// Equivalent descending: mask tokens where cumsum_before >= top_p (shift-right).
|
|
if (top_p > 0.0f && top_p < 1.0f) {
|
|
std::vector<TokenProb> sorted(V);
|
|
for (int i = 0; i < V; i++) sorted[i] = {i, logits[i]};
|
|
std::sort(sorted.begin(), sorted.end(),
|
|
[](const TokenProb & a, const TokenProb & b) { return a.prob > b.prob; });
|
|
|
|
// softmax of temp-scaled logits for cumsum
|
|
float max_val = sorted[0].prob;
|
|
float sum = 0.0f;
|
|
std::vector<float> probs(V);
|
|
for (int i = 0; i < V; i++) {
|
|
probs[i] = expf(sorted[i].prob - max_val);
|
|
sum += probs[i];
|
|
}
|
|
float inv = 1.0f / sum;
|
|
|
|
// cumulative sum, test before accumulating (shift-right trick)
|
|
float cum = 0.0f;
|
|
for (int i = 0; i < V; i++) {
|
|
if (i > 0 && cum >= top_p) // i>0: always keep at least first token
|
|
logits[sorted[i].id] = -INFINITY;
|
|
cum += probs[i] * inv;
|
|
}
|
|
}
|
|
|
|
// 4. softmax -> multinomial (temperature already applied)
|
|
float max_val = -INFINITY;
|
|
for (int i = 0; i < V; i++)
|
|
if (logits[i] > max_val) max_val = logits[i];
|
|
float sum = 0.0f;
|
|
for (int i = 0; i < V; i++) {
|
|
logits[i] = expf(logits[i] - max_val);
|
|
sum += logits[i];
|
|
}
|
|
|
|
std::uniform_real_distribution<float> dist(0.0f, sum);
|
|
float r = dist(rng);
|
|
float acc = 0.0f;
|
|
for (int i = 0; i < V; i++) {
|
|
acc += logits[i];
|
|
if (acc >= r) return i;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
//
|
|
// BPE decode (token IDs -> text)
|
|
//
|
|
|
|
static std::string bpe_decode(const BPETokenizer & bpe, const std::vector<int> & ids) {
|
|
static std::unordered_map<int, uint8_t> byte_dec;
|
|
static bool init = false;
|
|
if (!init) {
|
|
for (int b = 0; b < 256; b++) {
|
|
int adv;
|
|
int cp = utf8_codepoint(bpe.byte2str[b].c_str(), &adv);
|
|
byte_dec[cp] = (uint8_t)b;
|
|
}
|
|
init = true;
|
|
}
|
|
|
|
std::string result;
|
|
for (int id : ids) {
|
|
if (id == TOKEN_THINK) { result += "<think>"; continue; }
|
|
if (id == TOKEN_THINK_END) { result += "</think>"; continue; }
|
|
if (id == TOKEN_IM_START || id == TOKEN_IM_END) continue;
|
|
if (id >= AUDIO_CODE_BASE) continue;
|
|
if (id < 0 || id >= (int)bpe.id_to_str.size()) continue;
|
|
const std::string & s = bpe.id_to_str[id];
|
|
if (s.empty()) continue;
|
|
const char * p = s.c_str();
|
|
while (*p) {
|
|
int adv;
|
|
int cp = utf8_codepoint(p, &adv);
|
|
auto it = byte_dec.find(cp);
|
|
if (it != byte_dec.end()) result += (char)it->second;
|
|
p += adv;
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
//
|
|
// ACE-Step prompt
|
|
//
|
|
|
|
struct AcePrompt {
|
|
std::string caption;
|
|
std::string lyrics;
|
|
float duration;
|
|
int bpm;
|
|
std::string keyscale;
|
|
std::string timesignature;
|
|
std::string vocal_language;
|
|
};
|
|
|
|
static std::mt19937 acestep_lm_rng;
|
|
static bool acestep_lm_dbg = false;
|
|
static std::vector<int32_t> forced_tokens;
|
|
static std::vector<int> caption_tokens = std::vector<int>(); //will be filled with caption tokens
|
|
const int rep_pen_range = 48;
|
|
|
|
//
|
|
// CoT parsing (extract metadata + lyrics from LLM Phase1 output)
|
|
//
|
|
|
|
static bool parse_cot_and_lyrics(const std::string & text, AcePrompt * out) {
|
|
// Extract CoT content between <think>...</think>
|
|
size_t ts = text.find("<think>");
|
|
size_t te = text.find("</think>");
|
|
|
|
std::string cot;
|
|
std::string lyrics_after;
|
|
|
|
if (ts != std::string::npos && te != std::string::npos) {
|
|
cot = text.substr(ts + 7, te - ts - 7);
|
|
lyrics_after = text.substr(te + 8);
|
|
} else if (te != std::string::npos) {
|
|
cot = text.substr(0, te);
|
|
lyrics_after = text.substr(te + 8);
|
|
} else {
|
|
cot = text;
|
|
}
|
|
|
|
// Parse YAML-like fields from CoT
|
|
auto get_field = [&](const std::string & key) -> std::string {
|
|
std::string needle = key + ":";
|
|
size_t p = cot.find(needle);
|
|
if (p == std::string::npos) return "";
|
|
p += needle.size();
|
|
while (p < cot.size() && (cot[p] == ' ' || cot[p] == '\'')) p++;
|
|
size_t end = cot.find('\n', p);
|
|
if (end == std::string::npos) end = cot.size();
|
|
std::string val = cot.substr(p, end - p);
|
|
// Strip trailing whitespace and quotes
|
|
while (!val.empty() && (val.back() == ' ' || val.back() == '\'' || val.back() == '\r'))
|
|
val.pop_back();
|
|
return val;
|
|
};
|
|
|
|
std::string bpm_s = get_field("bpm");
|
|
if (!bpm_s.empty()) out->bpm = atoi(bpm_s.c_str());
|
|
|
|
std::string dur_s = get_field("duration");
|
|
if (!dur_s.empty()) out->duration = (float)atof(dur_s.c_str());
|
|
|
|
std::string ks = get_field("keyscale");
|
|
if (!ks.empty()) out->keyscale = ks;
|
|
|
|
std::string ts_s = get_field("timesignature");
|
|
if (!ts_s.empty()) out->timesignature = ts_s;
|
|
|
|
std::string lang = get_field("language");
|
|
if (!lang.empty()) out->vocal_language = lang;
|
|
|
|
std::string cap = get_field("caption");
|
|
if (!cap.empty()) {
|
|
// Caption may span multiple lines (yaml word-wrap)
|
|
size_t cp = cot.find("caption:");
|
|
if (cp != std::string::npos) {
|
|
cp += 8;
|
|
size_t end = cot.find("\nduration:", cp);
|
|
if (end == std::string::npos) end = cot.find("\nkeyscale:", cp);
|
|
if (end == std::string::npos) end = cot.size();
|
|
std::string full_cap = cot.substr(cp, end - cp);
|
|
// Trim and collapse whitespace
|
|
std::string cleaned;
|
|
bool in_space = true;
|
|
for (char ch : full_cap) {
|
|
if (ch == '\n' || ch == '\r') ch = ' ';
|
|
if (ch == ' ') {
|
|
if (!in_space) cleaned += ' ';
|
|
in_space = true;
|
|
} else {
|
|
cleaned += ch;
|
|
in_space = false;
|
|
}
|
|
}
|
|
while (!cleaned.empty() && cleaned.back() == ' ') cleaned.pop_back();
|
|
while (!cleaned.empty() && cleaned.front() == ' ') cleaned.erase(cleaned.begin());
|
|
if (!cleaned.empty()) out->caption = cleaned;
|
|
}
|
|
}
|
|
|
|
// Lyrics after </think>
|
|
if (!lyrics_after.empty()) {
|
|
// Trim leading whitespace
|
|
size_t s = lyrics_after.find_first_not_of(" \t\n\r");
|
|
if (s != std::string::npos)
|
|
lyrics_after = lyrics_after.substr(s);
|
|
// Trim trailing whitespace
|
|
while (!lyrics_after.empty() &&
|
|
(lyrics_after.back() == ' ' || lyrics_after.back() == '\n' || lyrics_after.back() == '\r'))
|
|
lyrics_after.pop_back();
|
|
if (!lyrics_after.empty())
|
|
out->lyrics = lyrics_after;
|
|
}
|
|
|
|
return (out->bpm > 0 || out->duration > 0);
|
|
}
|
|
|
|
//
|
|
// Prompt building (Qwen3 chat template)
|
|
//
|
|
|
|
static std::vector<int> build_lm_prompt(BPETokenizer & bpe, const AcePrompt & prompt) {
|
|
std::vector<int> ids;
|
|
auto append = [&](const std::string & text) {
|
|
auto t = bpe_encode(&bpe, text, false);
|
|
ids.insert(ids.end(), t.begin(), t.end());
|
|
};
|
|
ids.push_back(TOKEN_IM_START);
|
|
append("system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n");
|
|
ids.push_back(TOKEN_IM_END);
|
|
append("\n");
|
|
ids.push_back(TOKEN_IM_START);
|
|
append("user\n# Caption\n" + prompt.caption + "\n\n# Lyric\n" + prompt.lyrics + "\n");
|
|
ids.push_back(TOKEN_IM_END);
|
|
append("\n");
|
|
ids.push_back(TOKEN_IM_START);
|
|
append("assistant\n");
|
|
return ids;
|
|
}
|
|
|
|
static std::vector<int> build_lm_prompt_uncond(BPETokenizer & bpe, const AcePrompt & prompt,
|
|
const char * negative_prompt) {
|
|
std::vector<int> ids;
|
|
auto append = [&](const std::string & text) {
|
|
auto t = bpe_encode(&bpe, text, false);
|
|
ids.insert(ids.end(), t.begin(), t.end());
|
|
};
|
|
ids.push_back(TOKEN_IM_START);
|
|
append("system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n");
|
|
ids.push_back(TOKEN_IM_END);
|
|
append("\n");
|
|
ids.push_back(TOKEN_IM_START);
|
|
bool has_neg = negative_prompt && strlen(negative_prompt) > 0
|
|
&& strcmp(negative_prompt, "NO USER INPUT") != 0;
|
|
if (has_neg)
|
|
append("user\n# Caption\n" + std::string(negative_prompt) + "\n\n# Lyric\n" + prompt.lyrics + "\n");
|
|
else
|
|
append("user\n# Lyric\n" + prompt.lyrics + "\n");
|
|
ids.push_back(TOKEN_IM_END);
|
|
append("\n");
|
|
ids.push_back(TOKEN_IM_START);
|
|
append("assistant\n");
|
|
return ids;
|
|
}
|
|
|
|
// Build CoT YAML content (matching Python yaml.dump sort_keys=True)
|
|
static std::string build_cot_yaml(const AcePrompt & prompt) {
|
|
auto yaml_wrap = [](const std::string & key, const std::string & val) -> std::string {
|
|
std::string result = key + ":";
|
|
int col = (int)(key.size() + 1);
|
|
size_t i = 0;
|
|
while (i < val.size()) {
|
|
size_t end = val.find(' ', i);
|
|
if (end == std::string::npos) end = val.size();
|
|
std::string word = val.substr(i, end - i);
|
|
if (col > 80) {
|
|
result += "\n ";
|
|
col = 2;
|
|
} else {
|
|
result += " ";
|
|
col += 1;
|
|
}
|
|
result += word;
|
|
col += (int)word.size();
|
|
i = (end < val.size()) ? end + 1 : val.size();
|
|
}
|
|
result += "\n";
|
|
return result;
|
|
};
|
|
|
|
std::string yaml;
|
|
if (prompt.bpm > 0)
|
|
yaml += "bpm: " + std::to_string(prompt.bpm) + "\n";
|
|
// if (!prompt.caption.empty())
|
|
// yaml += yaml_wrap("caption", prompt.caption);
|
|
if (prompt.duration > 0)
|
|
yaml += "duration: " + std::to_string((int)prompt.duration) + "\n";
|
|
if (!prompt.keyscale.empty())
|
|
yaml += "keyscale: " + prompt.keyscale + "\n";
|
|
if (!prompt.vocal_language.empty())
|
|
yaml += "language: " + prompt.vocal_language + "\n";
|
|
if (!prompt.timesignature.empty())
|
|
yaml += "timesignature: " + prompt.timesignature + "\n";
|
|
return yaml;
|
|
}
|
|
|
|
// Prompt with injected CoT (Phase 2: all metas known)
|
|
static std::vector<int> build_lm_prompt_with_cot(BPETokenizer & bpe, const AcePrompt & prompt,
|
|
const std::string & cot_yaml) {
|
|
std::vector<int> ids;
|
|
auto append = [&](const std::string & text) {
|
|
auto t = bpe_encode(&bpe, text, false);
|
|
ids.insert(ids.end(), t.begin(), t.end());
|
|
};
|
|
ids.push_back(TOKEN_IM_START);
|
|
append("system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n");
|
|
ids.push_back(TOKEN_IM_END);
|
|
append("\n");
|
|
ids.push_back(TOKEN_IM_START);
|
|
append("user\n# Caption\n" + prompt.caption + "\n\n# Lyric\n" + prompt.lyrics + "\n");
|
|
ids.push_back(TOKEN_IM_END);
|
|
append("\n");
|
|
ids.push_back(TOKEN_IM_START);
|
|
append("assistant\n");
|
|
ids.push_back(TOKEN_THINK);
|
|
append("\n" + cot_yaml);
|
|
ids.push_back(TOKEN_THINK_END);
|
|
append("\n\n");
|
|
ids.push_back(TOKEN_IM_END);
|
|
append("\n");
|
|
return ids;
|
|
}
|
|
|
|
// Unconditional prompt with empty CoT for CFG (Phase 2)
|
|
static std::vector<int> build_lm_prompt_uncond_with_cot(BPETokenizer & bpe, const AcePrompt & prompt,
|
|
const char * negative_prompt) {
|
|
std::vector<int> ids;
|
|
auto append = [&](const std::string & text) {
|
|
auto t = bpe_encode(&bpe, text, false);
|
|
ids.insert(ids.end(), t.begin(), t.end());
|
|
};
|
|
ids.push_back(TOKEN_IM_START);
|
|
append("system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n");
|
|
ids.push_back(TOKEN_IM_END);
|
|
append("\n");
|
|
ids.push_back(TOKEN_IM_START);
|
|
bool has_neg = negative_prompt && strlen(negative_prompt) > 0
|
|
&& strcmp(negative_prompt, "NO USER INPUT") != 0;
|
|
std::string cap = has_neg ? std::string(negative_prompt) : prompt.caption;
|
|
append("user\n# Caption\n" + cap + "\n\n# Lyric\n" + prompt.lyrics + "\n");
|
|
ids.push_back(TOKEN_IM_END);
|
|
append("\n");
|
|
ids.push_back(TOKEN_IM_START);
|
|
append("assistant\n");
|
|
ids.push_back(TOKEN_THINK);
|
|
append("\n\n");
|
|
ids.push_back(TOKEN_THINK_END);
|
|
append("\n\n");
|
|
ids.push_back(TOKEN_IM_END);
|
|
append("\n");
|
|
return ids;
|
|
}
|
|
|
|
// Build Qwen3 chat prompt: <|im_start|>system\n...<|im_end|>\n<|im_start|>user\n...<|im_end|>\n<|im_start|>assistant\n
|
|
static std::vector<int> build_custom_prompt(BPETokenizer & bpe, const char * sys, const char * user) {
|
|
std::vector<int> ids;
|
|
auto append = [&](const std::string & text) {
|
|
auto t = bpe_encode(&bpe, text, false);
|
|
ids.insert(ids.end(), t.begin(), t.end());
|
|
};
|
|
ids.push_back(TOKEN_IM_START);
|
|
append("system\n" + std::string(sys) + "\n");
|
|
ids.push_back(TOKEN_IM_END);
|
|
append("\n");
|
|
ids.push_back(TOKEN_IM_START);
|
|
append("user\n" + std::string(user) + "\n");
|
|
ids.push_back(TOKEN_IM_END);
|
|
append("\n");
|
|
ids.push_back(TOKEN_IM_START);
|
|
append("assistant\n");
|
|
return ids;
|
|
}
|
|
|
|
//
|
|
// Prefix tree for FSM constrained decoding
|
|
//
|
|
|
|
struct PrefixTree {
|
|
// Maps prefix (token sequence) to set of valid next tokens
|
|
std::map<std::vector<int>, std::vector<int>> nodes;
|
|
|
|
void add(const std::vector<int> & seq) {
|
|
for (size_t i = 0; i < seq.size(); i++) {
|
|
std::vector<int> prefix(seq.begin(), seq.begin() + i);
|
|
int next = seq[i];
|
|
auto & vec = nodes[prefix];
|
|
if (std::find(vec.begin(), vec.end(), next) == vec.end())
|
|
vec.push_back(next);
|
|
}
|
|
}
|
|
|
|
const std::vector<int> * get(const std::vector<int> & prefix) const {
|
|
auto it = nodes.find(prefix);
|
|
return it != nodes.end() ? &it->second : nullptr;
|
|
}
|
|
};
|
|
|
|
//
|
|
// Metadata FSM (constrained decoding for CoT fields)
|
|
//
|
|
|
|
struct MetadataFSM {
|
|
enum State {
|
|
BPM_NAME, BPM_VALUE,
|
|
CAPTION_NAME, CAPTION_VALUE,
|
|
DURATION_NAME, DURATION_VALUE,
|
|
KEYSCALE_NAME, KEYSCALE_VALUE,
|
|
LANGUAGE_NAME, LANGUAGE_VALUE,
|
|
TIMESIG_NAME, TIMESIG_VALUE,
|
|
THINK_END,
|
|
CODES,
|
|
DISABLED
|
|
};
|
|
|
|
State state = DISABLED;
|
|
int caption_newline_counter = 0;
|
|
int name_pos = 0;
|
|
std::vector<int> value_acc;
|
|
bool enabled = false;
|
|
bool rewrite_caption = false;
|
|
|
|
std::vector<int> bpm_name, caption_name, duration_name;
|
|
std::vector<int> keyscale_name, language_name, timesig_name;
|
|
PrefixTree bpm_tree, duration_tree, keyscale_tree, language_tree, timesig_tree;
|
|
int newline_tok = -1;
|
|
int think_end_tok = TOKEN_THINK_END;
|
|
int vocab_size = 0;
|
|
|
|
static std::vector<int> tokenize_strip(BPETokenizer & bpe,
|
|
const std::string & full,
|
|
const std::string & prefix) {
|
|
std::vector<int> full_tok = bpe_encode(&bpe, full, false);
|
|
std::vector<int> pre_tok = bpe_encode(&bpe, prefix, false);
|
|
if (full_tok.size() >= pre_tok.size() &&
|
|
std::equal(pre_tok.begin(), pre_tok.end(), full_tok.begin()))
|
|
return std::vector<int>(full_tok.begin() + pre_tok.size(), full_tok.end());
|
|
return full_tok;
|
|
}
|
|
|
|
void build_value_tree(BPETokenizer & bpe, PrefixTree & tree,
|
|
const std::string & field_prefix,
|
|
const std::vector<std::string> & values) {
|
|
for (auto & val : values) {
|
|
std::string full = field_prefix + val + "\n";
|
|
std::vector<int> vtok = tokenize_strip(bpe, full, field_prefix);
|
|
tree.add(vtok);
|
|
}
|
|
}
|
|
|
|
void init(BPETokenizer & bpe, int vsize, bool rew_caption) {
|
|
vocab_size = vsize;
|
|
auto nl = bpe_encode(&bpe, "\n", false);
|
|
newline_tok = nl.empty() ? -1 : nl[0];
|
|
|
|
bpm_name = bpe_encode(&bpe, "bpm:", false);
|
|
caption_name = bpe_encode(&bpe, "caption:", false);
|
|
duration_name = bpe_encode(&bpe, "duration:", false);
|
|
keyscale_name = bpe_encode(&bpe, "keyscale:", false);
|
|
language_name = bpe_encode(&bpe, "language:", false);
|
|
timesig_name = bpe_encode(&bpe, "timesignature:", false);
|
|
|
|
// BPM 30-195
|
|
{
|
|
std::vector<std::string> vals;
|
|
for (int v = 30; v <= 195; v++) vals.push_back(std::to_string(v));
|
|
build_value_tree(bpe, bpm_tree, "bpm:", vals);
|
|
}
|
|
// Duration 40-450
|
|
{
|
|
std::vector<std::string> vals;
|
|
for (int v = 40; v <= 450; v++) vals.push_back(std::to_string(v));
|
|
build_value_tree(bpe, duration_tree, "duration:", vals);
|
|
}
|
|
// Keyscale
|
|
{
|
|
const char * notes[] = {"A","B","C","D","E","F","G"};
|
|
const char * accs[] = {"","b","#"};
|
|
const char * modes[] = {
|
|
"major","minor","dorian","phrygian","lydian","mixolydian",
|
|
"aeolian","locrian","chromatic","blues","pentatonic",
|
|
"harmonic minor","melodic minor"
|
|
};
|
|
std::vector<std::string> vals;
|
|
for (auto n : notes)
|
|
for (auto a : accs)
|
|
for (auto m : modes)
|
|
vals.push_back(std::string(n) + a + " " + m);
|
|
build_value_tree(bpe, keyscale_tree, "keyscale:", vals);
|
|
}
|
|
// Language
|
|
{
|
|
std::vector<std::string> vals = {
|
|
"en","zh","ja","ko","es","fr","de","uk","ru","pt",
|
|
"it","ar","tr","pl","sv","nl","unknown"
|
|
};
|
|
build_value_tree(bpe, language_tree, "language:", vals);
|
|
}
|
|
// Time signature
|
|
{
|
|
std::vector<std::string> vals = {"2","3","4","6"};
|
|
build_value_tree(bpe, timesig_tree, "timesignature:", vals);
|
|
}
|
|
|
|
fprintf(stderr, "[FSM] Prefix trees: bpm=%zu, dur=%zu, key=%zu, lang=%zu, tsig=%zu nodes\n",
|
|
bpm_tree.nodes.size(), duration_tree.nodes.size(),
|
|
keyscale_tree.nodes.size(), language_tree.nodes.size(),
|
|
timesig_tree.nodes.size());
|
|
enabled = true;
|
|
state = BPM_NAME;
|
|
name_pos = 0;
|
|
caption_newline_counter = 0;
|
|
rewrite_caption = rew_caption;
|
|
value_acc.clear();
|
|
}
|
|
|
|
void reset() {
|
|
state = BPM_NAME;
|
|
name_pos = 0;
|
|
caption_newline_counter = 0;
|
|
value_acc.clear();
|
|
}
|
|
|
|
// Force FSM to only allow a specific language value
|
|
void force_language(BPETokenizer & bpe, const std::string & lang) {
|
|
language_tree = PrefixTree();
|
|
build_value_tree(bpe, language_tree, "language:", {lang});
|
|
}
|
|
|
|
const std::vector<int> * current_name_tokens() const {
|
|
switch (state) {
|
|
case BPM_NAME: return &bpm_name;
|
|
case CAPTION_NAME: return &caption_name;
|
|
case DURATION_NAME: return &duration_name;
|
|
case KEYSCALE_NAME: return &keyscale_name;
|
|
case LANGUAGE_NAME: return &language_name;
|
|
case TIMESIG_NAME: return ×ig_name;
|
|
default: return nullptr;
|
|
}
|
|
}
|
|
|
|
const PrefixTree * current_value_tree() const {
|
|
switch (state) {
|
|
case BPM_VALUE: return &bpm_tree;
|
|
case DURATION_VALUE: return &duration_tree;
|
|
case KEYSCALE_VALUE: return &keyscale_tree;
|
|
case LANGUAGE_VALUE: return &language_tree;
|
|
case TIMESIG_VALUE: return ×ig_tree;
|
|
default: return nullptr;
|
|
}
|
|
}
|
|
|
|
State next_name_state() const {
|
|
switch (state) {
|
|
case BPM_NAME: case BPM_VALUE: return CAPTION_NAME;
|
|
case CAPTION_NAME: case CAPTION_VALUE: return DURATION_NAME;
|
|
case DURATION_NAME: case DURATION_VALUE: return KEYSCALE_NAME;
|
|
case KEYSCALE_NAME: case KEYSCALE_VALUE: return LANGUAGE_NAME;
|
|
case LANGUAGE_NAME: case LANGUAGE_VALUE: return TIMESIG_NAME;
|
|
case TIMESIG_NAME: case TIMESIG_VALUE: return THINK_END;
|
|
default: return CODES;
|
|
}
|
|
}
|
|
|
|
void apply_mask(float * logits) {
|
|
if (!enabled || state == CODES || state == DISABLED) return;
|
|
|
|
const std::vector<int> * name = current_name_tokens();
|
|
if (name && name_pos < (int)name->size()) {
|
|
int forced = (*name)[name_pos];
|
|
for (int v = 0; v < vocab_size; v++)
|
|
if (v != forced) logits[v] = -1e9f;
|
|
return;
|
|
}
|
|
|
|
const PrefixTree * tree = current_value_tree();
|
|
if (tree) {
|
|
const std::vector<int> * allowed = tree->get(value_acc);
|
|
if (allowed && !allowed->empty()) {
|
|
std::vector<float> saved(allowed->size());
|
|
for (size_t i = 0; i < allowed->size(); i++)
|
|
saved[i] = logits[(*allowed)[i]];
|
|
for (int v = 0; v < vocab_size; v++) logits[v] = -1e9f;
|
|
for (size_t i = 0; i < allowed->size(); i++)
|
|
logits[(*allowed)[i]] = saved[i];
|
|
} else {
|
|
if (newline_tok >= 0) {
|
|
for (int v = 0; v < vocab_size; v++)
|
|
if (v != newline_tok) logits[v] = -1e9f;
|
|
}
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (state == CAPTION_VALUE) {
|
|
for (int v = AUDIO_CODE_BASE; v < AUDIO_CODE_BASE + AUDIO_CODE_COUNT; v++)
|
|
if (v < vocab_size) logits[v] = -1e9f;
|
|
return;
|
|
}
|
|
|
|
if (state == THINK_END) {
|
|
for (int v = 0; v < vocab_size; v++)
|
|
if (v != think_end_tok) logits[v] = -1e9f;
|
|
return;
|
|
}
|
|
}
|
|
|
|
void update(int token) {
|
|
if (!enabled || state == CODES || state == DISABLED) return;
|
|
|
|
const std::vector<int> * name = current_name_tokens();
|
|
if (name && name_pos < (int)name->size()) {
|
|
name_pos++;
|
|
if (name_pos >= (int)name->size()) {
|
|
switch (state) {
|
|
case BPM_NAME: state = BPM_VALUE; break;
|
|
case CAPTION_NAME:
|
|
state = CAPTION_VALUE;
|
|
if(caption_tokens.size()>0)
|
|
{
|
|
forced_tokens.clear();
|
|
forced_tokens = caption_tokens;
|
|
}
|
|
break;
|
|
case DURATION_NAME: state = DURATION_VALUE; break;
|
|
case KEYSCALE_NAME: state = KEYSCALE_VALUE; break;
|
|
case LANGUAGE_NAME: state = LANGUAGE_VALUE; break;
|
|
case TIMESIG_NAME: state = TIMESIG_VALUE; break;
|
|
default: break;
|
|
}
|
|
value_acc.clear();
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (current_value_tree()) {
|
|
if (token == newline_tok) {
|
|
state = next_name_state();
|
|
name_pos = 0;
|
|
value_acc.clear();
|
|
} else {
|
|
value_acc.push_back(token);
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (state == CAPTION_VALUE) {
|
|
if(token==newline_tok && !rewrite_caption)
|
|
{
|
|
//not changing caption, so end it now
|
|
caption_newline_counter = 0;
|
|
state = DURATION_NAME;
|
|
name_pos = 0;
|
|
value_acc.clear();
|
|
}
|
|
else if (token == newline_tok) {
|
|
// caption_newline_counter += 1;
|
|
// if(caption_newline_counter>=3) //ingest 3 newlines for the caption, making it longer
|
|
// {
|
|
caption_newline_counter = 0;
|
|
state = DURATION_NAME;
|
|
name_pos = 0;
|
|
value_acc.clear();
|
|
// }
|
|
}
|
|
else if(token == 17021) // fallback for duration token, immediately go next
|
|
{
|
|
caption_newline_counter = 0;
|
|
state = DURATION_NAME;
|
|
name_pos = 0;
|
|
value_acc.clear();
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (state == THINK_END) {
|
|
state = CODES;
|
|
return;
|
|
}
|
|
}
|
|
};
|
|
|
|
//
|
|
// Generation
|
|
//
|
|
|
|
|
|
// Text-only generation (Phase 1: no CFG, stops at EOS)
|
|
static std::string codes_to_string(const std::vector<int> & codes);
|
|
|
|
// Convert audio codes vector to comma-separated string (Python-compatible)
|
|
static std::string codes_to_string(const std::vector<int> & codes) {
|
|
std::string s;
|
|
for (size_t i = 0; i < codes.size(); i++) {
|
|
if (i > 0) s += ',';
|
|
s += std::to_string(codes[i]);
|
|
}
|
|
return s;
|
|
}
|
|
|
|
// Phase 2: run audio code generation with all metas known
|
|
// Returns comma-separated codes string (empty on failure)
|
|
|
|
// Parse N Phase 1 outputs into N AcePrompts, merging into base.
|
|
// merge_lyrics: true for simple mode (Phase 1 generates lyrics),
|
|
// false for partial mode (user provided lyrics).
|
|
static void parse_phase1_into_aces(
|
|
const std::vector<std::string> & texts, const AcePrompt & base,
|
|
std::vector<AcePrompt> & aces, long long base_seed,
|
|
const char * label, bool merge_lyrics) {
|
|
int N = (int)texts.size();
|
|
aces.resize(N);
|
|
for (int i = 0; i < N; i++) {
|
|
fprintf(stderr, "[%s Batch%d] seed=%lld:\n%s\n", label, i, base_seed + i, texts[i].c_str());
|
|
AcePrompt parsed = {};
|
|
if (!parse_cot_and_lyrics(texts[i], &parsed))
|
|
fprintf(stderr, "WARNING: batch %d CoT parse incomplete\n", i);
|
|
aces[i] = base;
|
|
if (parsed.bpm > 0) aces[i].bpm = parsed.bpm;
|
|
if (parsed.duration > 0) aces[i].duration = parsed.duration;
|
|
if (!parsed.keyscale.empty()) aces[i].keyscale = parsed.keyscale;
|
|
if (!parsed.timesignature.empty()) aces[i].timesignature = parsed.timesignature;
|
|
if (!parsed.vocal_language.empty()) aces[i].vocal_language = parsed.vocal_language;
|
|
if (!parsed.caption.empty()) aces[i].caption = parsed.caption;
|
|
if (merge_lyrics && !parsed.lyrics.empty()) aces[i].lyrics = parsed.lyrics;
|
|
if (aces[i].duration <= 0) aces[i].duration = 120.0f;
|
|
if (aces[i].duration > 600) aces[i].duration = 600.0f;
|
|
}
|
|
}
|
|
|
|
//hack for kcpp: forcing the correct tokens after end of thinking
|
|
const std::vector<int> think_chain = {271,2,15953,2216,198}; // "\n# Lyric\n"
|
|
|
|
// Batched Phase 1: N text generations with shared prompt, different seeds.
|
|
// No CFG. Each element gets its own FSM state and RNG.
|
|
// Returns N generated text strings.
|
|
static std::vector<std::string> generate_phase1_batch(
|
|
Qwen3LM * m, BPETokenizer * bpe,
|
|
const std::vector<int> & prompt_tokens,
|
|
int max_new_tokens, float temperature, float top_p, int top_k, float rep_pen,
|
|
long long base_seed, int N,
|
|
MetadataFSM * fsm_template,
|
|
bool lyrics_mode,
|
|
float cfg_scale = 1.0f,
|
|
const std::vector<int> * uncond_tokens = nullptr,
|
|
bool stop_at_reasoning = false) {
|
|
|
|
int V = m->cfg.vocab_size;
|
|
bool use_cfg = cfg_scale > 1.0f && uncond_tokens && !uncond_tokens->empty();
|
|
forced_tokens.clear();
|
|
|
|
std::vector<int32_t> quicklastntoks;
|
|
|
|
// KV sets: cond [0..N-1], uncond [N..2N-1] if CFG
|
|
for (int i = 0; i < N; i++) qw3lm_reset_kv(m, i);
|
|
if (use_cfg)
|
|
for (int i = 0; i < N; i++) qw3lm_reset_kv(m, N + i);
|
|
|
|
// Prefill cond once, set 0, copy to 1..N-1
|
|
Timer t_prefill;
|
|
std::vector<float> prefill_logits(V);
|
|
qw3lm_forward(m, prompt_tokens.data(), (int)prompt_tokens.size(), 0, prefill_logits.data());
|
|
for (int i = 1; i < N; i++)
|
|
qw3lm_copy_kv(m, 0, i);
|
|
|
|
// Prefill uncond once, set N, copy to N+1..2N-1
|
|
std::vector<float> prefill_logits_uncond(V);
|
|
if (use_cfg) {
|
|
qw3lm_forward(m, uncond_tokens->data(), (int)uncond_tokens->size(), N, prefill_logits_uncond.data());
|
|
for (int i = 1; i < N; i++)
|
|
qw3lm_copy_kv(m, N, N + i);
|
|
}
|
|
|
|
fprintf(stderr, "[Phase1] Prefill %.0fms, %zu tokens, N=%d, CFG=%.2f\n",
|
|
t_prefill.ms(), prompt_tokens.size(), N, cfg_scale);
|
|
|
|
if(acestep_lm_dbg)
|
|
{
|
|
std::string tks = bpe_decode(*bpe,prompt_tokens);
|
|
printf("\nN:%d Prompt: %s",prompt_tokens.size(),tks.c_str());
|
|
}
|
|
|
|
// Per-element state
|
|
struct P1Seq {
|
|
std::mt19937 rng;
|
|
MetadataFSM fsm;
|
|
std::vector<int> gen_tokens;
|
|
int last_token;
|
|
bool codes_phase;
|
|
bool done;
|
|
};
|
|
std::vector<P1Seq> seqs(N);
|
|
|
|
// Sample first token from shared prefill logits
|
|
for (int i = 0; i < N; i++) {
|
|
seqs[i].rng.seed((uint32_t)(base_seed + i));
|
|
if (fsm_template) seqs[i].fsm = *fsm_template;
|
|
seqs[i].codes_phase = false;
|
|
seqs[i].done = false;
|
|
|
|
std::vector<float> lg(prefill_logits);
|
|
if (use_cfg) {
|
|
for (int v = 0; v < V; v++)
|
|
lg[v] = prefill_logits_uncond[v] + cfg_scale * (lg[v] - prefill_logits_uncond[v]);
|
|
}
|
|
if (fsm_template && fsm_template->enabled)
|
|
seqs[i].fsm.apply_mask(lg.data());
|
|
|
|
int tok = kcpp_quick_sample(lg.data(),V,quicklastntoks,rep_pen,top_p,top_k,temperature,acestep_lm_rng);
|
|
quicklastntoks.push_back(tok);
|
|
if (quicklastntoks.size()>rep_pen_range) {
|
|
quicklastntoks.erase(quicklastntoks.begin());
|
|
}
|
|
|
|
if (tok == TOKEN_IM_END) {
|
|
seqs[i].done = true;
|
|
} else {
|
|
if (fsm_template && fsm_template->enabled)
|
|
seqs[i].fsm.update(tok);
|
|
if (tok == TOKEN_THINK_END) {
|
|
seqs[i].codes_phase = true;
|
|
if (stop_at_reasoning) seqs[i].done = true;
|
|
}
|
|
seqs[i].gen_tokens.push_back(tok);
|
|
}
|
|
seqs[i].last_token = tok;
|
|
}
|
|
|
|
// KV set arrays + merged CFG arrays
|
|
std::vector<int> cond_sets(N), uncond_sets(N);
|
|
for (int i = 0; i < N; i++) {
|
|
cond_sets[i] = i;
|
|
uncond_sets[i] = N + i;
|
|
}
|
|
|
|
// Batched decode
|
|
Timer t_decode;
|
|
std::vector<float> logits_cond(V * N);
|
|
std::vector<float> logits_uncond(V * N);
|
|
std::vector<int> tokens(N);
|
|
|
|
// CFG: single forward with 2*N (cond + uncond)
|
|
int N2 = use_cfg ? 2 * N : N;
|
|
std::vector<int> tokens_2n(N2), sets_2n(N2);
|
|
std::vector<float> logits_2n((size_t)V * N2);
|
|
if (use_cfg) {
|
|
for (int i = 0; i < N; i++) {
|
|
sets_2n[i] = cond_sets[i];
|
|
sets_2n[N + i] = uncond_sets[i];
|
|
}
|
|
}
|
|
|
|
int n_active = N;
|
|
for (int i = 0; i < N; i++)
|
|
if (seqs[i].done) n_active--;
|
|
|
|
for (int step = 0; step < max_new_tokens && n_active > 0; step++) {
|
|
for (int i = 0; i < N; i++)
|
|
tokens[i] = seqs[i].last_token;
|
|
|
|
if (use_cfg) {
|
|
// Single batched forward: cond[0..N-1] + uncond[N..2N-1]
|
|
for (int i = 0; i < N; i++) {
|
|
tokens_2n[i] = tokens[i];
|
|
tokens_2n[N + i] = tokens[i];
|
|
}
|
|
qw3lm_forward_batch(m, tokens_2n.data(), sets_2n.data(), N2, logits_2n.data());
|
|
memcpy(logits_cond.data(), logits_2n.data(), (size_t)V * N * sizeof(float));
|
|
memcpy(logits_uncond.data(), logits_2n.data() + (size_t)V * N, (size_t)V * N * sizeof(float));
|
|
} else {
|
|
qw3lm_forward_batch(m, tokens.data(), cond_sets.data(), N, logits_cond.data());
|
|
}
|
|
|
|
for (int i = 0; i < N; i++) {
|
|
if (seqs[i].done) continue;
|
|
|
|
float * lc = logits_cond.data() + (size_t)i * V;
|
|
|
|
// CFG combine
|
|
if (use_cfg) {
|
|
float * lu = logits_uncond.data() + (size_t)i * V;
|
|
for (int v = 0; v < V; v++)
|
|
lc[v] = lu[v] + cfg_scale * (lc[v] - lu[v]);
|
|
}
|
|
|
|
// FSM mask (before </think>)
|
|
if (fsm_template && seqs[i].fsm.enabled && !seqs[i].codes_phase)
|
|
seqs[i].fsm.apply_mask(lc);
|
|
|
|
// After </think>: audio code constraint unless lyrics_mode
|
|
if (seqs[i].codes_phase && !lyrics_mode) {
|
|
for (int v = 0; v < AUDIO_CODE_BASE; v++)
|
|
if (v != TOKEN_IM_END) lc[v] = -1e9f;
|
|
}
|
|
|
|
// kcpp: prevent outputting audio codes during lyrics
|
|
if(lyrics_mode)
|
|
{
|
|
for (int v = AUDIO_CODE_BASE; v < AUDIO_CODE_COUNT+AUDIO_CODE_BASE; v++)
|
|
if (v != TOKEN_IM_END) lc[v] = -1e9f;
|
|
}
|
|
|
|
int tok = kcpp_quick_sample(lc,V,quicklastntoks,rep_pen,top_p,top_k,temperature,acestep_lm_rng);
|
|
quicklastntoks.push_back(tok);
|
|
if (quicklastntoks.size()>rep_pen_range) {
|
|
quicklastntoks.erase(quicklastntoks.begin());
|
|
}
|
|
|
|
//kcpp: force lyrics tokens right after think
|
|
if(forced_tokens.size()>0)
|
|
{
|
|
tok = forced_tokens[0];
|
|
forced_tokens.erase(forced_tokens.begin());
|
|
}
|
|
if (tok == TOKEN_THINK_END)
|
|
{
|
|
forced_tokens.clear();
|
|
forced_tokens = think_chain;
|
|
}
|
|
|
|
if(acestep_lm_dbg)
|
|
{
|
|
std::string tks = bpe_decode(*bpe,std::vector<int>({tok}));
|
|
printf("\nDebug temp: %f, top_p:%f, tok:%d = %s (%d)",temperature,top_p,tok,tks.c_str(),forced_tokens.size());
|
|
}
|
|
|
|
if (tok == TOKEN_IM_END) {
|
|
seqs[i].done = true;
|
|
n_active--;
|
|
} else {
|
|
if (seqs[i].fsm.enabled && !seqs[i].codes_phase)
|
|
seqs[i].fsm.update(tok);
|
|
if (tok == TOKEN_THINK_END && !seqs[i].codes_phase) {
|
|
seqs[i].codes_phase = true;
|
|
if (stop_at_reasoning) {
|
|
seqs[i].gen_tokens.push_back(tok);
|
|
seqs[i].done = true;
|
|
n_active--;
|
|
continue;
|
|
}
|
|
}
|
|
seqs[i].gen_tokens.push_back(tok);
|
|
}
|
|
seqs[i].last_token = tok;
|
|
}
|
|
|
|
if ((step + 1) % 100 == 0) {
|
|
double elapsed = t_decode.ms() / 1000.0;
|
|
fprintf(stderr, "[Phase1] step %d, %d active, %.1f tok/s\n",
|
|
step + 1, n_active, (double)(step + 1) * N / elapsed);
|
|
}
|
|
}
|
|
|
|
fprintf(stderr, "[Phase1] Decode %.0fms\n", t_decode.ms());
|
|
|
|
// Decode tokens to text
|
|
std::vector<std::string> results(N);
|
|
for (int i = 0; i < N; i++) {
|
|
results[i] = bpe_decode(*bpe, seqs[i].gen_tokens);
|
|
fprintf(stderr, "[Phase1 Batch%d] seed=%lld, %zu tokens\n",
|
|
i, base_seed + i, seqs[i].gen_tokens.size());
|
|
}
|
|
return results;
|
|
}
|
|
|
|
|
|
// Batched Phase 2: N sequences with potentially different prompts.
|
|
// aces.size() == N: each element gets its own lyrics/metadata.
|
|
// aces.size() == 1: single prompt replicated for all N (prefill once, copy KV).
|
|
// Returns N code strings. Seeds = base_seed + 0, 1, ..., N-1.
|
|
static std::vector<std::string> run_phase2_batch(
|
|
Qwen3LM * m, BPETokenizer & bpe, const std::vector<AcePrompt> & aces,
|
|
float temperature, float top_p, int top_k, float rep_pen, long long base_seed, int N,
|
|
float cfg_scale, const char * negative_prompt) {
|
|
|
|
int V = m->cfg.vocab_size;
|
|
bool use_cfg = cfg_scale > 1.0f;
|
|
bool shared_prompt = ((int)aces.size() == 1);
|
|
|
|
std::vector<int32_t> quicklastntoks;
|
|
|
|
// Build per-element prompts
|
|
std::vector<std::vector<int>> prompts(N), unconds(N);
|
|
int max_tokens = 0;
|
|
for (int i = 0; i < N; i++) {
|
|
const AcePrompt & a = shared_prompt ? aces[0] : aces[i];
|
|
std::string cot = build_cot_yaml(a);
|
|
if (i == 0)
|
|
fprintf(stderr, "[Phase2] N=%d, CoT[0]:\n%s", N, cot.c_str());
|
|
prompts[i] = build_lm_prompt_with_cot(bpe, a, cot);
|
|
if (use_cfg)
|
|
unconds[i] = build_lm_prompt_uncond_with_cot(bpe, a, negative_prompt);
|
|
int mt = (int)(a.duration * 5) + 100;
|
|
if (mt > max_tokens) max_tokens = mt;
|
|
}
|
|
fprintf(stderr, "[Phase2] max_tokens: %d, CFG: %.2f, seeds: %lld..%lld\n",
|
|
max_tokens, cfg_scale, base_seed, base_seed + N - 1);
|
|
|
|
// Reset all KV sets: cond [0..N-1], uncond [N..2N-1]
|
|
for (int i = 0; i < N; i++) qw3lm_reset_kv(m, i);
|
|
if (use_cfg)
|
|
for (int i = 0; i < N; i++) qw3lm_reset_kv(m, N + i);
|
|
|
|
// Prefill: if shared prompt, prefill once + copy KV. Otherwise prefill each.
|
|
Timer t_prefill;
|
|
std::vector<std::vector<float>> prefill_logits_vec(N, std::vector<float>(V));
|
|
|
|
if(acestep_lm_dbg)
|
|
{
|
|
std::string tks = bpe_decode(bpe,prompts[0]);
|
|
printf("\nPhase2: UseCFG:%d, Promptsiz:%d, Prompt: %s",use_cfg,prompts[0].size(),tks.c_str());
|
|
}
|
|
|
|
if (shared_prompt) {
|
|
qw3lm_forward(m, prompts[0].data(), (int)prompts[0].size(), 0, prefill_logits_vec[0].data());
|
|
for (int i = 1; i < N; i++) {
|
|
qw3lm_copy_kv(m, 0, i);
|
|
prefill_logits_vec[i] = prefill_logits_vec[0];
|
|
}
|
|
} else {
|
|
for (int i = 0; i < N; i++)
|
|
qw3lm_forward(m, prompts[i].data(), (int)prompts[i].size(), i, prefill_logits_vec[i].data());
|
|
}
|
|
|
|
// Prefill uncond
|
|
std::vector<std::vector<float>> prefill_logits_uncond_vec(N, std::vector<float>(V));
|
|
if (use_cfg) {
|
|
if (shared_prompt) {
|
|
qw3lm_forward(m, unconds[0].data(), (int)unconds[0].size(), N, prefill_logits_uncond_vec[0].data());
|
|
for (int i = 1; i < N; i++) {
|
|
qw3lm_copy_kv(m, N, N + i);
|
|
prefill_logits_uncond_vec[i] = prefill_logits_uncond_vec[0];
|
|
}
|
|
} else {
|
|
for (int i = 0; i < N; i++)
|
|
qw3lm_forward(m, unconds[i].data(), (int)unconds[i].size(), N + i, prefill_logits_uncond_vec[i].data());
|
|
}
|
|
}
|
|
|
|
double prefill_ms = t_prefill.ms();
|
|
fprintf(stderr, "[Phase2] Prefill %.0fms (%s)\n",
|
|
prefill_ms, shared_prompt ? "shared, 1 cond + 1 uncond" : "individual, N cond + N uncond");
|
|
|
|
// Per-sequence state
|
|
struct BatchSeq {
|
|
std::mt19937 rng;
|
|
std::vector<int> audio_codes;
|
|
int last_token;
|
|
bool done;
|
|
};
|
|
std::vector<BatchSeq> seqs(N);
|
|
|
|
// Sample first token from per-element prefill logits (N different seeds)
|
|
for (int i = 0; i < N; i++) {
|
|
seqs[i].rng.seed((uint32_t)(base_seed + i));
|
|
seqs[i].done = false;
|
|
|
|
std::vector<float> lg(prefill_logits_vec[i]); // copy
|
|
if (use_cfg) {
|
|
float * lu = prefill_logits_uncond_vec[i].data();
|
|
for (int v = 0; v < V; v++)
|
|
lg[v] = lu[v] + cfg_scale * (lg[v] - lu[v]);
|
|
}
|
|
// Only audio codes + EOS (codes_phase = true from start)
|
|
for (int v = 0; v < AUDIO_CODE_BASE; v++)
|
|
if (v != TOKEN_IM_END) lg[v] = -1e9f;
|
|
|
|
int tok = kcpp_quick_sample(lg.data(),V,quicklastntoks,rep_pen,top_p,top_k,temperature,acestep_lm_rng);
|
|
quicklastntoks.push_back(tok);
|
|
if (quicklastntoks.size()>rep_pen_range) {
|
|
quicklastntoks.erase(quicklastntoks.begin());
|
|
}
|
|
|
|
seqs[i].last_token = tok;
|
|
|
|
if (tok == TOKEN_IM_END) {
|
|
seqs[i].done = true;
|
|
} else if (tok >= AUDIO_CODE_BASE && tok < AUDIO_CODE_BASE + AUDIO_CODE_COUNT) {
|
|
seqs[i].audio_codes.push_back(tok - AUDIO_CODE_BASE);
|
|
}
|
|
}
|
|
|
|
// KV set arrays for batched forward
|
|
std::vector<int> cond_sets(N), uncond_sets(N);
|
|
for (int i = 0; i < N; i++) {
|
|
cond_sets[i] = i;
|
|
uncond_sets[i] = N + i;
|
|
}
|
|
|
|
// Batched decode loop
|
|
Timer t_decode;
|
|
std::vector<float> logits_cond(V * N);
|
|
std::vector<float> logits_uncond(V * N);
|
|
std::vector<int> tokens(N);
|
|
|
|
// CFG: single forward with 2*N (cond + uncond)
|
|
int N2 = use_cfg ? 2 * N : N;
|
|
std::vector<int> tokens_2n(N2), sets_2n(N2);
|
|
std::vector<float> logits_2n((size_t)V * N2);
|
|
if (use_cfg) {
|
|
for (int i = 0; i < N; i++) {
|
|
sets_2n[i] = cond_sets[i];
|
|
sets_2n[N + i] = uncond_sets[i];
|
|
}
|
|
}
|
|
|
|
int n_active = N;
|
|
for (int i = 0; i < N; i++)
|
|
if (seqs[i].done) n_active--;
|
|
|
|
for (int step = 0; step < max_tokens && n_active > 0; step++) {
|
|
// Collect tokens (done sequences feed their last token, result ignored)
|
|
for (int i = 0; i < N; i++)
|
|
tokens[i] = seqs[i].last_token;
|
|
|
|
if (use_cfg) {
|
|
// Single batched forward: cond[0..N-1] + uncond[N..2N-1]
|
|
for (int i = 0; i < N; i++) {
|
|
tokens_2n[i] = tokens[i];
|
|
tokens_2n[N + i] = tokens[i];
|
|
}
|
|
qw3lm_forward_batch(m, tokens_2n.data(), sets_2n.data(), N2, logits_2n.data());
|
|
memcpy(logits_cond.data(), logits_2n.data(), (size_t)V * N * sizeof(float));
|
|
memcpy(logits_uncond.data(), logits_2n.data() + (size_t)V * N, (size_t)V * N * sizeof(float));
|
|
} else {
|
|
qw3lm_forward_batch(m, tokens.data(), cond_sets.data(), N, logits_cond.data());
|
|
}
|
|
|
|
// Per-sequence: CFG combine + sample
|
|
for (int i = 0; i < N; i++) {
|
|
if (seqs[i].done) continue;
|
|
|
|
float * lc = logits_cond.data() + (size_t)i * V;
|
|
if (use_cfg) {
|
|
float * lu = logits_uncond.data() + (size_t)i * V;
|
|
for (int v = 0; v < V; v++)
|
|
lc[v] = lu[v] + cfg_scale * (lc[v] - lu[v]);
|
|
}
|
|
|
|
// Only audio codes + EOS
|
|
for (int v = 0; v < AUDIO_CODE_BASE; v++)
|
|
if (v != TOKEN_IM_END) lc[v] = -1e9f;
|
|
|
|
int tok = kcpp_quick_sample(lc,V,quicklastntoks,rep_pen,top_p,top_k,temperature,acestep_lm_rng);
|
|
quicklastntoks.push_back(tok);
|
|
if (quicklastntoks.size()>rep_pen_range) {
|
|
quicklastntoks.erase(quicklastntoks.begin());
|
|
}
|
|
seqs[i].last_token = tok;
|
|
|
|
if (tok == TOKEN_IM_END) {
|
|
seqs[i].done = true;
|
|
n_active--;
|
|
} else if (tok >= AUDIO_CODE_BASE && tok < AUDIO_CODE_BASE + AUDIO_CODE_COUNT) {
|
|
seqs[i].audio_codes.push_back(tok - AUDIO_CODE_BASE);
|
|
}
|
|
}
|
|
|
|
int total_codes = 0;
|
|
for (int i = 0; i < N; i++) total_codes += (int)seqs[i].audio_codes.size();
|
|
|
|
if ((step + 1) % 50 == 0) {
|
|
double elapsed = t_decode.ms() / 1000.0;
|
|
fprintf(stderr, "[Decode] step %d, %d active, %d total codes, %.1f tok/s\n",
|
|
step + 1, n_active, total_codes, (double)(step + 1) * N / elapsed);
|
|
}
|
|
}
|
|
|
|
double decode_ms = t_decode.ms();
|
|
fprintf(stderr, "[Phase2] Decode %.0fms\n", decode_ms);
|
|
|
|
// Build results
|
|
std::vector<std::string> results(N);
|
|
for (int i = 0; i < N; i++) {
|
|
results[i] = codes_to_string(seqs[i].audio_codes);
|
|
fprintf(stderr, "[Batch %d] seed=%lld, %zu codes\n",
|
|
i, base_seed + i, seqs[i].audio_codes.size());
|
|
}
|
|
return results;
|
|
}
|
|
|
|
//
|
|
// CLI
|
|
//
|
|
|
|
static void usage(const char * prog) {
|
|
fprintf(stderr,
|
|
"Usage: %s --request <json> --model <gguf> [options]\n"
|
|
"\n"
|
|
"Required:\n"
|
|
" --request <json> Input request JSON\n"
|
|
" --model <gguf> Model GGUF file\n"
|
|
"\n"
|
|
"Batch:\n"
|
|
" --batch <N> Batch N sequences (default: 1)\n"
|
|
"\n"
|
|
"Output naming: input.json -> input0.json, input1.json, ... (last digit = batch index)\n"
|
|
"\n"
|
|
"Debug:\n"
|
|
" --max-seq <N> KV cache size (default: 8192)\n"
|
|
" --no-fsm Disable FSM constrained decoding\n"
|
|
" --dump-logits <path> Dump prefill logits (binary f32)\n"
|
|
" --dump-tokens <path> Dump prompt token IDs (CSV)\n"
|
|
, prog);
|
|
}
|
|
|
|
//kcpp stuff
|
|
|
|
static Qwen3LM acestep_llm;
|
|
static BPETokenizer acestep_bpe;
|
|
static bool acestep_lm_loaded = false;
|
|
static std::string acestep_lm_path = "";
|
|
static bool acestep_lm_lowvram = false;
|
|
|
|
void unload_acestep_lm()
|
|
{
|
|
if(acestep_lm_loaded)
|
|
{
|
|
acestep_lm_loaded = false;
|
|
qw3lm_free(&acestep_llm);
|
|
printf("Unload Music LM model...\n");
|
|
}
|
|
}
|
|
|
|
bool load_acestep_lm(std::string model_path, bool lowvram, bool musicdebugmode)
|
|
{
|
|
acestep_lm_dbg = musicdebugmode;
|
|
if(acestep_lm_loaded)
|
|
{
|
|
unload_acestep_lm();
|
|
}
|
|
|
|
acestep_lm_lowvram = lowvram;
|
|
acestep_lm_path = model_path;
|
|
acestep_lm_loaded = false;
|
|
int max_seq = 8192;
|
|
const int batch_size = 1; //only bs 1 is allowed
|
|
if (!load_bpe_from_gguf(&acestep_bpe, model_path.c_str())) {
|
|
return false;
|
|
}
|
|
// Load model
|
|
int n_kv_sets = 2 * batch_size;
|
|
if (!qw3lm_load(&acestep_llm, model_path.c_str(), max_seq, n_kv_sets)) {
|
|
return false;
|
|
}
|
|
acestep_lm_loaded = true;
|
|
return true;
|
|
}
|
|
|
|
std::string acestep_prepare_request(const music_generation_inputs inputs)
|
|
{
|
|
if(!acestep_lm_loaded && acestep_lm_path!="")
|
|
{
|
|
printf("\nRuntime reload Music LM model...\n");
|
|
bool ok = load_acestep_lm(acestep_lm_path, acestep_lm_lowvram, acestep_lm_dbg);
|
|
if(!ok)
|
|
{
|
|
printf("\nERROR: Acestep LM load fail\n");
|
|
return "";
|
|
}
|
|
}
|
|
|
|
bool rewrite_caption = inputs.rewrite_caption;
|
|
|
|
const int batch_size = 1;
|
|
bool use_fsm = true;
|
|
MetadataFSM fsm;
|
|
if (use_fsm) {
|
|
fsm.init(acestep_bpe, acestep_llm.cfg.vocab_size,rewrite_caption);
|
|
}
|
|
|
|
// Read request and set essentials
|
|
AceRequest req;
|
|
std::string injson = inputs.input_json;
|
|
if (!request_parse_from_str(&req, injson))
|
|
{
|
|
fprintf(stderr, "\nMusic JSON parse error\n");
|
|
return "";
|
|
}
|
|
|
|
int seed = req.seed;
|
|
if (seed <= 0 || seed==0xFFFFFFFF)
|
|
{
|
|
seed = (((uint32_t)time(NULL)) % 1000000u);
|
|
}
|
|
req.seed = seed;
|
|
acestep_lm_rng = std::mt19937(seed);
|
|
|
|
if (req.caption.empty()) {
|
|
req.caption = "An interesting song";
|
|
}
|
|
|
|
// Generation params from request
|
|
float temperature = req.lm_temperature;
|
|
float top_p = req.lm_top_p;
|
|
int top_k = req.lm_top_k;
|
|
float rep_pen = req.lm_rep_pen;
|
|
float cfg_scale = req.lm_cfg_scale;
|
|
float phase1cfg = 1.0f;
|
|
const char * neg_prompt = req.lm_negative_prompt.c_str();
|
|
float codes_temperature = req.codes_temperature;
|
|
float codes_top_p = req.codes_top_p;
|
|
int codes_top_k = req.codes_top_k;
|
|
|
|
// Copy request -> AcePrompt (internal LLM struct)
|
|
AcePrompt ace = {};
|
|
ace.caption = req.caption;
|
|
ace.lyrics = req.lyrics;
|
|
ace.duration = req.duration;
|
|
ace.bpm = req.bpm;
|
|
ace.keyscale = req.keyscale;
|
|
ace.timesignature = req.timesignature;
|
|
ace.vocal_language = req.vocal_language;
|
|
|
|
bool user_has_codes = !req.audio_codes.empty();
|
|
bool need_lm_codes = inputs.gen_codes && !user_has_codes;
|
|
|
|
bool is_simple = ace.lyrics.empty();
|
|
|
|
std::vector<int> prompt;
|
|
std::vector<AcePrompt> aces; // populated by Phase 1 (simple or partial)
|
|
|
|
caption_tokens.clear();
|
|
if(ace.caption!="" && !rewrite_caption)
|
|
{
|
|
caption_tokens = bpe_encode(&acestep_bpe, ace.caption+"\n", false);
|
|
}
|
|
|
|
// Preprocessor: simple mode generates lyrics + metas from caption
|
|
if (is_simple) {
|
|
fprintf(stderr, "[Simple] Inspiration\n");
|
|
|
|
const char * sys =
|
|
"# Instruction\n"
|
|
"Expand the user's input into a more detailed"
|
|
" and specific musical description:\n";
|
|
std::string user_msg = "# Caption\n"+ace.caption + "\n\ninstrumental: "
|
|
+ std::string(req.instrumental ? "true" : "false");
|
|
prompt = build_custom_prompt(acestep_bpe, sys, user_msg.c_str());
|
|
|
|
// FSM: reset then optionally force language (shared for both paths)
|
|
fsm.reset();
|
|
if (use_fsm && ace.vocal_language != "unknown" && !ace.vocal_language.empty())
|
|
fsm.force_language(acestep_bpe, ace.vocal_language);
|
|
|
|
// Phase 1: N lyrics + metadata generations (always batched, N=batch_size)
|
|
fprintf(stderr, "[Simple] %zu tokens, N=%d, seeds: %lld..%lld\n",
|
|
prompt.size(), batch_size, seed, seed + batch_size - 1);
|
|
|
|
auto phase1_texts = generate_phase1_batch(
|
|
&acestep_llm, &acestep_bpe, prompt, 2048, temperature, top_p, top_k, rep_pen,
|
|
seed, batch_size, use_fsm ? &fsm : nullptr, true);
|
|
|
|
parse_phase1_into_aces(phase1_texts, ace, aces, seed, "Simple", true);
|
|
|
|
for (int i = 0; i < batch_size; i++) qw3lm_reset_kv(&acestep_llm, i);
|
|
}
|
|
|
|
// Re-evaluate after possible simple enrichment
|
|
const AcePrompt & ace_ref = aces.empty() ? ace : aces[0];
|
|
bool has_all_metas = (ace_ref.bpm > 0 && ace_ref.duration > 0 &&
|
|
!ace_ref.keyscale.empty() && !ace_ref.timesignature.empty());
|
|
|
|
if (!has_all_metas) {
|
|
// Partial-metas: Phase 1 with CFG to fill missing fields
|
|
prompt = build_lm_prompt(acestep_bpe, ace);
|
|
std::vector<int> uncond;
|
|
if (phase1cfg > 1.0f)
|
|
uncond = build_lm_prompt_uncond(acestep_bpe, ace, neg_prompt);
|
|
|
|
fprintf(stderr, "[Partial] %zu tokens, CFG: %.2f, N=%d, seeds: %lld..%lld\n",
|
|
prompt.size(), phase1cfg, batch_size, seed, seed + batch_size - 1);
|
|
|
|
fsm.reset();
|
|
auto phase1_texts = generate_phase1_batch(
|
|
&acestep_llm, &acestep_bpe, prompt, 2048, temperature, top_p, top_k, rep_pen,
|
|
seed, batch_size, use_fsm ? &fsm : nullptr, false,
|
|
phase1cfg, uncond.empty() ? nullptr : &uncond, true);
|
|
|
|
parse_phase1_into_aces(phase1_texts, ace, aces, seed, "Partial", false);
|
|
|
|
for (int i = 0; i < 2 * batch_size; i++) qw3lm_reset_kv(&acestep_llm, i);
|
|
}
|
|
fsm.reset();
|
|
|
|
// Guarantee aces is populated (all-metas: single shared ace for prefill optimization)
|
|
if (aces.empty()) {
|
|
aces = { ace };
|
|
}
|
|
|
|
// Phase 2: generate audio codes (always batched, N=batch_size)
|
|
std::vector<std::string> batch_codes(batch_size);
|
|
if (need_lm_codes) {
|
|
batch_codes = run_phase2_batch(&acestep_llm, acestep_bpe, aces,
|
|
temperature, top_p, top_k, rep_pen, seed, batch_size, cfg_scale, neg_prompt);
|
|
} else {
|
|
fprintf(stderr, "[Skip] %s, no code generation\n",
|
|
user_has_codes ? "user codes present" : "thinking=false");
|
|
}
|
|
|
|
// only batch size 1 is allowed
|
|
AceRequest rr = req;
|
|
const AcePrompt & a = aces[0];
|
|
rr.caption = a.caption;
|
|
rr.lyrics = a.lyrics;
|
|
rr.bpm = a.bpm;
|
|
rr.duration = a.duration;
|
|
rr.keyscale = a.keyscale;
|
|
rr.timesignature = a.timesignature;
|
|
rr.vocal_language = a.vocal_language;
|
|
if (!batch_codes[0].empty()) rr.audio_codes = batch_codes[0];
|
|
|
|
std::string prefix_erase = "# Lyric";
|
|
// Check if the string is long enough and starts with the prefix
|
|
if (rr.lyrics.size() >= prefix_erase.size() && rr.lyrics.compare(0, prefix_erase.size(), prefix_erase) == 0) {
|
|
rr.lyrics = rr.lyrics.substr(prefix_erase.size()); // Returns a new string starting after the prefix
|
|
}
|
|
|
|
prefix_erase = "keyscale:";
|
|
// Check if the string is long enough and starts with the prefix
|
|
if (rr.keyscale.size() >= prefix_erase.size() && rr.keyscale.compare(0, prefix_erase.size(), prefix_erase) == 0) {
|
|
rr.keyscale = rr.keyscale.substr(prefix_erase.size()); // Returns a new string starting after the prefix
|
|
}
|
|
|
|
std::string delimiter = "durationduration";
|
|
size_t ddpos = rr.caption.find(delimiter);
|
|
rr.caption = ((ddpos != std::string::npos) ? rr.caption.substr(0, ddpos) : rr.caption);
|
|
|
|
//now convert to string
|
|
std::ostringstream oss;
|
|
oss << "{\n";
|
|
oss << " \"caption\": \"" << json_escape(rr.caption) << "\",\n";
|
|
oss << " \"lyrics\": \"" << json_escape(rr.lyrics) << "\",\n";
|
|
if (rr.instrumental) {
|
|
oss << " \"instrumental\": true,\n";
|
|
}
|
|
oss << " \"bpm\": " << rr.bpm << ",\n";
|
|
oss << " \"duration\": " << std::fixed << std::setprecision(1) << rr.duration << ",\n";
|
|
oss << " \"keyscale\": \"" << json_escape(rr.keyscale) << "\",\n";
|
|
oss << " \"timesignature\": \"" << json_escape(rr.timesignature) << "\",\n";
|
|
oss << " \"vocal_language\": \"" << json_escape(rr.vocal_language) << "\",\n";
|
|
oss << " \"task_type\": \"" << json_escape(rr.task_type) << "\",\n";
|
|
oss << " \"seed\": " << rr.seed << ",\n";
|
|
oss << " \"thinking\": " << (rr.thinking ? "true" : "false") << ",\n";
|
|
oss << " \"lm_temperature\": " << std::fixed << std::setprecision(2) << rr.lm_temperature << ",\n";
|
|
oss << " \"lm_cfg_scale\": " << std::fixed << std::setprecision(1) << rr.lm_cfg_scale << ",\n";
|
|
oss << " \"lm_top_p\": " << std::fixed << std::setprecision(2) << rr.lm_top_p << ",\n";
|
|
oss << " \"lm_top_k\": " << std::fixed << std::setprecision(1) << rr.lm_top_k << ",\n";
|
|
oss << " \"codes_temperature\": " << std::fixed << std::setprecision(2) << rr.codes_temperature << ",\n";
|
|
oss << " \"codes_top_p\": " << std::fixed << std::setprecision(2) << rr.codes_top_p << ",\n";
|
|
oss << " \"codes_top_k\": " << std::fixed << std::setprecision(1) << rr.codes_top_k << ",\n";
|
|
oss << " \"lm_rep_pen\": " << std::fixed << std::setprecision(2) << rr.lm_rep_pen << ",\n";
|
|
oss << " \"lm_negative_prompt\": \"" << json_escape(rr.lm_negative_prompt) << "\",\n";
|
|
oss << " \"inference_steps\": " << rr.inference_steps << ",\n";
|
|
oss << " \"guidance_scale\": " << std::fixed << std::setprecision(1) << rr.guidance_scale << ",\n";
|
|
oss << " \"shift\": " << std::fixed << std::setprecision(1) << rr.shift << ",\n";
|
|
oss << " \"audio_codes\": \"" << json_escape(rr.audio_codes) << "\"\n";
|
|
oss << "}\n";
|
|
std::string output_json = oss.str();
|
|
|
|
if(acestep_lm_lowvram)
|
|
{
|
|
unload_acestep_lm();
|
|
}
|
|
|
|
return output_json;
|
|
}
|
|
|