From c2b0cb26a8cd20a5ce7a6c516677787f37ef010a Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Mon, 23 Feb 2026 14:04:45 +0800 Subject: [PATCH] ace step codes api --- expose.h | 2 +- koboldcpp.py | 49 ++++++++++++++++++++++++++--- otherarch/acestep/ace-qwen3.cpp | 11 ++++--- otherarch/acestep/music_adapter.cpp | 7 +++++ otherarch/acestep/request.cpp | 15 +++++---- otherarch/acestep/request.h | 2 ++ 6 files changed, 70 insertions(+), 16 deletions(-) diff --git a/expose.h b/expose.h index 287476f0c..9c8e4e0a6 100644 --- a/expose.h +++ b/expose.h @@ -342,7 +342,7 @@ struct music_load_model_inputs struct music_generation_inputs { const bool is_codes = false; //if true, generate codes, else, generate diffusion music - const char * caption = nullptr; + const char * input_json = nullptr; }; struct music_generation_outputs { diff --git a/koboldcpp.py b/koboldcpp.py index 358a4a99f..d95cb5850 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -451,7 +451,7 @@ class music_load_model_inputs(ctypes.Structure): class music_generation_inputs(ctypes.Structure): _fields_ = [("is_codes", ctypes.c_bool), - ("caption", ctypes.c_char_p)] + ("input_json", ctypes.c_char_p)] class music_generation_outputs(ctypes.Structure): _fields_ = [("status", ctypes.c_int), @@ -2376,16 +2376,20 @@ def music_load_model(musicllm,musicembedding,musicdiffusion,musicvae): def music_generate_codes(genparams): global args - caption = genparams.get("caption", "interesting music song") + input_json = json.dumps(genparams) inputs = music_generation_inputs() inputs.is_codes = True - inputs.caption = caption.encode("UTF-8") + inputs.input_json = input_json.encode("UTF-8") ret = handle.music_generate(inputs) outstr = "" if ret.status==1: outstr = ret.codes_json.decode("UTF-8","ignore") + outstr = json.dumps(json.loads(outstr)) return outstr +def music_generate_audio(genparams): + return "" + def tokenize_ids(countprompt,tcaddspecial): rawcountdata = handle.token_count(countprompt.encode("UTF-8"),tcaddspecial) countlimit = rawcountdata.count if (rawcountdata.count>=0 and rawcountdata.count<50000) else 0 @@ -4724,6 +4728,8 @@ Change Mode
is_transcribe = False is_tts = False is_embeddings = False + is_music_codes = False + is_music_audio = False response_body = None use_jinja = args.jinja @@ -4820,13 +4826,17 @@ Change Mode
is_tts = True elif self.path.endswith('/api/extra/embeddings') or self.path.endswith('/v1/embeddings'): is_embeddings = True + elif self.path.endswith('/api/extra/music/prepare'): + is_music_codes = True + elif self.path.endswith('/api/extra/music/generate'): + is_music_audio = True if response_body is not None: self.send_response(response_code) self.send_header('content-length', str(len(response_body))) self.end_headers(content_type='application/json') self.wfile.write(response_body) - elif is_imggen or is_img_upscale or is_transcribe or is_tts or is_embeddings or api_format > 0: + elif is_imggen or is_img_upscale or is_transcribe or is_tts or is_embeddings or is_music_codes or is_music_audio or api_format > 0: global last_req_time last_req_time = time.time() @@ -5131,6 +5141,37 @@ Change Mode
print("Create Embeddings: The response could not be sent, maybe connection was terminated?") time.sleep(0.2) #short delay return + elif is_music_codes: + try: + gendat = music_generate_codes(genparams) + genresp = (json.dumps({"error":"music code generation failed"}).encode()) + if gendat: + genresp = gendat.encode() + self.send_response(200) + self.send_header('content-length', str(len(genresp))) + self.end_headers(content_type='application/json') + self.wfile.write(genresp) + except Exception as ex: + utfprint(ex,1) + print("Music Gen Codes: The response could not be sent, maybe connection was terminated?") + time.sleep(0.2) #short delay + return + elif is_music_audio: + try: + gendat = music_generate_audio(genparams) + wav_data = b'' + if gendat: + wav_data = base64.b64decode(gendat) # Decode the Base64 string into binary data + self.send_response(200) + self.send_header('content-length', str(len(wav_data))) # Set content length + self.send_header('Content-Disposition', 'attachment; filename="output.wav"') + self.end_headers(content_type='audio/wav') + self.wfile.write(wav_data) # Write the binary WAV data to the response + except Exception as ex: + utfprint(ex,1) + print("Music Gen Audio: The response could not be sent, maybe connection was terminated?") + time.sleep(0.2) #short delay + return finally: time.sleep(0.05) diff --git a/otherarch/acestep/ace-qwen3.cpp b/otherarch/acestep/ace-qwen3.cpp index e5a715bd5..e24af1b43 100644 --- a/otherarch/acestep/ace-qwen3.cpp +++ b/otherarch/acestep/ace-qwen3.cpp @@ -1132,11 +1132,12 @@ std::string acestep_prepare_request(const music_generation_inputs inputs) // Read request and set essentials AceRequest req; - request_init(&req); - req.caption = inputs.caption; - req.lyrics = ""; //can be overridden or left auto - req.inference_steps = 8; - req.vocal_language = "en"; + std::string injson = inputs.input_json; + if (!request_parse_from_str(&req, injson)) + { + fprintf(stderr, "\nMusic JSON parse error\n"); + return ""; + } int seed = req.seed; if (seed <= 0 || seed==0xFFFFFFFF) diff --git a/otherarch/acestep/music_adapter.cpp b/otherarch/acestep/music_adapter.cpp index e9b600c59..4e89fa85e 100644 --- a/otherarch/acestep/music_adapter.cpp +++ b/otherarch/acestep/music_adapter.cpp @@ -83,6 +83,13 @@ music_generation_outputs musictype_generate(const music_generation_inputs inputs printf("\nMusic Gen Generating Codes..."); } codes_json_str = acestep_prepare_request(inputs); + if(codes_json_str=="") + { + printf("\nMusic codes generation failed!\n"); + output.status = 0; + output.codes_json = ""; + return output; + } output.status = 1; output.codes_json = codes_json_str.c_str(); if (!music_is_quiet) { diff --git a/otherarch/acestep/request.cpp b/otherarch/acestep/request.cpp index cc6a45b5a..9c5b430e0 100644 --- a/otherarch/acestep/request.cpp +++ b/otherarch/acestep/request.cpp @@ -198,19 +198,22 @@ static std::string read_file(const char * path) { return buf; } -// Public API -bool request_parse(AceRequest * r, const char * path) { - request_init(r); - +bool request_parse(AceRequest * r, const char * path) +{ std::string json = read_file(path); if (json.empty()) { fprintf(stderr, "[Request] ERROR: cannot read %s\n", path); return false; } + return request_parse_from_str(r, json); +} +// Public API +bool request_parse_from_str(AceRequest * r, std::string json) { + request_init(r); std::vector pairs; if (!parse_json_flat(json.c_str(), &pairs)) { - fprintf(stderr, "[Request] ERROR: malformed JSON in %s\n", path); + fprintf(stderr, "[Request] ERROR: malformed JSON\n"); return false; } @@ -247,7 +250,7 @@ bool request_parse(AceRequest * r, const char * path) { // unknown keys: silently ignored (forward compat) } - fprintf(stderr, "[Request] parsed %s (%zu fields)\n", path, pairs.size()); + fprintf(stderr, "[Request] parsed json (%zu fields)\n", pairs.size()); return true; } diff --git a/otherarch/acestep/request.h b/otherarch/acestep/request.h index f9e540988..1e9d5d723 100644 --- a/otherarch/acestep/request.h +++ b/otherarch/acestep/request.h @@ -50,6 +50,8 @@ void request_init(AceRequest * r); // Returns false on file error or malformed JSON. bool request_parse(AceRequest * r, const char * path); +bool request_parse_from_str(AceRequest * r, std::string json); + // Write struct to JSON file (overwrites). Returns false on file error. bool request_write(const AceRequest * r, const char * path);