mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
added extra endpoints for abort gen and polled streaming
This commit is contained in:
parent
5bd9cef9fa
commit
43f7e40470
5 changed files with 63 additions and 25 deletions
|
@ -28,6 +28,11 @@
|
|||
#include "neox_v3.cpp"
|
||||
#include "mpt_v3.cpp"
|
||||
|
||||
//shared
|
||||
std::string executable_path = "";
|
||||
std::string lora_filename = "";
|
||||
bool generation_finished;
|
||||
std::vector<std::string> generated_tokens;
|
||||
|
||||
//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt)
|
||||
static FileFormat file_format = FileFormat::BADFORMAT;
|
||||
|
@ -63,7 +68,6 @@ static bool useSmartContext = false;
|
|||
static bool unbanTokens = false;
|
||||
static int blasbatchsize = 512;
|
||||
static bool debugmode = false;
|
||||
static bool stream_sse = true;
|
||||
static std::string modelname;
|
||||
static std::vector<gpt_vocab::id> last_n_tokens;
|
||||
static std::vector<gpt_vocab::id> current_context_tokens;
|
||||
|
@ -72,6 +76,8 @@ static std::vector<float> logits;
|
|||
static std::vector<int> smartcontext;
|
||||
static std::vector<std::string> stop_sequence;
|
||||
static std::vector<llama_token_data> top_picks;
|
||||
static int remaining_tokens = 0;
|
||||
static std::string concat_output = "";
|
||||
|
||||
inline bool IsNanCheck(float f)
|
||||
{
|
||||
|
@ -707,6 +713,16 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
|
||||
}
|
||||
|
||||
bool gpttype_generate_abort()
|
||||
{
|
||||
remaining_tokens = 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::string & gpttype_get_pending_output()
|
||||
{
|
||||
return concat_output;
|
||||
}
|
||||
|
||||
generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output)
|
||||
{
|
||||
|
@ -735,6 +751,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
params.n_ctx = inputs.max_context_length;
|
||||
params.n_batch = n_batch;
|
||||
params.n_threads = n_threads;
|
||||
bool stream_sse = inputs.stream_sse;
|
||||
|
||||
generation_finished = false; // Set current generation status
|
||||
generated_tokens.clear(); // New Generation, new tokens
|
||||
|
@ -837,11 +854,11 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
|
||||
current_context_tokens.resize(n_past);
|
||||
|
||||
int remaining_tokens = params.n_predict;
|
||||
remaining_tokens = params.n_predict;
|
||||
int stopper_unused_tokens = 0;
|
||||
int input_consumed = 0;
|
||||
std::mt19937 rng(params.seed);
|
||||
std::string concat_output = "";
|
||||
concat_output = "";
|
||||
|
||||
bool startedsampling = false;
|
||||
|
||||
|
@ -1153,8 +1170,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
for (auto id : embd)
|
||||
{
|
||||
std::string tokenizedstr = FileFormatTokenizeID(id, file_format);
|
||||
|
||||
if (stream_sse)
|
||||
if(stream_sse)
|
||||
{
|
||||
generated_tokens.push_back(tokenizedstr);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue