diff --git a/expose.cpp b/expose.cpp index b9c57f7ea..d94c53ab3 100644 --- a/expose.cpp +++ b/expose.cpp @@ -300,6 +300,18 @@ extern "C" return output; } + static std::string detokenized_str = ""; //just share a static object for detokenizing + const char * detokenize(const token_count_outputs input) + { + std::vector input_arr; + for(int i=0;i last_logprob_toppicks; static std::vector last_logprob_items; last_logprobs_outputs last_logprobs() diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index d7e14105a..16b409a8a 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -2528,6 +2528,17 @@ std::vector gpttype_get_token_arr(const std::string & input, bool addbos) return toks; } +std::string gpttype_detokenize(const std::vector & inputids, bool render_special) +{ + std::string output = ""; + for (auto eid : inputids) + { + std::string tokenizedstr = FileFormatTokenizeID(eid, file_format, render_special); + output += tokenizedstr; + } + return output; +} + const std::string & gpttype_get_pending_output() { if(kcpp_data==nullptr) diff --git a/koboldcpp.py b/koboldcpp.py index 71ac46476..02ac4ab10 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -471,6 +471,8 @@ def init_library(): handle.whisper_generate.argtypes = [whisper_generation_inputs] handle.whisper_generate.restype = whisper_generation_outputs handle.last_logprobs.restype = last_logprobs_outputs + handle.detokenize.argtypes = [token_count_outputs] + handle.detokenize.restype = ctypes.c_char_p def set_backend_props(inputs): clblastids = 0 @@ -1310,7 +1312,7 @@ def parse_last_logprobs(lastlogprobs): def transform_genparams(genparams, api_format): global chatcompl_adapter - #api format 1=basic,2=kai,3=oai,4=oai-chat,5=interrogate + #api format 1=basic,2=kai,3=oai,4=oai-chat,5=interrogate,6=ollama #alias all nonstandard alternative names for rep pen. rp1 = genparams.get('repeat_penalty', 1.0) rp2 = genparams.get('repetition_penalty', 1.0) @@ -1460,6 +1462,8 @@ ws ::= | " " | "\n" [ \t]{0,20} user_message_start = adapter_obj.get("user_start", "### Instruction:") assistant_message_start = adapter_obj.get("assistant_start", "### Response:") genparams["prompt"] = f"{user_message_start} In one sentence, write a descriptive caption for this image.\n{assistant_message_start}" + elif api_format==6: + genparams["prompt"] = genparams.get('system', "") + genparams.get('prompt', "") return genparams @@ -1563,6 +1567,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): "choices": [{"index": 0, "message": {"role": "assistant", "content": recvtxt, "tool_calls": tool_calls}, "finish_reason": currfinishreason, "logprobs":logprobsdict}]} elif api_format == 5: res = {"caption": end_trim_to_sentence(recvtxt)} + elif api_format == 6: + res = {"model": friendlymodelname,"created_at": str(datetime.now(timezone.utc).isoformat()),"response":recvtxt,"done": True,"context": [1,2,3],"total_duration": 1,"load_duration": 1,"prompt_eval_count": prompttokens,"prompt_eval_duration": 1,"eval_count": comptokens,"eval_duration": 1} else: res = {"results": [{"text": recvtxt, "finish_reason": currfinishreason, "logprobs":logprobsdict, "prompt_tokens": prompttokens, "completion_tokens": comptokens}]} @@ -2025,7 +2031,7 @@ Enter Prompt:
response_body = None response_code = 200 - if self.path.endswith(('/api/extra/tokencount')): + if self.path.endswith('/api/extra/tokencount') or self.path.endswith('/api/extra/tokenize'): if not self.secure_endpoint(): return try: @@ -2043,6 +2049,28 @@ Enter Prompt:
response_code = 400 response_body = (json.dumps({"value": -1}).encode()) + elif self.path.endswith('/api/extra/detokenize'): + if not self.secure_endpoint(): + return + try: + genparams = json.loads(body) + tokids = genparams.get('ids', []) + tokidslen = len(tokids) + detokstr = "" + if tokidslen > 0 and tokidslen < 65536: + inputs = token_count_outputs() + inputs.count = tokidslen + inputs.ids = (ctypes.c_int * tokidslen)() + for i, cid in enumerate(tokids): + inputs.ids[i] = cid + detok = handle.detokenize(inputs) + detokstr = ctypes.string_at(detok).decode("UTF-8","ignore") + response_body = (json.dumps({"result": detokstr,"success":True}).encode()) + except Exception as e: + utfprint("Detokenize Error: " + str(e)) + response_code = 400 + response_body = (json.dumps({"result": "","success":False}).encode()) + elif self.path.endswith('/api/extra/abort'): if not self.secure_endpoint(): return @@ -2101,7 +2129,7 @@ Enter Prompt:
logprobsdict = parse_last_logprobs(lastlogprobs) response_body = (json.dumps({"logprobs":logprobsdict}).encode()) - elif self.path=="/api/extra/multiplayer/status": + elif self.path.endswith('/api/extra/multiplayer/status'): if not self.secure_endpoint(): return if not has_multiplayer: @@ -2120,7 +2148,7 @@ Enter Prompt:
multiplayer_lastactive[sender] = int(time.time()) response_body = (json.dumps({"turn_major":multiplayer_turn_major,"turn_minor":multiplayer_turn_minor,"idle":self.get_multiplayer_idle_state(sender),"data_format":multiplayer_dataformat}).encode()) - elif self.path=="/api/extra/multiplayer/getstory": + elif self.path.endswith('/api/extra/multiplayer/getstory'): if not self.secure_endpoint(): return if not has_multiplayer: @@ -2130,7 +2158,7 @@ Enter Prompt:
else: response_body = multiplayer_story_data_compressed.encode() - elif self.path=="/api/extra/multiplayer/setstory": + elif self.path.endswith('/api/extra/multiplayer/setstory'): if not self.secure_endpoint(): return if not has_multiplayer: @@ -2197,7 +2225,7 @@ Enter Prompt:
try: sse_stream_flag = False - api_format = 0 #1=basic,2=kai,3=oai,4=oai-chat,5=interrogate + api_format = 0 #1=basic,2=kai,3=oai,4=oai-chat,5=interrogate,6=ollama is_imggen = False is_transcribe = False @@ -2229,6 +2257,9 @@ Enter Prompt:
return api_format = 5 + if self.path.endswith('/api/generate'): + api_format = 6 + if self.path.endswith('/sdapi/v1/txt2img') or self.path.endswith('/sdapi/v1/img2img'): is_imggen = True @@ -2239,7 +2270,7 @@ Enter Prompt:
global last_req_time last_req_time = time.time() - if not is_imggen and not is_transcribe and api_format<5: + if not is_imggen and not is_transcribe and api_format!=5: if not self.secure_endpoint(): return diff --git a/model_adapter.h b/model_adapter.h index 5c448320f..48195a26d 100644 --- a/model_adapter.h +++ b/model_adapter.h @@ -95,6 +95,7 @@ std::string gpttype_get_chat_template(); const std::string & gpttype_get_pending_output(); std::vector gpttype_get_token_arr(const std::string & input, bool addbos); +std::string gpttype_detokenize(const std::vector & input, bool render_special); const std::vector gpttype_get_top_picks_data(); bool sdtype_load_model(const sd_load_model_inputs inputs);