added a kobold API compatible implementation of stopping sequences

This commit is contained in:
Concedo 2023-04-16 18:37:49 +08:00
parent 8bf2e50a11
commit 525184930d
7 changed files with 79 additions and 11 deletions

View file

@ -36,8 +36,8 @@ static std::vector<gpt_vocab::id> last_n_tokens;
static std::vector<gpt_vocab::id> current_context_tokens;
static size_t mem_per_token = 0;
static std::vector<float> logits;
static std::vector<int> smartcontext;
static std::vector<std::string> stop_sequence;
inline bool IsNanCheck(float f)
{
@ -154,6 +154,14 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output)
{
stop_sequence.clear();
for(int x=0;x<stop_token_max;++x)
{
if(inputs.stop_sequence[x]!="")
{
stop_sequence.push_back(inputs.stop_sequence[x]);
}
}
params.prompt = inputs.prompt;
params.seed = inputs.seed;
params.n_predict = inputs.max_length;
@ -333,9 +341,18 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
// decrement remaining sampling budget
--remaining_tokens;
for (auto id : embd) {
for (auto id : embd)
{
concat_output += vocab.id_to_token[id].c_str();
for (const auto &matched : stop_sequence)
{
if (concat_output.find(matched) != std::string::npos)
{
remaining_tokens = 0;
break;
}
}
}
}
else