ace step codes generation now working

This commit is contained in:
Concedo 2026-02-23 00:27:26 +08:00
parent 71d42fae85
commit 4be93db21c
4 changed files with 56 additions and 14 deletions

View file

@ -341,11 +341,13 @@ struct music_load_model_inputs
};
struct music_generation_inputs
{
const char * prompt = nullptr;
const bool is_codes = false; //if true, generate codes, else, generate diffusion music
const char * caption = nullptr;
};
struct music_generation_outputs
{
int status = -1;
const char * codes_json = "";
};
extern std::string executable_path;

View file

@ -450,10 +450,12 @@ class music_load_model_inputs(ctypes.Structure):
("debugmode", ctypes.c_int)]
class music_generation_inputs(ctypes.Structure):
_fields_ = [("prompt", ctypes.c_char_p)]
_fields_ = [("is_codes", ctypes.c_bool),
("caption", ctypes.c_char_p)]
class music_generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int)]
_fields_ = [("status", ctypes.c_int),
("codes_json", ctypes.c_char_p)]
class StdoutRedirector:
def __init__(self, writer):
@ -2372,15 +2374,16 @@ def music_load_model(musicllm,musicembedding,musicdiffusion,musicvae):
ret = handle.music_load_model(inputs)
return ret
def music_generate(genparams):
def music_generate_codes(genparams):
global args
prompt = genparams.get("prompt", "")
caption = genparams.get("caption", "interesting music song")
inputs = music_generation_inputs()
inputs.prompt = prompt.encode("UTF-8")
inputs.is_codes = True
inputs.caption = caption.encode("UTF-8")
ret = handle.music_generate(inputs)
outstr = ""
if ret.status==1:
outstr = ret.data.decode("UTF-8","ignore")
outstr = ret.codes_json.decode("UTF-8","ignore")
return outstr
def tokenize_ids(countprompt,tcaddspecial):

View file

@ -16,6 +16,7 @@
#include <chrono>
#include <map>
#include <unordered_map>
#include <iomanip>
// Timer
struct Timer {
@ -1120,7 +1121,7 @@ bool load_acestep(std::string model_path)
return true;
}
AceRequest acestep_prepare_request()
std::string acestep_prepare_request(const music_generation_inputs inputs)
{
const int batch_size = 1;
bool use_fsm = true;
@ -1132,7 +1133,7 @@ AceRequest acestep_prepare_request()
// Read request and set essentials
AceRequest req;
request_init(&req);
req.caption = "doom";
req.caption = inputs.caption;
req.lyrics = ""; //can be overridden or left auto
req.inference_steps = 8;
req.vocal_language = "en";
@ -1254,7 +1255,33 @@ AceRequest acestep_prepare_request()
if (!batch_codes[0].empty()) rr.audio_codes = batch_codes[0];
rr.seed = seed;
return rr;
//now convert to string
std::ostringstream oss;
oss << "{\n";
oss << " \"caption\": \"" << json_escape(rr.caption) << "\",\n";
oss << " \"lyrics\": \"" << json_escape(rr.lyrics) << "\",\n";
if (rr.instrumental) {
oss << " \"instrumental\": true,\n";
}
oss << " \"bpm\": " << rr.bpm << ",\n";
oss << " \"duration\": " << std::fixed << std::setprecision(1) << rr.duration << ",\n";
oss << " \"keyscale\": \"" << json_escape(rr.keyscale) << "\",\n";
oss << " \"timesignature\": \"" << json_escape(rr.timesignature) << "\",\n";
oss << " \"vocal_language\": \"" << json_escape(rr.vocal_language) << "\",\n";
oss << " \"task_type\": \"" << json_escape(rr.task_type) << "\",\n";
oss << " \"seed\": " << rr.seed << ",\n";
oss << " \"thinking\": " << (rr.thinking ? "true" : "false") << ",\n";
oss << " \"lm_temperature\": " << std::fixed << std::setprecision(2) << rr.lm_temperature << ",\n";
oss << " \"lm_cfg_scale\": " << std::fixed << std::setprecision(1) << rr.lm_cfg_scale << ",\n";
oss << " \"lm_top_p\": " << std::fixed << std::setprecision(2) << rr.lm_top_p << ",\n";
oss << " \"lm_negative_prompt\": \"" << json_escape(rr.lm_negative_prompt) << "\",\n";
oss << " \"inference_steps\": " << rr.inference_steps << ",\n";
oss << " \"guidance_scale\": " << std::fixed << std::setprecision(1) << rr.guidance_scale << ",\n";
oss << " \"shift\": " << std::fixed << std::setprecision(1) << rr.shift << ",\n";
oss << " \"audio_codes\": \"" << json_escape(rr.audio_codes) << "\"\n";
oss << "}\n";
std::string output_json = oss.str();
return output_json;
}
void unload_acestep()

View file

@ -24,6 +24,8 @@ static bool music_is_quiet = false;
static bool musicgen_loaded = false;
static std::string musicvulkandeviceenv;
static std::string codes_json_str = "";
bool musictype_load_model(const music_load_model_inputs inputs)
{
music_is_quiet = inputs.quiet;
@ -72,14 +74,22 @@ music_generation_outputs musictype_generate(const music_generation_inputs inputs
{
printf("\nWarning: KCPP music gen not initialized!\n");
output.status = 0;
output.codes_json = "";
return output;
}
if(!music_is_quiet)
{
printf("\nMusic Gen Generating...");
if (inputs.is_codes) {
if (!music_is_quiet) {
printf("\nMusic Gen Generating Codes...");
}
codes_json_str = acestep_prepare_request(inputs);
output.status = 1;
output.codes_json = codes_json_str.c_str();
if (!music_is_quiet) {
printf("\nMusic Gen Codes Done:\n%s\n",codes_json_str.c_str());
}
} else {
}
output.status = 1;
return output;
}