mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
wip adding embeddings support
This commit is contained in:
parent
b1641ee4a2
commit
3992fb79cc
7 changed files with 378 additions and 12 deletions
45
koboldcpp.py
45
koboldcpp.py
|
@ -318,6 +318,25 @@ class tts_generation_outputs(ctypes.Structure):
|
|||
_fields_ = [("status", ctypes.c_int),
|
||||
("data", ctypes.c_char_p)]
|
||||
|
||||
class embeddings_load_model_inputs(ctypes.Structure):
|
||||
_fields_ = [("threads", ctypes.c_int),
|
||||
("model_filename", ctypes.c_char_p),
|
||||
("executable_path", ctypes.c_char_p),
|
||||
("clblast_info", ctypes.c_int),
|
||||
("cublas_info", ctypes.c_int),
|
||||
("vulkan_info", ctypes.c_char_p),
|
||||
("gpulayers", ctypes.c_int),
|
||||
("flash_attention", ctypes.c_bool),
|
||||
("quiet", ctypes.c_bool),
|
||||
("debugmode", ctypes.c_int)]
|
||||
|
||||
class embeddings_generation_inputs(ctypes.Structure):
|
||||
_fields_ = [("prompt", ctypes.c_char_p)]
|
||||
|
||||
class embeddings_generation_outputs(ctypes.Structure):
|
||||
_fields_ = [("status", ctypes.c_int),
|
||||
("data", ctypes.c_char_p)]
|
||||
|
||||
def getdirpath():
|
||||
return os.path.dirname(os.path.realpath(__file__))
|
||||
def getabspath():
|
||||
|
@ -491,6 +510,10 @@ def init_library():
|
|||
handle.tts_load_model.restype = ctypes.c_bool
|
||||
handle.tts_generate.argtypes = [tts_generation_inputs]
|
||||
handle.tts_generate.restype = tts_generation_outputs
|
||||
handle.embeddings_load_model.argtypes = [embeddings_load_model_inputs]
|
||||
handle.embeddings_load_model.restype = ctypes.c_bool
|
||||
handle.embeddings_generate.argtypes = [embeddings_generation_inputs]
|
||||
handle.embeddings_generate.restype = embeddings_generation_outputs
|
||||
handle.last_logprobs.restype = last_logprobs_outputs
|
||||
handle.detokenize.argtypes = [token_count_outputs]
|
||||
handle.detokenize.restype = ctypes.c_char_p
|
||||
|
@ -1564,6 +1587,28 @@ def tts_generate(genparams):
|
|||
outstr = ret.data.decode("UTF-8","ignore")
|
||||
return outstr
|
||||
|
||||
def embeddings_load_model(model_filename):
|
||||
global args
|
||||
inputs = embeddings_load_model_inputs()
|
||||
inputs.model_filename = model_filename.encode("UTF-8")
|
||||
inputs.gpulayers = (999 if args.ttsgpu else 0)
|
||||
inputs.flash_attention = args.flashattention
|
||||
inputs.threads = args.threads
|
||||
inputs = set_backend_props(inputs)
|
||||
ret = handle.embeddings_load_model(inputs)
|
||||
return ret
|
||||
|
||||
def embeddings_generate(genparams):
|
||||
global args
|
||||
prompt = genparams.get("input", "")
|
||||
inputs = embeddings_generation_inputs()
|
||||
inputs.prompt = prompt.encode("UTF-8")
|
||||
ret = handle.embeddings_generate(inputs)
|
||||
outstr = ""
|
||||
if ret.status==1:
|
||||
outstr = ret.data.decode("UTF-8","ignore")
|
||||
return outstr
|
||||
|
||||
def tokenize_ids(countprompt,tcaddspecial):
|
||||
rawcountdata = handle.token_count(countprompt.encode("UTF-8"),tcaddspecial)
|
||||
countlimit = rawcountdata.count if (rawcountdata.count>=0 and rawcountdata.count<50000) else 0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue