added skip bos for tokenize endpoint

This commit is contained in:
Concedo 2024-06-05 10:49:11 +08:00
parent 5789417802
commit 10b148f4c2
5 changed files with 11 additions and 12 deletions

View file

@ -277,11 +277,11 @@ extern "C"
} }
static std::vector<int> toks; //just share a static object for token counting static std::vector<int> toks; //just share a static object for token counting
token_count_outputs token_count(const char * input) token_count_outputs token_count(const char * input, bool addbos)
{ {
std::string inputstr = input; std::string inputstr = input;
token_count_outputs output; token_count_outputs output;
toks = gpttype_get_token_arr(inputstr); toks = gpttype_get_token_arr(inputstr,addbos);
output.count = toks.size(); output.count = toks.size();
output.ids = toks.data(); //this may be slightly unsafe output.ids = toks.data(); //this may be slightly unsafe
return output; return output;

View file

@ -798,10 +798,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
kcpp_params->n_threads_batch = inputs.blasthreads; kcpp_params->n_threads_batch = inputs.blasthreads;
bool isGguf = (file_format == FileFormat::GGUF_GENERIC); bool isGguf = (file_format == FileFormat::GGUF_GENERIC);
kcpp_params->n_batch = GetBatchSize(inputs.blasbatchsize, in_file_format); kcpp_params->n_batch = GetBatchSize(inputs.blasbatchsize, in_file_format);
if(kcpp_params->n_batch>512) kcpp_params->n_ubatch = kcpp_params->n_batch;
{
kcpp_params->n_ubatch = (kcpp_params->n_batch>1024?1024:kcpp_params->n_batch);
}
kcpp_params->flash_attn = inputs.flash_attention; kcpp_params->flash_attn = inputs.flash_attention;
modelname = kcpp_params->model = inputs.model_filename; modelname = kcpp_params->model = inputs.model_filename;
useSmartContext = inputs.use_smartcontext; useSmartContext = inputs.use_smartcontext;
@ -1544,7 +1541,7 @@ bool gpttype_generate_abort()
return true; return true;
} }
std::vector<int> gpttype_get_token_arr(const std::string & input) std::vector<int> gpttype_get_token_arr(const std::string & input, bool addbos)
{ {
std::vector<int> toks; std::vector<int> toks;
if(kcpp_params==nullptr) if(kcpp_params==nullptr)
@ -1556,7 +1553,7 @@ std::vector<int> gpttype_get_token_arr(const std::string & input)
{ {
printf("\nFileFormat: %d, Tokenizing: %s",file_format ,input.c_str()); printf("\nFileFormat: %d, Tokenizing: %s",file_format ,input.c_str());
} }
TokenizeString(input, toks, file_format); TokenizeString(input, toks, file_format,addbos);
int tokcount = toks.size(); int tokcount = toks.size();
if(debugmode==1) if(debugmode==1)
{ {

View file

@ -7022,7 +7022,7 @@ Current version: 145
{ {
key = parseInt(tokarr[x]); key = parseInt(tokarr[x]);
val = parseInt(val); val = parseInt(val);
if (!isNaN(key) && key!=1) { if (!isNaN(key)) {
dict[key] = parseInt(val); dict[key] = parseInt(val);
} }
} }
@ -12839,7 +12839,8 @@ Current version: 145
function kcpp_tokenize(prompt,onDone) function kcpp_tokenize(prompt,onDone)
{ {
let payload = { let payload = {
"prompt": prompt "prompt": prompt,
"special": false,
}; };
fetch(apply_proxy_url(custom_kobold_endpoint + koboldcpp_tokenize_endpoint), { fetch(apply_proxy_url(custom_kobold_endpoint + koboldcpp_tokenize_endpoint), {
method: 'POST', method: 'POST',

View file

@ -1352,7 +1352,8 @@ Enter Prompt:<br>
try: try:
genparams = json.loads(body) genparams = json.loads(body)
countprompt = genparams.get('prompt', "") countprompt = genparams.get('prompt', "")
rawcountdata = handle.token_count(countprompt.encode("UTF-8")) tcaddspecial = genparams.get('special', True)
rawcountdata = handle.token_count(countprompt.encode("UTF-8"),tcaddspecial)
countlimit = rawcountdata.count if (rawcountdata.count>=0 and rawcountdata.count<50000) else 0 countlimit = rawcountdata.count if (rawcountdata.count>=0 and rawcountdata.count<50000) else 0
# the above protects the server in case the count limit got corrupted # the above protects the server in case the count limit got corrupted
countdata = [rawcountdata.ids[i] for i in range(countlimit)] countdata = [rawcountdata.ids[i] for i in range(countlimit)]

View file

@ -77,7 +77,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
generation_outputs gpttype_generate(const generation_inputs inputs); generation_outputs gpttype_generate(const generation_inputs inputs);
bool gpttype_generate_abort(); bool gpttype_generate_abort();
const std::string & gpttype_get_pending_output(); const std::string & gpttype_get_pending_output();
std::vector<int> gpttype_get_token_arr(const std::string & input); std::vector<int> gpttype_get_token_arr(const std::string & input, bool addbos);
bool sdtype_load_model(const sd_load_model_inputs inputs); bool sdtype_load_model(const sd_load_model_inputs inputs);
sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs); sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs);