mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
wip ollama emulation, added detokenize endpoint
This commit is contained in:
parent
c0da7e4dcf
commit
2c1a06a07d
4 changed files with 62 additions and 7 deletions
12
expose.cpp
12
expose.cpp
|
@ -300,6 +300,18 @@ extern "C"
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::string detokenized_str = ""; //just share a static object for detokenizing
|
||||||
|
const char * detokenize(const token_count_outputs input)
|
||||||
|
{
|
||||||
|
std::vector<int> input_arr;
|
||||||
|
for(int i=0;i<input.count;++i)
|
||||||
|
{
|
||||||
|
input_arr.push_back(input.ids[i]);
|
||||||
|
}
|
||||||
|
detokenized_str = gpttype_detokenize(input_arr,false);
|
||||||
|
return detokenized_str.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
static std::vector<TopPicksData> last_logprob_toppicks;
|
static std::vector<TopPicksData> last_logprob_toppicks;
|
||||||
static std::vector<logprob_item> last_logprob_items;
|
static std::vector<logprob_item> last_logprob_items;
|
||||||
last_logprobs_outputs last_logprobs()
|
last_logprobs_outputs last_logprobs()
|
||||||
|
|
|
@ -2528,6 +2528,17 @@ std::vector<int> gpttype_get_token_arr(const std::string & input, bool addbos)
|
||||||
return toks;
|
return toks;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string gpttype_detokenize(const std::vector<int> & inputids, bool render_special)
|
||||||
|
{
|
||||||
|
std::string output = "";
|
||||||
|
for (auto eid : inputids)
|
||||||
|
{
|
||||||
|
std::string tokenizedstr = FileFormatTokenizeID(eid, file_format, render_special);
|
||||||
|
output += tokenizedstr;
|
||||||
|
}
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
const std::string & gpttype_get_pending_output()
|
const std::string & gpttype_get_pending_output()
|
||||||
{
|
{
|
||||||
if(kcpp_data==nullptr)
|
if(kcpp_data==nullptr)
|
||||||
|
|
45
koboldcpp.py
45
koboldcpp.py
|
@ -471,6 +471,8 @@ def init_library():
|
||||||
handle.whisper_generate.argtypes = [whisper_generation_inputs]
|
handle.whisper_generate.argtypes = [whisper_generation_inputs]
|
||||||
handle.whisper_generate.restype = whisper_generation_outputs
|
handle.whisper_generate.restype = whisper_generation_outputs
|
||||||
handle.last_logprobs.restype = last_logprobs_outputs
|
handle.last_logprobs.restype = last_logprobs_outputs
|
||||||
|
handle.detokenize.argtypes = [token_count_outputs]
|
||||||
|
handle.detokenize.restype = ctypes.c_char_p
|
||||||
|
|
||||||
def set_backend_props(inputs):
|
def set_backend_props(inputs):
|
||||||
clblastids = 0
|
clblastids = 0
|
||||||
|
@ -1310,7 +1312,7 @@ def parse_last_logprobs(lastlogprobs):
|
||||||
|
|
||||||
def transform_genparams(genparams, api_format):
|
def transform_genparams(genparams, api_format):
|
||||||
global chatcompl_adapter
|
global chatcompl_adapter
|
||||||
#api format 1=basic,2=kai,3=oai,4=oai-chat,5=interrogate
|
#api format 1=basic,2=kai,3=oai,4=oai-chat,5=interrogate,6=ollama
|
||||||
#alias all nonstandard alternative names for rep pen.
|
#alias all nonstandard alternative names for rep pen.
|
||||||
rp1 = genparams.get('repeat_penalty', 1.0)
|
rp1 = genparams.get('repeat_penalty', 1.0)
|
||||||
rp2 = genparams.get('repetition_penalty', 1.0)
|
rp2 = genparams.get('repetition_penalty', 1.0)
|
||||||
|
@ -1460,6 +1462,8 @@ ws ::= | " " | "\n" [ \t]{0,20}
|
||||||
user_message_start = adapter_obj.get("user_start", "### Instruction:")
|
user_message_start = adapter_obj.get("user_start", "### Instruction:")
|
||||||
assistant_message_start = adapter_obj.get("assistant_start", "### Response:")
|
assistant_message_start = adapter_obj.get("assistant_start", "### Response:")
|
||||||
genparams["prompt"] = f"{user_message_start} In one sentence, write a descriptive caption for this image.\n{assistant_message_start}"
|
genparams["prompt"] = f"{user_message_start} In one sentence, write a descriptive caption for this image.\n{assistant_message_start}"
|
||||||
|
elif api_format==6:
|
||||||
|
genparams["prompt"] = genparams.get('system', "") + genparams.get('prompt', "")
|
||||||
|
|
||||||
return genparams
|
return genparams
|
||||||
|
|
||||||
|
@ -1563,6 +1567,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": recvtxt, "tool_calls": tool_calls}, "finish_reason": currfinishreason, "logprobs":logprobsdict}]}
|
"choices": [{"index": 0, "message": {"role": "assistant", "content": recvtxt, "tool_calls": tool_calls}, "finish_reason": currfinishreason, "logprobs":logprobsdict}]}
|
||||||
elif api_format == 5:
|
elif api_format == 5:
|
||||||
res = {"caption": end_trim_to_sentence(recvtxt)}
|
res = {"caption": end_trim_to_sentence(recvtxt)}
|
||||||
|
elif api_format == 6:
|
||||||
|
res = {"model": friendlymodelname,"created_at": str(datetime.now(timezone.utc).isoformat()),"response":recvtxt,"done": True,"context": [1,2,3],"total_duration": 1,"load_duration": 1,"prompt_eval_count": prompttokens,"prompt_eval_duration": 1,"eval_count": comptokens,"eval_duration": 1}
|
||||||
else:
|
else:
|
||||||
res = {"results": [{"text": recvtxt, "finish_reason": currfinishreason, "logprobs":logprobsdict, "prompt_tokens": prompttokens, "completion_tokens": comptokens}]}
|
res = {"results": [{"text": recvtxt, "finish_reason": currfinishreason, "logprobs":logprobsdict, "prompt_tokens": prompttokens, "completion_tokens": comptokens}]}
|
||||||
|
|
||||||
|
@ -2025,7 +2031,7 @@ Enter Prompt:<br>
|
||||||
response_body = None
|
response_body = None
|
||||||
response_code = 200
|
response_code = 200
|
||||||
|
|
||||||
if self.path.endswith(('/api/extra/tokencount')):
|
if self.path.endswith('/api/extra/tokencount') or self.path.endswith('/api/extra/tokenize'):
|
||||||
if not self.secure_endpoint():
|
if not self.secure_endpoint():
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
|
@ -2043,6 +2049,28 @@ Enter Prompt:<br>
|
||||||
response_code = 400
|
response_code = 400
|
||||||
response_body = (json.dumps({"value": -1}).encode())
|
response_body = (json.dumps({"value": -1}).encode())
|
||||||
|
|
||||||
|
elif self.path.endswith('/api/extra/detokenize'):
|
||||||
|
if not self.secure_endpoint():
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
genparams = json.loads(body)
|
||||||
|
tokids = genparams.get('ids', [])
|
||||||
|
tokidslen = len(tokids)
|
||||||
|
detokstr = ""
|
||||||
|
if tokidslen > 0 and tokidslen < 65536:
|
||||||
|
inputs = token_count_outputs()
|
||||||
|
inputs.count = tokidslen
|
||||||
|
inputs.ids = (ctypes.c_int * tokidslen)()
|
||||||
|
for i, cid in enumerate(tokids):
|
||||||
|
inputs.ids[i] = cid
|
||||||
|
detok = handle.detokenize(inputs)
|
||||||
|
detokstr = ctypes.string_at(detok).decode("UTF-8","ignore")
|
||||||
|
response_body = (json.dumps({"result": detokstr,"success":True}).encode())
|
||||||
|
except Exception as e:
|
||||||
|
utfprint("Detokenize Error: " + str(e))
|
||||||
|
response_code = 400
|
||||||
|
response_body = (json.dumps({"result": "","success":False}).encode())
|
||||||
|
|
||||||
elif self.path.endswith('/api/extra/abort'):
|
elif self.path.endswith('/api/extra/abort'):
|
||||||
if not self.secure_endpoint():
|
if not self.secure_endpoint():
|
||||||
return
|
return
|
||||||
|
@ -2101,7 +2129,7 @@ Enter Prompt:<br>
|
||||||
logprobsdict = parse_last_logprobs(lastlogprobs)
|
logprobsdict = parse_last_logprobs(lastlogprobs)
|
||||||
response_body = (json.dumps({"logprobs":logprobsdict}).encode())
|
response_body = (json.dumps({"logprobs":logprobsdict}).encode())
|
||||||
|
|
||||||
elif self.path=="/api/extra/multiplayer/status":
|
elif self.path.endswith('/api/extra/multiplayer/status'):
|
||||||
if not self.secure_endpoint():
|
if not self.secure_endpoint():
|
||||||
return
|
return
|
||||||
if not has_multiplayer:
|
if not has_multiplayer:
|
||||||
|
@ -2120,7 +2148,7 @@ Enter Prompt:<br>
|
||||||
multiplayer_lastactive[sender] = int(time.time())
|
multiplayer_lastactive[sender] = int(time.time())
|
||||||
response_body = (json.dumps({"turn_major":multiplayer_turn_major,"turn_minor":multiplayer_turn_minor,"idle":self.get_multiplayer_idle_state(sender),"data_format":multiplayer_dataformat}).encode())
|
response_body = (json.dumps({"turn_major":multiplayer_turn_major,"turn_minor":multiplayer_turn_minor,"idle":self.get_multiplayer_idle_state(sender),"data_format":multiplayer_dataformat}).encode())
|
||||||
|
|
||||||
elif self.path=="/api/extra/multiplayer/getstory":
|
elif self.path.endswith('/api/extra/multiplayer/getstory'):
|
||||||
if not self.secure_endpoint():
|
if not self.secure_endpoint():
|
||||||
return
|
return
|
||||||
if not has_multiplayer:
|
if not has_multiplayer:
|
||||||
|
@ -2130,7 +2158,7 @@ Enter Prompt:<br>
|
||||||
else:
|
else:
|
||||||
response_body = multiplayer_story_data_compressed.encode()
|
response_body = multiplayer_story_data_compressed.encode()
|
||||||
|
|
||||||
elif self.path=="/api/extra/multiplayer/setstory":
|
elif self.path.endswith('/api/extra/multiplayer/setstory'):
|
||||||
if not self.secure_endpoint():
|
if not self.secure_endpoint():
|
||||||
return
|
return
|
||||||
if not has_multiplayer:
|
if not has_multiplayer:
|
||||||
|
@ -2197,7 +2225,7 @@ Enter Prompt:<br>
|
||||||
try:
|
try:
|
||||||
sse_stream_flag = False
|
sse_stream_flag = False
|
||||||
|
|
||||||
api_format = 0 #1=basic,2=kai,3=oai,4=oai-chat,5=interrogate
|
api_format = 0 #1=basic,2=kai,3=oai,4=oai-chat,5=interrogate,6=ollama
|
||||||
is_imggen = False
|
is_imggen = False
|
||||||
is_transcribe = False
|
is_transcribe = False
|
||||||
|
|
||||||
|
@ -2229,6 +2257,9 @@ Enter Prompt:<br>
|
||||||
return
|
return
|
||||||
api_format = 5
|
api_format = 5
|
||||||
|
|
||||||
|
if self.path.endswith('/api/generate'):
|
||||||
|
api_format = 6
|
||||||
|
|
||||||
if self.path.endswith('/sdapi/v1/txt2img') or self.path.endswith('/sdapi/v1/img2img'):
|
if self.path.endswith('/sdapi/v1/txt2img') or self.path.endswith('/sdapi/v1/img2img'):
|
||||||
is_imggen = True
|
is_imggen = True
|
||||||
|
|
||||||
|
@ -2239,7 +2270,7 @@ Enter Prompt:<br>
|
||||||
global last_req_time
|
global last_req_time
|
||||||
last_req_time = time.time()
|
last_req_time = time.time()
|
||||||
|
|
||||||
if not is_imggen and not is_transcribe and api_format<5:
|
if not is_imggen and not is_transcribe and api_format!=5:
|
||||||
if not self.secure_endpoint():
|
if not self.secure_endpoint():
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -95,6 +95,7 @@ std::string gpttype_get_chat_template();
|
||||||
|
|
||||||
const std::string & gpttype_get_pending_output();
|
const std::string & gpttype_get_pending_output();
|
||||||
std::vector<int> gpttype_get_token_arr(const std::string & input, bool addbos);
|
std::vector<int> gpttype_get_token_arr(const std::string & input, bool addbos);
|
||||||
|
std::string gpttype_detokenize(const std::vector<int> & input, bool render_special);
|
||||||
const std::vector<TopPicksData> gpttype_get_top_picks_data();
|
const std::vector<TopPicksData> gpttype_get_top_picks_data();
|
||||||
|
|
||||||
bool sdtype_load_model(const sd_load_model_inputs inputs);
|
bool sdtype_load_model(const sd_load_model_inputs inputs);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue