mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-07 00:41:50 +00:00
ace step codes api
This commit is contained in:
parent
d100c8660e
commit
c2b0cb26a8
6 changed files with 70 additions and 16 deletions
2
expose.h
2
expose.h
|
|
@ -342,7 +342,7 @@ struct music_load_model_inputs
|
|||
struct music_generation_inputs
|
||||
{
|
||||
const bool is_codes = false; //if true, generate codes, else, generate diffusion music
|
||||
const char * caption = nullptr;
|
||||
const char * input_json = nullptr;
|
||||
};
|
||||
struct music_generation_outputs
|
||||
{
|
||||
|
|
|
|||
49
koboldcpp.py
49
koboldcpp.py
|
|
@ -451,7 +451,7 @@ class music_load_model_inputs(ctypes.Structure):
|
|||
|
||||
class music_generation_inputs(ctypes.Structure):
|
||||
_fields_ = [("is_codes", ctypes.c_bool),
|
||||
("caption", ctypes.c_char_p)]
|
||||
("input_json", ctypes.c_char_p)]
|
||||
|
||||
class music_generation_outputs(ctypes.Structure):
|
||||
_fields_ = [("status", ctypes.c_int),
|
||||
|
|
@ -2376,16 +2376,20 @@ def music_load_model(musicllm,musicembedding,musicdiffusion,musicvae):
|
|||
|
||||
def music_generate_codes(genparams):
|
||||
global args
|
||||
caption = genparams.get("caption", "interesting music song")
|
||||
input_json = json.dumps(genparams)
|
||||
inputs = music_generation_inputs()
|
||||
inputs.is_codes = True
|
||||
inputs.caption = caption.encode("UTF-8")
|
||||
inputs.input_json = input_json.encode("UTF-8")
|
||||
ret = handle.music_generate(inputs)
|
||||
outstr = ""
|
||||
if ret.status==1:
|
||||
outstr = ret.codes_json.decode("UTF-8","ignore")
|
||||
outstr = json.dumps(json.loads(outstr))
|
||||
return outstr
|
||||
|
||||
def music_generate_audio(genparams):
|
||||
return ""
|
||||
|
||||
def tokenize_ids(countprompt,tcaddspecial):
|
||||
rawcountdata = handle.token_count(countprompt.encode("UTF-8"),tcaddspecial)
|
||||
countlimit = rawcountdata.count if (rawcountdata.count>=0 and rawcountdata.count<50000) else 0
|
||||
|
|
@ -4724,6 +4728,8 @@ Change Mode<br>
|
|||
is_transcribe = False
|
||||
is_tts = False
|
||||
is_embeddings = False
|
||||
is_music_codes = False
|
||||
is_music_audio = False
|
||||
response_body = None
|
||||
use_jinja = args.jinja
|
||||
|
||||
|
|
@ -4820,13 +4826,17 @@ Change Mode<br>
|
|||
is_tts = True
|
||||
elif self.path.endswith('/api/extra/embeddings') or self.path.endswith('/v1/embeddings'):
|
||||
is_embeddings = True
|
||||
elif self.path.endswith('/api/extra/music/prepare'):
|
||||
is_music_codes = True
|
||||
elif self.path.endswith('/api/extra/music/generate'):
|
||||
is_music_audio = True
|
||||
|
||||
if response_body is not None:
|
||||
self.send_response(response_code)
|
||||
self.send_header('content-length', str(len(response_body)))
|
||||
self.end_headers(content_type='application/json')
|
||||
self.wfile.write(response_body)
|
||||
elif is_imggen or is_img_upscale or is_transcribe or is_tts or is_embeddings or api_format > 0:
|
||||
elif is_imggen or is_img_upscale or is_transcribe or is_tts or is_embeddings or is_music_codes or is_music_audio or api_format > 0:
|
||||
global last_req_time
|
||||
last_req_time = time.time()
|
||||
|
||||
|
|
@ -5131,6 +5141,37 @@ Change Mode<br>
|
|||
print("Create Embeddings: The response could not be sent, maybe connection was terminated?")
|
||||
time.sleep(0.2) #short delay
|
||||
return
|
||||
elif is_music_codes:
|
||||
try:
|
||||
gendat = music_generate_codes(genparams)
|
||||
genresp = (json.dumps({"error":"music code generation failed"}).encode())
|
||||
if gendat:
|
||||
genresp = gendat.encode()
|
||||
self.send_response(200)
|
||||
self.send_header('content-length', str(len(genresp)))
|
||||
self.end_headers(content_type='application/json')
|
||||
self.wfile.write(genresp)
|
||||
except Exception as ex:
|
||||
utfprint(ex,1)
|
||||
print("Music Gen Codes: The response could not be sent, maybe connection was terminated?")
|
||||
time.sleep(0.2) #short delay
|
||||
return
|
||||
elif is_music_audio:
|
||||
try:
|
||||
gendat = music_generate_audio(genparams)
|
||||
wav_data = b''
|
||||
if gendat:
|
||||
wav_data = base64.b64decode(gendat) # Decode the Base64 string into binary data
|
||||
self.send_response(200)
|
||||
self.send_header('content-length', str(len(wav_data))) # Set content length
|
||||
self.send_header('Content-Disposition', 'attachment; filename="output.wav"')
|
||||
self.end_headers(content_type='audio/wav')
|
||||
self.wfile.write(wav_data) # Write the binary WAV data to the response
|
||||
except Exception as ex:
|
||||
utfprint(ex,1)
|
||||
print("Music Gen Audio: The response could not be sent, maybe connection was terminated?")
|
||||
time.sleep(0.2) #short delay
|
||||
return
|
||||
|
||||
finally:
|
||||
time.sleep(0.05)
|
||||
|
|
|
|||
|
|
@ -1132,11 +1132,12 @@ std::string acestep_prepare_request(const music_generation_inputs inputs)
|
|||
|
||||
// Read request and set essentials
|
||||
AceRequest req;
|
||||
request_init(&req);
|
||||
req.caption = inputs.caption;
|
||||
req.lyrics = ""; //can be overridden or left auto
|
||||
req.inference_steps = 8;
|
||||
req.vocal_language = "en";
|
||||
std::string injson = inputs.input_json;
|
||||
if (!request_parse_from_str(&req, injson))
|
||||
{
|
||||
fprintf(stderr, "\nMusic JSON parse error\n");
|
||||
return "";
|
||||
}
|
||||
|
||||
int seed = req.seed;
|
||||
if (seed <= 0 || seed==0xFFFFFFFF)
|
||||
|
|
|
|||
|
|
@ -83,6 +83,13 @@ music_generation_outputs musictype_generate(const music_generation_inputs inputs
|
|||
printf("\nMusic Gen Generating Codes...");
|
||||
}
|
||||
codes_json_str = acestep_prepare_request(inputs);
|
||||
if(codes_json_str=="")
|
||||
{
|
||||
printf("\nMusic codes generation failed!\n");
|
||||
output.status = 0;
|
||||
output.codes_json = "";
|
||||
return output;
|
||||
}
|
||||
output.status = 1;
|
||||
output.codes_json = codes_json_str.c_str();
|
||||
if (!music_is_quiet) {
|
||||
|
|
|
|||
|
|
@ -198,19 +198,22 @@ static std::string read_file(const char * path) {
|
|||
return buf;
|
||||
}
|
||||
|
||||
// Public API
|
||||
bool request_parse(AceRequest * r, const char * path) {
|
||||
request_init(r);
|
||||
|
||||
bool request_parse(AceRequest * r, const char * path)
|
||||
{
|
||||
std::string json = read_file(path);
|
||||
if (json.empty()) {
|
||||
fprintf(stderr, "[Request] ERROR: cannot read %s\n", path);
|
||||
return false;
|
||||
}
|
||||
return request_parse_from_str(r, json);
|
||||
}
|
||||
|
||||
// Public API
|
||||
bool request_parse_from_str(AceRequest * r, std::string json) {
|
||||
request_init(r);
|
||||
std::vector<JsonPair> pairs;
|
||||
if (!parse_json_flat(json.c_str(), &pairs)) {
|
||||
fprintf(stderr, "[Request] ERROR: malformed JSON in %s\n", path);
|
||||
fprintf(stderr, "[Request] ERROR: malformed JSON\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -247,7 +250,7 @@ bool request_parse(AceRequest * r, const char * path) {
|
|||
// unknown keys: silently ignored (forward compat)
|
||||
}
|
||||
|
||||
fprintf(stderr, "[Request] parsed %s (%zu fields)\n", path, pairs.size());
|
||||
fprintf(stderr, "[Request] parsed json (%zu fields)\n", pairs.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -50,6 +50,8 @@ void request_init(AceRequest * r);
|
|||
// Returns false on file error or malformed JSON.
|
||||
bool request_parse(AceRequest * r, const char * path);
|
||||
|
||||
bool request_parse_from_str(AceRequest * r, std::string json);
|
||||
|
||||
// Write struct to JSON file (overwrites). Returns false on file error.
|
||||
bool request_write(const AceRequest * r, const char * path);
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue