mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-07 00:41:50 +00:00
ace step codes generation now working
This commit is contained in:
parent
71d42fae85
commit
4be93db21c
4 changed files with 56 additions and 14 deletions
4
expose.h
4
expose.h
|
|
@ -341,11 +341,13 @@ struct music_load_model_inputs
|
|||
};
|
||||
struct music_generation_inputs
|
||||
{
|
||||
const char * prompt = nullptr;
|
||||
const bool is_codes = false; //if true, generate codes, else, generate diffusion music
|
||||
const char * caption = nullptr;
|
||||
};
|
||||
struct music_generation_outputs
|
||||
{
|
||||
int status = -1;
|
||||
const char * codes_json = "";
|
||||
};
|
||||
|
||||
extern std::string executable_path;
|
||||
|
|
|
|||
15
koboldcpp.py
15
koboldcpp.py
|
|
@ -450,10 +450,12 @@ class music_load_model_inputs(ctypes.Structure):
|
|||
("debugmode", ctypes.c_int)]
|
||||
|
||||
class music_generation_inputs(ctypes.Structure):
|
||||
_fields_ = [("prompt", ctypes.c_char_p)]
|
||||
_fields_ = [("is_codes", ctypes.c_bool),
|
||||
("caption", ctypes.c_char_p)]
|
||||
|
||||
class music_generation_outputs(ctypes.Structure):
|
||||
_fields_ = [("status", ctypes.c_int)]
|
||||
_fields_ = [("status", ctypes.c_int),
|
||||
("codes_json", ctypes.c_char_p)]
|
||||
|
||||
class StdoutRedirector:
|
||||
def __init__(self, writer):
|
||||
|
|
@ -2372,15 +2374,16 @@ def music_load_model(musicllm,musicembedding,musicdiffusion,musicvae):
|
|||
ret = handle.music_load_model(inputs)
|
||||
return ret
|
||||
|
||||
def music_generate(genparams):
|
||||
def music_generate_codes(genparams):
|
||||
global args
|
||||
prompt = genparams.get("prompt", "")
|
||||
caption = genparams.get("caption", "interesting music song")
|
||||
inputs = music_generation_inputs()
|
||||
inputs.prompt = prompt.encode("UTF-8")
|
||||
inputs.is_codes = True
|
||||
inputs.caption = caption.encode("UTF-8")
|
||||
ret = handle.music_generate(inputs)
|
||||
outstr = ""
|
||||
if ret.status==1:
|
||||
outstr = ret.data.decode("UTF-8","ignore")
|
||||
outstr = ret.codes_json.decode("UTF-8","ignore")
|
||||
return outstr
|
||||
|
||||
def tokenize_ids(countprompt,tcaddspecial):
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
#include <chrono>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <iomanip>
|
||||
|
||||
// Timer
|
||||
struct Timer {
|
||||
|
|
@ -1120,7 +1121,7 @@ bool load_acestep(std::string model_path)
|
|||
return true;
|
||||
}
|
||||
|
||||
AceRequest acestep_prepare_request()
|
||||
std::string acestep_prepare_request(const music_generation_inputs inputs)
|
||||
{
|
||||
const int batch_size = 1;
|
||||
bool use_fsm = true;
|
||||
|
|
@ -1132,7 +1133,7 @@ AceRequest acestep_prepare_request()
|
|||
// Read request and set essentials
|
||||
AceRequest req;
|
||||
request_init(&req);
|
||||
req.caption = "doom";
|
||||
req.caption = inputs.caption;
|
||||
req.lyrics = ""; //can be overridden or left auto
|
||||
req.inference_steps = 8;
|
||||
req.vocal_language = "en";
|
||||
|
|
@ -1254,7 +1255,33 @@ AceRequest acestep_prepare_request()
|
|||
if (!batch_codes[0].empty()) rr.audio_codes = batch_codes[0];
|
||||
rr.seed = seed;
|
||||
|
||||
return rr;
|
||||
//now convert to string
|
||||
std::ostringstream oss;
|
||||
oss << "{\n";
|
||||
oss << " \"caption\": \"" << json_escape(rr.caption) << "\",\n";
|
||||
oss << " \"lyrics\": \"" << json_escape(rr.lyrics) << "\",\n";
|
||||
if (rr.instrumental) {
|
||||
oss << " \"instrumental\": true,\n";
|
||||
}
|
||||
oss << " \"bpm\": " << rr.bpm << ",\n";
|
||||
oss << " \"duration\": " << std::fixed << std::setprecision(1) << rr.duration << ",\n";
|
||||
oss << " \"keyscale\": \"" << json_escape(rr.keyscale) << "\",\n";
|
||||
oss << " \"timesignature\": \"" << json_escape(rr.timesignature) << "\",\n";
|
||||
oss << " \"vocal_language\": \"" << json_escape(rr.vocal_language) << "\",\n";
|
||||
oss << " \"task_type\": \"" << json_escape(rr.task_type) << "\",\n";
|
||||
oss << " \"seed\": " << rr.seed << ",\n";
|
||||
oss << " \"thinking\": " << (rr.thinking ? "true" : "false") << ",\n";
|
||||
oss << " \"lm_temperature\": " << std::fixed << std::setprecision(2) << rr.lm_temperature << ",\n";
|
||||
oss << " \"lm_cfg_scale\": " << std::fixed << std::setprecision(1) << rr.lm_cfg_scale << ",\n";
|
||||
oss << " \"lm_top_p\": " << std::fixed << std::setprecision(2) << rr.lm_top_p << ",\n";
|
||||
oss << " \"lm_negative_prompt\": \"" << json_escape(rr.lm_negative_prompt) << "\",\n";
|
||||
oss << " \"inference_steps\": " << rr.inference_steps << ",\n";
|
||||
oss << " \"guidance_scale\": " << std::fixed << std::setprecision(1) << rr.guidance_scale << ",\n";
|
||||
oss << " \"shift\": " << std::fixed << std::setprecision(1) << rr.shift << ",\n";
|
||||
oss << " \"audio_codes\": \"" << json_escape(rr.audio_codes) << "\"\n";
|
||||
oss << "}\n";
|
||||
std::string output_json = oss.str();
|
||||
return output_json;
|
||||
}
|
||||
|
||||
void unload_acestep()
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ static bool music_is_quiet = false;
|
|||
static bool musicgen_loaded = false;
|
||||
static std::string musicvulkandeviceenv;
|
||||
|
||||
static std::string codes_json_str = "";
|
||||
|
||||
bool musictype_load_model(const music_load_model_inputs inputs)
|
||||
{
|
||||
music_is_quiet = inputs.quiet;
|
||||
|
|
@ -72,14 +74,22 @@ music_generation_outputs musictype_generate(const music_generation_inputs inputs
|
|||
{
|
||||
printf("\nWarning: KCPP music gen not initialized!\n");
|
||||
output.status = 0;
|
||||
output.codes_json = "";
|
||||
return output;
|
||||
}
|
||||
|
||||
if(!music_is_quiet)
|
||||
{
|
||||
printf("\nMusic Gen Generating...");
|
||||
if (inputs.is_codes) {
|
||||
if (!music_is_quiet) {
|
||||
printf("\nMusic Gen Generating Codes...");
|
||||
}
|
||||
codes_json_str = acestep_prepare_request(inputs);
|
||||
output.status = 1;
|
||||
output.codes_json = codes_json_str.c_str();
|
||||
if (!music_is_quiet) {
|
||||
printf("\nMusic Gen Codes Done:\n%s\n",codes_json_str.c_str());
|
||||
}
|
||||
} else {
|
||||
}
|
||||
|
||||
output.status = 1;
|
||||
return output;
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue