From 62dde8cfb24d31ce5bbe201be93f6b3234a908b3 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Sat, 23 Nov 2024 23:31:37 +0800 Subject: [PATCH] ollama sync completions mostly working. stupid api. --- gpttype_adapter.cpp | 4 ++++ koboldcpp.py | 57 ++++++++++++++++++++++++++++++++------------- 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 16b409a8a..d967bdacf 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -2533,6 +2533,10 @@ std::string gpttype_detokenize(const std::vector & inputids, bool render_sp std::string output = ""; for (auto eid : inputids) { + if(eid<0 || eid>=n_vocab) + { + continue; + } std::string tokenizedstr = FileFormatTokenizeID(eid, file_format, render_special); output += tokenizedstr; } diff --git a/koboldcpp.py b/koboldcpp.py index 02ac4ab10..11fed1002 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -1243,6 +1243,26 @@ def whisper_generate(genparams): outstr = ret.data.decode("UTF-8","ignore") return outstr +def tokenize_ids(countprompt,tcaddspecial): + rawcountdata = handle.token_count(countprompt.encode("UTF-8"),tcaddspecial) + countlimit = rawcountdata.count if (rawcountdata.count>=0 and rawcountdata.count<50000) else 0 + # the above protects the server in case the count limit got corrupted + countdata = [rawcountdata.ids[i] for i in range(countlimit)] + return countdata + +def detokenize_ids(tokids): + tokidslen = len(tokids) + detokstr = "" + if tokidslen > 0 and tokidslen < 65536: + inputs = token_count_outputs() + inputs.count = tokidslen + inputs.ids = (ctypes.c_int * tokidslen)() + for i, cid in enumerate(tokids): + inputs.ids[i] = cid + detok = handle.detokenize(inputs) + detokstr = ctypes.string_at(detok).decode("UTF-8","ignore") + return detokstr + ################################################################# ### A hacky simple HTTP server simulating a kobold api by Concedo ### we are intentionally NOT using flask, because we want MINIMAL dependencies @@ -1463,7 +1483,22 @@ ws ::= | " " | "\n" [ \t]{0,20} assistant_message_start = adapter_obj.get("assistant_start", "### Response:") genparams["prompt"] = f"{user_message_start} In one sentence, write a descriptive caption for this image.\n{assistant_message_start}" elif api_format==6: - genparams["prompt"] = genparams.get('system', "") + genparams.get('prompt', "") + detokstr = "" + tokids = genparams.get('context', []) + try: + detokstr = detokenize_ids(tokids) + except Exception as e: + utfprint("Ollama Context Error: " + str(e)) + ollamasysprompt = genparams.get('system', "") + ollamabodyprompt = detokstr + "\n\n### Instruction:\n" + genparams.get('prompt', "") + "\n\n### Response:\n" + genparams["stop_sequence"] = genparams.get('stop', []) + genparams["stop_sequence"].append("\n### Instruction:") + genparams["stop_sequence"].append("\n### Response:") + genparams["trim_stop"] = True + genparams["ollamasysprompt"] = ollamasysprompt + genparams["ollamabodyprompt"] = ollamabodyprompt + genparams["prompt"] = ollamasysprompt + ollamabodyprompt + utfprint(genparams["prompt"]) return genparams @@ -1568,7 +1603,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): elif api_format == 5: res = {"caption": end_trim_to_sentence(recvtxt)} elif api_format == 6: - res = {"model": friendlymodelname,"created_at": str(datetime.now(timezone.utc).isoformat()),"response":recvtxt,"done": True,"context": [1,2,3],"total_duration": 1,"load_duration": 1,"prompt_eval_count": prompttokens,"prompt_eval_duration": 1,"eval_count": comptokens,"eval_duration": 1} + oldprompt = genparams.get('ollamabodyprompt', "") + tokarr = tokenize_ids(oldprompt+recvtxt,False) + res = {"model": friendlymodelname,"created_at": str(datetime.now(timezone.utc).isoformat()),"response":recvtxt,"done": True,"context": tokarr,"total_duration": 1,"load_duration": 1,"prompt_eval_count": prompttokens,"prompt_eval_duration": 1,"eval_count": comptokens,"eval_duration": 1} else: res = {"results": [{"text": recvtxt, "finish_reason": currfinishreason, "logprobs":logprobsdict, "prompt_tokens": prompttokens, "completion_tokens": comptokens}]} @@ -2038,10 +2075,7 @@ Enter Prompt:
genparams = json.loads(body) countprompt = genparams.get('prompt', "") tcaddspecial = genparams.get('special', True) - rawcountdata = handle.token_count(countprompt.encode("UTF-8"),tcaddspecial) - countlimit = rawcountdata.count if (rawcountdata.count>=0 and rawcountdata.count<50000) else 0 - # the above protects the server in case the count limit got corrupted - countdata = [rawcountdata.ids[i] for i in range(countlimit)] + countdata = tokenize_ids(countprompt,tcaddspecial) response_body = (json.dumps({"value": len(countdata),"ids": countdata}).encode()) except Exception as e: @@ -2055,16 +2089,7 @@ Enter Prompt:
try: genparams = json.loads(body) tokids = genparams.get('ids', []) - tokidslen = len(tokids) - detokstr = "" - if tokidslen > 0 and tokidslen < 65536: - inputs = token_count_outputs() - inputs.count = tokidslen - inputs.ids = (ctypes.c_int * tokidslen)() - for i, cid in enumerate(tokids): - inputs.ids[i] = cid - detok = handle.detokenize(inputs) - detokstr = ctypes.string_at(detok).decode("UTF-8","ignore") + detokstr = detokenize_ids(tokids) response_body = (json.dumps({"result": detokstr,"success":True}).encode()) except Exception as e: utfprint("Detokenize Error: " + str(e))