mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
generate passes whole object now
This commit is contained in:
parent
7fab499b79
commit
139ab3d198
2 changed files with 63 additions and 49 deletions
4
class.py
4
class.py
|
@ -306,11 +306,11 @@ class model_backend(InferenceModel):
|
|||
# Store context in memory to use it for comparison with generated content
|
||||
utils.koboldai_vars.lastctx = decoded_prompt
|
||||
|
||||
self.input_queue.put({'command': 'generate', 'data': [(decoded_prompt,), {'max_length': max_new, 'max_context_length': utils.koboldai_vars.max_length,
|
||||
self.input_queue.put({'command': 'generate', 'data': {'prompt':decoded_prompt, 'max_length': max_new, 'max_context_length': utils.koboldai_vars.max_length,
|
||||
'temperature': gen_settings.temp, 'top_k': int(gen_settings.top_k), 'top_a': gen_settings.top_a, 'top_p': gen_settings.top_p,
|
||||
'typical_p': gen_settings.typical, 'tfs': gen_settings.tfs, 'rep_pen': gen_settings.rep_pen, 'rep_pen_range': gen_settings.rep_pen_range,
|
||||
"sampler_order": gen_settings.sampler_order, "use_default_badwordsids": utils.koboldai_vars.use_default_badwordsids}
|
||||
]})
|
||||
})
|
||||
|
||||
#genresult = koboldcpp.generate(decoded_prompt,"",max_new,utils.koboldai_vars.max_length,
|
||||
#gen_settings.temp,int(gen_settings.top_k),gen_settings.top_a,gen_settings.top_p,
|
||||
|
|
104
koboldcpp.py
104
koboldcpp.py
|
@ -865,8 +865,51 @@ def load_model(model_filename):
|
|||
ret = handle.load_model(inputs)
|
||||
return ret
|
||||
|
||||
def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, rep_pen_slope=1.0, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, dry_multiplier=0.0, dry_base=1.75, dry_allowed_length=2, dry_penalty_last_n=0, dry_sequence_breakers=[], sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], ban_eos_token=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[], bypass_eos_token=False):
|
||||
def generate(genparams, is_quiet=False, stream_flag=False):
|
||||
global maxctx, args, currentusergenkey, totalgens, pendingabortkey
|
||||
|
||||
prompt = genparams.get('prompt', "")
|
||||
memory = genparams.get('memory', "")
|
||||
images = genparams.get('images', [])
|
||||
max_context_length = genparams.get('max_context_length', maxctx)
|
||||
max_length = genparams.get('max_length', 180)
|
||||
temperature = genparams.get('temperature', 0.7)
|
||||
top_k = genparams.get('top_k', 100)
|
||||
top_a = genparams.get('top_a', 0.0)
|
||||
top_p = genparams.get('top_p', 0.92)
|
||||
min_p = genparams.get('min_p', 0.0)
|
||||
typical_p = genparams.get('typical', 1.0)
|
||||
tfs = genparams.get('tfs', 1.0)
|
||||
rep_pen = genparams.get('rep_pen', 1.0)
|
||||
rep_pen_range = genparams.get('rep_pen_range', 256)
|
||||
rep_pen_slope = genparams.get('rep_pen_slope', 1.0)
|
||||
presence_penalty = genparams.get('presence_penalty', 0.0)
|
||||
mirostat = genparams.get('mirostat', 0)
|
||||
mirostat_tau = genparams.get('mirostat_tau', 5.0)
|
||||
mirostat_eta = genparams.get('mirostat_eta', 0.1)
|
||||
dry_multiplier = genparams.get('dry_multiplier', 0.0)
|
||||
dry_base = genparams.get('dry_base', 1.75)
|
||||
dry_allowed_length = genparams.get('dry_allowed_length', 2)
|
||||
dry_penalty_last_n = genparams.get('dry_penalty_last_n', 0)
|
||||
dry_sequence_breakers = genparams.get('dry_sequence_breakers', [])
|
||||
sampler_order = genparams.get('sampler_order', [6, 0, 1, 3, 4, 2, 5])
|
||||
seed = tryparseint(genparams.get('sampler_seed', -1))
|
||||
stop_sequence = genparams.get('stop_sequence', [])
|
||||
ban_eos_token = genparams.get('ban_eos_token', False)
|
||||
stream_sse = stream_flag
|
||||
grammar = genparams.get('grammar', '')
|
||||
grammar_retain_state = genparams.get('grammar_retain_state', False)
|
||||
genkey = genparams.get('genkey', '')
|
||||
trimstop = genparams.get('trim_stop', False)
|
||||
quiet = is_quiet
|
||||
dynatemp_range = genparams.get('dynatemp_range', 0.0)
|
||||
dynatemp_exponent = genparams.get('dynatemp_exponent', 1.0)
|
||||
smoothing_factor = genparams.get('smoothing_factor', 0.0)
|
||||
logit_biases = genparams.get('logit_bias', {})
|
||||
render_special = genparams.get('render_special', False)
|
||||
banned_tokens = genparams.get('banned_tokens', [])
|
||||
bypass_eos_token = genparams.get('bypass_eos', False)
|
||||
|
||||
inputs = generation_inputs()
|
||||
inputs.prompt = prompt.encode("UTF-8")
|
||||
inputs.memory = memory.encode("UTF-8")
|
||||
|
@ -1362,47 +1405,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
last_non_horde_req_time = time.time()
|
||||
|
||||
return generate(
|
||||
prompt=genparams.get('prompt', ""),
|
||||
memory=genparams.get('memory', ""),
|
||||
images=genparams.get('images', []),
|
||||
max_context_length=genparams.get('max_context_length', maxctx),
|
||||
max_length=genparams.get('max_length', 180),
|
||||
temperature=genparams.get('temperature', 0.7),
|
||||
top_k=genparams.get('top_k', 100),
|
||||
top_a=genparams.get('top_a', 0.0),
|
||||
top_p=genparams.get('top_p', 0.92),
|
||||
min_p=genparams.get('min_p', 0.0),
|
||||
typical_p=genparams.get('typical', 1.0),
|
||||
tfs=genparams.get('tfs', 1.0),
|
||||
rep_pen=genparams.get('rep_pen', 1.0),
|
||||
rep_pen_range=genparams.get('rep_pen_range', 256),
|
||||
rep_pen_slope=genparams.get('rep_pen_slope', 1.0),
|
||||
presence_penalty=genparams.get('presence_penalty', 0.0),
|
||||
mirostat=genparams.get('mirostat', 0),
|
||||
mirostat_tau=genparams.get('mirostat_tau', 5.0),
|
||||
mirostat_eta=genparams.get('mirostat_eta', 0.1),
|
||||
dry_multiplier=genparams.get('dry_multiplier', 0.0),
|
||||
dry_base=genparams.get('dry_base', 1.75),
|
||||
dry_allowed_length=genparams.get('dry_allowed_length', 2),
|
||||
dry_penalty_last_n=genparams.get('dry_penalty_last_n', 0),
|
||||
dry_sequence_breakers=genparams.get('dry_sequence_breakers', []),
|
||||
sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]),
|
||||
seed=tryparseint(genparams.get('sampler_seed', -1)),
|
||||
stop_sequence=genparams.get('stop_sequence', []),
|
||||
ban_eos_token=genparams.get('ban_eos_token', False),
|
||||
stream_sse=stream_flag,
|
||||
grammar=genparams.get('grammar', ''),
|
||||
grammar_retain_state = genparams.get('grammar_retain_state', False),
|
||||
genkey=genparams.get('genkey', ''),
|
||||
trimstop=genparams.get('trim_stop', False),
|
||||
quiet=is_quiet,
|
||||
dynatemp_range=genparams.get('dynatemp_range', 0.0),
|
||||
dynatemp_exponent=genparams.get('dynatemp_exponent', 1.0),
|
||||
smoothing_factor=genparams.get('smoothing_factor', 0.0),
|
||||
logit_biases=genparams.get('logit_bias', {}),
|
||||
render_special=genparams.get('render_special', False),
|
||||
banned_tokens=genparams.get('banned_tokens', []),
|
||||
bypass_eos_token=genparams.get('bypass_eos', False),
|
||||
genparams=genparams,
|
||||
is_quiet=is_quiet,
|
||||
stream_flag=stream_flag
|
||||
)
|
||||
|
||||
genout = {"text": "", "status": -1, "stopreason": -1}
|
||||
|
@ -4237,7 +4242,16 @@ def main(launch_args,start_server=True):
|
|||
benchprompt = " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1"
|
||||
for i in range(0,14): #generate massive prompt
|
||||
benchprompt += benchprompt
|
||||
genout = generate(benchprompt,memory="",images=[],max_length=benchlen,max_context_length=benchmaxctx,temperature=benchtemp,top_k=benchtopk,rep_pen=benchreppen,ban_eos_token=benchbaneos)
|
||||
genp = {
|
||||
"prompt":benchprompt,
|
||||
"max_length":benchlen,
|
||||
"max_context_length":benchmaxctx,
|
||||
"temperature":benchtemp,
|
||||
"top_k":benchtopk,
|
||||
"rep_pen":benchreppen,
|
||||
"ban_eos_token":benchbaneos
|
||||
}
|
||||
genout = generate(genparams=genp)
|
||||
result = genout['text']
|
||||
if args.prompt and not args.benchmark:
|
||||
restore_stdout()
|
||||
|
@ -4300,8 +4314,8 @@ def run_in_queue(launch_args, input_queue, output_queue):
|
|||
while not input_queue.empty():
|
||||
data = input_queue.get()
|
||||
if data['command'] == 'generate':
|
||||
(args, kwargs) = data['data']
|
||||
genout = generate(*args, **kwargs)
|
||||
pl = data['data']
|
||||
genout = generate(genparams=pl)
|
||||
result = genout['text']
|
||||
output_queue.put({'command': 'generated text', 'data': result})
|
||||
time.sleep(0.2)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue