mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
token count includes ids
This commit is contained in:
parent
0ca814e544
commit
6570a2005b
5 changed files with 26 additions and 9 deletions
13
koboldcpp.py
13
koboldcpp.py
|
@ -77,6 +77,10 @@ class generation_outputs(ctypes.Structure):
|
|||
_fields_ = [("status", ctypes.c_int),
|
||||
("text", ctypes.c_char * 32768)]
|
||||
|
||||
class token_count_outputs(ctypes.Structure):
|
||||
_fields_ = [("count", ctypes.c_int),
|
||||
("ids", ctypes.POINTER(ctypes.c_int))]
|
||||
|
||||
handle = None
|
||||
|
||||
def getdirpath():
|
||||
|
@ -218,7 +222,7 @@ def init_library():
|
|||
handle.get_total_gens.restype = ctypes.c_int
|
||||
handle.get_last_stop_reason.restype = ctypes.c_int
|
||||
handle.abort_generate.restype = ctypes.c_bool
|
||||
handle.token_count.restype = ctypes.c_int
|
||||
handle.token_count.restype = token_count_outputs
|
||||
handle.get_pending_output.restype = ctypes.c_char_p
|
||||
|
||||
def load_model(model_filename):
|
||||
|
@ -729,8 +733,11 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
try:
|
||||
genparams = json.loads(body)
|
||||
countprompt = genparams.get('prompt', "")
|
||||
count = handle.token_count(countprompt.encode("UTF-8"))
|
||||
response_body = (json.dumps({"value": count}).encode())
|
||||
rawcountdata = handle.token_count(countprompt.encode("UTF-8"))
|
||||
countlimit = rawcountdata.count if (rawcountdata.count>=0 and rawcountdata.count<50000) else 0
|
||||
# the above protects the server in case the count limit got corrupted
|
||||
countdata = [rawcountdata.ids[i] for i in range(countlimit)]
|
||||
response_body = (json.dumps({"value": len(countdata),"ids": countdata}).encode())
|
||||
|
||||
except Exception as e:
|
||||
utfprint("Count Tokens - Body Error: " + str(e))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue