refactored a lot of code, remove bantokens, move it to api

This commit is contained in:
Concedo 2024-04-27 17:57:13 +08:00
parent 4ec8a9c57b
commit c230b78906
6 changed files with 214 additions and 76 deletions

View file

@ -56,7 +56,6 @@ class load_model_inputs(ctypes.Structure):
("gpulayers", ctypes.c_int),
("rope_freq_scale", ctypes.c_float),
("rope_freq_base", ctypes.c_float),
("banned_tokens", ctypes.c_char_p * ban_token_max),
("tensor_split", ctypes.c_float * tensor_split_max)]
class generation_inputs(ctypes.Structure):
@ -91,7 +90,8 @@ class generation_inputs(ctypes.Structure):
("dynatemp_range", ctypes.c_float),
("dynatemp_exponent", ctypes.c_float),
("smoothing_factor", ctypes.c_float),
("logit_biases", logit_bias * logit_bias_max)]
("logit_biases", logit_bias * logit_bias_max),
("banned_tokens", ctypes.c_char_p * ban_token_max)]
class generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int),
@ -391,16 +391,10 @@ def load_model(model_filename):
inputs.executable_path = (getdirpath()+"/").encode("UTF-8")
inputs.debugmode = args.debugmode
banned_tokens = args.bantokens
for n in range(ban_token_max):
if not banned_tokens or n >= len(banned_tokens):
inputs.banned_tokens[n] = "".encode("UTF-8")
else:
inputs.banned_tokens[n] = banned_tokens[n].encode("UTF-8")
ret = handle.load_model(inputs)
return ret
def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False):
def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[]):
global maxctx, args, currentusergenkey, totalgens, pendingabortkey
inputs = generation_inputs()
inputs.prompt = prompt.encode("UTF-8")
@ -487,6 +481,12 @@ def generate(prompt, memory="", images=[], max_length=32, max_context_length=512
inputs.logit_biases[n] = logit_bias(-1, 0.0)
print(f"Skipped unparsable logit bias:{ex}")
for n in range(ban_token_max):
if not banned_tokens or n >= len(banned_tokens):
inputs.banned_tokens[n] = "".encode("UTF-8")
else:
inputs.banned_tokens[n] = banned_tokens[n].encode("UTF-8")
currentusergenkey = genkey
totalgens += 1
#early exit if aborted
@ -672,6 +672,10 @@ def transform_genparams(genparams, api_format):
genparams["top_k"] = int(genparams.get('top_k', 120))
genparams["max_length"] = genparams.get('max', 100)
elif api_format==2:
if "ignore_eos" in genparams and not ("use_default_badwordsids" in genparams):
genparams["use_default_badwordsids"] = genparams.get('ignore_eos', False)
elif api_format==3 or api_format==4:
genparams["max_length"] = genparams.get('max_tokens', 100)
presence_penalty = genparams.get('presence_penalty', genparams.get('frequency_penalty', 0.0))
@ -813,6 +817,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
smoothing_factor=genparams.get('smoothing_factor', 0.0),
logit_biases=genparams.get('logit_bias', {}),
render_special=genparams.get('render_special', False),
banned_tokens=genparams.get('banned_tokens', []),
)
genout = {"text":"","status":-1,"stopreason":-1}
@ -3281,7 +3286,6 @@ if __name__ == '__main__':
parser.add_argument("--lora", help="LLAMA models only, applies a lora file on top of model. Experimental.", metavar=('[lora_filename]', '[lora_base]'), nargs='+')
parser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently.", action='store_true')
parser.add_argument("--noshift", help="If set, do not attempt to Trim and Shift the GGUF context.", action='store_true')
parser.add_argument("--bantokens", help="You can manually specify a list of token SUBSTRINGS that the AI cannot use. This bans ALL instances of that substring.", metavar=('[token_substrings]'), nargs='+')
parser.add_argument("--forceversion", help="If the model file format detection fails (e.g. rogue modified model) you can set this to override the detected format (enter desired version, e.g. 401 for GPTNeoX-Type2).",metavar=('[version]'), type=int, default=0)
parser.add_argument("--nommap", help="If set, do not use mmap to load newer models", action='store_true')
parser.add_argument("--usemlock", help="For Apple Systems. Force system to keep model in RAM rather than swapping or compressing", action='store_true')