refactor - do not use a copy buffer to store generation outputs, instead return a cpp allocated ptr

This commit is contained in:
Concedo 2024-02-29 14:02:20 +08:00
parent f75e479db0
commit 524ba12abd
6 changed files with 32 additions and 26 deletions

View file

@ -206,18 +206,18 @@ extern "C"
} }
} }
generation_outputs generate(const generation_inputs inputs, generation_outputs &output) generation_outputs generate(const generation_inputs inputs)
{ {
return gpttype_generate(inputs, output); return gpttype_generate(inputs);
} }
bool load_model_sd(const load_sd_model_inputs inputs) bool load_model_sd(const load_sd_model_inputs inputs)
{ {
return sdtype_load_model(inputs); return sdtype_load_model(inputs);
} }
sd_generation_outputs generate_sd(const sd_generation_inputs inputs, sd_generation_outputs &output) sd_generation_outputs generate_sd(const sd_generation_inputs inputs)
{ {
return sdtype_generate(inputs, output); return sdtype_generate(inputs);
} }
const char * new_token(int idx) { const char * new_token(int idx) {

View file

@ -92,7 +92,7 @@ struct generation_inputs
struct generation_outputs struct generation_outputs
{ {
int status = -1; int status = -1;
char text[24576]; //24kb should be enough for any response const char * text; //response will now be stored in c++ allocated memory
}; };
struct token_count_outputs struct token_count_outputs
{ {
@ -115,7 +115,7 @@ struct sd_generation_inputs
struct sd_generation_outputs struct sd_generation_outputs
{ {
int status = -1; int status = -1;
char data[24576]; const char * data;
}; };
extern std::string executable_path; extern std::string executable_path;

View file

@ -94,7 +94,8 @@ static int remaining_tokens = 0;
static int stopper_unused_tokens = 0; static int stopper_unused_tokens = 0;
static std::mutex concat_output_mtx; static std::mutex concat_output_mtx;
static std::string concat_output = ""; static std::string concat_output = "";
static std::string concat_output_reader_copy = ""; static std::string concat_output_reader_copy_poll = ""; //for streaming
static std::string concat_output_reader_copy_res = ""; //for gen response
static std::vector<logit_bias> logit_biases; static std::vector<logit_bias> logit_biases;
const int extra_context_handle_fragmentation = 80; const int extra_context_handle_fragmentation = 80;
@ -1469,12 +1470,12 @@ const std::string & gpttype_get_pending_output()
if(kcpp_params==nullptr) if(kcpp_params==nullptr)
{ {
printf("\nWarning: KCPP not initialized!\n"); printf("\nWarning: KCPP not initialized!\n");
return concat_output_reader_copy; return concat_output_reader_copy_poll;
} }
concat_output_mtx.lock(); concat_output_mtx.lock();
concat_output_reader_copy = concat_output; concat_output_reader_copy_poll = concat_output;
concat_output_mtx.unlock(); concat_output_mtx.unlock();
return concat_output_reader_copy; return concat_output_reader_copy_poll;
} }
int GetThreadsToUse(bool blasmode) int GetThreadsToUse(bool blasmode)
@ -1493,12 +1494,14 @@ int GetThreadsToUse(bool blasmode)
return kcpp_params->n_threads; return kcpp_params->n_threads;
} }
generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output) generation_outputs gpttype_generate(const generation_inputs inputs)
{ {
generation_outputs output;
if(kcpp_params==nullptr) if(kcpp_params==nullptr)
{ {
printf("\nWarning: KCPP not initialized!\n"); printf("\nWarning: KCPP not initialized!\n");
snprintf(output.text, sizeof(output.text), "%s", ""); output.text = nullptr;
output.status = 0; output.status = 0;
generation_finished = true; generation_finished = true;
return output; return output;
@ -1511,7 +1514,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
concat_output_mtx.lock(); concat_output_mtx.lock();
concat_output = ""; concat_output = "";
concat_output_reader_copy = ""; concat_output_reader_copy_poll = "";
concat_output_reader_copy_res = "";
concat_output_mtx.unlock(); concat_output_mtx.unlock();
last_stop_reason = stop_reason::OUT_OF_TOKENS; last_stop_reason = stop_reason::OUT_OF_TOKENS;
stop_sequence.clear(); stop_sequence.clear();
@ -1897,7 +1901,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
if (!evalres) if (!evalres)
{ {
fprintf(stderr, "\nFailed to predict at %d! Check your context buffer sizes!\n",n_past); fprintf(stderr, "\nFailed to predict at %d! Check your context buffer sizes!\n",n_past);
snprintf(output.text, sizeof(output.text), "%s", ""); output.text = nullptr;
output.status = 0; output.status = 0;
generation_finished = true; generation_finished = true;
return output; return output;
@ -2092,7 +2096,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
last_token_count = realnpredict; last_token_count = realnpredict;
last_seed = kcpp_params->seed; last_seed = kcpp_params->seed;
total_gens += 1; total_gens += 1;
snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str()); concat_output_mtx.lock();
concat_output_reader_copy_res = concat_output;
concat_output_mtx.unlock();
output.text = concat_output_reader_copy_res.c_str();
return output; return output;
} }

View file

@ -87,7 +87,7 @@ class generation_inputs(ctypes.Structure):
class generation_outputs(ctypes.Structure): class generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int), _fields_ = [("status", ctypes.c_int),
("text", ctypes.c_char * 24576)] ("text", ctypes.c_char_p)]
class token_count_outputs(ctypes.Structure): class token_count_outputs(ctypes.Structure):
_fields_ = [("count", ctypes.c_int), _fields_ = [("count", ctypes.c_int),
@ -242,7 +242,7 @@ def init_library():
handle.load_model.argtypes = [load_model_inputs] handle.load_model.argtypes = [load_model_inputs]
handle.load_model.restype = ctypes.c_bool handle.load_model.restype = ctypes.c_bool
handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever handle.generate.argtypes = [generation_inputs]
handle.generate.restype = generation_outputs handle.generate.restype = generation_outputs
handle.new_token.restype = ctypes.c_char_p handle.new_token.restype = ctypes.c_char_p
handle.new_token.argtypes = [ctypes.c_int] handle.new_token.argtypes = [ctypes.c_int]
@ -350,7 +350,6 @@ def load_model(model_filename):
def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}): def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}):
global maxctx, args, currentusergenkey, totalgens, pendingabortkey global maxctx, args, currentusergenkey, totalgens, pendingabortkey
inputs = generation_inputs() inputs = generation_inputs()
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
inputs.prompt = prompt.encode("UTF-8") inputs.prompt = prompt.encode("UTF-8")
inputs.memory = memory.encode("UTF-8") inputs.memory = memory.encode("UTF-8")
if max_length >= (max_context_length-1): if max_length >= (max_context_length-1):
@ -438,7 +437,7 @@ def generate(prompt, memory="", max_length=32, max_context_length=512, temperatu
pendingabortkey = "" pendingabortkey = ""
return "" return ""
else: else:
ret = handle.generate(inputs,outputs) ret = handle.generate(inputs)
outstr = "" outstr = ""
if ret.status==1: if ret.status==1:
outstr = ret.text.decode("UTF-8","ignore") outstr = ret.text.decode("UTF-8","ignore")

View file

@ -73,13 +73,13 @@ enum ModelLoadResult
}; };
ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in_file_format, FileFormatExtraMeta file_format_meta); ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in_file_format, FileFormatExtraMeta file_format_meta);
generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output); generation_outputs gpttype_generate(const generation_inputs inputs);
bool gpttype_generate_abort(); bool gpttype_generate_abort();
const std::string & gpttype_get_pending_output(); const std::string & gpttype_get_pending_output();
std::vector<int> gpttype_get_token_arr(const std::string & input); std::vector<int> gpttype_get_token_arr(const std::string & input);
bool sdtype_load_model(const load_sd_model_inputs inputs); bool sdtype_load_model(const load_sd_model_inputs inputs);
sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs, sd_generation_outputs &output); sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs);
void timer_start(); void timer_start();
double timer_check(); double timer_check();

View file

@ -144,12 +144,13 @@ bool sdtype_load_model(const load_sd_model_inputs inputs) {
} }
sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs, sd_generation_outputs &output) sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
{ {
sd_generation_outputs output;
if(sd_ctx == nullptr || sd_params == nullptr) if(sd_ctx == nullptr || sd_params == nullptr)
{ {
printf("\nError: KCPP SD is not initialized!\n"); printf("\nError: KCPP SD is not initialized!\n");
snprintf(output.data, sizeof(output.data), "%s", ""); output.data = nullptr;
output.status = 0; output.status = 0;
return output; return output;
} }
@ -208,7 +209,7 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs, sd_gene
if (results == NULL) { if (results == NULL) {
printf("\nKCPP SD generate failed!\n"); printf("\nKCPP SD generate failed!\n");
snprintf(output.data, sizeof(output.data), "%s", ""); output.data = nullptr;
output.status = 0; output.status = 0;
return output; return output;
} }
@ -230,7 +231,7 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs, sd_gene
free(results); free(results);
snprintf(output.data, sizeof(output.data), "%s", ""); output.data = nullptr;
output.status = 1; output.status = 1;
return output; return output;
} }