integrated mirostat as a launch parameter, works on all models

This commit is contained in:
Concedo 2023-05-06 00:47:17 +08:00
parent 851f55325a
commit 8a964e76c8
3 changed files with 102 additions and 26 deletions

View file

@ -157,7 +157,7 @@ def load_model(model_filename):
ret = handle.load_model(inputs)
return ret
def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=100,top_p=0.85, typical_p=1.0, tfs=1.0 ,rep_pen=1.1,rep_pen_range=128,mirostat=0,mirostat_lr=0.1,mirostat_ent=5.0,seed=-1,stop_sequence=[]):
def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=100,top_p=0.85, typical_p=1.0, tfs=1.0 ,rep_pen=1.1,rep_pen_range=128,seed=-1,stop_sequence=[]):
inputs = generation_inputs()
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
inputs.prompt = prompt.encode("UTF-8")
@ -170,9 +170,12 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=
inputs.tfs = tfs
inputs.rep_pen = rep_pen
inputs.rep_pen_range = rep_pen_range
inputs.mirostat = mirostat
inputs.mirostat_eta = mirostat_lr
inputs.mirostat_tau = mirostat_ent
if args.usemirostat and args.usemirostat[0]>0:
inputs.mirostat = int(args.usemirostat[0])
inputs.mirostat_tau = float(args.usemirostat[1])
inputs.mirostat_eta = float(args.usemirostat[2])
else:
inputs.mirostat = inputs.mirostat_tau = inputs.mirostat_eta = 0
inputs.seed = seed
for n in range(0,stop_token_max):
if not stop_sequence or n >= len(stop_sequence):
@ -317,9 +320,6 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
tfs=genparams.get('tfs', 1.0),
rep_pen=genparams.get('rep_pen', 1.1),
rep_pen_range=genparams.get('rep_pen_range', 128),
mirostat=genparams.get('mirostat', 0),
mirostat_lr=genparams.get('mirostat_lr', 0.1),
mirostat_ent=genparams.get('mirostat_ent', 5.0),
seed=-1,
stop_sequence=genparams.get('stop_sequence', [])
)
@ -336,9 +336,6 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
tfs=genparams.get('tfs', 1.0),
rep_pen=genparams.get('rep_pen', 1.1),
rep_pen_range=genparams.get('rep_pen_range', 128),
mirostat=genparams.get('mirostat', 0),
mirostat_lr=genparams.get('mirostat_lr', 0.1),
mirostat_ent=genparams.get('mirostat_ent', 5.0),
seed=-1,
stop_sequence=genparams.get('stop_sequence', [])
)
@ -620,14 +617,15 @@ if __name__ == '__main__':
physical_core_limit = int(os.cpu_count()/2)
default_threads = (physical_core_limit if physical_core_limit<=3 else max(3,physical_core_limit-1))
parser.add_argument("--threads", help="Use a custom number of threads if specified. Otherwise, uses an amount based on CPU cores", type=int, default=default_threads)
parser.add_argument("--blasthreads", help="Use a different number of threads during BLAS if specified. Otherwise, has the same value as --threads", type=int, default=0)
parser.add_argument("--blasthreads", help="Use a different number of threads during BLAS if specified. Otherwise, has the same value as --threads",metavar=('[threads]'), type=int, default=0)
parser.add_argument("--psutil_set_threads", help="Experimental flag. If set, uses psutils to determine thread count based on physical cores.", action='store_true')
parser.add_argument("--highpriority", help="Experimental flag. If set, increases the process CPU priority, potentially speeding up generation. Use caution.", action='store_true')
parser.add_argument("--blasbatchsize", help="Sets the batch size used in BLAS processing (default 512)", type=int,choices=[32,64,128,256,512,1024], default=512)
parser.add_argument("--stream", help="Uses pseudo streaming when generating tokens. Only for the Kobold Lite UI.", action='store_true')
parser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently.", action='store_true')
parser.add_argument("--unbantokens", help="Normally, KoboldAI prevents certain tokens such as EOS and Square Brackets. This flag unbans them.", action='store_true')
parser.add_argument("--forceversion", help="If the model file format detection fails (e.g. rogue modified model) you can set this to override the detected format (enter desired version, e.g. 401 for GPTNeoX-Type2).", type=int, default=0)
parser.add_argument("--usemirostat", help="Experimental! Replaces your samplers with mirostat. Takes 3 params = [type(0/1/2), tau(5.0), eta(0.1)].",metavar=('[type]', '[tau]', '[eta]'), type=float, nargs=3)
parser.add_argument("--forceversion", help="If the model file format detection fails (e.g. rogue modified model) you can set this to override the detected format (enter desired version, e.g. 401 for GPTNeoX-Type2).",metavar=('[version]'), type=int, default=0)
parser.add_argument("--nommap", help="If set, do not use mmap to load newer models", action='store_true')
parser.add_argument("--usemlock", help="For Apple Systems. Force system to keep model in RAM rather than swapping or compressing", action='store_true')
parser.add_argument("--noavx2", help="Do not use AVX2 instructions, a slower compatibility mode for older devices. Does not work with --clblast.", action='store_true')