added a kobold API compatible implementation of stopping sequences

This commit is contained in:
Concedo 2023-04-16 18:37:49 +08:00
parent 8bf2e50a11
commit 525184930d
7 changed files with 79 additions and 11 deletions

View file

@ -34,6 +34,7 @@ static llama_context *ctx;
static std::vector<llama_token> last_n_tokens;
static std::vector<llama_token> current_context_tokens;
static std::vector<llama_token> smartcontext;
static std::vector<std::string> stop_sequence;
bool llama_load_model(const load_model_inputs inputs, FileFormat in_file_format)
{
@ -81,6 +82,15 @@ bool llama_load_model(const load_model_inputs inputs, FileFormat in_file_format)
generation_outputs llama_generate(const generation_inputs inputs, generation_outputs &output)
{
stop_sequence.clear();
for(int x=0;x<stop_token_max;++x)
{
std::string stopper = inputs.stop_sequence[x];
if(stopper!="")
{
stop_sequence.push_back(stopper);
}
}
params.prompt = inputs.prompt;
params.seed = inputs.seed;
params.n_predict = inputs.max_length;
@ -231,6 +241,14 @@ generation_outputs llama_generate(const generation_inputs inputs, generation_out
--remaining_tokens;
//printf("\nid:%d word:%s\n",id,llama_token_to_str(ctx, id));
concat_output += llama_token_to_str(ctx, id);
for (const auto &matched : stop_sequence)
{
if (concat_output.find(matched) != std::string::npos)
{
remaining_tokens = 0;
break;
}
}
}
else
{