refactor - do not use a copy buffer to store generation outputs, instead return a cpp allocated ptr

This commit is contained in:
Concedo 2024-02-29 14:02:20 +08:00
parent f75e479db0
commit 524ba12abd
6 changed files with 32 additions and 26 deletions

View file

@ -87,7 +87,7 @@ class generation_inputs(ctypes.Structure):
class generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int),
("text", ctypes.c_char * 24576)]
("text", ctypes.c_char_p)]
class token_count_outputs(ctypes.Structure):
_fields_ = [("count", ctypes.c_int),
@ -242,7 +242,7 @@ def init_library():
handle.load_model.argtypes = [load_model_inputs]
handle.load_model.restype = ctypes.c_bool
handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever
handle.generate.argtypes = [generation_inputs]
handle.generate.restype = generation_outputs
handle.new_token.restype = ctypes.c_char_p
handle.new_token.argtypes = [ctypes.c_int]
@ -350,7 +350,6 @@ def load_model(model_filename):
def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}):
global maxctx, args, currentusergenkey, totalgens, pendingabortkey
inputs = generation_inputs()
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
inputs.prompt = prompt.encode("UTF-8")
inputs.memory = memory.encode("UTF-8")
if max_length >= (max_context_length-1):
@ -438,7 +437,7 @@ def generate(prompt, memory="", max_length=32, max_context_length=512, temperatu
pendingabortkey = ""
return ""
else:
ret = handle.generate(inputs,outputs)
ret = handle.generate(inputs)
outstr = ""
if ret.status==1:
outstr = ret.text.decode("UTF-8","ignore")