mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 09:04:36 +00:00
kobold integration of min_p sampler (+1 squashed commits)
Squashed commits: [8ad2e349] kobold integration for min_p sampler
This commit is contained in:
parent
bcb397953f
commit
ae2cd56de8
5 changed files with 38 additions and 7 deletions
11
koboldcpp.py
11
koboldcpp.py
|
@ -55,6 +55,7 @@ class generation_inputs(ctypes.Structure):
|
|||
("top_k", ctypes.c_int),
|
||||
("top_a", ctypes.c_float),
|
||||
("top_p", ctypes.c_float),
|
||||
("min_p", ctypes.c_float),
|
||||
("typical_p", ctypes.c_float),
|
||||
("tfs", ctypes.c_float),
|
||||
("rep_pen", ctypes.c_float),
|
||||
|
@ -286,7 +287,7 @@ def load_model(model_filename):
|
|||
ret = handle.load_model(inputs)
|
||||
return ret
|
||||
|
||||
def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey=''):
|
||||
def generate(prompt,max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey=''):
|
||||
global maxctx, args, currentusergenkey, totalgens
|
||||
inputs = generation_inputs()
|
||||
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
|
||||
|
@ -303,6 +304,7 @@ def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_
|
|||
inputs.top_k = top_k
|
||||
inputs.top_a = top_a
|
||||
inputs.top_p = top_p
|
||||
inputs.min_p = min_p
|
||||
inputs.typical_p = typical_p
|
||||
inputs.tfs = tfs
|
||||
inputs.rep_pen = rep_pen
|
||||
|
@ -463,10 +465,11 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
prompt=genparams.get('prompt', ""),
|
||||
max_context_length=genparams.get('max_context_length', maxctx),
|
||||
max_length=genparams.get('max_length', 80),
|
||||
temperature=genparams.get('temperature', 0.8),
|
||||
top_k=genparams.get('top_k', 120),
|
||||
temperature=genparams.get('temperature', 0.7),
|
||||
top_k=genparams.get('top_k', 100),
|
||||
top_a=genparams.get('top_a', 0.0),
|
||||
top_p=genparams.get('top_p', 0.85),
|
||||
top_p=genparams.get('top_p', 0.92),
|
||||
min_p=genparams.get('min_p', 0.0),
|
||||
typical_p=genparams.get('typical', 1.0),
|
||||
tfs=genparams.get('tfs', 1.0),
|
||||
rep_pen=genparams.get('rep_pen', 1.1),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue