mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 09:04:36 +00:00
refactor - do not use a copy buffer to store generation outputs, instead return a cpp allocated ptr
This commit is contained in:
parent
f75e479db0
commit
524ba12abd
6 changed files with 32 additions and 26 deletions
|
@ -206,18 +206,18 @@ extern "C"
|
|||
}
|
||||
}
|
||||
|
||||
generation_outputs generate(const generation_inputs inputs, generation_outputs &output)
|
||||
generation_outputs generate(const generation_inputs inputs)
|
||||
{
|
||||
return gpttype_generate(inputs, output);
|
||||
return gpttype_generate(inputs);
|
||||
}
|
||||
|
||||
bool load_model_sd(const load_sd_model_inputs inputs)
|
||||
{
|
||||
return sdtype_load_model(inputs);
|
||||
}
|
||||
sd_generation_outputs generate_sd(const sd_generation_inputs inputs, sd_generation_outputs &output)
|
||||
sd_generation_outputs generate_sd(const sd_generation_inputs inputs)
|
||||
{
|
||||
return sdtype_generate(inputs, output);
|
||||
return sdtype_generate(inputs);
|
||||
}
|
||||
|
||||
const char * new_token(int idx) {
|
||||
|
|
4
expose.h
4
expose.h
|
@ -92,7 +92,7 @@ struct generation_inputs
|
|||
struct generation_outputs
|
||||
{
|
||||
int status = -1;
|
||||
char text[24576]; //24kb should be enough for any response
|
||||
const char * text; //response will now be stored in c++ allocated memory
|
||||
};
|
||||
struct token_count_outputs
|
||||
{
|
||||
|
@ -115,7 +115,7 @@ struct sd_generation_inputs
|
|||
struct sd_generation_outputs
|
||||
{
|
||||
int status = -1;
|
||||
char data[24576];
|
||||
const char * data;
|
||||
};
|
||||
|
||||
extern std::string executable_path;
|
||||
|
|
|
@ -94,7 +94,8 @@ static int remaining_tokens = 0;
|
|||
static int stopper_unused_tokens = 0;
|
||||
static std::mutex concat_output_mtx;
|
||||
static std::string concat_output = "";
|
||||
static std::string concat_output_reader_copy = "";
|
||||
static std::string concat_output_reader_copy_poll = ""; //for streaming
|
||||
static std::string concat_output_reader_copy_res = ""; //for gen response
|
||||
static std::vector<logit_bias> logit_biases;
|
||||
|
||||
const int extra_context_handle_fragmentation = 80;
|
||||
|
@ -1469,12 +1470,12 @@ const std::string & gpttype_get_pending_output()
|
|||
if(kcpp_params==nullptr)
|
||||
{
|
||||
printf("\nWarning: KCPP not initialized!\n");
|
||||
return concat_output_reader_copy;
|
||||
return concat_output_reader_copy_poll;
|
||||
}
|
||||
concat_output_mtx.lock();
|
||||
concat_output_reader_copy = concat_output;
|
||||
concat_output_reader_copy_poll = concat_output;
|
||||
concat_output_mtx.unlock();
|
||||
return concat_output_reader_copy;
|
||||
return concat_output_reader_copy_poll;
|
||||
}
|
||||
|
||||
int GetThreadsToUse(bool blasmode)
|
||||
|
@ -1493,12 +1494,14 @@ int GetThreadsToUse(bool blasmode)
|
|||
return kcpp_params->n_threads;
|
||||
}
|
||||
|
||||
generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output)
|
||||
generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||
{
|
||||
generation_outputs output;
|
||||
|
||||
if(kcpp_params==nullptr)
|
||||
{
|
||||
printf("\nWarning: KCPP not initialized!\n");
|
||||
snprintf(output.text, sizeof(output.text), "%s", "");
|
||||
output.text = nullptr;
|
||||
output.status = 0;
|
||||
generation_finished = true;
|
||||
return output;
|
||||
|
@ -1511,7 +1514,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
|
||||
concat_output_mtx.lock();
|
||||
concat_output = "";
|
||||
concat_output_reader_copy = "";
|
||||
concat_output_reader_copy_poll = "";
|
||||
concat_output_reader_copy_res = "";
|
||||
concat_output_mtx.unlock();
|
||||
last_stop_reason = stop_reason::OUT_OF_TOKENS;
|
||||
stop_sequence.clear();
|
||||
|
@ -1897,7 +1901,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
if (!evalres)
|
||||
{
|
||||
fprintf(stderr, "\nFailed to predict at %d! Check your context buffer sizes!\n",n_past);
|
||||
snprintf(output.text, sizeof(output.text), "%s", "");
|
||||
output.text = nullptr;
|
||||
output.status = 0;
|
||||
generation_finished = true;
|
||||
return output;
|
||||
|
@ -2092,7 +2096,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
last_token_count = realnpredict;
|
||||
last_seed = kcpp_params->seed;
|
||||
total_gens += 1;
|
||||
snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str());
|
||||
|
||||
concat_output_mtx.lock();
|
||||
concat_output_reader_copy_res = concat_output;
|
||||
concat_output_mtx.unlock();
|
||||
output.text = concat_output_reader_copy_res.c_str();
|
||||
return output;
|
||||
}
|
|
@ -87,7 +87,7 @@ class generation_inputs(ctypes.Structure):
|
|||
|
||||
class generation_outputs(ctypes.Structure):
|
||||
_fields_ = [("status", ctypes.c_int),
|
||||
("text", ctypes.c_char * 24576)]
|
||||
("text", ctypes.c_char_p)]
|
||||
|
||||
class token_count_outputs(ctypes.Structure):
|
||||
_fields_ = [("count", ctypes.c_int),
|
||||
|
@ -242,7 +242,7 @@ def init_library():
|
|||
|
||||
handle.load_model.argtypes = [load_model_inputs]
|
||||
handle.load_model.restype = ctypes.c_bool
|
||||
handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever
|
||||
handle.generate.argtypes = [generation_inputs]
|
||||
handle.generate.restype = generation_outputs
|
||||
handle.new_token.restype = ctypes.c_char_p
|
||||
handle.new_token.argtypes = [ctypes.c_int]
|
||||
|
@ -350,7 +350,6 @@ def load_model(model_filename):
|
|||
def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}):
|
||||
global maxctx, args, currentusergenkey, totalgens, pendingabortkey
|
||||
inputs = generation_inputs()
|
||||
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
|
||||
inputs.prompt = prompt.encode("UTF-8")
|
||||
inputs.memory = memory.encode("UTF-8")
|
||||
if max_length >= (max_context_length-1):
|
||||
|
@ -438,7 +437,7 @@ def generate(prompt, memory="", max_length=32, max_context_length=512, temperatu
|
|||
pendingabortkey = ""
|
||||
return ""
|
||||
else:
|
||||
ret = handle.generate(inputs,outputs)
|
||||
ret = handle.generate(inputs)
|
||||
outstr = ""
|
||||
if ret.status==1:
|
||||
outstr = ret.text.decode("UTF-8","ignore")
|
||||
|
|
|
@ -73,13 +73,13 @@ enum ModelLoadResult
|
|||
};
|
||||
|
||||
ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in_file_format, FileFormatExtraMeta file_format_meta);
|
||||
generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output);
|
||||
generation_outputs gpttype_generate(const generation_inputs inputs);
|
||||
bool gpttype_generate_abort();
|
||||
const std::string & gpttype_get_pending_output();
|
||||
std::vector<int> gpttype_get_token_arr(const std::string & input);
|
||||
|
||||
bool sdtype_load_model(const load_sd_model_inputs inputs);
|
||||
sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs, sd_generation_outputs &output);
|
||||
sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs);
|
||||
|
||||
void timer_start();
|
||||
double timer_check();
|
||||
|
|
|
@ -144,12 +144,13 @@ bool sdtype_load_model(const load_sd_model_inputs inputs) {
|
|||
|
||||
}
|
||||
|
||||
sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs, sd_generation_outputs &output)
|
||||
sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
|
||||
{
|
||||
sd_generation_outputs output;
|
||||
if(sd_ctx == nullptr || sd_params == nullptr)
|
||||
{
|
||||
printf("\nError: KCPP SD is not initialized!\n");
|
||||
snprintf(output.data, sizeof(output.data), "%s", "");
|
||||
output.data = nullptr;
|
||||
output.status = 0;
|
||||
return output;
|
||||
}
|
||||
|
@ -208,7 +209,7 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs, sd_gene
|
|||
|
||||
if (results == NULL) {
|
||||
printf("\nKCPP SD generate failed!\n");
|
||||
snprintf(output.data, sizeof(output.data), "%s", "");
|
||||
output.data = nullptr;
|
||||
output.status = 0;
|
||||
return output;
|
||||
}
|
||||
|
@ -230,7 +231,7 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs, sd_gene
|
|||
|
||||
free(results);
|
||||
|
||||
snprintf(output.data, sizeof(output.data), "%s", "");
|
||||
output.data = nullptr;
|
||||
output.status = 1;
|
||||
return output;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue