wip adding embeddings support

This commit is contained in:
Concedo 2025-03-24 18:01:23 +08:00
parent b1641ee4a2
commit 3992fb79cc
7 changed files with 378 additions and 12 deletions

View file

@ -318,6 +318,25 @@ class tts_generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int),
("data", ctypes.c_char_p)]
class embeddings_load_model_inputs(ctypes.Structure):
_fields_ = [("threads", ctypes.c_int),
("model_filename", ctypes.c_char_p),
("executable_path", ctypes.c_char_p),
("clblast_info", ctypes.c_int),
("cublas_info", ctypes.c_int),
("vulkan_info", ctypes.c_char_p),
("gpulayers", ctypes.c_int),
("flash_attention", ctypes.c_bool),
("quiet", ctypes.c_bool),
("debugmode", ctypes.c_int)]
class embeddings_generation_inputs(ctypes.Structure):
_fields_ = [("prompt", ctypes.c_char_p)]
class embeddings_generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int),
("data", ctypes.c_char_p)]
def getdirpath():
return os.path.dirname(os.path.realpath(__file__))
def getabspath():
@ -491,6 +510,10 @@ def init_library():
handle.tts_load_model.restype = ctypes.c_bool
handle.tts_generate.argtypes = [tts_generation_inputs]
handle.tts_generate.restype = tts_generation_outputs
handle.embeddings_load_model.argtypes = [embeddings_load_model_inputs]
handle.embeddings_load_model.restype = ctypes.c_bool
handle.embeddings_generate.argtypes = [embeddings_generation_inputs]
handle.embeddings_generate.restype = embeddings_generation_outputs
handle.last_logprobs.restype = last_logprobs_outputs
handle.detokenize.argtypes = [token_count_outputs]
handle.detokenize.restype = ctypes.c_char_p
@ -1564,6 +1587,28 @@ def tts_generate(genparams):
outstr = ret.data.decode("UTF-8","ignore")
return outstr
def embeddings_load_model(model_filename):
global args
inputs = embeddings_load_model_inputs()
inputs.model_filename = model_filename.encode("UTF-8")
inputs.gpulayers = (999 if args.ttsgpu else 0)
inputs.flash_attention = args.flashattention
inputs.threads = args.threads
inputs = set_backend_props(inputs)
ret = handle.embeddings_load_model(inputs)
return ret
def embeddings_generate(genparams):
global args
prompt = genparams.get("input", "")
inputs = embeddings_generation_inputs()
inputs.prompt = prompt.encode("UTF-8")
ret = handle.embeddings_generate(inputs)
outstr = ""
if ret.status==1:
outstr = ret.data.decode("UTF-8","ignore")
return outstr
def tokenize_ids(countprompt,tcaddspecial):
rawcountdata = handle.token_count(countprompt.encode("UTF-8"),tcaddspecial)
countlimit = rawcountdata.count if (rawcountdata.count>=0 and rawcountdata.count<50000) else 0