add presence penalty

This commit is contained in:
Concedo 2023-12-19 23:18:56 +08:00
parent da2db0302c
commit 3f863eed72
3 changed files with 18 additions and 9 deletions

View file

@ -60,6 +60,7 @@ class generation_inputs(ctypes.Structure):
("tfs", ctypes.c_float),
("rep_pen", ctypes.c_float),
("rep_pen_range", ctypes.c_int),
("presence_penalty", ctypes.c_float),
("mirostat", ctypes.c_int),
("mirostat_tau", ctypes.c_float),
("mirostat_eta", ctypes.c_float),
@ -302,7 +303,7 @@ def load_model(model_filename):
ret = handle.load_model(inputs)
return ret
def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False):
def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False):
global maxctx, args, currentusergenkey, totalgens
inputs = generation_inputs()
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
@ -327,6 +328,7 @@ def generate(prompt, memory="", max_length=32, max_context_length=512, temperatu
inputs.tfs = tfs
inputs.rep_pen = rep_pen
inputs.rep_pen_range = rep_pen_range
inputs.presence_penalty = presence_penalty
inputs.stream_sse = stream_sse
inputs.quiet = quiet
inputs.grammar = grammar.encode("UTF-8")
@ -440,10 +442,11 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
genparams["max_length"] = genparams.get('max', 100)
elif api_format==3 or api_format==4:
frqp = genparams.get('frequency_penalty', 0.1)
scaled_rep_pen = genparams.get('presence_penalty', frqp) + 1
genparams["max_length"] = genparams.get('max_tokens', 100)
genparams["rep_pen"] = scaled_rep_pen
presence_penalty = genparams.get('presence_penalty', genparams.get('frequency_penalty', 0.0))
genparams["presence_penalty"] = presence_penalty
if presence_penalty > 0:
genparams["rep_pen"] = 1.0
# openai allows either a string or a list as a stop sequence
if isinstance(genparams.get('stop',[]), list):
genparams["stop_sequence"] = genparams.get('stop', [])
@ -500,6 +503,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
tfs=genparams.get('tfs', 1.0),
rep_pen=genparams.get('rep_pen', 1.1),
rep_pen_range=genparams.get('rep_pen_range', 256),
presence_penalty=genparams.get('presence_penalty', 0.0),
mirostat=genparams.get('mirostat', 0),
mirostat_tau=genparams.get('mirostat_tau', 5.0),
mirostat_eta=genparams.get('mirostat_eta', 0.1),