added grammar sampling

This commit is contained in:
Concedo 2023-09-18 23:02:00 +08:00
parent 951614bfc6
commit 8c453d1e4e
6 changed files with 291 additions and 205 deletions

View file

@ -11,6 +11,7 @@
#include <mutex>
#include "model_adapter.h"
#include "otherarch.h"
#include "grammar-parser.h"
//for easier compilation
//concat source files into one file for compilation purposes
@ -41,10 +42,14 @@ int last_token_count = 0;
stop_reason last_stop_reason = stop_reason::INVALID;
std::vector<std::string> generated_tokens;
llama_grammar * grammar = nullptr; //currently used grammar
grammar_parser::parse_state parsed_grammar;
//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt)
static FileFormat file_format = FileFormat::BADFORMAT;
static gpt_vocab vocab;
static int32_t n_vocab = 0;
static gptj_v1_model gptj_ctx_v1;
static gptj_v2_model gptj_ctx_v2;
@ -61,6 +66,7 @@ static mpt_model mpt_ctx_v3;
static rwkv_v2_context * rwkv_ctx_v2;
static rwkv_context * rwkv_ctx_v3;
static llama_v2_context * llama_ctx_v2;
static llama_v3_context * llama_ctx_v3;
static llama_context * llama_ctx_v4;
@ -115,6 +121,133 @@ inline bool LogitsDuplicated(std::vector<float> & arr1, std::vector<float> & arr
}
static std::string FileFormatTokenizeID(int id, FileFormat file_format)
{
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
return std::string(llama_v2_token_to_str(llama_ctx_v2, id));
}
else if (file_format == FileFormat::GGJT_3)
{
return std::string(llama_v3_token_to_str(llama_ctx_v3, id));
}
else if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
{
return std::string(llama_token_to_str(llama_ctx_v4, id));
}
else
{
return vocab.id_to_token[id];
}
}
static void TokenizeString(const std::string & str_to_tokenize, std::vector<int> & output_tokens, FileFormat file_format)
{
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
{
if(file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 )
{
output_tokens = ::llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true);
}
else if (file_format == FileFormat::GGML)
{
output_tokens = ::legacy_llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true);
}
else if (file_format == FileFormat::GGJT_3)
{
output_tokens = ::llama_v3_tokenize(llama_ctx_v3, str_to_tokenize, true);
}
else
{
output_tokens = ::llama_tokenize(llama_ctx_v4, str_to_tokenize, true);
}
}
else
{
// tokenize the prompt
output_tokens = ::gpt_tokenize(vocab, str_to_tokenize);
}
}
static int GetEosID(FileFormat file_format, int32_t n_vocab)
{
unsigned int eosID = 0;
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
{
if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
{
eosID = llama_token_eos(llama_ctx_v4);
}
else if(file_format == FileFormat::GGJT_3)
{
eosID = llama_v3_token_eos();
}
else
{
eosID = llama_v3_token_eos();
}
}
else
{
if (file_format == FileFormat::GPT2_1 ||
file_format == FileFormat::GPT2_2 ||
file_format == FileFormat::GPT2_3 ||
file_format == FileFormat::GPT2_4 ||
file_format == FileFormat::GPTJ_1 ||
file_format == FileFormat::GPTJ_2 ||
file_format == FileFormat::GPTJ_3 ||
file_format == FileFormat::GPTJ_4 ||
file_format == FileFormat::GPTJ_5)
{
eosID = 50256;
if (n_vocab <= eosID)
{
//special case, starcoder models use ID 0 for EOS
eosID = 0;
}
}
if (file_format == FileFormat::RWKV_1 ||
file_format == FileFormat::RWKV_2 ||
file_format == FileFormat::NEOX_1 ||
file_format == FileFormat::NEOX_2 ||
file_format == FileFormat::NEOX_3 ||
file_format == FileFormat::NEOX_4 ||
file_format == FileFormat::NEOX_5 ||
file_format == FileFormat::NEOX_6 ||
file_format == FileFormat::NEOX_7 ||
file_format == FileFormat::MPT_1)
{
eosID = 0;
}
}
return eosID;
}
static float LowestLogit(const std::vector<float> & logits)
{
int topid = std::min_element(logits.begin(), logits.end()) - logits.begin();
float v = logits[topid];
return (v < 0 ? (v-8) : 0);
}
static float LowestLogit(const float *logits, size_t size)
{
if (size == 0) {
// Handle the case of an empty array
return 0.0;
}
int topid = std::min_element(logits, logits + size) - logits;
float v = logits[topid];
return (v < 0 ? (v-8) : 0);
}
static std::string RemoveBell(const std::string & input) //removes the bell character
{
std::string word2;
std::remove_copy(input.begin(), input.end(), std::back_inserter(word2), '\a');
return word2;
}
llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng)
{
llama_sample_softmax(nullptr, candidates);
@ -256,8 +389,47 @@ void sample_temperature(llama_token_data_array * candidates_p, float temp)
}
}
void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
const int64_t t_start_sample_us = ggml_time_us();
bool allow_eos = false;
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
allow_eos = true;
break;
}
}
const llama_token eos = GetEosID(file_format,n_vocab);
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
std::vector<llama_grammar_candidate> candidates_grammar;
for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
const std::string piece = FileFormatTokenizeID(id,file_format);
if (id == eos) {
if (!allow_eos) {
candidates->data[i].logit = -INFINITY;
}
} else if (piece.empty() || piece[0] == 0) {
candidates->data[i].logit = -INFINITY;
} else {
candidates_decoded.push_back(decode_utf8(piece.c_str(), grammar->partial_utf8));
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
}
}
const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
for (const auto & reject : rejects) {
candidates->data[reject.index].logit = -INFINITY;
}
}
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_a, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers> & sampler_order)
int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers> & sampler_order, llama_grammar * grammar)
{
int id = 0;
std::vector<llama_token_data> candidates;
@ -268,6 +440,10 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
if (grammar != nullptr) {
sample_grammar(file_format, n_vocab, &candidates_p, grammar);
}
if (mirostat == 1 || mirostat == 2)
{
static float mirostat_mu = 2.0f * mirostat_tau;
@ -321,76 +497,48 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers
return id;
}
static std::string FileFormatTokenizeID(int id, FileFormat file_format)
static void grammar_accept_token(FileFormat file_format, int32_t n_vocab, struct llama_grammar * grammar, llama_token token)
{
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
return std::string(llama_v2_token_to_str(llama_ctx_v2, id));
if (token == GetEosID(file_format,n_vocab)) {
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
return;
}
}
GGML_ASSERT(false);
}
else if (file_format == FileFormat::GGJT_3)
{
return std::string(llama_v3_token_to_str(llama_ctx_v3, id));
}
else if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
{
return std::string(llama_token_to_str(llama_ctx_v4, id));
}
else
{
return vocab.id_to_token[id];
const std::string piece = FileFormatTokenizeID(token,file_format); //llama_token_to_str(ctx, token);
// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece.c_str(), grammar->partial_utf8);
const auto & code_points = decoded.first;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
}
grammar->partial_utf8 = decoded.second;
GGML_ASSERT(!grammar->stacks.empty());
}
static void TokenizeString(const std::string & str_to_tokenize, std::vector<int> & output_tokens, FileFormat file_format)
static void load_grammar(const std::string & gammarstr)
{
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
if(grammar!=nullptr) //on demand free when next grammar is loaded
{
if(file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 )
{
output_tokens = ::llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true);
}
else if (file_format == FileFormat::GGML)
{
output_tokens = ::legacy_llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true);
}
else if (file_format == FileFormat::GGJT_3)
{
output_tokens = ::llama_v3_tokenize(llama_ctx_v3, str_to_tokenize, true);
}
else
{
output_tokens = ::llama_tokenize(llama_ctx_v4, str_to_tokenize, true);
}
llama_grammar_free(grammar);
grammar = nullptr;
}
else
{
// tokenize the prompt
output_tokens = ::gpt_tokenize(vocab, str_to_tokenize);
}
}
static float LowestLogit(const std::vector<float> & logits)
{
int topid = std::min_element(logits.begin(), logits.end()) - logits.begin();
float v = logits[topid];
return (v < 0 ? (v-8) : 0);
}
static float LowestLogit(const float *logits, size_t size)
{
if (size == 0) {
// Handle the case of an empty array
return 0.0;
if (!gammarstr.empty()) {
parsed_grammar = grammar_parser::parse(gammarstr.c_str());
// will be empty (default) if there are parse errors
if (parsed_grammar.rules.empty()) {
printf("\nIgnored invalid grammar sampler.");
return;
}
grammar_parser::print_grammar(stderr, parsed_grammar);
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
int topid = std::min_element(logits, logits + size) - logits;
float v = logits[topid];
return (v < 0 ? (v-8) : 0);
}
static std::string RemoveBell(const std::string & input) //removes the bell character
{
std::string word2;
std::remove_copy(input.begin(), input.end(), std::back_inserter(word2), '\a');
return word2;
}
ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in_file_format, FileFormatExtraMeta file_format_meta)
@ -522,6 +670,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
}
n_vocab = llama_v2_n_vocab(llama_ctx_v2);
//determine mem per token
const std::vector<int> tmp = {1, 2, 3, 4};
llama_v2_eval(llama_ctx_v2, tmp.data(), tmp.size(), 0, params.n_threads);
@ -587,6 +737,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
}
n_vocab = llama_v3_n_vocab(llama_ctx_v3);
//determine mem per token
const std::vector<int> tmp = {1, 2, 3, 4};
auto er = llama_v3_eval(llama_ctx_v3, tmp.data(), tmp.size(), 0, params.n_threads);
@ -663,6 +815,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
}
n_vocab = llama_n_vocab(llama_ctx_v4);
//determine mem per token
const std::vector<int> tmp = {1, 2, 3, 4};
auto er = llama_eval(llama_ctx_v4, tmp.data(), tmp.size(), 0, params.n_threads);
@ -720,6 +874,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
printf("\nRWKV Vocab: %u\n", vocabsiz);
logits.resize(vocabsiz);
n_vocab = vocab.id_to_token.size(); //handled seperately
if (file_format == FileFormat::RWKV_1)
{
n_batch = 1;
@ -790,6 +946,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
printf("\nTensor Transposition Detected! Retrying GPT-2 model loading...");
return res;
}
n_vocab = gpt2_ctx_v1.hparams.n_vocab;
// determine the required inference memory per token:
legacy_gpt2_eval(gpt2_ctx_v1, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format);
return ModelLoadResult::SUCCESS;
@ -809,6 +968,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
printf("\nTensor Transposition Detected! Retrying GPT-2 model loading...");
return res;
}
n_vocab = gpt2_ctx_v3.hparams.n_vocab;
// determine the required inference memory per token:
gpt2_eval(gpt2_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, use_scratch);
return ModelLoadResult::SUCCESS;
@ -829,6 +991,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
printf("\nTensor Transposition Detected! Retrying GPT-2 model loading...");
return res;
}
n_vocab = gpt2_ctx_v2.hparams.n_vocab;
// determine the required inference memory per token:
gpt2_v2_eval(gpt2_ctx_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format);
return ModelLoadResult::SUCCESS;
@ -847,6 +1012,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
printf("\nTensor Transposition Detected! Retrying GPT-J model loading...");
return res;
}
n_vocab = gptj_ctx_v1.hparams.n_vocab;
// determine the required inference memory per token:
legacy_gptj_eval(gptj_ctx_v1, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format);
@ -876,6 +1044,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
return loadresult;
}
n_vocab = gptj_ctx_v3.hparams.n_vocab;
// determine the required inference memory per token:
gptj_eval(gptj_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, use_scratch);
@ -912,6 +1082,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
return loadresult;
}
n_vocab = gptj_ctx_v2.hparams.n_vocab;
// determine the required inference memory per token:
gptj_v2_eval(gptj_ctx_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
@ -948,6 +1120,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
return res;
}
n_vocab = neox_ctx_v3.hparams.n_vocab;
// determine the required inference memory per token:
gpt_neox_eval(neox_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, use_scratch);
@ -970,6 +1144,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
return res;
}
n_vocab = neox_ctx_v2.hparams.n_vocab;
// determine the required inference memory per token:
gpt_neox_v2_eval(neox_ctx_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
@ -1005,6 +1181,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
return ModelLoadResult::FAIL;
}
n_vocab = mpt_ctx_v3.hparams.n_vocab;
// determine the required inference memory per token:
mpt_eval(mpt_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, false, mem_per_token, use_scratch);
return ModelLoadResult::SUCCESS;
@ -1084,6 +1262,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
generation_finished = false; // Set current generation status
generated_tokens.clear(); // New Generation, new tokens
std::string grammarstr = inputs.grammar;
load_grammar(grammarstr);
if (params.repeat_last_n < 1)
{
params.repeat_last_n = 1;
@ -1193,59 +1374,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
timer_start();
double time1 = 0, time2 = 0;
int32_t n_vocab = 0;
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
{
n_vocab = llama_v2_n_vocab(llama_ctx_v2);
}
else if(file_format == FileFormat::GGJT_3)
{
n_vocab = llama_v3_n_vocab(llama_ctx_v3);
}
else if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
{
n_vocab = llama_n_vocab(llama_ctx_v4);
}
else if (file_format == FileFormat::GPTJ_1 || file_format == FileFormat::GPTJ_2)
{
n_vocab = gptj_ctx_v1.hparams.n_vocab;
}
else if(file_format == FileFormat::GPTJ_3 || file_format==FileFormat::GPTJ_4)
{
n_vocab = gptj_ctx_v2.hparams.n_vocab;
}
else if(file_format==FileFormat::GPTJ_5)
{
n_vocab = gptj_ctx_v3.hparams.n_vocab;
}
else if(file_format == FileFormat::GPT2_1)
{
n_vocab = gpt2_ctx_v1.hparams.n_vocab;
}
else if(file_format == FileFormat::GPT2_2 || file_format==FileFormat::GPT2_3)
{
n_vocab = gpt2_ctx_v2.hparams.n_vocab;
}
else if(file_format==FileFormat::GPT2_4)
{
n_vocab = gpt2_ctx_v3.hparams.n_vocab;
}
else if(file_format == FileFormat::NEOX_1 || file_format == FileFormat::NEOX_2 || file_format == FileFormat::NEOX_3 || file_format==FileFormat::NEOX_4 || file_format==FileFormat::NEOX_5)
{
n_vocab = neox_ctx_v2.hparams.n_vocab;
}
else if( file_format==FileFormat::NEOX_6|| file_format==FileFormat::NEOX_7)
{
n_vocab = neox_ctx_v3.hparams.n_vocab;
}
else if( file_format==FileFormat::MPT_1)
{
n_vocab = mpt_ctx_v3.hparams.n_vocab;
}
else if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
{
n_vocab = vocab.id_to_token.size(); //handled seperately
if(n_past==0)
{
if(file_format == FileFormat::RWKV_1)
@ -1276,9 +1407,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
}
}
}
else
if(n_vocab<=0)
{
printf("Bad format!");
printf("\nWarning! n_vocab is invalid, maybe bad format!");
}
//prepare banned tokens
@ -1459,107 +1591,52 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
}
}
unsigned int eosID = 0;
unsigned int eosID = GetEosID(file_format, n_vocab);
float * logitsPtr;
float lowestLogit = 0;
int btsize = banned_token_ids.size();
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
{
if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
{
logitsPtr = llama_get_logits(llama_ctx_v4);
eosID = llama_token_eos(llama_ctx_v4);
}
else if(file_format == FileFormat::GGJT_3)
{
logitsPtr = llama_v3_get_logits(llama_ctx_v3);
eosID = llama_v3_token_eos();
}
else
{
logitsPtr = llama_v2_get_logits(llama_ctx_v2);
eosID = llama_v3_token_eos();
}
float lowestLogit = LowestLogit(logitsPtr,n_vocab);
if (!unbanTokens && !inputs.unban_tokens_rt)
{
// set the logit of the eos token (2) to -INF to avoid sampling it
logitsPtr[eosID] = lowestLogit;
}
if(btsize>0)
{
for(int t=0;t<btsize;++t)
{
logitsPtr[banned_token_ids[t]]=lowestLogit;
}
}
lowestLogit = LowestLogit(logitsPtr,n_vocab);
}
else
{
logitsPtr = logits.data();
float lowestLogit = LowestLogit(logits);
if (!unbanTokens && !inputs.unban_tokens_rt)
lowestLogit = LowestLogit(logits);
}
if (!unbanTokens && !inputs.unban_tokens_rt)
{
// set the logit of the eos token to very low to avoid sampling it
logitsPtr[eosID] = lowestLogit;
}
if(btsize>0)
{
for(int t=0;t<btsize;++t)
{
//gpt2 uses negative logits, so we cant zero it
// set the logit of the eos token to minimum to avoid sampling it
if (file_format == FileFormat::GPT2_1 ||
file_format == FileFormat::GPT2_2 ||
file_format == FileFormat::GPT2_3 ||
file_format == FileFormat::GPT2_4 ||
file_format == FileFormat::GPTJ_1 ||
file_format == FileFormat::GPTJ_2 ||
file_format == FileFormat::GPTJ_3 ||
file_format == FileFormat::GPTJ_4 ||
file_format == FileFormat::GPTJ_5)
{
eosID = 50256;
if(logits.size() > eosID)
{
logits[eosID] = lowestLogit;
}
else
{
//special case, starcoder models use ID 0 for EOS
if (file_format == FileFormat::GPT2_3 || file_format == FileFormat::GPT2_4)
{
eosID = 0;
logits[eosID] = lowestLogit;
}
}
}
// set the logit of the eos token (0) to minimum to avoid sampling it
if (file_format == FileFormat::RWKV_1 ||
file_format == FileFormat::RWKV_2 ||
file_format == FileFormat::NEOX_1 ||
file_format == FileFormat::NEOX_2 ||
file_format == FileFormat::NEOX_3 ||
file_format == FileFormat::NEOX_4 ||
file_format == FileFormat::NEOX_5 ||
file_format == FileFormat::NEOX_6 ||
file_format == FileFormat::NEOX_7 ||
file_format == FileFormat::MPT_1)
{
eosID = 0;
logits[eosID] = lowestLogit;
}
}
if(btsize>0)
{
for (int t = 0; t < btsize; ++t)
{
logits[banned_token_ids[t]] = lowestLogit;
}
logitsPtr[banned_token_ids[t]]=lowestLogit;
}
}
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty,
top_k, top_a, top_p, typical_p, tfs_z, temp, rng,
params.mirostat, params.mirostat_tau, params.mirostat_eta, sampler_order);
params.mirostat, params.mirostat_tau, params.mirostat_eta, sampler_order, grammar);
if (grammar != nullptr) {
grammar_accept_token(file_format, n_vocab, grammar, id);
}
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);