wip logprobs data

This commit is contained in:
Concedo 2024-10-30 00:59:34 +08:00
parent bd05efd648
commit 90f5cd0f67
6 changed files with 70 additions and 19 deletions

View file

@ -101,6 +101,14 @@ class token_count_outputs(ctypes.Structure):
_fields_ = [("count", ctypes.c_int),
("ids", ctypes.POINTER(ctypes.c_int))]
# returns top 5 logprobs per token
class last_logprobs_outputs(ctypes.Structure):
_fields_ = [("count", ctypes.c_int),
("selected_token", ctypes.POINTER(ctypes.c_char_p)),
("selected_logprob", ctypes.POINTER(ctypes.c_float)),
("tokens", ctypes.POINTER(5 * ctypes.c_char_p)),
("logprobs", ctypes.POINTER(5 * ctypes.c_float))]
class load_model_inputs(ctypes.Structure):
_fields_ = [("threads", ctypes.c_int),
("blasthreads", ctypes.c_int),
@ -445,6 +453,7 @@ def init_library():
handle.whisper_load_model.restype = ctypes.c_bool
handle.whisper_generate.argtypes = [whisper_generation_inputs]
handle.whisper_generate.restype = whisper_generation_outputs
handle.last_logprobs.restype = last_logprobs_outputs
def set_backend_props(inputs):
clblastids = 0