kobold integration of min_p sampler (+1 squashed commits)

Squashed commits:

[8ad2e349] kobold integration for min_p sampler
This commit is contained in:
Concedo 2023-11-01 19:07:26 +08:00
parent bcb397953f
commit ae2cd56de8
5 changed files with 38 additions and 7 deletions

View file

@ -55,6 +55,7 @@ class generation_inputs(ctypes.Structure):
("top_k", ctypes.c_int),
("top_a", ctypes.c_float),
("top_p", ctypes.c_float),
("min_p", ctypes.c_float),
("typical_p", ctypes.c_float),
("tfs", ctypes.c_float),
("rep_pen", ctypes.c_float),
@ -286,7 +287,7 @@ def load_model(model_filename):
ret = handle.load_model(inputs)
return ret
def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey=''):
def generate(prompt,max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey=''):
global maxctx, args, currentusergenkey, totalgens
inputs = generation_inputs()
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
@ -303,6 +304,7 @@ def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_
inputs.top_k = top_k
inputs.top_a = top_a
inputs.top_p = top_p
inputs.min_p = min_p
inputs.typical_p = typical_p
inputs.tfs = tfs
inputs.rep_pen = rep_pen
@ -463,10 +465,11 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
prompt=genparams.get('prompt', ""),
max_context_length=genparams.get('max_context_length', maxctx),
max_length=genparams.get('max_length', 80),
temperature=genparams.get('temperature', 0.8),
top_k=genparams.get('top_k', 120),
temperature=genparams.get('temperature', 0.7),
top_k=genparams.get('top_k', 100),
top_a=genparams.get('top_a', 0.0),
top_p=genparams.get('top_p', 0.85),
top_p=genparams.get('top_p', 0.92),
min_p=genparams.get('min_p', 0.0),
typical_p=genparams.get('typical', 1.0),
tfs=genparams.get('tfs', 1.0),
rep_pen=genparams.get('rep_pen', 1.1),