mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
refactored a lot of code, remove bantokens, move it to api
This commit is contained in:
parent
4ec8a9c57b
commit
c230b78906
6 changed files with 214 additions and 76 deletions
24
koboldcpp.py
24
koboldcpp.py
|
@ -56,7 +56,6 @@ class load_model_inputs(ctypes.Structure):
|
|||
("gpulayers", ctypes.c_int),
|
||||
("rope_freq_scale", ctypes.c_float),
|
||||
("rope_freq_base", ctypes.c_float),
|
||||
("banned_tokens", ctypes.c_char_p * ban_token_max),
|
||||
("tensor_split", ctypes.c_float * tensor_split_max)]
|
||||
|
||||
class generation_inputs(ctypes.Structure):
|
||||
|
@ -91,7 +90,8 @@ class generation_inputs(ctypes.Structure):
|
|||
("dynatemp_range", ctypes.c_float),
|
||||
("dynatemp_exponent", ctypes.c_float),
|
||||
("smoothing_factor", ctypes.c_float),
|
||||
("logit_biases", logit_bias * logit_bias_max)]
|
||||
("logit_biases", logit_bias * logit_bias_max),
|
||||
("banned_tokens", ctypes.c_char_p * ban_token_max)]
|
||||
|
||||
class generation_outputs(ctypes.Structure):
|
||||
_fields_ = [("status", ctypes.c_int),
|
||||
|
@ -391,16 +391,10 @@ def load_model(model_filename):
|
|||
|
||||
inputs.executable_path = (getdirpath()+"/").encode("UTF-8")
|
||||
inputs.debugmode = args.debugmode
|
||||
banned_tokens = args.bantokens
|
||||
for n in range(ban_token_max):
|
||||
if not banned_tokens or n >= len(banned_tokens):
|
||||
inputs.banned_tokens[n] = "".encode("UTF-8")
|
||||
else:
|
||||
inputs.banned_tokens[n] = banned_tokens[n].encode("UTF-8")
|
||||
ret = handle.load_model(inputs)
|
||||
return ret
|
||||
|
||||
def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False):
|
||||
def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[]):
|
||||
global maxctx, args, currentusergenkey, totalgens, pendingabortkey
|
||||
inputs = generation_inputs()
|
||||
inputs.prompt = prompt.encode("UTF-8")
|
||||
|
@ -487,6 +481,12 @@ def generate(prompt, memory="", images=[], max_length=32, max_context_length=512
|
|||
inputs.logit_biases[n] = logit_bias(-1, 0.0)
|
||||
print(f"Skipped unparsable logit bias:{ex}")
|
||||
|
||||
for n in range(ban_token_max):
|
||||
if not banned_tokens or n >= len(banned_tokens):
|
||||
inputs.banned_tokens[n] = "".encode("UTF-8")
|
||||
else:
|
||||
inputs.banned_tokens[n] = banned_tokens[n].encode("UTF-8")
|
||||
|
||||
currentusergenkey = genkey
|
||||
totalgens += 1
|
||||
#early exit if aborted
|
||||
|
@ -672,6 +672,10 @@ def transform_genparams(genparams, api_format):
|
|||
genparams["top_k"] = int(genparams.get('top_k', 120))
|
||||
genparams["max_length"] = genparams.get('max', 100)
|
||||
|
||||
elif api_format==2:
|
||||
if "ignore_eos" in genparams and not ("use_default_badwordsids" in genparams):
|
||||
genparams["use_default_badwordsids"] = genparams.get('ignore_eos', False)
|
||||
|
||||
elif api_format==3 or api_format==4:
|
||||
genparams["max_length"] = genparams.get('max_tokens', 100)
|
||||
presence_penalty = genparams.get('presence_penalty', genparams.get('frequency_penalty', 0.0))
|
||||
|
@ -813,6 +817,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
smoothing_factor=genparams.get('smoothing_factor', 0.0),
|
||||
logit_biases=genparams.get('logit_bias', {}),
|
||||
render_special=genparams.get('render_special', False),
|
||||
banned_tokens=genparams.get('banned_tokens', []),
|
||||
)
|
||||
|
||||
genout = {"text":"","status":-1,"stopreason":-1}
|
||||
|
@ -3281,7 +3286,6 @@ if __name__ == '__main__':
|
|||
parser.add_argument("--lora", help="LLAMA models only, applies a lora file on top of model. Experimental.", metavar=('[lora_filename]', '[lora_base]'), nargs='+')
|
||||
parser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently.", action='store_true')
|
||||
parser.add_argument("--noshift", help="If set, do not attempt to Trim and Shift the GGUF context.", action='store_true')
|
||||
parser.add_argument("--bantokens", help="You can manually specify a list of token SUBSTRINGS that the AI cannot use. This bans ALL instances of that substring.", metavar=('[token_substrings]'), nargs='+')
|
||||
parser.add_argument("--forceversion", help="If the model file format detection fails (e.g. rogue modified model) you can set this to override the detected format (enter desired version, e.g. 401 for GPTNeoX-Type2).",metavar=('[version]'), type=int, default=0)
|
||||
parser.add_argument("--nommap", help="If set, do not use mmap to load newer models", action='store_true')
|
||||
parser.add_argument("--usemlock", help="For Apple Systems. Force system to keep model in RAM rather than swapping or compressing", action='store_true')
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue