mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 09:04:36 +00:00
handle memory separately for kcpp
This commit is contained in:
parent
f277ed0e8c
commit
fb3bcac368
4 changed files with 105 additions and 22 deletions
|
@ -49,6 +49,7 @@ class load_model_inputs(ctypes.Structure):
|
|||
class generation_inputs(ctypes.Structure):
|
||||
_fields_ = [("seed", ctypes.c_int),
|
||||
("prompt", ctypes.c_char_p),
|
||||
("memory", ctypes.c_char_p),
|
||||
("max_context_length", ctypes.c_int),
|
||||
("max_length", ctypes.c_int),
|
||||
("temperature", ctypes.c_float),
|
||||
|
@ -73,7 +74,7 @@ class generation_inputs(ctypes.Structure):
|
|||
|
||||
class generation_outputs(ctypes.Structure):
|
||||
_fields_ = [("status", ctypes.c_int),
|
||||
("text", ctypes.c_char * 24576)]
|
||||
("text", ctypes.c_char * 32768)]
|
||||
|
||||
handle = None
|
||||
|
||||
|
@ -297,11 +298,12 @@ def load_model(model_filename):
|
|||
ret = handle.load_model(inputs)
|
||||
return ret
|
||||
|
||||
def generate(prompt,max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey=''):
|
||||
def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey=''):
|
||||
global maxctx, args, currentusergenkey, totalgens
|
||||
inputs = generation_inputs()
|
||||
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
|
||||
inputs.prompt = prompt.encode("UTF-8")
|
||||
inputs.memory = memory.encode("UTF-8")
|
||||
if max_length >= max_context_length:
|
||||
max_length = max_context_length-1
|
||||
inputs.max_context_length = max_context_length # this will resize the context buffer if changed
|
||||
|
@ -379,7 +381,7 @@ maxhordelen = 256
|
|||
modelbusy = threading.Lock()
|
||||
requestsinqueue = 0
|
||||
defaultport = 5001
|
||||
KcppVersion = "1.48.1"
|
||||
KcppVersion = "1.49"
|
||||
showdebug = True
|
||||
showsamplerwarning = True
|
||||
showmaxctxwarning = True
|
||||
|
@ -474,6 +476,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
|
||||
return generate(
|
||||
prompt=genparams.get('prompt', ""),
|
||||
memory=genparams.get('memory', ""),
|
||||
max_context_length=genparams.get('max_context_length', maxctx),
|
||||
max_length=genparams.get('max_length', 80),
|
||||
temperature=genparams.get('temperature', 0.7),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue