mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
generate passes whole object now
This commit is contained in:
parent
7fab499b79
commit
139ab3d198
2 changed files with 63 additions and 49 deletions
4
class.py
4
class.py
|
@ -306,11 +306,11 @@ class model_backend(InferenceModel):
|
||||||
# Store context in memory to use it for comparison with generated content
|
# Store context in memory to use it for comparison with generated content
|
||||||
utils.koboldai_vars.lastctx = decoded_prompt
|
utils.koboldai_vars.lastctx = decoded_prompt
|
||||||
|
|
||||||
self.input_queue.put({'command': 'generate', 'data': [(decoded_prompt,), {'max_length': max_new, 'max_context_length': utils.koboldai_vars.max_length,
|
self.input_queue.put({'command': 'generate', 'data': {'prompt':decoded_prompt, 'max_length': max_new, 'max_context_length': utils.koboldai_vars.max_length,
|
||||||
'temperature': gen_settings.temp, 'top_k': int(gen_settings.top_k), 'top_a': gen_settings.top_a, 'top_p': gen_settings.top_p,
|
'temperature': gen_settings.temp, 'top_k': int(gen_settings.top_k), 'top_a': gen_settings.top_a, 'top_p': gen_settings.top_p,
|
||||||
'typical_p': gen_settings.typical, 'tfs': gen_settings.tfs, 'rep_pen': gen_settings.rep_pen, 'rep_pen_range': gen_settings.rep_pen_range,
|
'typical_p': gen_settings.typical, 'tfs': gen_settings.tfs, 'rep_pen': gen_settings.rep_pen, 'rep_pen_range': gen_settings.rep_pen_range,
|
||||||
"sampler_order": gen_settings.sampler_order, "use_default_badwordsids": utils.koboldai_vars.use_default_badwordsids}
|
"sampler_order": gen_settings.sampler_order, "use_default_badwordsids": utils.koboldai_vars.use_default_badwordsids}
|
||||||
]})
|
})
|
||||||
|
|
||||||
#genresult = koboldcpp.generate(decoded_prompt,"",max_new,utils.koboldai_vars.max_length,
|
#genresult = koboldcpp.generate(decoded_prompt,"",max_new,utils.koboldai_vars.max_length,
|
||||||
#gen_settings.temp,int(gen_settings.top_k),gen_settings.top_a,gen_settings.top_p,
|
#gen_settings.temp,int(gen_settings.top_k),gen_settings.top_a,gen_settings.top_p,
|
||||||
|
|
104
koboldcpp.py
104
koboldcpp.py
|
@ -865,8 +865,51 @@ def load_model(model_filename):
|
||||||
ret = handle.load_model(inputs)
|
ret = handle.load_model(inputs)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, rep_pen_slope=1.0, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, dry_multiplier=0.0, dry_base=1.75, dry_allowed_length=2, dry_penalty_last_n=0, dry_sequence_breakers=[], sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], ban_eos_token=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[], bypass_eos_token=False):
|
def generate(genparams, is_quiet=False, stream_flag=False):
|
||||||
global maxctx, args, currentusergenkey, totalgens, pendingabortkey
|
global maxctx, args, currentusergenkey, totalgens, pendingabortkey
|
||||||
|
|
||||||
|
prompt = genparams.get('prompt', "")
|
||||||
|
memory = genparams.get('memory', "")
|
||||||
|
images = genparams.get('images', [])
|
||||||
|
max_context_length = genparams.get('max_context_length', maxctx)
|
||||||
|
max_length = genparams.get('max_length', 180)
|
||||||
|
temperature = genparams.get('temperature', 0.7)
|
||||||
|
top_k = genparams.get('top_k', 100)
|
||||||
|
top_a = genparams.get('top_a', 0.0)
|
||||||
|
top_p = genparams.get('top_p', 0.92)
|
||||||
|
min_p = genparams.get('min_p', 0.0)
|
||||||
|
typical_p = genparams.get('typical', 1.0)
|
||||||
|
tfs = genparams.get('tfs', 1.0)
|
||||||
|
rep_pen = genparams.get('rep_pen', 1.0)
|
||||||
|
rep_pen_range = genparams.get('rep_pen_range', 256)
|
||||||
|
rep_pen_slope = genparams.get('rep_pen_slope', 1.0)
|
||||||
|
presence_penalty = genparams.get('presence_penalty', 0.0)
|
||||||
|
mirostat = genparams.get('mirostat', 0)
|
||||||
|
mirostat_tau = genparams.get('mirostat_tau', 5.0)
|
||||||
|
mirostat_eta = genparams.get('mirostat_eta', 0.1)
|
||||||
|
dry_multiplier = genparams.get('dry_multiplier', 0.0)
|
||||||
|
dry_base = genparams.get('dry_base', 1.75)
|
||||||
|
dry_allowed_length = genparams.get('dry_allowed_length', 2)
|
||||||
|
dry_penalty_last_n = genparams.get('dry_penalty_last_n', 0)
|
||||||
|
dry_sequence_breakers = genparams.get('dry_sequence_breakers', [])
|
||||||
|
sampler_order = genparams.get('sampler_order', [6, 0, 1, 3, 4, 2, 5])
|
||||||
|
seed = tryparseint(genparams.get('sampler_seed', -1))
|
||||||
|
stop_sequence = genparams.get('stop_sequence', [])
|
||||||
|
ban_eos_token = genparams.get('ban_eos_token', False)
|
||||||
|
stream_sse = stream_flag
|
||||||
|
grammar = genparams.get('grammar', '')
|
||||||
|
grammar_retain_state = genparams.get('grammar_retain_state', False)
|
||||||
|
genkey = genparams.get('genkey', '')
|
||||||
|
trimstop = genparams.get('trim_stop', False)
|
||||||
|
quiet = is_quiet
|
||||||
|
dynatemp_range = genparams.get('dynatemp_range', 0.0)
|
||||||
|
dynatemp_exponent = genparams.get('dynatemp_exponent', 1.0)
|
||||||
|
smoothing_factor = genparams.get('smoothing_factor', 0.0)
|
||||||
|
logit_biases = genparams.get('logit_bias', {})
|
||||||
|
render_special = genparams.get('render_special', False)
|
||||||
|
banned_tokens = genparams.get('banned_tokens', [])
|
||||||
|
bypass_eos_token = genparams.get('bypass_eos', False)
|
||||||
|
|
||||||
inputs = generation_inputs()
|
inputs = generation_inputs()
|
||||||
inputs.prompt = prompt.encode("UTF-8")
|
inputs.prompt = prompt.encode("UTF-8")
|
||||||
inputs.memory = memory.encode("UTF-8")
|
inputs.memory = memory.encode("UTF-8")
|
||||||
|
@ -1362,47 +1405,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
last_non_horde_req_time = time.time()
|
last_non_horde_req_time = time.time()
|
||||||
|
|
||||||
return generate(
|
return generate(
|
||||||
prompt=genparams.get('prompt', ""),
|
genparams=genparams,
|
||||||
memory=genparams.get('memory', ""),
|
is_quiet=is_quiet,
|
||||||
images=genparams.get('images', []),
|
stream_flag=stream_flag
|
||||||
max_context_length=genparams.get('max_context_length', maxctx),
|
|
||||||
max_length=genparams.get('max_length', 180),
|
|
||||||
temperature=genparams.get('temperature', 0.7),
|
|
||||||
top_k=genparams.get('top_k', 100),
|
|
||||||
top_a=genparams.get('top_a', 0.0),
|
|
||||||
top_p=genparams.get('top_p', 0.92),
|
|
||||||
min_p=genparams.get('min_p', 0.0),
|
|
||||||
typical_p=genparams.get('typical', 1.0),
|
|
||||||
tfs=genparams.get('tfs', 1.0),
|
|
||||||
rep_pen=genparams.get('rep_pen', 1.0),
|
|
||||||
rep_pen_range=genparams.get('rep_pen_range', 256),
|
|
||||||
rep_pen_slope=genparams.get('rep_pen_slope', 1.0),
|
|
||||||
presence_penalty=genparams.get('presence_penalty', 0.0),
|
|
||||||
mirostat=genparams.get('mirostat', 0),
|
|
||||||
mirostat_tau=genparams.get('mirostat_tau', 5.0),
|
|
||||||
mirostat_eta=genparams.get('mirostat_eta', 0.1),
|
|
||||||
dry_multiplier=genparams.get('dry_multiplier', 0.0),
|
|
||||||
dry_base=genparams.get('dry_base', 1.75),
|
|
||||||
dry_allowed_length=genparams.get('dry_allowed_length', 2),
|
|
||||||
dry_penalty_last_n=genparams.get('dry_penalty_last_n', 0),
|
|
||||||
dry_sequence_breakers=genparams.get('dry_sequence_breakers', []),
|
|
||||||
sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]),
|
|
||||||
seed=tryparseint(genparams.get('sampler_seed', -1)),
|
|
||||||
stop_sequence=genparams.get('stop_sequence', []),
|
|
||||||
ban_eos_token=genparams.get('ban_eos_token', False),
|
|
||||||
stream_sse=stream_flag,
|
|
||||||
grammar=genparams.get('grammar', ''),
|
|
||||||
grammar_retain_state = genparams.get('grammar_retain_state', False),
|
|
||||||
genkey=genparams.get('genkey', ''),
|
|
||||||
trimstop=genparams.get('trim_stop', False),
|
|
||||||
quiet=is_quiet,
|
|
||||||
dynatemp_range=genparams.get('dynatemp_range', 0.0),
|
|
||||||
dynatemp_exponent=genparams.get('dynatemp_exponent', 1.0),
|
|
||||||
smoothing_factor=genparams.get('smoothing_factor', 0.0),
|
|
||||||
logit_biases=genparams.get('logit_bias', {}),
|
|
||||||
render_special=genparams.get('render_special', False),
|
|
||||||
banned_tokens=genparams.get('banned_tokens', []),
|
|
||||||
bypass_eos_token=genparams.get('bypass_eos', False),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
genout = {"text": "", "status": -1, "stopreason": -1}
|
genout = {"text": "", "status": -1, "stopreason": -1}
|
||||||
|
@ -4237,7 +4242,16 @@ def main(launch_args,start_server=True):
|
||||||
benchprompt = " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1"
|
benchprompt = " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1"
|
||||||
for i in range(0,14): #generate massive prompt
|
for i in range(0,14): #generate massive prompt
|
||||||
benchprompt += benchprompt
|
benchprompt += benchprompt
|
||||||
genout = generate(benchprompt,memory="",images=[],max_length=benchlen,max_context_length=benchmaxctx,temperature=benchtemp,top_k=benchtopk,rep_pen=benchreppen,ban_eos_token=benchbaneos)
|
genp = {
|
||||||
|
"prompt":benchprompt,
|
||||||
|
"max_length":benchlen,
|
||||||
|
"max_context_length":benchmaxctx,
|
||||||
|
"temperature":benchtemp,
|
||||||
|
"top_k":benchtopk,
|
||||||
|
"rep_pen":benchreppen,
|
||||||
|
"ban_eos_token":benchbaneos
|
||||||
|
}
|
||||||
|
genout = generate(genparams=genp)
|
||||||
result = genout['text']
|
result = genout['text']
|
||||||
if args.prompt and not args.benchmark:
|
if args.prompt and not args.benchmark:
|
||||||
restore_stdout()
|
restore_stdout()
|
||||||
|
@ -4300,8 +4314,8 @@ def run_in_queue(launch_args, input_queue, output_queue):
|
||||||
while not input_queue.empty():
|
while not input_queue.empty():
|
||||||
data = input_queue.get()
|
data = input_queue.get()
|
||||||
if data['command'] == 'generate':
|
if data['command'] == 'generate':
|
||||||
(args, kwargs) = data['data']
|
pl = data['data']
|
||||||
genout = generate(*args, **kwargs)
|
genout = generate(genparams=pl)
|
||||||
result = genout['text']
|
result = genout['text']
|
||||||
output_queue.put({'command': 'generated text', 'data': result})
|
output_queue.put({'command': 'generated text', 'data': result})
|
||||||
time.sleep(0.2)
|
time.sleep(0.2)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue