mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-20 17:54:31 +00:00
music lm finally working
This commit is contained in:
parent
cf042af701
commit
edbc4fe592
4 changed files with 89 additions and 68 deletions
|
|
@ -156,7 +156,9 @@ audio{width:100%;margin-top:6px;}
|
|||
</head>
|
||||
<body>
|
||||
|
||||
<header>🎵 KoboldCpp Music Generation</header>
|
||||
<header>
|
||||
<img style="width: 36px; height: 36px; vertical-align: middle; margin-right:4px" src="data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAEAAAABACAMAAACdt4HsAAAAAXNSR0IB2cksfwAAAAlwSFlzAAALEwAACxMBAJqcGAAAADxQTFRFS2Si+X5+pmBfHyApLjZSS2SjP057Vzw5EA4Sf1ZT+9Sv1WpqnYx/7qaYw7vUAAAAS2Sj9PPzgnrLS2SjAzrF9gAAABR0Uk5T///////w////////////AKj//yMlHqVpAAAD3klEQVR4nKWXi7KjIAyGFSgxEjhV3/9d90+8onZPd810prWSDwi50fyoTNP7/X79g2D4NJlqo+rvV/Mf8npPM2B6/4+6ihKaB/pGaH4e6IPw00y3+48xhBC3J32Id+NeUzN9UPfer4RoD/eIqbnuwLS7zncLAfqdPvvDmvY9XAE6vuuImEAw8fNT1/kr4Qqw+YhdIocfJl0glxyTvyG8m7MNY1B9diAkmgGUODnH7Km7AF53AGEjUJtWYdUPzn0LyC6AQO0qCUCi1PKXAM5tCwXeAC0ROf36AqA2VACmbQ8yP9DVimeA6lPKkLaW3EPylXAARBXV701OhOVPI6hcAXH1mTyP7e8AMyEc4mQDzP7XrfOfl5D7ndAdfXID6NwMyXACEpEbgPTCLJn1hEGoAep/OKheQiCEEhj1HgBQX1ZxQMPLlyVsABwejkp8EGEQAkxRA4RgIRYhTxme1fkKoBZwAHjLA+b/cgLQ8gZ4gZ+tVtgAnboaa+Lg0IwRhBqAmX0cI0WFqHN3FUAXAOPpzIWhPzZYQgUAu4ljiaKTaKwtZtwAIdv8XkocR9+UYM5/BMTRxzJKsWEu+RPAAsBxKSWWgTHS18cofiwhlCJD4cApUb0CNWKA/5dhwAqKD2UIXAEoFgUMkIJTCCcjzkGE890BQhXA685WQNqD6ujKWDRhhI7EdKUCtKSGxd8ASEr+6sqNApKPeD/iFEpT6nAUcAMgMmBzqwVPgJCd80X3AIlDDcjSzH8PJbD7AGiT020WjfcCN0jI5WwJGk5axP4eikeyvQd4HE5i7I4xEpWANKg0m2p0OUIcQKJnd7uCaABMRebOSOoB1WUVYACzaGSs012NaI5gAC0GcPWD9iLI6/qVdGeXY7R6xu1M0FAhG7s865ctw97Zoz85kuXi5T2EbaZatLileQA+VifrYGrT7ruL+lbZ0orYcXQJpry/tl+26l1s8sOy+BxMqKjr23nf7mhFnktbOgJOGQmnVG0ZVve06VvDUFmEztGIhHAy2YHA+qsCuFNS1T0Edf41AOZ1b7uwH1tYYFA4p3U1owiOOu+AsyxrQ3AIXwrLXtryL4BPpW0rrvMaPgHSx+K6l3cj3Oin1lH6S3nfd+KDa51lAjJhE6ddz7XRu29xUH51O95SgNOahDTB3PPvLc7cZPWYEVlVlp5AkGtJK/63XZoq0jBsvUrPeNDvr/tE1SnD3qxIEVuNfAsY0J9w4Ux2ZKizHPLHFdw127r7HIS2ZpvFTHHbbN+3+2Qm29p9NvXv2v3twkHHCwd9vnA8vvI8vnQ9vvY9v3g+vvo+v3w/u/7/AZoAPJwrbZ1IAAAAAElFTkSuQmCC"/>
|
||||
<span>KoboldCpp Music Generation UI</span></header>
|
||||
|
||||
<div class="wrapper">
|
||||
|
||||
|
|
@ -297,10 +299,15 @@ function getFormData(){
|
|||
}
|
||||
|
||||
function updateForm(data){
|
||||
//let origseed = document.getElementById("seed").value;
|
||||
Object.keys(data).forEach(k=>{
|
||||
if(document.getElementById(k))
|
||||
document.getElementById(k).value=data[k]??"";
|
||||
});
|
||||
//if(origseed=="-1" || origseed=="")
|
||||
//{
|
||||
// document.getElementById("seed").value = "-1";
|
||||
//}
|
||||
}
|
||||
|
||||
function deriveTitle(caption){
|
||||
|
|
|
|||
|
|
@ -169,6 +169,9 @@ struct AcePrompt {
|
|||
std::string vocal_language;
|
||||
};
|
||||
|
||||
static std::mt19937 acestep_lm_rng;
|
||||
static bool acestep_lm_dbg = false;
|
||||
|
||||
//
|
||||
// CoT parsing (extract metadata + lyrics from LLM Phase1 output)
|
||||
//
|
||||
|
|
@ -758,13 +761,16 @@ static void parse_phase1_into_aces(
|
|||
}
|
||||
}
|
||||
|
||||
//hack for kcpp: forcing the correct tokens after end of thinking
|
||||
const std::vector<int> think_chain = {271,2,15953,2216,198}; // "\n# Lyric\n"
|
||||
|
||||
// Batched Phase 1: N text generations with shared prompt, different seeds.
|
||||
// No CFG. Each element gets its own FSM state and RNG.
|
||||
// Returns N generated text strings.
|
||||
static std::vector<std::string> generate_phase1_batch(
|
||||
Qwen3LM * m, BPETokenizer * bpe,
|
||||
const std::vector<int> & prompt_tokens,
|
||||
int max_new_tokens, float temperature, float top_p, int top_k,
|
||||
int max_new_tokens, float temperature, float top_p,
|
||||
long long base_seed, int N,
|
||||
MetadataFSM * fsm_template,
|
||||
bool lyrics_mode,
|
||||
|
|
@ -798,6 +804,12 @@ static std::vector<std::string> generate_phase1_batch(
|
|||
fprintf(stderr, "[Phase1] Prefill %.0fms, %zu tokens, N=%d, CFG=%.2f\n",
|
||||
t_prefill.ms(), prompt_tokens.size(), N, cfg_scale);
|
||||
|
||||
if(acestep_lm_dbg)
|
||||
{
|
||||
std::string tks = bpe_decode(*bpe,prompt_tokens);
|
||||
printf("\nN:%d Prompt: %s",prompt_tokens.size(),tks.c_str());
|
||||
}
|
||||
|
||||
// Per-element state
|
||||
struct P1Seq {
|
||||
std::mt19937 rng;
|
||||
|
|
@ -824,7 +836,7 @@ static std::vector<std::string> generate_phase1_batch(
|
|||
if (fsm_template && fsm_template->enabled)
|
||||
seqs[i].fsm.apply_mask(lg.data());
|
||||
|
||||
int tok = sample_top_k_p(lg.data(), V, temperature, top_p, top_k, seqs[i].rng);
|
||||
int tok = kcpp_quick_sample(lg.data(),V,std::vector<int32_t>(),1.03f,top_p,25,temperature,acestep_lm_rng);
|
||||
|
||||
if (tok == TOKEN_IM_END) {
|
||||
seqs[i].done = true;
|
||||
|
|
@ -853,37 +865,20 @@ static std::vector<std::string> generate_phase1_batch(
|
|||
std::vector<float> logits_uncond(V * N);
|
||||
std::vector<int> tokens(N);
|
||||
|
||||
// CFG: single forward with 2*N (cond + uncond)
|
||||
int N2 = use_cfg ? 2 * N : N;
|
||||
std::vector<int> tokens_2n(N2), sets_2n(N2);
|
||||
std::vector<float> logits_2n((size_t)V * N2);
|
||||
if (use_cfg) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
sets_2n[i] = cond_sets[i];
|
||||
sets_2n[N + i] = uncond_sets[i];
|
||||
}
|
||||
}
|
||||
|
||||
int n_active = N;
|
||||
for (int i = 0; i < N; i++)
|
||||
if (seqs[i].done) n_active--;
|
||||
|
||||
std::vector<int32_t> quicklastntoks;
|
||||
std::vector<int32_t> forced_tokens;
|
||||
|
||||
for (int step = 0; step < max_new_tokens && n_active > 0; step++) {
|
||||
for (int i = 0; i < N; i++)
|
||||
tokens[i] = seqs[i].last_token;
|
||||
|
||||
if (use_cfg) {
|
||||
// Single batched forward: cond[0..N-1] + uncond[N..2N-1]
|
||||
for (int i = 0; i < N; i++) {
|
||||
tokens_2n[i] = tokens[i];
|
||||
tokens_2n[N + i] = tokens[i];
|
||||
}
|
||||
qw3lm_forward_batch(m, tokens_2n.data(), sets_2n.data(), N2, logits_2n.data());
|
||||
memcpy(logits_cond.data(), logits_2n.data(), (size_t)V * N * sizeof(float));
|
||||
memcpy(logits_uncond.data(), logits_2n.data() + (size_t)V * N, (size_t)V * N * sizeof(float));
|
||||
} else {
|
||||
qw3lm_forward_batch(m, tokens.data(), cond_sets.data(), N, logits_cond.data());
|
||||
}
|
||||
qw3lm_forward_batch(m, tokens.data(), cond_sets.data(), N, logits_cond.data());
|
||||
if (use_cfg)
|
||||
qw3lm_forward_batch(m, tokens.data(), uncond_sets.data(), N, logits_uncond.data());
|
||||
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (seqs[i].done) continue;
|
||||
|
|
@ -902,12 +897,38 @@ static std::vector<std::string> generate_phase1_batch(
|
|||
seqs[i].fsm.apply_mask(lc);
|
||||
|
||||
// After </think>: audio code constraint unless lyrics_mode
|
||||
if (seqs[i].codes_phase && !lyrics_mode) {
|
||||
for (int v = 0; v < AUDIO_CODE_BASE; v++)
|
||||
if (v != TOKEN_IM_END) lc[v] = -1e9f;
|
||||
// if (seqs[i].codes_phase && !lyrics_mode) {
|
||||
// for (int v = 0; v < AUDIO_CODE_BASE; v++)
|
||||
// if (v != TOKEN_IM_END) lc[v] = -1e9f;
|
||||
// }
|
||||
|
||||
// kcpp: prevent outputting audio codes
|
||||
for (int v = AUDIO_CODE_BASE; v < AUDIO_CODE_COUNT+AUDIO_CODE_BASE; v++)
|
||||
if (v != TOKEN_IM_END) lc[v] = -1e9f;
|
||||
|
||||
int tok = kcpp_quick_sample(lc,V,quicklastntoks,1.03f,top_p,25,temperature,acestep_lm_rng);
|
||||
quicklastntoks.push_back(tok);
|
||||
if (quicklastntoks.size()>32) {
|
||||
quicklastntoks.erase(quicklastntoks.begin());
|
||||
}
|
||||
|
||||
int tok = sample_top_k_p(lc, V, temperature, top_p, top_k, seqs[i].rng);
|
||||
//kcpp: force lyrics tokens right after think
|
||||
if(forced_tokens.size()>0)
|
||||
{
|
||||
tok = forced_tokens[0];
|
||||
forced_tokens.erase(forced_tokens.begin());
|
||||
}
|
||||
if (tok == TOKEN_THINK_END)
|
||||
{
|
||||
forced_tokens.clear();
|
||||
forced_tokens = think_chain;
|
||||
}
|
||||
|
||||
if(acestep_lm_dbg)
|
||||
{
|
||||
std::string tks = bpe_decode(*bpe,std::vector<int>({tok}));
|
||||
printf("\nDebug temp: %f, top_p:%f, tok:%d = %s (%d)",temperature,top_p,tok,tks.c_str(),forced_tokens.size());
|
||||
}
|
||||
|
||||
if (tok == TOKEN_IM_END) {
|
||||
seqs[i].done = true;
|
||||
|
|
@ -955,7 +976,7 @@ static std::vector<std::string> generate_phase1_batch(
|
|||
// Returns N code strings. Seeds = base_seed + 0, 1, ..., N-1.
|
||||
static std::vector<std::string> run_phase2_batch(
|
||||
Qwen3LM * m, BPETokenizer & bpe, const std::vector<AcePrompt> & aces,
|
||||
float temperature, float top_p, int top_k, long long base_seed, int N,
|
||||
float temperature, float top_p, long long base_seed, int N,
|
||||
float cfg_scale, const char * negative_prompt) {
|
||||
|
||||
int V = m->cfg.vocab_size;
|
||||
|
|
@ -1042,7 +1063,7 @@ static std::vector<std::string> run_phase2_batch(
|
|||
for (int v = 0; v < AUDIO_CODE_BASE; v++)
|
||||
if (v != TOKEN_IM_END) lg[v] = -1e9f;
|
||||
|
||||
int tok = sample_top_k_p(lg.data(), V, temperature, top_p, top_k, seqs[i].rng);
|
||||
int tok = kcpp_quick_sample(lg.data(),V,std::vector<int32_t>(),1.00f,top_p,25,temperature,acestep_lm_rng);
|
||||
seqs[i].last_token = tok;
|
||||
|
||||
if (tok == TOKEN_IM_END) {
|
||||
|
|
@ -1065,17 +1086,6 @@ static std::vector<std::string> run_phase2_batch(
|
|||
std::vector<float> logits_uncond(V * N);
|
||||
std::vector<int> tokens(N);
|
||||
|
||||
// CFG: single forward with 2*N (cond + uncond)
|
||||
int N2 = use_cfg ? 2 * N : N;
|
||||
std::vector<int> tokens_2n(N2), sets_2n(N2);
|
||||
std::vector<float> logits_2n((size_t)V * N2);
|
||||
if (use_cfg) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
sets_2n[i] = cond_sets[i];
|
||||
sets_2n[N + i] = uncond_sets[i];
|
||||
}
|
||||
}
|
||||
|
||||
int n_active = N;
|
||||
for (int i = 0; i < N; i++)
|
||||
if (seqs[i].done) n_active--;
|
||||
|
|
@ -1085,18 +1095,12 @@ static std::vector<std::string> run_phase2_batch(
|
|||
for (int i = 0; i < N; i++)
|
||||
tokens[i] = seqs[i].last_token;
|
||||
|
||||
if (use_cfg) {
|
||||
// Single batched forward: cond[0..N-1] + uncond[N..2N-1]
|
||||
for (int i = 0; i < N; i++) {
|
||||
tokens_2n[i] = tokens[i];
|
||||
tokens_2n[N + i] = tokens[i];
|
||||
}
|
||||
qw3lm_forward_batch(m, tokens_2n.data(), sets_2n.data(), N2, logits_2n.data());
|
||||
memcpy(logits_cond.data(), logits_2n.data(), (size_t)V * N * sizeof(float));
|
||||
memcpy(logits_uncond.data(), logits_2n.data() + (size_t)V * N, (size_t)V * N * sizeof(float));
|
||||
} else {
|
||||
qw3lm_forward_batch(m, tokens.data(), cond_sets.data(), N, logits_cond.data());
|
||||
}
|
||||
// Batched forward: cond
|
||||
qw3lm_forward_batch(m, tokens.data(), cond_sets.data(), N, logits_cond.data());
|
||||
|
||||
// Batched forward: uncond
|
||||
if (use_cfg)
|
||||
qw3lm_forward_batch(m, tokens.data(), uncond_sets.data(), N, logits_uncond.data());
|
||||
|
||||
// Per-sequence: CFG combine + sample
|
||||
for (int i = 0; i < N; i++) {
|
||||
|
|
@ -1113,7 +1117,7 @@ static std::vector<std::string> run_phase2_batch(
|
|||
for (int v = 0; v < AUDIO_CODE_BASE; v++)
|
||||
if (v != TOKEN_IM_END) lc[v] = -1e9f;
|
||||
|
||||
int tok = sample_top_k_p(lc, V, temperature, top_p, top_k, seqs[i].rng);
|
||||
int tok = kcpp_quick_sample(lc,V,std::vector<int32_t>(),1.00f,top_p,25,temperature,acestep_lm_rng);
|
||||
seqs[i].last_token = tok;
|
||||
|
||||
if (tok == TOKEN_IM_END) {
|
||||
|
|
@ -1436,8 +1440,9 @@ void unload_acestep_lm()
|
|||
}
|
||||
}
|
||||
|
||||
bool load_acestep_lm(std::string model_path, bool lowvram)
|
||||
bool load_acestep_lm(std::string model_path, bool lowvram, bool musicdebugmode)
|
||||
{
|
||||
acestep_lm_dbg = musicdebugmode;
|
||||
if(acestep_lm_loaded)
|
||||
{
|
||||
unload_acestep_lm();
|
||||
|
|
@ -1465,7 +1470,7 @@ std::string acestep_prepare_request(const music_generation_inputs inputs)
|
|||
if(!acestep_lm_loaded && acestep_lm_path!="")
|
||||
{
|
||||
printf("\nRuntime reload Music LM model...\n");
|
||||
bool ok = load_acestep_lm(acestep_lm_path, acestep_lm_lowvram);
|
||||
bool ok = load_acestep_lm(acestep_lm_path, acestep_lm_lowvram, acestep_lm_dbg);
|
||||
if(!ok)
|
||||
{
|
||||
printf("\nERROR: Acestep LM load fail\n");
|
||||
|
|
@ -1495,6 +1500,11 @@ std::string acestep_prepare_request(const music_generation_inputs inputs)
|
|||
seed = (((uint32_t)time(NULL)) % 1000000u);
|
||||
}
|
||||
req.seed = seed;
|
||||
acestep_lm_rng = std::mt19937(seed);
|
||||
|
||||
if (req.caption.empty()) {
|
||||
req.caption = "An interesting song";
|
||||
}
|
||||
|
||||
// Generation params from request
|
||||
float temperature = req.lm_temperature;
|
||||
|
|
@ -1518,7 +1528,7 @@ std::string acestep_prepare_request(const music_generation_inputs inputs)
|
|||
req.audio_codes = "";
|
||||
|
||||
bool user_has_codes = !req.audio_codes.empty();
|
||||
bool need_lm_codes = req.thinking && !user_has_codes;
|
||||
bool need_lm_codes = false;//req.thinking && !user_has_codes;
|
||||
|
||||
bool is_simple = ace.lyrics.empty();
|
||||
|
||||
|
|
@ -1547,7 +1557,7 @@ std::string acestep_prepare_request(const music_generation_inputs inputs)
|
|||
prompt.size(), batch_size, seed, seed + batch_size - 1);
|
||||
|
||||
auto phase1_texts = generate_phase1_batch(
|
||||
&acestep_llm, &acestep_bpe, prompt, 2048, temperature, 0.95f, 40,
|
||||
&acestep_llm, &acestep_bpe, prompt, 2048, temperature, top_p,
|
||||
seed, batch_size, use_fsm ? &fsm : nullptr, true);
|
||||
|
||||
parse_phase1_into_aces(phase1_texts, ace, aces, seed, "Simple", true);
|
||||
|
|
@ -1572,7 +1582,7 @@ std::string acestep_prepare_request(const music_generation_inputs inputs)
|
|||
|
||||
fsm.reset();
|
||||
auto phase1_texts = generate_phase1_batch(
|
||||
&acestep_llm, &acestep_bpe, prompt, 2048, temperature, top_p, top_k,
|
||||
&acestep_llm, &acestep_bpe, prompt, 2048, temperature, top_p,
|
||||
seed, batch_size, use_fsm ? &fsm : nullptr, false,
|
||||
cfg_scale, uncond.empty() ? nullptr : &uncond, true);
|
||||
|
||||
|
|
@ -1590,7 +1600,7 @@ std::string acestep_prepare_request(const music_generation_inputs inputs)
|
|||
std::vector<std::string> batch_codes(batch_size);
|
||||
if (need_lm_codes) {
|
||||
batch_codes = run_phase2_batch(&acestep_llm, acestep_bpe, aces,
|
||||
temperature, top_p, top_k, seed, batch_size, cfg_scale, neg_prompt);
|
||||
temperature, top_p, seed, batch_size, cfg_scale, neg_prompt);
|
||||
} else {
|
||||
fprintf(stderr, "[Skip] %s, no code generation\n",
|
||||
user_has_codes ? "user codes present" : "thinking=false");
|
||||
|
|
@ -1607,7 +1617,6 @@ std::string acestep_prepare_request(const music_generation_inputs inputs)
|
|||
rr.timesignature = a.timesignature;
|
||||
rr.vocal_language = a.vocal_language;
|
||||
if (!batch_codes[0].empty()) rr.audio_codes = batch_codes[0];
|
||||
rr.seed = seed;
|
||||
|
||||
std::string prefix_erase = "# Lyric";
|
||||
// Check if the string is long enough and starts with the prefix
|
||||
|
|
@ -1615,6 +1624,12 @@ std::string acestep_prepare_request(const music_generation_inputs inputs)
|
|||
rr.lyrics = rr.lyrics.substr(prefix_erase.size()); // Returns a new string starting after the prefix
|
||||
}
|
||||
|
||||
prefix_erase = "keyscale:";
|
||||
// Check if the string is long enough and starts with the prefix
|
||||
if (rr.keyscale.size() >= prefix_erase.size() && rr.keyscale.compare(0, prefix_erase.size(), prefix_erase) == 0) {
|
||||
rr.keyscale = rr.keyscale.substr(prefix_erase.size()); // Returns a new string starting after the prefix
|
||||
}
|
||||
|
||||
//now convert to string
|
||||
std::ostringstream oss;
|
||||
oss << "{\n";
|
||||
|
|
|
|||
|
|
@ -774,10 +774,9 @@ std::string acestep_generate_audio(const music_generation_inputs inputs)
|
|||
guidance_scale = 1.0f;
|
||||
}
|
||||
|
||||
if (seed < 0) {
|
||||
std::random_device rd;
|
||||
seed = (long long)rd() << 32 | rd();
|
||||
if (seed < 0) seed = -seed;
|
||||
if (seed <= 0 || seed==0xFFFFFFFF)
|
||||
{
|
||||
seed = (((uint32_t)time(NULL)) % 1000000u);
|
||||
}
|
||||
fprintf(stderr, "[Pipeline] seed=%lld, steps=%d, guidance=%.1f, shift=%.1f, duration=%.1fs\n",
|
||||
seed, num_steps, guidance_scale, shift, duration);
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ bool musictype_load_model(const music_load_model_inputs inputs)
|
|||
musicllm_filename.c_str(),musicembedding_filename.c_str(),musicdiffusion_filename.c_str(),musicvae_filename.c_str());
|
||||
musicdebugmode = inputs.debugmode;
|
||||
|
||||
bool ok = load_acestep_lm(musicllm_filename,lowvram);
|
||||
bool ok = load_acestep_lm(musicllm_filename,lowvram,musicdebugmode);
|
||||
if (!ok) {
|
||||
printf("\nFailed to load Music Gen LM Model!\n");
|
||||
return false;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue